Validate multiply of invalid vector/matrix sizes

Added tests that test all combos of vec*mat, mat*vec, and mat*mat for 2,
3, and 4 dimensions.

Bug: tint:698
Change-Id: I4a407228261cf8ea2a93bc7077544e5a9244d854
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/46660
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/program_builder.h b/src/program_builder.h
index 76c823d..613fb62 100644
--- a/src/program_builder.h
+++ b/src/program_builder.h
@@ -928,7 +928,7 @@
   /// @param rhs the right hand argument to the addition operation
   /// @returns a `ast::BinaryExpression` summing the arguments `lhs` and `rhs`
   template <typename LHS, typename RHS>
-  ast::Expression* Add(LHS&& lhs, RHS&& rhs) {
+  ast::BinaryExpression* Add(LHS&& lhs, RHS&& rhs) {
     return create<ast::BinaryExpression>(ast::BinaryOp::kAdd,
                                          Expr(std::forward<LHS>(lhs)),
                                          Expr(std::forward<RHS>(rhs)));
@@ -938,7 +938,7 @@
   /// @param rhs the right hand argument to the subtraction operation
   /// @returns a `ast::BinaryExpression` subtracting `rhs` from `lhs`
   template <typename LHS, typename RHS>
-  ast::Expression* Sub(LHS&& lhs, RHS&& rhs) {
+  ast::BinaryExpression* Sub(LHS&& lhs, RHS&& rhs) {
     return create<ast::BinaryExpression>(ast::BinaryOp::kSubtract,
                                          Expr(std::forward<LHS>(lhs)),
                                          Expr(std::forward<RHS>(rhs)));
@@ -948,17 +948,28 @@
   /// @param rhs the right hand argument to the multiplication operation
   /// @returns a `ast::BinaryExpression` multiplying `rhs` from `lhs`
   template <typename LHS, typename RHS>
-  ast::Expression* Mul(LHS&& lhs, RHS&& rhs) {
+  ast::BinaryExpression* Mul(LHS&& lhs, RHS&& rhs) {
     return create<ast::BinaryExpression>(ast::BinaryOp::kMultiply,
                                          Expr(std::forward<LHS>(lhs)),
                                          Expr(std::forward<RHS>(rhs)));
   }
 
+  /// @param source the source information
+  /// @param lhs the left hand argument to the multiplication operation
+  /// @param rhs the right hand argument to the multiplication operation
+  /// @returns a `ast::BinaryExpression` multiplying `rhs` from `lhs`
+  template <typename LHS, typename RHS>
+  ast::BinaryExpression* Mul(const Source& source, LHS&& lhs, RHS&& rhs) {
+    return create<ast::BinaryExpression>(source, ast::BinaryOp::kMultiply,
+                                         Expr(std::forward<LHS>(lhs)),
+                                         Expr(std::forward<RHS>(rhs)));
+  }
+
   /// @param arr the array argument for the array accessor expression
   /// @param idx the index argument for the array accessor expression
   /// @returns a `ast::ArrayAccessorExpression` that indexes `arr` with `idx`
   template <typename ARR, typename IDX>
-  ast::Expression* IndexAccessor(ARR&& arr, IDX&& idx) {
+  ast::ArrayAccessorExpression* IndexAccessor(ARR&& arr, IDX&& idx) {
     return create<ast::ArrayAccessorExpression>(Expr(std::forward<ARR>(arr)),
                                                 Expr(std::forward<IDX>(idx)));
   }
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 1158ad6..0c79dd7 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -1040,8 +1040,10 @@
   auto* rhs_vec_elem_type =
       rhs_vec ? rhs_vec->type()->UnwrapAliasIfNeeded() : nullptr;
 
-  const bool matching_vec_elem_types = lhs_vec_elem_type && rhs_vec_elem_type &&
-                                       (lhs_vec_elem_type == rhs_vec_elem_type);
+  const bool matching_vec_elem_types =
+      lhs_vec_elem_type && rhs_vec_elem_type &&
+      (lhs_vec_elem_type == rhs_vec_elem_type) &&
+      (lhs_vec->size() == rhs_vec->size());
 
   const bool matching_types = matching_vec_elem_types || (lhs_type == rhs_type);
 
@@ -1106,19 +1108,22 @@
 
     // Vector times matrix
     if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() &&
-        rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>()) {
+        rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() &&
+        (lhs_vec->size() == rhs_mat->rows())) {
       return true;
     }
 
     // Matrix times vector
     if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
-        rhs_vec_elem_type && rhs_vec_elem_type->Is<F32>()) {
+        rhs_vec_elem_type && rhs_vec_elem_type->Is<F32>() &&
+        (lhs_mat->columns() == rhs_vec->size())) {
       return true;
     }
 
     // Matrix times matrix
     if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
-        rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>()) {
+        rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() &&
+        (lhs_mat->columns() == rhs_mat->rows())) {
       return true;
     }
   }
diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc
index 12d019e..7ee58c6 100644
--- a/src/resolver/resolver_test.cc
+++ b/src/resolver/resolver_test.cc
@@ -1263,11 +1263,11 @@
   Global("lhs", lhs_type, ast::StorageClass::kNone);
   Global("rhs", rhs_type, ast::StorageClass::kNone);
 
-  auto* expr = create<ast::BinaryExpression>(
-      Source{Source::Location{12, 34}}, params.op, Expr("lhs"), Expr("rhs"));
+  auto* expr = create<ast::BinaryExpression>(Source{{12, 34}}, params.op,
+                                             Expr("lhs"), Expr("rhs"));
   WrapInFunction(expr);
 
-  ASSERT_FALSE(r()->Resolve()) << r()->error();
+  ASSERT_FALSE(r()->Resolve());
   ASSERT_EQ(r()->error(),
             "12:34 error: Binary expression operand types are invalid for "
             "this operation: " +
@@ -1280,6 +1280,99 @@
     Expr_Binary_Test_Invalid,
     testing::Combine(testing::ValuesIn(all_valid_cases),
                      testing::ValuesIn(all_create_type_funcs)));
+
+using Expr_Binary_Test_Invalid_VectorMatrixMultiply =
+    ResolverTestWithParam<std::tuple<bool, uint32_t, uint32_t, uint32_t>>;
+TEST_P(Expr_Binary_Test_Invalid_VectorMatrixMultiply, All) {
+  bool vec_by_mat = std::get<0>(GetParam());
+  uint32_t vec_size = std::get<1>(GetParam());
+  uint32_t mat_rows = std::get<2>(GetParam());
+  uint32_t mat_cols = std::get<3>(GetParam());
+
+  type::Type* lhs_type;
+  type::Type* rhs_type;
+  type::Type* result_type;
+  bool is_valid_expr;
+
+  if (vec_by_mat) {
+    lhs_type = create<type::Vector>(ty.f32(), vec_size);
+    rhs_type = create<type::Matrix>(ty.f32(), mat_rows, mat_cols);
+    result_type = create<type::Vector>(ty.f32(), mat_cols);
+    is_valid_expr = vec_size == mat_rows;
+  } else {
+    lhs_type = create<type::Matrix>(ty.f32(), mat_rows, mat_cols);
+    rhs_type = create<type::Vector>(ty.f32(), vec_size);
+    result_type = create<type::Vector>(ty.f32(), mat_rows);
+    is_valid_expr = vec_size == mat_cols;
+  }
+
+  Global("lhs", lhs_type, ast::StorageClass::kNone);
+  Global("rhs", rhs_type, ast::StorageClass::kNone);
+
+  auto* expr = Mul(Source{{12, 34}}, Expr("lhs"), Expr("rhs"));
+  WrapInFunction(expr);
+
+  if (is_valid_expr) {
+    ASSERT_TRUE(r()->Resolve()) << r()->error();
+    ASSERT_TRUE(TypeOf(expr) == result_type);
+  } else {
+    ASSERT_FALSE(r()->Resolve());
+    ASSERT_EQ(r()->error(),
+              "12:34 error: Binary expression operand types are invalid for "
+              "this operation: " +
+                  lhs_type->FriendlyName(Symbols()) + " " +
+                  FriendlyName(expr->op()) + " " +
+                  rhs_type->FriendlyName(Symbols()));
+  }
+}
+auto all_dimension_values = testing::Values(2u, 3u, 4u);
+INSTANTIATE_TEST_SUITE_P(ResolverTest,
+                         Expr_Binary_Test_Invalid_VectorMatrixMultiply,
+                         testing::Combine(testing::Values(true, false),
+                                          all_dimension_values,
+                                          all_dimension_values,
+                                          all_dimension_values));
+
+using Expr_Binary_Test_Invalid_MatrixMatrixMultiply =
+    ResolverTestWithParam<std::tuple<uint32_t, uint32_t, uint32_t, uint32_t>>;
+TEST_P(Expr_Binary_Test_Invalid_MatrixMatrixMultiply, All) {
+  uint32_t lhs_mat_rows = std::get<0>(GetParam());
+  uint32_t lhs_mat_cols = std::get<1>(GetParam());
+  uint32_t rhs_mat_rows = std::get<2>(GetParam());
+  uint32_t rhs_mat_cols = std::get<3>(GetParam());
+
+  auto* lhs_type = create<type::Matrix>(ty.f32(), lhs_mat_rows, lhs_mat_cols);
+  auto* rhs_type = create<type::Matrix>(ty.f32(), rhs_mat_rows, rhs_mat_cols);
+  auto* result_type =
+      create<type::Matrix>(ty.f32(), lhs_mat_rows, rhs_mat_cols);
+
+  Global("lhs", lhs_type, ast::StorageClass::kNone);
+  Global("rhs", rhs_type, ast::StorageClass::kNone);
+
+  auto* expr = Mul(Source{{12, 34}}, Expr("lhs"), Expr("rhs"));
+  WrapInFunction(expr);
+
+  bool is_valid_expr = lhs_mat_cols == rhs_mat_rows;
+  if (is_valid_expr) {
+    ASSERT_TRUE(r()->Resolve()) << r()->error();
+    ASSERT_TRUE(TypeOf(expr) == result_type);
+  } else {
+    ASSERT_FALSE(r()->Resolve());
+    ASSERT_EQ(r()->error(),
+              "12:34 error: Binary expression operand types are invalid for "
+              "this operation: " +
+                  lhs_type->FriendlyName(Symbols()) + " " +
+                  FriendlyName(expr->op()) + " " +
+                  rhs_type->FriendlyName(Symbols()));
+  }
+}
+INSTANTIATE_TEST_SUITE_P(ResolverTest,
+                         Expr_Binary_Test_Invalid_MatrixMatrixMultiply,
+                         testing::Combine(all_dimension_values,
+                                          all_dimension_values,
+                                          all_dimension_values,
+                                          all_dimension_values));
+
 }  // namespace ExprBinaryTest
 
 using UnaryOpExpressionTest = ResolverTestWithParam<ast::UnaryOp>;