[tint] Early evaluation errors for subgroupShuffle
This covers the functions subgroupShuffle, subgroupShuffleUp,
subgroupShuffleDown, and subgroupShuffleXor.
There is a CTS in the works:
https://github.com/gpuweb/cts/pull/4065/
Bug: 380862306
Change-Id: I0077557f62b4140bcbdd8601cbe6bc0a1933cf56
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/217074
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: Peter McNeeley <petermcneeley@google.com>
diff --git a/src/tint/lang/wgsl/resolver/builtin_validation_test.cc b/src/tint/lang/wgsl/resolver/builtin_validation_test.cc
index ade78ee..01bf7a1 100644
--- a/src/tint/lang/wgsl/resolver/builtin_validation_test.cc
+++ b/src/tint/lang/wgsl/resolver/builtin_validation_test.cc
@@ -834,6 +834,75 @@
R"(error: workgroupUniformLoad must not be called with an argument that contains an atomic type)");
}
+TEST_F(ResolverBuiltinValidationTest, SubgroupShuffleLaneArgMustBeNonNeg) {
+ Enable(wgsl::Extension::kSubgroups);
+ Func("func", tint::Empty, ty.u32(),
+ Vector{
+ Return(Call("subgroupShuffle", 1_u, Expr(Source{{12, 34}}, -1_i))),
+ });
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: the sourceLaneIndex argument of subgroupShuffle must be greater than or equal to zero)");
+}
+
+TEST_F(ResolverBuiltinValidationTest, SubgroupShuffleLaneArgMustLessThan128Signed) {
+ Enable(wgsl::Extension::kSubgroups);
+ Func("func", tint::Empty, ty.u32(),
+ Vector{
+ Return(Call("subgroupShuffle", 1_u, Expr(Source{{12, 34}}, 128_i))),
+ });
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: the sourceLaneIndex argument of subgroupShuffle must be less than 128)");
+}
+
+TEST_F(ResolverBuiltinValidationTest, SubgroupShuffleLaneArgMustLessThan128) {
+ Enable(wgsl::Extension::kSubgroups);
+ Func("func", tint::Empty, ty.u32(),
+ Vector{
+ Return(Call("subgroupShuffle", 1_u, Expr(Source{{12, 34}}, 128_u))),
+ });
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: the sourceLaneIndex argument of subgroupShuffle must be less than 128)");
+}
+
+TEST_F(ResolverBuiltinValidationTest, SubgroupShuffleUpDeltaArgMustLessThan128) {
+ Enable(wgsl::Extension::kSubgroups);
+ Func("func", tint::Empty, ty.u32(),
+ Vector{
+ Return(Call("subgroupShuffleUp", 1_u, Expr(Source{{12, 34}}, 128_u))),
+ });
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(12:34 error: the delta argument of subgroupShuffleUp must be less than 128)");
+}
+
+TEST_F(ResolverBuiltinValidationTest, SubgroupShuffleDownDeltaArgMustLessThan128) {
+ Enable(wgsl::Extension::kSubgroups);
+ Func("func", tint::Empty, ty.u32(),
+ Vector{
+ Return(Call("subgroupShuffleDown", 1_u, Expr(Source{{12, 34}}, 128_u))),
+ });
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(12:34 error: the delta argument of subgroupShuffleDown must be less than 128)");
+}
+
+TEST_F(ResolverBuiltinValidationTest, SubgroupShuffleXorMaskArgMustLessThan128) {
+ Enable(wgsl::Extension::kSubgroups);
+ Func("func", tint::Empty, ty.u32(),
+ Vector{
+ Return(Call("subgroupShuffleXor", 1_u, Expr(Source{{12, 34}}, 128_u))),
+ });
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(12:34 error: the mask argument of subgroupShuffleXor must be less than 128)");
+}
+
TEST_F(ResolverBuiltinValidationTest, SubgroupBallotWithoutExtension) {
// fn func { return subgroupBallot(true); }
Func("func", tint::Empty, ty.vec4<u32>(),
diff --git a/src/tint/lang/wgsl/resolver/resolver.cc b/src/tint/lang/wgsl/resolver/resolver.cc
index 3bc3f75..894f5cb 100644
--- a/src/tint/lang/wgsl/resolver/resolver.cc
+++ b/src/tint/lang/wgsl/resolver/resolver.cc
@@ -2491,6 +2491,14 @@
return nullptr;
}
break;
+ case wgsl::BuiltinFn::kSubgroupShuffle:
+ case wgsl::BuiltinFn::kSubgroupShuffleUp:
+ case wgsl::BuiltinFn::kSubgroupShuffleDown:
+ case wgsl::BuiltinFn::kSubgroupShuffleXor:
+ if (!validator_.SubgroupShuffleFunction(fn, call)) {
+ return nullptr;
+ }
+ break;
case wgsl::BuiltinFn::kQuadBroadcast:
if (!validator_.QuadBroadcast(call)) {
diff --git a/src/tint/lang/wgsl/resolver/validator.cc b/src/tint/lang/wgsl/resolver/validator.cc
index 8cf0c3d..a17be30 100644
--- a/src/tint/lang/wgsl/resolver/validator.cc
+++ b/src/tint/lang/wgsl/resolver/validator.cc
@@ -1888,6 +1888,61 @@
return true;
}
+bool Validator::SubgroupShuffleFunction(wgsl::BuiltinFn fn, const sem::Call* call) const {
+ auto* builtin = call->Target()->As<sem::BuiltinFn>();
+ if (!builtin) {
+ return false;
+ }
+
+ TINT_ASSERT(call->Arguments().Length() == 2);
+ auto* id = call->Arguments()[1];
+ auto* constant_value = id->ConstantValue();
+
+ if (!constant_value) {
+ // Non const values are allowed as parameters.
+ return true;
+ }
+
+ // User friendly param name.
+ std::string paramName = "sourceLaneIndex";
+ switch (fn) {
+ case wgsl::BuiltinFn::kSubgroupShuffleXor:
+ paramName = "mask";
+ break;
+ case wgsl::BuiltinFn::kSubgroupShuffleUp:
+ case wgsl::BuiltinFn::kSubgroupShuffleDown:
+ paramName = "delta";
+ break;
+ default:
+ break;
+ }
+
+ if (id->Type()->IsSignedIntegerScalar() && constant_value->ValueAs<i32>() < 0) {
+ AddError(id->Declaration()->source)
+ << "the " << paramName << " argument of " << builtin->str()
+ << " must be greater than or equal to zero";
+ return false;
+ }
+
+ if (id->Type()->IsSignedIntegerScalar() &&
+ constant_value->ValueAs<i32>() >= tint::internal_limits::kMaxSubgroupSize) {
+ AddError(id->Declaration()->source)
+ << "the " << paramName << " argument of " << builtin->str() << " must be less than "
+ << tint::internal_limits::kMaxSubgroupSize;
+ return false;
+ }
+
+ if (id->Type()->IsUnsignedIntegerScalar() &&
+ constant_value->ValueAs<u32>() >= tint::internal_limits::kMaxSubgroupSize) {
+ AddError(id->Declaration()->source)
+ << "the " << paramName << " argument of " << builtin->str() << " must be less than "
+ << tint::internal_limits::kMaxSubgroupSize;
+ return false;
+ }
+
+ return true;
+}
+
bool Validator::TextureBuiltinFn(const sem::Call* call) const {
auto* builtin = call->Target()->As<sem::BuiltinFn>();
if (!builtin) {
diff --git a/src/tint/lang/wgsl/resolver/validator.h b/src/tint/lang/wgsl/resolver/validator.h
index a46bb8c..36b34f7 100644
--- a/src/tint/lang/wgsl/resolver/validator.h
+++ b/src/tint/lang/wgsl/resolver/validator.h
@@ -528,6 +528,12 @@
/// @returns true on success, false otherwise
bool ArrayConstructor(const ast::CallExpression* ctor, const sem::Array* arr_type) const;
+ /// Validates a subgroupShuffle builtin functions including Up,Down, and Xor.
+ /// @param fn the builtin call type
+ /// @param call the builtin call to validate
+ /// @returns true on success, false otherwise
+ bool SubgroupShuffleFunction(wgsl::BuiltinFn fn, const sem::Call* call) const;
+
/// Validates a texture builtin function
/// @param call the builtin call to validate
/// @returns true on success, false otherwise