[ir] Add checker for subgroup_id
Prevents the validator from crashing when subgroup_id is used, which
was triggering fuzzer bugs.
Fixed: 416777692
Change-Id: I89fe35050d2818ebcc7f1dad9917ca6829461c8a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/241834
Auto-Submit: James Price <jrprice@google.com>
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: Ryan Harrison <rharrison@chromium.org>
diff --git a/src/tint/lang/core/ir/validator.cc b/src/tint/lang/core/ir/validator.cc
index c3561e3..c0dc46e 100644
--- a/src/tint/lang/core/ir/validator.cc
+++ b/src/tint/lang/core/ir/validator.cc
@@ -414,6 +414,15 @@
/* type_error */ "sample_index must be an u32",
};
+constexpr BuiltinChecker kSubgroupIdChecker{
+ /* name */ "subgroup_id",
+ /* stages */
+ EnumSet<Function::PipelineStage>(Function::PipelineStage::kCompute),
+ /* direction */ BuiltinChecker::IODirection::kInput,
+ /* type_check */ [](const core::type::Type* ty) -> bool { return ty->Is<core::type::U32>(); },
+ /* type_error */ "subgroup_id must be an u32",
+};
+
constexpr BuiltinChecker kSubgroupInvocationIdChecker{
/* name */ "subgroup_invocation_id",
/* stages */
@@ -476,6 +485,8 @@
return kNumWorkgroupsChecker;
case BuiltinValue::kSampleIndex:
return kSampleIndexChecker;
+ case BuiltinValue::kSubgroupId:
+ return kSubgroupIdChecker;
case BuiltinValue::kSubgroupInvocationId:
return kSubgroupInvocationIdChecker;
case BuiltinValue::kSubgroupSize:
diff --git a/src/tint/lang/core/ir/validator_builtin_test.cc b/src/tint/lang/core/ir/validator_builtin_test.cc
index e55705e..79e034e 100644
--- a/src/tint/lang/core/ir/validator_builtin_test.cc
+++ b/src/tint/lang/core/ir/validator_builtin_test.cc
@@ -814,6 +814,52 @@
)")) << res.Failure();
}
+TEST_F(IR_ValidatorTest, Builtin_SubgroupId_WrongStage) {
+ auto* f = VertexEntryPoint();
+ AddBuiltinParam(f, "id", BuiltinValue::kSubgroupId, ty.u32());
+
+ b.Append(f->Block(), [&] { b.Unreachable(); });
+
+ auto res = ir::Validate(mod);
+ ASSERT_NE(res, Success);
+ EXPECT_THAT(res.Failure().reason,
+ testing::HasSubstr(
+ R"(:1:19 error: subgroup_id must be used in a compute shader entry point
+%f = @vertex func(%id:u32 [@subgroup_id]):vec4<f32> [@position] {
+ ^^^^^^^
+)")) << res.Failure();
+}
+
+TEST_F(IR_ValidatorTest, Builtin_SubgroupId_WrongIODirection) {
+ auto* f = ComputeEntryPoint();
+ AddBuiltinReturn(f, "id", BuiltinValue::kSubgroupId, ty.u32());
+
+ b.Append(f->Block(), [&] { b.Unreachable(); });
+
+ auto res = ir::Validate(mod);
+ ASSERT_NE(res, Success);
+ EXPECT_THAT(
+ res.Failure().reason,
+ testing::HasSubstr(R"(:1:1 error: subgroup_id must be an input of a shader entry point
+%f = @compute @workgroup_size(1u, 1u, 1u) func():u32 [@subgroup_id] {
+^^
+)")) << res.Failure();
+}
+
+TEST_F(IR_ValidatorTest, Builtin_SubgroupId_WrongType) {
+ auto* f = ComputeEntryPoint();
+ AddBuiltinParam(f, "id", BuiltinValue::kSubgroupId, ty.i32());
+
+ b.Append(f->Block(), [&] { b.Unreachable(); });
+
+ auto res = ir::Validate(mod);
+ ASSERT_NE(res, Success);
+ EXPECT_THAT(res.Failure().reason, testing::HasSubstr(R"(:1:48 error: subgroup_id must be an u32
+%f = @compute @workgroup_size(1u, 1u, 1u) func(%id:i32 [@subgroup_id]):void {
+ ^^^^^^^
+)")) << res.Failure();
+}
+
TEST_F(IR_ValidatorTest, Builtin_SubgroupSize_WrongStage) {
auto* f = VertexEntryPoint();
AddBuiltinParam(f, "size", BuiltinValue::kSubgroupSize, ty.u32());