[metal-writer] Add entry point support. This CL adds preliminary entry point support to the Metal backend. Bug: tint:8 Change-Id: I7b904621d706d4503d5054711de64872f79cf2fa Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/23708 Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc index ba4c935..61e560d 100644 --- a/src/writer/msl/generator_impl.cc +++ b/src/writer/msl/generator_impl.cc
@@ -44,12 +44,16 @@ GeneratorImpl::~GeneratorImpl() = default; bool GeneratorImpl::Generate(const ast::Module& module) { + module_ = &module; + for (const auto& func : module.functions()) { if (!EmitFunction(func.get())) { return false; } out_ << std::endl; } + + module_ = nullptr; return true; } @@ -227,14 +231,51 @@ return false; } +void GeneratorImpl::EmitStage(ast::PipelineStage stage) { + switch (stage) { + case ast::PipelineStage::kFragment: + out_ << "fragment"; + break; + case ast::PipelineStage::kVertex: + out_ << "vertex"; + break; + case ast::PipelineStage::kCompute: + out_ << "kernel"; + break; + case ast::PipelineStage::kNone: + break; + } + return; +} + bool GeneratorImpl::EmitFunction(ast::Function* func) { make_indent(); + // TODO(dsinclair): Technically this is wrong as you could, in theory, have + // multiple entry points pointing at the same function. I'm ignoring that for + // now. It will either go away with the entry_point changes in the spec + // or we'll have to figure out how to deal with it. + + auto name = func->name(); + + for (const auto& ep : module_->entry_points()) { + if (ep->function_name() == name) { + EmitStage(ep->stage()); + out_ << " "; + + if (!ep->name().empty()) { + name = ep->name(); + } + + break; + } + } + if (!EmitType(func->return_type(), "")) { return false; } - out_ << " " << func->name() << "("; + out_ << " " << name << "("; bool first = true; for (const auto& v : func->params()) {
diff --git a/src/writer/msl/generator_impl.h b/src/writer/msl/generator_impl.h index 491f120..10d9073 100644 --- a/src/writer/msl/generator_impl.h +++ b/src/writer/msl/generator_impl.h
@@ -76,6 +76,9 @@ /// @param expr the scalar constructor expression /// @returns true if the scalar constructor is emitted bool EmitScalarConstructor(ast::ScalarConstructorExpression* expr); + /// Handles emitting a pipeline stage name + /// @param stage the stage to emit + void EmitStage(ast::PipelineStage stage); /// Handles a brace-enclosed list of statements. /// @param statements the statements to output /// @returns true if the statements were emitted @@ -97,6 +100,9 @@ /// @param expr the type constructor expression /// @returns true if the constructor is emitted bool EmitTypeConstructor(ast::TypeConstructorExpression* expr); + + private: + const ast::Module* module_ = nullptr; }; } // namespace msl
diff --git a/src/writer/msl/generator_impl_function_test.cc b/src/writer/msl/generator_impl_function_test.cc index d47864e..965eaba 100644 --- a/src/writer/msl/generator_impl_function_test.cc +++ b/src/writer/msl/generator_impl_function_test.cc
@@ -14,6 +14,7 @@ #include "gtest/gtest.h" #include "src/ast/function.h" +#include "src/ast/module.h" #include "src/ast/return_statement.h" #include "src/ast/type/array_type.h" #include "src/ast/type/f32_type.h" @@ -32,19 +33,24 @@ TEST_F(MslGeneratorImplTest, Emit_Function) { ast::type::VoidType void_type; - ast::Function func("my_func", {}, &void_type); + auto func = std::make_unique<ast::Function>("my_func", ast::VariableList{}, + &void_type); ast::StatementList body; body.push_back(std::make_unique<ast::ReturnStatement>()); - func.set_body(std::move(body)); + func->set_body(std::move(body)); + + ast::Module m; + m.AddFunction(std::move(func)); GeneratorImpl g; g.increment_indent(); - ASSERT_TRUE(g.EmitFunction(&func)); + ASSERT_TRUE(g.Generate(m)) << g.error(); EXPECT_EQ(g.result(), R"( void my_func() { return; } + )"); } @@ -59,19 +65,64 @@ std::make_unique<ast::Variable>("b", ast::StorageClass::kNone, &i32)); ast::type::VoidType void_type; - ast::Function func("my_func", std::move(params), &void_type); + auto func = + std::make_unique<ast::Function>("my_func", std::move(params), &void_type); ast::StatementList body; body.push_back(std::make_unique<ast::ReturnStatement>()); - func.set_body(std::move(body)); + func->set_body(std::move(body)); + + ast::Module m; + m.AddFunction(std::move(func)); GeneratorImpl g; g.increment_indent(); - ASSERT_TRUE(g.EmitFunction(&func)); + ASSERT_TRUE(g.Generate(m)) << g.error(); EXPECT_EQ(g.result(), R"( void my_func(float a, int b) { return; } + +)"); +} + +TEST_F(MslGeneratorImplTest, Emit_Function_EntryPoint_NoName) { + ast::type::VoidType void_type; + + auto func = std::make_unique<ast::Function>("frag_main", ast::VariableList{}, + &void_type); + auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kFragment, "", + "frag_main"); + + ast::Module m; + m.AddFunction(std::move(func)); + m.AddEntryPoint(std::move(ep)); + + GeneratorImpl g; + ASSERT_TRUE(g.Generate(m)) << g.error(); + EXPECT_EQ(g.result(), R"(fragment void frag_main() { +} + +)"); +} + +TEST_F(MslGeneratorImplTest, Emit_Function_EntryPoint_WithName) { + ast::type::VoidType void_type; + + auto func = std::make_unique<ast::Function>("comp_main", ast::VariableList{}, + &void_type); + auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kCompute, + "main", "comp_main"); + + ast::Module m; + m.AddFunction(std::move(func)); + m.AddEntryPoint(std::move(ep)); + + GeneratorImpl g; + ASSERT_TRUE(g.Generate(m)) << g.error(); + EXPECT_EQ(g.result(), R"(kernel void main() { +} + )"); }