Resolver: merge binary expression validation and resolving logic together

This avoid duplicating the logic in two places, and makes it easier to
implement according to the spec.

Bug: tint:376
Change-Id: If62f508e2c76b5b661e66aae9ff20b8e874a65d8
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/52323
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 15e7f3e..3be647b 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -2035,7 +2035,14 @@
   return true;
 }
 
-bool Resolver::ValidateBinary(ast::BinaryExpression* expr) {
+bool Resolver::Binary(ast::BinaryExpression* expr) {
+  Mark(expr->lhs());
+  Mark(expr->rhs());
+
+  if (!Expression(expr->lhs()) || !Expression(expr->rhs())) {
+    return false;
+  }
+
   using Bool = sem::Bool;
   using F32 = sem::F32;
   using I32 = sem::I32;
@@ -2061,14 +2068,17 @@
   // Binary logical expressions
   if (expr->IsLogicalAnd() || expr->IsLogicalOr()) {
     if (matching_types && lhs_type->Is<Bool>()) {
+      SetType(expr, lhs_type);
       return true;
     }
   }
   if (expr->IsOr() || expr->IsAnd()) {
     if (matching_types && lhs_type->Is<Bool>()) {
+      SetType(expr, lhs_type);
       return true;
     }
     if (matching_types && lhs_vec_elem_type && lhs_vec_elem_type->Is<Bool>()) {
+      SetType(expr, lhs_type);
       return true;
     }
   }
@@ -2076,13 +2086,15 @@
   // Arithmetic expressions
   if (expr->IsArithmetic()) {
     // Binary arithmetic expressions over scalars
-    if (matching_types && lhs_type->IsAnyOf<I32, F32, U32>()) {
+    if (matching_types && lhs_type->is_numeric_scalar()) {
+      SetType(expr, lhs_type);
       return true;
     }
 
     // Binary arithmetic expressions over vectors
     if (matching_types && lhs_vec_elem_type &&
-        lhs_vec_elem_type->IsAnyOf<I32, F32, U32>()) {
+        lhs_vec_elem_type->is_numeric_scalar()) {
+      SetType(expr, lhs_type);
       return true;
     }
   }
@@ -2093,10 +2105,12 @@
     // Multiplication of a vector and a scalar
     if (lhs_type->Is<F32>() && rhs_vec_elem_type &&
         rhs_vec_elem_type->Is<F32>()) {
+      SetType(expr, rhs_type);
       return true;
     }
     if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() &&
         rhs_type->Is<F32>()) {
+      SetType(expr, lhs_type);
       return true;
     }
 
@@ -2108,10 +2122,12 @@
     // Multiplication of a matrix and a scalar
     if (lhs_type->Is<F32>() && rhs_mat_elem_type &&
         rhs_mat_elem_type->Is<F32>()) {
+      SetType(expr, rhs_type);
       return true;
     }
     if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
         rhs_type->Is<F32>()) {
+      SetType(expr, lhs_type);
       return true;
     }
 
@@ -2119,6 +2135,8 @@
     if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() &&
         rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() &&
         (lhs_vec->size() == rhs_mat->rows())) {
+      SetType(expr, builder_->create<sem::Vector>(lhs_vec->type(),
+                                                  rhs_mat->columns()));
       return true;
     }
 
@@ -2126,6 +2144,8 @@
     if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
         rhs_vec_elem_type && rhs_vec_elem_type->Is<F32>() &&
         (lhs_mat->columns() == rhs_vec->size())) {
+      SetType(expr,
+              builder_->create<sem::Vector>(rhs_vec->type(), lhs_mat->rows()));
       return true;
     }
 
@@ -2133,6 +2153,10 @@
     if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
         rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() &&
         (lhs_mat->columns() == rhs_mat->rows())) {
+      SetType(expr, builder_->create<sem::Matrix>(
+                        builder_->create<sem::Vector>(lhs_mat_elem_type,
+                                                      lhs_mat->rows()),
+                        rhs_mat->columns()));
       return true;
     }
   }
@@ -2142,11 +2166,13 @@
     if (matching_types) {
       // Special case for bools: only == and !=
       if (lhs_type->Is<Bool>() && (expr->IsEqual() || expr->IsNotEqual())) {
+        SetType(expr, builder_->create<sem::Bool>());
         return true;
       }
 
       // For the rest, we can compare i32, u32, and f32
       if (lhs_type->IsAnyOf<I32, U32, F32>()) {
+        SetType(expr, builder_->create<sem::Bool>());
         return true;
       }
     }
@@ -2155,10 +2181,14 @@
     if (matching_vec_elem_types) {
       if (lhs_vec_elem_type->Is<Bool>() &&
           (expr->IsEqual() || expr->IsNotEqual())) {
+        SetType(expr, builder_->create<sem::Vector>(
+                          builder_->create<sem::Bool>(), lhs_vec->size()));
         return true;
       }
 
-      if (lhs_vec_elem_type->IsAnyOf<I32, U32, F32>()) {
+      if (lhs_vec_elem_type->is_numeric_scalar()) {
+        SetType(expr, builder_->create<sem::Vector>(
+                          builder_->create<sem::Bool>(), lhs_vec->size()));
         return true;
       }
     }
@@ -2167,6 +2197,7 @@
   // Binary bitwise operations
   if (expr->IsBitwise()) {
     if (matching_types && lhs_type->IsAnyOf<I32, U32>()) {
+      SetType(expr, lhs_type);
       return true;
     }
   }
@@ -2178,11 +2209,13 @@
     // logical depending on lhs type).
 
     if (lhs_type->IsAnyOf<I32, U32>() && rhs_type->Is<U32>()) {
+      SetType(expr, lhs_type);
       return true;
     }
 
     if (lhs_vec_elem_type && lhs_vec_elem_type->IsAnyOf<I32, U32>() &&
         rhs_vec_elem_type && rhs_vec_elem_type->Is<U32>()) {
+      SetType(expr, lhs_type);
       return true;
     }
   }
@@ -2196,86 +2229,6 @@
   return false;
 }
 
-bool Resolver::Binary(ast::BinaryExpression* expr) {
-  Mark(expr->lhs());
-  Mark(expr->rhs());
-  if (!Expression(expr->lhs()) || !Expression(expr->rhs())) {
-    return false;
-  }
-
-  if (!ValidateBinary(expr)) {
-    return false;
-  }
-
-  // Result type matches first parameter type
-  if (expr->IsAnd() || expr->IsOr() || expr->IsXor() || expr->IsShiftLeft() ||
-      expr->IsShiftRight() || expr->IsAdd() || expr->IsSubtract() ||
-      expr->IsDivide() || expr->IsModulo()) {
-    SetType(expr, TypeOf(expr->lhs())->UnwrapRef());
-    return true;
-  }
-  // Result type is a scalar or vector of boolean type
-  if (expr->IsLogicalAnd() || expr->IsLogicalOr() || expr->IsEqual() ||
-      expr->IsNotEqual() || expr->IsLessThan() || expr->IsGreaterThan() ||
-      expr->IsLessThanEqual() || expr->IsGreaterThanEqual()) {
-    auto* bool_type = builder_->create<sem::Bool>();
-    auto* param_type = TypeOf(expr->lhs())->UnwrapRef();
-    sem::Type* result_type = bool_type;
-    if (auto* vec = param_type->As<sem::Vector>()) {
-      result_type = builder_->create<sem::Vector>(bool_type, vec->size());
-    }
-    SetType(expr, result_type);
-    return true;
-  }
-  if (expr->IsMultiply()) {
-    auto* lhs_type = TypeOf(expr->lhs())->UnwrapRef();
-    auto* rhs_type = TypeOf(expr->rhs())->UnwrapRef();
-
-    // Note, the ordering here matters. The later checks depend on the prior
-    // checks having been done.
-    auto* lhs_mat = lhs_type->As<sem::Matrix>();
-    auto* rhs_mat = rhs_type->As<sem::Matrix>();
-    auto* lhs_vec = lhs_type->As<sem::Vector>();
-    auto* rhs_vec = rhs_type->As<sem::Vector>();
-    const sem::Type* result_type = nullptr;
-    if (lhs_mat && rhs_mat) {
-      auto* column_type =
-          builder_->create<sem::Vector>(lhs_mat->type(), lhs_mat->rows());
-      result_type =
-          builder_->create<sem::Matrix>(column_type, rhs_mat->columns());
-    } else if (lhs_mat && rhs_vec) {
-      result_type =
-          builder_->create<sem::Vector>(lhs_mat->type(), lhs_mat->rows());
-    } else if (lhs_vec && rhs_mat) {
-      result_type =
-          builder_->create<sem::Vector>(rhs_mat->type(), rhs_mat->columns());
-    } else if (lhs_mat) {
-      // matrix * scalar
-      result_type = lhs_type;
-    } else if (rhs_mat) {
-      // scalar * matrix
-      result_type = rhs_type;
-    } else if (lhs_vec && rhs_vec) {
-      result_type = lhs_type;
-    } else if (lhs_vec) {
-      // Vector * scalar
-      result_type = lhs_type;
-    } else if (rhs_vec) {
-      // Scalar * vector
-      result_type = rhs_type;
-    } else {
-      // Scalar * Scalar
-      result_type = lhs_type;
-    }
-
-    SetType(expr, result_type);
-    return true;
-  }
-
-  diagnostics_.add_error("Unknown binary expression", expr->source());
-  return false;
-}
-
 bool Resolver::UnaryOp(ast::UnaryOpExpression* unary) {
   Mark(unary->expr());
 
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index f002f97..8bf89ac 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -236,7 +236,6 @@
                                      uint32_t el_align,
                                      const Source& source);
   bool ValidateAssignment(const ast::AssignmentStatement* a);
-  bool ValidateBinary(ast::BinaryExpression* expr);
   bool ValidateEntryPoint(const ast::Function* func, const FunctionInfo* info);
   bool ValidateFunction(const ast::Function* func, const FunctionInfo* info);
   bool ValidateGlobalVariable(const VariableInfo* var);