[tint] Check uniformity for subgroup matrix builtin arguments
Requires changing some tests that stored subgroup matrix values in
var<private> declarations, which are currently always considered to be
non-uniform.
Fixed: 403611487
Change-Id: I3c1148596407187c50271c187ecff818136e8039
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/236054
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/tint/lang/wgsl/resolver/subgroup_matrix_test.cc b/src/tint/lang/wgsl/resolver/subgroup_matrix_test.cc
index 2ff11a6..ccedf7a 100644
--- a/src/tint/lang/wgsl/resolver/subgroup_matrix_test.cc
+++ b/src/tint/lang/wgsl/resolver/subgroup_matrix_test.cc
@@ -410,11 +410,13 @@
TEST_F(ResolverSubgroupMatrixTest, SubgroupMatrixMultiply) {
Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
- auto* left = GlobalVar("left", private_, ty("subgroup_matrix_left", ty.f32(), 2_u, 4_u));
- auto* right = GlobalVar("right", private_, ty("subgroup_matrix_right", ty.f32(), 8_u, 2_u));
+ auto* left = Var("left", function, ty("subgroup_matrix_left", ty.f32(), 2_u, 4_u));
+ auto* right = Var("right", function, ty("subgroup_matrix_right", ty.f32(), 8_u, 2_u));
auto* call = Call(Ident(wgsl::BuiltinFn::kSubgroupMatrixMultiply, ty.f32()), left, right);
Func("foo", Empty, ty.void_(),
Vector{
+ Decl(left),
+ Decl(right),
Assign(Phony(), call),
});
@@ -435,11 +437,13 @@
TEST_F(ResolverSubgroupMatrixTest, SubgroupMatrixMultiply_MissingTemplateArg) {
Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
- auto* left = GlobalVar("left", private_, ty("subgroup_matrix_left", ty.f32(), 2_u, 4_u));
- auto* right = GlobalVar("right", private_, ty("subgroup_matrix_right", ty.f32(), 8_u, 2_u));
+ auto* left = Var("left", function, ty("subgroup_matrix_left", ty.f32(), 2_u, 4_u));
+ auto* right = Var("right", function, ty("subgroup_matrix_right", ty.f32(), 8_u, 2_u));
auto* call = Call(wgsl::BuiltinFn::kSubgroupMatrixMultiply, left, right);
Func("foo", Empty, ty.void_(),
Vector{
+ Decl(left),
+ Decl(right),
Assign(Phony(), call),
});
@@ -450,11 +454,13 @@
TEST_F(ResolverSubgroupMatrixTest, SubgroupMatrixMultiply_MismatchDimensions) {
Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
- auto* left = GlobalVar("left", private_, ty("subgroup_matrix_left", ty.f32(), 4_u, 2_u));
- auto* right = GlobalVar("right", private_, ty("subgroup_matrix_right", ty.f32(), 2_u, 8_u));
+ auto* left = Var("left", function, ty("subgroup_matrix_left", ty.f32(), 4_u, 2_u));
+ auto* right = Var("right", function, ty("subgroup_matrix_right", ty.f32(), 2_u, 8_u));
auto* call = Call(Ident(wgsl::BuiltinFn::kSubgroupMatrixMultiply, ty.f32()), left, right);
Func("foo", Empty, ty.void_(),
Vector{
+ Decl(left),
+ Decl(right),
Assign(Phony(), call),
});
@@ -465,11 +471,13 @@
TEST_F(ResolverSubgroupMatrixTest, SubgroupMatrixMultiply_MismatchTypes) {
Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
- auto* left = GlobalVar("left", private_, ty("subgroup_matrix_left", ty.u32(), 8_u, 8_u));
- auto* right = GlobalVar("right", private_, ty("subgroup_matrix_right", ty.i32(), 8_u, 8_u));
+ auto* left = Var("left", function, ty("subgroup_matrix_left", ty.u32(), 8_u, 8_u));
+ auto* right = Var("right", function, ty("subgroup_matrix_right", ty.i32(), 8_u, 8_u));
auto* call = Call(Ident(wgsl::BuiltinFn::kSubgroupMatrixMultiply, ty.f32()), left, right);
Func("foo", Empty, ty.void_(),
Vector{
+ Decl(left),
+ Decl(right),
Assign(Phony(), call),
});
@@ -480,11 +488,13 @@
TEST_F(ResolverSubgroupMatrixTest, SubgroupMatrixMultiply_MismatchKinds) {
Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
- auto* left = GlobalVar("left", private_, ty("subgroup_matrix_left", ty.f32(), 8_u, 8_u));
- auto* right = GlobalVar("right", private_, ty("subgroup_matrix_right", ty.f32(), 8_u, 8_u));
+ auto* left = Var("left", function, ty("subgroup_matrix_left", ty.f32(), 8_u, 8_u));
+ auto* right = Var("right", function, ty("subgroup_matrix_right", ty.f32(), 8_u, 8_u));
auto* call = Call(Ident(wgsl::BuiltinFn::kSubgroupMatrixMultiply, ty.f32()), right, left);
Func("foo", Empty, ty.void_(),
Vector{
+ Decl(left),
+ Decl(right),
Assign(Phony(), call),
});
@@ -495,12 +505,15 @@
TEST_F(ResolverSubgroupMatrixTest, SubgroupMatrixMultiplyAccumulate) {
Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
- auto* left = GlobalVar("left", private_, ty("subgroup_matrix_left", ty.f32(), 2_u, 4_u));
- auto* right = GlobalVar("right", private_, ty("subgroup_matrix_right", ty.f32(), 8_u, 2_u));
- auto* acc = GlobalVar("acc", private_, ty("subgroup_matrix_result", ty.f32(), 8_u, 4_u));
+ auto* left = Var("left", function, ty("subgroup_matrix_left", ty.f32(), 2_u, 4_u));
+ auto* right = Var("right", function, ty("subgroup_matrix_right", ty.f32(), 8_u, 2_u));
+ auto* acc = Var("acc", function, ty("subgroup_matrix_result", ty.f32(), 8_u, 4_u));
auto* call = Call(wgsl::BuiltinFn::kSubgroupMatrixMultiplyAccumulate, left, right, acc);
Func("foo", Empty, ty.void_(),
Vector{
+ Decl(left),
+ Decl(right),
+ Decl(acc),
Assign(Phony(), call),
});
diff --git a/src/tint/lang/wgsl/resolver/uniformity.cc b/src/tint/lang/wgsl/resolver/uniformity.cc
index 551af00..2cf5545 100644
--- a/src/tint/lang/wgsl/resolver/uniformity.cc
+++ b/src/tint/lang/wgsl/resolver/uniformity.cc
@@ -1854,6 +1854,7 @@
}
} else {
auto* builtin = sem->Target()->As<sem::BuiltinFn>();
+ auto* construct = sem->Target()->As<sem::ValueConstructor>();
if (builtin && builtin->Fn() == wgsl::BuiltinFn::kWorkgroupUniformLoad) {
// The workgroupUniformLoad builtin requires its parameter to be uniform.
current_function_->RequiredToBeUniform(default_severity)->AddEdge(args[i]);
@@ -1870,6 +1871,16 @@
if (severity != wgsl::DiagnosticSeverity::kOff) {
current_function_->RequiredToBeUniform(severity)->AddEdge(args[i]);
}
+ } else if (((builtin && builtin->IsSubgroupMatrix()) ||
+ (construct &&
+ construct->ReturnType()->Is<core::type::SubgroupMatrix>()))) {
+ // For all subgroup matrix builtins and constructors, all arguments must be
+ // uniform.
+ auto severity = sem_.DiagnosticSeverity(
+ call->args[i], wgsl::ChromiumDiagnosticRule::kSubgroupMatrixUniformity);
+ if (severity != wgsl::DiagnosticSeverity::kOff) {
+ current_function_->RequiredToBeUniform(severity)->AddEdge(args[i]);
+ }
} else {
// All other builtin function parameters are RequiredToBeUniformForReturnValue,
// as are parameters for value constructors and value conversions.
@@ -2137,29 +2148,37 @@
cause->type == Node::kFunctionCallArgumentContents) {
bool is_value = (cause->type == Node::kFunctionCallArgumentValue);
- auto* user_func = target->As<sem::Function>();
- if (user_func) {
- // Recurse into the called function to show the reason for the requirement.
- auto next_function = functions_.Get(user_func->Declaration());
- auto& param_info = next_function->parameters[cause->arg_index];
- MakeError(*next_function,
- is_value ? param_info.value : param_info.ptr_input_contents, severity);
+ Switch(
+ target, //
+ [&](const sem::Function* user_func) {
+ // Recurse into the called function to show the reason for the requirement.
+ auto next_function = functions_.Get(user_func->Declaration());
+ auto& param_info = next_function->parameters[cause->arg_index];
+ MakeError(*next_function,
+ is_value ? param_info.value : param_info.ptr_input_contents,
+ severity);
- // Show the place where the non-uniform argument was passed.
- // If this is a builtin, this will be the trigger location for the failure.
- StringStream ss;
- ss << "possibly non-uniform value passed" << (is_value ? "" : " via pointer")
- << " here";
- report(call->args[cause->arg_index]->source, ss.str(), /* note */ true);
- } else {
- // The uniformity requirement must come from a builtin function.
- auto* builtin = target->As<sem::BuiltinFn>();
- TINT_ASSERT(builtin);
- StringStream ss;
- ss << "'" << builtin->Fn() << "' requires argument " << cause->arg_index << " to "
- << (is_value ? "be uniform" : "have uniform contents");
- report(call->args[cause->arg_index]->source, ss.str(), /* note */ false);
- }
+ // Show the place where the non-uniform argument was passed.
+ // If this is a builtin, this will be the trigger location for the failure.
+ StringStream ss;
+ ss << "possibly non-uniform value passed" << (is_value ? "" : " via pointer")
+ << " here";
+ report(call->args[cause->arg_index]->source, ss.str(), /* note */ true);
+ },
+ [&](const sem::BuiltinFn* builtin) {
+ StringStream ss;
+ ss << "'" << builtin->Fn() << "' requires argument " << cause->arg_index
+ << " to " << (is_value ? "be uniform" : "have uniform contents");
+ report(call->args[cause->arg_index]->source, ss.str(), /* note */ false);
+ },
+ [&](const sem::ValueConstructor* construct) {
+ StringStream ss;
+ ss << construct->ReturnType()->FriendlyName()
+ << " constructor requires argument " << cause->arg_index << " to "
+ << (is_value ? "be uniform" : "have uniform contents");
+ report(call->args[cause->arg_index]->source, ss.str(), /* note */ false);
+ },
+ TINT_ICE_ON_NO_MATCH);
// Show the origin of non-uniformity for the value or data that is being passed.
ShowSourceOfNonUniformity(source_node->visited_from);
diff --git a/src/tint/lang/wgsl/resolver/uniformity_test.cc b/src/tint/lang/wgsl/resolver/uniformity_test.cc
index 0ffa2a6..0109a3c 100644
--- a/src/tint/lang/wgsl/resolver/uniformity_test.cc
+++ b/src/tint/lang/wgsl/resolver/uniformity_test.cc
@@ -447,10 +447,6 @@
@group(2) @binding(0) var<storage, read_write> subgroup_matrix_data : array<f32>;
-var<private> subgroup_matrix_left_zero: subgroup_matrix_left<f32, 8, 8>;
-var<private> subgroup_matrix_right_zero: subgroup_matrix_right<f32, 8, 8>;
-var<private> subgroup_matrix_result_zero: subgroup_matrix_result<f32, 8, 8>;
-
const module_const : i32 = 42;
@id(42) override pipeline_overridable : i32;
@@ -464,6 +460,10 @@
let let_uniform_rhs = 7;
let let_nonuniform_rhs = rw;
+ var subgroup_matrix_left_zero: subgroup_matrix_left<f32, 8, 8>;
+ var subgroup_matrix_right_zero: subgroup_matrix_right<f32, 8, 8>;
+ var subgroup_matrix_result_zero: subgroup_matrix_result<f32, 8, 8>;
+
var func_uniform = 7;
var func_non_uniform = 7;
func_non_uniform = rw;
@@ -9983,6 +9983,406 @@
)");
}
+TEST_F(UniformityAnalysisTest, SubgroupMatrixConstructor_NonUniformArgument) {
+ std::string src = R"(
+enable chromium_experimental_subgroup_matrix;
+
+var<private> non_uniform: f32;
+
+fn foo() {
+ _ = subgroup_matrix_result<f32, 8, 8>(non_uniform);
+}
+)";
+
+ RunTest(src, false);
+ EXPECT_EQ(
+ error_,
+ R"(test:7:41 error: subgroup_matrix_result<f32, 8, 8> constructor requires argument 0 to be uniform
+ _ = subgroup_matrix_result<f32, 8, 8>(non_uniform);
+ ^^^^^^^^^^^
+
+test:7:41 note: reading from module-scope private variable 'non_uniform' may result in a non-uniform value
+ _ = subgroup_matrix_result<f32, 8, 8>(non_uniform);
+ ^^^^^^^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, SubgroupMatrixLoad_NonUniformPointer) {
+ std::string src = R"(
+enable chromium_experimental_subgroup_matrix;
+
+var<private> non_uniform: u32;
+
+var<workgroup> buffer : array<array<f32, 64>, 4>;
+
+fn foo() {
+ let p = &buffer[non_uniform];
+ _ = subgroupMatrixLoad<subgroup_matrix_result<f32, 8, 8>>(p, 0, false, 4);
+}
+)";
+
+ RunTest(src, false);
+ EXPECT_EQ(error_,
+ R"(test:10:61 error: 'subgroupMatrixLoad' requires argument 0 to be uniform
+ _ = subgroupMatrixLoad<subgroup_matrix_result<f32, 8, 8>>(p, 0, false, 4);
+ ^
+
+test:9:19 note: reading from module-scope private variable 'non_uniform' may result in a non-uniform value
+ let p = &buffer[non_uniform];
+ ^^^^^^^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, SubgroupMatrixLoad_NonUniformOffset) {
+ std::string src = R"(
+enable chromium_experimental_subgroup_matrix;
+
+var<private> non_uniform: u32;
+
+var<workgroup> buffer : array<f32, 64>;
+
+fn foo() {
+ _ = subgroupMatrixLoad<subgroup_matrix_result<f32, 8, 8>>(&buffer, non_uniform, false, 4);
+}
+)";
+
+ RunTest(src, false);
+ EXPECT_EQ(error_,
+ R"(test:9:70 error: 'subgroupMatrixLoad' requires argument 1 to be uniform
+ _ = subgroupMatrixLoad<subgroup_matrix_result<f32, 8, 8>>(&buffer, non_uniform, false, 4);
+ ^^^^^^^^^^^
+
+test:9:70 note: reading from module-scope private variable 'non_uniform' may result in a non-uniform value
+ _ = subgroupMatrixLoad<subgroup_matrix_result<f32, 8, 8>>(&buffer, non_uniform, false, 4);
+ ^^^^^^^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, SubgroupMatrixLoad_NonUniformStride) {
+ std::string src = R"(
+enable chromium_experimental_subgroup_matrix;
+
+var<private> non_uniform: u32;
+
+var<workgroup> buffer : array<f32, 64>;
+
+fn foo() {
+ _ = subgroupMatrixLoad<subgroup_matrix_result<f32, 8, 8>>(&buffer, 0, false, non_uniform);
+}
+)";
+
+ RunTest(src, false);
+ EXPECT_EQ(error_,
+ R"(test:9:80 error: 'subgroupMatrixLoad' requires argument 3 to be uniform
+ _ = subgroupMatrixLoad<subgroup_matrix_result<f32, 8, 8>>(&buffer, 0, false, non_uniform);
+ ^^^^^^^^^^^
+
+test:9:80 note: reading from module-scope private variable 'non_uniform' may result in a non-uniform value
+ _ = subgroupMatrixLoad<subgroup_matrix_result<f32, 8, 8>>(&buffer, 0, false, non_uniform);
+ ^^^^^^^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, SubgroupMatrixStore_NonUniformPointer) {
+ std::string src = R"(
+enable chromium_experimental_subgroup_matrix;
+
+var<private> non_uniform: u32;
+
+var<workgroup> buffer : array<array<f32, 64>, 4>;
+
+fn foo() {
+ let p = &buffer[non_uniform];
+ let value = subgroup_matrix_result<f32, 8, 8>();
+ subgroupMatrixStore(p, 0, value, false, 4);
+}
+)";
+
+ RunTest(src, false);
+ EXPECT_EQ(error_,
+ R"(test:11:23 error: 'subgroupMatrixStore' requires argument 0 to be uniform
+ subgroupMatrixStore(p, 0, value, false, 4);
+ ^
+
+test:9:19 note: reading from module-scope private variable 'non_uniform' may result in a non-uniform value
+ let p = &buffer[non_uniform];
+ ^^^^^^^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, SubgroupMatrixStore_NonUniformOffset) {
+ std::string src = R"(
+enable chromium_experimental_subgroup_matrix;
+
+var<private> non_uniform: u32;
+
+var<workgroup> buffer : array<f32, 64>;
+
+fn foo() {
+ let value = subgroup_matrix_result<f32, 8, 8>();
+ subgroupMatrixStore(&buffer, non_uniform, value, false, 4);
+}
+)";
+
+ RunTest(src, false);
+ EXPECT_EQ(error_,
+ R"(test:10:32 error: 'subgroupMatrixStore' requires argument 1 to be uniform
+ subgroupMatrixStore(&buffer, non_uniform, value, false, 4);
+ ^^^^^^^^^^^
+
+test:10:32 note: reading from module-scope private variable 'non_uniform' may result in a non-uniform value
+ subgroupMatrixStore(&buffer, non_uniform, value, false, 4);
+ ^^^^^^^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, SubgroupMatrixStore_NonUniformValue) {
+ std::string src = R"(
+enable chromium_experimental_subgroup_matrix;
+
+var<private> non_uniform: u32;
+
+var<workgroup> buffer : array<f32, 64>;
+
+fn bar() -> subgroup_matrix_result<f32, 8, 8> {
+ var value1 = subgroup_matrix_result<f32, 8, 8>(1);
+ var value2 = subgroup_matrix_result<f32, 8, 8>(2);
+ if (non_uniform == 1) {
+ return value1;
+ } else {
+ return value2;
+ }
+}
+
+fn foo() {
+ subgroupMatrixStore(&buffer, 0, bar(), false, 4);
+}
+)";
+
+ RunTest(src, false);
+ EXPECT_EQ(error_,
+ R"(test:19:35 error: 'subgroupMatrixStore' requires argument 2 to be uniform
+ subgroupMatrixStore(&buffer, 0, bar(), false, 4);
+ ^^^^^
+
+test:19:35 note: return value of 'bar' may be non-uniform
+ subgroupMatrixStore(&buffer, 0, bar(), false, 4);
+ ^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, SubgroupMatrixStore_NonUniformStride) {
+ std::string src = R"(
+enable chromium_experimental_subgroup_matrix;
+
+var<private> non_uniform: u32;
+
+var<workgroup> buffer : array<f32, 64>;
+
+fn foo() {
+ let value = subgroup_matrix_result<f32, 8, 8>();
+ subgroupMatrixStore(&buffer, 0, value, false, non_uniform);
+}
+)";
+
+ RunTest(src, false);
+ EXPECT_EQ(error_,
+ R"(test:10:49 error: 'subgroupMatrixStore' requires argument 4 to be uniform
+ subgroupMatrixStore(&buffer, 0, value, false, non_uniform);
+ ^^^^^^^^^^^
+
+test:10:49 note: reading from module-scope private variable 'non_uniform' may result in a non-uniform value
+ subgroupMatrixStore(&buffer, 0, value, false, non_uniform);
+ ^^^^^^^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, SubgroupMatrixMultiply_NonUniformLHS) {
+ std::string src = R"(
+enable chromium_experimental_subgroup_matrix;
+
+var<private> non_uniform: u32;
+
+var<workgroup> buffer : array<f32, 64>;
+
+fn bar() -> subgroup_matrix_left<f32, 8, 8> {
+ var value1 = subgroup_matrix_left<f32, 8, 8>(1);
+ var value2 = subgroup_matrix_left<f32, 8, 8>(2);
+ if (non_uniform == 1) {
+ return value1;
+ } else {
+ return value2;
+ }
+}
+
+fn foo() {
+ let rhs = subgroup_matrix_right<f32, 8, 8>();
+ _ = subgroupMatrixMultiply<f32>(bar(), rhs);
+}
+)";
+
+ RunTest(src, false);
+ EXPECT_EQ(error_,
+ R"(test:20:35 error: 'subgroupMatrixMultiply' requires argument 0 to be uniform
+ _ = subgroupMatrixMultiply<f32>(bar(), rhs);
+ ^^^^^
+
+test:20:35 note: return value of 'bar' may be non-uniform
+ _ = subgroupMatrixMultiply<f32>(bar(), rhs);
+ ^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, SubgroupMatrixMultiply_NonUniformRHS) {
+ std::string src = R"(
+enable chromium_experimental_subgroup_matrix;
+
+var<private> non_uniform: u32;
+
+var<workgroup> buffer : array<f32, 64>;
+
+fn bar() -> subgroup_matrix_right<f32, 8, 8> {
+ var value1 = subgroup_matrix_right<f32, 8, 8>(1);
+ var value2 = subgroup_matrix_right<f32, 8, 8>(2);
+ if (non_uniform == 1) {
+ return value1;
+ } else {
+ return value2;
+ }
+}
+
+fn foo() {
+ let lhs = subgroup_matrix_left<f32, 8, 8>();
+ _ = subgroupMatrixMultiply<f32>(lhs, bar());
+}
+)";
+
+ RunTest(src, false);
+ EXPECT_EQ(error_,
+ R"(test:20:40 error: 'subgroupMatrixMultiply' requires argument 1 to be uniform
+ _ = subgroupMatrixMultiply<f32>(lhs, bar());
+ ^^^^^
+
+test:20:40 note: return value of 'bar' may be non-uniform
+ _ = subgroupMatrixMultiply<f32>(lhs, bar());
+ ^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, SubgroupMatrixMultiplyAccumulate_NonUniformLHS) {
+ std::string src = R"(
+enable chromium_experimental_subgroup_matrix;
+
+var<private> non_uniform: u32;
+
+var<workgroup> buffer : array<f32, 64>;
+
+fn bar() -> subgroup_matrix_left<f32, 8, 8> {
+ var value1 = subgroup_matrix_left<f32, 8, 8>(1);
+ var value2 = subgroup_matrix_left<f32, 8, 8>(2);
+ if (non_uniform == 1) {
+ return value1;
+ } else {
+ return value2;
+ }
+}
+
+fn foo() {
+ let rhs = subgroup_matrix_right<f32, 8, 8>();
+ let acc = subgroup_matrix_result<f32, 8, 8>();
+ _ = subgroupMatrixMultiplyAccumulate(bar(), rhs, acc);
+}
+)";
+
+ RunTest(src, false);
+ EXPECT_EQ(
+ error_,
+ R"(test:21:40 error: 'subgroupMatrixMultiplyAccumulate' requires argument 0 to be uniform
+ _ = subgroupMatrixMultiplyAccumulate(bar(), rhs, acc);
+ ^^^^^
+
+test:21:40 note: return value of 'bar' may be non-uniform
+ _ = subgroupMatrixMultiplyAccumulate(bar(), rhs, acc);
+ ^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, SubgroupMatrixMultiplyAccumulate_NonUniformRHS) {
+ std::string src = R"(
+enable chromium_experimental_subgroup_matrix;
+
+var<private> non_uniform: u32;
+
+var<workgroup> buffer : array<f32, 64>;
+
+fn bar() -> subgroup_matrix_right<f32, 8, 8> {
+ var value1 = subgroup_matrix_right<f32, 8, 8>(1);
+ var value2 = subgroup_matrix_right<f32, 8, 8>(2);
+ if (non_uniform == 1) {
+ return value1;
+ } else {
+ return value2;
+ }
+}
+
+fn foo() {
+ let lhs = subgroup_matrix_left<f32, 8, 8>();
+ let acc = subgroup_matrix_result<f32, 8, 8>();
+ _ = subgroupMatrixMultiplyAccumulate(lhs, bar(), acc);
+}
+)";
+
+ RunTest(src, false);
+ EXPECT_EQ(
+ error_,
+ R"(test:21:45 error: 'subgroupMatrixMultiplyAccumulate' requires argument 1 to be uniform
+ _ = subgroupMatrixMultiplyAccumulate(lhs, bar(), acc);
+ ^^^^^
+
+test:21:45 note: return value of 'bar' may be non-uniform
+ _ = subgroupMatrixMultiplyAccumulate(lhs, bar(), acc);
+ ^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, SubgroupMatrixMultiplyAccumulate_NonUniformAcc) {
+ std::string src = R"(
+enable chromium_experimental_subgroup_matrix;
+
+var<private> non_uniform: u32;
+
+var<workgroup> buffer : array<f32, 64>;
+
+fn bar() -> subgroup_matrix_result<f32, 8, 8> {
+ var value1 = subgroup_matrix_result<f32, 8, 8>(1);
+ var value2 = subgroup_matrix_result<f32, 8, 8>(2);
+ if (non_uniform == 1) {
+ return value1;
+ } else {
+ return value2;
+ }
+}
+
+fn foo() {
+ let lhs = subgroup_matrix_left<f32, 8, 8>();
+ let rhs = subgroup_matrix_right<f32, 8, 8>();
+ _ = subgroupMatrixMultiplyAccumulate(lhs, rhs, bar());
+}
+)";
+
+ RunTest(src, false);
+ EXPECT_EQ(
+ error_,
+ R"(test:21:50 error: 'subgroupMatrixMultiplyAccumulate' requires argument 2 to be uniform
+ _ = subgroupMatrixMultiplyAccumulate(lhs, rhs, bar());
+ ^^^^^
+
+test:21:50 note: return value of 'bar' may be non-uniform
+ _ = subgroupMatrixMultiplyAccumulate(lhs, rhs, bar());
+ ^^^^^
+)");
+}
+
TEST_F(UniformityAnalysisTest, StressGraphTraversalDepth) {
// Create a function with a very long sequence of variable declarations and assignments to
// test traversals of very deep graphs. This requires a non-recursive traversal algorithm.
@@ -10197,6 +10597,32 @@
}
}
+TEST_P(UniformityAnalysisDiagnosticFilterTest,
+ Directive_SubgroupMatrixUniformity_BuiltinFunctionArgument) {
+ auto& param = GetParam();
+ StringStream ss;
+ ss << "enable chromium_experimental_subgroup_matrix;\n"
+ << "diagnostic(" << param << ", chromium.subgroup_matrix_uniformity);" << R"(
+@group(0) @binding(0) var<storage, read_write> non_uniform : u32;
+
+@group(0) @binding(1) var<storage, read_write> data : array<f32>;
+
+fn foo() {
+ _ = subgroupMatrixLoad<subgroup_matrix_left<f32, 8, 8>>(&data, non_uniform, false, 4);
+}
+)";
+
+ RunTest(ss.str(), param != wgsl::DiagnosticSeverity::kError);
+
+ if (param == wgsl::DiagnosticSeverity::kOff) {
+ EXPECT_TRUE(error_.empty());
+ } else {
+ StringStream err;
+ err << ToStr(param) << ": 'subgroupMatrixLoad' requires argument 1 to be uniform";
+ EXPECT_THAT(error_, ::testing::HasSubstr(err.str()));
+ }
+}
+
TEST_P(UniformityAnalysisDiagnosticFilterTest, AttributeOnFunction_DerivativeUniformity) {
auto& param = GetParam();
StringStream ss;