Emit call statements from the various backends.
This CL adds emission of CallStatement to the various backends.
Bug: tint:45
Change-Id: Ia2bdf0433f136c516ecccdcbc64a5365094220af
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/25281
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc
index b2c5af0..f826213 100644
--- a/src/writer/msl/generator_impl.cc
+++ b/src/writer/msl/generator_impl.cc
@@ -22,6 +22,7 @@
#include "src/ast/bool_literal.h"
#include "src/ast/break_statement.h"
#include "src/ast/call_expression.h"
+#include "src/ast/call_statement.h"
#include "src/ast/case_statement.h"
#include "src/ast/cast_expression.h"
#include "src/ast/continue_statement.h"
@@ -1480,6 +1481,14 @@
if (stmt->IsBreak()) {
return EmitBreak(stmt->AsBreak());
}
+ if (stmt->IsCall()) {
+ make_indent();
+ if (!EmitCall(stmt->AsCall()->expr())) {
+ return false;
+ }
+ out_ << ";" << std::endl;
+ return true;
+ }
if (stmt->IsContinue()) {
return EmitContinue(stmt->AsContinue());
}
diff --git a/src/writer/msl/generator_impl_call_test.cc b/src/writer/msl/generator_impl_call_test.cc
index 97272c6..ef6bb6b 100644
--- a/src/writer/msl/generator_impl_call_test.cc
+++ b/src/writer/msl/generator_impl_call_test.cc
@@ -16,6 +16,7 @@
#include "gtest/gtest.h"
#include "src/ast/call_expression.h"
+#include "src/ast/call_statement.h"
#include "src/ast/function.h"
#include "src/ast/identifier_expression.h"
#include "src/ast/module.h"
@@ -66,6 +67,28 @@
EXPECT_EQ(g.result(), "my_func(param1, param2)");
}
+TEST_F(MslGeneratorImplTest, EmitStatement_Call) {
+ ast::type::VoidType void_type;
+
+ auto id = std::make_unique<ast::IdentifierExpression>("my_func");
+ ast::ExpressionList params;
+ params.push_back(std::make_unique<ast::IdentifierExpression>("param1"));
+ params.push_back(std::make_unique<ast::IdentifierExpression>("param2"));
+ ast::CallStatement call(
+ std::make_unique<ast::CallExpression>(std::move(id), std::move(params)));
+
+ auto func = std::make_unique<ast::Function>("my_func", ast::VariableList{},
+ &void_type);
+
+ ast::Module m;
+ m.AddFunction(std::move(func));
+
+ GeneratorImpl g(&m);
+ g.increment_indent();
+ ASSERT_TRUE(g.EmitStatement(&call)) << g.error();
+ EXPECT_EQ(g.result(), " my_func(param1, param2);\n");
+}
+
} // namespace
} // namespace msl
} // namespace writer
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 59a36c0..bfad844 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -26,6 +26,7 @@
#include "src/ast/bool_literal.h"
#include "src/ast/builtin_decoration.h"
#include "src/ast/call_expression.h"
+#include "src/ast/call_statement.h"
#include "src/ast/case_statement.h"
#include "src/ast/cast_expression.h"
#include "src/ast/constructor_expression.h"
@@ -1807,6 +1808,9 @@
if (stmt->IsBreak()) {
return GenerateBreakStatement(stmt->AsBreak());
}
+ if (stmt->IsCall()) {
+ return GenerateCallExpression(stmt->AsCall()->expr()) != 0;
+ }
if (stmt->IsContinue()) {
return GenerateContinueStatement(stmt->AsContinue());
}
diff --git a/src/writer/spirv/builder_call_test.cc b/src/writer/spirv/builder_call_test.cc
index 36fe539..5499f59 100644
--- a/src/writer/spirv/builder_call_test.cc
+++ b/src/writer/spirv/builder_call_test.cc
@@ -17,6 +17,7 @@
#include "gtest/gtest.h"
#include "src/ast/binary_expression.h"
#include "src/ast/call_expression.h"
+#include "src/ast/call_statement.h"
#include "src/ast/float_literal.h"
#include "src/ast/identifier_expression.h"
#include "src/ast/return_statement.h"
@@ -129,7 +130,7 @@
)");
}
-TEST_F(BuilderTest, Call) {
+TEST_F(BuilderTest, Expression_Call) {
ast::type::F32Type f32;
ast::type::VoidType void_type;
@@ -197,6 +198,74 @@
)");
}
+TEST_F(BuilderTest, Statement_Call) {
+ ast::type::F32Type f32;
+ ast::type::VoidType void_type;
+
+ ast::VariableList func_params;
+ func_params.push_back(
+ std::make_unique<ast::Variable>("a", ast::StorageClass::kFunction, &f32));
+ func_params.push_back(
+ std::make_unique<ast::Variable>("b", ast::StorageClass::kFunction, &f32));
+
+ ast::Function a_func("a_func", std::move(func_params), &void_type);
+
+ ast::StatementList body;
+ body.push_back(std::make_unique<ast::ReturnStatement>(
+ std::make_unique<ast::BinaryExpression>(
+ ast::BinaryOp::kAdd, std::make_unique<ast::IdentifierExpression>("a"),
+ std::make_unique<ast::IdentifierExpression>("b"))));
+ a_func.set_body(std::move(body));
+
+ ast::Function func("main", {}, &void_type);
+
+ ast::ExpressionList call_params;
+ call_params.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
+ call_params.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
+
+ ast::CallStatement expr(std::make_unique<ast::CallExpression>(
+ std::make_unique<ast::IdentifierExpression>("a_func"),
+ std::move(call_params)));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ ASSERT_TRUE(td.DetermineFunction(&func)) << td.error();
+ ASSERT_TRUE(td.DetermineFunction(&a_func)) << td.error();
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+ Builder b(&mod);
+ ASSERT_TRUE(b.GenerateFunction(&a_func)) << b.error();
+ ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
+
+ EXPECT_TRUE(b.GenerateStatement(&expr)) << b.error();
+ EXPECT_EQ(DumpBuilder(b), R"(OpName %4 "a_func"
+OpName %5 "a"
+OpName %6 "b"
+OpName %12 "main"
+%2 = OpTypeVoid
+%3 = OpTypeFloat 32
+%1 = OpTypeFunction %2 %3 %3
+%11 = OpTypeFunction %2
+%15 = OpConstant %3 1
+%4 = OpFunction %2 None %1
+%5 = OpFunctionParameter %3
+%6 = OpFunctionParameter %3
+%7 = OpLabel
+%8 = OpLoad %3 %5
+%9 = OpLoad %3 %6
+%10 = OpFAdd %3 %8 %9
+OpReturnValue %10
+OpFunctionEnd
+%12 = OpFunction %2 None %11
+%13 = OpLabel
+%14 = OpFunctionCall %2 %4 %15 %15
+OpFunctionEnd
+)");
+}
+
} // namespace
} // namespace spirv
} // namespace writer
diff --git a/src/writer/wgsl/generator_impl.cc b/src/writer/wgsl/generator_impl.cc
index 817da91..4eb9583 100644
--- a/src/writer/wgsl/generator_impl.cc
+++ b/src/writer/wgsl/generator_impl.cc
@@ -26,6 +26,7 @@
#include "src/ast/break_statement.h"
#include "src/ast/builtin_decoration.h"
#include "src/ast/call_expression.h"
+#include "src/ast/call_statement.h"
#include "src/ast/case_statement.h"
#include "src/ast/cast_expression.h"
#include "src/ast/constructor_expression.h"
@@ -631,6 +632,14 @@
if (stmt->IsBreak()) {
return EmitBreak(stmt->AsBreak());
}
+ if (stmt->IsCall()) {
+ make_indent();
+ if (!EmitCall(stmt->AsCall()->expr())) {
+ return false;
+ }
+ out_ << ";" << std::endl;
+ return true;
+ }
if (stmt->IsContinue()) {
return EmitContinue(stmt->AsContinue());
}
@@ -656,7 +665,7 @@
return EmitVariable(stmt->AsVariableDecl()->variable());
}
- error_ = "unknown statement type";
+ error_ = "unknown statement type: " + stmt->str();
return false;
}
diff --git a/src/writer/wgsl/generator_impl_call_test.cc b/src/writer/wgsl/generator_impl_call_test.cc
index 91a5a4c..3ed0b3d 100644
--- a/src/writer/wgsl/generator_impl_call_test.cc
+++ b/src/writer/wgsl/generator_impl_call_test.cc
@@ -16,6 +16,7 @@
#include "gtest/gtest.h"
#include "src/ast/call_expression.h"
+#include "src/ast/call_statement.h"
#include "src/ast/identifier_expression.h"
#include "src/writer/wgsl/generator_impl.h"
@@ -47,6 +48,21 @@
EXPECT_EQ(g.result(), "my_func(param1, param2)");
}
+TEST_F(WgslGeneratorImplTest, EmitStatement_Call) {
+ auto id = std::make_unique<ast::IdentifierExpression>("my_func");
+ ast::ExpressionList params;
+ params.push_back(std::make_unique<ast::IdentifierExpression>("param1"));
+ params.push_back(std::make_unique<ast::IdentifierExpression>("param2"));
+
+ ast::CallStatement call(
+ std::make_unique<ast::CallExpression>(std::move(id), std::move(params)));
+
+ GeneratorImpl g;
+ g.increment_indent();
+ ASSERT_TRUE(g.EmitStatement(&call)) << g.error();
+ EXPECT_EQ(g.result(), " my_func(param1, param2);\n");
+}
+
} // namespace
} // namespace wgsl
} // namespace writer