[tint] Check uniformity for subgroup matrix constructors

They must only be called from uniform control flow.

Fixed: 403610976
Change-Id: Iaca57415a3ae316f145df60c4ccffba9fafa42b9
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/235656
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: David Neto <dneto@google.com>
diff --git a/src/tint/lang/wgsl/resolver/uniformity.cc b/src/tint/lang/wgsl/resolver/uniformity.cc
index a9b0bec..2bb5cdc 100644
--- a/src/tint/lang/wgsl/resolver/uniformity.cc
+++ b/src/tint/lang/wgsl/resolver/uniformity.cc
@@ -401,6 +401,9 @@
     /// The function currently being analyzed.
     FunctionInfo* current_function_;
 
+    /// A map from composite type to true/false to indicate whether it contains a subgroup matrix.
+    Hashmap<const core::type::Type*, bool, 16> composite_subgroup_matrix_info_;
+
     /// Create a new node.
     /// @param tag_list a string list that will be used to identify the node for debugging purposes
     /// @param ast the optional AST node that this node corresponds to
@@ -1746,9 +1749,18 @@
                 function_tag = info->function_tag;
                 func_info = info.value;
             },
-            [&](const sem::ValueConstructor*) {
-                callsite_tag = {CallSiteTag::CallSiteNoRestriction};
-                function_tag = NoRestriction;
+            [&](const sem::ValueConstructor* construct) {
+                if (ContainsSubgroupMatrix(construct->ReturnType())) {
+                    // Get the severity of subgroup matrix uniformity violations in this context.
+                    auto severity = sem_.DiagnosticSeverity(
+                        call, wgsl::ChromiumDiagnosticRule::kSubgroupMatrixUniformity);
+                    if (severity != wgsl::DiagnosticSeverity::kOff) {
+                        callsite_tag = {CallSiteTag::CallSiteRequiredToBeUniform, severity};
+                    }
+                } else {
+                    callsite_tag = {CallSiteTag::CallSiteNoRestriction};
+                    function_tag = NoRestriction;
+                }
             },
             [&](const sem::ValueConversion*) {
                 callsite_tag = {CallSiteTag::CallSiteNoRestriction};
@@ -1908,8 +1920,8 @@
         const ast::CallExpression* call,
         wgsl::DiagnosticSeverity severity) {
         auto* target = SemCall(call)->Target();
-        if (target->Is<sem::BuiltinFn>()) {
-            // This is a call to a builtin, so we must be done.
+        if (target->IsAnyOf<sem::BuiltinFn, sem::ValueConstructor>()) {
+            // This is a call to a builtin or constructor, so we must be done.
             return call;
         } else if (auto* user = target->As<sem::Function>()) {
             // This is a call to a user-defined function, so inspect the functions called by that
@@ -2156,6 +2168,28 @@
     const sem::Call* SemCall(const ast::CallExpression* expr) const {
         return sem_.Get(expr)->UnwrapMaterialize()->As<sem::Call>();
     }
+
+    /// @returns true if @p type is or contains a subgroup matrix type
+    bool ContainsSubgroupMatrix(const core::type::Type* type) {
+        if (type->Is<core::type::SubgroupMatrix>()) {
+            return true;
+        }
+        return composite_subgroup_matrix_info_.GetOrAdd(type, [&] {
+            return Switch(
+                type,  //
+                [&](const core::type::Array* arr) {
+                    return ContainsSubgroupMatrix(arr->ElemType());
+                },
+                [&](const core::type::Struct* str) {
+                    for (auto* member : str->Members()) {
+                        if (ContainsSubgroupMatrix(member->Type())) {
+                            return true;
+                        }
+                    }
+                    return false;
+                });
+        });
+    }
 };
 
 }  // namespace
diff --git a/src/tint/lang/wgsl/resolver/uniformity_test.cc b/src/tint/lang/wgsl/resolver/uniformity_test.cc
index 7365faf..500dfda 100644
--- a/src/tint/lang/wgsl/resolver/uniformity_test.cc
+++ b/src/tint/lang/wgsl/resolver/uniformity_test.cc
@@ -171,6 +171,7 @@
         kQuadSwapY,
         kQuadSwapDiagonal,
         // Subgroup matrix functions:
+        kSubgroupMatrixConstruct,
         kSubgroupMatrixLoad,
         kSubgroupMatrixStore,
         kSubgroupMatrixMultiply,
@@ -309,6 +310,8 @@
                 return "_ = quadSwapY(1.0)";
             case kQuadSwapDiagonal:
                 return "_ = quadSwapDiagonal(1.0)";
+            case kSubgroupMatrixConstruct:
+                return "_ = subgroup_matrix_result<f32, 8, 8>()";
             case kSubgroupMatrixLoad:
                 return "_ = subgroupMatrixLoad<subgroup_matrix_result<f32, 8, 8>>("
                        "&subgroup_matrix_data, 0, false, 4)";
@@ -409,6 +412,7 @@
             CASE(kQuadSwapX);
             CASE(kQuadSwapY);
             CASE(kQuadSwapDiagonal);
+            CASE(kSubgroupMatrixConstruct);
             CASE(kSubgroupMatrixLoad);
             CASE(kSubgroupMatrixStore);
             CASE(kSubgroupMatrixMultiply);
@@ -9787,6 +9791,76 @@
     RunTest(src, true);
 }
 
+TEST_F(UniformityAnalysisTest, SubgroupMatrixConstructor_NestedInStruct) {
+    std::string src = R"(
+enable chromium_experimental_subgroup_matrix;
+
+@group(0) @binding(0) var<storage, read_write> non_uniform_global : i32;
+
+struct Inner {
+  u : u32,
+  m : subgroup_matrix_result<f32, 8, 8>,
+}
+
+struct S {
+  u : u32,
+  inner : Inner,
+}
+
+fn foo() {
+  if (non_uniform_global == 0) {
+    _ = S();
+  }
+}
+)";
+
+    RunTest(src, false);
+    EXPECT_EQ(error_,
+              R"(test:18:9 error: 'S' must only be called from uniform control flow
+    _ = S();
+        ^^^
+
+test:17:3 note: control flow depends on possibly non-uniform value
+  if (non_uniform_global == 0) {
+  ^^
+
+test:17:7 note: reading from read_write storage buffer 'non_uniform_global' may result in a non-uniform value
+  if (non_uniform_global == 0) {
+      ^^^^^^^^^^^^^^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, SubgroupMatrixConstructor_NestedInArray) {
+    std::string src = R"(
+enable chromium_experimental_subgroup_matrix;
+
+@group(0) @binding(0) var<storage, read_write> non_uniform_global : i32;
+
+alias ArrayType = array<array<subgroup_matrix_result<f32, 8, 8>, 4>, 4>;
+
+fn foo() {
+  if (non_uniform_global == 0) {
+    _ = ArrayType();
+  }
+}
+)";
+
+    RunTest(src, false);
+    EXPECT_EQ(error_,
+              R"(test:10:9 error: 'ArrayType' must only be called from uniform control flow
+    _ = ArrayType();
+        ^^^^^^^^^^^
+
+test:9:3 note: control flow depends on possibly non-uniform value
+  if (non_uniform_global == 0) {
+  ^^
+
+test:9:7 note: reading from read_write storage buffer 'non_uniform_global' may result in a non-uniform value
+  if (non_uniform_global == 0) {
+      ^^^^^^^^^^^^^^^^^^
+)");
+}
+
 TEST_F(UniformityAnalysisTest, StressGraphTraversalDepth) {
     // Create a function with a very long sequence of variable declarations and assignments to
     // test traversals of very deep graphs. This requires a non-recursive traversal algorithm.
@@ -9920,7 +9994,7 @@
     }
 }
 
-TEST_P(UniformityAnalysisDiagnosticFilterTest, Directive_SubgroupMatrixUniformity_Callsite) {
+TEST_P(UniformityAnalysisDiagnosticFilterTest, Directive_SubgroupMatrixUniformity_BuiltinFunction) {
     auto& param = GetParam();
     StringStream ss;
     ss << "enable chromium_experimental_subgroup_matrix;\n"
@@ -9947,6 +10021,33 @@
     }
 }
 
+TEST_P(UniformityAnalysisDiagnosticFilterTest, Directive_SubgroupMatrixUniformity_Constructor) {
+    auto& param = GetParam();
+    StringStream ss;
+    ss << "enable chromium_experimental_subgroup_matrix;\n"
+       << "diagnostic(" << param << ", chromium.subgroup_matrix_uniformity);" << R"(
+@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
+
+@group(0) @binding(1) var<storage, read_write> data : array<f32>;
+
+fn foo() {
+  if (non_uniform == 42) {
+    _ = subgroup_matrix_left<f32, 8, 8>(1);
+  }
+}
+)";
+
+    RunTest(ss.str(), param != wgsl::DiagnosticSeverity::kError);
+
+    if (param == wgsl::DiagnosticSeverity::kOff) {
+        EXPECT_TRUE(error_.empty());
+    } else {
+        StringStream err;
+        err << ToStr(param) << ": 'subgroup_matrix_left' must only be called";
+        EXPECT_THAT(error_, ::testing::HasSubstr(err.str()));
+    }
+}
+
 TEST_P(UniformityAnalysisDiagnosticFilterTest, AttributeOnFunction_DerivativeUniformity) {
     auto& param = GetParam();
     StringStream ss;