Validate binary operations

This change validates that the operand types and result type of every
binary operation is valid.

* Added two unit tests which test all valid and invalid param combos. I
also removed the old tests, many of which failed once I added this
validation, and the rest are obviated by the new tests.

* Fixed VertexPulling transform, as well as many tests, that were using
invalid operand types for binary operations.

Fixed: tint:354
Change-Id: Ia3f48384256993da61b341f17ba5583741011819
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/44341
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/ast/binary_expression.h b/src/ast/binary_expression.h
index 84834ac..b9b8c1c 100644
--- a/src/ast/binary_expression.h
+++ b/src/ast/binary_expression.h
@@ -23,11 +23,11 @@
 /// The operator type
 enum class BinaryOp {
   kNone = 0,
-  kAnd,
-  kOr,
+  kAnd,  // &
+  kOr,   // |
   kXor,
-  kLogicalAnd,
-  kLogicalOr,
+  kLogicalAnd,  // &&
+  kLogicalOr,   // ||
   kEqual,
   kNotEqual,
   kLessThan,
@@ -98,6 +98,14 @@
   bool IsDivide() const { return op_ == BinaryOp::kDivide; }
   /// @returns true if the op is modulo
   bool IsModulo() const { return op_ == BinaryOp::kModulo; }
+  /// @returns true if the op is an arithmetic operation
+  bool IsArithmetic() const;
+  /// @returns true if the op is a comparison operation
+  bool IsComparison() const;
+  /// @returns true if the op is a bitwise operation
+  bool IsBitwise() const;
+  /// @returns true if the op is a bit shift operation
+  bool IsBitshift() const;
 
   /// @returns the left side expression
   Expression* lhs() const { return lhs_; }
@@ -126,6 +134,54 @@
   Expression* const rhs_;
 };
 
+inline bool BinaryExpression::IsArithmetic() const {
+  switch (op_) {
+    case ast::BinaryOp::kAdd:
+    case ast::BinaryOp::kSubtract:
+    case ast::BinaryOp::kMultiply:
+    case ast::BinaryOp::kDivide:
+    case ast::BinaryOp::kModulo:
+      return true;
+    default:
+      return false;
+  }
+}
+
+inline bool BinaryExpression::IsComparison() const {
+  switch (op_) {
+    case ast::BinaryOp::kEqual:
+    case ast::BinaryOp::kNotEqual:
+    case ast::BinaryOp::kLessThan:
+    case ast::BinaryOp::kLessThanEqual:
+    case ast::BinaryOp::kGreaterThan:
+    case ast::BinaryOp::kGreaterThanEqual:
+      return true;
+    default:
+      return false;
+  }
+}
+
+inline bool BinaryExpression::IsBitwise() const {
+  switch (op_) {
+    case ast::BinaryOp::kAnd:
+    case ast::BinaryOp::kOr:
+    case ast::BinaryOp::kXor:
+      return true;
+    default:
+      return false;
+  }
+}
+
+inline bool BinaryExpression::IsBitshift() const {
+  switch (op_) {
+    case ast::BinaryOp::kShiftLeft:
+    case ast::BinaryOp::kShiftRight:
+      return true;
+    default:
+      return false;
+  }
+}
+
 inline std::ostream& operator<<(std::ostream& out, BinaryOp op) {
   switch (op) {
     case BinaryOp::kNone:
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 8b27d47..9198c39 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -877,11 +877,167 @@
   return true;
 }
 
+bool Resolver::ValidateBinary(ast::BinaryExpression* expr) {
+  using Bool = type::Bool;
+  using F32 = type::F32;
+  using I32 = type::I32;
+  using U32 = type::U32;
+  using Matrix = type::Matrix;
+  using Vector = type::Vector;
+
+  auto* lhs_type = TypeOf(expr->lhs())->UnwrapPtrIfNeeded();
+  auto* rhs_type = TypeOf(expr->rhs())->UnwrapPtrIfNeeded();
+
+  auto* lhs_vec = lhs_type->As<Vector>();
+  auto* lhs_vec_elem_type = lhs_vec ? lhs_vec->type() : nullptr;
+  auto* rhs_vec = rhs_type->As<Vector>();
+  auto* rhs_vec_elem_type = rhs_vec ? rhs_vec->type() : nullptr;
+
+  const bool matching_types = lhs_type == rhs_type;
+  const bool matching_vec_elem_types = lhs_vec_elem_type && rhs_vec_elem_type &&
+                                       (lhs_vec_elem_type == rhs_vec_elem_type);
+
+  // Binary logical expressions
+  if (expr->IsLogicalAnd() || expr->IsLogicalOr()) {
+    if (matching_types && lhs_type->Is<Bool>()) {
+      return true;
+    }
+  }
+  if (expr->IsOr() || expr->IsAnd()) {
+    if (matching_types && lhs_type->Is<Bool>()) {
+      return true;
+    }
+    if (matching_types && lhs_vec_elem_type && lhs_vec_elem_type->Is<Bool>()) {
+      return true;
+    }
+  }
+
+  // Arithmetic expressions
+  if (expr->IsArithmetic()) {
+    // Binary arithmetic expressions over scalars
+    if (matching_types && lhs_type->IsAnyOf<I32, F32, U32>()) {
+      return true;
+    }
+
+    // Binary arithmetic expressions over vectors
+    if (matching_types && lhs_vec_elem_type &&
+        lhs_vec_elem_type->IsAnyOf<I32, F32, U32>()) {
+      return true;
+    }
+  }
+
+  // Binary arithmetic expressions with mixed scalar, vector, and matrix
+  // operands
+  if (expr->IsMultiply()) {
+    // Multiplication of a vector and a scalar
+    if (lhs_type->Is<F32>() && rhs_vec_elem_type &&
+        rhs_vec_elem_type->Is<F32>()) {
+      return true;
+    }
+    if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() &&
+        rhs_type->Is<F32>()) {
+      return true;
+    }
+
+    auto* lhs_mat = lhs_type->As<Matrix>();
+    auto* lhs_mat_elem_type = lhs_mat ? lhs_mat->type() : nullptr;
+    auto* rhs_mat = rhs_type->As<Matrix>();
+    auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr;
+
+    // Multiplication of a matrix and a scalar
+    if (lhs_type->Is<F32>() && rhs_mat_elem_type &&
+        rhs_mat_elem_type->Is<F32>()) {
+      return true;
+    }
+    if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
+        rhs_type->Is<F32>()) {
+      return true;
+    }
+
+    // Vector times matrix
+    if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() &&
+        rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>()) {
+      return true;
+    }
+
+    // Matrix times vector
+    if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
+        rhs_vec_elem_type && rhs_vec_elem_type->Is<F32>()) {
+      return true;
+    }
+
+    // Matrix times matrix
+    if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
+        rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>()) {
+      return true;
+    }
+  }
+
+  // Comparison expressions
+  if (expr->IsComparison()) {
+    if (matching_types) {
+      // Special case for bools: only == and !=
+      if (lhs_type->Is<Bool>() && (expr->IsEqual() || expr->IsNotEqual())) {
+        return true;
+      }
+
+      // For the rest, we can compare i32, u32, and f32
+      if (lhs_type->IsAnyOf<I32, U32, F32>()) {
+        return true;
+      }
+    }
+
+    // Same for vectors
+    if (matching_vec_elem_types) {
+      if (lhs_vec_elem_type->Is<Bool>() &&
+          (expr->IsEqual() || expr->IsNotEqual())) {
+        return true;
+      }
+
+      if (lhs_vec_elem_type->IsAnyOf<I32, U32, F32>()) {
+        return true;
+      }
+    }
+  }
+
+  // Binary bitwise operations
+  if (expr->IsBitwise()) {
+    if (matching_types && lhs_type->IsAnyOf<I32, U32>()) {
+      return true;
+    }
+  }
+
+  // Bit shift expressions
+  if (expr->IsBitshift()) {
+    // Type validation rules are the same for left or right shift, despite
+    // differences in computation rules (i.e. right shift can be arithmetic or
+    // logical depending on lhs type).
+
+    if (lhs_type->IsAnyOf<I32, U32>() && rhs_type->Is<U32>()) {
+      return true;
+    }
+
+    if (lhs_vec_elem_type && lhs_vec_elem_type->IsAnyOf<I32, U32>() &&
+        rhs_vec_elem_type && rhs_vec_elem_type->Is<U32>()) {
+      return true;
+    }
+  }
+
+  diagnostics_.add_error(
+      "Binary expression operand types are invalid for this operation",
+      expr->source());
+  return false;
+}
+
 bool Resolver::Binary(ast::BinaryExpression* expr) {
   if (!Expression(expr->lhs()) || !Expression(expr->rhs())) {
     return false;
   }
 
+  if (!ValidateBinary(expr)) {
+    return false;
+  }
+
   // Result type matches first parameter type
   if (expr->IsAnd() || expr->IsOr() || expr->IsXor() || expr->IsShiftLeft() ||
       expr->IsShiftRight() || expr->IsAdd() || expr->IsSubtract() ||
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index ddf7c11..d8a3d81 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -171,6 +171,7 @@
   // AST and Type traversal methods
   // Each return true on success, false on failure.
   bool ArrayAccessor(ast::ArrayAccessorExpression*);
+  bool ValidateBinary(ast::BinaryExpression* expr);
   bool Binary(ast::BinaryExpression*);
   bool Bitcast(ast::BitcastExpression*);
   bool BlockStatement(const ast::BlockStatement*);
diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc
index caf608e..7ce593f 100644
--- a/src/resolver/resolver_test.cc
+++ b/src/resolver/resolver_test.cc
@@ -14,6 +14,8 @@
 
 #include "src/resolver/resolver.h"
 
+#include <tuple>
+
 #include "gmock/gmock.h"
 #include "src/ast/assignment_statement.h"
 #include "src/ast/bitcast_expression.h"
@@ -971,246 +973,276 @@
   EXPECT_TRUE(TypeOf(expr)->Is<type::F32>());
 }
 
-using Expr_Binary_BitwiseTest = ResolverTestWithParam<ast::BinaryOp>;
-TEST_P(Expr_Binary_BitwiseTest, Scalar) {
-  auto op = GetParam();
+namespace ExprBinaryTest {
 
-  Global("val", ty.i32(), ast::StorageClass::kNone);
+using create_type_func_ptr =
+    type::Type* (*)(const ProgramBuilder::TypesBuilder& ty);
 
-  auto* expr = create<ast::BinaryExpression>(op, Expr("val"), Expr("val"));
-  WrapInFunction(expr);
+struct Params {
+  ast::BinaryOp op;
+  create_type_func_ptr create_lhs_type;
+  create_type_func_ptr create_rhs_type;
+  create_type_func_ptr create_result_type;
+};
 
-  ASSERT_TRUE(r()->Resolve()) << r()->error();
-  ASSERT_NE(TypeOf(expr), nullptr);
-  EXPECT_TRUE(TypeOf(expr)->Is<type::I32>());
+// Helpers and typedefs to make building the table below more succinct
+
+using i32 = ProgramBuilder::i32;
+using u32 = ProgramBuilder::u32;
+using f32 = ProgramBuilder::f32;
+using Op = ast::BinaryOp;
+
+type::Type* ty_bool_(const ProgramBuilder::TypesBuilder& ty) {
+  return ty.bool_();
+}
+type::Type* ty_i32(const ProgramBuilder::TypesBuilder& ty) {
+  return ty.i32();
+}
+type::Type* ty_u32(const ProgramBuilder::TypesBuilder& ty) {
+  return ty.u32();
+}
+type::Type* ty_f32(const ProgramBuilder::TypesBuilder& ty) {
+  return ty.f32();
 }
 
-TEST_P(Expr_Binary_BitwiseTest, Vector) {
-  auto op = GetParam();
+template <typename T>
+type::Type* ty_vec3(const ProgramBuilder::TypesBuilder& ty) {
+  return ty.vec3<T>();
+}
 
-  Global("val", ty.vec3<i32>(), ast::StorageClass::kNone);
+template <typename T>
+type::Type* ty_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
+  return ty.mat3x3<T>();
+}
 
-  auto* expr = create<ast::BinaryExpression>(op, Expr("val"), Expr("val"));
+static constexpr create_type_func_ptr all_create_type_funcs[] = {
+    ty_bool_,       ty_u32,         ty_i32,        ty_f32,
+    ty_vec3<bool>,  ty_vec3<i32>,   ty_vec3<u32>,  ty_vec3<f32>,
+    ty_mat3x3<i32>, ty_mat3x3<u32>, ty_mat3x3<f32>};
+
+// A list of all valid test cases for 'lhs op rhs', except that for vecN and
+// matNxN, we only test N=3.
+static constexpr Params all_valid_cases[] = {
+    // Logical expressions
+    // https://gpuweb.github.io/gpuweb/wgsl.html#logical-expr
+
+    // Binary logical expressions
+    Params{Op::kLogicalAnd, ty_bool_, ty_bool_, ty_bool_},
+    Params{Op::kLogicalOr, ty_bool_, ty_bool_, ty_bool_},
+
+    Params{Op::kAnd, ty_bool_, ty_bool_, ty_bool_},
+    Params{Op::kOr, ty_bool_, ty_bool_, ty_bool_},
+    Params{Op::kAnd, ty_vec3<bool>, ty_vec3<bool>, ty_vec3<bool>},
+    Params{Op::kOr, ty_vec3<bool>, ty_vec3<bool>, ty_vec3<bool>},
+
+    // Arithmetic expressions
+    // https://gpuweb.github.io/gpuweb/wgsl.html#arithmetic-expr
+
+    // Binary arithmetic expressions over scalars
+    Params{Op::kAdd, ty_i32, ty_i32, ty_i32},
+    Params{Op::kSubtract, ty_i32, ty_i32, ty_i32},
+    Params{Op::kMultiply, ty_i32, ty_i32, ty_i32},
+    Params{Op::kDivide, ty_i32, ty_i32, ty_i32},
+    Params{Op::kModulo, ty_i32, ty_i32, ty_i32},
+
+    Params{Op::kAdd, ty_u32, ty_u32, ty_u32},
+    Params{Op::kSubtract, ty_u32, ty_u32, ty_u32},
+    Params{Op::kMultiply, ty_u32, ty_u32, ty_u32},
+    Params{Op::kDivide, ty_u32, ty_u32, ty_u32},
+    Params{Op::kModulo, ty_u32, ty_u32, ty_u32},
+
+    Params{Op::kAdd, ty_f32, ty_f32, ty_f32},
+    Params{Op::kSubtract, ty_f32, ty_f32, ty_f32},
+    Params{Op::kMultiply, ty_f32, ty_f32, ty_f32},
+    Params{Op::kDivide, ty_f32, ty_f32, ty_f32},
+    Params{Op::kModulo, ty_f32, ty_f32, ty_f32},
+
+    // Binary arithmetic expressions over vectors
+    Params{Op::kAdd, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<i32>},
+    Params{Op::kSubtract, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<i32>},
+    Params{Op::kMultiply, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<i32>},
+    Params{Op::kDivide, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<i32>},
+    Params{Op::kModulo, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<i32>},
+
+    Params{Op::kAdd, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<u32>},
+    Params{Op::kSubtract, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<u32>},
+    Params{Op::kMultiply, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<u32>},
+    Params{Op::kDivide, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<u32>},
+    Params{Op::kModulo, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<u32>},
+
+    Params{Op::kAdd, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<f32>},
+    Params{Op::kSubtract, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<f32>},
+    Params{Op::kMultiply, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<f32>},
+    Params{Op::kDivide, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<f32>},
+    Params{Op::kModulo, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<f32>},
+
+    // Binary arithmetic expressions with mixed scalar, vector, and matrix
+    // operands
+    Params{Op::kMultiply, ty_vec3<f32>, ty_f32, ty_vec3<f32>},
+    Params{Op::kMultiply, ty_f32, ty_vec3<f32>, ty_vec3<f32>},
+
+    Params{Op::kMultiply, ty_mat3x3<f32>, ty_f32, ty_mat3x3<f32>},
+    Params{Op::kMultiply, ty_f32, ty_mat3x3<f32>, ty_mat3x3<f32>},
+
+    Params{Op::kMultiply, ty_vec3<f32>, ty_mat3x3<f32>, ty_vec3<f32>},
+    Params{Op::kMultiply, ty_mat3x3<f32>, ty_vec3<f32>, ty_vec3<f32>},
+    Params{Op::kMultiply, ty_mat3x3<f32>, ty_mat3x3<f32>, ty_mat3x3<f32>},
+
+    // Comparison expressions
+    // https://gpuweb.github.io/gpuweb/wgsl.html#comparison-expr
+
+    // Comparisons over scalars
+    Params{Op::kEqual, ty_bool_, ty_bool_, ty_bool_},
+    Params{Op::kNotEqual, ty_bool_, ty_bool_, ty_bool_},
+
+    Params{Op::kEqual, ty_i32, ty_i32, ty_bool_},
+    Params{Op::kNotEqual, ty_i32, ty_i32, ty_bool_},
+    Params{Op::kLessThan, ty_i32, ty_i32, ty_bool_},
+    Params{Op::kLessThanEqual, ty_i32, ty_i32, ty_bool_},
+    Params{Op::kGreaterThan, ty_i32, ty_i32, ty_bool_},
+    Params{Op::kGreaterThanEqual, ty_i32, ty_i32, ty_bool_},
+
+    Params{Op::kEqual, ty_u32, ty_u32, ty_bool_},
+    Params{Op::kNotEqual, ty_u32, ty_u32, ty_bool_},
+    Params{Op::kLessThan, ty_u32, ty_u32, ty_bool_},
+    Params{Op::kLessThanEqual, ty_u32, ty_u32, ty_bool_},
+    Params{Op::kGreaterThan, ty_u32, ty_u32, ty_bool_},
+    Params{Op::kGreaterThanEqual, ty_u32, ty_u32, ty_bool_},
+
+    Params{Op::kEqual, ty_f32, ty_f32, ty_bool_},
+    Params{Op::kNotEqual, ty_f32, ty_f32, ty_bool_},
+    Params{Op::kLessThan, ty_f32, ty_f32, ty_bool_},
+    Params{Op::kLessThanEqual, ty_f32, ty_f32, ty_bool_},
+    Params{Op::kGreaterThan, ty_f32, ty_f32, ty_bool_},
+    Params{Op::kGreaterThanEqual, ty_f32, ty_f32, ty_bool_},
+
+    // Comparisons over vectors
+    Params{Op::kEqual, ty_vec3<bool>, ty_vec3<bool>, ty_vec3<bool>},
+    Params{Op::kNotEqual, ty_vec3<bool>, ty_vec3<bool>, ty_vec3<bool>},
+
+    Params{Op::kEqual, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<bool>},
+    Params{Op::kNotEqual, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<bool>},
+    Params{Op::kLessThan, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<bool>},
+    Params{Op::kLessThanEqual, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<bool>},
+    Params{Op::kGreaterThan, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<bool>},
+    Params{Op::kGreaterThanEqual, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<bool>},
+
+    Params{Op::kEqual, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<bool>},
+    Params{Op::kNotEqual, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<bool>},
+    Params{Op::kLessThan, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<bool>},
+    Params{Op::kLessThanEqual, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<bool>},
+    Params{Op::kGreaterThan, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<bool>},
+    Params{Op::kGreaterThanEqual, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<bool>},
+
+    Params{Op::kEqual, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<bool>},
+    Params{Op::kNotEqual, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<bool>},
+    Params{Op::kLessThan, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<bool>},
+    Params{Op::kLessThanEqual, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<bool>},
+    Params{Op::kGreaterThan, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<bool>},
+    Params{Op::kGreaterThanEqual, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<bool>},
+
+    // Bit expressions
+    // https://gpuweb.github.io/gpuweb/wgsl.html#bit-expr
+
+    // Binary bitwise operations
+    Params{Op::kOr, ty_i32, ty_i32, ty_i32},
+    Params{Op::kAnd, ty_i32, ty_i32, ty_i32},
+    Params{Op::kXor, ty_i32, ty_i32, ty_i32},
+
+    Params{Op::kOr, ty_u32, ty_u32, ty_u32},
+    Params{Op::kAnd, ty_u32, ty_u32, ty_u32},
+    Params{Op::kXor, ty_u32, ty_u32, ty_u32},
+
+    // Bit shift expressions
+    Params{Op::kShiftLeft, ty_i32, ty_u32, ty_i32},
+    Params{Op::kShiftLeft, ty_vec3<i32>, ty_vec3<u32>, ty_vec3<i32>},
+
+    Params{Op::kShiftLeft, ty_u32, ty_u32, ty_u32},
+    Params{Op::kShiftLeft, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<u32>},
+
+    Params{Op::kShiftRight, ty_i32, ty_u32, ty_i32},
+    Params{Op::kShiftRight, ty_vec3<i32>, ty_vec3<u32>, ty_vec3<i32>},
+
+    Params{Op::kShiftRight, ty_u32, ty_u32, ty_u32},
+    Params{Op::kShiftRight, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<u32>}};
+
+using Expr_Binary_Test_Valid = ResolverTestWithParam<Params>;
+TEST_P(Expr_Binary_Test_Valid, All) {
+  auto& params = GetParam();
+
+  auto* lhs_type = params.create_lhs_type(ty);
+  auto* rhs_type = params.create_rhs_type(ty);
+  auto* result_type = params.create_result_type(ty);
+
+  SCOPED_TRACE(testing::Message()
+               << lhs_type->FriendlyName(Symbols()) << " " << params.op << " "
+               << rhs_type->FriendlyName(Symbols()));
+
+  Global("lhs", lhs_type, ast::StorageClass::kNone);
+  Global("rhs", rhs_type, ast::StorageClass::kNone);
+
+  auto* expr =
+      create<ast::BinaryExpression>(params.op, Expr("lhs"), Expr("rhs"));
   WrapInFunction(expr);
 
   ASSERT_TRUE(r()->Resolve()) << r()->error();
   ASSERT_NE(TypeOf(expr), nullptr);
-  ASSERT_TRUE(TypeOf(expr)->Is<type::Vector>());
-  EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::I32>());
-  EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 3u);
+  ASSERT_TRUE(TypeOf(expr) == result_type);
 }
 INSTANTIATE_TEST_SUITE_P(ResolverTest,
-                         Expr_Binary_BitwiseTest,
-                         testing::Values(ast::BinaryOp::kAnd,
-                                         ast::BinaryOp::kOr,
-                                         ast::BinaryOp::kXor,
-                                         ast::BinaryOp::kShiftLeft,
-                                         ast::BinaryOp::kShiftRight,
-                                         ast::BinaryOp::kAdd,
-                                         ast::BinaryOp::kSubtract,
-                                         ast::BinaryOp::kDivide,
-                                         ast::BinaryOp::kModulo));
+                         Expr_Binary_Test_Valid,
+                         testing::ValuesIn(all_valid_cases));
 
-using Expr_Binary_LogicalTest = ResolverTestWithParam<ast::BinaryOp>;
-TEST_P(Expr_Binary_LogicalTest, Scalar) {
-  auto op = GetParam();
+using Expr_Binary_Test_Invalid =
+    ResolverTestWithParam<std::tuple<Params, create_type_func_ptr>>;
+TEST_P(Expr_Binary_Test_Invalid, All) {
+  const Params& params = std::get<0>(GetParam());
+  const create_type_func_ptr& create_type_func = std::get<1>(GetParam());
 
-  Global("val", ty.bool_(), ast::StorageClass::kNone);
+  // Currently, for most operations, for a given lhs type, there is exactly one
+  // rhs type allowed.  The only exception is for multiplication, which allows
+  // any permutation of f32, vecN<f32>, and matNxN<f32>. We are fed valid inputs
+  // only via `params`, and all possible types via `create_type_func`, so we
+  // test invalid combinations by testing every other rhs type, modulo
+  // exceptions.
 
-  auto* expr = create<ast::BinaryExpression>(op, Expr("val"), Expr("val"));
+  // Skip valid rhs type
+  if (params.create_rhs_type == create_type_func) {
+    return;
+  }
+
+  auto* lhs_type = params.create_lhs_type(ty);
+  auto* rhs_type = create_type_func(ty);
+
+  // Skip exceptions: multiplication of f32, vecN<f32>, and matNxN<f32>
+  if (params.op == Op::kMultiply &&
+      lhs_type->is_float_scalar_or_vector_or_matrix() &&
+      rhs_type->is_float_scalar_or_vector_or_matrix()) {
+    return;
+  }
+
+  SCOPED_TRACE(testing::Message()
+               << lhs_type->FriendlyName(Symbols()) << " " << params.op << " "
+               << rhs_type->FriendlyName(Symbols()));
+
+  Global("lhs", lhs_type, ast::StorageClass::kNone);
+  Global("rhs", rhs_type, ast::StorageClass::kNone);
+
+  auto* expr = create<ast::BinaryExpression>(
+      Source{Source::Location{12, 34}}, params.op, Expr("lhs"), Expr("rhs"));
   WrapInFunction(expr);
 
-  ASSERT_TRUE(r()->Resolve()) << r()->error();
-  ASSERT_NE(TypeOf(expr), nullptr);
-  EXPECT_TRUE(TypeOf(expr)->Is<type::Bool>());
+  ASSERT_FALSE(r()->Resolve()) << r()->error();
+  ASSERT_EQ(r()->error(),
+            "12:34 error: Binary expression operand types are invalid for "
+            "this operation");
 }
-
-TEST_P(Expr_Binary_LogicalTest, Vector) {
-  auto op = GetParam();
-
-  Global("val", ty.vec3<bool>(), ast::StorageClass::kNone);
-
-  auto* expr = create<ast::BinaryExpression>(op, Expr("val"), Expr("val"));
-  WrapInFunction(expr);
-
-  ASSERT_TRUE(r()->Resolve()) << r()->error();
-  ASSERT_NE(TypeOf(expr), nullptr);
-  ASSERT_TRUE(TypeOf(expr)->Is<type::Vector>());
-  EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::Bool>());
-  EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 3u);
-}
-INSTANTIATE_TEST_SUITE_P(ResolverTest,
-                         Expr_Binary_LogicalTest,
-                         testing::Values(ast::BinaryOp::kLogicalAnd,
-                                         ast::BinaryOp::kLogicalOr));
-
-using Expr_Binary_CompareTest = ResolverTestWithParam<ast::BinaryOp>;
-TEST_P(Expr_Binary_CompareTest, Scalar) {
-  auto op = GetParam();
-
-  Global("val", ty.i32(), ast::StorageClass::kNone);
-
-  auto* expr = create<ast::BinaryExpression>(op, Expr("val"), Expr("val"));
-  WrapInFunction(expr);
-
-  ASSERT_TRUE(r()->Resolve()) << r()->error();
-  ASSERT_NE(TypeOf(expr), nullptr);
-  EXPECT_TRUE(TypeOf(expr)->Is<type::Bool>());
-}
-
-TEST_P(Expr_Binary_CompareTest, Vector) {
-  auto op = GetParam();
-
-  Global("val", ty.vec3<i32>(), ast::StorageClass::kNone);
-
-  auto* expr = create<ast::BinaryExpression>(op, Expr("val"), Expr("val"));
-  WrapInFunction(expr);
-
-  ASSERT_TRUE(r()->Resolve()) << r()->error();
-  ASSERT_NE(TypeOf(expr), nullptr);
-  ASSERT_TRUE(TypeOf(expr)->Is<type::Vector>());
-  EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::Bool>());
-  EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 3u);
-}
-INSTANTIATE_TEST_SUITE_P(ResolverTest,
-                         Expr_Binary_CompareTest,
-                         testing::Values(ast::BinaryOp::kEqual,
-                                         ast::BinaryOp::kNotEqual,
-                                         ast::BinaryOp::kLessThan,
-                                         ast::BinaryOp::kGreaterThan,
-                                         ast::BinaryOp::kLessThanEqual,
-                                         ast::BinaryOp::kGreaterThanEqual));
-
-TEST_F(ResolverTest, Expr_Binary_Multiply_Scalar_Scalar) {
-  Global("val", ty.i32(), ast::StorageClass::kNone);
-
-  auto* expr = Mul("val", "val");
-  WrapInFunction(expr);
-
-  ASSERT_TRUE(r()->Resolve()) << r()->error();
-  ASSERT_NE(TypeOf(expr), nullptr);
-  EXPECT_TRUE(TypeOf(expr)->Is<type::I32>());
-}
-
-TEST_F(ResolverTest, Expr_Binary_Multiply_Vector_Scalar) {
-  Global("scalar", ty.f32(), ast::StorageClass::kNone);
-  Global("vector", ty.vec3<f32>(), ast::StorageClass::kNone);
-
-  auto* expr = Mul("vector", "scalar");
-  WrapInFunction(expr);
-
-  ASSERT_TRUE(r()->Resolve()) << r()->error();
-  ASSERT_NE(TypeOf(expr), nullptr);
-  ASSERT_TRUE(TypeOf(expr)->Is<type::Vector>());
-  EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::F32>());
-  EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 3u);
-}
-
-TEST_F(ResolverTest, Expr_Binary_Multiply_Scalar_Vector) {
-  Global("scalar", ty.f32(), ast::StorageClass::kNone);
-  Global("vector", ty.vec3<f32>(), ast::StorageClass::kNone);
-
-  auto* expr = Mul("scalar", "vector");
-  WrapInFunction(expr);
-
-  ASSERT_TRUE(r()->Resolve()) << r()->error();
-  ASSERT_NE(TypeOf(expr), nullptr);
-  ASSERT_TRUE(TypeOf(expr)->Is<type::Vector>());
-  EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::F32>());
-  EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 3u);
-}
-
-TEST_F(ResolverTest, Expr_Binary_Multiply_Vector_Vector) {
-  Global("vector", ty.vec3<f32>(), ast::StorageClass::kNone);
-
-  auto* expr = Mul("vector", "vector");
-  WrapInFunction(expr);
-
-  ASSERT_TRUE(r()->Resolve()) << r()->error();
-  ASSERT_NE(TypeOf(expr), nullptr);
-  ASSERT_TRUE(TypeOf(expr)->Is<type::Vector>());
-  EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::F32>());
-  EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 3u);
-}
-
-TEST_F(ResolverTest, Expr_Binary_Multiply_Matrix_Scalar) {
-  Global("scalar", ty.f32(), ast::StorageClass::kNone);
-  Global("matrix", ty.mat2x3<f32>(), ast::StorageClass::kNone);
-
-  auto* expr = Mul("matrix", "scalar");
-  WrapInFunction(expr);
-
-  ASSERT_TRUE(r()->Resolve()) << r()->error();
-  ASSERT_NE(TypeOf(expr), nullptr);
-  ASSERT_TRUE(TypeOf(expr)->Is<type::Matrix>());
-
-  auto* mat = TypeOf(expr)->As<type::Matrix>();
-  EXPECT_TRUE(mat->type()->Is<type::F32>());
-  EXPECT_EQ(mat->rows(), 3u);
-  EXPECT_EQ(mat->columns(), 2u);
-}
-
-TEST_F(ResolverTest, Expr_Binary_Multiply_Scalar_Matrix) {
-  Global("scalar", ty.f32(), ast::StorageClass::kNone);
-  Global("matrix", ty.mat2x3<f32>(), ast::StorageClass::kNone);
-
-  auto* expr = Mul("scalar", "matrix");
-  WrapInFunction(expr);
-
-  ASSERT_TRUE(r()->Resolve()) << r()->error();
-  ASSERT_NE(TypeOf(expr), nullptr);
-  ASSERT_TRUE(TypeOf(expr)->Is<type::Matrix>());
-
-  auto* mat = TypeOf(expr)->As<type::Matrix>();
-  EXPECT_TRUE(mat->type()->Is<type::F32>());
-  EXPECT_EQ(mat->rows(), 3u);
-  EXPECT_EQ(mat->columns(), 2u);
-}
-
-TEST_F(ResolverTest, Expr_Binary_Multiply_Matrix_Vector) {
-  Global("vector", ty.vec3<f32>(), ast::StorageClass::kNone);
-  Global("matrix", ty.mat2x3<f32>(), ast::StorageClass::kNone);
-
-  auto* expr = Mul("matrix", "vector");
-  WrapInFunction(expr);
-
-  ASSERT_TRUE(r()->Resolve()) << r()->error();
-  ASSERT_NE(TypeOf(expr), nullptr);
-  ASSERT_TRUE(TypeOf(expr)->Is<type::Vector>());
-  EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::F32>());
-  EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 3u);
-}
-
-TEST_F(ResolverTest, Expr_Binary_Multiply_Vector_Matrix) {
-  Global("vector", ty.vec3<f32>(), ast::StorageClass::kNone);
-  Global("matrix", ty.mat2x3<f32>(), ast::StorageClass::kNone);
-
-  auto* expr = Mul("vector", "matrix");
-  WrapInFunction(expr);
-
-  ASSERT_TRUE(r()->Resolve()) << r()->error();
-  ASSERT_NE(TypeOf(expr), nullptr);
-  ASSERT_TRUE(TypeOf(expr)->Is<type::Vector>());
-  EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::F32>());
-  EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 2u);
-}
-
-TEST_F(ResolverTest, Expr_Binary_Multiply_Matrix_Matrix) {
-  Global("mat3x4", ty.mat3x4<f32>(), ast::StorageClass::kNone);
-  Global("mat4x3", ty.mat4x3<f32>(), ast::StorageClass::kNone);
-
-  auto* expr = Mul("mat3x4", "mat4x3");
-  WrapInFunction(expr);
-
-  ASSERT_TRUE(r()->Resolve()) << r()->error();
-  ASSERT_NE(TypeOf(expr), nullptr);
-  ASSERT_TRUE(TypeOf(expr)->Is<type::Matrix>());
-
-  auto* mat = TypeOf(expr)->As<type::Matrix>();
-  EXPECT_TRUE(mat->type()->Is<type::F32>());
-  EXPECT_EQ(mat->rows(), 4u);
-  EXPECT_EQ(mat->columns(), 4u);
-}
+INSTANTIATE_TEST_SUITE_P(
+    ResolverTest,
+    Expr_Binary_Test_Invalid,
+    testing::Combine(testing::ValuesIn(all_valid_cases),
+                     testing::ValuesIn(all_create_type_funcs)));
+}  // namespace ExprBinaryTest
 
 using UnaryOpExpressionTest = ResolverTestWithParam<ast::UnaryOp>;
 TEST_P(UnaryOpExpressionTest, Expr_UnaryOp) {
diff --git a/src/transform/bound_array_accessors_test.cc b/src/transform/bound_array_accessors_test.cc
index 67a831c..7eedc18 100644
--- a/src/transform/bound_array_accessors_test.cc
+++ b/src/transform/bound_array_accessors_test.cc
@@ -104,7 +104,7 @@
   auto* src = R"(
 var a : array<f32, 3>;
 
-var c : u32;
+var c : i32;
 
 fn f() -> void {
   var b : f32 = a[c + 2 - 3];
@@ -114,7 +114,7 @@
   auto* expect = R"(
 var a : array<f32, 3>;
 
-var c : u32;
+var c : i32;
 
 fn f() -> void {
   var b : f32 = a[min(u32(((c + 2) - 3)), 2u)];
@@ -196,7 +196,7 @@
   auto* src = R"(
 var a : vec3<f32>;
 
-var c : u32;
+var c : i32;
 
 fn f() -> void {
   var b : f32 = a[c + 2 - 3];
@@ -206,7 +206,7 @@
   auto* expect = R"(
 var a : vec3<f32>;
 
-var c : u32;
+var c : i32;
 
 fn f() -> void {
   var b : f32 = a[min(u32(((c + 2) - 3)), 2u)];
@@ -244,7 +244,7 @@
   auto* src = R"(
 var a : vec3<f32>;
 
-var c : u32;
+var c : i32;
 
 fn f() -> void {
   var b : f32 = a.xy[c];
@@ -254,7 +254,7 @@
   auto* expect = R"(
 var a : vec3<f32>;
 
-var c : u32;
+var c : i32;
 
 fn f() -> void {
   var b : f32 = a.xy[min(u32(c), 1u)];
@@ -269,7 +269,7 @@
   auto* src = R"(
 var a : vec3<f32>;
 
-var c : u32;
+var c : i32;
 
 fn f() -> void {
   var b : f32 = a.xy[c + 2 - 3];
@@ -279,7 +279,7 @@
   auto* expect = R"(
 var a : vec3<f32>;
 
-var c : u32;
+var c : i32;
 
 fn f() -> void {
   var b : f32 = a.xy[min(u32(((c + 2) - 3)), 1u)];
@@ -361,7 +361,7 @@
   auto* src = R"(
 var a : mat3x2<f32>;
 
-var c : u32;
+var c : i32;
 
 fn f() -> void {
   var b : f32 = a[c + 2 - 3][1];
@@ -371,7 +371,7 @@
   auto* expect = R"(
 var a : mat3x2<f32>;
 
-var c : u32;
+var c : i32;
 
 fn f() -> void {
   var b : f32 = a[min(u32(((c + 2) - 3)), 2u)][1];
@@ -387,7 +387,7 @@
   auto* src = R"(
 var a : mat3x2<f32>;
 
-var c : u32;
+var c : i32;
 
 fn f() -> void {
   var b : f32 = a[1][c + 2 - 3];
@@ -397,7 +397,7 @@
   auto* expect = R"(
 var a : mat3x2<f32>;
 
-var c : u32;
+var c : i32;
 
 fn f() -> void {
   var b : f32 = a[1][min(u32(((c + 2) - 3)), 1u)];
diff --git a/src/transform/vertex_pulling.cc b/src/transform/vertex_pulling.cc
index b82a4d5..78e96c7 100644
--- a/src/transform/vertex_pulling.cc
+++ b/src/transform/vertex_pulling.cc
@@ -132,7 +132,7 @@
       Source{},                                        // source
       ctx.dst->Symbols().Register(vertex_index_name),  // symbol
       ast::StorageClass::kInput,                       // storage_class
-      GetI32Type(),                                    // type
+      GetU32Type(),                                    // type
       false,                                           // is_const
       nullptr,                                         // constructor
       ast::DecorationList{
@@ -179,7 +179,7 @@
       Source{},                                          // source
       ctx.dst->Symbols().Register(instance_index_name),  // symbol
       ast::StorageClass::kInput,                         // storage_class
-      GetI32Type(),                                      // type
+      GetU32Type(),                                      // type
       false,                                             // is_const
       nullptr,                                           // constructor
       ast::DecorationList{
@@ -273,7 +273,7 @@
                     Source{},                                         // source
                     ctx.dst->Symbols().Register(kPullingPosVarName),  // symbol
                     ast::StorageClass::kFunction,  // storage_class
-                    GetI32Type(),                  // type
+                    GetU32Type(),                  // type
                     false,                         // is_const
                     nullptr,                       // constructor
                     ast::DecorationList{}));       // decorations
diff --git a/src/transform/vertex_pulling_test.cc b/src/transform/vertex_pulling_test.cc
index d36dcaa..c0b6e19 100644
--- a/src/transform/vertex_pulling_test.cc
+++ b/src/transform/vertex_pulling_test.cc
@@ -89,7 +89,7 @@
 [[stage(vertex)]]
 fn main() -> void {
   {
-    var _tint_pulling_pos : i32;
+    var _tint_pulling_pos : u32;
   }
 }
 )";
@@ -113,7 +113,7 @@
 )";
 
   auto* expect = R"(
-[[builtin(vertex_index)]] var<in> _tint_pulling_vertex_index : i32;
+[[builtin(vertex_index)]] var<in> _tint_pulling_vertex_index : u32;
 
 [[binding(0), group(4)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
 
@@ -127,7 +127,7 @@
 [[stage(vertex)]]
 fn main() -> void {
   {
-    var _tint_pulling_pos : i32;
+    var _tint_pulling_pos : u32;
     _tint_pulling_pos = ((_tint_pulling_vertex_index * 4u) + 0u);
     var_a = bitcast<f32>(_tint_pulling_vertex_buffer_0._tint_vertex_data[(_tint_pulling_pos / 4u)]);
   }
@@ -155,7 +155,7 @@
 )";
 
   auto* expect = R"(
-[[builtin(instance_index)]] var<in> _tint_pulling_instance_index : i32;
+[[builtin(instance_index)]] var<in> _tint_pulling_instance_index : u32;
 
 [[binding(0), group(4)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
 
@@ -169,7 +169,7 @@
 [[stage(vertex)]]
 fn main() -> void {
   {
-    var _tint_pulling_pos : i32;
+    var _tint_pulling_pos : u32;
     _tint_pulling_pos = ((_tint_pulling_instance_index * 4u) + 0u);
     var_a = bitcast<f32>(_tint_pulling_vertex_buffer_0._tint_vertex_data[(_tint_pulling_pos / 4u)]);
   }
@@ -197,7 +197,7 @@
 )";
 
   auto* expect = R"(
-[[builtin(vertex_index)]] var<in> _tint_pulling_vertex_index : i32;
+[[builtin(vertex_index)]] var<in> _tint_pulling_vertex_index : u32;
 
 [[binding(0), group(5)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
 
@@ -211,7 +211,7 @@
 [[stage(vertex)]]
 fn main() -> void {
   {
-    var _tint_pulling_pos : i32;
+    var _tint_pulling_pos : u32;
     _tint_pulling_pos = ((_tint_pulling_vertex_index * 4u) + 0u);
     var_a = bitcast<f32>(_tint_pulling_vertex_buffer_0._tint_vertex_data[(_tint_pulling_pos / 4u)]);
   }
@@ -236,8 +236,8 @@
   auto* src = R"(
 [[location(0)]] var<in> var_a : f32;
 [[location(1)]] var<in> var_b : f32;
-[[builtin(vertex_index)]] var<in> custom_vertex_index : i32;
-[[builtin(instance_index)]] var<in> custom_instance_index : i32;
+[[builtin(vertex_index)]] var<in> custom_vertex_index : u32;
+[[builtin(instance_index)]] var<in> custom_instance_index : u32;
 
 [[stage(vertex)]]
 fn main() -> void {}
@@ -257,14 +257,14 @@
 
 var<private> var_b : f32;
 
-[[builtin(vertex_index)]] var<in> custom_vertex_index : i32;
+[[builtin(vertex_index)]] var<in> custom_vertex_index : u32;
 
-[[builtin(instance_index)]] var<in> custom_instance_index : i32;
+[[builtin(instance_index)]] var<in> custom_instance_index : u32;
 
 [[stage(vertex)]]
 fn main() -> void {
   {
-    var _tint_pulling_pos : i32;
+    var _tint_pulling_pos : u32;
     _tint_pulling_pos = ((custom_vertex_index * 4u) + 0u);
     var_a = bitcast<f32>(_tint_pulling_vertex_buffer_0._tint_vertex_data[(_tint_pulling_pos / 4u)]);
     _tint_pulling_pos = ((custom_instance_index * 4u) + 0u);
@@ -305,7 +305,7 @@
 )";
 
   auto* expect = R"(
-[[builtin(vertex_index)]] var<in> _tint_pulling_vertex_index : i32;
+[[builtin(vertex_index)]] var<in> _tint_pulling_vertex_index : u32;
 
 [[binding(0), group(4)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
 
@@ -321,7 +321,7 @@
 [[stage(vertex)]]
 fn main() -> void {
   {
-    var _tint_pulling_pos : i32;
+    var _tint_pulling_pos : u32;
     _tint_pulling_pos = ((_tint_pulling_vertex_index * 16u) + 0u);
     var_a = bitcast<f32>(_tint_pulling_vertex_buffer_0._tint_vertex_data[(_tint_pulling_pos / 4u)]);
     _tint_pulling_pos = ((_tint_pulling_vertex_index * 16u) + 0u);
@@ -355,7 +355,7 @@
 )";
 
   auto* expect = R"(
-[[builtin(vertex_index)]] var<in> _tint_pulling_vertex_index : i32;
+[[builtin(vertex_index)]] var<in> _tint_pulling_vertex_index : u32;
 
 [[binding(0), group(4)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
 
@@ -377,7 +377,7 @@
 [[stage(vertex)]]
 fn main() -> void {
   {
-    var _tint_pulling_pos : i32;
+    var _tint_pulling_pos : u32;
     _tint_pulling_pos = ((_tint_pulling_vertex_index * 8u) + 0u);
     var_a = vec2<f32>(bitcast<f32>(_tint_pulling_vertex_buffer_0._tint_vertex_data[((_tint_pulling_pos + 0u) / 4u)]), bitcast<f32>(_tint_pulling_vertex_buffer_0._tint_vertex_data[((_tint_pulling_pos + 4u) / 4u)]));
     _tint_pulling_pos = ((_tint_pulling_vertex_index * 12u) + 0u);
diff --git a/src/type/type.cc b/src/type/type.cc
index 2a7d6ba..f74c20c 100644
--- a/src/type/type.cc
+++ b/src/type/type.cc
@@ -92,6 +92,10 @@
   return is_float_scalar() || is_float_vector();
 }
 
+bool Type::is_float_scalar_or_vector_or_matrix() const {
+  return is_float_scalar() || is_float_vector() || is_float_matrix();
+}
+
 bool Type::is_integer_scalar() const {
   return IsAnyOf<U32, I32>();
 }
diff --git a/src/type/type.h b/src/type/type.h
index c4e7a36..0dd9eaf 100644
--- a/src/type/type.h
+++ b/src/type/type.h
@@ -77,6 +77,8 @@
   bool is_float_vector() const;
   /// @returns true if this type is a float scalar or vector
   bool is_float_scalar_or_vector() const;
+  /// @returns true if this type is a float scalar or vector or matrix
+  bool is_float_scalar_or_vector_or_matrix() const;
   /// @returns true if this type is an integer scalar
   bool is_integer_scalar() const;
   /// @returns true if this type is a signed integer vector
diff --git a/src/writer/hlsl/generator_impl_binary_test.cc b/src/writer/hlsl/generator_impl_binary_test.cc
index 4b748b9..0f65366 100644
--- a/src/writer/hlsl/generator_impl_binary_test.cc
+++ b/src/writer/hlsl/generator_impl_binary_test.cc
@@ -36,6 +36,14 @@
 TEST_P(HlslBinaryTest, Emit_f32) {
   auto params = GetParam();
 
+  // Skip ops that are illegal for this type
+  if (params.op == ast::BinaryOp::kAnd || params.op == ast::BinaryOp::kOr ||
+      params.op == ast::BinaryOp::kXor ||
+      params.op == ast::BinaryOp::kShiftLeft ||
+      params.op == ast::BinaryOp::kShiftRight) {
+    return;
+  }
+
   Global("left", ty.f32(), ast::StorageClass::kFunction);
   Global("right", ty.f32(), ast::StorageClass::kFunction);
 
@@ -72,6 +80,12 @@
 TEST_P(HlslBinaryTest, Emit_i32) {
   auto params = GetParam();
 
+  // Skip ops that are illegal for this type
+  if (params.op == ast::BinaryOp::kShiftLeft ||
+      params.op == ast::BinaryOp::kShiftRight) {
+    return;
+  }
+
   Global("left", ty.i32(), ast::StorageClass::kFunction);
   Global("right", ty.i32(), ast::StorageClass::kFunction);
 
diff --git a/src/writer/spirv/builder_binary_expression_test.cc b/src/writer/spirv/builder_binary_expression_test.cc
index 6a44c9a..55000ce 100644
--- a/src/writer/spirv/builder_binary_expression_test.cc
+++ b/src/writer/spirv/builder_binary_expression_test.cc
@@ -58,6 +58,12 @@
 TEST_P(BinaryArithSignedIntegerTest, Vector) {
   auto param = GetParam();
 
+  // Skip ops that are illegal for this type
+  if (param.op == ast::BinaryOp::kAnd || param.op == ast::BinaryOp::kOr ||
+      param.op == ast::BinaryOp::kXor) {
+    return;
+  }
+
   auto* lhs = vec3<i32>(1, 1, 1);
   auto* rhs = vec3<i32>(1, 1, 1);
 
@@ -111,15 +117,13 @@
 INSTANTIATE_TEST_SUITE_P(
     BuilderTest,
     BinaryArithSignedIntegerTest,
+    // NOTE: No left and right shift as they require u32 for rhs operand
     testing::Values(BinaryData{ast::BinaryOp::kAdd, "OpIAdd"},
                     BinaryData{ast::BinaryOp::kAnd, "OpBitwiseAnd"},
                     BinaryData{ast::BinaryOp::kDivide, "OpSDiv"},
                     BinaryData{ast::BinaryOp::kModulo, "OpSMod"},
                     BinaryData{ast::BinaryOp::kMultiply, "OpIMul"},
                     BinaryData{ast::BinaryOp::kOr, "OpBitwiseOr"},
-                    BinaryData{ast::BinaryOp::kShiftLeft, "OpShiftLeftLogical"},
-                    BinaryData{ast::BinaryOp::kShiftRight,
-                               "OpShiftRightArithmetic"},
                     BinaryData{ast::BinaryOp::kSubtract, "OpISub"},
                     BinaryData{ast::BinaryOp::kXor, "OpBitwiseXor"}));
 
@@ -149,6 +153,12 @@
 TEST_P(BinaryArithUnsignedIntegerTest, Vector) {
   auto param = GetParam();
 
+  // Skip ops that are illegal for this type
+  if (param.op == ast::BinaryOp::kAnd || param.op == ast::BinaryOp::kOr ||
+      param.op == ast::BinaryOp::kXor) {
+    return;
+  }
+
   auto* lhs = vec3<u32>(1u, 1u, 1u);
   auto* rhs = vec3<u32>(1u, 1u, 1u);