resolver: Refactor binary operator type resolution

This same logic will be used for resolving and validating compound
assignment statements, so pull the core out into a separate function
that decouples it from ast::BinaryExpression.

Bug: tint:1325
Change-Id: Ibdb5a7fc8153dac0dd7f9ae3d5164e23585068cd
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/74360
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/ast/binary_expression.h b/src/tint/ast/binary_expression.h
index 7f0f5e7..bcacdd4 100644
--- a/src/tint/ast/binary_expression.h
+++ b/src/tint/ast/binary_expression.h
@@ -122,7 +122,9 @@
   const Expression* const rhs;
 };
 
-inline bool BinaryExpression::IsArithmetic() const {
+/// @param op the operator
+/// @returns true if the op is an arithmetic operation
+inline bool IsArithmetic(BinaryOp op) {
   switch (op) {
     case ast::BinaryOp::kAdd:
     case ast::BinaryOp::kSubtract:
@@ -135,7 +137,9 @@
   }
 }
 
-inline bool BinaryExpression::IsComparison() const {
+/// @param op the operator
+/// @returns true if the op is a comparison operation
+inline bool IsComparison(BinaryOp op) {
   switch (op) {
     case ast::BinaryOp::kEqual:
     case ast::BinaryOp::kNotEqual:
@@ -149,7 +153,9 @@
   }
 }
 
-inline bool BinaryExpression::IsBitwise() const {
+/// @param op the operator
+/// @returns true if the op is a bitwise operation
+inline bool IsBitwise(BinaryOp op) {
   switch (op) {
     case ast::BinaryOp::kAnd:
     case ast::BinaryOp::kOr:
@@ -160,7 +166,9 @@
   }
 }
 
-inline bool BinaryExpression::IsBitshift() const {
+/// @param op the operator
+/// @returns true if the op is a bit shift operation
+inline bool IsBitshift(BinaryOp op) {
   switch (op) {
     case ast::BinaryOp::kShiftLeft:
     case ast::BinaryOp::kShiftRight:
@@ -180,6 +188,22 @@
   }
 }
 
+inline bool BinaryExpression::IsArithmetic() const {
+  return ast::IsArithmetic(op);
+}
+
+inline bool BinaryExpression::IsComparison() const {
+  return ast::IsComparison(op);
+}
+
+inline bool BinaryExpression::IsBitwise() const {
+  return ast::IsBitwise(op);
+}
+
+inline bool BinaryExpression::IsBitshift() const {
+  return ast::IsBitshift(op);
+}
+
 /// @returns the human readable name of the given BinaryOp
 /// @param op the BinaryOp
 constexpr const char* FriendlyName(BinaryOp op) {
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index 6068a73..e10425a 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -1838,6 +1838,33 @@
 }
 
 sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
+  auto* lhs = Sem(expr->lhs);
+  auto* rhs = Sem(expr->rhs);
+  auto* lhs_ty = lhs->Type()->UnwrapRef();
+  auto* rhs_ty = rhs->Type()->UnwrapRef();
+
+  auto* ty = BinaryOpType(lhs_ty, rhs_ty, expr->op);
+  if (!ty) {
+    AddError(
+        "Binary expression operand types are invalid for this operation: " +
+            TypeNameOf(lhs_ty) + " " + FriendlyName(expr->op) + " " +
+            TypeNameOf(rhs_ty),
+        expr->source);
+    return nullptr;
+  }
+
+  auto val = EvaluateConstantValue(expr, ty);
+  bool has_side_effects = lhs->HasSideEffects() || rhs->HasSideEffects();
+  auto* sem = builder_->create<sem::Expression>(expr, ty, current_statement_,
+                                                val, has_side_effects);
+  sem->Behaviors() = lhs->Behaviors() + rhs->Behaviors();
+
+  return sem;
+}
+
+const sem::Type* Resolver::BinaryOpType(const sem::Type* lhs_ty,
+                                        const sem::Type* rhs_ty,
+                                        ast::BinaryOp op) {
   using Bool = sem::Bool;
   using F32 = sem::F32;
   using I32 = sem::I32;
@@ -1845,12 +1872,6 @@
   using Matrix = sem::Matrix;
   using Vector = sem::Vector;
 
-  auto* lhs = Sem(expr->lhs);
-  auto* rhs = Sem(expr->rhs);
-
-  auto* lhs_ty = lhs->Type()->UnwrapRef();
-  auto* rhs_ty = rhs->Type()->UnwrapRef();
-
   auto* lhs_vec = lhs_ty->As<Vector>();
   auto* lhs_vec_elem_type = lhs_vec ? lhs_vec->type() : nullptr;
   auto* rhs_vec = rhs_ty->As<Vector>();
@@ -1863,51 +1884,42 @@
 
   const bool matching_types = matching_vec_elem_types || (lhs_ty == rhs_ty);
 
-  auto build = [&](const sem::Type* ty) {
-    auto val = EvaluateConstantValue(expr, ty);
-    bool has_side_effects = lhs->HasSideEffects() || rhs->HasSideEffects();
-    auto* sem = builder_->create<sem::Expression>(expr, ty, current_statement_,
-                                                  val, has_side_effects);
-    sem->Behaviors() = lhs->Behaviors() + rhs->Behaviors();
-    return sem;
-  };
-
   // Binary logical expressions
-  if (expr->IsLogicalAnd() || expr->IsLogicalOr()) {
+  if (op == ast::BinaryOp::kLogicalAnd || op == ast::BinaryOp::kLogicalOr) {
     if (matching_types && lhs_ty->Is<Bool>()) {
-      return build(lhs_ty);
+      return lhs_ty;
     }
   }
-  if (expr->IsOr() || expr->IsAnd()) {
+  if (op == ast::BinaryOp::kOr || op == ast::BinaryOp::kAnd) {
     if (matching_types && lhs_ty->Is<Bool>()) {
-      return build(lhs_ty);
+      return lhs_ty;
     }
     if (matching_types && lhs_vec_elem_type && lhs_vec_elem_type->Is<Bool>()) {
-      return build(lhs_ty);
+      return lhs_ty;
     }
   }
 
   // Arithmetic expressions
-  if (expr->IsArithmetic()) {
+  if (ast::IsArithmetic(op)) {
     // Binary arithmetic expressions over scalars
     if (matching_types && lhs_ty->is_numeric_scalar()) {
-      return build(lhs_ty);
+      return lhs_ty;
     }
 
     // Binary arithmetic expressions over vectors
     if (matching_types && lhs_vec_elem_type &&
         lhs_vec_elem_type->is_numeric_scalar()) {
-      return build(lhs_ty);
+      return lhs_ty;
     }
 
     // Binary arithmetic expressions with mixed scalar and vector operands
     if (lhs_vec_elem_type && (lhs_vec_elem_type == rhs_ty) &&
         rhs_ty->is_numeric_scalar()) {
-      return build(lhs_ty);
+      return lhs_ty;
     }
     if (rhs_vec_elem_type && (rhs_vec_elem_type == lhs_ty) &&
         lhs_ty->is_numeric_scalar()) {
-      return build(rhs_ty);
+      return rhs_ty;
     }
   }
 
@@ -1917,106 +1929,101 @@
   auto* rhs_mat = rhs_ty->As<Matrix>();
   auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr;
   // Addition and subtraction of float matrices
-  if ((expr->IsAdd() || expr->IsSubtract()) && lhs_mat_elem_type &&
-      lhs_mat_elem_type->Is<F32>() && rhs_mat_elem_type &&
+  if ((op == ast::BinaryOp::kAdd || op == ast::BinaryOp::kSubtract) &&
+      lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() && rhs_mat_elem_type &&
       rhs_mat_elem_type->Is<F32>() &&
       (lhs_mat->columns() == rhs_mat->columns()) &&
       (lhs_mat->rows() == rhs_mat->rows())) {
-    return build(rhs_ty);
+    return rhs_ty;
   }
-  if (expr->IsMultiply()) {
+  if (op == ast::BinaryOp::kMultiply) {
     // Multiplication of a matrix and a scalar
     if (lhs_ty->Is<F32>() && rhs_mat_elem_type &&
         rhs_mat_elem_type->Is<F32>()) {
-      return build(rhs_ty);
+      return rhs_ty;
     }
     if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
         rhs_ty->Is<F32>()) {
-      return build(lhs_ty);
+      return lhs_ty;
     }
 
     // Vector times matrix
     if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() &&
         rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() &&
         (lhs_vec->Width() == rhs_mat->rows())) {
-      return build(
-          builder_->create<sem::Vector>(lhs_vec->type(), rhs_mat->columns()));
+      return builder_->create<sem::Vector>(lhs_vec->type(), rhs_mat->columns());
     }
 
     // Matrix times vector
     if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
         rhs_vec_elem_type && rhs_vec_elem_type->Is<F32>() &&
         (lhs_mat->columns() == rhs_vec->Width())) {
-      return build(
-          builder_->create<sem::Vector>(rhs_vec->type(), lhs_mat->rows()));
+      return builder_->create<sem::Vector>(rhs_vec->type(), lhs_mat->rows());
     }
 
     // Matrix times matrix
     if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
         rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() &&
         (lhs_mat->columns() == rhs_mat->rows())) {
-      return build(builder_->create<sem::Matrix>(
+      return builder_->create<sem::Matrix>(
           builder_->create<sem::Vector>(lhs_mat_elem_type, lhs_mat->rows()),
-          rhs_mat->columns()));
+          rhs_mat->columns());
     }
   }
 
   // Comparison expressions
-  if (expr->IsComparison()) {
+  if (ast::IsComparison(op)) {
     if (matching_types) {
       // Special case for bools: only == and !=
-      if (lhs_ty->Is<Bool>() && (expr->IsEqual() || expr->IsNotEqual())) {
-        return build(builder_->create<sem::Bool>());
+      if (lhs_ty->Is<Bool>() &&
+          (op == ast::BinaryOp::kEqual || op == ast::BinaryOp::kNotEqual)) {
+        return builder_->create<sem::Bool>();
       }
 
       // For the rest, we can compare i32, u32, and f32
       if (lhs_ty->IsAnyOf<I32, U32, F32>()) {
-        return build(builder_->create<sem::Bool>());
+        return builder_->create<sem::Bool>();
       }
     }
 
     // Same for vectors
     if (matching_vec_elem_types) {
       if (lhs_vec_elem_type->Is<Bool>() &&
-          (expr->IsEqual() || expr->IsNotEqual())) {
-        return build(builder_->create<sem::Vector>(
-            builder_->create<sem::Bool>(), lhs_vec->Width()));
+          (op == ast::BinaryOp::kEqual || op == ast::BinaryOp::kNotEqual)) {
+        return builder_->create<sem::Vector>(builder_->create<sem::Bool>(),
+                                             lhs_vec->Width());
       }
 
       if (lhs_vec_elem_type->is_numeric_scalar()) {
-        return build(builder_->create<sem::Vector>(
-            builder_->create<sem::Bool>(), lhs_vec->Width()));
+        return builder_->create<sem::Vector>(builder_->create<sem::Bool>(),
+                                             lhs_vec->Width());
       }
     }
   }
 
   // Binary bitwise operations
-  if (expr->IsBitwise()) {
+  if (ast::IsBitwise(op)) {
     if (matching_types && lhs_ty->is_integer_scalar_or_vector()) {
-      return build(lhs_ty);
+      return lhs_ty;
     }
   }
 
   // Bit shift expressions
-  if (expr->IsBitshift()) {
+  if (ast::IsBitshift(op)) {
     // Type validation rules are the same for left or right shift, despite
     // differences in computation rules (i.e. right shift can be arithmetic or
     // logical depending on lhs type).
 
     if (lhs_ty->IsAnyOf<I32, U32>() && rhs_ty->Is<U32>()) {
-      return build(lhs_ty);
+      return lhs_ty;
     }
 
     if (lhs_vec_elem_type && lhs_vec_elem_type->IsAnyOf<I32, U32>() &&
         rhs_vec_elem_type && rhs_vec_elem_type->Is<U32>()) {
-      return build(lhs_ty);
+      return lhs_ty;
     }
   }
 
-  AddError("Binary expression operand types are invalid for this operation: " +
-               TypeNameOf(lhs_ty) + " " + FriendlyName(expr->op) + " " +
-               TypeNameOf(rhs_ty),
-           expr->source);
   return nullptr;
 }
 
diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h
index fe7e865..7c3d217 100644
--- a/src/tint/resolver/resolver.h
+++ b/src/tint/resolver/resolver.h
@@ -229,6 +229,12 @@
   sem::Statement* VariableDeclStatement(const ast::VariableDeclStatement*);
   bool Statements(const ast::StatementList&);
 
+  // Resolve the result type of a binary operator.
+  // Returns nullptr if the types are not valid for this operator.
+  const sem::Type* BinaryOpType(const sem::Type* lhs_ty,
+                                const sem::Type* rhs_ty,
+                                ast::BinaryOp op);
+
   // AST and Type validation methods
   // Each return true on success, false on failure.
   bool ValidateAlias(const ast::Alias*);