[metal-writer] Add entry point support.

This CL adds preliminary entry point support to the Metal backend.

Bug: tint:8
Change-Id: I7b904621d706d4503d5054711de64872f79cf2fa
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/23708
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc
index ba4c935..61e560d 100644
--- a/src/writer/msl/generator_impl.cc
+++ b/src/writer/msl/generator_impl.cc
@@ -44,12 +44,16 @@
 GeneratorImpl::~GeneratorImpl() = default;
 
 bool GeneratorImpl::Generate(const ast::Module& module) {
+  module_ = &module;
+
   for (const auto& func : module.functions()) {
     if (!EmitFunction(func.get())) {
       return false;
     }
     out_ << std::endl;
   }
+
+  module_ = nullptr;
   return true;
 }
 
@@ -227,14 +231,51 @@
   return false;
 }
 
+void GeneratorImpl::EmitStage(ast::PipelineStage stage) {
+  switch (stage) {
+    case ast::PipelineStage::kFragment:
+      out_ << "fragment";
+      break;
+    case ast::PipelineStage::kVertex:
+      out_ << "vertex";
+      break;
+    case ast::PipelineStage::kCompute:
+      out_ << "kernel";
+      break;
+    case ast::PipelineStage::kNone:
+      break;
+  }
+  return;
+}
+
 bool GeneratorImpl::EmitFunction(ast::Function* func) {
   make_indent();
 
+  // TODO(dsinclair): Technically this is wrong as you could, in theory, have
+  // multiple entry points pointing at the same function. I'm ignoring that for
+  // now. It will either go away with the entry_point changes in the spec
+  // or we'll have to figure out how to deal with it.
+
+  auto name = func->name();
+
+  for (const auto& ep : module_->entry_points()) {
+    if (ep->function_name() == name) {
+      EmitStage(ep->stage());
+      out_ << " ";
+
+      if (!ep->name().empty()) {
+        name = ep->name();
+      }
+
+      break;
+    }
+  }
+
   if (!EmitType(func->return_type(), "")) {
     return false;
   }
 
-  out_ << " " << func->name() << "(";
+  out_ << " " << name << "(";
 
   bool first = true;
   for (const auto& v : func->params()) {
diff --git a/src/writer/msl/generator_impl.h b/src/writer/msl/generator_impl.h
index 491f120..10d9073 100644
--- a/src/writer/msl/generator_impl.h
+++ b/src/writer/msl/generator_impl.h
@@ -76,6 +76,9 @@
   /// @param expr the scalar constructor expression
   /// @returns true if the scalar constructor is emitted
   bool EmitScalarConstructor(ast::ScalarConstructorExpression* expr);
+  /// Handles emitting a pipeline stage name
+  /// @param stage the stage to emit
+  void EmitStage(ast::PipelineStage stage);
   /// Handles a brace-enclosed list of statements.
   /// @param statements the statements to output
   /// @returns true if the statements were emitted
@@ -97,6 +100,9 @@
   /// @param expr the type constructor expression
   /// @returns true if the constructor is emitted
   bool EmitTypeConstructor(ast::TypeConstructorExpression* expr);
+
+ private:
+  const ast::Module* module_ = nullptr;
 };
 
 }  // namespace msl
diff --git a/src/writer/msl/generator_impl_function_test.cc b/src/writer/msl/generator_impl_function_test.cc
index d47864e..965eaba 100644
--- a/src/writer/msl/generator_impl_function_test.cc
+++ b/src/writer/msl/generator_impl_function_test.cc
@@ -14,6 +14,7 @@
 
 #include "gtest/gtest.h"
 #include "src/ast/function.h"
+#include "src/ast/module.h"
 #include "src/ast/return_statement.h"
 #include "src/ast/type/array_type.h"
 #include "src/ast/type/f32_type.h"
@@ -32,19 +33,24 @@
 TEST_F(MslGeneratorImplTest, Emit_Function) {
   ast::type::VoidType void_type;
 
-  ast::Function func("my_func", {}, &void_type);
+  auto func = std::make_unique<ast::Function>("my_func", ast::VariableList{},
+                                              &void_type);
 
   ast::StatementList body;
   body.push_back(std::make_unique<ast::ReturnStatement>());
-  func.set_body(std::move(body));
+  func->set_body(std::move(body));
+
+  ast::Module m;
+  m.AddFunction(std::move(func));
 
   GeneratorImpl g;
   g.increment_indent();
 
-  ASSERT_TRUE(g.EmitFunction(&func));
+  ASSERT_TRUE(g.Generate(m)) << g.error();
   EXPECT_EQ(g.result(), R"(  void my_func() {
     return;
   }
+
 )");
 }
 
@@ -59,19 +65,64 @@
       std::make_unique<ast::Variable>("b", ast::StorageClass::kNone, &i32));
 
   ast::type::VoidType void_type;
-  ast::Function func("my_func", std::move(params), &void_type);
+  auto func =
+      std::make_unique<ast::Function>("my_func", std::move(params), &void_type);
 
   ast::StatementList body;
   body.push_back(std::make_unique<ast::ReturnStatement>());
-  func.set_body(std::move(body));
+  func->set_body(std::move(body));
+
+  ast::Module m;
+  m.AddFunction(std::move(func));
 
   GeneratorImpl g;
   g.increment_indent();
 
-  ASSERT_TRUE(g.EmitFunction(&func));
+  ASSERT_TRUE(g.Generate(m)) << g.error();
   EXPECT_EQ(g.result(), R"(  void my_func(float a, int b) {
     return;
   }
+
+)");
+}
+
+TEST_F(MslGeneratorImplTest, Emit_Function_EntryPoint_NoName) {
+  ast::type::VoidType void_type;
+
+  auto func = std::make_unique<ast::Function>("frag_main", ast::VariableList{},
+                                              &void_type);
+  auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kFragment, "",
+                                              "frag_main");
+
+  ast::Module m;
+  m.AddFunction(std::move(func));
+  m.AddEntryPoint(std::move(ep));
+
+  GeneratorImpl g;
+  ASSERT_TRUE(g.Generate(m)) << g.error();
+  EXPECT_EQ(g.result(), R"(fragment void frag_main() {
+}
+
+)");
+}
+
+TEST_F(MslGeneratorImplTest, Emit_Function_EntryPoint_WithName) {
+  ast::type::VoidType void_type;
+
+  auto func = std::make_unique<ast::Function>("comp_main", ast::VariableList{},
+                                              &void_type);
+  auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kCompute,
+                                              "main", "comp_main");
+
+  ast::Module m;
+  m.AddFunction(std::move(func));
+  m.AddEntryPoint(std::move(ep));
+
+  GeneratorImpl g;
+  ASSERT_TRUE(g.Generate(m)) << g.error();
+  EXPECT_EQ(g.result(), R"(kernel void main() {
+}
+
 )");
 }