[spirv-reader] Add mixed scalar/vector/matrix multiply

Bug: tint:3
Change-Id: I5875bf453b05c5d5c96f90122206da04f6799976
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/23401
Reviewed-by: dan sinclair <dsinclair@google.com>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 932c5ca..05c3b30 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -200,6 +200,11 @@
       return ast::BinaryOp::kSubtract;
     case SpvOpIMul:
     case SpvOpFMul:
+    case SpvOpVectorTimesScalar:
+    case SpvOpMatrixTimesScalar:
+    case SpvOpVectorTimesMatrix:
+    case SpvOpMatrixTimesVector:
+    case SpvOpMatrixTimesMatrix:
       return ast::BinaryOp::kMultiply;
     case SpvOpUDiv:
     case SpvOpSDiv:
diff --git a/src/reader/spirv/function_arithmetic_test.cc b/src/reader/spirv/function_arithmetic_test.cc
index c92a84c..ba9656d 100644
--- a/src/reader/spirv/function_arithmetic_test.cc
+++ b/src/reader/spirv/function_arithmetic_test.cc
@@ -58,6 +58,10 @@
   %v2int_40_30 = OpConstantComposite %v2int %int_40 %int_30
   %v2float_50_60 = OpConstantComposite %v2float %float_50 %float_60
   %v2float_60_50 = OpConstantComposite %v2float %float_60 %float_50
+
+  %m2v2float = OpTypeMatrix %v2float 2
+  %m2v2float_a = OpConstantComposite %m2v2float %v2float_50_60 %v2float_60_50
+  %m2v2float_b = OpConstantComposite %m2v2float %v2float_60_50 %v2float_50_60
 )";
 }
 
@@ -904,17 +908,157 @@
                    "__vec_2__f32", AstFor("v2float_50_60"), "modulo",
                    AstFor("v2float_60_50")}));
 
+TEST_F(SpvBinaryArithTestBasic, VectorTimesScalar) {
+  const auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %1 = OpCopyObject %v2float %v2float_50_60
+     %2 = OpCopyObject %float %float_50
+     %10 = OpVectorTimesScalar %v2float %1 %2
+     OpReturn
+     OpFunctionEnd
+)";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{
+    x_10
+    none
+    __vec_2__f32
+    {
+      Binary{
+        Identifier{x_1}
+        multiply
+        Identifier{x_2}
+      }
+    }
+  })"))
+      << ToString(fe.ast_body());
+}
+
+TEST_F(SpvBinaryArithTestBasic, MatrixTimesScalar) {
+  const auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %1 = OpCopyObject %m2v2float %m2v2float_a
+     %2 = OpCopyObject %float %float_50
+     %10 = OpMatrixTimesScalar %m2v2float %1 %2
+     OpReturn
+     OpFunctionEnd
+)";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{
+    x_10
+    none
+    __mat_2_2__f32
+    {
+      Binary{
+        Identifier{x_1}
+        multiply
+        Identifier{x_2}
+      }
+    }
+  })"))
+      << ToString(fe.ast_body());
+}
+
+TEST_F(SpvBinaryArithTestBasic, VectorTimesMatrix) {
+  const auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %1 = OpCopyObject %v2float %v2float_50_60
+     %2 = OpCopyObject %m2v2float %m2v2float_a
+     %10 = OpMatrixTimesVector %m2v2float %1 %2
+     OpReturn
+     OpFunctionEnd
+)";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{
+    x_10
+    none
+    __mat_2_2__f32
+    {
+      Binary{
+        Identifier{x_1}
+        multiply
+        Identifier{x_2}
+      }
+    }
+  })"))
+      << ToString(fe.ast_body());
+}
+
+TEST_F(SpvBinaryArithTestBasic, MatrixTimesVector) {
+  const auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %1 = OpCopyObject %m2v2float %m2v2float_a
+     %2 = OpCopyObject %v2float %v2float_50_60
+     %10 = OpMatrixTimesVector %m2v2float %1 %2
+     OpReturn
+     OpFunctionEnd
+)";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{
+    x_10
+    none
+    __mat_2_2__f32
+    {
+      Binary{
+        Identifier{x_1}
+        multiply
+        Identifier{x_2}
+      }
+    }
+  })"))
+      << ToString(fe.ast_body());
+}
+
+TEST_F(SpvBinaryArithTestBasic, MatrixTimesMatrix) {
+  const auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %1 = OpCopyObject %m2v2float %m2v2float_a
+     %2 = OpCopyObject %m2v2float %m2v2float_b
+     %10 = OpMatrixTimesMatrix %m2v2float %1 %2
+     OpReturn
+     OpFunctionEnd
+)";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{
+    x_10
+    none
+    __mat_2_2__f32
+    {
+      Binary{
+        Identifier{x_1}
+        multiply
+        Identifier{x_2}
+      }
+    }
+  })"))
+      << ToString(fe.ast_body());
+}
+
 // TODO(dneto): OpSRem. Missing from WGSL
 // https://github.com/gpuweb/gpuweb/issues/702
 
 // TODO(dneto): OpFRem. Missing from WGSL
 // https://github.com/gpuweb/gpuweb/issues/702
 
-// TODO(dneto): OpVectorTimesScalar
-// TODO(dneto): OpMatrixTimesScalar
-// TODO(dneto): OpVectorTimesMatrix
-// TODO(dneto): OpMatrixTimesVector
-// TODO(dneto): OpMatrixTimesMatrix
 // TODO(dneto): OpOuterProduct
 // TODO(dneto): OpDot
 // TODO(dneto): OpIAddCarry