Fix binary expression resolving and validation with type aliases

* Fixed resolving logical compares with lhs alias
* Fixed resolving multiply with lhs or rhs alias
* Fixed resolving ops with vecN<alias>and matNxM<alias>
* Fixed validation with lhs or rhs alias
* Fixed spir-v generation with lhs/rhs alias and added missing error
message
* Added tests for all valid binary expressions with lhs, rhs, or both as
alias

Bug: tint:680
Change-Id: I095255a3c63ec20b2e974c6866be9470e7e6ec6a
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/46560
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index f1a5338..900e169 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -1030,18 +1030,21 @@
   using Matrix = type::Matrix;
   using Vector = type::Vector;
 
-  auto* lhs_type = TypeOf(expr->lhs())->UnwrapPtrIfNeeded();
-  auto* rhs_type = TypeOf(expr->rhs())->UnwrapPtrIfNeeded();
+  auto* lhs_type = TypeOf(expr->lhs())->UnwrapAll();
+  auto* rhs_type = TypeOf(expr->rhs())->UnwrapAll();
 
   auto* lhs_vec = lhs_type->As<Vector>();
-  auto* lhs_vec_elem_type = lhs_vec ? lhs_vec->type() : nullptr;
+  auto* lhs_vec_elem_type =
+      lhs_vec ? lhs_vec->type()->UnwrapAliasIfNeeded() : nullptr;
   auto* rhs_vec = rhs_type->As<Vector>();
-  auto* rhs_vec_elem_type = rhs_vec ? rhs_vec->type() : nullptr;
+  auto* rhs_vec_elem_type =
+      rhs_vec ? rhs_vec->type()->UnwrapAliasIfNeeded() : nullptr;
 
-  const bool matching_types = lhs_type == rhs_type;
   const bool matching_vec_elem_types = lhs_vec_elem_type && rhs_vec_elem_type &&
                                        (lhs_vec_elem_type == rhs_vec_elem_type);
 
+  const bool matching_types = matching_vec_elem_types || (lhs_type == rhs_type);
+
   // Binary logical expressions
   if (expr->IsLogicalAnd() || expr->IsLogicalOr()) {
     if (matching_types && lhs_type->Is<Bool>()) {
@@ -1085,9 +1088,11 @@
     }
 
     auto* lhs_mat = lhs_type->As<Matrix>();
-    auto* lhs_mat_elem_type = lhs_mat ? lhs_mat->type() : nullptr;
+    auto* lhs_mat_elem_type =
+        lhs_mat ? lhs_mat->type()->UnwrapAliasIfNeeded() : nullptr;
     auto* rhs_mat = rhs_type->As<Matrix>();
-    auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr;
+    auto* rhs_mat_elem_type =
+        rhs_mat ? rhs_mat->type()->UnwrapAliasIfNeeded() : nullptr;
 
     // Multiplication of a matrix and a scalar
     if (lhs_type->Is<F32>() && rhs_mat_elem_type &&
@@ -1195,7 +1200,7 @@
       expr->IsNotEqual() || expr->IsLessThan() || expr->IsGreaterThan() ||
       expr->IsLessThanEqual() || expr->IsGreaterThanEqual()) {
     auto* bool_type = builder_->create<type::Bool>();
-    auto* param_type = TypeOf(expr->lhs())->UnwrapPtrIfNeeded();
+    auto* param_type = TypeOf(expr->lhs())->UnwrapAll();
     type::Type* result_type = bool_type;
     if (auto* vec = param_type->As<type::Vector>()) {
       result_type = builder_->create<type::Vector>(bool_type, vec->size());
@@ -1204,8 +1209,8 @@
     return true;
   }
   if (expr->IsMultiply()) {
-    auto* lhs_type = TypeOf(expr->lhs())->UnwrapPtrIfNeeded();
-    auto* rhs_type = TypeOf(expr->rhs())->UnwrapPtrIfNeeded();
+    auto* lhs_type = TypeOf(expr->lhs())->UnwrapAll();
+    auto* rhs_type = TypeOf(expr->rhs())->UnwrapAll();
 
     // Note, the ordering here matters. The later checks depend on the prior
     // checks having been done.
diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc
index eadc0bc..1c785d9 100644
--- a/src/resolver/resolver_test.cc
+++ b/src/resolver/resolver_test.cc
@@ -1143,6 +1143,11 @@
   auto* rhs_type = params.create_rhs_type(ty);
   auto* result_type = params.create_result_type(ty);
 
+  std::stringstream ss;
+  ss << lhs_type->FriendlyName(Symbols()) << " " << params.op << " "
+     << rhs_type->FriendlyName(Symbols());
+  SCOPED_TRACE(ss.str());
+
   Global("lhs", lhs_type, ast::StorageClass::kNone);
   Global("rhs", rhs_type, ast::StorageClass::kNone);
 
@@ -1158,6 +1163,70 @@
                          Expr_Binary_Test_Valid,
                          testing::ValuesIn(all_valid_cases));
 
+enum class BinaryExprSide { Left, Right, Both };
+using Expr_Binary_Test_WithAlias_Valid =
+    ResolverTestWithParam<std::tuple<Params, BinaryExprSide>>;
+TEST_P(Expr_Binary_Test_WithAlias_Valid, All) {
+  const Params& params = std::get<0>(GetParam());
+  BinaryExprSide side = std::get<1>(GetParam());
+
+  auto* lhs_type = params.create_lhs_type(ty);
+  auto* rhs_type = params.create_rhs_type(ty);
+
+  std::stringstream ss;
+  ss << lhs_type->FriendlyName(Symbols()) << " " << params.op << " "
+     << rhs_type->FriendlyName(Symbols());
+
+  // For vectors and matrices, wrap the sub type in an alias
+  auto make_alias = [this](type::Type* type) -> type::Type* {
+    type::Type* result;
+    if (auto* v = type->As<type::Vector>()) {
+      result = create<type::Vector>(
+          create<type::Alias>(Symbols().New(), v->type()), v->size());
+    } else if (auto* m = type->As<type::Matrix>()) {
+      result =
+          create<type::Matrix>(create<type::Alias>(Symbols().New(), m->type()),
+                               m->rows(), m->columns());
+    } else {
+      result = create<type::Alias>(Symbols().New(), type);
+    }
+    return result;
+  };
+
+  // Wrap in alias
+  if (side == BinaryExprSide::Left || side == BinaryExprSide::Both) {
+    lhs_type = make_alias(lhs_type);
+  }
+  if (side == BinaryExprSide::Right || side == BinaryExprSide::Both) {
+    rhs_type = make_alias(rhs_type);
+  }
+
+  ss << ", After aliasing: " << lhs_type->FriendlyName(Symbols()) << " "
+     << params.op << " " << rhs_type->FriendlyName(Symbols());
+  SCOPED_TRACE(ss.str());
+
+  Global("lhs", lhs_type, ast::StorageClass::kNone);
+  Global("rhs", rhs_type, ast::StorageClass::kNone);
+
+  auto* expr =
+      create<ast::BinaryExpression>(params.op, Expr("lhs"), Expr("rhs"));
+  WrapInFunction(expr);
+
+  ASSERT_TRUE(r()->Resolve()) << r()->error();
+  ASSERT_NE(TypeOf(expr), nullptr);
+  // TODO(amaiorano): Bring this back once we have a way to get the canonical
+  // type
+  // auto* result_type = params.create_result_type(ty);
+  // ASSERT_TRUE(TypeOf(expr) == result_type);
+}
+INSTANTIATE_TEST_SUITE_P(
+    ResolverTest,
+    Expr_Binary_Test_WithAlias_Valid,
+    testing::Combine(testing::ValuesIn(all_valid_cases),
+                     testing::Values(BinaryExprSide::Left,
+                                     BinaryExprSide::Right,
+                                     BinaryExprSide::Both)));
+
 using Expr_Binary_Test_Invalid =
     ResolverTestWithParam<std::tuple<Params, create_type_func_ptr>>;
 TEST_P(Expr_Binary_Test_Invalid, All) {
@@ -1186,6 +1255,11 @@
     return;
   }
 
+  std::stringstream ss;
+  ss << lhs_type->FriendlyName(Symbols()) << " " << params.op << " "
+     << rhs_type->FriendlyName(Symbols());
+  SCOPED_TRACE(ss.str());
+
   Global("lhs", lhs_type, ast::StorageClass::kNone);
   Global("rhs", rhs_type, ast::StorageClass::kNone);
 
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index ca555ec..af2ec2f 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -1723,8 +1723,8 @@
 
   // Handle int and float and the vectors of those types. Other types
   // should have been rejected by validation.
-  auto* lhs_type = TypeOf(expr->lhs())->UnwrapPtrIfNeeded();
-  auto* rhs_type = TypeOf(expr->rhs())->UnwrapPtrIfNeeded();
+  auto* lhs_type = TypeOf(expr->lhs())->UnwrapAll();
+  auto* rhs_type = TypeOf(expr->rhs())->UnwrapAll();
   bool lhs_is_float_or_vec = lhs_type->is_float_scalar_or_vector();
   bool lhs_is_unsigned = lhs_type->is_unsigned_scalar_or_vector();
 
@@ -1820,6 +1820,7 @@
       // float matrix * matrix
       op = spv::Op::OpMatrixTimesMatrix;
     } else {
+      error_ = "invalid multiply expression";
       return 0;
     }
   } else if (expr->IsNotEqual()) {