Resolver: Enforce matrix constructor type rules

Added enforcement for matrix constructor type rules according to the
table in https://gpuweb.github.io/gpuweb/wgsl.html#type-constructor-expr.

Fixed: tint:633
Change-Id: I97fc7f558f04780ed03252d94c071af3e0e07e26
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/45020
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Arman Uguray <armansito@chromium.org>
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 6e16b2e..d0e24d9 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -63,6 +63,13 @@
   T old_value_;
 };
 
+// Helper function that returns the range union of two source locations. The
+// `start` and `end` locations are assumed to refer to the same source file.
+Source CombineSourceRange(const Source& start, const Source& end) {
+  return Source(Source::Range(start.range.begin, end.range.end),
+                start.file_path, start.file_content);
+}
+
 }  // namespace
 
 Resolver::Resolver(ProgramBuilder* builder)
@@ -572,9 +579,11 @@
     // obey the constructor type rules laid out in
     // https://gpuweb.github.io/gpuweb/wgsl.html#type-constructor-expr.
     if (auto* vec_type = type_ctor->type()->As<type::Vector>()) {
-      return VectorConstructor(*vec_type, type_ctor->values());
+      return VectorConstructor(vec_type, type_ctor->values());
     }
-    // TODO(crbug.com/tint/633): Validate matrix constructor
+    if (auto* mat_type = type_ctor->type()->As<type::Matrix>()) {
+      return MatrixConstructor(mat_type, type_ctor->values());
+    }
     // TODO(crbug.com/tint/634): Validate array constructor
   } else if (auto* scalar_ctor = expr->As<ast::ScalarConstructorExpression>()) {
     SetType(expr, scalar_ctor->literal()->type());
@@ -584,9 +593,9 @@
   return true;
 }
 
-bool Resolver::VectorConstructor(const type::Vector& vec_type,
+bool Resolver::VectorConstructor(const type::Vector* vec_type,
                                  const ast::ExpressionList& values) {
-  type::Type* elem_type = vec_type.type()->UnwrapAll();
+  type::Type* elem_type = vec_type->type()->UnwrapAll();
   size_t value_cardinality_sum = 0;
   for (auto* value : values) {
     type::Type* value_type = TypeOf(value)->UnwrapAll();
@@ -635,26 +644,63 @@
   // A correct vector constructor must either be a zero-value expression
   // or the number of components of all constructor arguments must add up
   // to the vector cardinality.
-  if (value_cardinality_sum > 0 && value_cardinality_sum != vec_type.size()) {
+  if (value_cardinality_sum > 0 && value_cardinality_sum != vec_type->size()) {
     if (values.empty()) {
       TINT_ICE(diagnostics_)
           << "constructor arguments expected to be non-empty!";
     }
     const Source& values_start = values[0]->source();
     const Source& values_end = values[values.size() - 1]->source();
-    const Source src(
-        Source::Range(values_start.range.begin, values_end.range.end),
-        values_start.file_path, values_start.file_content);
     diagnostics_.add_error(
         "attempted to construct '" +
-            vec_type.FriendlyName(builder_->Symbols()) + "' with " +
+            vec_type->FriendlyName(builder_->Symbols()) + "' with " +
             std::to_string(value_cardinality_sum) + " component(s)",
-        src);
+        CombineSourceRange(values_start, values_end));
     return false;
   }
   return true;
 }
 
+bool Resolver::MatrixConstructor(const type::Matrix* matrix_type,
+                                 const ast::ExpressionList& values) {
+  // Zero Value expression
+  if (values.empty()) {
+    return true;
+  }
+
+  type::Type* elem_type = matrix_type->type()->UnwrapAll();
+  if (matrix_type->columns() != values.size()) {
+    const Source& values_start = values[0]->source();
+    const Source& values_end = values[values.size() - 1]->source();
+    diagnostics_.add_error(
+        "expected " + std::to_string(matrix_type->columns()) + " '" +
+            VectorPretty(matrix_type->rows(), elem_type) + "' arguments in '" +
+            matrix_type->FriendlyName(builder_->Symbols()) +
+            "' constructor, found " + std::to_string(values.size()),
+        CombineSourceRange(values_start, values_end));
+    return false;
+  }
+
+  for (auto* value : values) {
+    type::Type* value_type = TypeOf(value)->UnwrapAll();
+    auto* value_vec = value_type->As<type::Vector>();
+
+    if (!value_vec || value_vec->size() != matrix_type->rows() ||
+        elem_type != value_vec->type()->UnwrapAll()) {
+      diagnostics_.add_error(
+          "expected argument type '" +
+              VectorPretty(matrix_type->rows(), elem_type) + "' in '" +
+              matrix_type->FriendlyName(builder_->Symbols()) +
+              "' constructor, found '" +
+              value_type->FriendlyName(builder_->Symbols()) + "'",
+          value->source());
+      return false;
+    }
+  }
+
+  return true;
+}
+
 bool Resolver::Identifier(ast::IdentifierExpression* expr) {
   auto symbol = expr->symbol();
   VariableInfo* var;
@@ -1501,6 +1547,11 @@
   return callback();
 }
 
+std::string Resolver::VectorPretty(uint32_t size, type::Type* element_type) {
+  type::Vector vec_type(element_type, size);
+  return vec_type.FriendlyName(builder_->Symbols());
+}
+
 Resolver::VariableInfo::VariableInfo(ast::Variable* decl)
     : declaration(decl), storage_class(decl->declared_storage_class()) {}
 
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index 7bb56b6..63e5850 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -194,7 +194,9 @@
   bool Call(ast::CallExpression*);
   bool CaseStatement(ast::CaseStatement*);
   bool Constructor(ast::ConstructorExpression*);
-  bool VectorConstructor(const type::Vector& vec_type,
+  bool VectorConstructor(const type::Vector* vec_type,
+                         const ast::ExpressionList& values);
+  bool MatrixConstructor(const type::Matrix* matrix_type,
                          const ast::ExpressionList& values);
   bool Expression(ast::Expression*);
   bool Expressions(const ast::ExpressionList&);
@@ -247,6 +249,13 @@
   template <typename F>
   bool BlockScope(BlockInfo::Type type, F&& callback);
 
+  /// Returns a human-readable string representation of the vector type name
+  /// with the given parameters.
+  /// @param size the vector dimension
+  /// @param element_type scalar vector sub-element type
+  /// @return pretty string representation
+  std::string VectorPretty(uint32_t size, type::Type* element_type);
+
   ProgramBuilder* const builder_;
   std::unique_ptr<IntrinsicTable> const intrinsic_table_;
   diag::List diagnostics_;
diff --git a/src/resolver/validation_test.cc b/src/resolver/validation_test.cc
index 9dc2e0d..8cc4093 100644
--- a/src/resolver/validation_test.cc
+++ b/src/resolver/validation_test.cc
@@ -1664,6 +1664,357 @@
   EXPECT_TRUE(r()->Resolve()) << r()->error();
 }
 
+struct MatrixDimensions {
+  uint32_t rows;
+  uint32_t columns;
+};
+
+std::string MatrixStr(const MatrixDimensions& dimensions,
+                      std::string subtype = "f32") {
+  return "mat" + std::to_string(dimensions.columns) + "x" +
+         std::to_string(dimensions.rows) + "<" + subtype + ">";
+}
+
+std::string VecStr(uint32_t dimensions, std::string subtype = "f32") {
+  return "vec" + std::to_string(dimensions) + "<" + subtype + ">";
+}
+
+using MatrixConstructorTest = ResolverTestWithParam<MatrixDimensions>;
+
+TEST_P(MatrixConstructorTest, Expr_Constructor_Error_TooFewArguments) {
+  // matNxM<f32>(vecM<f32>(), ...); with N - 1 arguments
+
+  const auto param = GetParam();
+  auto* matrix_type = create<type::Matrix>(ty.f32(), param.rows, param.columns);
+  auto* vec_type = create<type::Vector>(ty.f32(), param.rows);
+
+  ast::ExpressionList args;
+  for (uint32_t i = 1; i <= param.columns - 1; i++) {
+    args.push_back(create<ast::TypeConstructorExpression>(
+        Source{{12, i}}, vec_type, ExprList()));
+  }
+
+  auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
+                                                    std::move(args));
+  WrapInFunction(tc);
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:1 error: expected " + std::to_string(param.columns) + " '" +
+                VecStr(param.rows) + "' arguments in '" + MatrixStr(param) +
+                "' constructor, found " + std::to_string(param.columns - 1));
+}
+
+TEST_P(MatrixConstructorTest, Expr_Constructor_Error_TooManyArguments) {
+  // matNxM<f32>(vecM<f32>(), ...); with N + 1 arguments
+
+  const auto param = GetParam();
+  auto* matrix_type = create<type::Matrix>(ty.f32(), param.rows, param.columns);
+  auto* vec_type = create<type::Vector>(ty.f32(), param.rows);
+
+  ast::ExpressionList args;
+  for (uint32_t i = 1; i <= param.columns + 1; i++) {
+    args.push_back(create<ast::TypeConstructorExpression>(
+        Source{{12, i}}, vec_type, ExprList()));
+  }
+
+  auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
+                                                    std::move(args));
+  WrapInFunction(tc);
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:1 error: expected " + std::to_string(param.columns) + " '" +
+                VecStr(param.rows) + "' arguments in '" + MatrixStr(param) +
+                "' constructor, found " + std::to_string(param.columns + 1));
+}
+
+TEST_P(MatrixConstructorTest, Expr_Constructor_Error_InvalidArgumentType) {
+  // matNxM<f32>(1.0, 1.0, ...); N arguments
+
+  const auto param = GetParam();
+  auto* matrix_type = create<type::Matrix>(ty.f32(), param.rows, param.columns);
+
+  ast::ExpressionList args;
+  for (uint32_t i = 1; i <= param.columns; i++) {
+    args.push_back(create<ast::ScalarConstructorExpression>(Source{{12, i}},
+                                                            Literal(1.0f)));
+  }
+
+  auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
+                                                    std::move(args));
+  WrapInFunction(tc);
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(), "12:1 error: expected argument type '" +
+                              VecStr(param.rows) + "' in '" + MatrixStr(param) +
+                              "' constructor, found 'f32'");
+}
+
+TEST_P(MatrixConstructorTest,
+       Expr_Constructor_Error_TooFewRowsInVectorArgument) {
+  // matNxM<f32>(vecM<f32>(),...,vecM-1<f32>());
+
+  const auto param = GetParam();
+
+  // Skip the test if parameters would have resuled in an invalid vec1 type.
+  if (param.rows == 2) {
+    return;
+  }
+
+  auto* matrix_type = create<type::Matrix>(ty.f32(), param.rows, param.columns);
+  auto* valid_vec_type = create<type::Vector>(ty.f32(), param.rows);
+  auto* invalid_vec_type = create<type::Vector>(ty.f32(), param.rows - 1);
+
+  ast::ExpressionList args;
+  for (uint32_t i = 1; i <= param.columns - 1; i++) {
+    args.push_back(create<ast::TypeConstructorExpression>(
+        Source{{12, i}}, valid_vec_type, ExprList()));
+  }
+  const size_t kInvalidLoc = 2 * (param.columns - 1);
+  args.push_back(create<ast::TypeConstructorExpression>(
+      Source{{12, kInvalidLoc}}, invalid_vec_type, ExprList()));
+
+  auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
+                                                    std::move(args));
+  WrapInFunction(tc);
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(), "12:" + std::to_string(kInvalidLoc) +
+                              " error: expected argument type '" +
+                              VecStr(param.rows) + "' in '" + MatrixStr(param) +
+                              "' constructor, found '" +
+                              VecStr(param.rows - 1) + "'");
+}
+
+TEST_P(MatrixConstructorTest,
+       Expr_Constructor_Error_TooManyRowsInVectorArgument) {
+  // matNxM<f32>(vecM<f32>(),...,vecM+1<f32>());
+
+  const auto param = GetParam();
+
+  // Skip the test if parameters would have resuled in an invalid vec5 type.
+  if (param.rows == 4) {
+    return;
+  }
+
+  auto* matrix_type = create<type::Matrix>(ty.f32(), param.rows, param.columns);
+  auto* valid_vec_type = create<type::Vector>(ty.f32(), param.rows);
+  auto* invalid_vec_type = create<type::Vector>(ty.f32(), param.rows + 1);
+
+  ast::ExpressionList args;
+  for (uint32_t i = 1; i <= param.columns - 1; i++) {
+    args.push_back(create<ast::TypeConstructorExpression>(
+        Source{{12, i}}, valid_vec_type, ExprList()));
+  }
+  const size_t kInvalidLoc = 2 * (param.columns - 1);
+  args.push_back(create<ast::TypeConstructorExpression>(
+      Source{{12, kInvalidLoc}}, invalid_vec_type, ExprList()));
+
+  auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
+                                                    std::move(args));
+  WrapInFunction(tc);
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(), "12:" + std::to_string(kInvalidLoc) +
+                              " error: expected argument type '" +
+                              VecStr(param.rows) + "' in '" + MatrixStr(param) +
+                              "' constructor, found '" +
+                              VecStr(param.rows + 1) + "'");
+}
+
+TEST_P(MatrixConstructorTest,
+       Expr_Constructor_Error_ArgumentVectorElementTypeMismatch) {
+  // matNxM<f32>(vecM<u32>(), ...); with N arguments
+
+  const auto param = GetParam();
+  auto* matrix_type = create<type::Matrix>(ty.f32(), param.rows, param.columns);
+  auto* vec_type = create<type::Vector>(ty.u32(), param.rows);
+
+  ast::ExpressionList args;
+  for (uint32_t i = 1; i <= param.columns; i++) {
+    args.push_back(create<ast::TypeConstructorExpression>(
+        Source{{12, i}}, vec_type, ExprList()));
+  }
+
+  auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
+                                                    std::move(args));
+  WrapInFunction(tc);
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(), "12:1 error: expected argument type '" +
+                              VecStr(param.rows) + "' in '" + MatrixStr(param) +
+                              "' constructor, found '" +
+                              VecStr(param.rows, "u32") + "'");
+}
+
+TEST_P(MatrixConstructorTest, Expr_Constructor_ZeroValue_Success) {
+  // matNxM<f32>();
+
+  const auto param = GetParam();
+  auto* matrix_type = create<type::Matrix>(ty.f32(), param.rows, param.columns);
+  auto* tc = create<ast::TypeConstructorExpression>(Source{{12, 40}},
+                                                    matrix_type, ExprList());
+  WrapInFunction(tc);
+
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_P(MatrixConstructorTest, Expr_Constructor_WithArguments_Success) {
+  // matNxM<f32>(vecM<f32>(), ...); with N arguments
+
+  const auto param = GetParam();
+  auto* matrix_type = create<type::Matrix>(ty.f32(), param.rows, param.columns);
+  auto* vec_type = create<type::Vector>(ty.f32(), param.rows);
+
+  ast::ExpressionList args;
+  for (uint32_t i = 1; i <= param.columns; i++) {
+    args.push_back(create<ast::TypeConstructorExpression>(
+        Source{{12, i}}, vec_type, ExprList()));
+  }
+
+  auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
+                                                    std::move(args));
+  WrapInFunction(tc);
+
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_P(MatrixConstructorTest, Expr_Constructor_ElementTypeAlias_Error) {
+  // matNxM<Float32>(vecM<u32>(), ...); with N arguments
+
+  const auto param = GetParam();
+  auto* f32_alias = ty.alias("Float32", ty.f32());
+  auto* matrix_type =
+      create<type::Matrix>(f32_alias, param.rows, param.columns);
+  auto* vec_type = create<type::Vector>(ty.u32(), param.rows);
+
+  ast::ExpressionList args;
+  for (uint32_t i = 1; i <= param.columns; i++) {
+    args.push_back(create<ast::TypeConstructorExpression>(
+        Source{{12, i}}, vec_type, ExprList()));
+  }
+
+  auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
+                                                    std::move(args));
+  WrapInFunction(tc);
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:1 error: expected argument type '" + VecStr(param.rows) +
+                "' in '" + MatrixStr(param, "Float32") +
+                "' constructor, found '" + VecStr(param.rows, "u32") + "'");
+}
+
+TEST_P(MatrixConstructorTest, Expr_Constructor_ElementTypeAlias_Success) {
+  // matNxM<Float32>(vecM<f32>(), ...); with N arguments
+
+  const auto param = GetParam();
+  auto* f32_alias = ty.alias("Float32", ty.f32());
+  auto* matrix_type =
+      create<type::Matrix>(f32_alias, param.rows, param.columns);
+  auto* vec_type = create<type::Vector>(ty.f32(), param.rows);
+
+  ast::ExpressionList args;
+  for (uint32_t i = 1; i <= param.columns; i++) {
+    args.push_back(create<ast::TypeConstructorExpression>(
+        Source{{12, i}}, vec_type, ExprList()));
+  }
+
+  auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
+                                                    std::move(args));
+  WrapInFunction(tc);
+
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverValidationTest, Expr_MatrixConstructor_ArgumentTypeAlias_Error) {
+  auto* vec2_alias = ty.alias("VectorUnsigned2", ty.vec2<u32>());
+  auto* tc = mat2x2<f32>(create<ast::TypeConstructorExpression>(
+                             Source{{12, 34}}, vec2_alias, ExprList()),
+                         vec2<f32>());
+  WrapInFunction(tc);
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:34 error: expected argument type 'vec2<f32>' in 'mat2x2<f32>' "
+            "constructor, found 'vec2<u32>'");
+}
+
+TEST_P(MatrixConstructorTest, Expr_Constructor_ArgumentTypeAlias_Success) {
+  const auto param = GetParam();
+  auto* matrix_type = create<type::Matrix>(ty.f32(), param.rows, param.columns);
+  auto* vec_type = create<type::Vector>(ty.f32(), param.rows);
+  auto* vec_alias = ty.alias("VectorFloat2", vec_type);
+
+  ast::ExpressionList args;
+  for (uint32_t i = 1; i <= param.columns; i++) {
+    args.push_back(create<ast::TypeConstructorExpression>(
+        Source{{12, i}}, vec_alias, ExprList()));
+  }
+
+  auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
+                                                    std::move(args));
+  WrapInFunction(tc);
+
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_P(MatrixConstructorTest, Expr_Constructor_ArgumentElementTypeAlias_Error) {
+  const auto param = GetParam();
+  auto* matrix_type = create<type::Matrix>(ty.f32(), param.rows, param.columns);
+  auto* f32_alias = ty.alias("UnsignedInt", ty.u32());
+  auto* vec_type = create<type::Vector>(f32_alias, param.rows);
+
+  ast::ExpressionList args;
+  for (uint32_t i = 1; i <= param.columns; i++) {
+    args.push_back(create<ast::TypeConstructorExpression>(
+        Source{{12, i}}, vec_type, ExprList()));
+  }
+
+  auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
+                                                    std::move(args));
+  WrapInFunction(tc);
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(), "12:1 error: expected argument type '" +
+                              VecStr(param.rows) + "' in '" + MatrixStr(param) +
+                              "' constructor, found '" +
+                              VecStr(param.rows, "UnsignedInt") + "'");
+}
+
+TEST_P(MatrixConstructorTest,
+       Expr_Constructor_ArgumentElementTypeAlias_Success) {
+  const auto param = GetParam();
+  auto* matrix_type = create<type::Matrix>(ty.f32(), param.rows, param.columns);
+  auto* f32_alias = ty.alias("Float32", ty.f32());
+  auto* vec_type = create<type::Vector>(f32_alias, param.rows);
+
+  ast::ExpressionList args;
+  for (uint32_t i = 1; i <= param.columns; i++) {
+    args.push_back(create<ast::TypeConstructorExpression>(
+        Source{{12, i}}, vec_type, ExprList()));
+  }
+
+  auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
+                                                    std::move(args));
+  WrapInFunction(tc);
+
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+INSTANTIATE_TEST_SUITE_P(ResolverValidationTest,
+                         MatrixConstructorTest,
+                         testing::Values(MatrixDimensions{2, 2},
+                                         MatrixDimensions{3, 2},
+                                         MatrixDimensions{4, 2},
+                                         MatrixDimensions{2, 3},
+                                         MatrixDimensions{3, 3},
+                                         MatrixDimensions{4, 3},
+                                         MatrixDimensions{2, 4},
+                                         MatrixDimensions{3, 4},
+                                         MatrixDimensions{4, 4}));
+
 }  // namespace
 }  // namespace resolver
 }  // namespace tint