[msl-writer] Generate entry point functions.

This CL generates entry point functions and duplicate functions as
needed to call from the entry points.

Bug: tint:8
Change-Id: I8092ce463248e7a887c26ae05b0774e8fa21ab94
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/24764
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/ast/decorated_variable.cc b/src/ast/decorated_variable.cc
index 340075e..89db9e4 100644
--- a/src/ast/decorated_variable.cc
+++ b/src/ast/decorated_variable.cc
@@ -26,6 +26,15 @@
 
 DecoratedVariable::~DecoratedVariable() = default;
 
+bool DecoratedVariable::HasLocationDecoration() const {
+  for (const auto& deco : decorations_) {
+    if (deco->IsLocation()) {
+      return true;
+    }
+  }
+  return false;
+}
+
 bool DecoratedVariable::IsDecorated() const {
   return true;
 }
diff --git a/src/ast/decorated_variable.h b/src/ast/decorated_variable.h
index 992e641..d2e381a 100644
--- a/src/ast/decorated_variable.h
+++ b/src/ast/decorated_variable.h
@@ -45,6 +45,9 @@
   /// @returns the decorations attached to this variable
   const VariableDecorationList& decorations() const { return decorations_; }
 
+  /// @returns true if the decorations include a LocationDecoration
+  bool HasLocationDecoration() const;
+
   /// @returns true if this is a decorated variable
   bool IsDecorated() const override;
 
diff --git a/src/ast/module.cc b/src/ast/module.cc
index 7ddd7e7..74a081c 100644
--- a/src/ast/module.cc
+++ b/src/ast/module.cc
@@ -43,6 +43,14 @@
   return nullptr;
 }
 
+bool Module::IsFunctionEntryPoint(const std::string& name) const {
+  for (const auto& ep : entry_points_) {
+    if (ep->function_name() == name)
+      return true;
+  }
+  return false;
+}
+
 bool Module::IsValid() const {
   for (const auto& import : imports_) {
     if (import == nullptr || !import->IsValid()) {
diff --git a/src/ast/module.h b/src/ast/module.h
index 77e1924..d17a0cb 100644
--- a/src/ast/module.h
+++ b/src/ast/module.h
@@ -65,6 +65,11 @@
   /// @returns the entry points in the module
   const EntryPointList& entry_points() const { return entry_points_; }
 
+  /// Checks if the given function name is an entry point function
+  /// @param name the function name
+  /// @returns true if name is an entry point function
+  bool IsFunctionEntryPoint(const std::string& name) const;
+
   /// Adds a type alias to the module
   /// @param type the alias to add
   void AddAliasType(type::AliasType* type) { alias_types_.push_back(type); }
diff --git a/src/ast/module_test.cc b/src/ast/module_test.cc
index 869f304..4cd0ebb 100644
--- a/src/ast/module_test.cc
+++ b/src/ast/module_test.cc
@@ -91,6 +91,19 @@
   EXPECT_EQ(func_ptr, m.FindFunctionByName("main"));
 }
 
+TEST_F(ModuleTest, IsEntryPoint) {
+  type::F32Type f32;
+  Module m;
+
+  auto func = std::make_unique<Function>("other_func", VariableList{}, &f32);
+  m.AddFunction(std::move(func));
+
+  m.AddEntryPoint(
+      std::make_unique<EntryPoint>(PipelineStage::kVertex, "main", "vtx_main"));
+  EXPECT_TRUE(m.IsFunctionEntryPoint("vtx_main"));
+  EXPECT_FALSE(m.IsFunctionEntryPoint("other_func"));
+}
+
 TEST_F(ModuleTest, LookupFunctionMissing) {
   Module m;
   EXPECT_EQ(nullptr, m.FindFunctionByName("Missing"));
diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc
index 5568d7e..defabf5 100644
--- a/src/writer/msl/generator_impl.cc
+++ b/src/writer/msl/generator_impl.cc
@@ -59,6 +59,8 @@
 
 const char kInStructNameSuffix[] = "in";
 const char kOutStructNameSuffix[] = "out";
+const char kTintStructInVarPrefix[] = "tint_in";
+const char kTintStructOutVarPrefix[] = "tint_out";
 
 bool last_is_break_or_fallthrough(const ast::StatementList& stmts) {
   if (stmts.empty()) {
@@ -78,13 +80,11 @@
   module_ = mod;
 }
 
-std::string GeneratorImpl::generate_struct_name(ast::EntryPoint* ep,
-                                                const std::string& type) {
-  std::string base_name = ep->function_name() + "_" + type;
-  std::string name = base_name;
+std::string GeneratorImpl::generate_name(const std::string& prefix) {
+  std::string name = prefix;
   uint32_t i = 0;
   while (namer_.IsMapped(name)) {
-    name = base_name + "_" + std::to_string(i);
+    name = prefix + "_" + std::to_string(i);
     ++i;
   }
   namer_.RegisterRemappedName(name);
@@ -96,6 +96,10 @@
 
   out_ << "#include <metal_stdlib>" << std::endl << std::endl;
 
+  for (const auto& global : module.global_variables()) {
+    global_variables_.set(global->name(), global.get());
+  }
+
   for (auto* const alias : module.alias_types()) {
     if (!EmitAliasType(alias)) {
       return false;
@@ -106,7 +110,7 @@
   }
 
   for (const auto& ep : module.entry_points()) {
-    if (!EmitEntryPoint(ep.get())) {
+    if (!EmitEntryPointData(ep.get())) {
       return false;
     }
   }
@@ -115,6 +119,12 @@
     if (!EmitFunction(func.get())) {
       return false;
     }
+  }
+
+  for (const auto& ep : module.entry_points()) {
+    if (!EmitEntryPointFunction(ep.get())) {
+      return false;
+    }
     out_ << std::endl;
   }
 
@@ -283,12 +293,32 @@
   }
 
   if (!ident->has_path()) {
-    if (!EmitExpression(expr->func())) {
-      return false;
+    auto name = ident->name();
+    auto it = ep_func_name_remapped_.find(current_ep_name_ + "_" + name);
+    if (it != ep_func_name_remapped_.end()) {
+      name = it->second;
     }
-    out_ << "(";
+    out_ << name << "(";
 
     bool first = true;
+
+    auto in_it = ep_name_to_in_data_.find(current_ep_name_);
+    if (in_it != ep_name_to_in_data_.end()) {
+      out_ << in_it->second.var_name;
+      first = false;
+    }
+
+    auto out_it = ep_name_to_out_data_.find(current_ep_name_);
+    if (out_it != ep_name_to_out_data_.end()) {
+      if (!first) {
+        out_ << ", ";
+      }
+      out_ << out_it->second.var_name;
+      first = false;
+    }
+
+    // TODO(dsinclair): Emit builtins
+
     const auto& params = expr->params();
     for (const auto& param : params) {
       if (!first) {
@@ -459,7 +489,7 @@
   return true;
 }
 
-bool GeneratorImpl::EmitEntryPoint(ast::EntryPoint* ep) {
+bool GeneratorImpl::EmitEntryPointData(ast::EntryPoint* ep) {
   auto* func = module_->FindFunctionByName(ep->function_name());
   if (func == nullptr) {
     error_ = "Unable to find entry point function: " + ep->function_name();
@@ -491,9 +521,20 @@
     }
   }
 
+  auto ep_name = ep->name();
+  if (ep_name.empty()) {
+    ep_name = ep->function_name();
+  }
+
+  // TODO(dsinclair): There is a potential bug here. Entry points can have the
+  // same name in WGSL if they have different pipeline stages. This does not
+  // take that into account and will emit duplicate struct names. I'm ignoring
+  // this until https://github.com/gpuweb/gpuweb/issues/662 is resolved as it
+  // may remove this issue and entry point names will need to be unique.
   if (!in_locations.empty()) {
-    auto in_struct_name = generate_struct_name(ep, kInStructNameSuffix);
-    ep_name_to_in_struct_[ep->name()] = in_struct_name;
+    auto in_struct_name = generate_name(ep_name + "_" + kInStructNameSuffix);
+    auto in_var_name = generate_name(kTintStructInVarPrefix);
+    ep_name_to_in_data_[ep_name] = {in_struct_name, in_var_name};
 
     make_indent();
     out_ << "struct " << in_struct_name << " {" << std::endl;
@@ -527,8 +568,9 @@
   }
 
   if (!out_locations.empty()) {
-    auto out_struct_name = generate_struct_name(ep, kOutStructNameSuffix);
-    ep_name_to_out_struct_[ep->name()] = out_struct_name;
+    auto out_struct_name = generate_name(ep_name + "_" + kOutStructNameSuffix);
+    auto out_var_name = generate_name(kTintStructOutVarPrefix);
+    ep_name_to_out_data_[ep_name] = {out_struct_name, out_var_name};
 
     make_indent();
     out_ << "struct " << out_struct_name << " {" << std::endl;
@@ -615,33 +657,82 @@
 bool GeneratorImpl::EmitFunction(ast::Function* func) {
   make_indent();
 
-  // TODO(dsinclair): Technically this is wrong as you could, in theory, have
-  // multiple entry points pointing at the same function. I'm ignoring that for
-  // now. It will either go away with the entry_point changes in the spec
-  // or we'll have to figure out how to deal with it.
-
-  auto name = func->name();
-
-  for (const auto& ep : module_->entry_points()) {
-    if (ep->function_name() == name) {
-      EmitStage(ep->stage());
-      out_ << " ";
-
-      if (!ep->name().empty()) {
-        name = ep->name();
-      }
-
-      break;
-    }
+  // Entry points will be emitted later, skip for now.
+  if (module_->IsFunctionEntryPoint(func->name())) {
+    return true;
   }
 
+  // TODO(dsinclair): This could be smarter. If the input/outputs for multiple
+  // entry points are the same we could generate a single struct and then have
+  // this determine it's the same struct and just emit once.
+  bool emit_duplicate_functions =
+      func->ancestor_entry_points().size() > 0 &&
+      func->referenced_module_variables().size() > 0;
+
+  if (emit_duplicate_functions) {
+    for (const auto& ep_name : func->ancestor_entry_points()) {
+      if (!EmitFunctionInternal(func, emit_duplicate_functions, ep_name)) {
+        return false;
+      }
+      out_ << std::endl;
+    }
+  } else {
+    // Emit as non-duplicated
+    if (!EmitFunctionInternal(func, false, "")) {
+      return false;
+    }
+    out_ << std::endl;
+  }
+
+  return true;
+}
+
+bool GeneratorImpl::EmitFunctionInternal(ast::Function* func,
+                                         bool emit_duplicate_functions,
+                                         const std::string& ep_name) {
+  auto name = func->name();
+
   if (!EmitType(func->return_type(), "")) {
     return false;
   }
 
-  out_ << " " << namer_.NameFor(name) << "(";
+  out_ << " ";
+
+  if (emit_duplicate_functions) {
+    name = generate_name(name + "_" + ep_name);
+    ep_func_name_remapped_[ep_name + "_" + func->name()] = name;
+  } else {
+    name = namer_.NameFor(name);
+  }
+  out_ << name << "(";
 
   bool first = true;
+
+  // If we're emitting duplicate functions that means the function takes
+  // the stage_in or stage_out value from the entry point, emit them.
+  //
+  // We emit both of them if they're there regardless of if they're both used.
+  if (emit_duplicate_functions) {
+    auto in_it = ep_name_to_in_data_.find(ep_name);
+    if (in_it != ep_name_to_in_data_.end()) {
+      out_ << "thread " << in_it->second.struct_name << "& "
+           << in_it->second.var_name;
+      first = false;
+    }
+
+    auto out_it = ep_name_to_out_data_.find(ep_name);
+    if (out_it != ep_name_to_out_data_.end()) {
+      if (!first) {
+        out_ << ", ";
+      }
+      out_ << "thread " << out_it->second.struct_name << "& "
+           << out_it->second.var_name;
+      first = false;
+    }
+  }
+
+  // TODO(dsinclair): Handle any entry point builtin params used here
+
   for (const auto& v : func->params()) {
     if (!first) {
       out_ << ", ";
@@ -656,9 +747,79 @@
       out_ << " " << v->name();
     }
   }
+
   out_ << ")";
 
-  return EmitStatementBlockAndNewline(func->body());
+  current_ep_name_ = ep_name;
+
+  if (!EmitStatementBlockAndNewline(func->body())) {
+    return false;
+  }
+
+  current_ep_name_ = "";
+
+  return true;
+}
+
+bool GeneratorImpl::EmitEntryPointFunction(ast::EntryPoint* ep) {
+  make_indent();
+
+  current_ep_name_ = ep->name();
+  if (current_ep_name_.empty()) {
+    current_ep_name_ = ep->function_name();
+  }
+
+  auto* func = module_->FindFunctionByName(ep->function_name());
+  if (func == nullptr) {
+    error_ = "unable to find function for entry point: " + ep->function_name();
+    return false;
+  }
+
+  EmitStage(ep->stage());
+  out_ << " ";
+
+  // This is an entry point, the return type is the entry point output structure
+  // if one exists, or void otherwise.
+  auto out_data = ep_name_to_out_data_.find(current_ep_name_);
+  bool has_out_data = out_data != ep_name_to_out_data_.end();
+  if (has_out_data) {
+    out_ << out_data->second.struct_name;
+  } else {
+    out_ << "void";
+  }
+  out_ << " " << namer_.NameFor(current_ep_name_) << "(";
+
+  auto in_data = ep_name_to_in_data_.find(current_ep_name_);
+  if (in_data != ep_name_to_in_data_.end()) {
+    out_ << in_data->second.struct_name << " " << in_data->second.var_name
+         << " [[stage_in]]";
+  }
+
+  // TODO(dsinclair): Output other builtin inputs
+  out_ << ") {" << std::endl;
+
+  increment_indent();
+
+  if (has_out_data) {
+    make_indent();
+    out_ << out_data->second.struct_name << " " << out_data->second.var_name
+         << " = {};" << std::endl;
+  }
+
+  generating_entry_point_ = true;
+  for (const auto& s : func->body()) {
+    if (!EmitStatement(s.get())) {
+      return false;
+    }
+  }
+  generating_entry_point_ = false;
+
+  decrement_indent();
+  make_indent();
+  out_ << "}" << std::endl;
+
+  current_ep_name_ = "";
+  return true;
 }
 
 bool GeneratorImpl::EmitIdentifier(ast::IdentifierExpression* expr) {
@@ -668,7 +829,30 @@
     error_ = "Identifier paths not handled yet.";
     return false;
   }
+
+  ast::Variable* var = nullptr;
+  if (global_variables_.get(ident->name(), &var)) {
+    if (var->storage_class() == ast::StorageClass::kInput &&
+        var->IsDecorated() && var->AsDecorated()->HasLocationDecoration()) {
+      auto it = ep_name_to_in_data_.find(current_ep_name_);
+      if (it == ep_name_to_in_data_.end()) {
+        error_ = "unable to find entry point data for input";
+        return false;
+      }
+      out_ << it->second.var_name << ".";
+    } else if (var->storage_class() == ast::StorageClass::kOutput &&
+               var->IsDecorated() &&
+               var->AsDecorated()->HasLocationDecoration()) {
+      auto it = ep_name_to_out_data_.find(current_ep_name_);
+      if (it == ep_name_to_out_data_.end()) {
+        error_ = "unable to find entry point data for output";
+        return false;
+      }
+      out_ << it->second.var_name << ".";
+    }
+  }
   out_ << namer_.NameFor(ident->name());
+
   return true;
 }
 
@@ -785,7 +969,13 @@
   make_indent();
 
   out_ << "return";
-  if (stmt->has_value()) {
+
+  if (generating_entry_point_) {
+    auto out_data = ep_name_to_out_data_.find(current_ep_name_);
+    if (out_data != ep_name_to_out_data_.end()) {
+      out_ << " " << out_data->second.var_name;
+    }
+  } else if (stmt->has_value()) {
     out_ << " ";
     if (!EmitExpression(stmt->value())) {
       return false;
diff --git a/src/writer/msl/generator_impl.h b/src/writer/msl/generator_impl.h
index 4c94f3c..c851ee7 100644
--- a/src/writer/msl/generator_impl.h
+++ b/src/writer/msl/generator_impl.h
@@ -23,6 +23,7 @@
 #include "src/ast/module.h"
 #include "src/ast/scalar_constructor_expression.h"
 #include "src/ast/type_constructor_expression.h"
+#include "src/scope_stack.h"
 #include "src/writer/msl/namer.h"
 #include "src/writer/text_generator.h"
 
@@ -93,7 +94,11 @@
   /// Handles emitting information for an entry point
   /// @param ep the entry point
   /// @returns true if the entry point data was emitted
-  bool EmitEntryPoint(ast::EntryPoint* ep);
+  bool EmitEntryPointData(ast::EntryPoint* ep);
+  /// Handles emitting the entry point function
+  /// @param ep the entry point
+  /// @returns true if the entry point function was emitted
+  bool EmitEntryPointFunction(ast::EntryPoint* ep);
   /// Handles generate an Expression
   /// @param expr the expression
   /// @returns true if the expression was emitted
@@ -102,6 +107,15 @@
   /// @param func the function to generate
   /// @returns true if the function was emitted
   bool EmitFunction(ast::Function* func);
+  /// Internal helper for emitting functions
+  /// @param func the function to emit
+  /// @param emit_duplicate_functions set true if we need to duplicate per entry
+  /// point
+  /// @param ep_name the current entry point or blank if none set
+  /// @returns true if the function was emitted.
+  bool EmitFunctionInternal(ast::Function* func,
+                            bool emit_duplicate_functions,
+                            const std::string& ep_name);
   /// Handles generating an identifier expression
   /// @param expr the identifier expression
   /// @returns true if the identifeir was emitted
@@ -179,22 +193,33 @@
   /// @param mod the module to set.
   void set_module_for_testing(ast::Module* mod);
 
-  /// Generates a name for the input struct
-  /// @param ep the entry point to generate for
-  /// @param type the type of struct to generate
-  /// @returns the input struct name
-  std::string generate_struct_name(ast::EntryPoint* ep,
-                                   const std::string& type);
+  /// Generates a name for the prefix
+  /// @param prefix the prefix of the name to generate
+  /// @returns the name
+  std::string generate_name(const std::string& prefix);
 
   /// @returns the namer for testing
   Namer* namer_for_testing() { return &namer_; }
 
  private:
   Namer namer_;
+  ScopeStack<ast::Variable*> global_variables_;
+  std::string current_ep_name_;
+  bool generating_entry_point_ = false;
   const ast::Module* module_ = nullptr;
   uint32_t loop_emission_counter_ = 0;
-  std::unordered_map<std::string, std::string> ep_name_to_in_struct_;
-  std::unordered_map<std::string, std::string> ep_name_to_out_struct_;
+
+  struct EntryPointData {
+    std::string struct_name;
+    std::string var_name;
+  };
+  std::unordered_map<std::string, EntryPointData> ep_name_to_in_data_;
+  std::unordered_map<std::string, EntryPointData> ep_name_to_out_data_;
+
+  // This maps an input of "<entry_point_name>_<function_name>" to a remapped
+  // function name. If there is no entry for a given key then function did
+  // not need to be remapped for the entry point and can be emitted directly.
+  std::unordered_map<std::string, std::string> ep_func_name_remapped_;
 };
 
 }  // namespace msl
diff --git a/src/writer/msl/generator_impl_entry_point_test.cc b/src/writer/msl/generator_impl_entry_point_test.cc
index a102300..824ac6c 100644
--- a/src/writer/msl/generator_impl_entry_point_test.cc
+++ b/src/writer/msl/generator_impl_entry_point_test.cc
@@ -33,7 +33,7 @@
 
 using MslGeneratorImplTest = testing::Test;
 
-TEST_F(MslGeneratorImplTest, EmitEntryPoint_Vertex_Input) {
+TEST_F(MslGeneratorImplTest, EmitEntryPointData_Vertex_Input) {
   // [[location 0]] var<in> foo : f32;
   // [[location 1]] var<in> bar : i32;
   //
@@ -81,8 +81,8 @@
 
   mod.AddFunction(std::move(func));
 
-  auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kVertex,
-                                              "main", "vtx_main");
+  auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kVertex, "",
+                                              "vtx_main");
   auto* ep_ptr = ep.get();
 
   mod.AddEntryPoint(std::move(ep));
@@ -91,7 +91,7 @@
 
   GeneratorImpl g;
   g.set_module_for_testing(&mod);
-  ASSERT_TRUE(g.EmitEntryPoint(ep_ptr)) << g.error();
+  ASSERT_TRUE(g.EmitEntryPointData(ep_ptr)) << g.error();
   EXPECT_EQ(g.result(), R"(struct vtx_main_in {
   float foo [[attribute(0)]];
   int bar [[attribute(1)]];
@@ -100,7 +100,7 @@
 )");
 }
 
-TEST_F(MslGeneratorImplTest, EmitEntryPoint_Vertex_Output) {
+TEST_F(MslGeneratorImplTest, EmitEntryPointData_Vertex_Output) {
   // [[location 0]] var<out> foo : f32;
   // [[location 1]] var<out> bar : i32;
   //
@@ -148,8 +148,8 @@
 
   mod.AddFunction(std::move(func));
 
-  auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kVertex,
-                                              "main", "vtx_main");
+  auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kVertex, "",
+                                              "vtx_main");
   auto* ep_ptr = ep.get();
 
   mod.AddEntryPoint(std::move(ep));
@@ -158,7 +158,7 @@
 
   GeneratorImpl g;
   g.set_module_for_testing(&mod);
-  ASSERT_TRUE(g.EmitEntryPoint(ep_ptr)) << g.error();
+  ASSERT_TRUE(g.EmitEntryPointData(ep_ptr)) << g.error();
   EXPECT_EQ(g.result(), R"(struct vtx_main_out {
   float foo [[user(locn0)]];
   int bar [[user(locn1)]];
@@ -167,7 +167,7 @@
 )");
 }
 
-TEST_F(MslGeneratorImplTest, EmitEntryPoint_Fragment_Input) {
+TEST_F(MslGeneratorImplTest, EmitEntryPointData_Fragment_Input) {
   // [[location 0]] var<in> foo : f32;
   // [[location 1]] var<in> bar : i32;
   //
@@ -225,8 +225,8 @@
 
   GeneratorImpl g;
   g.set_module_for_testing(&mod);
-  ASSERT_TRUE(g.EmitEntryPoint(ep_ptr)) << g.error();
-  EXPECT_EQ(g.result(), R"(struct frag_main_in {
+  ASSERT_TRUE(g.EmitEntryPointData(ep_ptr)) << g.error();
+  EXPECT_EQ(g.result(), R"(struct main_in {
   float foo [[user(locn0)]];
   int bar [[user(locn1)]];
 };
@@ -234,7 +234,7 @@
 )");
 }
 
-TEST_F(MslGeneratorImplTest, EmitEntryPoint_Fragment_Output) {
+TEST_F(MslGeneratorImplTest, EmitEntryPointData_Fragment_Output) {
   // [[location 0]] var<out> foo : f32;
   // [[location 1]] var<out> bar : i32;
   //
@@ -292,8 +292,8 @@
 
   GeneratorImpl g;
   g.set_module_for_testing(&mod);
-  ASSERT_TRUE(g.EmitEntryPoint(ep_ptr)) << g.error();
-  EXPECT_EQ(g.result(), R"(struct frag_main_out {
+  ASSERT_TRUE(g.EmitEntryPointData(ep_ptr)) << g.error();
+  EXPECT_EQ(g.result(), R"(struct main_out {
   float foo [[color(0)]];
   int bar [[color(1)]];
 };
@@ -301,7 +301,7 @@
 )");
 }
 
-TEST_F(MslGeneratorImplTest, EmitEntryPoint_Compute_Input) {
+TEST_F(MslGeneratorImplTest, EmitEntryPointData_Compute_Input) {
   // [[location 0]] var<in> foo : f32;
   // [[location 1]] var<in> bar : i32;
   //
@@ -356,11 +356,11 @@
 
   GeneratorImpl g;
   g.set_module_for_testing(&mod);
-  ASSERT_FALSE(g.EmitEntryPoint(ep_ptr)) << g.error();
+  ASSERT_FALSE(g.EmitEntryPointData(ep_ptr)) << g.error();
   EXPECT_EQ(g.error(), R"(invalid location variable for pipeline stage)");
 }
 
-TEST_F(MslGeneratorImplTest, EmitEntryPoint_Compute_Output) {
+TEST_F(MslGeneratorImplTest, EmitEntryPointData_Compute_Output) {
   // [[location 0]] var<out> foo : f32;
   // [[location 1]] var<out> bar : i32;
   //
@@ -415,7 +415,7 @@
 
   GeneratorImpl g;
   g.set_module_for_testing(&mod);
-  ASSERT_FALSE(g.EmitEntryPoint(ep_ptr)) << g.error();
+  ASSERT_FALSE(g.EmitEntryPointData(ep_ptr)) << g.error();
   EXPECT_EQ(g.error(), R"(invalid location variable for pipeline stage)");
 }
 
diff --git a/src/writer/msl/generator_impl_function_test.cc b/src/writer/msl/generator_impl_function_test.cc
index e555ed5..7d49b1d 100644
--- a/src/writer/msl/generator_impl_function_test.cc
+++ b/src/writer/msl/generator_impl_function_test.cc
@@ -13,14 +13,27 @@
 // limitations under the License.
 
 #include "gtest/gtest.h"
+#include "src/ast/assignment_statement.h"
+#include "src/ast/binary_expression.h"
+#include "src/ast/call_expression.h"
+#include "src/ast/decorated_variable.h"
+#include "src/ast/float_literal.h"
 #include "src/ast/function.h"
+#include "src/ast/identifier_expression.h"
+#include "src/ast/if_statement.h"
+#include "src/ast/location_decoration.h"
 #include "src/ast/module.h"
 #include "src/ast/return_statement.h"
+#include "src/ast/scalar_constructor_expression.h"
+#include "src/ast/sint_literal.h"
 #include "src/ast/type/array_type.h"
 #include "src/ast/type/f32_type.h"
 #include "src/ast/type/i32_type.h"
 #include "src/ast/type/void_type.h"
 #include "src/ast/variable.h"
+#include "src/ast/variable_decl_statement.h"
+#include "src/context.h"
+#include "src/type_determiner.h"
 #include "src/writer/msl/generator_impl.h"
 
 namespace tint {
@@ -138,6 +151,415 @@
 )");
 }
 
+TEST_F(MslGeneratorImplTest, Emit_Function_EntryPoint_WithInOutVars) {
+  ast::type::VoidType void_type;
+  ast::type::F32Type f32;
+
+  auto foo_var = std::make_unique<ast::DecoratedVariable>(
+      std::make_unique<ast::Variable>("foo", ast::StorageClass::kInput, &f32));
+
+  ast::VariableDecorationList decos;
+  decos.push_back(std::make_unique<ast::LocationDecoration>(0));
+  foo_var->set_decorations(std::move(decos));
+
+  auto bar_var = std::make_unique<ast::DecoratedVariable>(
+      std::make_unique<ast::Variable>("bar", ast::StorageClass::kOutput, &f32));
+  decos.push_back(std::make_unique<ast::LocationDecoration>(1));
+  bar_var->set_decorations(std::move(decos));
+
+  Context ctx;
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+  td.RegisterVariableForTesting(foo_var.get());
+  td.RegisterVariableForTesting(bar_var.get());
+
+  mod.AddGlobalVariable(std::move(foo_var));
+  mod.AddGlobalVariable(std::move(bar_var));
+
+  ast::VariableList params;
+  auto func = std::make_unique<ast::Function>("frag_main", std::move(params),
+                                              &void_type);
+
+  ast::StatementList body;
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("bar"),
+      std::make_unique<ast::IdentifierExpression>("foo")));
+  body.push_back(std::make_unique<ast::ReturnStatement>());
+  func->set_body(std::move(body));
+
+  mod.AddFunction(std::move(func));
+
+  auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kFragment, "",
+                                              "frag_main");
+  mod.AddEntryPoint(std::move(ep));
+
+  ASSERT_TRUE(td.Determine()) << td.error();
+
+  GeneratorImpl g;
+  ASSERT_TRUE(g.Generate(mod)) << g.error();
+  EXPECT_EQ(g.result(), R"(#include <metal_stdlib>
+
+struct frag_main_in {
+  float foo [[user(locn0)]];
+};
+
+struct frag_main_out {
+  float bar [[color(1)]];
+};
+
+fragment frag_main_out frag_main(frag_main_in tint_in [[stage_in]]) {
+  frag_main_out tint_out = {};
+  tint_out.bar = tint_in.foo;
+  return tint_out;
+}
+
+)");
+}
+
+TEST_F(MslGeneratorImplTest,
+       Emit_Function_Called_By_EntryPoints_WithGlobals_And_Params) {
+  ast::type::VoidType void_type;
+  ast::type::F32Type f32;
+
+  auto foo_var = std::make_unique<ast::DecoratedVariable>(
+      std::make_unique<ast::Variable>("foo", ast::StorageClass::kInput, &f32));
+
+  ast::VariableDecorationList decos;
+  decos.push_back(std::make_unique<ast::LocationDecoration>(0));
+  foo_var->set_decorations(std::move(decos));
+
+  auto bar_var = std::make_unique<ast::DecoratedVariable>(
+      std::make_unique<ast::Variable>("bar", ast::StorageClass::kOutput, &f32));
+  decos.push_back(std::make_unique<ast::LocationDecoration>(1));
+  bar_var->set_decorations(std::move(decos));
+
+  auto val_var = std::make_unique<ast::DecoratedVariable>(
+      std::make_unique<ast::Variable>("val", ast::StorageClass::kOutput, &f32));
+  decos.push_back(std::make_unique<ast::LocationDecoration>(0));
+  val_var->set_decorations(std::move(decos));
+
+  Context ctx;
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+  td.RegisterVariableForTesting(foo_var.get());
+  td.RegisterVariableForTesting(bar_var.get());
+  td.RegisterVariableForTesting(val_var.get());
+
+  mod.AddGlobalVariable(std::move(foo_var));
+  mod.AddGlobalVariable(std::move(bar_var));
+  mod.AddGlobalVariable(std::move(val_var));
+
+  ast::VariableList params;
+  params.push_back(std::make_unique<ast::Variable>(
+      "param", ast::StorageClass::kFunction, &f32));
+  auto sub_func =
+      std::make_unique<ast::Function>("sub_func", std::move(params), &f32);
+
+  ast::StatementList body;
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("bar"),
+      std::make_unique<ast::IdentifierExpression>("foo")));
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("val"),
+      std::make_unique<ast::IdentifierExpression>("param")));
+  body.push_back(std::make_unique<ast::ReturnStatement>(
+      std::make_unique<ast::IdentifierExpression>("foo")));
+  sub_func->set_body(std::move(body));
+
+  mod.AddFunction(std::move(sub_func));
+
+  auto func_1 = std::make_unique<ast::Function>("frag_1_main",
+                                                std::move(params), &void_type);
+
+  ast::ExpressionList expr;
+  expr.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("bar"),
+      std::make_unique<ast::CallExpression>(
+          std::make_unique<ast::IdentifierExpression>("sub_func"),
+          std::move(expr))));
+  body.push_back(std::make_unique<ast::ReturnStatement>());
+  func_1->set_body(std::move(body));
+
+  mod.AddFunction(std::move(func_1));
+
+  auto ep1 = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kFragment,
+                                               "ep_1", "frag_1_main");
+  mod.AddEntryPoint(std::move(ep1));
+
+  ASSERT_TRUE(td.Determine()) << td.error();
+
+  GeneratorImpl g;
+  ASSERT_TRUE(g.Generate(mod)) << g.error();
+  EXPECT_EQ(g.result(), R"(#include <metal_stdlib>
+
+struct ep_1_in {
+  float foo [[user(locn0)]];
+};
+
+struct ep_1_out {
+  float bar [[color(1)]];
+  float val [[color(0)]];
+};
+
+float sub_func_ep_1(thread ep_1_in& tint_in, thread ep_1_out& tint_out, float param) {
+  tint_out.bar = tint_in.foo;
+  tint_out.val = param;
+  return tint_in.foo;
+}
+
+fragment ep_1_out ep_1(ep_1_in tint_in [[stage_in]]) {
+  ep_1_out tint_out = {};
+  tint_out.bar = sub_func_ep_1(tint_in, tint_out, 1.00000000f);
+  return tint_out;
+}
+
+)");
+}
+
+TEST_F(MslGeneratorImplTest, Emit_Function_Called_Two_EntryPoints_WithGlobals) {
+  ast::type::VoidType void_type;
+  ast::type::F32Type f32;
+
+  auto foo_var = std::make_unique<ast::DecoratedVariable>(
+      std::make_unique<ast::Variable>("foo", ast::StorageClass::kInput, &f32));
+
+  ast::VariableDecorationList decos;
+  decos.push_back(std::make_unique<ast::LocationDecoration>(0));
+  foo_var->set_decorations(std::move(decos));
+
+  auto bar_var = std::make_unique<ast::DecoratedVariable>(
+      std::make_unique<ast::Variable>("bar", ast::StorageClass::kOutput, &f32));
+  decos.push_back(std::make_unique<ast::LocationDecoration>(1));
+  bar_var->set_decorations(std::move(decos));
+
+  Context ctx;
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+  td.RegisterVariableForTesting(foo_var.get());
+  td.RegisterVariableForTesting(bar_var.get());
+
+  mod.AddGlobalVariable(std::move(foo_var));
+  mod.AddGlobalVariable(std::move(bar_var));
+
+  ast::VariableList params;
+  auto sub_func =
+      std::make_unique<ast::Function>("sub_func", std::move(params), &f32);
+
+  ast::StatementList body;
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("bar"),
+      std::make_unique<ast::IdentifierExpression>("foo")));
+  body.push_back(std::make_unique<ast::ReturnStatement>(
+      std::make_unique<ast::IdentifierExpression>("foo")));
+  sub_func->set_body(std::move(body));
+
+  mod.AddFunction(std::move(sub_func));
+
+  auto func_1 = std::make_unique<ast::Function>("frag_1_main",
+                                                std::move(params), &void_type);
+
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("bar"),
+      std::make_unique<ast::CallExpression>(
+          std::make_unique<ast::IdentifierExpression>("sub_func"),
+          ast::ExpressionList{})));
+  body.push_back(std::make_unique<ast::ReturnStatement>());
+  func_1->set_body(std::move(body));
+
+  mod.AddFunction(std::move(func_1));
+
+  auto ep1 = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kFragment,
+                                               "ep_1", "frag_1_main");
+  auto ep2 = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kFragment,
+                                               "ep_2", "frag_1_main");
+  mod.AddEntryPoint(std::move(ep1));
+  mod.AddEntryPoint(std::move(ep2));
+
+  ASSERT_TRUE(td.Determine()) << td.error();
+
+  GeneratorImpl g;
+  ASSERT_TRUE(g.Generate(mod)) << g.error();
+  EXPECT_EQ(g.result(), R"(#include <metal_stdlib>
+
+struct ep_1_in {
+  float foo [[user(locn0)]];
+};
+
+struct ep_1_out {
+  float bar [[color(1)]];
+};
+
+struct ep_2_in {
+  float foo [[user(locn0)]];
+};
+
+struct ep_2_out {
+  float bar [[color(1)]];
+};
+
+float sub_func_ep_1(thread ep_1_in& tint_in, thread ep_1_out& tint_out) {
+  tint_out.bar = tint_in.foo;
+  return tint_in.foo;
+}
+
+float sub_func_ep_2(thread ep_2_in& tint_in, thread ep_2_out& tint_out) {
+  tint_out.bar = tint_in.foo;
+  return tint_in.foo;
+}
+
+fragment ep_1_out ep_1(ep_1_in tint_in [[stage_in]]) {
+  ep_1_out tint_out = {};
+  tint_out.bar = sub_func_ep_1(tint_in, tint_out);
+  return tint_out;
+}
+
+fragment ep_2_out ep_2(ep_2_in tint_in [[stage_in]]) {
+  ep_2_out tint_out = {};
+  tint_out.bar = sub_func_ep_2(tint_in, tint_out);
+  return tint_out;
+}
+
+)");
+}
+
+TEST_F(MslGeneratorImplTest,
+       Emit_Function_EntryPoints_WithGlobal_Nested_Return) {
+  ast::type::VoidType void_type;
+  ast::type::F32Type f32;
+  ast::type::I32Type i32;
+
+  auto bar_var = std::make_unique<ast::DecoratedVariable>(
+      std::make_unique<ast::Variable>("bar", ast::StorageClass::kOutput, &f32));
+  ast::VariableDecorationList decos;
+  decos.push_back(std::make_unique<ast::LocationDecoration>(1));
+  bar_var->set_decorations(std::move(decos));
+
+  Context ctx;
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+  td.RegisterVariableForTesting(bar_var.get());
+  mod.AddGlobalVariable(std::move(bar_var));
+
+  ast::VariableList params;
+  auto func_1 = std::make_unique<ast::Function>("frag_1_main",
+                                                std::move(params), &void_type);
+
+  ast::StatementList body;
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("bar"),
+      std::make_unique<ast::ScalarConstructorExpression>(
+          std::make_unique<ast::FloatLiteral>(&f32, 1.0f))));
+
+  ast::StatementList list;
+  list.push_back(std::make_unique<ast::ReturnStatement>());
+  body.push_back(std::make_unique<ast::IfStatement>(
+      std::make_unique<ast::BinaryExpression>(
+          ast::BinaryOp::kEqual,
+          std::make_unique<ast::ScalarConstructorExpression>(
+              std::make_unique<ast::SintLiteral>(&i32, 1)),
+          std::make_unique<ast::ScalarConstructorExpression>(
+              std::make_unique<ast::SintLiteral>(&i32, 1))),
+      std::move(list)));
+
+  body.push_back(std::make_unique<ast::ReturnStatement>());
+  func_1->set_body(std::move(body));
+
+  mod.AddFunction(std::move(func_1));
+
+  auto ep1 = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kFragment,
+                                               "ep_1", "frag_1_main");
+  mod.AddEntryPoint(std::move(ep1));
+
+  ASSERT_TRUE(td.Determine()) << td.error();
+
+  GeneratorImpl g;
+  ASSERT_TRUE(g.Generate(mod)) << g.error();
+  EXPECT_EQ(g.result(), R"(#include <metal_stdlib>
+
+struct ep_1_out {
+  float bar [[color(1)]];
+};
+
+fragment ep_1_out ep_1() {
+  ep_1_out tint_out = {};
+  tint_out.bar = 1.00000000f;
+  if ((1 == 1)) {
+    return tint_out;
+  }
+  return tint_out;
+}
+
+)");
+}
+
+TEST_F(MslGeneratorImplTest,
+       Emit_Function_Called_Two_EntryPoints_WithoutGlobals) {
+  ast::type::VoidType void_type;
+  ast::type::F32Type f32;
+
+  Context ctx;
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+
+  ast::VariableList params;
+  auto sub_func =
+      std::make_unique<ast::Function>("sub_func", std::move(params), &f32);
+
+  ast::StatementList body;
+  body.push_back(std::make_unique<ast::ReturnStatement>(
+      std::make_unique<ast::ScalarConstructorExpression>(
+          std::make_unique<ast::FloatLiteral>(&f32, 1.0))));
+  sub_func->set_body(std::move(body));
+
+  mod.AddFunction(std::move(sub_func));
+
+  auto func_1 = std::make_unique<ast::Function>("frag_1_main",
+                                                std::move(params), &void_type);
+
+  body.push_back(std::make_unique<ast::VariableDeclStatement>(
+      std::make_unique<ast::Variable>("foo", ast::StorageClass::kFunction,
+                                      &f32)));
+  body.back()->AsVariableDecl()->variable()->set_constructor(
+      std::make_unique<ast::CallExpression>(
+          std::make_unique<ast::IdentifierExpression>("sub_func"),
+          ast::ExpressionList{}));
+
+  body.push_back(std::make_unique<ast::ReturnStatement>());
+  func_1->set_body(std::move(body));
+
+  mod.AddFunction(std::move(func_1));
+
+  auto ep1 = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kFragment,
+                                               "ep_1", "frag_1_main");
+  auto ep2 = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kFragment,
+                                               "ep_2", "frag_1_main");
+  mod.AddEntryPoint(std::move(ep1));
+  mod.AddEntryPoint(std::move(ep2));
+
+  ASSERT_TRUE(td.Determine()) << td.error();
+
+  GeneratorImpl g;
+  ASSERT_TRUE(g.Generate(mod)) << g.error();
+  EXPECT_EQ(g.result(), R"(#include <metal_stdlib>
+
+float sub_func() {
+  return 1.00000000f;
+}
+
+fragment void ep_1() {
+  float foo = sub_func();
+  return;
+}
+
+fragment void ep_2() {
+  float foo = sub_func();
+  return;
+}
+
+)");
+}
 TEST_F(MslGeneratorImplTest, Emit_Function_EntryPoint_WithName) {
   ast::type::VoidType void_type;
 
diff --git a/src/writer/msl/generator_impl_test.cc b/src/writer/msl/generator_impl_test.cc
index 41b0cbb..063fd1f 100644
--- a/src/writer/msl/generator_impl_test.cc
+++ b/src/writer/msl/generator_impl_test.cc
@@ -51,29 +51,23 @@
 }
 
 TEST_F(MslGeneratorImplTest, InputStructName) {
-  ast::EntryPoint ep(ast::PipelineStage::kVertex, "main", "func_main");
-
   GeneratorImpl g;
-  ASSERT_EQ(g.generate_struct_name(&ep, "in"), "func_main_in");
+  ASSERT_EQ(g.generate_name("func_main_in"), "func_main_in");
 }
 
 TEST_F(MslGeneratorImplTest, InputStructName_ConflictWithExisting) {
-  ast::EntryPoint ep(ast::PipelineStage::kVertex, "main", "func_main");
-
   GeneratorImpl g;
 
   // Register the struct name as existing.
   auto* namer = g.namer_for_testing();
   namer->NameFor("func_main_out");
 
-  ASSERT_EQ(g.generate_struct_name(&ep, "out"), "func_main_out_0");
+  ASSERT_EQ(g.generate_name("func_main_out"), "func_main_out_0");
 }
 
 TEST_F(MslGeneratorImplTest, NameConflictWith_InputStructName) {
-  ast::EntryPoint ep(ast::PipelineStage::kVertex, "main", "func_main");
-
   GeneratorImpl g;
-  ASSERT_EQ(g.generate_struct_name(&ep, "in"), "func_main_in");
+  ASSERT_EQ(g.generate_name("func_main_in"), "func_main_in");
 
   ast::IdentifierExpression ident("func_main_in");
   ASSERT_TRUE(g.EmitIdentifier(&ident));
diff --git a/test/triangle.wgsl b/test/triangle.wgsl
index 864417f..6eb9ed7 100644
--- a/test/triangle.wgsl
+++ b/test/triangle.wgsl
@@ -28,9 +28,9 @@
 entry_point vertex as "main" = vtx_main;
 
 # Fragment shader
-[[location 0]] var outColor : ptr<out, vec4<f32>>;
+[[location 0]] var<out> outColor : vec4<f32>;
 fn frag_main() -> void {
   outColor = vec4<f32>(1, 0, 0, 1);
   return;
 }
-entry_point fragment as "main" = frag_main;
+entry_point fragment = frag_main;