[ir] Validate matrix constructor arguments

Use the intrinsics table to check that all matrix construct
instructions have valid type signatures.

Fix one invalid test and a bug in a SPIR-V reader transform.

Fixed: 427965903, 433565172
Change-Id: I8c21eee81022e8652bc99742de18a057a973feb9
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/255014
Reviewed-by: Ryan Harrison <rharrison@chromium.org>
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/core/ir/validator.cc b/src/tint/lang/core/ir/validator.cc
index da472ad..3934eb3 100644
--- a/src/tint/lang/core/ir/validator.cc
+++ b/src/tint/lang/core/ir/validator.cc
@@ -3153,8 +3153,16 @@
         // TODO(crbug.com/427964608): This needs special handling as Element() produces nullptr.
     } else if (result_type->Is<core::type::Vector>()) {
         // TODO(crbug.com/427964205): This needs special handling as there are many cases.
-    } else if (result_type->Is<core::type::Matrix>()) {
-        // TODO(crbug.com/427965903): This needs special handling as there are many cases.
+    } else if (auto* mat = result_type->As<core::type::Matrix>()) {
+        auto table = intrinsic::Table<intrinsic::Dialect>(type_mgr_, symbols_);
+        auto ctor_conv = intrinsic::MatrixCtorConv(mat->Columns(), mat->Rows());
+        auto arg_types = Transform<8>(args, [&](auto* v) { return v->Type(); });
+        auto match = table.Lookup(ctor_conv, Vector{mat->Type()}, std::move(arg_types),
+                                  core::EvaluationStage::kConstant);
+        if (match != Success) {
+            AddError(construct) << "no matching overload for " << mat->FriendlyName()
+                                << " constructor";
+        }
     } else if (result_type->Is<core::type::Array>()) {
         check_args_match_elements();
     } else if (auto* str = As<core::type::Struct>(result_type)) {
diff --git a/src/tint/lang/core/ir/validator_test.cc b/src/tint/lang/core/ir/validator_test.cc
index e4a38aa..dc06f9a 100644
--- a/src/tint/lang/core/ir/validator_test.cc
+++ b/src/tint/lang/core/ir/validator_test.cc
@@ -210,6 +210,69 @@
 )")) << res.Failure();
 }
 
+TEST_F(IR_ValidatorTest, Construct_Matrix_NoArgs) {
+    auto* f = b.Function("f", ty.void_());
+    b.Append(f->Block(), [&] {
+        b.Construct(ty.mat2x2<f32>());
+        b.Return(f);
+    });
+
+    auto res = ir::Validate(mod);
+    ASSERT_EQ(res, Success) << res.Failure();
+}
+
+TEST_F(IR_ValidatorTest, Construct_Matrix_Scalar) {
+    auto* f = b.Function("f", ty.void_());
+    b.Append(f->Block(), [&] {
+        b.Construct(ty.mat2x2<f32>(), 1_f, 2_f, 3_f, 4_f);
+        b.Return(f);
+    });
+
+    auto res = ir::Validate(mod);
+    ASSERT_EQ(res, Success) << res.Failure();
+}
+
+TEST_F(IR_ValidatorTest, Construct_Matrix_ColumnVectors) {
+    auto* f = b.Function("f", ty.void_());
+    b.Append(f->Block(), [&] {
+        auto* v1 = b.Composite(ty.vec2<f32>(), 1_f, 2_f);
+        auto* v2 = b.Composite(ty.vec2<f32>(), 3_f, 4_f);
+        b.Construct(ty.mat2x2<f32>(), v1, v2);
+        b.Return(f);
+    });
+
+    auto res = ir::Validate(mod);
+    ASSERT_EQ(res, Success) << res.Failure();
+}
+
+TEST_F(IR_ValidatorTest, Construct_Matrix_MixedScalarVector) {
+    auto* f = b.Function("f", ty.void_());
+    b.Append(f->Block(), [&] {
+        b.Construct(ty.mat2x2<f32>(), 1_f, b.Composite(ty.vec2<f32>(), 2_f, 3_f));
+        b.Return(f);
+    });
+
+    auto res = ir::Validate(mod);
+    ASSERT_NE(res, Success);
+    EXPECT_THAT(
+        res.Failure().reason,
+        testing::HasSubstr("error: construct: no matching overload for mat2x2<f32> constructor"));
+}
+
+TEST_F(IR_ValidatorTest, Construct_Matrix_Scalar_WrongType) {
+    auto* f = b.Function("f", ty.void_());
+    b.Append(f->Block(), [&] {
+        b.Construct(ty.mat2x2<f32>(), 1_f, 2_f, 3_h, 4_f);
+        b.Return(f);
+    });
+
+    auto res = ir::Validate(mod);
+    ASSERT_NE(res, Success);
+    EXPECT_THAT(
+        res.Failure().reason,
+        testing::HasSubstr("error: construct: no matching overload for mat2x2<f32> constructor"));
+}
+
 TEST_F(IR_ValidatorTest, Construct_Struct_ZeroValue) {
     auto* str_ty = ty.Struct(mod.symbols.New("MyStruct"), {
                                                               {mod.symbols.New("a"), ty.i32()},
diff --git a/src/tint/lang/hlsl/writer/construct_test.cc b/src/tint/lang/hlsl/writer/construct_test.cc
index 1c6e11c..55323c2 100644
--- a/src/tint/lang/hlsl/writer/construct_test.cc
+++ b/src/tint/lang/hlsl/writer/construct_test.cc
@@ -141,7 +141,7 @@
 TEST_F(HlslWriterTest, ConstructMatrix) {
     auto* f = b.Function("a", ty.mat2x2<f32>());
     b.Append(f->Block(), [&] {
-        auto* v = b.Var("v", 2_f);
+        auto* v = b.Let("v", 2_f);
         b.Return(f, b.Construct(ty.mat2x2<f32>(), v, v, v, v));
     });
 
@@ -149,7 +149,8 @@
     EXPECT_EQ(output_.hlsl, R"(
 float2x2 a() {
   float v = 2.0f;
-  return float2x2(v, v, v, v);
+  float2 v_1 = float2(v, v);
+  return float2x2(v_1, float2(v, v));
 }
 
 [numthreads(1, 1, 1)]
diff --git a/src/tint/lang/spirv/reader/lower/builtins.cc b/src/tint/lang/spirv/reader/lower/builtins.cc
index 2b7046d..9f1caa4 100644
--- a/src/tint/lang/spirv/reader/lower/builtins.cc
+++ b/src/tint/lang/spirv/reader/lower/builtins.cc
@@ -1287,10 +1287,10 @@
                     // a * fkgj - b * ekgi + c * ejfi
                     auto* r_33 = sub_add_mul3(ma, fkgj, mb, ekgi, mc, ejfi);
 
-                    auto* r1 = b.Construct(ty.vec3(elem_ty), r_00, r_01, r_02, r_03);
-                    auto* r2 = b.Construct(ty.vec3(elem_ty), r_10, r_11, r_12, r_13);
-                    auto* r3 = b.Construct(ty.vec3(elem_ty), r_20, r_21, r_22, r_23);
-                    auto* r4 = b.Construct(ty.vec3(elem_ty), r_30, r_31, r_32, r_33);
+                    auto* r1 = b.Construct(ty.vec4(elem_ty), r_00, r_01, r_02, r_03);
+                    auto* r2 = b.Construct(ty.vec4(elem_ty), r_10, r_11, r_12, r_13);
+                    auto* r3 = b.Construct(ty.vec4(elem_ty), r_20, r_21, r_22, r_23);
+                    auto* r4 = b.Construct(ty.vec4(elem_ty), r_30, r_31, r_32, r_33);
 
                     auto* m = b.Construct(mat_ty, r1, r2, r3, r4);
                     auto* inv = b.Multiply(mat_ty, inv_det, m);
diff --git a/src/tint/lang/spirv/reader/lower/builtins_test.cc b/src/tint/lang/spirv/reader/lower/builtins_test.cc
index 684703e..26e9340 100644
--- a/src/tint/lang/spirv/reader/lower/builtins_test.cc
+++ b/src/tint/lang/spirv/reader/lower/builtins_test.cc
@@ -532,10 +532,10 @@
     %155:f32 = mul %7, %74
     %156:f32 = sub %153, %154
     %157:f32 = add %156, %155
-    %158:vec3<f32> = construct %80, %85, %90, %95
-    %159:vec3<f32> = construct %102, %107, %112, %117
-    %160:vec3<f32> = construct %122, %127, %132, %137
-    %161:vec3<f32> = construct %142, %147, %152, %157
+    %158:vec4<f32> = construct %80, %85, %90, %95
+    %159:vec4<f32> = construct %102, %107, %112, %117
+    %160:vec4<f32> = construct %122, %127, %132, %137
+    %161:vec4<f32> = construct %142, %147, %152, %157
     %162:mat4x4<f32> = construct %158, %159, %160, %161
     %163:mat4x4<f32> = mul %4, %162
     ret
@@ -728,10 +728,10 @@
     %155:f16 = mul %7, %74
     %156:f16 = sub %153, %154
     %157:f16 = add %156, %155
-    %158:vec3<f16> = construct %80, %85, %90, %95
-    %159:vec3<f16> = construct %102, %107, %112, %117
-    %160:vec3<f16> = construct %122, %127, %132, %137
-    %161:vec3<f16> = construct %142, %147, %152, %157
+    %158:vec4<f16> = construct %80, %85, %90, %95
+    %159:vec4<f16> = construct %102, %107, %112, %117
+    %160:vec4<f16> = construct %122, %127, %132, %137
+    %161:vec4<f16> = construct %142, %147, %152, %157
     %162:mat4x4<f16> = construct %158, %159, %160, %161
     %163:mat4x4<f16> = mul %4, %162
     ret