[wgsl] Handle subgroup matrix constructors

These have to be resolved manually due to the arbitrary column and row
counts.

Support zero value and single value constructors only.

Bug: 348702031
Change-Id: I550adb23c27411fdddb09aa54bac026967a91c14
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/224055
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/wgsl/resolver/resolver.cc b/src/tint/lang/wgsl/resolver/resolver.cc
index 7c19a50..a181969 100644
--- a/src/tint/lang/wgsl/resolver/resolver.cc
+++ b/src/tint/lang/wgsl/resolver/resolver.cc
@@ -2270,6 +2270,36 @@
 
                 return arr_or_str_init(str, call_target);
             },
+            [&](const core::type::SubgroupMatrix* m) -> sem::Call* {
+                auto* call_target = subgroup_matrix_ctors_.GetOrAdd(
+                    SubgroupMatrixConstructorSig{{m, args.Length()}},
+                    [&]() -> sem::ValueConstructor* {
+                        auto params = tint::Transform(args, [&](auto, size_t i) {
+                            return b.create<sem::Parameter>(nullptr,  // declaration
+                                                            static_cast<uint32_t>(i),  // index
+                                                            m->Type());
+                        });
+                        return b.create<sem::ValueConstructor>(m, std::move(params),
+                                                               core::EvaluationStage::kRuntime);
+                    });
+
+                if (DAWN_UNLIKELY(!MaybeMaterializeAndLoadArguments(args, call_target))) {
+                    return nullptr;
+                }
+
+                if (DAWN_UNLIKELY(!validator_.SubgroupMatrixConstructor(expr, m))) {
+                    return nullptr;
+                }
+
+                // Subgroup matrix constructors are never const-evaluated.
+                auto stage = core::EvaluationStage::kRuntime;
+                if (not_evaluated_.Contains(expr)) {
+                    stage = core::EvaluationStage::kNotEvaluated;
+                }
+
+                return b.create<sem::Call>(expr, call_target, stage, std::move(args),
+                                           current_statement_, nullptr, has_side_effects);
+            },
             [&](Default) {
                 AddError(expr->source) << "type is not constructible";
                 return nullptr;
diff --git a/src/tint/lang/wgsl/resolver/resolver.h b/src/tint/lang/wgsl/resolver/resolver.h
index eb811eb..12e73e2 100644
--- a/src/tint/lang/wgsl/resolver/resolver.h
+++ b/src/tint/lang/wgsl/resolver/resolver.h
@@ -662,6 +662,11 @@
     using StructConstructorSig = tint::UnorderedKeyWrapper<
         std::tuple<const core::type::Struct*, size_t, core::EvaluationStage>>;
 
+    // SubgroupMatrixConstructorSig represents a unique subgroup matrix constructor signature.
+    // It is a tuple of the subgroup matrix type and the number of arguments provided.
+    using SubgroupMatrixConstructorSig =
+        tint::UnorderedKeyWrapper<std::tuple<const core::type::SubgroupMatrix*, size_t>>;
+
     /// ExprEvalStageConstraint describes a constraint on when expressions can be evaluated.
     struct ExprEvalStageConstraint {
         /// The latest stage that the expression can be evaluated
@@ -701,6 +706,7 @@
     Hashmap<OverrideId, const sem::Variable*, 8> override_ids_;
     Hashmap<ArrayConstructorSig, sem::CallTarget*, 8> array_ctors_;
     Hashmap<StructConstructorSig, sem::CallTarget*, 8> struct_ctors_;
+    Hashmap<SubgroupMatrixConstructorSig, sem::CallTarget*, 8> subgroup_matrix_ctors_;
     sem::Function* current_function_ = nullptr;
     sem::Statement* current_statement_ = nullptr;
     sem::CompoundStatement* current_compound_statement_ = nullptr;
diff --git a/src/tint/lang/wgsl/resolver/subgroup_matrix_test.cc b/src/tint/lang/wgsl/resolver/subgroup_matrix_test.cc
index daf419e..31b9f2c 100644
--- a/src/tint/lang/wgsl/resolver/subgroup_matrix_test.cc
+++ b/src/tint/lang/wgsl/resolver/subgroup_matrix_test.cc
@@ -27,6 +27,7 @@
 
 #include "src/tint/lang/wgsl/resolver/resolver.h"
 #include "src/tint/lang/wgsl/resolver/resolver_helper_test.h"
+#include "src/tint/lang/wgsl/sem/value_constructor.h"
 
 #include "gmock/gmock.h"
 
@@ -222,5 +223,69 @@
               R"(error: subgroup matrix row count must be a constant positive integer)");
 }
 
+TEST_F(ResolverSubgroupMatrixTest, ZeroValueConstructor) {
+    Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+    auto* call = Call(Ident("subgroup_matrix_result", ty.f32(), 8_a, 8_a));
+    Func("foo", Empty, ty.void_(),
+         Vector{
+             Assign(Phony(), call),
+         });
+
+    EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+    auto call_sem = Sem().Get(call)->As<sem::Call>();
+    ASSERT_NE(call_sem, nullptr);
+    auto* target = call_sem->Target()->As<sem::ValueConstructor>();
+    ASSERT_NE(target, nullptr);
+    EXPECT_TRUE(target->ReturnType()->Is<core::type::SubgroupMatrix>());
+    EXPECT_EQ(target->Parameters().Length(), 0u);
+    EXPECT_EQ(target->Stage(), core::EvaluationStage::kRuntime);
+}
+
+TEST_F(ResolverSubgroupMatrixTest, SingleValueConstructor) {
+    Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+    auto* call = Call(Ident("subgroup_matrix_result", ty.f32(), 8_a, 8_a), 1_a);
+    Func("foo", Empty, ty.void_(),
+         Vector{
+             Assign(Phony(), call),
+         });
+
+    EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+    auto call_sem = Sem().Get(call)->As<sem::Call>();
+    ASSERT_NE(call_sem, nullptr);
+    auto* target = call_sem->Target()->As<sem::ValueConstructor>();
+    ASSERT_NE(target, nullptr);
+    EXPECT_TRUE(target->ReturnType()->Is<core::type::SubgroupMatrix>());
+    EXPECT_EQ(target->Parameters().Length(), 1u);
+    EXPECT_EQ(target->Stage(), core::EvaluationStage::kRuntime);
+}
+
+TEST_F(ResolverSubgroupMatrixTest, ConstructorTooManyArgs) {
+    Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+    auto* call = Call(Ident("subgroup_matrix_result", ty.f32(), 8_a, 8_a), 1_f, 2_f);
+    Func("foo", Empty, ty.void_(),
+         Vector{
+             Assign(Phony(), call),
+         });
+
+    EXPECT_FALSE(r()->Resolve());
+    EXPECT_EQ(r()->error(),
+              R"(error: subgroup_matrix constructor can only have zero or one elements)");
+}
+
+TEST_F(ResolverSubgroupMatrixTest, ConstructorWrongType) {
+    Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+    auto* call = Call(Ident("subgroup_matrix_result", ty.u32(), 8_a, 8_a), 1_f);
+    Func("foo", Empty, ty.void_(),
+         Vector{
+             Assign(Phony(), call),
+         });
+
+    EXPECT_FALSE(r()->Resolve());
+    EXPECT_EQ(r()->error(),
+              R"(error: 'f32' cannot be used to construct a subgroup matrix of 'u32')");
+}
+
 }  // namespace
 }  // namespace tint::resolver
diff --git a/src/tint/lang/wgsl/resolver/validator.cc b/src/tint/lang/wgsl/resolver/validator.cc
index a3f559f..bc900a9 100644
--- a/src/tint/lang/wgsl/resolver/validator.cc
+++ b/src/tint/lang/wgsl/resolver/validator.cc
@@ -2323,6 +2323,29 @@
     return true;
 }
 
+bool Validator::SubgroupMatrixConstructor(
+    const ast::CallExpression* ctor,
+    const core::type::SubgroupMatrix* subgroup_matrix_type) const {
+    auto& values = ctor->args;
+    if (values.Length() == 1) {
+        auto* elem_ty = subgroup_matrix_type->Type();
+        auto* value_ty = sem_.TypeOf(values[0])->UnwrapRef();
+        if (core::type::Type::ConversionRank(value_ty, elem_ty) ==
+            core::type::Type::kNoConversion) {
+            AddError(values[0]->source) << style::Type(sem_.TypeNameOf(value_ty))
+                                        << " cannot be used to construct a subgroup matrix of "
+                                        << style::Type(sem_.TypeNameOf(elem_ty));
+            return false;
+        }
+    } else if (values.Length() > 1) {
+        AddError(ctor->target->source)
+            << "subgroup_matrix constructor can only have zero or one elements";
+        return false;
+    }
+
+    return true;
+}
+
 bool Validator::Vector(const core::type::Type* el_ty, const Source& source) const {
     if (!el_ty->Is<core::type::Scalar>()) {
         AddError(source) << "vector element type must be " << style::Type("bool") << ", "
diff --git a/src/tint/lang/wgsl/resolver/validator.h b/src/tint/lang/wgsl/resolver/validator.h
index 4d533f6..7907f9d 100644
--- a/src/tint/lang/wgsl/resolver/validator.h
+++ b/src/tint/lang/wgsl/resolver/validator.h
@@ -529,11 +529,18 @@
     bool Vector(const core::type::Type* el_ty, const Source& source) const;
 
     /// Validates an array constructor
-    /// @param ctor the call expresion to validate
+    /// @param ctor the call expression to validate
     /// @param arr_type the type of the array
     /// @returns true on success, false otherwise
     bool ArrayConstructor(const ast::CallExpression* ctor, const sem::Array* arr_type) const;
 
+    /// Validates a subgroup matrix constructor
+    /// @param ctor the call expression to validate
+    /// @param subgroup_matrix_type the type of the subgroup matrix
+    /// @returns true on success, false otherwise
+    bool SubgroupMatrixConstructor(const ast::CallExpression* ctor,
+                                   const core::type::SubgroupMatrix* subgroup_matrix_type) const;
+
     /// Validates a subgroupShuffle builtin functions including Up,Down, and Xor.
     /// @param fn the builtin call type
     /// @param call the builtin call to validate