[metal-writer] Add entry point support.
This CL adds preliminary entry point support to the Metal backend.
Bug: tint:8
Change-Id: I7b904621d706d4503d5054711de64872f79cf2fa
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/23708
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc
index ba4c935..61e560d 100644
--- a/src/writer/msl/generator_impl.cc
+++ b/src/writer/msl/generator_impl.cc
@@ -44,12 +44,16 @@
GeneratorImpl::~GeneratorImpl() = default;
bool GeneratorImpl::Generate(const ast::Module& module) {
+ module_ = &module;
+
for (const auto& func : module.functions()) {
if (!EmitFunction(func.get())) {
return false;
}
out_ << std::endl;
}
+
+ module_ = nullptr;
return true;
}
@@ -227,14 +231,51 @@
return false;
}
+void GeneratorImpl::EmitStage(ast::PipelineStage stage) {
+ switch (stage) {
+ case ast::PipelineStage::kFragment:
+ out_ << "fragment";
+ break;
+ case ast::PipelineStage::kVertex:
+ out_ << "vertex";
+ break;
+ case ast::PipelineStage::kCompute:
+ out_ << "kernel";
+ break;
+ case ast::PipelineStage::kNone:
+ break;
+ }
+ return;
+}
+
bool GeneratorImpl::EmitFunction(ast::Function* func) {
make_indent();
+ // TODO(dsinclair): Technically this is wrong as you could, in theory, have
+ // multiple entry points pointing at the same function. I'm ignoring that for
+ // now. It will either go away with the entry_point changes in the spec
+ // or we'll have to figure out how to deal with it.
+
+ auto name = func->name();
+
+ for (const auto& ep : module_->entry_points()) {
+ if (ep->function_name() == name) {
+ EmitStage(ep->stage());
+ out_ << " ";
+
+ if (!ep->name().empty()) {
+ name = ep->name();
+ }
+
+ break;
+ }
+ }
+
if (!EmitType(func->return_type(), "")) {
return false;
}
- out_ << " " << func->name() << "(";
+ out_ << " " << name << "(";
bool first = true;
for (const auto& v : func->params()) {
diff --git a/src/writer/msl/generator_impl.h b/src/writer/msl/generator_impl.h
index 491f120..10d9073 100644
--- a/src/writer/msl/generator_impl.h
+++ b/src/writer/msl/generator_impl.h
@@ -76,6 +76,9 @@
/// @param expr the scalar constructor expression
/// @returns true if the scalar constructor is emitted
bool EmitScalarConstructor(ast::ScalarConstructorExpression* expr);
+ /// Handles emitting a pipeline stage name
+ /// @param stage the stage to emit
+ void EmitStage(ast::PipelineStage stage);
/// Handles a brace-enclosed list of statements.
/// @param statements the statements to output
/// @returns true if the statements were emitted
@@ -97,6 +100,9 @@
/// @param expr the type constructor expression
/// @returns true if the constructor is emitted
bool EmitTypeConstructor(ast::TypeConstructorExpression* expr);
+
+ private:
+ const ast::Module* module_ = nullptr;
};
} // namespace msl
diff --git a/src/writer/msl/generator_impl_function_test.cc b/src/writer/msl/generator_impl_function_test.cc
index d47864e..965eaba 100644
--- a/src/writer/msl/generator_impl_function_test.cc
+++ b/src/writer/msl/generator_impl_function_test.cc
@@ -14,6 +14,7 @@
#include "gtest/gtest.h"
#include "src/ast/function.h"
+#include "src/ast/module.h"
#include "src/ast/return_statement.h"
#include "src/ast/type/array_type.h"
#include "src/ast/type/f32_type.h"
@@ -32,19 +33,24 @@
TEST_F(MslGeneratorImplTest, Emit_Function) {
ast::type::VoidType void_type;
- ast::Function func("my_func", {}, &void_type);
+ auto func = std::make_unique<ast::Function>("my_func", ast::VariableList{},
+ &void_type);
ast::StatementList body;
body.push_back(std::make_unique<ast::ReturnStatement>());
- func.set_body(std::move(body));
+ func->set_body(std::move(body));
+
+ ast::Module m;
+ m.AddFunction(std::move(func));
GeneratorImpl g;
g.increment_indent();
- ASSERT_TRUE(g.EmitFunction(&func));
+ ASSERT_TRUE(g.Generate(m)) << g.error();
EXPECT_EQ(g.result(), R"( void my_func() {
return;
}
+
)");
}
@@ -59,19 +65,64 @@
std::make_unique<ast::Variable>("b", ast::StorageClass::kNone, &i32));
ast::type::VoidType void_type;
- ast::Function func("my_func", std::move(params), &void_type);
+ auto func =
+ std::make_unique<ast::Function>("my_func", std::move(params), &void_type);
ast::StatementList body;
body.push_back(std::make_unique<ast::ReturnStatement>());
- func.set_body(std::move(body));
+ func->set_body(std::move(body));
+
+ ast::Module m;
+ m.AddFunction(std::move(func));
GeneratorImpl g;
g.increment_indent();
- ASSERT_TRUE(g.EmitFunction(&func));
+ ASSERT_TRUE(g.Generate(m)) << g.error();
EXPECT_EQ(g.result(), R"( void my_func(float a, int b) {
return;
}
+
+)");
+}
+
+TEST_F(MslGeneratorImplTest, Emit_Function_EntryPoint_NoName) {
+ ast::type::VoidType void_type;
+
+ auto func = std::make_unique<ast::Function>("frag_main", ast::VariableList{},
+ &void_type);
+ auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kFragment, "",
+ "frag_main");
+
+ ast::Module m;
+ m.AddFunction(std::move(func));
+ m.AddEntryPoint(std::move(ep));
+
+ GeneratorImpl g;
+ ASSERT_TRUE(g.Generate(m)) << g.error();
+ EXPECT_EQ(g.result(), R"(fragment void frag_main() {
+}
+
+)");
+}
+
+TEST_F(MslGeneratorImplTest, Emit_Function_EntryPoint_WithName) {
+ ast::type::VoidType void_type;
+
+ auto func = std::make_unique<ast::Function>("comp_main", ast::VariableList{},
+ &void_type);
+ auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kCompute,
+ "main", "comp_main");
+
+ ast::Module m;
+ m.AddFunction(std::move(func));
+ m.AddEntryPoint(std::move(ep));
+
+ GeneratorImpl g;
+ ASSERT_TRUE(g.Generate(m)) << g.error();
+ EXPECT_EQ(g.result(), R"(kernel void main() {
+}
+
)");
}