[spirv-reader] Add mixed scalar/vector/matrix multiply
Bug: tint:3
Change-Id: I5875bf453b05c5d5c96f90122206da04f6799976
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/23401
Reviewed-by: dan sinclair <dsinclair@google.com>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 932c5ca..05c3b30 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -200,6 +200,11 @@
return ast::BinaryOp::kSubtract;
case SpvOpIMul:
case SpvOpFMul:
+ case SpvOpVectorTimesScalar:
+ case SpvOpMatrixTimesScalar:
+ case SpvOpVectorTimesMatrix:
+ case SpvOpMatrixTimesVector:
+ case SpvOpMatrixTimesMatrix:
return ast::BinaryOp::kMultiply;
case SpvOpUDiv:
case SpvOpSDiv:
diff --git a/src/reader/spirv/function_arithmetic_test.cc b/src/reader/spirv/function_arithmetic_test.cc
index c92a84c..ba9656d 100644
--- a/src/reader/spirv/function_arithmetic_test.cc
+++ b/src/reader/spirv/function_arithmetic_test.cc
@@ -58,6 +58,10 @@
%v2int_40_30 = OpConstantComposite %v2int %int_40 %int_30
%v2float_50_60 = OpConstantComposite %v2float %float_50 %float_60
%v2float_60_50 = OpConstantComposite %v2float %float_60 %float_50
+
+ %m2v2float = OpTypeMatrix %v2float 2
+ %m2v2float_a = OpConstantComposite %m2v2float %v2float_50_60 %v2float_60_50
+ %m2v2float_b = OpConstantComposite %m2v2float %v2float_60_50 %v2float_50_60
)";
}
@@ -904,17 +908,157 @@
"__vec_2__f32", AstFor("v2float_50_60"), "modulo",
AstFor("v2float_60_50")}));
+TEST_F(SpvBinaryArithTestBasic, VectorTimesScalar) {
+ const auto assembly = CommonTypes() + R"(
+ %100 = OpFunction %void None %voidfn
+ %entry = OpLabel
+ %1 = OpCopyObject %v2float %v2float_50_60
+ %2 = OpCopyObject %float %float_50
+ %10 = OpVectorTimesScalar %v2float %1 %2
+ OpReturn
+ OpFunctionEnd
+)";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{
+ x_10
+ none
+ __vec_2__f32
+ {
+ Binary{
+ Identifier{x_1}
+ multiply
+ Identifier{x_2}
+ }
+ }
+ })"))
+ << ToString(fe.ast_body());
+}
+
+TEST_F(SpvBinaryArithTestBasic, MatrixTimesScalar) {
+ const auto assembly = CommonTypes() + R"(
+ %100 = OpFunction %void None %voidfn
+ %entry = OpLabel
+ %1 = OpCopyObject %m2v2float %m2v2float_a
+ %2 = OpCopyObject %float %float_50
+ %10 = OpMatrixTimesScalar %m2v2float %1 %2
+ OpReturn
+ OpFunctionEnd
+)";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{
+ x_10
+ none
+ __mat_2_2__f32
+ {
+ Binary{
+ Identifier{x_1}
+ multiply
+ Identifier{x_2}
+ }
+ }
+ })"))
+ << ToString(fe.ast_body());
+}
+
+TEST_F(SpvBinaryArithTestBasic, VectorTimesMatrix) {
+ const auto assembly = CommonTypes() + R"(
+ %100 = OpFunction %void None %voidfn
+ %entry = OpLabel
+ %1 = OpCopyObject %v2float %v2float_50_60
+ %2 = OpCopyObject %m2v2float %m2v2float_a
+ %10 = OpMatrixTimesVector %m2v2float %1 %2
+ OpReturn
+ OpFunctionEnd
+)";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{
+ x_10
+ none
+ __mat_2_2__f32
+ {
+ Binary{
+ Identifier{x_1}
+ multiply
+ Identifier{x_2}
+ }
+ }
+ })"))
+ << ToString(fe.ast_body());
+}
+
+TEST_F(SpvBinaryArithTestBasic, MatrixTimesVector) {
+ const auto assembly = CommonTypes() + R"(
+ %100 = OpFunction %void None %voidfn
+ %entry = OpLabel
+ %1 = OpCopyObject %m2v2float %m2v2float_a
+ %2 = OpCopyObject %v2float %v2float_50_60
+ %10 = OpMatrixTimesVector %m2v2float %1 %2
+ OpReturn
+ OpFunctionEnd
+)";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{
+ x_10
+ none
+ __mat_2_2__f32
+ {
+ Binary{
+ Identifier{x_1}
+ multiply
+ Identifier{x_2}
+ }
+ }
+ })"))
+ << ToString(fe.ast_body());
+}
+
+TEST_F(SpvBinaryArithTestBasic, MatrixTimesMatrix) {
+ const auto assembly = CommonTypes() + R"(
+ %100 = OpFunction %void None %voidfn
+ %entry = OpLabel
+ %1 = OpCopyObject %m2v2float %m2v2float_a
+ %2 = OpCopyObject %m2v2float %m2v2float_b
+ %10 = OpMatrixTimesMatrix %m2v2float %1 %2
+ OpReturn
+ OpFunctionEnd
+)";
+ auto* p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+ FunctionEmitter fe(p, *spirv_function(100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{
+ x_10
+ none
+ __mat_2_2__f32
+ {
+ Binary{
+ Identifier{x_1}
+ multiply
+ Identifier{x_2}
+ }
+ }
+ })"))
+ << ToString(fe.ast_body());
+}
+
// TODO(dneto): OpSRem. Missing from WGSL
// https://github.com/gpuweb/gpuweb/issues/702
// TODO(dneto): OpFRem. Missing from WGSL
// https://github.com/gpuweb/gpuweb/issues/702
-// TODO(dneto): OpVectorTimesScalar
-// TODO(dneto): OpMatrixTimesScalar
-// TODO(dneto): OpVectorTimesMatrix
-// TODO(dneto): OpMatrixTimesVector
-// TODO(dneto): OpMatrixTimesMatrix
// TODO(dneto): OpOuterProduct
// TODO(dneto): OpDot
// TODO(dneto): OpIAddCarry