[spirv-writer] Add binary multiplication.

This CL adds binary multiplication generation to the SPIR-V writer.

Bug: tint:5
Change-Id: I668d24035e947c51a9737549fd0841a4e8af1331
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/19700
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index c388e70..e9e45d7 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -79,6 +79,36 @@
   return model;
 }
 
+bool is_float_scalar(ast::type::Type* type) {
+  return type->IsF32();
+}
+
+bool is_float_matrix(ast::type::Type* type) {
+  return type->IsMatrix() && type->AsMatrix()->type()->IsF32();
+}
+
+bool is_float_vector(ast::type::Type* type) {
+  return type->IsVector() && type->AsVector()->type()->IsF32();
+}
+
+bool is_float_scalar_or_vector(ast::type::Type* type) {
+  return is_float_scalar(type) || is_float_vector(type);
+}
+
+bool is_unsigned_scalar_or_vector(ast::type::Type* type) {
+  return type->IsU32() ||
+         (type->IsVector() && type->AsVector()->type()->IsU32());
+}
+
+bool is_signed_scalar_or_vector(ast::type::Type* type) {
+  return type->IsI32() ||
+         (type->IsVector() && type->AsVector()->type()->IsI32());
+}
+
+bool is_integer_scalar_or_vector(ast::type::Type* type) {
+  return is_unsigned_scalar_or_vector(type) || is_signed_scalar_or_vector(type);
+}
+
 }  // namespace
 
 Builder::Builder() : scope_stack_({}) {}
@@ -569,12 +599,9 @@
   // Handle int and float and the vectors of those types. Other types
   // should have been rejected by validation.
   auto* lhs_type = expr->lhs()->result_type();
-  bool lhs_is_float_or_vec =
-      lhs_type->IsF32() ||
-      (lhs_type->IsVector() && lhs_type->AsVector()->type()->IsF32());
-  bool lhs_is_unsigned =
-      lhs_type->IsU32() ||
-      (lhs_type->IsVector() && lhs_type->AsVector()->type()->IsU32());
+  auto* rhs_type = expr->rhs()->result_type();
+  bool lhs_is_float_or_vec = is_float_scalar_or_vector(lhs_type);
+  bool lhs_is_unsigned = is_unsigned_scalar_or_vector(lhs_type);
 
   spv::Op op = spv::Op::OpNop;
   if (expr->IsAnd()) {
@@ -631,6 +658,45 @@
     } else {
       op = spv::Op::OpSMod;
     }
+  } else if (expr->IsMultiply()) {
+    if (is_integer_scalar_or_vector(lhs_type)) {
+      // If the left hand side is an integer then this _has_ to be OpIMul as
+      // there there is no other integer multiplication.
+      op = spv::Op::OpIMul;
+    } else if (is_float_scalar(lhs_type) && is_float_scalar(rhs_type)) {
+      // Float scalars multiply with OpFMul
+      op = spv::Op::OpFMul;
+    } else if (is_float_vector(lhs_type) && is_float_vector(rhs_type)) {
+      // Float vectors must be validated to be the same size and then use OpFMul
+      op = spv::Op::OpFMul;
+    } else if (is_float_scalar(lhs_type) && is_float_vector(rhs_type)) {
+      // Scalar * Vector we need to flip lhs and rhs types
+      // because OpVectorTimesScalar expects <vector>, <scalar>
+      std::swap(lhs_id, rhs_id);
+      op = spv::Op::OpVectorTimesScalar;
+    } else if (is_float_vector(lhs_type) && is_float_scalar(rhs_type)) {
+      // float vector * scalar
+      op = spv::Op::OpVectorTimesScalar;
+    } else if (is_float_scalar(lhs_type) && is_float_matrix(rhs_type)) {
+      // Scalar * Matrix we need to flip lhs and rhs types because
+      // OpMatrixTimesScalar expects <matrix>, <scalar>
+      std::swap(lhs_id, rhs_id);
+      op = spv::Op::OpMatrixTimesScalar;
+    } else if (is_float_matrix(lhs_type) && is_float_scalar(rhs_type)) {
+      // float matrix * scalar
+      op = spv::Op::OpMatrixTimesScalar;
+    } else if (is_float_vector(lhs_type) && is_float_matrix(rhs_type)) {
+      // float vector * matrix
+      op = spv::Op::OpVectorTimesMatrix;
+    } else if (is_float_matrix(lhs_type) && is_float_vector(rhs_type)) {
+      // float matrix * vector
+      op = spv::Op::OpMatrixTimesVector;
+    } else if (is_float_matrix(lhs_type) && is_float_matrix(rhs_type)) {
+      // float matrix * matrix
+      op = spv::Op::OpMatrixTimesMatrix;
+    } else {
+      return 0;
+    }
   } else if (expr->IsNotEqual()) {
     op = lhs_is_float_or_vec ? spv::Op::OpFOrdNotEqual : spv::Op::OpINotEqual;
   } else if (expr->IsOr()) {
diff --git a/src/writer/spirv/builder_binary_expression_test.cc b/src/writer/spirv/builder_binary_expression_test.cc
index 97b3569..fe18196 100644
--- a/src/writer/spirv/builder_binary_expression_test.cc
+++ b/src/writer/spirv/builder_binary_expression_test.cc
@@ -17,10 +17,12 @@
 #include "gtest/gtest.h"
 #include "src/ast/binary_expression.h"
 #include "src/ast/float_literal.h"
+#include "src/ast/identifier_expression.h"
 #include "src/ast/int_literal.h"
 #include "src/ast/scalar_constructor_expression.h"
 #include "src/ast/type/f32_type.h"
 #include "src/ast/type/i32_type.h"
+#include "src/ast/type/matrix_type.h"
 #include "src/ast/type/u32_type.h"
 #include "src/ast/type/vector_type.h"
 #include "src/ast/type_constructor_expression.h"
@@ -65,7 +67,7 @@
   Builder b;
   b.push_function(Function{});
 
-  ASSERT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
+  EXPECT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
   EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 1
 %2 = OpConstant %1 3
 %3 = OpConstant %1 4
@@ -108,7 +110,7 @@
   Builder b;
   b.push_function(Function{});
 
-  ASSERT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
+  EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
   EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1
 %1 = OpTypeVector %2 3
 %3 = OpConstant %2 1
@@ -125,6 +127,7 @@
         BinaryData{ast::BinaryOp::kAnd, "OpBitwiseAnd"},
         BinaryData{ast::BinaryOp::kDivide, "OpSDiv"},
         BinaryData{ast::BinaryOp::kModulo, "OpSMod"},
+        BinaryData{ast::BinaryOp::kMultiply, "OpIMul"},
         BinaryData{ast::BinaryOp::kOr, "OpBitwiseOr"},
         BinaryData{ast::BinaryOp::kShiftLeft, "OpShiftLeftLogical"},
         BinaryData{ast::BinaryOp::kShiftRight, "OpShiftRightLogical"},
@@ -152,7 +155,7 @@
   Builder b;
   b.push_function(Function{});
 
-  ASSERT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
+  EXPECT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
   EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 0
 %2 = OpConstant %1 3
 %3 = OpConstant %1 4
@@ -195,7 +198,7 @@
   Builder b;
   b.push_function(Function{});
 
-  ASSERT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
+  EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
   EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 0
 %1 = OpTypeVector %2 3
 %3 = OpConstant %2 1
@@ -212,6 +215,7 @@
         BinaryData{ast::BinaryOp::kAnd, "OpBitwiseAnd"},
         BinaryData{ast::BinaryOp::kDivide, "OpUDiv"},
         BinaryData{ast::BinaryOp::kModulo, "OpUMod"},
+        BinaryData{ast::BinaryOp::kMultiply, "OpIMul"},
         BinaryData{ast::BinaryOp::kOr, "OpBitwiseOr"},
         BinaryData{ast::BinaryOp::kShiftLeft, "OpShiftLeftLogical"},
         BinaryData{ast::BinaryOp::kShiftRight, "OpShiftRightLogical"},
@@ -239,7 +243,7 @@
   Builder b;
   b.push_function(Function{});
 
-  ASSERT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
+  EXPECT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
   EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32
 %2 = OpConstant %1 3.20000005
 %3 = OpConstant %1 4.5
@@ -283,7 +287,7 @@
   Builder b;
   b.push_function(Function{});
 
-  ASSERT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
+  EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
   EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
 %1 = OpTypeVector %2 3
 %3 = OpConstant %2 1
@@ -298,6 +302,7 @@
     testing::Values(BinaryData{ast::BinaryOp::kAdd, "OpFAdd"},
                     BinaryData{ast::BinaryOp::kDivide, "OpFDiv"},
                     BinaryData{ast::BinaryOp::kModulo, "OpFMod"},
+                    BinaryData{ast::BinaryOp::kMultiply, "OpFMul"},
                     BinaryData{ast::BinaryOp::kSubtract, "OpFSub"}));
 
 using BinaryCompareUnsignedIntegerTest = testing::TestWithParam<BinaryData>;
@@ -320,7 +325,7 @@
   Builder b;
   b.push_function(Function{});
 
-  ASSERT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
+  EXPECT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
   EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 0
 %2 = OpConstant %1 3
 %3 = OpConstant %1 4
@@ -365,7 +370,7 @@
   Builder b;
   b.push_function(Function{});
 
-  ASSERT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
+  EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
   EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 0
 %1 = OpTypeVector %2 3
 %3 = OpConstant %2 1
@@ -407,7 +412,7 @@
   Builder b;
   b.push_function(Function{});
 
-  ASSERT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
+  EXPECT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
   EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 1
 %2 = OpConstant %1 3
 %3 = OpConstant %1 4
@@ -452,7 +457,7 @@
   Builder b;
   b.push_function(Function{});
 
-  ASSERT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
+  EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
   EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1
 %1 = OpTypeVector %2 3
 %3 = OpConstant %2 1
@@ -494,7 +499,7 @@
   Builder b;
   b.push_function(Function{});
 
-  ASSERT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
+  EXPECT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
   EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32
 %2 = OpConstant %1 3.20000005
 %3 = OpConstant %1 4.5
@@ -539,7 +544,7 @@
   Builder b;
   b.push_function(Function{});
 
-  ASSERT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
+  EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
   EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
 %1 = OpTypeVector %2 3
 %3 = OpConstant %2 1
@@ -561,6 +566,288 @@
         BinaryData{ast::BinaryOp::kLessThanEqual, "OpFOrdLessThanEqual"},
         BinaryData{ast::BinaryOp::kNotEqual, "OpFOrdNotEqual"}));
 
+TEST_F(BuilderTest, Binary_Multiply_VectorScalar) {
+  ast::type::F32Type f32;
+  ast::type::VectorType vec3(&f32, 3);
+
+  ast::ExpressionList vals;
+  vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
+  vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
+  vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
+  auto lhs =
+      std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
+
+  auto rhs = std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 1.f));
+
+  Context ctx;
+  TypeDeterminer td(&ctx);
+
+  ast::BinaryExpression expr(ast::BinaryOp::kMultiply, std::move(lhs),
+                             std::move(rhs));
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+  Builder b;
+  b.push_function(Function{});
+
+  EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
+%1 = OpTypeVector %2 3
+%3 = OpConstant %2 1
+%4 = OpConstantComposite %1 %3 %3 %3
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+            "%5 = OpVectorTimesScalar %1 %4 %3\n");
+}
+
+TEST_F(BuilderTest, Binary_Multiply_ScalarVector) {
+  ast::type::F32Type f32;
+  ast::type::VectorType vec3(&f32, 3);
+
+  auto lhs = std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 1.f));
+
+  ast::ExpressionList vals;
+  vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
+  vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
+  vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
+  auto rhs =
+      std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
+
+  Context ctx;
+  TypeDeterminer td(&ctx);
+
+  ast::BinaryExpression expr(ast::BinaryOp::kMultiply, std::move(lhs),
+                             std::move(rhs));
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+  Builder b;
+  b.push_function(Function{});
+
+  EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32
+%2 = OpConstant %1 1
+%3 = OpTypeVector %1 3
+%4 = OpConstantComposite %3 %2 %2 %2
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+            "%5 = OpVectorTimesScalar %3 %4 %2\n");
+}
+
+TEST_F(BuilderTest, Binary_Multiply_MatrixScalar) {
+  ast::type::F32Type f32;
+  ast::type::MatrixType mat3(&f32, 3, 3);
+
+  auto var = std::make_unique<ast::Variable>(
+      "mat", ast::StorageClass::kFunction, &mat3);
+  auto lhs = std::make_unique<ast::IdentifierExpression>("mat");
+  auto rhs = std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 1.f));
+
+  Context ctx;
+  TypeDeterminer td(&ctx);
+  td.RegisterVariableForTesting(var.get());
+
+  ast::BinaryExpression expr(ast::BinaryOp::kMultiply, std::move(lhs),
+                             std::move(rhs));
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+  Builder b;
+  b.push_function(Function{});
+  ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
+
+  EXPECT_EQ(b.GenerateBinaryExpression(&expr), 8) << b.error();
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32
+%4 = OpTypeVector %5 3
+%3 = OpTypeMatrix %4 3
+%2 = OpTypePointer Function %3
+%1 = OpVariable %2 Function
+%7 = OpConstant %5 1
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+            R"(%6 = OpLoad %3 %1
+%8 = OpMatrixTimesScalar %3 %6 %7
+)");
+}
+
+TEST_F(BuilderTest, Binary_Multiply_ScalarMatrix) {
+  ast::type::F32Type f32;
+  ast::type::MatrixType mat3(&f32, 3, 3);
+
+  auto var = std::make_unique<ast::Variable>(
+      "mat", ast::StorageClass::kFunction, &mat3);
+  auto lhs = std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 1.f));
+  auto rhs = std::make_unique<ast::IdentifierExpression>("mat");
+
+  Context ctx;
+  TypeDeterminer td(&ctx);
+  td.RegisterVariableForTesting(var.get());
+
+  ast::BinaryExpression expr(ast::BinaryOp::kMultiply, std::move(lhs),
+                             std::move(rhs));
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+  Builder b;
+  b.push_function(Function{});
+  ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
+
+  EXPECT_EQ(b.GenerateBinaryExpression(&expr), 8) << b.error();
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32
+%4 = OpTypeVector %5 3
+%3 = OpTypeMatrix %4 3
+%2 = OpTypePointer Function %3
+%1 = OpVariable %2 Function
+%6 = OpConstant %5 1
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+            R"(%7 = OpLoad %3 %1
+%8 = OpMatrixTimesScalar %3 %7 %6
+)");
+}
+
+TEST_F(BuilderTest, Binary_Multiply_MatrixVector) {
+  ast::type::F32Type f32;
+  ast::type::VectorType vec3(&f32, 3);
+  ast::type::MatrixType mat3(&f32, 3, 3);
+
+  auto var = std::make_unique<ast::Variable>(
+      "mat", ast::StorageClass::kFunction, &mat3);
+  auto lhs = std::make_unique<ast::IdentifierExpression>("mat");
+
+  ast::ExpressionList vals;
+  vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
+  vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
+  vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
+  auto rhs =
+      std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
+
+  Context ctx;
+  TypeDeterminer td(&ctx);
+  td.RegisterVariableForTesting(var.get());
+
+  ast::BinaryExpression expr(ast::BinaryOp::kMultiply, std::move(lhs),
+                             std::move(rhs));
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+  Builder b;
+  b.push_function(Function{});
+  ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
+
+  EXPECT_EQ(b.GenerateBinaryExpression(&expr), 9) << b.error();
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32
+%4 = OpTypeVector %5 3
+%3 = OpTypeMatrix %4 3
+%2 = OpTypePointer Function %3
+%1 = OpVariable %2 Function
+%7 = OpConstant %5 1
+%8 = OpConstantComposite %4 %7 %7 %7
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+            R"(%6 = OpLoad %3 %1
+%9 = OpMatrixTimesVector %4 %6 %8
+)");
+}
+
+TEST_F(BuilderTest, Binary_Multiply_VectorMatrix) {
+  ast::type::F32Type f32;
+  ast::type::VectorType vec3(&f32, 3);
+  ast::type::MatrixType mat3(&f32, 3, 3);
+
+  auto var = std::make_unique<ast::Variable>(
+      "mat", ast::StorageClass::kFunction, &mat3);
+
+  ast::ExpressionList vals;
+  vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
+  vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
+  vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
+  auto lhs =
+      std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
+
+  auto rhs = std::make_unique<ast::IdentifierExpression>("mat");
+
+  Context ctx;
+  TypeDeterminer td(&ctx);
+  td.RegisterVariableForTesting(var.get());
+
+  ast::BinaryExpression expr(ast::BinaryOp::kMultiply, std::move(lhs),
+                             std::move(rhs));
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+  Builder b;
+  b.push_function(Function{});
+  ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
+
+  EXPECT_EQ(b.GenerateBinaryExpression(&expr), 9) << b.error();
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32
+%4 = OpTypeVector %5 3
+%3 = OpTypeMatrix %4 3
+%2 = OpTypePointer Function %3
+%1 = OpVariable %2 Function
+%6 = OpConstant %5 1
+%7 = OpConstantComposite %4 %6 %6 %6
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+            R"(%8 = OpLoad %3 %1
+%9 = OpVectorTimesMatrix %4 %7 %8
+)");
+}
+
+TEST_F(BuilderTest, Binary_Multiply_MatrixMatrix) {
+  ast::type::F32Type f32;
+  ast::type::VectorType vec3(&f32, 3);
+  ast::type::MatrixType mat3(&f32, 3, 3);
+
+  auto var = std::make_unique<ast::Variable>(
+      "mat", ast::StorageClass::kFunction, &mat3);
+  auto lhs = std::make_unique<ast::IdentifierExpression>("mat");
+  auto rhs = std::make_unique<ast::IdentifierExpression>("mat");
+
+  Context ctx;
+  TypeDeterminer td(&ctx);
+  td.RegisterVariableForTesting(var.get());
+
+  ast::BinaryExpression expr(ast::BinaryOp::kMultiply, std::move(lhs),
+                             std::move(rhs));
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+  Builder b;
+  b.push_function(Function{});
+  ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
+
+  EXPECT_EQ(b.GenerateBinaryExpression(&expr), 8) << b.error();
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32
+%4 = OpTypeVector %5 3
+%3 = OpTypeMatrix %4 3
+%2 = OpTypePointer Function %3
+%1 = OpVariable %2 Function
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+            R"(%6 = OpLoad %3 %1
+%7 = OpLoad %3 %1
+%8 = OpMatrixTimesMatrix %3 %6 %7
+)");
+}
+
 }  // namespace
 }  // namespace spirv
 }  // namespace writer