Add support for binary arithmetic expressions with mixed scalar and vector operands

Bug: tint:376
Change-Id: I2994ff7394efa903050b470a850b41628d5b775c
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/52324
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 3be647b..baceffa 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -2097,23 +2097,35 @@
       SetType(expr, lhs_type);
       return true;
     }
+
+    // Binary arithmetic expressions with mixed scalar and vector operands
+    if (lhs_vec_elem_type && (lhs_vec_elem_type == rhs_type)) {
+      if (expr->IsModulo()) {
+        if (rhs_type->is_integer_scalar()) {
+          SetType(expr, lhs_type);
+          return true;
+        }
+      } else if (rhs_type->is_numeric_scalar()) {
+        SetType(expr, lhs_type);
+        return true;
+      }
+    }
+    if (rhs_vec_elem_type && (rhs_vec_elem_type == lhs_type)) {
+      if (expr->IsModulo()) {
+        if (lhs_type->is_integer_scalar()) {
+          SetType(expr, rhs_type);
+          return true;
+        }
+      } else if (lhs_type->is_numeric_scalar()) {
+        SetType(expr, rhs_type);
+        return true;
+      }
+    }
   }
 
-  // Binary arithmetic expressions with mixed scalar, vector, and matrix
-  // operands
+  // Matrix arithmetic
+  // TODO(amaiorano): matrix-matrix addition and subtraction
   if (expr->IsMultiply()) {
-    // Multiplication of a vector and a scalar
-    if (lhs_type->Is<F32>() && rhs_vec_elem_type &&
-        rhs_vec_elem_type->Is<F32>()) {
-      SetType(expr, rhs_type);
-      return true;
-    }
-    if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() &&
-        rhs_type->Is<F32>()) {
-      SetType(expr, lhs_type);
-      return true;
-    }
-
     auto* lhs_mat = lhs_type->As<Matrix>();
     auto* lhs_mat_elem_type = lhs_mat ? lhs_mat->type() : nullptr;
     auto* rhs_mat = rhs_type->As<Matrix>();
diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc
index 7a088c7..06b6ce1 100644
--- a/src/resolver/resolver_test.cc
+++ b/src/resolver/resolver_test.cc
@@ -1298,16 +1298,52 @@
     Params{Op::kDivide, ast_vec3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>},
     Params{Op::kModulo, ast_vec3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>},
 
-    // Binary arithmetic expressions with mixed scalar, vector, and matrix
-    // operands
-    Params{Op::kMultiply, ast_vec3<f32>, ast_f32, sem_vec3<sem_f32>},
-    Params{Op::kMultiply, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>},
+    // Binary arithmetic expressions with mixed scalar and vector operands
+    Params{Op::kAdd, ast_vec3<i32>, ast_i32, sem_vec3<sem_i32>},
+    Params{Op::kSubtract, ast_vec3<i32>, ast_i32, sem_vec3<sem_i32>},
+    Params{Op::kMultiply, ast_vec3<i32>, ast_i32, sem_vec3<sem_i32>},
+    Params{Op::kDivide, ast_vec3<i32>, ast_i32, sem_vec3<sem_i32>},
+    Params{Op::kModulo, ast_vec3<i32>, ast_i32, sem_vec3<sem_i32>},
 
+    Params{Op::kAdd, ast_i32, ast_vec3<i32>, sem_vec3<sem_i32>},
+    Params{Op::kSubtract, ast_i32, ast_vec3<i32>, sem_vec3<sem_i32>},
+    Params{Op::kMultiply, ast_i32, ast_vec3<i32>, sem_vec3<sem_i32>},
+    Params{Op::kDivide, ast_i32, ast_vec3<i32>, sem_vec3<sem_i32>},
+    Params{Op::kModulo, ast_i32, ast_vec3<i32>, sem_vec3<sem_i32>},
+
+    Params{Op::kAdd, ast_vec3<u32>, ast_u32, sem_vec3<sem_u32>},
+    Params{Op::kSubtract, ast_vec3<u32>, ast_u32, sem_vec3<sem_u32>},
+    Params{Op::kMultiply, ast_vec3<u32>, ast_u32, sem_vec3<sem_u32>},
+    Params{Op::kDivide, ast_vec3<u32>, ast_u32, sem_vec3<sem_u32>},
+    Params{Op::kModulo, ast_vec3<u32>, ast_u32, sem_vec3<sem_u32>},
+
+    Params{Op::kAdd, ast_u32, ast_vec3<u32>, sem_vec3<sem_u32>},
+    Params{Op::kSubtract, ast_u32, ast_vec3<u32>, sem_vec3<sem_u32>},
+    Params{Op::kMultiply, ast_u32, ast_vec3<u32>, sem_vec3<sem_u32>},
+    Params{Op::kDivide, ast_u32, ast_vec3<u32>, sem_vec3<sem_u32>},
+    Params{Op::kModulo, ast_u32, ast_vec3<u32>, sem_vec3<sem_u32>},
+
+    Params{Op::kAdd, ast_vec3<f32>, ast_f32, sem_vec3<sem_f32>},
+    Params{Op::kSubtract, ast_vec3<f32>, ast_f32, sem_vec3<sem_f32>},
+    Params{Op::kMultiply, ast_vec3<f32>, ast_f32, sem_vec3<sem_f32>},
+    Params{Op::kDivide, ast_vec3<f32>, ast_f32, sem_vec3<sem_f32>},
+    // NOTE: no kModulo for ast_vec3<f32>, ast_f32
+    // Params{Op::kModulo, ast_vec3<f32>, ast_f32, sem_vec3<sem_f32>},
+
+    Params{Op::kAdd, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>},
+    Params{Op::kSubtract, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>},
+    Params{Op::kMultiply, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>},
+    Params{Op::kDivide, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>},
+    // NOTE: no kModulo for ast_f32, ast_vec3<f32>
+    // Params{Op::kModulo, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>},
+
+    // Matrix arithmetic
     Params{Op::kMultiply, ast_mat3x3<f32>, ast_f32, sem_mat3x3<sem_f32>},
     Params{Op::kMultiply, ast_f32, ast_mat3x3<f32>, sem_mat3x3<sem_f32>},
 
     Params{Op::kMultiply, ast_vec3<f32>, ast_mat3x3<f32>, sem_vec3<sem_f32>},
     Params{Op::kMultiply, ast_mat3x3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>},
+    // TODO(amaiorano): add mat+mat and mat-mat
     Params{Op::kMultiply, ast_mat3x3<f32>, ast_mat3x3<f32>,
            sem_mat3x3<sem_f32>},
 
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 41d9f1b..6f96b21 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -1719,6 +1719,31 @@
   return result_id;
 }
 
+uint32_t Builder::GenerateSplat(uint32_t scalar_id, const sem::Type* vec_type) {
+  // Create a new vector to splat scalar into
+  auto splat_vector = result_op();
+  auto* splat_vector_type =
+      builder_.create<sem::Pointer>(vec_type, ast::StorageClass::kFunction);
+  push_function_var(
+      {Operand::Int(GenerateTypeIfNeeded(splat_vector_type)), splat_vector,
+       Operand::Int(ConvertStorageClass(ast::StorageClass::kFunction)),
+       Operand::Int(GenerateConstantNullIfNeeded(vec_type))});
+
+  // Splat scalar into vector
+  auto splat_result = result_op();
+  OperandList ops;
+  ops.push_back(Operand::Int(GenerateTypeIfNeeded(vec_type)));
+  ops.push_back(splat_result);
+  for (size_t i = 0; i < vec_type->As<sem::Vector>()->size(); ++i) {
+    ops.push_back(Operand::Int(scalar_id));
+  }
+  if (!push_function_inst(spv::Op::OpCompositeConstruct, ops)) {
+    return 0;
+  }
+
+  return splat_result.to_i();
+}
+
 uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
   // There is special logic for short circuiting operators.
   if (expr->IsLogicalAnd() || expr->IsLogicalOr()) {
@@ -1749,6 +1774,33 @@
   // should have been rejected by validation.
   auto* lhs_type = TypeOf(expr->lhs())->UnwrapRef();
   auto* rhs_type = TypeOf(expr->rhs())->UnwrapRef();
+
+  // For vector-scalar arithmetic operations, splat scalar into a vector. We
+  // skip this for multiply as we can use OpVectorTimesScalar.
+  const bool is_float_scalar_vector_multiply =
+      expr->IsMultiply() &&
+      ((lhs_type->is_float_scalar() && rhs_type->is_float_vector()) ||
+       (lhs_type->is_float_vector() && rhs_type->is_float_scalar()));
+
+  if (expr->IsArithmetic() && !is_float_scalar_vector_multiply) {
+    if (lhs_type->Is<sem::Vector>() && rhs_type->is_numeric_scalar()) {
+      uint32_t splat_vector_id = GenerateSplat(rhs_id, lhs_type);
+      if (splat_vector_id == 0) {
+        return 0;
+      }
+      rhs_id = splat_vector_id;
+      rhs_type = lhs_type;
+
+    } else if (lhs_type->is_numeric_scalar() && rhs_type->Is<sem::Vector>()) {
+      uint32_t splat_vector_id = GenerateSplat(lhs_id, rhs_type);
+      if (splat_vector_id == 0) {
+        return 0;
+      }
+      lhs_id = splat_vector_id;
+      lhs_type = rhs_type;
+    }
+  }
+
   bool lhs_is_float_or_vec = lhs_type->is_float_scalar_or_vector();
   bool lhs_is_unsigned = lhs_type->is_unsigned_scalar_or_vector();
 
diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h
index 0774e7a..3daaa18 100644
--- a/src/writer/spirv/builder.h
+++ b/src/writer/spirv/builder.h
@@ -473,6 +473,13 @@
   /// @returns true if the vector was successfully generated
   bool GenerateVectorType(const sem::Vector* vec, const Operand& result);
 
+  /// Generates instructions to splat `scalar_id` into a vector of type
+  /// `vec_type`
+  /// @param scalar_id scalar to splat
+  /// @param vec_type type of vector
+  /// @returns id of the new vector
+  uint32_t GenerateSplat(uint32_t scalar_id, const sem::Type* vec_type);
+
   /// Converts AST image format to SPIR-V and pushes an appropriate capability.
   /// @param format AST image format type
   /// @returns SPIR-V image format type
diff --git a/src/writer/spirv/builder_binary_expression_test.cc b/src/writer/spirv/builder_binary_expression_test.cc
index bdf19a2..e15d4ec 100644
--- a/src/writer/spirv/builder_binary_expression_test.cc
+++ b/src/writer/spirv/builder_binary_expression_test.cc
@@ -863,6 +863,222 @@
 )");
 }
 
+enum class Type { f32, i32, u32 };
+ast::Expression* MakeVectorExpr(ProgramBuilder* builder, Type type) {
+  switch (type) {
+    case Type::f32:
+      return builder->vec3<ProgramBuilder::f32>(1.f, 1.f, 1.f);
+    case Type::i32:
+      return builder->vec3<ProgramBuilder::i32>(1, 1, 1);
+    case Type::u32:
+      return builder->vec3<ProgramBuilder::u32>(1u, 1u, 1u);
+  }
+  return nullptr;
+}
+ast::Expression* MakeScalarExpr(ProgramBuilder* builder, Type type) {
+  switch (type) {
+    case Type::f32:
+      return builder->Expr(1.f);
+    case Type::i32:
+      return builder->Expr(1);
+    case Type::u32:
+      return builder->Expr(1u);
+  }
+  return nullptr;
+}
+std::string OpTypeDecl(Type type) {
+  switch (type) {
+    case Type::f32:
+      return "OpTypeFloat 32";
+    case Type::i32:
+      return "OpTypeInt 32 1";
+    case Type::u32:
+      return "OpTypeInt 32 0";
+  }
+  return {};
+}
+
+struct Param {
+  Type type;
+  ast::BinaryOp op;
+  std::string name;
+};
+
+using MixedBinaryArithTest = TestParamHelper<Param>;
+TEST_P(MixedBinaryArithTest, VectorScalar) {
+  auto& param = GetParam();
+
+  ast::Expression* lhs = MakeVectorExpr(this, param.type);
+  ast::Expression* rhs = MakeScalarExpr(this, param.type);
+  std::string op_type_decl = OpTypeDecl(param.type);
+
+  auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
+
+  WrapInFunction(expr);
+
+  spirv::Builder& b = Build();
+  ASSERT_TRUE(b.Build()) << b.error();
+
+  EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %3 "test_function"
+OpExecutionMode %3 LocalSize 1 1 1
+OpName %3 "test_function"
+%2 = OpTypeVoid
+%1 = OpTypeFunction %2
+%6 = )" + op_type_decl + R"(
+%5 = OpTypeVector %6 3
+%7 = OpConstant %6 1
+%8 = OpConstantComposite %5 %7 %7 %7
+%11 = OpTypePointer Function %5
+%12 = OpConstantNull %5
+%3 = OpFunction %2 None %1
+%4 = OpLabel
+%10 = OpVariable %11 Function %12
+%13 = OpCompositeConstruct %5 %7 %7 %7
+%9 = )" + param.name + R"( %5 %8 %13
+OpReturn
+OpFunctionEnd
+)");
+
+  Validate(b);
+}
+TEST_P(MixedBinaryArithTest, ScalarVector) {
+  auto& param = GetParam();
+
+  ast::Expression* lhs = MakeScalarExpr(this, param.type);
+  ast::Expression* rhs = MakeVectorExpr(this, param.type);
+  std::string op_type_decl = OpTypeDecl(param.type);
+
+  auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
+
+  WrapInFunction(expr);
+
+  spirv::Builder& b = Build();
+  ASSERT_TRUE(b.Build()) << b.error();
+
+  EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %3 "test_function"
+OpExecutionMode %3 LocalSize 1 1 1
+OpName %3 "test_function"
+%2 = OpTypeVoid
+%1 = OpTypeFunction %2
+%5 = )" + op_type_decl + R"(
+%6 = OpConstant %5 1
+%7 = OpTypeVector %5 3
+%8 = OpConstantComposite %7 %6 %6 %6
+%11 = OpTypePointer Function %7
+%12 = OpConstantNull %7
+%3 = OpFunction %2 None %1
+%4 = OpLabel
+%10 = OpVariable %11 Function %12
+%13 = OpCompositeConstruct %7 %6 %6 %6
+%9 = )" + param.name + R"( %7 %13 %8
+OpReturn
+OpFunctionEnd
+)");
+
+  Validate(b);
+}
+INSTANTIATE_TEST_SUITE_P(
+    BuilderTest,
+    MixedBinaryArithTest,
+    testing::Values(Param{Type::f32, ast::BinaryOp::kAdd, "OpFAdd"},
+                    Param{Type::f32, ast::BinaryOp::kDivide, "OpFDiv"},
+                    // NOTE: Modulo not allowed on mixed float scalar-vector
+                    // Param{Type::f32, ast::BinaryOp::kModulo, "OpFMod"},
+                    // NOTE: We test f32 multiplies separately as we emit
+                    // OpVectorTimesScalar for this case
+                    // Param{Type::i32, ast::BinaryOp::kMultiply, "OpIMul"},
+                    Param{Type::f32, ast::BinaryOp::kSubtract, "OpFSub"},
+
+                    Param{Type::i32, ast::BinaryOp::kAdd, "OpIAdd"},
+                    Param{Type::i32, ast::BinaryOp::kDivide, "OpSDiv"},
+                    Param{Type::i32, ast::BinaryOp::kModulo, "OpSMod"},
+                    Param{Type::i32, ast::BinaryOp::kMultiply, "OpIMul"},
+                    Param{Type::i32, ast::BinaryOp::kSubtract, "OpISub"},
+
+                    Param{Type::u32, ast::BinaryOp::kAdd, "OpIAdd"},
+                    Param{Type::u32, ast::BinaryOp::kDivide, "OpUDiv"},
+                    Param{Type::u32, ast::BinaryOp::kModulo, "OpUMod"},
+                    Param{Type::u32, ast::BinaryOp::kMultiply, "OpIMul"},
+                    Param{Type::u32, ast::BinaryOp::kSubtract, "OpISub"}));
+
+using MixedBinaryArithMultiplyTest = TestParamHelper<Param>;
+TEST_P(MixedBinaryArithMultiplyTest, VectorScalar) {
+  auto& param = GetParam();
+
+  ast::Expression* lhs = MakeVectorExpr(this, param.type);
+  ast::Expression* rhs = MakeScalarExpr(this, param.type);
+  std::string op_type_decl = OpTypeDecl(param.type);
+
+  auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
+
+  WrapInFunction(expr);
+
+  spirv::Builder& b = Build();
+  ASSERT_TRUE(b.Build()) << b.error();
+
+  EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %3 "test_function"
+OpExecutionMode %3 LocalSize 1 1 1
+OpName %3 "test_function"
+%2 = OpTypeVoid
+%1 = OpTypeFunction %2
+%6 = )" + op_type_decl + R"(
+%5 = OpTypeVector %6 3
+%7 = OpConstant %6 1
+%8 = OpConstantComposite %5 %7 %7 %7
+%3 = OpFunction %2 None %1
+%4 = OpLabel
+%9 = OpVectorTimesScalar %5 %8 %7
+OpReturn
+OpFunctionEnd
+)");
+
+  Validate(b);
+}
+TEST_P(MixedBinaryArithMultiplyTest, ScalarVector) {
+  auto& param = GetParam();
+
+  ast::Expression* lhs = MakeScalarExpr(this, param.type);
+  ast::Expression* rhs = MakeVectorExpr(this, param.type);
+  std::string op_type_decl = OpTypeDecl(param.type);
+
+  auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
+
+  WrapInFunction(expr);
+
+  spirv::Builder& b = Build();
+  ASSERT_TRUE(b.Build()) << b.error();
+
+  EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %3 "test_function"
+OpExecutionMode %3 LocalSize 1 1 1
+OpName %3 "test_function"
+%2 = OpTypeVoid
+%1 = OpTypeFunction %2
+%5 = )" + op_type_decl + R"(
+%6 = OpConstant %5 1
+%7 = OpTypeVector %5 3
+%8 = OpConstantComposite %7 %6 %6 %6
+%3 = OpFunction %2 None %1
+%4 = OpLabel
+%9 = OpVectorTimesScalar %7 %8 %6
+OpReturn
+OpFunctionEnd
+)");
+
+  Validate(b);
+}
+INSTANTIATE_TEST_SUITE_P(BuilderTest,
+                         MixedBinaryArithMultiplyTest,
+                         testing::Values(Param{
+                             Type::f32, ast::BinaryOp::kMultiply, "OpFMul"}));
+
 }  // namespace
 }  // namespace spirv
 }  // namespace writer
diff --git a/test/expressions/binary_expressions.wgsl b/test/expressions/binary_expressions.wgsl
new file mode 100644
index 0000000..a58eece
--- /dev/null
+++ b/test/expressions/binary_expressions.wgsl
@@ -0,0 +1,70 @@
+fn vector_scalar_f32() {

+    var v : vec3<f32>;

+    var s : f32;

+    var r : vec3<f32>;

+    r = v + s;

+    r = v - s;

+    r = v * s;

+    r = v / s;

+    //r = v % s;

+}

+

+fn vector_scalar_i32() {

+    var v : vec3<i32>;

+    var s : i32;

+    var r : vec3<i32>;

+    r = v + s;

+    r = v - s;

+    r = v * s;

+    r = v / s;

+    r = v % s;

+}

+

+fn vector_scalar_u32() {

+    var v : vec3<u32>;

+    var s : u32;

+    var r : vec3<u32>;

+    r = v + s;

+    r = v - s;

+    r = v * s;

+    r = v / s;

+    r = v % s;

+}

+

+fn scalar_vector_f32() {

+    var v : vec3<f32>;

+    var s : f32;

+    var r : vec3<f32>;

+    r = s + v;

+    r = s - v;

+    r = s * v;

+    r = s / v;

+    //r = s % v;

+}

+

+fn scalar_vector_i32() {

+    var v : vec3<i32>;

+    var s : i32;

+    var r : vec3<i32>;

+    r = s + v;

+    r = s - v;

+    r = s * v;

+    r = s / v;

+    r = s % v;

+}

+

+fn scalar_vector_u32() {

+    var v : vec3<u32>;

+    var s : u32;

+    var r : vec3<u32>;

+    r = s + v;

+    r = s - v;

+    r = s * v;

+    r = s / v;

+    r = s % v;

+}

+

+[[stage(fragment)]]

+fn main() -> [[location(0)]] vec4<f32> {

+    return vec4<f32>(0.0,0.0,0.0,0.0);

+}

diff --git a/test/expressions/binary_expressions.wgsl.expected.hlsl b/test/expressions/binary_expressions.wgsl.expected.hlsl
new file mode 100644
index 0000000..01f2d2b
--- /dev/null
+++ b/test/expressions/binary_expressions.wgsl.expected.hlsl
@@ -0,0 +1,73 @@
+struct tint_symbol {
+  float4 value : SV_Target0;
+};
+
+void vector_scalar_f32() {
+  float3 v = float3(0.0f, 0.0f, 0.0f);
+  float s = 0.0f;
+  float3 r = float3(0.0f, 0.0f, 0.0f);
+  r = (v + s);
+  r = (v - s);
+  r = (v * s);
+  r = (v / s);
+}
+
+void vector_scalar_i32() {
+  int3 v = int3(0, 0, 0);
+  int s = 0;
+  int3 r = int3(0, 0, 0);
+  r = (v + s);
+  r = (v - s);
+  r = (v * s);
+  r = (v / s);
+  r = (v % s);
+}
+
+void vector_scalar_u32() {
+  uint3 v = uint3(0u, 0u, 0u);
+  uint s = 0u;
+  uint3 r = uint3(0u, 0u, 0u);
+  r = (v + s);
+  r = (v - s);
+  r = (v * s);
+  r = (v / s);
+  r = (v % s);
+}
+
+void scalar_vector_f32() {
+  float3 v = float3(0.0f, 0.0f, 0.0f);
+  float s = 0.0f;
+  float3 r = float3(0.0f, 0.0f, 0.0f);
+  r = (s + v);
+  r = (s - v);
+  r = (s * v);
+  r = (s / v);
+}
+
+void scalar_vector_i32() {
+  int3 v = int3(0, 0, 0);
+  int s = 0;
+  int3 r = int3(0, 0, 0);
+  r = (s + v);
+  r = (s - v);
+  r = (s * v);
+  r = (s / v);
+  r = (s % v);
+}
+
+void scalar_vector_u32() {
+  uint3 v = uint3(0u, 0u, 0u);
+  uint s = 0u;
+  uint3 r = uint3(0u, 0u, 0u);
+  r = (s + v);
+  r = (s - v);
+  r = (s * v);
+  r = (s / v);
+  r = (s % v);
+}
+
+tint_symbol main() {
+  const tint_symbol tint_symbol_1 = {float4(0.0f, 0.0f, 0.0f, 0.0f)};
+  return tint_symbol_1;
+}
+
diff --git a/test/expressions/binary_expressions.wgsl.expected.msl b/test/expressions/binary_expressions.wgsl.expected.msl
new file mode 100644
index 0000000..5630940
--- /dev/null
+++ b/test/expressions/binary_expressions.wgsl.expected.msl
@@ -0,0 +1,75 @@
+#include <metal_stdlib>
+
+using namespace metal;
+struct tint_symbol_1 {
+  float4 value [[color(0)]];
+};
+
+void vector_scalar_f32() {
+  float3 v = 0.0f;
+  float s = 0.0f;
+  float3 r = 0.0f;
+  r = (v + s);
+  r = (v - s);
+  r = (v * s);
+  r = (v / s);
+}
+
+void vector_scalar_i32() {
+  int3 v = 0;
+  int s = 0;
+  int3 r = 0;
+  r = (v + s);
+  r = (v - s);
+  r = (v * s);
+  r = (v / s);
+  r = (v % s);
+}
+
+void vector_scalar_u32() {
+  uint3 v = 0u;
+  uint s = 0u;
+  uint3 r = 0u;
+  r = (v + s);
+  r = (v - s);
+  r = (v * s);
+  r = (v / s);
+  r = (v % s);
+}
+
+void scalar_vector_f32() {
+  float3 v = 0.0f;
+  float s = 0.0f;
+  float3 r = 0.0f;
+  r = (s + v);
+  r = (s - v);
+  r = (s * v);
+  r = (s / v);
+}
+
+void scalar_vector_i32() {
+  int3 v = 0;
+  int s = 0;
+  int3 r = 0;
+  r = (s + v);
+  r = (s - v);
+  r = (s * v);
+  r = (s / v);
+  r = (s % v);
+}
+
+void scalar_vector_u32() {
+  uint3 v = 0u;
+  uint s = 0u;
+  uint3 r = 0u;
+  r = (s + v);
+  r = (s - v);
+  r = (s * v);
+  r = (s / v);
+  r = (s % v);
+}
+
+fragment tint_symbol_1 tint_symbol() {
+  return {float4(0.0f, 0.0f, 0.0f, 0.0f)};
+}
+
diff --git a/test/expressions/binary_expressions.wgsl.expected.spvasm b/test/expressions/binary_expressions.wgsl.expected.spvasm
new file mode 100644
index 0000000..445bb49
--- /dev/null
+++ b/test/expressions/binary_expressions.wgsl.expected.spvasm
@@ -0,0 +1,282 @@
+; SPIR-V
+; Version: 1.3
+; Generator: Google Tint Compiler; 0
+; Bound: 200
+; Schema: 0
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %main "main" %tint_symbol_1
+               OpExecutionMode %main OriginUpperLeft
+               OpName %tint_symbol_1 "tint_symbol_1"
+               OpName %vector_scalar_f32 "vector_scalar_f32"
+               OpName %v "v"
+               OpName %s "s"
+               OpName %r "r"
+               OpName %vector_scalar_i32 "vector_scalar_i32"
+               OpName %v_0 "v"
+               OpName %s_0 "s"
+               OpName %r_0 "r"
+               OpName %vector_scalar_u32 "vector_scalar_u32"
+               OpName %v_1 "v"
+               OpName %s_1 "s"
+               OpName %r_1 "r"
+               OpName %scalar_vector_f32 "scalar_vector_f32"
+               OpName %v_2 "v"
+               OpName %s_2 "s"
+               OpName %r_2 "r"
+               OpName %scalar_vector_i32 "scalar_vector_i32"
+               OpName %v_3 "v"
+               OpName %s_3 "s"
+               OpName %r_3 "r"
+               OpName %scalar_vector_u32 "scalar_vector_u32"
+               OpName %v_4 "v"
+               OpName %s_4 "s"
+               OpName %r_4 "r"
+               OpName %tint_symbol_2 "tint_symbol_2"
+               OpName %tint_symbol "tint_symbol"
+               OpName %main "main"
+               OpDecorate %tint_symbol_1 Location 0
+      %float = OpTypeFloat 32
+    %v4float = OpTypeVector %float 4
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+          %5 = OpConstantNull %v4float
+%tint_symbol_1 = OpVariable %_ptr_Output_v4float Output %5
+       %void = OpTypeVoid
+          %6 = OpTypeFunction %void
+    %v3float = OpTypeVector %float 3
+%_ptr_Function_v3float = OpTypePointer Function %v3float
+         %13 = OpConstantNull %v3float
+%_ptr_Function_float = OpTypePointer Function %float
+         %16 = OpConstantNull %float
+        %int = OpTypeInt 32 1
+      %v3int = OpTypeVector %int 3
+%_ptr_Function_v3int = OpTypePointer Function %v3int
+         %42 = OpConstantNull %v3int
+%_ptr_Function_int = OpTypePointer Function %int
+         %45 = OpConstantNull %int
+       %uint = OpTypeInt 32 0
+     %v3uint = OpTypeVector %uint 3
+%_ptr_Function_v3uint = OpTypePointer Function %v3uint
+         %78 = OpConstantNull %v3uint
+%_ptr_Function_uint = OpTypePointer Function %uint
+         %81 = OpConstantNull %uint
+        %191 = OpTypeFunction %void %v4float
+    %float_0 = OpConstant %float 0
+        %199 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
+%vector_scalar_f32 = OpFunction %void None %6
+          %9 = OpLabel
+          %v = OpVariable %_ptr_Function_v3float Function %13
+          %s = OpVariable %_ptr_Function_float Function %16
+          %r = OpVariable %_ptr_Function_v3float Function %13
+         %21 = OpVariable %_ptr_Function_v3float Function %13
+         %26 = OpVariable %_ptr_Function_v3float Function %13
+         %34 = OpVariable %_ptr_Function_v3float Function %13
+         %18 = OpLoad %v3float %v
+         %19 = OpLoad %float %s
+         %22 = OpCompositeConstruct %v3float %19 %19 %19
+         %20 = OpFAdd %v3float %18 %22
+               OpStore %r %20
+         %23 = OpLoad %v3float %v
+         %24 = OpLoad %float %s
+         %27 = OpCompositeConstruct %v3float %24 %24 %24
+         %25 = OpFSub %v3float %23 %27
+               OpStore %r %25
+         %28 = OpLoad %v3float %v
+         %29 = OpLoad %float %s
+         %30 = OpVectorTimesScalar %v3float %28 %29
+               OpStore %r %30
+         %31 = OpLoad %v3float %v
+         %32 = OpLoad %float %s
+         %35 = OpCompositeConstruct %v3float %32 %32 %32
+         %33 = OpFDiv %v3float %31 %35
+               OpStore %r %33
+               OpReturn
+               OpFunctionEnd
+%vector_scalar_i32 = OpFunction %void None %6
+         %37 = OpLabel
+        %v_0 = OpVariable %_ptr_Function_v3int Function %42
+        %s_0 = OpVariable %_ptr_Function_int Function %45
+        %r_0 = OpVariable %_ptr_Function_v3int Function %42
+         %50 = OpVariable %_ptr_Function_v3int Function %42
+         %55 = OpVariable %_ptr_Function_v3int Function %42
+         %60 = OpVariable %_ptr_Function_v3int Function %42
+         %65 = OpVariable %_ptr_Function_v3int Function %42
+         %70 = OpVariable %_ptr_Function_v3int Function %42
+         %47 = OpLoad %v3int %v_0
+         %48 = OpLoad %int %s_0
+         %51 = OpCompositeConstruct %v3int %48 %48 %48
+         %49 = OpIAdd %v3int %47 %51
+               OpStore %r_0 %49
+         %52 = OpLoad %v3int %v_0
+         %53 = OpLoad %int %s_0
+         %56 = OpCompositeConstruct %v3int %53 %53 %53
+         %54 = OpISub %v3int %52 %56
+               OpStore %r_0 %54
+         %57 = OpLoad %v3int %v_0
+         %58 = OpLoad %int %s_0
+         %61 = OpCompositeConstruct %v3int %58 %58 %58
+         %59 = OpIMul %v3int %57 %61
+               OpStore %r_0 %59
+         %62 = OpLoad %v3int %v_0
+         %63 = OpLoad %int %s_0
+         %66 = OpCompositeConstruct %v3int %63 %63 %63
+         %64 = OpSDiv %v3int %62 %66
+               OpStore %r_0 %64
+         %67 = OpLoad %v3int %v_0
+         %68 = OpLoad %int %s_0
+         %71 = OpCompositeConstruct %v3int %68 %68 %68
+         %69 = OpSMod %v3int %67 %71
+               OpStore %r_0 %69
+               OpReturn
+               OpFunctionEnd
+%vector_scalar_u32 = OpFunction %void None %6
+         %73 = OpLabel
+        %v_1 = OpVariable %_ptr_Function_v3uint Function %78
+        %s_1 = OpVariable %_ptr_Function_uint Function %81
+        %r_1 = OpVariable %_ptr_Function_v3uint Function %78
+         %86 = OpVariable %_ptr_Function_v3uint Function %78
+         %91 = OpVariable %_ptr_Function_v3uint Function %78
+         %96 = OpVariable %_ptr_Function_v3uint Function %78
+        %101 = OpVariable %_ptr_Function_v3uint Function %78
+        %106 = OpVariable %_ptr_Function_v3uint Function %78
+         %83 = OpLoad %v3uint %v_1
+         %84 = OpLoad %uint %s_1
+         %87 = OpCompositeConstruct %v3uint %84 %84 %84
+         %85 = OpIAdd %v3uint %83 %87
+               OpStore %r_1 %85
+         %88 = OpLoad %v3uint %v_1
+         %89 = OpLoad %uint %s_1
+         %92 = OpCompositeConstruct %v3uint %89 %89 %89
+         %90 = OpISub %v3uint %88 %92
+               OpStore %r_1 %90
+         %93 = OpLoad %v3uint %v_1
+         %94 = OpLoad %uint %s_1
+         %97 = OpCompositeConstruct %v3uint %94 %94 %94
+         %95 = OpIMul %v3uint %93 %97
+               OpStore %r_1 %95
+         %98 = OpLoad %v3uint %v_1
+         %99 = OpLoad %uint %s_1
+        %102 = OpCompositeConstruct %v3uint %99 %99 %99
+        %100 = OpUDiv %v3uint %98 %102
+               OpStore %r_1 %100
+        %103 = OpLoad %v3uint %v_1
+        %104 = OpLoad %uint %s_1
+        %107 = OpCompositeConstruct %v3uint %104 %104 %104
+        %105 = OpUMod %v3uint %103 %107
+               OpStore %r_1 %105
+               OpReturn
+               OpFunctionEnd
+%scalar_vector_f32 = OpFunction %void None %6
+        %109 = OpLabel
+        %v_2 = OpVariable %_ptr_Function_v3float Function %13
+        %s_2 = OpVariable %_ptr_Function_float Function %16
+        %r_2 = OpVariable %_ptr_Function_v3float Function %13
+        %116 = OpVariable %_ptr_Function_v3float Function %13
+        %121 = OpVariable %_ptr_Function_v3float Function %13
+        %129 = OpVariable %_ptr_Function_v3float Function %13
+        %113 = OpLoad %float %s_2
+        %114 = OpLoad %v3float %v_2
+        %117 = OpCompositeConstruct %v3float %113 %113 %113
+        %115 = OpFAdd %v3float %117 %114
+               OpStore %r_2 %115
+        %118 = OpLoad %float %s_2
+        %119 = OpLoad %v3float %v_2
+        %122 = OpCompositeConstruct %v3float %118 %118 %118
+        %120 = OpFSub %v3float %122 %119
+               OpStore %r_2 %120
+        %123 = OpLoad %float %s_2
+        %124 = OpLoad %v3float %v_2
+        %125 = OpVectorTimesScalar %v3float %124 %123
+               OpStore %r_2 %125
+        %126 = OpLoad %float %s_2
+        %127 = OpLoad %v3float %v_2
+        %130 = OpCompositeConstruct %v3float %126 %126 %126
+        %128 = OpFDiv %v3float %130 %127
+               OpStore %r_2 %128
+               OpReturn
+               OpFunctionEnd
+%scalar_vector_i32 = OpFunction %void None %6
+        %132 = OpLabel
+        %v_3 = OpVariable %_ptr_Function_v3int Function %42
+        %s_3 = OpVariable %_ptr_Function_int Function %45
+        %r_3 = OpVariable %_ptr_Function_v3int Function %42
+        %139 = OpVariable %_ptr_Function_v3int Function %42
+        %144 = OpVariable %_ptr_Function_v3int Function %42
+        %149 = OpVariable %_ptr_Function_v3int Function %42
+        %154 = OpVariable %_ptr_Function_v3int Function %42
+        %159 = OpVariable %_ptr_Function_v3int Function %42
+        %136 = OpLoad %int %s_3
+        %137 = OpLoad %v3int %v_3
+        %140 = OpCompositeConstruct %v3int %136 %136 %136
+        %138 = OpIAdd %v3int %140 %137
+               OpStore %r_3 %138
+        %141 = OpLoad %int %s_3
+        %142 = OpLoad %v3int %v_3
+        %145 = OpCompositeConstruct %v3int %141 %141 %141
+        %143 = OpISub %v3int %145 %142
+               OpStore %r_3 %143
+        %146 = OpLoad %int %s_3
+        %147 = OpLoad %v3int %v_3
+        %150 = OpCompositeConstruct %v3int %146 %146 %146
+        %148 = OpIMul %v3int %150 %147
+               OpStore %r_3 %148
+        %151 = OpLoad %int %s_3
+        %152 = OpLoad %v3int %v_3
+        %155 = OpCompositeConstruct %v3int %151 %151 %151
+        %153 = OpSDiv %v3int %155 %152
+               OpStore %r_3 %153
+        %156 = OpLoad %int %s_3
+        %157 = OpLoad %v3int %v_3
+        %160 = OpCompositeConstruct %v3int %156 %156 %156
+        %158 = OpSMod %v3int %160 %157
+               OpStore %r_3 %158
+               OpReturn
+               OpFunctionEnd
+%scalar_vector_u32 = OpFunction %void None %6
+        %162 = OpLabel
+        %v_4 = OpVariable %_ptr_Function_v3uint Function %78
+        %s_4 = OpVariable %_ptr_Function_uint Function %81
+        %r_4 = OpVariable %_ptr_Function_v3uint Function %78
+        %169 = OpVariable %_ptr_Function_v3uint Function %78
+        %174 = OpVariable %_ptr_Function_v3uint Function %78
+        %179 = OpVariable %_ptr_Function_v3uint Function %78
+        %184 = OpVariable %_ptr_Function_v3uint Function %78
+        %189 = OpVariable %_ptr_Function_v3uint Function %78
+        %166 = OpLoad %uint %s_4
+        %167 = OpLoad %v3uint %v_4
+        %170 = OpCompositeConstruct %v3uint %166 %166 %166
+        %168 = OpIAdd %v3uint %170 %167
+               OpStore %r_4 %168
+        %171 = OpLoad %uint %s_4
+        %172 = OpLoad %v3uint %v_4
+        %175 = OpCompositeConstruct %v3uint %171 %171 %171
+        %173 = OpISub %v3uint %175 %172
+               OpStore %r_4 %173
+        %176 = OpLoad %uint %s_4
+        %177 = OpLoad %v3uint %v_4
+        %180 = OpCompositeConstruct %v3uint %176 %176 %176
+        %178 = OpIMul %v3uint %180 %177
+               OpStore %r_4 %178
+        %181 = OpLoad %uint %s_4
+        %182 = OpLoad %v3uint %v_4
+        %185 = OpCompositeConstruct %v3uint %181 %181 %181
+        %183 = OpUDiv %v3uint %185 %182
+               OpStore %r_4 %183
+        %186 = OpLoad %uint %s_4
+        %187 = OpLoad %v3uint %v_4
+        %190 = OpCompositeConstruct %v3uint %186 %186 %186
+        %188 = OpUMod %v3uint %190 %187
+               OpStore %r_4 %188
+               OpReturn
+               OpFunctionEnd
+%tint_symbol_2 = OpFunction %void None %191
+%tint_symbol = OpFunctionParameter %v4float
+        %194 = OpLabel
+               OpStore %tint_symbol_1 %tint_symbol
+               OpReturn
+               OpFunctionEnd
+       %main = OpFunction %void None %6
+        %196 = OpLabel
+        %197 = OpFunctionCall %void %tint_symbol_2 %199
+               OpReturn
+               OpFunctionEnd
diff --git a/test/expressions/binary_expressions.wgsl.expected.wgsl b/test/expressions/binary_expressions.wgsl.expected.wgsl
new file mode 100644
index 0000000..ec7e7d5
--- /dev/null
+++ b/test/expressions/binary_expressions.wgsl.expected.wgsl
@@ -0,0 +1,68 @@
+fn vector_scalar_f32() {
+  var v : vec3<f32>;
+  var s : f32;
+  var r : vec3<f32>;
+  r = (v + s);
+  r = (v - s);
+  r = (v * s);
+  r = (v / s);
+}
+
+fn vector_scalar_i32() {
+  var v : vec3<i32>;
+  var s : i32;
+  var r : vec3<i32>;
+  r = (v + s);
+  r = (v - s);
+  r = (v * s);
+  r = (v / s);
+  r = (v % s);
+}
+
+fn vector_scalar_u32() {
+  var v : vec3<u32>;
+  var s : u32;
+  var r : vec3<u32>;
+  r = (v + s);
+  r = (v - s);
+  r = (v * s);
+  r = (v / s);
+  r = (v % s);
+}
+
+fn scalar_vector_f32() {
+  var v : vec3<f32>;
+  var s : f32;
+  var r : vec3<f32>;
+  r = (s + v);
+  r = (s - v);
+  r = (s * v);
+  r = (s / v);
+}
+
+fn scalar_vector_i32() {
+  var v : vec3<i32>;
+  var s : i32;
+  var r : vec3<i32>;
+  r = (s + v);
+  r = (s - v);
+  r = (s * v);
+  r = (s / v);
+  r = (s % v);
+}
+
+fn scalar_vector_u32() {
+  var v : vec3<u32>;
+  var s : u32;
+  var r : vec3<u32>;
+  r = (s + v);
+  r = (s - v);
+  r = (s * v);
+  r = (s / v);
+  r = (s % v);
+}
+
+[[stage(fragment)]]
+fn main() -> [[location(0)]] vec4<f32> {
+  return vec4<f32>(0.0, 0.0, 0.0, 0.0);
+}