[spirv-reader][ir] Add support for OpGroupNonUniform{IF}Mul
Add support to convert the `OpGroupNonUniformFMul` and
`OpGroupNonUniformIMul` instructions into the equivalent
`subgroupMul`, `subgroupInclusiveMul` or `subgroupExclusiveMul`
instructions.
Fixed: 431031712, 431031615
Change-Id: I5f9a566493f8b413c687f8e63b7be03382bd8685
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/252715
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/spirv/reader/parser/builtin_test.cc b/src/tint/lang/spirv/reader/parser/builtin_test.cc
index ffb3934..454a781 100644
--- a/src/tint/lang/spirv/reader/parser/builtin_test.cc
+++ b/src/tint/lang/spirv/reader/parser/builtin_test.cc
@@ -3852,5 +3852,422 @@
SPV_ENV_VULKAN_1_1);
}
+TEST_F(SpirvParserTest, NonUniformIMul_Reduce_Scalar) {
+ EXPECT_IR_SPV(R"(
+ OpCapability Shader
+ OpCapability GroupNonUniformArithmetic
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ OpName %main "main"
+ %uint = OpTypeInt 32 0
+ %uint_1 = OpConstant %uint 1
+ %uint_3 = OpConstant %uint 3
+ %bool = OpTypeBool
+ %true = OpConstantTrue %bool
+ %void = OpTypeVoid
+ %23 = OpTypeFunction %void
+ %main = OpFunction %void None %23
+ %24 = OpLabel
+ %8 = OpGroupNonUniformIMul %uint %uint_3 Reduce %uint_1
+ OpReturn
+ OpFunctionEnd
+)",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:u32 = subgroupMul 1u
+ ret
+ }
+}
+)",
+ SPV_ENV_VULKAN_1_1);
+}
+
+TEST_F(SpirvParserTest, NonUniformIMul_Reduce_Vector) {
+ EXPECT_IR_SPV(R"(
+ OpCapability Shader
+ OpCapability GroupNonUniformArithmetic
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ OpName %main "main"
+ %uint = OpTypeInt 32 0
+ %uint_1 = OpConstant %uint 1
+ %uint_3 = OpConstant %uint 3
+ %int = OpTypeInt 32 1
+ %int_1 = OpConstant %int 1
+ %int_3 = OpConstant %int 3
+ %v3int = OpTypeVector %int 3
+ %12 = OpConstantComposite %v3int %int_1 %int_3 %int_1
+ %bool = OpTypeBool
+ %true = OpConstantTrue %bool
+ %void = OpTypeVoid
+ %23 = OpTypeFunction %void
+ %main = OpFunction %void None %23
+ %24 = OpLabel
+ %8 = OpGroupNonUniformIMul %v3int %uint_3 Reduce %12
+ OpReturn
+ OpFunctionEnd
+)",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:vec3<i32> = subgroupMul vec3<i32>(1i, 3i, 1i)
+ ret
+ }
+}
+)",
+ SPV_ENV_VULKAN_1_1);
+}
+
+TEST_F(SpirvParserTest, NonUniformFMul_Reduce_Scalar) {
+ EXPECT_IR_SPV(R"(
+ OpCapability Shader
+ OpCapability GroupNonUniformArithmetic
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ OpName %main "main"
+ %uint = OpTypeInt 32 0
+ %uint_3 = OpConstant %uint 3
+ %float = OpTypeFloat 32
+ %float_1 = OpConstant %float 1
+ %bool = OpTypeBool
+ %true = OpConstantTrue %bool
+ %void = OpTypeVoid
+ %23 = OpTypeFunction %void
+ %main = OpFunction %void None %23
+ %24 = OpLabel
+ %8 = OpGroupNonUniformFMul %float %uint_3 Reduce %float_1
+ OpReturn
+ OpFunctionEnd
+)",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:f32 = subgroupMul 1.0f
+ ret
+ }
+}
+)",
+ SPV_ENV_VULKAN_1_1);
+}
+
+TEST_F(SpirvParserTest, NonUniformFMul_Reduce_Vector) {
+ EXPECT_IR_SPV(R"(
+ OpCapability Shader
+ OpCapability GroupNonUniformArithmetic
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ OpName %main "main"
+ %uint = OpTypeInt 32 0
+ %uint_1 = OpConstant %uint 1
+ %uint_3 = OpConstant %uint 3
+ %float = OpTypeFloat 32
+ %float_1 = OpConstant %float 1
+ %float_3 = OpConstant %float 3
+ %v3float = OpTypeVector %float 3
+ %12 = OpConstantComposite %v3float %float_1 %float_3 %float_1
+ %bool = OpTypeBool
+ %true = OpConstantTrue %bool
+ %void = OpTypeVoid
+ %23 = OpTypeFunction %void
+ %main = OpFunction %void None %23
+ %24 = OpLabel
+ %8 = OpGroupNonUniformFMul %v3float %uint_3 Reduce %12
+ OpReturn
+ OpFunctionEnd
+)",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:vec3<f32> = subgroupMul vec3<f32>(1.0f, 3.0f, 1.0f)
+ ret
+ }
+}
+)",
+ SPV_ENV_VULKAN_1_1);
+}
+
+TEST_F(SpirvParserTest, NonUniformIMul_InclusiveScan_Scalar) {
+ EXPECT_IR_SPV(R"(
+ OpCapability Shader
+ OpCapability GroupNonUniformArithmetic
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ OpName %main "main"
+ %uint = OpTypeInt 32 0
+ %uint_1 = OpConstant %uint 1
+ %uint_3 = OpConstant %uint 3
+ %bool = OpTypeBool
+ %true = OpConstantTrue %bool
+ %void = OpTypeVoid
+ %23 = OpTypeFunction %void
+ %main = OpFunction %void None %23
+ %24 = OpLabel
+ %8 = OpGroupNonUniformIMul %uint %uint_3 InclusiveScan %uint_1
+ OpReturn
+ OpFunctionEnd
+)",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:u32 = subgroupInclusiveMul 1u
+ ret
+ }
+}
+)",
+ SPV_ENV_VULKAN_1_1);
+}
+
+TEST_F(SpirvParserTest, NonUniformIMul_InclusiveScan_Vector) {
+ EXPECT_IR_SPV(R"(
+ OpCapability Shader
+ OpCapability GroupNonUniformArithmetic
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ OpName %main "main"
+ %uint = OpTypeInt 32 0
+ %uint_1 = OpConstant %uint 1
+ %uint_3 = OpConstant %uint 3
+ %int = OpTypeInt 32 1
+ %int_1 = OpConstant %int 1
+ %int_3 = OpConstant %int 3
+ %v3int = OpTypeVector %int 3
+ %12 = OpConstantComposite %v3int %int_1 %int_3 %int_1
+ %bool = OpTypeBool
+ %true = OpConstantTrue %bool
+ %void = OpTypeVoid
+ %23 = OpTypeFunction %void
+ %main = OpFunction %void None %23
+ %24 = OpLabel
+ %8 = OpGroupNonUniformIMul %v3int %uint_3 InclusiveScan %12
+ OpReturn
+ OpFunctionEnd
+)",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:vec3<i32> = subgroupInclusiveMul vec3<i32>(1i, 3i, 1i)
+ ret
+ }
+}
+)",
+ SPV_ENV_VULKAN_1_1);
+}
+
+TEST_F(SpirvParserTest, NonUniformFMul_InclusiveScan_Scalar) {
+ EXPECT_IR_SPV(R"(
+ OpCapability Shader
+ OpCapability GroupNonUniformArithmetic
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ OpName %main "main"
+ %uint = OpTypeInt 32 0
+ %uint_3 = OpConstant %uint 3
+ %float = OpTypeFloat 32
+ %float_1 = OpConstant %float 1
+ %bool = OpTypeBool
+ %true = OpConstantTrue %bool
+ %void = OpTypeVoid
+ %23 = OpTypeFunction %void
+ %main = OpFunction %void None %23
+ %24 = OpLabel
+ %8 = OpGroupNonUniformFMul %float %uint_3 InclusiveScan %float_1
+ OpReturn
+ OpFunctionEnd
+)",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:f32 = subgroupInclusiveMul 1.0f
+ ret
+ }
+}
+)",
+ SPV_ENV_VULKAN_1_1);
+}
+
+TEST_F(SpirvParserTest, NonUniformFMul_InclusiveScan_Vector) {
+ EXPECT_IR_SPV(R"(
+ OpCapability Shader
+ OpCapability GroupNonUniformArithmetic
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ OpName %main "main"
+ %uint = OpTypeInt 32 0
+ %uint_1 = OpConstant %uint 1
+ %uint_3 = OpConstant %uint 3
+ %float = OpTypeFloat 32
+ %float_1 = OpConstant %float 1
+ %float_3 = OpConstant %float 3
+ %v3float = OpTypeVector %float 3
+ %12 = OpConstantComposite %v3float %float_1 %float_3 %float_1
+ %bool = OpTypeBool
+ %true = OpConstantTrue %bool
+ %void = OpTypeVoid
+ %23 = OpTypeFunction %void
+ %main = OpFunction %void None %23
+ %24 = OpLabel
+ %8 = OpGroupNonUniformFMul %v3float %uint_3 InclusiveScan %12
+ OpReturn
+ OpFunctionEnd
+)",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:vec3<f32> = subgroupInclusiveMul vec3<f32>(1.0f, 3.0f, 1.0f)
+ ret
+ }
+}
+)",
+ SPV_ENV_VULKAN_1_1);
+}
+
+TEST_F(SpirvParserTest, NonUniformIMul_ExclusiveScan_Scalar) {
+ EXPECT_IR_SPV(R"(
+ OpCapability Shader
+ OpCapability GroupNonUniformArithmetic
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ OpName %main "main"
+ %uint = OpTypeInt 32 0
+ %uint_1 = OpConstant %uint 1
+ %uint_3 = OpConstant %uint 3
+ %bool = OpTypeBool
+ %true = OpConstantTrue %bool
+ %void = OpTypeVoid
+ %23 = OpTypeFunction %void
+ %main = OpFunction %void None %23
+ %24 = OpLabel
+ %8 = OpGroupNonUniformIMul %uint %uint_3 ExclusiveScan %uint_1
+ OpReturn
+ OpFunctionEnd
+)",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:u32 = subgroupExclusiveMul 1u
+ ret
+ }
+}
+)",
+ SPV_ENV_VULKAN_1_1);
+}
+
+TEST_F(SpirvParserTest, NonUniformIMul_ExclusiveScan_Vector) {
+ EXPECT_IR_SPV(R"(
+ OpCapability Shader
+ OpCapability GroupNonUniformArithmetic
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ OpName %main "main"
+ %uint = OpTypeInt 32 0
+ %uint_1 = OpConstant %uint 1
+ %uint_3 = OpConstant %uint 3
+ %int = OpTypeInt 32 1
+ %int_1 = OpConstant %int 1
+ %int_3 = OpConstant %int 3
+ %v3int = OpTypeVector %int 3
+ %12 = OpConstantComposite %v3int %int_1 %int_3 %int_1
+ %bool = OpTypeBool
+ %true = OpConstantTrue %bool
+ %void = OpTypeVoid
+ %23 = OpTypeFunction %void
+ %main = OpFunction %void None %23
+ %24 = OpLabel
+ %8 = OpGroupNonUniformIMul %v3int %uint_3 ExclusiveScan %12
+ OpReturn
+ OpFunctionEnd
+)",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:vec3<i32> = subgroupExclusiveMul vec3<i32>(1i, 3i, 1i)
+ ret
+ }
+}
+)",
+ SPV_ENV_VULKAN_1_1);
+}
+
+TEST_F(SpirvParserTest, NonUniformFMul_ExclusiveScan_Scalar) {
+ EXPECT_IR_SPV(R"(
+ OpCapability Shader
+ OpCapability GroupNonUniformArithmetic
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ OpName %main "main"
+ %uint = OpTypeInt 32 0
+ %uint_3 = OpConstant %uint 3
+ %float = OpTypeFloat 32
+ %float_1 = OpConstant %float 1
+ %bool = OpTypeBool
+ %true = OpConstantTrue %bool
+ %void = OpTypeVoid
+ %23 = OpTypeFunction %void
+ %main = OpFunction %void None %23
+ %24 = OpLabel
+ %8 = OpGroupNonUniformFMul %float %uint_3 ExclusiveScan %float_1
+ OpReturn
+ OpFunctionEnd
+)",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:f32 = subgroupExclusiveMul 1.0f
+ ret
+ }
+}
+)",
+ SPV_ENV_VULKAN_1_1);
+}
+
+TEST_F(SpirvParserTest, NonUniformFMul_ExclusiveScan_Vector) {
+ EXPECT_IR_SPV(R"(
+ OpCapability Shader
+ OpCapability GroupNonUniformArithmetic
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ OpName %main "main"
+ %uint = OpTypeInt 32 0
+ %uint_1 = OpConstant %uint 1
+ %uint_3 = OpConstant %uint 3
+ %float = OpTypeFloat 32
+ %float_1 = OpConstant %float 1
+ %float_3 = OpConstant %float 3
+ %v3float = OpTypeVector %float 3
+ %12 = OpConstantComposite %v3float %float_1 %float_3 %float_1
+ %bool = OpTypeBool
+ %true = OpConstantTrue %bool
+ %void = OpTypeVoid
+ %23 = OpTypeFunction %void
+ %main = OpFunction %void None %23
+ %24 = OpLabel
+ %8 = OpGroupNonUniformFMul %v3float %uint_3 ExclusiveScan %12
+ OpReturn
+ OpFunctionEnd
+)",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:vec3<f32> = subgroupExclusiveMul vec3<f32>(1.0f, 3.0f, 1.0f)
+ ret
+ }
+}
+)",
+ SPV_ENV_VULKAN_1_1);
+}
+
} // namespace
} // namespace tint::spirv::reader
diff --git a/src/tint/lang/spirv/reader/parser/parser.cc b/src/tint/lang/spirv/reader/parser/parser.cc
index ae1e03c..c81e951 100644
--- a/src/tint/lang/spirv/reader/parser/parser.cc
+++ b/src/tint/lang/spirv/reader/parser/parser.cc
@@ -2264,6 +2264,10 @@
case spv::Op::OpGroupNonUniformFAdd:
EmitSubgroupAdd(inst);
break;
+ case spv::Op::OpGroupNonUniformIMul:
+ case spv::Op::OpGroupNonUniformFMul:
+ EmitSubgroupMul(inst);
+ break;
default:
TINT_UNIMPLEMENTED()
<< "unhandled SPIR-V instruction: " << static_cast<uint32_t>(inst.opcode());
@@ -2282,6 +2286,26 @@
}
}
+ void EmitSubgroupMul(spvtools::opt::Instruction& inst) {
+ ValidateScope(inst);
+
+ core::BuiltinFn fn = core::BuiltinFn::kNone;
+
+ auto group = inst.GetSingleWordInOperand(1);
+ if (static_cast<spv::GroupOperation>(group) == spv::GroupOperation::Reduce) {
+ fn = core::BuiltinFn::kSubgroupMul;
+ } else if (static_cast<spv::GroupOperation>(group) == spv::GroupOperation::InclusiveScan) {
+ fn = core::BuiltinFn::kSubgroupInclusiveMul;
+ } else if (static_cast<spv::GroupOperation>(group) == spv::GroupOperation::ExclusiveScan) {
+ fn = core::BuiltinFn::kSubgroupExclusiveMul;
+ } else {
+ TINT_ICE() << "GroupNonUniform Mul instruction must have a group of `Reduce`, "
+ "`InclusiveScan`, or `ExclusiveScan`";
+ }
+
+ Emit(b_.Call(Type(inst.type_id()), fn, Args(inst, 4)), inst.result_id());
+ }
+
void EmitSubgroupAdd(spvtools::opt::Instruction& inst) {
ValidateScope(inst);