[spirv-writer] Add support for dot call.
This CL adds support for generating OpDot.
Bug: tint:5
Change-Id: I5a77e49ff26ff12b4ed7b2b01665f0928e51a568
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/22624
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 7269a30..aeb4b94 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -1250,10 +1250,12 @@
op = spv::Op::OpAny;
} else if (name == "all") {
op = spv::Op::OpAll;
- } else if (name == "is_nan") {
- op = spv::Op::OpIsNan;
+ } else if (name == "dot") {
+ op = spv::Op::OpDot;
} else if (name == "is_inf") {
op = spv::Op::OpIsInf;
+ } else if (name == "is_nan") {
+ op = spv::Op::OpIsNan;
}
if (op == spv::Op::OpNop) {
error_ = "unable to determine operator for: " + name;
diff --git a/src/writer/spirv/builder_intrinsic_test.cc b/src/writer/spirv/builder_intrinsic_test.cc
index f5151da..468ead6 100644
--- a/src/writer/spirv/builder_intrinsic_test.cc
+++ b/src/writer/spirv/builder_intrinsic_test.cc
@@ -170,6 +170,44 @@
testing::Values(IntrinsicData{"is_nan", "OpIsNan"},
IntrinsicData{"is_inf", "OpIsInf"}));
+TEST_F(BuilderTest, Call_Dot) {
+ ast::type::F32Type f32;
+ ast::type::VectorType vec3(&f32, 3);
+
+ auto var =
+ std::make_unique<ast::Variable>("v", ast::StorageClass::kPrivate, &vec3);
+
+ ast::ExpressionList params;
+ params.push_back(std::make_unique<ast::IdentifierExpression>("v"));
+ params.push_back(std::make_unique<ast::IdentifierExpression>("v"));
+ ast::CallExpression expr(std::make_unique<ast::IdentifierExpression>("dot"),
+ std::move(params));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ td.RegisterVariableForTesting(var.get());
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+ Builder b(&mod);
+ b.push_function(Function{});
+ ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
+
+ EXPECT_EQ(b.GenerateCallExpression(&expr), 6u) << b.error();
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32
+%3 = OpTypeVector %4 3
+%2 = OpTypePointer Private %3
+%5 = OpConstantNull %3
+%1 = OpVariable %2 Private %5
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ R"(%7 = OpLoad %3 %1
+%8 = OpLoad %3 %1
+%6 = OpDot %4 %7 %8
+)");
+}
+
} // namespace
} // namespace spirv
} // namespace writer