[spirv-reader][ir] Fix constant-id check for NonUniform calls.
The original check that the `value` was constant was a) too broadly
applied and b) the wrong operand to check. This Cl constrains the check
to just the Broadcast, QuadBroadcast and QuadSwap instructions and
corrects the check to use the `Invocation Id` instead of the `Value`.
Fixed: 432807736
Change-Id: I643d3e44a77aa65ed8f9ce26c1fb9ade2c5b8fd6
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/256374
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/tint/lang/spirv/reader/parser/builtin_test.cc b/src/tint/lang/spirv/reader/parser/builtin_test.cc
index 88fef85..9ca470e 100644
--- a/src/tint/lang/spirv/reader/parser/builtin_test.cc
+++ b/src/tint/lang/spirv/reader/parser/builtin_test.cc
@@ -1924,6 +1924,42 @@
SPV_ENV_VULKAN_1_1);
}
+TEST_F(SpirvParserTest, NonUniformBroadcast_NonConstant_NumericVector) {
+ EXPECT_IR_SPV(R"(
+ OpCapability Shader
+ OpCapability GroupNonUniformBallot
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ OpName %main "main"
+ %uint = OpTypeInt 32 0
+ %uint_1 = OpConstant %uint 1
+ %uint_3 = OpConstant %uint 3
+ %v3uint = OpTypeVector %uint 3
+ %12 = OpConstantComposite %v3uint %uint_1 %uint_3 %uint_1
+ %bool = OpTypeBool
+ %true = OpConstantTrue %bool
+ %void = OpTypeVoid
+ %23 = OpTypeFunction %void
+ %main = OpFunction %void None %23
+ %24 = OpLabel
+ %7 = OpCopyObject %v3uint %12
+ %8 = OpGroupNonUniformBroadcast %v3uint %uint_3 %7 %uint_1
+ OpReturn
+ OpFunctionEnd
+)",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:vec3<u32> = let vec3<u32>(1u, 3u, 1u)
+ %3:vec3<u32> = spirv.group_non_uniform_broadcast 3u, %2, 1u
+ ret
+ }
+}
+)",
+ SPV_ENV_VULKAN_1_1);
+}
+
TEST_F(SpirvParserTest, NonUniformBroadcastFirst_Constant_BoolScalar) {
EXPECT_IR_SPV(R"(
OpCapability Shader
@@ -2389,6 +2425,40 @@
SPV_ENV_VULKAN_1_1);
}
+TEST_F(SpirvParserTest, NonUniformShuffle_NonConstant_BoolScalar) {
+ EXPECT_IR_SPV(R"(
+ OpCapability Shader
+ OpCapability GroupNonUniformShuffle
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ OpName %main "main"
+ %uint = OpTypeInt 32 0
+ %uint_1 = OpConstant %uint 1
+ %uint_3 = OpConstant %uint 3
+ %bool = OpTypeBool
+ %true = OpConstantTrue %bool
+ %void = OpTypeVoid
+ %23 = OpTypeFunction %void
+ %main = OpFunction %void None %23
+ %24 = OpLabel
+ %7 = OpCopyObject %bool %true
+ %8 = OpGroupNonUniformShuffle %bool %uint_3 %7 %uint_1
+ OpReturn
+ OpFunctionEnd
+)",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:bool = let true
+ %3:bool = spirv.group_non_uniform_shuffle 3u, %2, 1u
+ ret
+ }
+}
+)",
+ SPV_ENV_VULKAN_1_1);
+}
+
TEST_F(SpirvParserTest, NonUniformShuffle_Constant_BoolVector) {
EXPECT_IR_SPV(R"(
OpCapability Shader
diff --git a/src/tint/lang/spirv/reader/parser/parser.cc b/src/tint/lang/spirv/reader/parser/parser.cc
index 1280cac..9f6a7bd 100644
--- a/src/tint/lang/spirv/reader/parser/parser.cc
+++ b/src/tint/lang/spirv/reader/parser/parser.cc
@@ -2256,6 +2256,17 @@
EmitPhi(inst);
break;
+ case spv::Op::OpGroupNonUniformBroadcast:
+ EmitSubgroupBuiltinConstantId(inst,
+ spirv::BuiltinFn::kGroupNonUniformBroadcast);
+ break;
+ case spv::Op::OpGroupNonUniformQuadBroadcast:
+ EmitSubgroupBuiltinConstantId(inst,
+ spirv::BuiltinFn::kGroupNonUniformQuadBroadcast);
+ break;
+ case spv::Op::OpGroupNonUniformQuadSwap:
+ EmitSubgroupBuiltinConstantId(inst, spirv::BuiltinFn::kGroupNonUniformQuadSwap);
+ break;
case spv::Op::OpGroupNonUniformAll:
EmitSubgroupBuiltin(inst, core::BuiltinFn::kSubgroupAll);
break;
@@ -2271,9 +2282,6 @@
case spv::Op::OpGroupNonUniformBroadcastFirst:
EmitSubgroupBuiltin(inst, spirv::BuiltinFn::kGroupNonUniformBroadcastFirst);
break;
- case spv::Op::OpGroupNonUniformBroadcast:
- EmitSubgroupBuiltin(inst, spirv::BuiltinFn::kGroupNonUniformBroadcast);
- break;
case spv::Op::OpGroupNonUniformShuffle:
EmitSubgroupBuiltin(inst, spirv::BuiltinFn::kGroupNonUniformShuffle);
break;
@@ -2286,12 +2294,6 @@
case spv::Op::OpGroupNonUniformShuffleUp:
EmitSubgroupBuiltin(inst, spirv::BuiltinFn::kGroupNonUniformShuffleUp);
break;
- case spv::Op::OpGroupNonUniformQuadBroadcast:
- EmitSubgroupBuiltin(inst, spirv::BuiltinFn::kGroupNonUniformQuadBroadcast);
- break;
- case spv::Op::OpGroupNonUniformQuadSwap:
- EmitSubgroupBuiltin(inst, spirv::BuiltinFn::kGroupNonUniformQuadSwap);
- break;
case spv::Op::OpGroupNonUniformSMin:
EmitSubgroupMinMax(inst, spirv::BuiltinFn::kGroupNonUniformSMin);
break;
@@ -2419,16 +2421,16 @@
inst.result_id());
}
- void EmitSubgroupBuiltin(spvtools::opt::Instruction& inst, spirv::BuiltinFn fn) {
- auto val = Value(inst.GetSingleWordInOperand(1));
+ void EmitSubgroupBuiltinConstantId(spvtools::opt::Instruction& inst, spirv::BuiltinFn fn) {
+ auto id = Value(inst.GetSingleWordInOperand(2));
// TODO(431054356): Convert core::BuiltinFn::kSubgroupBroadcast non-constant values into a
// `subgroupShuffle` when we support SPIR-V >= 1.5 source.
//
// For QuadBroadcast this will remain an error as there is no WGSL equivalent.
// For QuadSwap this will remain an error as there is no WGSL equivalent.
- if (!val->Is<core::ir::Constant>()) {
- TINT_ICE() << "non-constant GroupNonUniform `value` not supported";
+ if (!id->Is<core::ir::Constant>()) {
+ TINT_ICE() << "non-constant GroupNonUniform `Invocation Id` not supported";
}
ValidateScope(inst);
@@ -2436,6 +2438,12 @@
inst.result_id());
}
+ void EmitSubgroupBuiltin(spvtools::opt::Instruction& inst, spirv::BuiltinFn fn) {
+ ValidateScope(inst);
+ Emit(b_.Call<spirv::ir::BuiltinCall>(Type(inst.type_id()), fn, Args(inst, 2)),
+ inst.result_id());
+ }
+
void EmitSubgroupBuiltin(spvtools::opt::Instruction& inst, core::BuiltinFn fn) {
ValidateScope(inst);
Emit(b_.Call(Type(inst.type_id()), fn, Args(inst, 3)), inst.result_id());