Implement addition and subtraction of float matrices

Bug: tint:316
Change-Id: I3a1082c41c47daacf0220d029cb2a5f118684959
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/52580
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: David Neto <dneto@google.com>
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index baceffa..99dc700 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -2124,13 +2124,20 @@
   }
 
   // Matrix arithmetic
-  // TODO(amaiorano): matrix-matrix addition and subtraction
+  auto* lhs_mat = lhs_type->As<Matrix>();
+  auto* lhs_mat_elem_type = lhs_mat ? lhs_mat->type() : nullptr;
+  auto* rhs_mat = rhs_type->As<Matrix>();
+  auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr;
+  // Addition and subtraction of float matrices
+  if ((expr->IsAdd() || expr->IsSubtract()) && lhs_mat_elem_type &&
+      lhs_mat_elem_type->Is<F32>() && rhs_mat_elem_type &&
+      rhs_mat_elem_type->Is<F32>() &&
+      (lhs_mat->columns() == rhs_mat->columns()) &&
+      (lhs_mat->rows() == rhs_mat->rows())) {
+    SetType(expr, rhs_type);
+    return true;
+  }
   if (expr->IsMultiply()) {
-    auto* lhs_mat = lhs_type->As<Matrix>();
-    auto* lhs_mat_elem_type = lhs_mat ? lhs_mat->type() : nullptr;
-    auto* rhs_mat = rhs_type->As<Matrix>();
-    auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr;
-
     // Multiplication of a matrix and a scalar
     if (lhs_type->Is<F32>() && rhs_mat_elem_type &&
         rhs_mat_elem_type->Is<F32>()) {
diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc
index 06b6ce1..7406204 100644
--- a/src/resolver/resolver_test.cc
+++ b/src/resolver/resolver_test.cc
@@ -1238,9 +1238,12 @@
 };
 
 static constexpr create_ast_type_func_ptr all_create_type_funcs[] = {
-    ast_bool,        ast_u32,         ast_i32,        ast_f32,
-    ast_vec3<bool>,  ast_vec3<i32>,   ast_vec3<u32>,  ast_vec3<f32>,
-    ast_mat3x3<i32>, ast_mat3x3<u32>, ast_mat3x3<f32>};
+    ast_bool,        ast_u32,         ast_i32,         ast_f32,
+    ast_vec3<bool>,  ast_vec3<i32>,   ast_vec3<u32>,   ast_vec3<f32>,
+    ast_mat3x3<i32>, ast_mat3x3<u32>, ast_mat3x3<f32>,  //
+    ast_mat2x3<i32>, ast_mat2x3<u32>, ast_mat2x3<f32>,  //
+    ast_mat3x2<i32>, ast_mat3x2<u32>, ast_mat3x2<f32>   //
+};
 
 // A list of all valid test cases for 'lhs op rhs', except that for vecN and
 // matNxN, we only test N=3.
@@ -1338,14 +1341,43 @@
     // Params{Op::kModulo, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>},
 
     // Matrix arithmetic
+    Params{Op::kMultiply, ast_mat2x3<f32>, ast_f32, sem_mat2x3<sem_f32>},
+    Params{Op::kMultiply, ast_mat3x2<f32>, ast_f32, sem_mat3x2<sem_f32>},
     Params{Op::kMultiply, ast_mat3x3<f32>, ast_f32, sem_mat3x3<sem_f32>},
+
+    Params{Op::kMultiply, ast_f32, ast_mat2x3<f32>, sem_mat2x3<sem_f32>},
+    Params{Op::kMultiply, ast_f32, ast_mat3x2<f32>, sem_mat3x2<sem_f32>},
     Params{Op::kMultiply, ast_f32, ast_mat3x3<f32>, sem_mat3x3<sem_f32>},
 
+    Params{Op::kMultiply, ast_vec3<f32>, ast_mat2x3<f32>, sem_vec2<sem_f32>},
+    Params{Op::kMultiply, ast_vec2<f32>, ast_mat3x2<f32>, sem_vec3<sem_f32>},
     Params{Op::kMultiply, ast_vec3<f32>, ast_mat3x3<f32>, sem_vec3<sem_f32>},
+
+    Params{Op::kMultiply, ast_mat3x2<f32>, ast_vec3<f32>, sem_vec2<sem_f32>},
+    Params{Op::kMultiply, ast_mat2x3<f32>, ast_vec2<f32>, sem_vec3<sem_f32>},
     Params{Op::kMultiply, ast_mat3x3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>},
-    // TODO(amaiorano): add mat+mat and mat-mat
+
+    Params{Op::kMultiply, ast_mat2x3<f32>, ast_mat3x2<f32>,
+           sem_mat3x3<sem_f32>},
+    Params{Op::kMultiply, ast_mat3x2<f32>, ast_mat2x3<f32>,
+           sem_mat2x2<sem_f32>},
+    Params{Op::kMultiply, ast_mat3x2<f32>, ast_mat3x3<f32>,
+           sem_mat3x2<sem_f32>},
     Params{Op::kMultiply, ast_mat3x3<f32>, ast_mat3x3<f32>,
            sem_mat3x3<sem_f32>},
+    Params{Op::kMultiply, ast_mat3x3<f32>, ast_mat2x3<f32>,
+           sem_mat2x3<sem_f32>},
+
+    Params{Op::kAdd, ast_mat2x3<f32>, ast_mat2x3<f32>, sem_mat2x3<sem_f32>},
+    Params{Op::kAdd, ast_mat3x2<f32>, ast_mat3x2<f32>, sem_mat3x2<sem_f32>},
+    Params{Op::kAdd, ast_mat3x3<f32>, ast_mat3x3<f32>, sem_mat3x3<sem_f32>},
+
+    Params{Op::kSubtract, ast_mat2x3<f32>, ast_mat2x3<f32>,
+           sem_mat2x3<sem_f32>},
+    Params{Op::kSubtract, ast_mat3x2<f32>, ast_mat3x2<f32>,
+           sem_mat3x2<sem_f32>},
+    Params{Op::kSubtract, ast_mat3x3<f32>, ast_mat3x3<f32>,
+           sem_mat3x3<sem_f32>},
 
     // Comparison expressions
     // https://gpuweb.github.io/gpuweb/wgsl.html#comparison-expr
diff --git a/src/resolver/resolver_test_helper.h b/src/resolver/resolver_test_helper.h
index b82aec0..4a10110 100644
--- a/src/resolver/resolver_test_helper.h
+++ b/src/resolver/resolver_test_helper.h
@@ -177,6 +177,26 @@
 }
 
 template <typename T>
+ast::Type* ast_mat2x3(const ProgramBuilder::TypesBuilder& ty) {
+  return ty.mat2x3<T>();
+}
+
+template <create_ast_type_func_ptr create_type>
+ast::Type* ast_mat2x3(const ProgramBuilder::TypesBuilder& ty) {
+  return ty.mat2x3(create_type(ty));
+}
+
+template <typename T>
+ast::Type* ast_mat3x2(const ProgramBuilder::TypesBuilder& ty) {
+  return ty.mat3x2<T>();
+}
+
+template <create_ast_type_func_ptr create_type>
+ast::Type* ast_mat3x2(const ProgramBuilder::TypesBuilder& ty) {
+  return ty.mat3x2(create_type(ty));
+}
+
+template <typename T>
 ast::Type* ast_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
   return ty.mat3x3<T>();
 }
@@ -250,6 +270,18 @@
 }
 
 template <create_sem_type_func_ptr create_type>
+sem::Type* sem_mat2x3(const ProgramBuilder::TypesBuilder& ty) {
+  auto* column_type = ty.builder->create<sem::Vector>(create_type(ty), 3u);
+  return ty.builder->create<sem::Matrix>(column_type, 2u);
+}
+
+template <create_sem_type_func_ptr create_type>
+sem::Type* sem_mat3x2(const ProgramBuilder::TypesBuilder& ty) {
+  auto* column_type = ty.builder->create<sem::Vector>(create_type(ty), 2u);
+  return ty.builder->create<sem::Matrix>(column_type, 3u);
+}
+
+template <create_sem_type_func_ptr create_type>
 sem::Type* sem_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
   auto* column_type = ty.builder->create<sem::Vector>(create_type(ty), 3u);
   return ty.builder->create<sem::Matrix>(column_type, 3u);
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index dd8de52..79e551c 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -1748,6 +1748,67 @@
   return splat_result.to_i();
 }
 
+uint32_t Builder::GenerateMatrixAddOrSub(uint32_t lhs_id,
+                                         uint32_t rhs_id,
+                                         const sem::Matrix* type,
+                                         spv::Op op) {
+  // Example addition of two matrices:
+  // %31 = OpLoad %mat3v4float %m34
+  // %32 = OpLoad %mat3v4float %m34
+  // %33 = OpCompositeExtract %v4float %31 0
+  // %34 = OpCompositeExtract %v4float %32 0
+  // %35 = OpFAdd %v4float %33 %34
+  // %36 = OpCompositeExtract %v4float %31 1
+  // %37 = OpCompositeExtract %v4float %32 1
+  // %38 = OpFAdd %v4float %36 %37
+  // %39 = OpCompositeExtract %v4float %31 2
+  // %40 = OpCompositeExtract %v4float %32 2
+  // %41 = OpFAdd %v4float %39 %40
+  // %42 = OpCompositeConstruct %mat3v4float %35 %38 %41
+
+  auto* column_type = builder_.create<sem::Vector>(type->type(), type->rows());
+  auto column_type_id = GenerateTypeIfNeeded(column_type);
+
+  OperandList ops;
+
+  for (uint32_t i = 0; i < type->columns(); ++i) {
+    // Extract column `i` from lhs mat
+    auto lhs_column_id = result_op();
+    if (!push_function_inst(spv::Op::OpCompositeExtract,
+                            {Operand::Int(column_type_id), lhs_column_id,
+                             Operand::Int(lhs_id), Operand::Int(i)})) {
+      return 0;
+    }
+
+    // Extract column `i` from rhs mat
+    auto rhs_column_id = result_op();
+    if (!push_function_inst(spv::Op::OpCompositeExtract,
+                            {Operand::Int(column_type_id), rhs_column_id,
+                             Operand::Int(rhs_id), Operand::Int(i)})) {
+      return 0;
+    }
+
+    // Add or subtract the two columns
+    auto result = result_op();
+    if (!push_function_inst(op, {Operand::Int(column_type_id), result,
+                                 lhs_column_id, rhs_column_id})) {
+      return 0;
+    }
+
+    ops.push_back(result);
+  }
+
+  // Create the result matrix from the added/subtracted column vectors
+  auto result_mat_id = result_op();
+  ops.insert(ops.begin(), result_mat_id);
+  ops.insert(ops.begin(), Operand::Int(GenerateTypeIfNeeded(type)));
+  if (!push_function_inst(spv::Op::OpCompositeConstruct, ops)) {
+    return 0;
+  }
+
+  return result_mat_id.to_i();
+}
+
 uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
   // There is special logic for short circuiting operators.
   if (expr->IsLogicalAnd() || expr->IsLogicalOr()) {
@@ -1779,6 +1840,24 @@
   auto* lhs_type = TypeOf(expr->lhs())->UnwrapRef();
   auto* rhs_type = TypeOf(expr->rhs())->UnwrapRef();
 
+  // Handle matrix-matrix addition and subtraction
+  if ((expr->IsAdd() || expr->IsSubtract()) && lhs_type->is_float_matrix() &&
+      rhs_type->is_float_matrix()) {
+    auto* lhs_mat = lhs_type->As<sem::Matrix>();
+    auto* rhs_mat = rhs_type->As<sem::Matrix>();
+
+    // This should already have been validated by resolver
+    if (lhs_mat->rows() != rhs_mat->rows() ||
+        lhs_mat->columns() != rhs_mat->columns()) {
+      error_ = "matrices must have same dimensionality for add or subtract";
+      return 0;
+    }
+
+    return GenerateMatrixAddOrSub(
+        lhs_id, rhs_id, lhs_mat,
+        expr->IsAdd() ? spv::Op::OpFAdd : spv::Op::OpFSub);
+  }
+
   // For vector-scalar arithmetic operations, splat scalar into a vector. We
   // skip this for multiply as we can use OpVectorTimesScalar.
   const bool is_float_scalar_vector_multiply =
diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h
index 3dac6d2..1a7271d 100644
--- a/src/writer/spirv/builder.h
+++ b/src/writer/spirv/builder.h
@@ -482,6 +482,17 @@
   /// @returns id of the new vector
   uint32_t GenerateSplat(uint32_t scalar_id, const sem::Type* vec_type);
 
+  /// Generates instructions to add or subtract two matrices
+  /// @param lhs_id id of multiplicand
+  /// @param rhs_id id of multiplier
+  /// @param type type of both matrices and of result
+  /// @param op one of `spv::Op::OpFAdd` or `spv::Op::OpFSub`
+  /// @returns id of the result matrix
+  uint32_t GenerateMatrixAddOrSub(uint32_t lhs_id,
+                                  uint32_t rhs_id,
+                                  const sem::Matrix* type,
+                                  spv::Op op);
+
   /// Converts AST image format to SPIR-V and pushes an appropriate capability.
   /// @param format AST image format type
   /// @returns SPIR-V image format type
diff --git a/src/writer/spirv/builder_binary_expression_test.cc b/src/writer/spirv/builder_binary_expression_test.cc
index e15d4ec..3b8391b 100644
--- a/src/writer/spirv/builder_binary_expression_test.cc
+++ b/src/writer/spirv/builder_binary_expression_test.cc
@@ -863,8 +863,10 @@
 )");
 }
 
+namespace BinaryArithVectorScalar {
+
 enum class Type { f32, i32, u32 };
-ast::Expression* MakeVectorExpr(ProgramBuilder* builder, Type type) {
+static ast::Expression* MakeVectorExpr(ProgramBuilder* builder, Type type) {
   switch (type) {
     case Type::f32:
       return builder->vec3<ProgramBuilder::f32>(1.f, 1.f, 1.f);
@@ -875,7 +877,7 @@
   }
   return nullptr;
 }
-ast::Expression* MakeScalarExpr(ProgramBuilder* builder, Type type) {
+static ast::Expression* MakeScalarExpr(ProgramBuilder* builder, Type type) {
   switch (type) {
     case Type::f32:
       return builder->Expr(1.f);
@@ -886,7 +888,7 @@
   }
   return nullptr;
 }
-std::string OpTypeDecl(Type type) {
+static std::string OpTypeDecl(Type type) {
   switch (type) {
     case Type::f32:
       return "OpTypeFloat 32";
@@ -904,8 +906,8 @@
   std::string name;
 };
 
-using MixedBinaryArithTest = TestParamHelper<Param>;
-TEST_P(MixedBinaryArithTest, VectorScalar) {
+using BinaryArithVectorScalarTest = TestParamHelper<Param>;
+TEST_P(BinaryArithVectorScalarTest, VectorScalar) {
   auto& param = GetParam();
 
   ast::Expression* lhs = MakeVectorExpr(this, param.type);
@@ -943,7 +945,7 @@
 
   Validate(b);
 }
-TEST_P(MixedBinaryArithTest, ScalarVector) {
+TEST_P(BinaryArithVectorScalarTest, ScalarVector) {
   auto& param = GetParam();
 
   ast::Expression* lhs = MakeScalarExpr(this, param.type);
@@ -983,7 +985,7 @@
 }
 INSTANTIATE_TEST_SUITE_P(
     BuilderTest,
-    MixedBinaryArithTest,
+    BinaryArithVectorScalarTest,
     testing::Values(Param{Type::f32, ast::BinaryOp::kAdd, "OpFAdd"},
                     Param{Type::f32, ast::BinaryOp::kDivide, "OpFDiv"},
                     // NOTE: Modulo not allowed on mixed float scalar-vector
@@ -1005,8 +1007,8 @@
                     Param{Type::u32, ast::BinaryOp::kMultiply, "OpIMul"},
                     Param{Type::u32, ast::BinaryOp::kSubtract, "OpISub"}));
 
-using MixedBinaryArithMultiplyTest = TestParamHelper<Param>;
-TEST_P(MixedBinaryArithMultiplyTest, VectorScalar) {
+using BinaryArithVectorScalarMultiplyTest = TestParamHelper<Param>;
+TEST_P(BinaryArithVectorScalarMultiplyTest, VectorScalar) {
   auto& param = GetParam();
 
   ast::Expression* lhs = MakeVectorExpr(this, param.type);
@@ -1040,7 +1042,7 @@
 
   Validate(b);
 }
-TEST_P(MixedBinaryArithMultiplyTest, ScalarVector) {
+TEST_P(BinaryArithVectorScalarMultiplyTest, ScalarVector) {
   auto& param = GetParam();
 
   ast::Expression* lhs = MakeScalarExpr(this, param.type);
@@ -1075,10 +1077,113 @@
   Validate(b);
 }
 INSTANTIATE_TEST_SUITE_P(BuilderTest,
-                         MixedBinaryArithMultiplyTest,
+                         BinaryArithVectorScalarMultiplyTest,
                          testing::Values(Param{
                              Type::f32, ast::BinaryOp::kMultiply, "OpFMul"}));
 
+}  // namespace BinaryArithVectorScalar
+
+namespace BinaryArithMatrixMatrix {
+
+struct Param {
+  ast::BinaryOp op;
+  std::string name;
+};
+
+using BinaryArithMatrixMatrix = TestParamHelper<Param>;
+TEST_P(BinaryArithMatrixMatrix, AddOrSubtract) {
+  auto& param = GetParam();
+
+  ast::Expression* lhs = mat3x4<f32>();
+  ast::Expression* rhs = mat3x4<f32>();
+
+  auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
+
+  WrapInFunction(expr);
+
+  spirv::Builder& b = Build();
+  ASSERT_TRUE(b.Build()) << b.error();
+
+  EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %3 "test_function"
+OpExecutionMode %3 LocalSize 1 1 1
+OpName %3 "test_function"
+%2 = OpTypeVoid
+%1 = OpTypeFunction %2
+%7 = OpTypeFloat 32
+%6 = OpTypeVector %7 4
+%5 = OpTypeMatrix %6 3
+%8 = OpConstantNull %5
+%3 = OpFunction %2 None %1
+%4 = OpLabel
+%10 = OpCompositeExtract %6 %8 0
+%11 = OpCompositeExtract %6 %8 0
+%12 = )" + param.name + R"( %6 %10 %11
+%13 = OpCompositeExtract %6 %8 1
+%14 = OpCompositeExtract %6 %8 1
+%15 = )" + param.name + R"( %6 %13 %14
+%16 = OpCompositeExtract %6 %8 2
+%17 = OpCompositeExtract %6 %8 2
+%18 = )" + param.name + R"( %6 %16 %17
+%19 = OpCompositeConstruct %5 %12 %15 %18
+OpReturn
+OpFunctionEnd
+)");
+
+  Validate(b);
+}
+INSTANTIATE_TEST_SUITE_P(  //
+    BuilderTest,
+    BinaryArithMatrixMatrix,
+    testing::Values(Param{ast::BinaryOp::kAdd, "OpFAdd"},
+                    Param{ast::BinaryOp::kSubtract, "OpFSub"}));
+
+using BinaryArithMatrixMatrixMultiply = TestParamHelper<Param>;
+TEST_P(BinaryArithMatrixMatrixMultiply, Multiply) {
+  auto& param = GetParam();
+
+  ast::Expression* lhs = mat3x4<f32>();
+  ast::Expression* rhs = mat4x3<f32>();
+
+  auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
+
+  WrapInFunction(expr);
+
+  spirv::Builder& b = Build();
+  ASSERT_TRUE(b.Build()) << b.error();
+
+  EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %3 "test_function"
+OpExecutionMode %3 LocalSize 1 1 1
+OpName %3 "test_function"
+%2 = OpTypeVoid
+%1 = OpTypeFunction %2
+%7 = OpTypeFloat 32
+%6 = OpTypeVector %7 4
+%5 = OpTypeMatrix %6 3
+%8 = OpConstantNull %5
+%10 = OpTypeVector %7 3
+%9 = OpTypeMatrix %10 4
+%11 = OpConstantNull %9
+%13 = OpTypeMatrix %6 4
+%3 = OpFunction %2 None %1
+%4 = OpLabel
+%12 = OpMatrixTimesMatrix %13 %8 %11
+OpReturn
+OpFunctionEnd
+)");
+
+  Validate(b);
+}
+INSTANTIATE_TEST_SUITE_P(  //
+    BuilderTest,
+    BinaryArithMatrixMatrixMultiply,
+    testing::Values(Param{ast::BinaryOp::kMultiply, "OpFMul"}));
+
+}  // namespace BinaryArithMatrixMatrix
+
 }  // namespace
 }  // namespace spirv
 }  // namespace writer
diff --git a/test/expressions/binary_expressions.wgsl b/test/expressions/binary_expressions.wgsl
index a58eece..b7444a0 100644
--- a/test/expressions/binary_expressions.wgsl
+++ b/test/expressions/binary_expressions.wgsl
@@ -64,6 +64,19 @@
     r = s % v;

 }

 

+fn matrix_matrix_f32() {

+    var m34 : mat3x4<f32>;

+    var m43 : mat4x3<f32>;

+    var m33 : mat3x3<f32>;

+    var m44 : mat4x4<f32>;

+

+    m34 = m34 + m34;

+    m34 = m34 - m34;

+

+    m33 = m43 * m34;

+    m44 = m34 * m43;

+}

+

 [[stage(fragment)]]

 fn main() -> [[location(0)]] vec4<f32> {

     return vec4<f32>(0.0,0.0,0.0,0.0);

diff --git a/test/expressions/binary_expressions.wgsl.expected.hlsl b/test/expressions/binary_expressions.wgsl.expected.hlsl
index 01f2d2b..c234983 100644
--- a/test/expressions/binary_expressions.wgsl.expected.hlsl
+++ b/test/expressions/binary_expressions.wgsl.expected.hlsl
@@ -66,6 +66,17 @@
   r = (s % v);
 }
 
+void matrix_matrix_f32() {
+  float3x4 m34 = float3x4(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
+  float4x3 m43 = float4x3(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
+  float3x3 m33 = float3x3(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
+  float4x4 m44 = float4x4(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
+  m34 = (m34 + m34);
+  m34 = (m34 - m34);
+  m33 = mul(m34, m43);
+  m44 = mul(m43, m34);
+}
+
 tint_symbol main() {
   const tint_symbol tint_symbol_1 = {float4(0.0f, 0.0f, 0.0f, 0.0f)};
   return tint_symbol_1;
diff --git a/test/expressions/binary_expressions.wgsl.expected.msl b/test/expressions/binary_expressions.wgsl.expected.msl
index 5630940..e46aa45 100644
--- a/test/expressions/binary_expressions.wgsl.expected.msl
+++ b/test/expressions/binary_expressions.wgsl.expected.msl
@@ -69,6 +69,17 @@
   r = (s % v);
 }
 
+void matrix_matrix_f32() {
+  float3x4 m34 = float3x4(0.0f);
+  float4x3 m43 = float4x3(0.0f);
+  float3x3 m33 = float3x3(0.0f);
+  float4x4 m44 = float4x4(0.0f);
+  m34 = (m34 + m34);
+  m34 = (m34 - m34);
+  m33 = (m43 * m34);
+  m44 = (m34 * m43);
+}
+
 fragment tint_symbol_1 tint_symbol() {
   return {float4(0.0f, 0.0f, 0.0f, 0.0f)};
 }
diff --git a/test/expressions/binary_expressions.wgsl.expected.spvasm b/test/expressions/binary_expressions.wgsl.expected.spvasm
index 445bb49..8052168 100644
--- a/test/expressions/binary_expressions.wgsl.expected.spvasm
+++ b/test/expressions/binary_expressions.wgsl.expected.spvasm
@@ -1,7 +1,7 @@
 ; SPIR-V
 ; Version: 1.3
 ; Generator: Google Tint Compiler; 0
-; Bound: 200
+; Bound: 250
 ; Schema: 0
                OpCapability Shader
                OpMemoryModel Logical GLSL450
@@ -32,6 +32,11 @@
                OpName %v_4 "v"
                OpName %s_4 "s"
                OpName %r_4 "r"
+               OpName %matrix_matrix_f32 "matrix_matrix_f32"
+               OpName %m34 "m34"
+               OpName %m43 "m43"
+               OpName %m33 "m33"
+               OpName %m44 "m44"
                OpName %tint_symbol_2 "tint_symbol_2"
                OpName %tint_symbol "tint_symbol"
                OpName %main "main"
@@ -60,9 +65,21 @@
          %78 = OpConstantNull %v3uint
 %_ptr_Function_uint = OpTypePointer Function %uint
          %81 = OpConstantNull %uint
-        %191 = OpTypeFunction %void %v4float
+%mat3v4float = OpTypeMatrix %v4float 3
+%_ptr_Function_mat3v4float = OpTypePointer Function %mat3v4float
+        %196 = OpConstantNull %mat3v4float
+%mat4v3float = OpTypeMatrix %v3float 4
+%_ptr_Function_mat4v3float = OpTypePointer Function %mat4v3float
+        %200 = OpConstantNull %mat4v3float
+%mat3v3float = OpTypeMatrix %v3float 3
+%_ptr_Function_mat3v3float = OpTypePointer Function %mat3v3float
+        %204 = OpConstantNull %mat3v3float
+%mat4v4float = OpTypeMatrix %v4float 4
+%_ptr_Function_mat4v4float = OpTypePointer Function %mat4v4float
+        %208 = OpConstantNull %mat4v4float
+        %241 = OpTypeFunction %void %v4float
     %float_0 = OpConstant %float 0
-        %199 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
+        %249 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
 %vector_scalar_f32 = OpFunction %void None %6
           %9 = OpLabel
           %v = OpVariable %_ptr_Function_v3float Function %13
@@ -269,14 +286,56 @@
                OpStore %r_4 %188
                OpReturn
                OpFunctionEnd
-%tint_symbol_2 = OpFunction %void None %191
+%matrix_matrix_f32 = OpFunction %void None %6
+        %192 = OpLabel
+        %m34 = OpVariable %_ptr_Function_mat3v4float Function %196
+        %m43 = OpVariable %_ptr_Function_mat4v3float Function %200
+        %m33 = OpVariable %_ptr_Function_mat3v3float Function %204
+        %m44 = OpVariable %_ptr_Function_mat4v4float Function %208
+        %209 = OpLoad %mat3v4float %m34
+        %210 = OpLoad %mat3v4float %m34
+        %212 = OpCompositeExtract %v4float %209 0
+        %213 = OpCompositeExtract %v4float %210 0
+        %214 = OpFAdd %v4float %212 %213
+        %215 = OpCompositeExtract %v4float %209 1
+        %216 = OpCompositeExtract %v4float %210 1
+        %217 = OpFAdd %v4float %215 %216
+        %218 = OpCompositeExtract %v4float %209 2
+        %219 = OpCompositeExtract %v4float %210 2
+        %220 = OpFAdd %v4float %218 %219
+        %221 = OpCompositeConstruct %mat3v4float %214 %217 %220
+               OpStore %m34 %221
+        %222 = OpLoad %mat3v4float %m34
+        %223 = OpLoad %mat3v4float %m34
+        %225 = OpCompositeExtract %v4float %222 0
+        %226 = OpCompositeExtract %v4float %223 0
+        %227 = OpFSub %v4float %225 %226
+        %228 = OpCompositeExtract %v4float %222 1
+        %229 = OpCompositeExtract %v4float %223 1
+        %230 = OpFSub %v4float %228 %229
+        %231 = OpCompositeExtract %v4float %222 2
+        %232 = OpCompositeExtract %v4float %223 2
+        %233 = OpFSub %v4float %231 %232
+        %234 = OpCompositeConstruct %mat3v4float %227 %230 %233
+               OpStore %m34 %234
+        %235 = OpLoad %mat4v3float %m43
+        %236 = OpLoad %mat3v4float %m34
+        %237 = OpMatrixTimesMatrix %mat3v3float %235 %236
+               OpStore %m33 %237
+        %238 = OpLoad %mat3v4float %m34
+        %239 = OpLoad %mat4v3float %m43
+        %240 = OpMatrixTimesMatrix %mat4v4float %238 %239
+               OpStore %m44 %240
+               OpReturn
+               OpFunctionEnd
+%tint_symbol_2 = OpFunction %void None %241
 %tint_symbol = OpFunctionParameter %v4float
-        %194 = OpLabel
+        %244 = OpLabel
                OpStore %tint_symbol_1 %tint_symbol
                OpReturn
                OpFunctionEnd
        %main = OpFunction %void None %6
-        %196 = OpLabel
-        %197 = OpFunctionCall %void %tint_symbol_2 %199
+        %246 = OpLabel
+        %247 = OpFunctionCall %void %tint_symbol_2 %249
                OpReturn
                OpFunctionEnd
diff --git a/test/expressions/binary_expressions.wgsl.expected.wgsl b/test/expressions/binary_expressions.wgsl.expected.wgsl
index ec7e7d5..cd0bf55 100644
--- a/test/expressions/binary_expressions.wgsl.expected.wgsl
+++ b/test/expressions/binary_expressions.wgsl.expected.wgsl
@@ -62,6 +62,17 @@
   r = (s % v);
 }
 
+fn matrix_matrix_f32() {
+  var m34 : mat3x4<f32>;
+  var m43 : mat4x3<f32>;
+  var m33 : mat3x3<f32>;
+  var m44 : mat4x4<f32>;
+  m34 = (m34 + m34);
+  m34 = (m34 - m34);
+  m33 = (m43 * m34);
+  m44 = (m34 * m43);
+}
+
 [[stage(fragment)]]
 fn main() -> [[location(0)]] vec4<f32> {
   return vec4<f32>(0.0, 0.0, 0.0, 0.0);