[spirv-writer] Add binary multiplication.
This CL adds binary multiplication generation to the SPIR-V writer.
Bug: tint:5
Change-Id: I668d24035e947c51a9737549fd0841a4e8af1331
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/19700
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index c388e70..e9e45d7 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -79,6 +79,36 @@
return model;
}
+bool is_float_scalar(ast::type::Type* type) {
+ return type->IsF32();
+}
+
+bool is_float_matrix(ast::type::Type* type) {
+ return type->IsMatrix() && type->AsMatrix()->type()->IsF32();
+}
+
+bool is_float_vector(ast::type::Type* type) {
+ return type->IsVector() && type->AsVector()->type()->IsF32();
+}
+
+bool is_float_scalar_or_vector(ast::type::Type* type) {
+ return is_float_scalar(type) || is_float_vector(type);
+}
+
+bool is_unsigned_scalar_or_vector(ast::type::Type* type) {
+ return type->IsU32() ||
+ (type->IsVector() && type->AsVector()->type()->IsU32());
+}
+
+bool is_signed_scalar_or_vector(ast::type::Type* type) {
+ return type->IsI32() ||
+ (type->IsVector() && type->AsVector()->type()->IsI32());
+}
+
+bool is_integer_scalar_or_vector(ast::type::Type* type) {
+ return is_unsigned_scalar_or_vector(type) || is_signed_scalar_or_vector(type);
+}
+
} // namespace
Builder::Builder() : scope_stack_({}) {}
@@ -569,12 +599,9 @@
// Handle int and float and the vectors of those types. Other types
// should have been rejected by validation.
auto* lhs_type = expr->lhs()->result_type();
- bool lhs_is_float_or_vec =
- lhs_type->IsF32() ||
- (lhs_type->IsVector() && lhs_type->AsVector()->type()->IsF32());
- bool lhs_is_unsigned =
- lhs_type->IsU32() ||
- (lhs_type->IsVector() && lhs_type->AsVector()->type()->IsU32());
+ auto* rhs_type = expr->rhs()->result_type();
+ bool lhs_is_float_or_vec = is_float_scalar_or_vector(lhs_type);
+ bool lhs_is_unsigned = is_unsigned_scalar_or_vector(lhs_type);
spv::Op op = spv::Op::OpNop;
if (expr->IsAnd()) {
@@ -631,6 +658,45 @@
} else {
op = spv::Op::OpSMod;
}
+ } else if (expr->IsMultiply()) {
+ if (is_integer_scalar_or_vector(lhs_type)) {
+ // If the left hand side is an integer then this _has_ to be OpIMul as
+ // there there is no other integer multiplication.
+ op = spv::Op::OpIMul;
+ } else if (is_float_scalar(lhs_type) && is_float_scalar(rhs_type)) {
+ // Float scalars multiply with OpFMul
+ op = spv::Op::OpFMul;
+ } else if (is_float_vector(lhs_type) && is_float_vector(rhs_type)) {
+ // Float vectors must be validated to be the same size and then use OpFMul
+ op = spv::Op::OpFMul;
+ } else if (is_float_scalar(lhs_type) && is_float_vector(rhs_type)) {
+ // Scalar * Vector we need to flip lhs and rhs types
+ // because OpVectorTimesScalar expects <vector>, <scalar>
+ std::swap(lhs_id, rhs_id);
+ op = spv::Op::OpVectorTimesScalar;
+ } else if (is_float_vector(lhs_type) && is_float_scalar(rhs_type)) {
+ // float vector * scalar
+ op = spv::Op::OpVectorTimesScalar;
+ } else if (is_float_scalar(lhs_type) && is_float_matrix(rhs_type)) {
+ // Scalar * Matrix we need to flip lhs and rhs types because
+ // OpMatrixTimesScalar expects <matrix>, <scalar>
+ std::swap(lhs_id, rhs_id);
+ op = spv::Op::OpMatrixTimesScalar;
+ } else if (is_float_matrix(lhs_type) && is_float_scalar(rhs_type)) {
+ // float matrix * scalar
+ op = spv::Op::OpMatrixTimesScalar;
+ } else if (is_float_vector(lhs_type) && is_float_matrix(rhs_type)) {
+ // float vector * matrix
+ op = spv::Op::OpVectorTimesMatrix;
+ } else if (is_float_matrix(lhs_type) && is_float_vector(rhs_type)) {
+ // float matrix * vector
+ op = spv::Op::OpMatrixTimesVector;
+ } else if (is_float_matrix(lhs_type) && is_float_matrix(rhs_type)) {
+ // float matrix * matrix
+ op = spv::Op::OpMatrixTimesMatrix;
+ } else {
+ return 0;
+ }
} else if (expr->IsNotEqual()) {
op = lhs_is_float_or_vec ? spv::Op::OpFOrdNotEqual : spv::Op::OpINotEqual;
} else if (expr->IsOr()) {
diff --git a/src/writer/spirv/builder_binary_expression_test.cc b/src/writer/spirv/builder_binary_expression_test.cc
index 97b3569..fe18196 100644
--- a/src/writer/spirv/builder_binary_expression_test.cc
+++ b/src/writer/spirv/builder_binary_expression_test.cc
@@ -17,10 +17,12 @@
#include "gtest/gtest.h"
#include "src/ast/binary_expression.h"
#include "src/ast/float_literal.h"
+#include "src/ast/identifier_expression.h"
#include "src/ast/int_literal.h"
#include "src/ast/scalar_constructor_expression.h"
#include "src/ast/type/f32_type.h"
#include "src/ast/type/i32_type.h"
+#include "src/ast/type/matrix_type.h"
#include "src/ast/type/u32_type.h"
#include "src/ast/type/vector_type.h"
#include "src/ast/type_constructor_expression.h"
@@ -65,7 +67,7 @@
Builder b;
b.push_function(Function{});
- ASSERT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
+ EXPECT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 1
%2 = OpConstant %1 3
%3 = OpConstant %1 4
@@ -108,7 +110,7 @@
Builder b;
b.push_function(Function{});
- ASSERT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
+ EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1
%1 = OpTypeVector %2 3
%3 = OpConstant %2 1
@@ -125,6 +127,7 @@
BinaryData{ast::BinaryOp::kAnd, "OpBitwiseAnd"},
BinaryData{ast::BinaryOp::kDivide, "OpSDiv"},
BinaryData{ast::BinaryOp::kModulo, "OpSMod"},
+ BinaryData{ast::BinaryOp::kMultiply, "OpIMul"},
BinaryData{ast::BinaryOp::kOr, "OpBitwiseOr"},
BinaryData{ast::BinaryOp::kShiftLeft, "OpShiftLeftLogical"},
BinaryData{ast::BinaryOp::kShiftRight, "OpShiftRightLogical"},
@@ -152,7 +155,7 @@
Builder b;
b.push_function(Function{});
- ASSERT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
+ EXPECT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 0
%2 = OpConstant %1 3
%3 = OpConstant %1 4
@@ -195,7 +198,7 @@
Builder b;
b.push_function(Function{});
- ASSERT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
+ EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 0
%1 = OpTypeVector %2 3
%3 = OpConstant %2 1
@@ -212,6 +215,7 @@
BinaryData{ast::BinaryOp::kAnd, "OpBitwiseAnd"},
BinaryData{ast::BinaryOp::kDivide, "OpUDiv"},
BinaryData{ast::BinaryOp::kModulo, "OpUMod"},
+ BinaryData{ast::BinaryOp::kMultiply, "OpIMul"},
BinaryData{ast::BinaryOp::kOr, "OpBitwiseOr"},
BinaryData{ast::BinaryOp::kShiftLeft, "OpShiftLeftLogical"},
BinaryData{ast::BinaryOp::kShiftRight, "OpShiftRightLogical"},
@@ -239,7 +243,7 @@
Builder b;
b.push_function(Function{});
- ASSERT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
+ EXPECT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32
%2 = OpConstant %1 3.20000005
%3 = OpConstant %1 4.5
@@ -283,7 +287,7 @@
Builder b;
b.push_function(Function{});
- ASSERT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
+ EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
%1 = OpTypeVector %2 3
%3 = OpConstant %2 1
@@ -298,6 +302,7 @@
testing::Values(BinaryData{ast::BinaryOp::kAdd, "OpFAdd"},
BinaryData{ast::BinaryOp::kDivide, "OpFDiv"},
BinaryData{ast::BinaryOp::kModulo, "OpFMod"},
+ BinaryData{ast::BinaryOp::kMultiply, "OpFMul"},
BinaryData{ast::BinaryOp::kSubtract, "OpFSub"}));
using BinaryCompareUnsignedIntegerTest = testing::TestWithParam<BinaryData>;
@@ -320,7 +325,7 @@
Builder b;
b.push_function(Function{});
- ASSERT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
+ EXPECT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 0
%2 = OpConstant %1 3
%3 = OpConstant %1 4
@@ -365,7 +370,7 @@
Builder b;
b.push_function(Function{});
- ASSERT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
+ EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 0
%1 = OpTypeVector %2 3
%3 = OpConstant %2 1
@@ -407,7 +412,7 @@
Builder b;
b.push_function(Function{});
- ASSERT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
+ EXPECT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 1
%2 = OpConstant %1 3
%3 = OpConstant %1 4
@@ -452,7 +457,7 @@
Builder b;
b.push_function(Function{});
- ASSERT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
+ EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1
%1 = OpTypeVector %2 3
%3 = OpConstant %2 1
@@ -494,7 +499,7 @@
Builder b;
b.push_function(Function{});
- ASSERT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
+ EXPECT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32
%2 = OpConstant %1 3.20000005
%3 = OpConstant %1 4.5
@@ -539,7 +544,7 @@
Builder b;
b.push_function(Function{});
- ASSERT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
+ EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
%1 = OpTypeVector %2 3
%3 = OpConstant %2 1
@@ -561,6 +566,288 @@
BinaryData{ast::BinaryOp::kLessThanEqual, "OpFOrdLessThanEqual"},
BinaryData{ast::BinaryOp::kNotEqual, "OpFOrdNotEqual"}));
+TEST_F(BuilderTest, Binary_Multiply_VectorScalar) {
+ ast::type::F32Type f32;
+ ast::type::VectorType vec3(&f32, 3);
+
+ ast::ExpressionList vals;
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
+ auto lhs =
+ std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
+
+ auto rhs = std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.f));
+
+ Context ctx;
+ TypeDeterminer td(&ctx);
+
+ ast::BinaryExpression expr(ast::BinaryOp::kMultiply, std::move(lhs),
+ std::move(rhs));
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+ Builder b;
+ b.push_function(Function{});
+
+ EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
+%1 = OpTypeVector %2 3
+%3 = OpConstant %2 1
+%4 = OpConstantComposite %1 %3 %3 %3
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ "%5 = OpVectorTimesScalar %1 %4 %3\n");
+}
+
+TEST_F(BuilderTest, Binary_Multiply_ScalarVector) {
+ ast::type::F32Type f32;
+ ast::type::VectorType vec3(&f32, 3);
+
+ auto lhs = std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.f));
+
+ ast::ExpressionList vals;
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
+ auto rhs =
+ std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
+
+ Context ctx;
+ TypeDeterminer td(&ctx);
+
+ ast::BinaryExpression expr(ast::BinaryOp::kMultiply, std::move(lhs),
+ std::move(rhs));
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+ Builder b;
+ b.push_function(Function{});
+
+ EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32
+%2 = OpConstant %1 1
+%3 = OpTypeVector %1 3
+%4 = OpConstantComposite %3 %2 %2 %2
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ "%5 = OpVectorTimesScalar %3 %4 %2\n");
+}
+
+TEST_F(BuilderTest, Binary_Multiply_MatrixScalar) {
+ ast::type::F32Type f32;
+ ast::type::MatrixType mat3(&f32, 3, 3);
+
+ auto var = std::make_unique<ast::Variable>(
+ "mat", ast::StorageClass::kFunction, &mat3);
+ auto lhs = std::make_unique<ast::IdentifierExpression>("mat");
+ auto rhs = std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.f));
+
+ Context ctx;
+ TypeDeterminer td(&ctx);
+ td.RegisterVariableForTesting(var.get());
+
+ ast::BinaryExpression expr(ast::BinaryOp::kMultiply, std::move(lhs),
+ std::move(rhs));
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+ Builder b;
+ b.push_function(Function{});
+ ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
+
+ EXPECT_EQ(b.GenerateBinaryExpression(&expr), 8) << b.error();
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32
+%4 = OpTypeVector %5 3
+%3 = OpTypeMatrix %4 3
+%2 = OpTypePointer Function %3
+%1 = OpVariable %2 Function
+%7 = OpConstant %5 1
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ R"(%6 = OpLoad %3 %1
+%8 = OpMatrixTimesScalar %3 %6 %7
+)");
+}
+
+TEST_F(BuilderTest, Binary_Multiply_ScalarMatrix) {
+ ast::type::F32Type f32;
+ ast::type::MatrixType mat3(&f32, 3, 3);
+
+ auto var = std::make_unique<ast::Variable>(
+ "mat", ast::StorageClass::kFunction, &mat3);
+ auto lhs = std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.f));
+ auto rhs = std::make_unique<ast::IdentifierExpression>("mat");
+
+ Context ctx;
+ TypeDeterminer td(&ctx);
+ td.RegisterVariableForTesting(var.get());
+
+ ast::BinaryExpression expr(ast::BinaryOp::kMultiply, std::move(lhs),
+ std::move(rhs));
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+ Builder b;
+ b.push_function(Function{});
+ ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
+
+ EXPECT_EQ(b.GenerateBinaryExpression(&expr), 8) << b.error();
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32
+%4 = OpTypeVector %5 3
+%3 = OpTypeMatrix %4 3
+%2 = OpTypePointer Function %3
+%1 = OpVariable %2 Function
+%6 = OpConstant %5 1
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ R"(%7 = OpLoad %3 %1
+%8 = OpMatrixTimesScalar %3 %7 %6
+)");
+}
+
+TEST_F(BuilderTest, Binary_Multiply_MatrixVector) {
+ ast::type::F32Type f32;
+ ast::type::VectorType vec3(&f32, 3);
+ ast::type::MatrixType mat3(&f32, 3, 3);
+
+ auto var = std::make_unique<ast::Variable>(
+ "mat", ast::StorageClass::kFunction, &mat3);
+ auto lhs = std::make_unique<ast::IdentifierExpression>("mat");
+
+ ast::ExpressionList vals;
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
+ auto rhs =
+ std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
+
+ Context ctx;
+ TypeDeterminer td(&ctx);
+ td.RegisterVariableForTesting(var.get());
+
+ ast::BinaryExpression expr(ast::BinaryOp::kMultiply, std::move(lhs),
+ std::move(rhs));
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+ Builder b;
+ b.push_function(Function{});
+ ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
+
+ EXPECT_EQ(b.GenerateBinaryExpression(&expr), 9) << b.error();
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32
+%4 = OpTypeVector %5 3
+%3 = OpTypeMatrix %4 3
+%2 = OpTypePointer Function %3
+%1 = OpVariable %2 Function
+%7 = OpConstant %5 1
+%8 = OpConstantComposite %4 %7 %7 %7
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ R"(%6 = OpLoad %3 %1
+%9 = OpMatrixTimesVector %4 %6 %8
+)");
+}
+
+TEST_F(BuilderTest, Binary_Multiply_VectorMatrix) {
+ ast::type::F32Type f32;
+ ast::type::VectorType vec3(&f32, 3);
+ ast::type::MatrixType mat3(&f32, 3, 3);
+
+ auto var = std::make_unique<ast::Variable>(
+ "mat", ast::StorageClass::kFunction, &mat3);
+
+ ast::ExpressionList vals;
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
+ auto lhs =
+ std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
+
+ auto rhs = std::make_unique<ast::IdentifierExpression>("mat");
+
+ Context ctx;
+ TypeDeterminer td(&ctx);
+ td.RegisterVariableForTesting(var.get());
+
+ ast::BinaryExpression expr(ast::BinaryOp::kMultiply, std::move(lhs),
+ std::move(rhs));
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+ Builder b;
+ b.push_function(Function{});
+ ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
+
+ EXPECT_EQ(b.GenerateBinaryExpression(&expr), 9) << b.error();
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32
+%4 = OpTypeVector %5 3
+%3 = OpTypeMatrix %4 3
+%2 = OpTypePointer Function %3
+%1 = OpVariable %2 Function
+%6 = OpConstant %5 1
+%7 = OpConstantComposite %4 %6 %6 %6
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ R"(%8 = OpLoad %3 %1
+%9 = OpVectorTimesMatrix %4 %7 %8
+)");
+}
+
+TEST_F(BuilderTest, Binary_Multiply_MatrixMatrix) {
+ ast::type::F32Type f32;
+ ast::type::VectorType vec3(&f32, 3);
+ ast::type::MatrixType mat3(&f32, 3, 3);
+
+ auto var = std::make_unique<ast::Variable>(
+ "mat", ast::StorageClass::kFunction, &mat3);
+ auto lhs = std::make_unique<ast::IdentifierExpression>("mat");
+ auto rhs = std::make_unique<ast::IdentifierExpression>("mat");
+
+ Context ctx;
+ TypeDeterminer td(&ctx);
+ td.RegisterVariableForTesting(var.get());
+
+ ast::BinaryExpression expr(ast::BinaryOp::kMultiply, std::move(lhs),
+ std::move(rhs));
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+ Builder b;
+ b.push_function(Function{});
+ ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
+
+ EXPECT_EQ(b.GenerateBinaryExpression(&expr), 8) << b.error();
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32
+%4 = OpTypeVector %5 3
+%3 = OpTypeMatrix %4 3
+%2 = OpTypePointer Function %3
+%1 = OpVariable %2 Function
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ R"(%6 = OpLoad %3 %1
+%7 = OpLoad %3 %1
+%8 = OpMatrixTimesMatrix %3 %6 %7
+)");
+}
+
} // namespace
} // namespace spirv
} // namespace writer