[wgsl] Fix 8-bit subgroup matrix constructors

The signatures for subgroup matrix constructors need to use the shader
scalar types, which are 32-bit integers for the 8-bit matrices.

Fixed: 458773229
Change-Id: I710af6fd705cde43d0978f70c1a0a191d7cc175f
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/272174
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/tint/lang/wgsl/resolver/resolver.cc b/src/tint/lang/wgsl/resolver/resolver.cc
index 655f3f1..d740796 100644
--- a/src/tint/lang/wgsl/resolver/resolver.cc
+++ b/src/tint/lang/wgsl/resolver/resolver.cc
@@ -2097,10 +2097,17 @@
                 auto* call_target = subgroup_matrix_ctors_.GetOrAdd(
                     SubgroupMatrixConstructorSig{{m, args.Length()}},
                     [&]() -> sem::ValueConstructor* {
+                        // 8-bit integer matrices use 32-bit shader scalar types.
+                        auto* scalar_ty = m->Type();
+                        if (m->Type()->Is<core::type::I8>()) {
+                            scalar_ty = b.create<core::type::I32>();
+                        } else if (m->Type()->Is<core::type::U8>()) {
+                            scalar_ty = b.create<core::type::U32>();
+                        }
                         auto params = tint::Transform(args, [&](auto, size_t i) {
                             return b.create<sem::Parameter>(nullptr,  // declaration
                                                             static_cast<uint32_t>(i),  // index
-                                                            m->Type());
+                                                            scalar_ty);
                         });
                         return b.create<sem::ValueConstructor>(m, std::move(params),
                                                                core::EvaluationStage::kRuntime);
@@ -2110,7 +2117,7 @@
                     return nullptr;
                 }
 
-                if (DAWN_UNLIKELY(!validator_.SubgroupMatrixConstructor(expr, m))) {
+                if (DAWN_UNLIKELY(!validator_.SubgroupMatrixConstructor(expr, m, call_target))) {
                     return nullptr;
                 }
 
diff --git a/src/tint/lang/wgsl/resolver/subgroup_matrix_test.cc b/src/tint/lang/wgsl/resolver/subgroup_matrix_test.cc
index caa327e..a4e06f4 100644
--- a/src/tint/lang/wgsl/resolver/subgroup_matrix_test.cc
+++ b/src/tint/lang/wgsl/resolver/subgroup_matrix_test.cc
@@ -292,6 +292,44 @@
     EXPECT_EQ(target->Stage(), core::EvaluationStage::kRuntime);
 }
 
+TEST_F(ResolverSubgroupMatrixTest, SingleValueConstructor_U8_Abstract) {
+    Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+    auto* call = Call(Ident("subgroup_matrix_result", ty.u8(), 8_a, 8_a), 1_a);
+    Func("foo", Empty, ty.void_(),
+         Vector{
+             Assign(Phony(), call),
+         });
+
+    EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+    auto call_sem = Sem().Get(call)->As<sem::Call>();
+    ASSERT_NE(call_sem, nullptr);
+    auto* target = call_sem->Target()->As<sem::ValueConstructor>();
+    ASSERT_NE(target, nullptr);
+    EXPECT_TRUE(target->ReturnType()->Is<core::type::SubgroupMatrix>());
+    EXPECT_EQ(target->Parameters().Length(), 1u);
+    EXPECT_EQ(target->Stage(), core::EvaluationStage::kRuntime);
+}
+
+TEST_F(ResolverSubgroupMatrixTest, SingleValueConstructor_U8_U32) {
+    Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+    auto* call = Call(Ident("subgroup_matrix_result", ty.u8(), 8_a, 8_a), 1_u);
+    Func("foo", Empty, ty.void_(),
+         Vector{
+             Assign(Phony(), call),
+         });
+
+    EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+    auto call_sem = Sem().Get(call)->As<sem::Call>();
+    ASSERT_NE(call_sem, nullptr);
+    auto* target = call_sem->Target()->As<sem::ValueConstructor>();
+    ASSERT_NE(target, nullptr);
+    EXPECT_TRUE(target->ReturnType()->Is<core::type::SubgroupMatrix>());
+    EXPECT_EQ(target->Parameters().Length(), 1u);
+    EXPECT_EQ(target->Stage(), core::EvaluationStage::kRuntime);
+}
+
 TEST_F(ResolverSubgroupMatrixTest, ConstructorTooManyArgs) {
     Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
     auto* call = Call(Ident("subgroup_matrix_result", ty.f32(), 8_a, 8_a), 1_f, 2_f);
diff --git a/src/tint/lang/wgsl/resolver/validator.cc b/src/tint/lang/wgsl/resolver/validator.cc
index 59ae08b..2da058c 100644
--- a/src/tint/lang/wgsl/resolver/validator.cc
+++ b/src/tint/lang/wgsl/resolver/validator.cc
@@ -2389,14 +2389,15 @@
     return true;
 }
 
-bool Validator::SubgroupMatrixConstructor(
-    const ast::CallExpression* ctor,
-    const core::type::SubgroupMatrix* subgroup_matrix_type) const {
+bool Validator::SubgroupMatrixConstructor(const ast::CallExpression* ctor,
+                                          const core::type::SubgroupMatrix* subgroup_matrix_type,
+                                          const sem::CallTarget* signature) const {
     auto& values = ctor->args;
     if (values.Length() == 1) {
         auto* elem_ty = subgroup_matrix_type->Type();
         auto* value_ty = sem_.TypeOf(values[0])->UnwrapRef();
-        if (core::type::Type::ConversionRank(value_ty, elem_ty) ==
+        auto* expected_ty = signature->Parameters()[0]->Type();
+        if (core::type::Type::ConversionRank(value_ty, expected_ty) ==
             core::type::Type::kNoConversion) {
             AddError(values[0]->source) << style::Type(sem_.TypeNameOf(value_ty))
                                         << " cannot be used to construct a subgroup matrix of "
diff --git a/src/tint/lang/wgsl/resolver/validator.h b/src/tint/lang/wgsl/resolver/validator.h
index de3b8b4..7142ec2 100644
--- a/src/tint/lang/wgsl/resolver/validator.h
+++ b/src/tint/lang/wgsl/resolver/validator.h
@@ -536,9 +536,11 @@
     /// Validates a subgroup matrix constructor
     /// @param ctor the call expression to validate
     /// @param subgroup_matrix_type the type of the subgroup matrix
+    /// @param signature the construct signature to validate against
     /// @returns true on success, false otherwise
     bool SubgroupMatrixConstructor(const ast::CallExpression* ctor,
-                                   const core::type::SubgroupMatrix* subgroup_matrix_type) const;
+                                   const core::type::SubgroupMatrix* subgroup_matrix_type,
+                                   const sem::CallTarget* signature) const;
 
     /// Validates a subgroupShuffle builtin functions including Up,Down, and Xor.
     /// @param fn the builtin call type