[tint] Check uniformity for subgroup matrix builtin arguments

Requires changing some tests that stored subgroup matrix values in
var<private> declarations, which are currently always considered to be
non-uniform.

Fixed: 403611487
Change-Id: I3c1148596407187c50271c187ecff818136e8039
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/236054
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/tint/lang/wgsl/resolver/subgroup_matrix_test.cc b/src/tint/lang/wgsl/resolver/subgroup_matrix_test.cc
index 2ff11a6..ccedf7a 100644
--- a/src/tint/lang/wgsl/resolver/subgroup_matrix_test.cc
+++ b/src/tint/lang/wgsl/resolver/subgroup_matrix_test.cc
@@ -410,11 +410,13 @@
 
 TEST_F(ResolverSubgroupMatrixTest, SubgroupMatrixMultiply) {
     Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
-    auto* left = GlobalVar("left", private_, ty("subgroup_matrix_left", ty.f32(), 2_u, 4_u));
-    auto* right = GlobalVar("right", private_, ty("subgroup_matrix_right", ty.f32(), 8_u, 2_u));
+    auto* left = Var("left", function, ty("subgroup_matrix_left", ty.f32(), 2_u, 4_u));
+    auto* right = Var("right", function, ty("subgroup_matrix_right", ty.f32(), 8_u, 2_u));
     auto* call = Call(Ident(wgsl::BuiltinFn::kSubgroupMatrixMultiply, ty.f32()), left, right);
     Func("foo", Empty, ty.void_(),
          Vector{
+             Decl(left),
+             Decl(right),
              Assign(Phony(), call),
          });
 
@@ -435,11 +437,13 @@
 
 TEST_F(ResolverSubgroupMatrixTest, SubgroupMatrixMultiply_MissingTemplateArg) {
     Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
-    auto* left = GlobalVar("left", private_, ty("subgroup_matrix_left", ty.f32(), 2_u, 4_u));
-    auto* right = GlobalVar("right", private_, ty("subgroup_matrix_right", ty.f32(), 8_u, 2_u));
+    auto* left = Var("left", function, ty("subgroup_matrix_left", ty.f32(), 2_u, 4_u));
+    auto* right = Var("right", function, ty("subgroup_matrix_right", ty.f32(), 8_u, 2_u));
     auto* call = Call(wgsl::BuiltinFn::kSubgroupMatrixMultiply, left, right);
     Func("foo", Empty, ty.void_(),
          Vector{
+             Decl(left),
+             Decl(right),
              Assign(Phony(), call),
          });
 
@@ -450,11 +454,13 @@
 
 TEST_F(ResolverSubgroupMatrixTest, SubgroupMatrixMultiply_MismatchDimensions) {
     Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
-    auto* left = GlobalVar("left", private_, ty("subgroup_matrix_left", ty.f32(), 4_u, 2_u));
-    auto* right = GlobalVar("right", private_, ty("subgroup_matrix_right", ty.f32(), 2_u, 8_u));
+    auto* left = Var("left", function, ty("subgroup_matrix_left", ty.f32(), 4_u, 2_u));
+    auto* right = Var("right", function, ty("subgroup_matrix_right", ty.f32(), 2_u, 8_u));
     auto* call = Call(Ident(wgsl::BuiltinFn::kSubgroupMatrixMultiply, ty.f32()), left, right);
     Func("foo", Empty, ty.void_(),
          Vector{
+             Decl(left),
+             Decl(right),
              Assign(Phony(), call),
          });
 
@@ -465,11 +471,13 @@
 
 TEST_F(ResolverSubgroupMatrixTest, SubgroupMatrixMultiply_MismatchTypes) {
     Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
-    auto* left = GlobalVar("left", private_, ty("subgroup_matrix_left", ty.u32(), 8_u, 8_u));
-    auto* right = GlobalVar("right", private_, ty("subgroup_matrix_right", ty.i32(), 8_u, 8_u));
+    auto* left = Var("left", function, ty("subgroup_matrix_left", ty.u32(), 8_u, 8_u));
+    auto* right = Var("right", function, ty("subgroup_matrix_right", ty.i32(), 8_u, 8_u));
     auto* call = Call(Ident(wgsl::BuiltinFn::kSubgroupMatrixMultiply, ty.f32()), left, right);
     Func("foo", Empty, ty.void_(),
          Vector{
+             Decl(left),
+             Decl(right),
              Assign(Phony(), call),
          });
 
@@ -480,11 +488,13 @@
 
 TEST_F(ResolverSubgroupMatrixTest, SubgroupMatrixMultiply_MismatchKinds) {
     Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
-    auto* left = GlobalVar("left", private_, ty("subgroup_matrix_left", ty.f32(), 8_u, 8_u));
-    auto* right = GlobalVar("right", private_, ty("subgroup_matrix_right", ty.f32(), 8_u, 8_u));
+    auto* left = Var("left", function, ty("subgroup_matrix_left", ty.f32(), 8_u, 8_u));
+    auto* right = Var("right", function, ty("subgroup_matrix_right", ty.f32(), 8_u, 8_u));
     auto* call = Call(Ident(wgsl::BuiltinFn::kSubgroupMatrixMultiply, ty.f32()), right, left);
     Func("foo", Empty, ty.void_(),
          Vector{
+             Decl(left),
+             Decl(right),
              Assign(Phony(), call),
          });
 
@@ -495,12 +505,15 @@
 
 TEST_F(ResolverSubgroupMatrixTest, SubgroupMatrixMultiplyAccumulate) {
     Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
-    auto* left = GlobalVar("left", private_, ty("subgroup_matrix_left", ty.f32(), 2_u, 4_u));
-    auto* right = GlobalVar("right", private_, ty("subgroup_matrix_right", ty.f32(), 8_u, 2_u));
-    auto* acc = GlobalVar("acc", private_, ty("subgroup_matrix_result", ty.f32(), 8_u, 4_u));
+    auto* left = Var("left", function, ty("subgroup_matrix_left", ty.f32(), 2_u, 4_u));
+    auto* right = Var("right", function, ty("subgroup_matrix_right", ty.f32(), 8_u, 2_u));
+    auto* acc = Var("acc", function, ty("subgroup_matrix_result", ty.f32(), 8_u, 4_u));
     auto* call = Call(wgsl::BuiltinFn::kSubgroupMatrixMultiplyAccumulate, left, right, acc);
     Func("foo", Empty, ty.void_(),
          Vector{
+             Decl(left),
+             Decl(right),
+             Decl(acc),
              Assign(Phony(), call),
          });
 
diff --git a/src/tint/lang/wgsl/resolver/uniformity.cc b/src/tint/lang/wgsl/resolver/uniformity.cc
index 551af00..2cf5545 100644
--- a/src/tint/lang/wgsl/resolver/uniformity.cc
+++ b/src/tint/lang/wgsl/resolver/uniformity.cc
@@ -1854,6 +1854,7 @@
                 }
             } else {
                 auto* builtin = sem->Target()->As<sem::BuiltinFn>();
+                auto* construct = sem->Target()->As<sem::ValueConstructor>();
                 if (builtin && builtin->Fn() == wgsl::BuiltinFn::kWorkgroupUniformLoad) {
                     // The workgroupUniformLoad builtin requires its parameter to be uniform.
                     current_function_->RequiredToBeUniform(default_severity)->AddEdge(args[i]);
@@ -1870,6 +1871,16 @@
                     if (severity != wgsl::DiagnosticSeverity::kOff) {
                         current_function_->RequiredToBeUniform(severity)->AddEdge(args[i]);
                     }
+                } else if (((builtin && builtin->IsSubgroupMatrix()) ||
+                            (construct &&
+                             construct->ReturnType()->Is<core::type::SubgroupMatrix>()))) {
+                    // For all subgroup matrix builtins and constructors, all arguments must be
+                    // uniform.
+                    auto severity = sem_.DiagnosticSeverity(
+                        call->args[i], wgsl::ChromiumDiagnosticRule::kSubgroupMatrixUniformity);
+                    if (severity != wgsl::DiagnosticSeverity::kOff) {
+                        current_function_->RequiredToBeUniform(severity)->AddEdge(args[i]);
+                    }
                 } else {
                     // All other builtin function parameters are RequiredToBeUniformForReturnValue,
                     // as are parameters for value constructors and value conversions.
@@ -2137,29 +2148,37 @@
             cause->type == Node::kFunctionCallArgumentContents) {
             bool is_value = (cause->type == Node::kFunctionCallArgumentValue);
 
-            auto* user_func = target->As<sem::Function>();
-            if (user_func) {
-                // Recurse into the called function to show the reason for the requirement.
-                auto next_function = functions_.Get(user_func->Declaration());
-                auto& param_info = next_function->parameters[cause->arg_index];
-                MakeError(*next_function,
-                          is_value ? param_info.value : param_info.ptr_input_contents, severity);
+            Switch(
+                target,  //
+                [&](const sem::Function* user_func) {
+                    // Recurse into the called function to show the reason for the requirement.
+                    auto next_function = functions_.Get(user_func->Declaration());
+                    auto& param_info = next_function->parameters[cause->arg_index];
+                    MakeError(*next_function,
+                              is_value ? param_info.value : param_info.ptr_input_contents,
+                              severity);
 
-                // Show the place where the non-uniform argument was passed.
-                // If this is a builtin, this will be the trigger location for the failure.
-                StringStream ss;
-                ss << "possibly non-uniform value passed" << (is_value ? "" : " via pointer")
-                   << " here";
-                report(call->args[cause->arg_index]->source, ss.str(), /* note */ true);
-            } else {
-                // The uniformity requirement must come from a builtin function.
-                auto* builtin = target->As<sem::BuiltinFn>();
-                TINT_ASSERT(builtin);
-                StringStream ss;
-                ss << "'" << builtin->Fn() << "' requires argument " << cause->arg_index << " to "
-                   << (is_value ? "be uniform" : "have uniform contents");
-                report(call->args[cause->arg_index]->source, ss.str(), /* note */ false);
-            }
+                    // Show the place where the non-uniform argument was passed.
+                    // If this is a builtin, this will be the trigger location for the failure.
+                    StringStream ss;
+                    ss << "possibly non-uniform value passed" << (is_value ? "" : " via pointer")
+                       << " here";
+                    report(call->args[cause->arg_index]->source, ss.str(), /* note */ true);
+                },
+                [&](const sem::BuiltinFn* builtin) {
+                    StringStream ss;
+                    ss << "'" << builtin->Fn() << "' requires argument " << cause->arg_index
+                       << " to " << (is_value ? "be uniform" : "have uniform contents");
+                    report(call->args[cause->arg_index]->source, ss.str(), /* note */ false);
+                },
+                [&](const sem::ValueConstructor* construct) {
+                    StringStream ss;
+                    ss << construct->ReturnType()->FriendlyName()
+                       << " constructor requires argument " << cause->arg_index << " to "
+                       << (is_value ? "be uniform" : "have uniform contents");
+                    report(call->args[cause->arg_index]->source, ss.str(), /* note */ false);
+                },
+                TINT_ICE_ON_NO_MATCH);
 
             // Show the origin of non-uniformity for the value or data that is being passed.
             ShowSourceOfNonUniformity(source_node->visited_from);
diff --git a/src/tint/lang/wgsl/resolver/uniformity_test.cc b/src/tint/lang/wgsl/resolver/uniformity_test.cc
index 0ffa2a6..0109a3c 100644
--- a/src/tint/lang/wgsl/resolver/uniformity_test.cc
+++ b/src/tint/lang/wgsl/resolver/uniformity_test.cc
@@ -447,10 +447,6 @@
 
 @group(2) @binding(0) var<storage, read_write> subgroup_matrix_data : array<f32>;
 
-var<private> subgroup_matrix_left_zero: subgroup_matrix_left<f32, 8, 8>;
-var<private> subgroup_matrix_right_zero: subgroup_matrix_right<f32, 8, 8>;
-var<private> subgroup_matrix_result_zero: subgroup_matrix_result<f32, 8, 8>;
-
 const module_const : i32 = 42;
 @id(42) override pipeline_overridable : i32;
 
@@ -464,6 +460,10 @@
   let let_uniform_rhs = 7;
   let let_nonuniform_rhs = rw;
 
+  var subgroup_matrix_left_zero: subgroup_matrix_left<f32, 8, 8>;
+  var subgroup_matrix_right_zero: subgroup_matrix_right<f32, 8, 8>;
+  var subgroup_matrix_result_zero: subgroup_matrix_result<f32, 8, 8>;
+
   var func_uniform = 7;
   var func_non_uniform = 7;
   func_non_uniform = rw;
@@ -9983,6 +9983,406 @@
 )");
 }
 
+TEST_F(UniformityAnalysisTest, SubgroupMatrixConstructor_NonUniformArgument) {
+    std::string src = R"(
+enable chromium_experimental_subgroup_matrix;
+
+var<private> non_uniform: f32;
+
+fn foo() {
+  _ = subgroup_matrix_result<f32, 8, 8>(non_uniform);
+}
+)";
+
+    RunTest(src, false);
+    EXPECT_EQ(
+        error_,
+        R"(test:7:41 error: subgroup_matrix_result<f32, 8, 8> constructor requires argument 0 to be uniform
+  _ = subgroup_matrix_result<f32, 8, 8>(non_uniform);
+                                        ^^^^^^^^^^^
+
+test:7:41 note: reading from module-scope private variable 'non_uniform' may result in a non-uniform value
+  _ = subgroup_matrix_result<f32, 8, 8>(non_uniform);
+                                        ^^^^^^^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, SubgroupMatrixLoad_NonUniformPointer) {
+    std::string src = R"(
+enable chromium_experimental_subgroup_matrix;
+
+var<private> non_uniform: u32;
+
+var<workgroup> buffer : array<array<f32, 64>, 4>;
+
+fn foo() {
+  let p = &buffer[non_uniform];
+  _ = subgroupMatrixLoad<subgroup_matrix_result<f32, 8, 8>>(p, 0, false, 4);
+}
+)";
+
+    RunTest(src, false);
+    EXPECT_EQ(error_,
+              R"(test:10:61 error: 'subgroupMatrixLoad' requires argument 0 to be uniform
+  _ = subgroupMatrixLoad<subgroup_matrix_result<f32, 8, 8>>(p, 0, false, 4);
+                                                            ^
+
+test:9:19 note: reading from module-scope private variable 'non_uniform' may result in a non-uniform value
+  let p = &buffer[non_uniform];
+                  ^^^^^^^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, SubgroupMatrixLoad_NonUniformOffset) {
+    std::string src = R"(
+enable chromium_experimental_subgroup_matrix;
+
+var<private> non_uniform: u32;
+
+var<workgroup> buffer : array<f32, 64>;
+
+fn foo() {
+  _ = subgroupMatrixLoad<subgroup_matrix_result<f32, 8, 8>>(&buffer, non_uniform, false, 4);
+}
+)";
+
+    RunTest(src, false);
+    EXPECT_EQ(error_,
+              R"(test:9:70 error: 'subgroupMatrixLoad' requires argument 1 to be uniform
+  _ = subgroupMatrixLoad<subgroup_matrix_result<f32, 8, 8>>(&buffer, non_uniform, false, 4);
+                                                                     ^^^^^^^^^^^
+
+test:9:70 note: reading from module-scope private variable 'non_uniform' may result in a non-uniform value
+  _ = subgroupMatrixLoad<subgroup_matrix_result<f32, 8, 8>>(&buffer, non_uniform, false, 4);
+                                                                     ^^^^^^^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, SubgroupMatrixLoad_NonUniformStride) {
+    std::string src = R"(
+enable chromium_experimental_subgroup_matrix;
+
+var<private> non_uniform: u32;
+
+var<workgroup> buffer : array<f32, 64>;
+
+fn foo() {
+  _ = subgroupMatrixLoad<subgroup_matrix_result<f32, 8, 8>>(&buffer, 0, false, non_uniform);
+}
+)";
+
+    RunTest(src, false);
+    EXPECT_EQ(error_,
+              R"(test:9:80 error: 'subgroupMatrixLoad' requires argument 3 to be uniform
+  _ = subgroupMatrixLoad<subgroup_matrix_result<f32, 8, 8>>(&buffer, 0, false, non_uniform);
+                                                                               ^^^^^^^^^^^
+
+test:9:80 note: reading from module-scope private variable 'non_uniform' may result in a non-uniform value
+  _ = subgroupMatrixLoad<subgroup_matrix_result<f32, 8, 8>>(&buffer, 0, false, non_uniform);
+                                                                               ^^^^^^^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, SubgroupMatrixStore_NonUniformPointer) {
+    std::string src = R"(
+enable chromium_experimental_subgroup_matrix;
+
+var<private> non_uniform: u32;
+
+var<workgroup> buffer : array<array<f32, 64>, 4>;
+
+fn foo() {
+  let p = &buffer[non_uniform];
+  let value = subgroup_matrix_result<f32, 8, 8>();
+  subgroupMatrixStore(p, 0, value, false, 4);
+}
+)";
+
+    RunTest(src, false);
+    EXPECT_EQ(error_,
+              R"(test:11:23 error: 'subgroupMatrixStore' requires argument 0 to be uniform
+  subgroupMatrixStore(p, 0, value, false, 4);
+                      ^
+
+test:9:19 note: reading from module-scope private variable 'non_uniform' may result in a non-uniform value
+  let p = &buffer[non_uniform];
+                  ^^^^^^^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, SubgroupMatrixStore_NonUniformOffset) {
+    std::string src = R"(
+enable chromium_experimental_subgroup_matrix;
+
+var<private> non_uniform: u32;
+
+var<workgroup> buffer : array<f32, 64>;
+
+fn foo() {
+  let value = subgroup_matrix_result<f32, 8, 8>();
+  subgroupMatrixStore(&buffer, non_uniform, value, false, 4);
+}
+)";
+
+    RunTest(src, false);
+    EXPECT_EQ(error_,
+              R"(test:10:32 error: 'subgroupMatrixStore' requires argument 1 to be uniform
+  subgroupMatrixStore(&buffer, non_uniform, value, false, 4);
+                               ^^^^^^^^^^^
+
+test:10:32 note: reading from module-scope private variable 'non_uniform' may result in a non-uniform value
+  subgroupMatrixStore(&buffer, non_uniform, value, false, 4);
+                               ^^^^^^^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, SubgroupMatrixStore_NonUniformValue) {
+    std::string src = R"(
+enable chromium_experimental_subgroup_matrix;
+
+var<private> non_uniform: u32;
+
+var<workgroup> buffer : array<f32, 64>;
+
+fn bar() -> subgroup_matrix_result<f32, 8, 8> {
+  var value1 = subgroup_matrix_result<f32, 8, 8>(1);
+  var value2 = subgroup_matrix_result<f32, 8, 8>(2);
+  if (non_uniform == 1) {
+    return value1;
+  } else {
+    return value2;
+  }
+}
+
+fn foo() {
+  subgroupMatrixStore(&buffer, 0, bar(), false, 4);
+}
+)";
+
+    RunTest(src, false);
+    EXPECT_EQ(error_,
+              R"(test:19:35 error: 'subgroupMatrixStore' requires argument 2 to be uniform
+  subgroupMatrixStore(&buffer, 0, bar(), false, 4);
+                                  ^^^^^
+
+test:19:35 note: return value of 'bar' may be non-uniform
+  subgroupMatrixStore(&buffer, 0, bar(), false, 4);
+                                  ^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, SubgroupMatrixStore_NonUniformStride) {
+    std::string src = R"(
+enable chromium_experimental_subgroup_matrix;
+
+var<private> non_uniform: u32;
+
+var<workgroup> buffer : array<f32, 64>;
+
+fn foo() {
+  let value = subgroup_matrix_result<f32, 8, 8>();
+  subgroupMatrixStore(&buffer, 0, value, false, non_uniform);
+}
+)";
+
+    RunTest(src, false);
+    EXPECT_EQ(error_,
+              R"(test:10:49 error: 'subgroupMatrixStore' requires argument 4 to be uniform
+  subgroupMatrixStore(&buffer, 0, value, false, non_uniform);
+                                                ^^^^^^^^^^^
+
+test:10:49 note: reading from module-scope private variable 'non_uniform' may result in a non-uniform value
+  subgroupMatrixStore(&buffer, 0, value, false, non_uniform);
+                                                ^^^^^^^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, SubgroupMatrixMultiply_NonUniformLHS) {
+    std::string src = R"(
+enable chromium_experimental_subgroup_matrix;
+
+var<private> non_uniform: u32;
+
+var<workgroup> buffer : array<f32, 64>;
+
+fn bar() -> subgroup_matrix_left<f32, 8, 8> {
+  var value1 = subgroup_matrix_left<f32, 8, 8>(1);
+  var value2 = subgroup_matrix_left<f32, 8, 8>(2);
+  if (non_uniform == 1) {
+    return value1;
+  } else {
+    return value2;
+  }
+}
+
+fn foo() {
+  let rhs = subgroup_matrix_right<f32, 8, 8>();
+  _ = subgroupMatrixMultiply<f32>(bar(), rhs);
+}
+)";
+
+    RunTest(src, false);
+    EXPECT_EQ(error_,
+              R"(test:20:35 error: 'subgroupMatrixMultiply' requires argument 0 to be uniform
+  _ = subgroupMatrixMultiply<f32>(bar(), rhs);
+                                  ^^^^^
+
+test:20:35 note: return value of 'bar' may be non-uniform
+  _ = subgroupMatrixMultiply<f32>(bar(), rhs);
+                                  ^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, SubgroupMatrixMultiply_NonUniformRHS) {
+    std::string src = R"(
+enable chromium_experimental_subgroup_matrix;
+
+var<private> non_uniform: u32;
+
+var<workgroup> buffer : array<f32, 64>;
+
+fn bar() -> subgroup_matrix_right<f32, 8, 8> {
+  var value1 = subgroup_matrix_right<f32, 8, 8>(1);
+  var value2 = subgroup_matrix_right<f32, 8, 8>(2);
+  if (non_uniform == 1) {
+    return value1;
+  } else {
+    return value2;
+  }
+}
+
+fn foo() {
+  let lhs = subgroup_matrix_left<f32, 8, 8>();
+  _ = subgroupMatrixMultiply<f32>(lhs, bar());
+}
+)";
+
+    RunTest(src, false);
+    EXPECT_EQ(error_,
+              R"(test:20:40 error: 'subgroupMatrixMultiply' requires argument 1 to be uniform
+  _ = subgroupMatrixMultiply<f32>(lhs, bar());
+                                       ^^^^^
+
+test:20:40 note: return value of 'bar' may be non-uniform
+  _ = subgroupMatrixMultiply<f32>(lhs, bar());
+                                       ^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, SubgroupMatrixMultiplyAccumulate_NonUniformLHS) {
+    std::string src = R"(
+enable chromium_experimental_subgroup_matrix;
+
+var<private> non_uniform: u32;
+
+var<workgroup> buffer : array<f32, 64>;
+
+fn bar() -> subgroup_matrix_left<f32, 8, 8> {
+  var value1 = subgroup_matrix_left<f32, 8, 8>(1);
+  var value2 = subgroup_matrix_left<f32, 8, 8>(2);
+  if (non_uniform == 1) {
+    return value1;
+  } else {
+    return value2;
+  }
+}
+
+fn foo() {
+  let rhs = subgroup_matrix_right<f32, 8, 8>();
+  let acc = subgroup_matrix_result<f32, 8, 8>();
+  _ = subgroupMatrixMultiplyAccumulate(bar(), rhs, acc);
+}
+)";
+
+    RunTest(src, false);
+    EXPECT_EQ(
+        error_,
+        R"(test:21:40 error: 'subgroupMatrixMultiplyAccumulate' requires argument 0 to be uniform
+  _ = subgroupMatrixMultiplyAccumulate(bar(), rhs, acc);
+                                       ^^^^^
+
+test:21:40 note: return value of 'bar' may be non-uniform
+  _ = subgroupMatrixMultiplyAccumulate(bar(), rhs, acc);
+                                       ^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, SubgroupMatrixMultiplyAccumulate_NonUniformRHS) {
+    std::string src = R"(
+enable chromium_experimental_subgroup_matrix;
+
+var<private> non_uniform: u32;
+
+var<workgroup> buffer : array<f32, 64>;
+
+fn bar() -> subgroup_matrix_right<f32, 8, 8> {
+  var value1 = subgroup_matrix_right<f32, 8, 8>(1);
+  var value2 = subgroup_matrix_right<f32, 8, 8>(2);
+  if (non_uniform == 1) {
+    return value1;
+  } else {
+    return value2;
+  }
+}
+
+fn foo() {
+  let lhs = subgroup_matrix_left<f32, 8, 8>();
+  let acc = subgroup_matrix_result<f32, 8, 8>();
+  _ = subgroupMatrixMultiplyAccumulate(lhs, bar(), acc);
+}
+)";
+
+    RunTest(src, false);
+    EXPECT_EQ(
+        error_,
+        R"(test:21:45 error: 'subgroupMatrixMultiplyAccumulate' requires argument 1 to be uniform
+  _ = subgroupMatrixMultiplyAccumulate(lhs, bar(), acc);
+                                            ^^^^^
+
+test:21:45 note: return value of 'bar' may be non-uniform
+  _ = subgroupMatrixMultiplyAccumulate(lhs, bar(), acc);
+                                            ^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, SubgroupMatrixMultiplyAccumulate_NonUniformAcc) {
+    std::string src = R"(
+enable chromium_experimental_subgroup_matrix;
+
+var<private> non_uniform: u32;
+
+var<workgroup> buffer : array<f32, 64>;
+
+fn bar() -> subgroup_matrix_result<f32, 8, 8> {
+  var value1 = subgroup_matrix_result<f32, 8, 8>(1);
+  var value2 = subgroup_matrix_result<f32, 8, 8>(2);
+  if (non_uniform == 1) {
+    return value1;
+  } else {
+    return value2;
+  }
+}
+
+fn foo() {
+  let lhs = subgroup_matrix_left<f32, 8, 8>();
+  let rhs = subgroup_matrix_right<f32, 8, 8>();
+  _ = subgroupMatrixMultiplyAccumulate(lhs, rhs, bar());
+}
+)";
+
+    RunTest(src, false);
+    EXPECT_EQ(
+        error_,
+        R"(test:21:50 error: 'subgroupMatrixMultiplyAccumulate' requires argument 2 to be uniform
+  _ = subgroupMatrixMultiplyAccumulate(lhs, rhs, bar());
+                                                 ^^^^^
+
+test:21:50 note: return value of 'bar' may be non-uniform
+  _ = subgroupMatrixMultiplyAccumulate(lhs, rhs, bar());
+                                                 ^^^^^
+)");
+}
+
 TEST_F(UniformityAnalysisTest, StressGraphTraversalDepth) {
     // Create a function with a very long sequence of variable declarations and assignments to
     // test traversals of very deep graphs. This requires a non-recursive traversal algorithm.
@@ -10197,6 +10597,32 @@
     }
 }
 
+TEST_P(UniformityAnalysisDiagnosticFilterTest,
+       Directive_SubgroupMatrixUniformity_BuiltinFunctionArgument) {
+    auto& param = GetParam();
+    StringStream ss;
+    ss << "enable chromium_experimental_subgroup_matrix;\n"
+       << "diagnostic(" << param << ", chromium.subgroup_matrix_uniformity);" << R"(
+@group(0) @binding(0) var<storage, read_write> non_uniform : u32;
+
+@group(0) @binding(1) var<storage, read_write> data : array<f32>;
+
+fn foo() {
+  _ = subgroupMatrixLoad<subgroup_matrix_left<f32, 8, 8>>(&data, non_uniform, false, 4);
+}
+)";
+
+    RunTest(ss.str(), param != wgsl::DiagnosticSeverity::kError);
+
+    if (param == wgsl::DiagnosticSeverity::kOff) {
+        EXPECT_TRUE(error_.empty());
+    } else {
+        StringStream err;
+        err << ToStr(param) << ": 'subgroupMatrixLoad' requires argument 1 to be uniform";
+        EXPECT_THAT(error_, ::testing::HasSubstr(err.str()));
+    }
+}
+
 TEST_P(UniformityAnalysisDiagnosticFilterTest, AttributeOnFunction_DerivativeUniformity) {
     auto& param = GetParam();
     StringStream ss;