[tint] Add subgroup builtin function validation
Subgroup builtin functions are either enabled with
`chromium_experimental_subgroups`, or with `subgroups` and
(optionally) `subgroups_f16`.
Update error messages to use the new extension names.
Remove the `RequiredExtensions()` as it cannot describe this
relationship, and was only used for subgroups.
Bug: 349125464
Change-Id: Iea39249eb99822b445eec3b43355b344d170a453
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/195797
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/wgsl/resolver/builtin_validation_test.cc b/src/tint/lang/wgsl/resolver/builtin_validation_test.cc
index 728a7b8..3470f7c 100644
--- a/src/tint/lang/wgsl/resolver/builtin_validation_test.cc
+++ b/src/tint/lang/wgsl/resolver/builtin_validation_test.cc
@@ -844,10 +844,23 @@
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
- R"(12:34 error: cannot call built-in function 'subgroupBallot' without extension chromium_experimental_subgroups)");
+ R"(12:34 error: cannot call built-in function 'subgroupBallot' without extension 'subgroups')");
}
TEST_F(ResolverBuiltinValidationTest, SubgroupBallotWithExtension) {
+ // enable subgroups;
+ // fn func -> vec4<u32> { return subgroupBallot(); }
+ Enable(wgsl::Extension::kSubgroups);
+
+ Func("func", tint::Empty, ty.vec4<u32>(),
+ Vector{
+ Return(Call("subgroupBallot")),
+ });
+
+ EXPECT_TRUE(r()->Resolve());
+}
+
+TEST_F(ResolverBuiltinValidationTest, SubgroupBallotWithExperimentalExtension) {
// enable chromium_experimental_subgroups;
// fn func -> vec4<u32> { return subgroupBallot(); }
Enable(wgsl::Extension::kChromiumExperimentalSubgroups);
@@ -870,10 +883,23 @@
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
- R"(12:34 error: cannot call built-in function 'subgroupBroadcast' without extension chromium_experimental_subgroups)");
+ R"(12:34 error: cannot call built-in function 'subgroupBroadcast' without extension 'subgroups')");
}
TEST_F(ResolverBuiltinValidationTest, SubgroupBroadcastWithExtension) {
+ // enable subgroups;
+ // fn func -> i32 { return subgroupBroadcast(1,0); }
+ Enable(wgsl::Extension::kSubgroups);
+
+ Func("func", tint::Empty, ty.i32(),
+ Vector{
+ Return(Call("subgroupBroadcast", 1_i, 0_u)),
+ });
+
+ EXPECT_TRUE(r()->Resolve());
+}
+
+TEST_F(ResolverBuiltinValidationTest, SubgroupBroadcastWithExperimentalExtension) {
// enable chromium_experimental_subgroups;
// fn func -> i32 { return subgroupBroadcast(1,0); }
Enable(wgsl::Extension::kChromiumExperimentalSubgroups);
@@ -886,10 +912,108 @@
EXPECT_TRUE(r()->Resolve());
}
+TEST_F(ResolverBuiltinValidationTest, SubgroupBroadcastWithoutExtension_F16) {
+ // enable f16;
+ // enable subgroups;
+ // fn func -> f16 { return subgroupBroadcast(1.h,0); }
+ Enable(wgsl::Extension::kF16);
+ Enable(wgsl::Extension::kSubgroups);
+ Func("func", tint::Empty, ty.f16(),
+ Vector{
+ Return(Call(Source{{12, 34}}, "subgroupBroadcast", 1_h, 0_u)),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: cannot call built-in function 'subgroupBroadcast' without extension 'subgroups_f16')");
+}
+
+TEST_F(ResolverBuiltinValidationTest, SubgroupBroadcastWithExtensions_F16) {
+ // enable f16;
+ // enable subgroups;
+ // enable subgroups_f16;
+ // fn func -> f16 { return subgroupBroadcast(1.h,0); }
+ Enable(wgsl::Extension::kF16);
+ Enable(wgsl::Extension::kSubgroups);
+ Enable(wgsl::Extension::kSubgroupsF16);
+
+ Func("func", tint::Empty, ty.f16(),
+ Vector{
+ Return(Call("subgroupBroadcast", 1_h, 0_u)),
+ });
+
+ EXPECT_TRUE(r()->Resolve());
+}
+
+TEST_F(ResolverBuiltinValidationTest, SubgroupBroadcastWithExperimentalExtension_F16) {
+ // enable f16;
+ // enable chromium_experimental_subgroups;
+ // fn func -> f16 { return subgroupBroadcast(1.h,0); }
+ Enable(wgsl::Extension::kF16);
+ Enable(wgsl::Extension::kChromiumExperimentalSubgroups);
+
+ Func("func", tint::Empty, ty.f16(),
+ Vector{
+ Return(Call("subgroupBroadcast", 1_h, 0_u)),
+ });
+
+ EXPECT_TRUE(r()->Resolve());
+}
+
+TEST_F(ResolverBuiltinValidationTest, SubgroupBroadcastWithoutExtension_VecF16) {
+ // enable f16;
+ // enable subgroups;
+ // fn func -> vec4<f16> { return subgroupBroadcast(vec4(1.h),0); }
+ Enable(wgsl::Extension::kF16);
+ Enable(wgsl::Extension::kSubgroups);
+ Func("func", tint::Empty, ty.vec4<f16>(),
+ Vector{
+ Return(Call(Source{{12, 34}}, "subgroupBroadcast", Call(ty.vec4<f16>(), 1_h), 0_u)),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(12:34 error: cannot call built-in function 'subgroupBroadcast' without extension 'subgroups_f16')");
+}
+
+TEST_F(ResolverBuiltinValidationTest, SubgroupBroadcastWithExtensions_VecF16) {
+ // enable f16;
+ // enable subgroups;
+ // enable subgroups_f16;
+ // fn func -> vec4<f16> { return subgroupBroadcast(vec4(1.h),0); }
+ Enable(wgsl::Extension::kF16);
+ Enable(wgsl::Extension::kSubgroups);
+ Enable(wgsl::Extension::kSubgroupsF16);
+
+ Func("func", tint::Empty, ty.vec4<f16>(),
+ Vector{
+ Return(Call("subgroupBroadcast", Call(ty.vec4<f16>(), 1_h), 0_u)),
+ });
+
+ EXPECT_TRUE(r()->Resolve());
+}
+
+TEST_F(ResolverBuiltinValidationTest, SubgroupBroadcastWithExperimentalExtension_VecF16) {
+ // enable f16;
+ // enable chromium_experimental_subgroups;
+ // fn func -> vec4<f16> { return subgroupBroadcast(vec4(1.h),0); }
+ Enable(wgsl::Extension::kF16);
+ Enable(wgsl::Extension::kChromiumExperimentalSubgroups);
+
+ Func("func", tint::Empty, ty.vec4<f16>(),
+ Vector{
+ Return(Call("subgroupBroadcast", Call(ty.vec4<f16>(), 1_h), 0_u)),
+ });
+
+ EXPECT_TRUE(r()->Resolve());
+}
+
TEST_F(ResolverBuiltinValidationTest, SubroupBroadcastInComputeStage) {
// @vertex fn func { dpdx(1.0); }
- Enable(wgsl::Extension::kChromiumExperimentalSubgroups);
+ Enable(wgsl::Extension::kSubgroups);
auto* call = Call("subgroupBroadcast", 1_f, 0_u);
Func(Source{{1, 2}}, "func", tint::Empty, ty.void_(), Vector{Ignore(call)},
@@ -904,7 +1028,7 @@
TEST_F(ResolverBuiltinValidationTest, SubroupBroadcastInVertexStageIsError) {
// @vertex fn func { dpdx(1.0); }
- Enable(wgsl::Extension::kChromiumExperimentalSubgroups);
+ Enable(wgsl::Extension::kSubgroups);
auto* call = Call(Source{{3, 4}}, "subgroupBroadcast", 1_f, 0_u);
Func("func", tint::Empty, ty.vec4<f32>(), Vector{Ignore(call), Return(Call(ty.vec4<f32>()))},
@@ -920,7 +1044,7 @@
TEST_F(ResolverBuiltinValidationTest, SubroupBroadcastInFragmentStageIsError) {
// @vertex fn func { dpdx(1.0); }
- Enable(wgsl::Extension::kChromiumExperimentalSubgroups);
+ Enable(wgsl::Extension::kSubgroups);
auto* call = Call(Source{{3, 4}}, "subgroupBroadcast", 1_f, 0_u);
Func("func",
@@ -935,7 +1059,7 @@
}
TEST_F(ResolverBuiltinValidationTest, SubgroupBroadcastValueF32) {
- Enable(wgsl::Extension::kChromiumExperimentalSubgroups);
+ Enable(wgsl::Extension::kSubgroups);
Func("func", tint::Empty, ty.f32(),
Vector{
Return(Call("subgroupBroadcast", 1_f, 0_u)),
@@ -944,7 +1068,7 @@
}
TEST_F(ResolverBuiltinValidationTest, SubgroupBroadcastValueI32) {
- Enable(wgsl::Extension::kChromiumExperimentalSubgroups);
+ Enable(wgsl::Extension::kSubgroups);
Func("func", tint::Empty, ty.i32(),
Vector{
Return(Call("subgroupBroadcast", 1_i, 0_u)),
@@ -953,7 +1077,7 @@
}
TEST_F(ResolverBuiltinValidationTest, SubgroupBroadcastValueU32) {
- Enable(wgsl::Extension::kChromiumExperimentalSubgroups);
+ Enable(wgsl::Extension::kSubgroups);
Func("func", tint::Empty, ty.u32(),
Vector{
Return(Call("subgroupBroadcast", 1_u, 0_u)),
@@ -962,7 +1086,7 @@
}
TEST_F(ResolverBuiltinValidationTest, SubgroupBroadcastLaneArgMustBeConst) {
- Enable(wgsl::Extension::kChromiumExperimentalSubgroups);
+ Enable(wgsl::Extension::kSubgroups);
Func("func", tint::Empty, ty.void_(),
Vector{
Decl(Let("lane", Expr(1_u))),
diff --git a/src/tint/lang/wgsl/resolver/validator.cc b/src/tint/lang/wgsl/resolver/validator.cc
index 9aa025a..7b19533 100644
--- a/src/tint/lang/wgsl/resolver/validator.cc
+++ b/src/tint/lang/wgsl/resolver/validator.cc
@@ -1935,13 +1935,20 @@
return true;
}
- const auto extension = builtin->RequiredExtension();
- if (extension != wgsl::Extension::kUndefined) {
- if (!enabled_extensions_.Contains(extension)) {
- AddError(call->Declaration()->source)
- << "cannot call built-in function " << style::Function(builtin->Fn())
- << " without extension " << extension;
- return false;
+ if (builtin->IsSubgroup()) {
+ // The `chromium_experimental_subgroups` extension enables all subgroup features. Otherwise,
+ // we need `subgroups`, or `subgroups_f16` for f16 functions.
+ if (!enabled_extensions_.Contains(wgsl::Extension::kChromiumExperimentalSubgroups)) {
+ auto ext = wgsl::Extension::kSubgroups;
+ if (builtin->ReturnType()->DeepestElement()->Is<core::type::F16>()) {
+ ext = wgsl::Extension::kSubgroupsF16;
+ }
+ if (!enabled_extensions_.Contains(ext)) {
+ AddError(call->Declaration()->source)
+ << "cannot call built-in function " << style::Function(builtin->Fn())
+ << " without extension " << style::Code(wgsl::ToString(ext));
+ return false;
+ }
}
}
diff --git a/src/tint/lang/wgsl/sem/builtin_fn.cc b/src/tint/lang/wgsl/sem/builtin_fn.cc
index b0218eb..c602fc5 100644
--- a/src/tint/lang/wgsl/sem/builtin_fn.cc
+++ b/src/tint/lang/wgsl/sem/builtin_fn.cc
@@ -112,13 +112,6 @@
return wgsl::HasSideEffects(fn_);
}
-wgsl::Extension BuiltinFn::RequiredExtension() const {
- if (IsSubgroup()) {
- return wgsl::Extension::kChromiumExperimentalSubgroups;
- }
- return wgsl::Extension::kUndefined;
-}
-
wgsl::LanguageFeature BuiltinFn::RequiredLanguageFeature() const {
if (fn_ == wgsl::BuiltinFn::kTextureBarrier) {
return wgsl::LanguageFeature::kReadonlyAndReadwriteStorageTextures;
diff --git a/src/tint/lang/wgsl/sem/builtin_fn.h b/src/tint/lang/wgsl/sem/builtin_fn.h
index 18684cc..f9abb9f 100644
--- a/src/tint/lang/wgsl/sem/builtin_fn.h
+++ b/src/tint/lang/wgsl/sem/builtin_fn.h
@@ -117,10 +117,6 @@
/// one of its inputs)
bool HasSideEffects() const;
- /// @returns the required extension of this builtin function. Returns
- /// wgsl::Extension::kUndefined if no extension is required.
- wgsl::Extension RequiredExtension() const;
-
/// @returns the required language feature of this builtin function. Returns
/// wgsl::LanguageFeature::kUndefined if no language feature is required.
wgsl::LanguageFeature RequiredLanguageFeature() const;