[msl-writer] Generate input/output structs

This CL adds generation of the input/output structures for entry points.

Bug: tint:8
Change-Id: I93942496bcea0a2eea944e5e1cd0baf383530f5e
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/24721
Reviewed-by: David Neto <dneto@google.com>
diff --git a/BUILD.gn b/BUILD.gn
index 49e1a9a..65890d6 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -896,6 +896,7 @@
     "src/writer/msl/generator_impl_cast_test.cc",
     "src/writer/msl/generator_impl_constructor_test.cc",
     "src/writer/msl/generator_impl_continue_test.cc",
+    "src/writer/msl/generator_impl_entry_point_test.cc",
     "src/writer/msl/generator_impl_function_test.cc",
     "src/writer/msl/generator_impl_identifier_test.cc",
     "src/writer/msl/generator_impl_if_test.cc",
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 66ae755..fe8ba84 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -504,6 +504,7 @@
     writer/msl/generator_impl_cast_test.cc
     writer/msl/generator_impl_constructor_test.cc
     writer/msl/generator_impl_continue_test.cc
+    writer/msl/generator_impl_entry_point_test.cc
     writer/msl/generator_impl_function_test.cc
     writer/msl/generator_impl_identifier_test.cc
     writer/msl/generator_impl_if_test.cc
diff --git a/src/ast/module.cc b/src/ast/module.cc
index 88d97d8..7ddd7e7 100644
--- a/src/ast/module.cc
+++ b/src/ast/module.cc
@@ -25,10 +25,20 @@
 
 Module::~Module() = default;
 
-Import* Module::FindImportByName(const std::string& name) {
+Import* Module::FindImportByName(const std::string& name) const {
   for (const auto& import : imports_) {
-    if (import->name() == name)
+    if (import->name() == name) {
       return import.get();
+    }
+  }
+  return nullptr;
+}
+
+Function* Module::FindFunctionByName(const std::string& name) const {
+  for (const auto& func : functions_) {
+    if (func->name() == name) {
+      return func.get();
+    }
   }
   return nullptr;
 }
diff --git a/src/ast/module.h b/src/ast/module.h
index 26e9da1..77e1924 100644
--- a/src/ast/module.h
+++ b/src/ast/module.h
@@ -47,7 +47,7 @@
   /// Find the import of the given name
   /// @param name The import name to search for
   /// @returns the import with the given name if found, nullptr otherwise.
-  Import* FindImportByName(const std::string& name);
+  Import* FindImportByName(const std::string& name) const;
 
   /// Add a global variable to the module
   /// @param var the variable to add
@@ -80,6 +80,10 @@
   }
   /// @returns the modules functions
   const FunctionList& functions() const { return functions_; }
+  /// Returns the function with the given name
+  /// @param name the name to search for
+  /// @returns the associated function or nullptr if none exists
+  Function* FindFunctionByName(const std::string& name) const;
 
   /// @returns true if all required fields in the AST are present.
   bool IsValid() const;
diff --git a/src/ast/module_test.cc b/src/ast/module_test.cc
index 0e02c13..869f304 100644
--- a/src/ast/module_test.cc
+++ b/src/ast/module_test.cc
@@ -81,6 +81,21 @@
   EXPECT_EQ(nullptr, m.FindImportByName("Missing"));
 }
 
+TEST_F(ModuleTest, LookupFunction) {
+  type::F32Type f32;
+  Module m;
+
+  auto func = std::make_unique<Function>("main", VariableList{}, &f32);
+  auto* func_ptr = func.get();
+  m.AddFunction(std::move(func));
+  EXPECT_EQ(func_ptr, m.FindFunctionByName("main"));
+}
+
+TEST_F(ModuleTest, LookupFunctionMissing) {
+  Module m;
+  EXPECT_EQ(nullptr, m.FindFunctionByName("Missing"));
+}
+
 TEST_F(ModuleTest, IsValid_Empty) {
   Module m;
   EXPECT_TRUE(m.IsValid());
diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc
index 2759bba..fe838df 100644
--- a/src/writer/msl/generator_impl.cc
+++ b/src/writer/msl/generator_impl.cc
@@ -23,11 +23,13 @@
 #include "src/ast/case_statement.h"
 #include "src/ast/cast_expression.h"
 #include "src/ast/continue_statement.h"
+#include "src/ast/decorated_variable.h"
 #include "src/ast/else_statement.h"
 #include "src/ast/float_literal.h"
 #include "src/ast/function.h"
 #include "src/ast/identifier_expression.h"
 #include "src/ast/if_statement.h"
+#include "src/ast/location_decoration.h"
 #include "src/ast/loop_statement.h"
 #include "src/ast/member_accessor_expression.h"
 #include "src/ast/return_statement.h"
@@ -53,6 +55,9 @@
 namespace msl {
 namespace {
 
+const char kInStructNameSuffix[] = "in";
+const char kOutStructNameSuffix[] = "out";
+
 bool last_is_break_or_fallthrough(const ast::StatementList& stmts) {
   if (stmts.empty()) {
     return false;
@@ -67,6 +72,23 @@
 
 GeneratorImpl::~GeneratorImpl() = default;
 
+void GeneratorImpl::set_module_for_testing(ast::Module* mod) {
+  module_ = mod;
+}
+
+std::string GeneratorImpl::generate_struct_name(ast::EntryPoint* ep,
+                                                const std::string& type) {
+  std::string base_name = ep->function_name() + "_" + type;
+  std::string name = base_name;
+  uint32_t i = 0;
+  while (namer_.IsMapped(name)) {
+    name = base_name + "_" + std::to_string(i);
+    ++i;
+  }
+  namer_.RegisterRemappedName(name);
+  return name;
+}
+
 bool GeneratorImpl::Generate(const ast::Module& module) {
   module_ = &module;
 
@@ -81,6 +103,12 @@
     out_ << std::endl;
   }
 
+  for (const auto& ep : module.entry_points()) {
+    if (!EmitEntryPoint(ep.get())) {
+      return false;
+    }
+  }
+
   for (const auto& func : module.functions()) {
     if (!EmitFunction(func.get())) {
       return false;
@@ -387,6 +415,109 @@
   return true;
 }
 
+bool GeneratorImpl::EmitEntryPoint(ast::EntryPoint* ep) {
+  auto* func = module_->FindFunctionByName(ep->function_name());
+  if (func == nullptr) {
+    error_ = "Unable to find entry point function: " + ep->function_name();
+    return false;
+  }
+
+  std::vector<std::pair<ast::Variable*, uint32_t>> in_locations;
+  std::vector<std::pair<ast::Variable*, uint32_t>> out_locations;
+  for (auto* var : func->referenced_module_variables()) {
+    if (!var->IsDecorated()) {
+      continue;
+    }
+    auto* decorated = var->AsDecorated();
+    ast::LocationDecoration* locn_deco = nullptr;
+    for (auto& deco : decorated->decorations()) {
+      if (deco->IsLocation()) {
+        locn_deco = deco.get()->AsLocation();
+        break;
+      }
+    }
+    if (locn_deco == nullptr) {
+      continue;
+    }
+
+    if (var->storage_class() == ast::StorageClass::kInput) {
+      in_locations.push_back({var, locn_deco->value()});
+    } else if (var->storage_class() == ast::StorageClass::kOutput) {
+      out_locations.push_back({var, locn_deco->value()});
+    }
+  }
+
+  if (!in_locations.empty()) {
+    auto in_struct_name = generate_struct_name(ep, kInStructNameSuffix);
+    ep_name_to_in_struct_[ep->name()] = in_struct_name;
+
+    make_indent();
+    out_ << "struct " << in_struct_name << " {" << std::endl;
+
+    increment_indent();
+
+    for (auto& data : in_locations) {
+      auto* var = data.first;
+      uint32_t loc = data.second;
+
+      make_indent();
+      if (!EmitType(var->type(), var->name())) {
+        return false;
+      }
+
+      out_ << " " << var->name() << " [[";
+      if (ep->stage() == ast::PipelineStage::kVertex) {
+        out_ << "attribute(" << loc << ")";
+      } else if (ep->stage() == ast::PipelineStage::kFragment) {
+        out_ << "user(locn" << loc << ")";
+      } else {
+        error_ = "invalid location variable for pipeline stage";
+        return false;
+      }
+      out_ << "]];" << std::endl;
+    }
+    decrement_indent();
+    make_indent();
+
+    out_ << "};" << std::endl << std::endl;
+  }
+
+  if (!out_locations.empty()) {
+    auto out_struct_name = generate_struct_name(ep, kOutStructNameSuffix);
+    ep_name_to_out_struct_[ep->name()] = out_struct_name;
+
+    make_indent();
+    out_ << "struct " << out_struct_name << " {" << std::endl;
+
+    increment_indent();
+    for (auto& data : out_locations) {
+      auto* var = data.first;
+      uint32_t loc = data.second;
+
+      make_indent();
+      if (!EmitType(var->type(), var->name())) {
+        return false;
+      }
+
+      out_ << " " << var->name() << " [[";
+      if (ep->stage() == ast::PipelineStage::kVertex) {
+        out_ << "user(locn" << loc << ")";
+      } else if (ep->stage() == ast::PipelineStage::kFragment) {
+        out_ << "color(" << loc << ")";
+      } else {
+        error_ = "invalid location variable for pipeline stage";
+        return false;
+      }
+      out_ << "]];" << std::endl;
+    }
+    decrement_indent();
+    make_indent();
+    out_ << "};" << std::endl << std::endl;
+  }
+
+  return true;
+}
+
 bool GeneratorImpl::EmitExpression(ast::Expression* expr) {
   if (expr->IsArrayAccessor()) {
     return EmitArrayAccessor(expr->AsArrayAccessor());
diff --git a/src/writer/msl/generator_impl.h b/src/writer/msl/generator_impl.h
index 719e1c3..f731c61 100644
--- a/src/writer/msl/generator_impl.h
+++ b/src/writer/msl/generator_impl.h
@@ -17,6 +17,7 @@
 
 #include <sstream>
 #include <string>
+#include <unordered_map>
 
 #include "src/ast/literal.h"
 #include "src/ast/module.h"
@@ -85,6 +86,10 @@
   /// @param stmt the statement to emit
   /// @returns true if the statement was emitted
   bool EmitElse(ast::ElseStatement* stmt);
+  /// Handles emitting information for an entry point
+  /// @param ep the entry point
+  /// @returns true if the entry point data was emitted
+  bool EmitEntryPoint(ast::EntryPoint* ep);
   /// Handles generate an Expression
   /// @param expr the expression
   /// @returns true if the expression was emitted
@@ -166,10 +171,26 @@
   /// @returns true if the zero value was successfully emitted.
   bool EmitZeroValue(ast::type::Type* type);
 
+  /// Sets the module for testing purposes
+  /// @param mod the module to set.
+  void set_module_for_testing(ast::Module* mod);
+
+  /// Generates a name for the input struct
+  /// @param ep the entry point to generate for
+  /// @param type the type of struct to generate
+  /// @returns the input struct name
+  std::string generate_struct_name(ast::EntryPoint* ep,
+                                   const std::string& type);
+
+  /// @returns the namer for testing
+  Namer* namer_for_testing() { return &namer_; }
+
  private:
   Namer namer_;
   const ast::Module* module_ = nullptr;
   uint32_t loop_emission_counter_ = 0;
+  std::unordered_map<std::string, std::string> ep_name_to_in_struct_;
+  std::unordered_map<std::string, std::string> ep_name_to_out_struct_;
 };
 
 }  // namespace msl
diff --git a/src/writer/msl/generator_impl_entry_point_test.cc b/src/writer/msl/generator_impl_entry_point_test.cc
new file mode 100644
index 0000000..a102300
--- /dev/null
+++ b/src/writer/msl/generator_impl_entry_point_test.cc
@@ -0,0 +1,425 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "gtest/gtest.h"
+#include "src/ast/assignment_statement.h"
+#include "src/ast/decorated_variable.h"
+#include "src/ast/entry_point.h"
+#include "src/ast/identifier_expression.h"
+#include "src/ast/location_decoration.h"
+#include "src/ast/module.h"
+#include "src/ast/type/f32_type.h"
+#include "src/ast/type/i32_type.h"
+#include "src/ast/variable.h"
+#include "src/context.h"
+#include "src/type_determiner.h"
+#include "src/writer/msl/generator_impl.h"
+
+namespace tint {
+namespace writer {
+namespace msl {
+namespace {
+
+using MslGeneratorImplTest = testing::Test;
+
+TEST_F(MslGeneratorImplTest, EmitEntryPoint_Vertex_Input) {
+  // [[location 0]] var<in> foo : f32;
+  // [[location 1]] var<in> bar : i32;
+  //
+  // struct vtx_main_in {
+  //   float foo [[attribute(0)]];
+  //   int bar [[attribute(1)]];
+  // };
+
+  ast::type::F32Type f32;
+  ast::type::I32Type i32;
+
+  auto foo_var = std::make_unique<ast::DecoratedVariable>(
+      std::make_unique<ast::Variable>("foo", ast::StorageClass::kInput, &f32));
+
+  ast::VariableDecorationList decos;
+  decos.push_back(std::make_unique<ast::LocationDecoration>(0));
+  foo_var->set_decorations(std::move(decos));
+
+  auto bar_var = std::make_unique<ast::DecoratedVariable>(
+      std::make_unique<ast::Variable>("bar", ast::StorageClass::kInput, &i32));
+  decos.push_back(std::make_unique<ast::LocationDecoration>(1));
+  bar_var->set_decorations(std::move(decos));
+
+  Context ctx;
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+  td.RegisterVariableForTesting(foo_var.get());
+  td.RegisterVariableForTesting(bar_var.get());
+
+  mod.AddGlobalVariable(std::move(foo_var));
+  mod.AddGlobalVariable(std::move(bar_var));
+
+  ast::VariableList params;
+  auto func =
+      std::make_unique<ast::Function>("vtx_main", std::move(params), &f32);
+
+  ast::StatementList body;
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("foo"),
+      std::make_unique<ast::IdentifierExpression>("foo")));
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("bar"),
+      std::make_unique<ast::IdentifierExpression>("bar")));
+  func->set_body(std::move(body));
+
+  mod.AddFunction(std::move(func));
+
+  auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kVertex,
+                                              "main", "vtx_main");
+  auto* ep_ptr = ep.get();
+
+  mod.AddEntryPoint(std::move(ep));
+
+  ASSERT_TRUE(td.Determine()) << td.error();
+
+  GeneratorImpl g;
+  g.set_module_for_testing(&mod);
+  ASSERT_TRUE(g.EmitEntryPoint(ep_ptr)) << g.error();
+  EXPECT_EQ(g.result(), R"(struct vtx_main_in {
+  float foo [[attribute(0)]];
+  int bar [[attribute(1)]];
+};
+
+)");
+}
+
+TEST_F(MslGeneratorImplTest, EmitEntryPoint_Vertex_Output) {
+  // [[location 0]] var<out> foo : f32;
+  // [[location 1]] var<out> bar : i32;
+  //
+  // struct vtx_main_out {
+  //   float foo [[user(locn0)]];
+  //   int bar [[user(locn1)]];
+  // };
+
+  ast::type::F32Type f32;
+  ast::type::I32Type i32;
+
+  auto foo_var = std::make_unique<ast::DecoratedVariable>(
+      std::make_unique<ast::Variable>("foo", ast::StorageClass::kOutput, &f32));
+
+  ast::VariableDecorationList decos;
+  decos.push_back(std::make_unique<ast::LocationDecoration>(0));
+  foo_var->set_decorations(std::move(decos));
+
+  auto bar_var = std::make_unique<ast::DecoratedVariable>(
+      std::make_unique<ast::Variable>("bar", ast::StorageClass::kOutput, &i32));
+  decos.push_back(std::make_unique<ast::LocationDecoration>(1));
+  bar_var->set_decorations(std::move(decos));
+
+  Context ctx;
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+  td.RegisterVariableForTesting(foo_var.get());
+  td.RegisterVariableForTesting(bar_var.get());
+
+  mod.AddGlobalVariable(std::move(foo_var));
+  mod.AddGlobalVariable(std::move(bar_var));
+
+  ast::VariableList params;
+  auto func =
+      std::make_unique<ast::Function>("vtx_main", std::move(params), &f32);
+
+  ast::StatementList body;
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("foo"),
+      std::make_unique<ast::IdentifierExpression>("foo")));
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("bar"),
+      std::make_unique<ast::IdentifierExpression>("bar")));
+  func->set_body(std::move(body));
+
+  mod.AddFunction(std::move(func));
+
+  auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kVertex,
+                                              "main", "vtx_main");
+  auto* ep_ptr = ep.get();
+
+  mod.AddEntryPoint(std::move(ep));
+
+  ASSERT_TRUE(td.Determine()) << td.error();
+
+  GeneratorImpl g;
+  g.set_module_for_testing(&mod);
+  ASSERT_TRUE(g.EmitEntryPoint(ep_ptr)) << g.error();
+  EXPECT_EQ(g.result(), R"(struct vtx_main_out {
+  float foo [[user(locn0)]];
+  int bar [[user(locn1)]];
+};
+
+)");
+}
+
+TEST_F(MslGeneratorImplTest, EmitEntryPoint_Fragment_Input) {
+  // [[location 0]] var<in> foo : f32;
+  // [[location 1]] var<in> bar : i32;
+  //
+  // struct frag_main_in {
+  //   float foo [[user(locn0)]];
+  //   int bar [[user(locn1)]];
+  // };
+
+  ast::type::F32Type f32;
+  ast::type::I32Type i32;
+
+  auto foo_var = std::make_unique<ast::DecoratedVariable>(
+      std::make_unique<ast::Variable>("foo", ast::StorageClass::kInput, &f32));
+
+  ast::VariableDecorationList decos;
+  decos.push_back(std::make_unique<ast::LocationDecoration>(0));
+  foo_var->set_decorations(std::move(decos));
+
+  auto bar_var = std::make_unique<ast::DecoratedVariable>(
+      std::make_unique<ast::Variable>("bar", ast::StorageClass::kInput, &i32));
+  decos.push_back(std::make_unique<ast::LocationDecoration>(1));
+  bar_var->set_decorations(std::move(decos));
+
+  Context ctx;
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+  td.RegisterVariableForTesting(foo_var.get());
+  td.RegisterVariableForTesting(bar_var.get());
+
+  mod.AddGlobalVariable(std::move(foo_var));
+  mod.AddGlobalVariable(std::move(bar_var));
+
+  ast::VariableList params;
+  auto func =
+      std::make_unique<ast::Function>("frag_main", std::move(params), &f32);
+
+  ast::StatementList body;
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("foo"),
+      std::make_unique<ast::IdentifierExpression>("foo")));
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("bar"),
+      std::make_unique<ast::IdentifierExpression>("bar")));
+  func->set_body(std::move(body));
+
+  mod.AddFunction(std::move(func));
+
+  auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kFragment,
+                                              "main", "frag_main");
+  auto* ep_ptr = ep.get();
+
+  mod.AddEntryPoint(std::move(ep));
+
+  ASSERT_TRUE(td.Determine()) << td.error();
+
+  GeneratorImpl g;
+  g.set_module_for_testing(&mod);
+  ASSERT_TRUE(g.EmitEntryPoint(ep_ptr)) << g.error();
+  EXPECT_EQ(g.result(), R"(struct frag_main_in {
+  float foo [[user(locn0)]];
+  int bar [[user(locn1)]];
+};
+
+)");
+}
+
+TEST_F(MslGeneratorImplTest, EmitEntryPoint_Fragment_Output) {
+  // [[location 0]] var<out> foo : f32;
+  // [[location 1]] var<out> bar : i32;
+  //
+  // struct frag_main_out {
+  //   float foo [[color(0)]];
+  //   int bar [[color(1)]];
+  // };
+
+  ast::type::F32Type f32;
+  ast::type::I32Type i32;
+
+  auto foo_var = std::make_unique<ast::DecoratedVariable>(
+      std::make_unique<ast::Variable>("foo", ast::StorageClass::kOutput, &f32));
+
+  ast::VariableDecorationList decos;
+  decos.push_back(std::make_unique<ast::LocationDecoration>(0));
+  foo_var->set_decorations(std::move(decos));
+
+  auto bar_var = std::make_unique<ast::DecoratedVariable>(
+      std::make_unique<ast::Variable>("bar", ast::StorageClass::kOutput, &i32));
+  decos.push_back(std::make_unique<ast::LocationDecoration>(1));
+  bar_var->set_decorations(std::move(decos));
+
+  Context ctx;
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+  td.RegisterVariableForTesting(foo_var.get());
+  td.RegisterVariableForTesting(bar_var.get());
+
+  mod.AddGlobalVariable(std::move(foo_var));
+  mod.AddGlobalVariable(std::move(bar_var));
+
+  ast::VariableList params;
+  auto func =
+      std::make_unique<ast::Function>("frag_main", std::move(params), &f32);
+
+  ast::StatementList body;
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("foo"),
+      std::make_unique<ast::IdentifierExpression>("foo")));
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("bar"),
+      std::make_unique<ast::IdentifierExpression>("bar")));
+  func->set_body(std::move(body));
+
+  mod.AddFunction(std::move(func));
+
+  auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kFragment,
+                                              "main", "frag_main");
+  auto* ep_ptr = ep.get();
+
+  mod.AddEntryPoint(std::move(ep));
+
+  ASSERT_TRUE(td.Determine()) << td.error();
+
+  GeneratorImpl g;
+  g.set_module_for_testing(&mod);
+  ASSERT_TRUE(g.EmitEntryPoint(ep_ptr)) << g.error();
+  EXPECT_EQ(g.result(), R"(struct frag_main_out {
+  float foo [[color(0)]];
+  int bar [[color(1)]];
+};
+
+)");
+}
+
+TEST_F(MslGeneratorImplTest, EmitEntryPoint_Compute_Input) {
+  // [[location 0]] var<in> foo : f32;
+  // [[location 1]] var<in> bar : i32;
+  //
+  // -> Error, not allowed
+
+  ast::type::F32Type f32;
+  ast::type::I32Type i32;
+
+  auto foo_var = std::make_unique<ast::DecoratedVariable>(
+      std::make_unique<ast::Variable>("foo", ast::StorageClass::kInput, &f32));
+
+  ast::VariableDecorationList decos;
+  decos.push_back(std::make_unique<ast::LocationDecoration>(0));
+  foo_var->set_decorations(std::move(decos));
+
+  auto bar_var = std::make_unique<ast::DecoratedVariable>(
+      std::make_unique<ast::Variable>("bar", ast::StorageClass::kInput, &i32));
+  decos.push_back(std::make_unique<ast::LocationDecoration>(1));
+  bar_var->set_decorations(std::move(decos));
+
+  Context ctx;
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+  td.RegisterVariableForTesting(foo_var.get());
+  td.RegisterVariableForTesting(bar_var.get());
+
+  mod.AddGlobalVariable(std::move(foo_var));
+  mod.AddGlobalVariable(std::move(bar_var));
+
+  ast::VariableList params;
+  auto func =
+      std::make_unique<ast::Function>("comp_main", std::move(params), &f32);
+
+  ast::StatementList body;
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("foo"),
+      std::make_unique<ast::IdentifierExpression>("foo")));
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("bar"),
+      std::make_unique<ast::IdentifierExpression>("bar")));
+  func->set_body(std::move(body));
+
+  mod.AddFunction(std::move(func));
+
+  auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kCompute,
+                                              "main", "comp_main");
+  auto* ep_ptr = ep.get();
+
+  mod.AddEntryPoint(std::move(ep));
+
+  ASSERT_TRUE(td.Determine()) << td.error();
+
+  GeneratorImpl g;
+  g.set_module_for_testing(&mod);
+  ASSERT_FALSE(g.EmitEntryPoint(ep_ptr)) << g.error();
+  EXPECT_EQ(g.error(), R"(invalid location variable for pipeline stage)");
+}
+
+TEST_F(MslGeneratorImplTest, EmitEntryPoint_Compute_Output) {
+  // [[location 0]] var<out> foo : f32;
+  // [[location 1]] var<out> bar : i32;
+  //
+  // -> Error not allowed
+
+  ast::type::F32Type f32;
+  ast::type::I32Type i32;
+
+  auto foo_var = std::make_unique<ast::DecoratedVariable>(
+      std::make_unique<ast::Variable>("foo", ast::StorageClass::kOutput, &f32));
+
+  ast::VariableDecorationList decos;
+  decos.push_back(std::make_unique<ast::LocationDecoration>(0));
+  foo_var->set_decorations(std::move(decos));
+
+  auto bar_var = std::make_unique<ast::DecoratedVariable>(
+      std::make_unique<ast::Variable>("bar", ast::StorageClass::kOutput, &i32));
+  decos.push_back(std::make_unique<ast::LocationDecoration>(1));
+  bar_var->set_decorations(std::move(decos));
+
+  Context ctx;
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+  td.RegisterVariableForTesting(foo_var.get());
+  td.RegisterVariableForTesting(bar_var.get());
+
+  mod.AddGlobalVariable(std::move(foo_var));
+  mod.AddGlobalVariable(std::move(bar_var));
+
+  ast::VariableList params;
+  auto func =
+      std::make_unique<ast::Function>("comp_main", std::move(params), &f32);
+
+  ast::StatementList body;
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("foo"),
+      std::make_unique<ast::IdentifierExpression>("foo")));
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("bar"),
+      std::make_unique<ast::IdentifierExpression>("bar")));
+  func->set_body(std::move(body));
+
+  mod.AddFunction(std::move(func));
+
+  auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kCompute,
+                                              "main", "comp_main");
+  auto* ep_ptr = ep.get();
+
+  mod.AddEntryPoint(std::move(ep));
+
+  ASSERT_TRUE(td.Determine()) << td.error();
+
+  GeneratorImpl g;
+  g.set_module_for_testing(&mod);
+  ASSERT_FALSE(g.EmitEntryPoint(ep_ptr)) << g.error();
+  EXPECT_EQ(g.error(), R"(invalid location variable for pipeline stage)");
+}
+
+}  // namespace
+}  // namespace msl
+}  // namespace writer
+}  // namespace tint
diff --git a/src/writer/msl/generator_impl_test.cc b/src/writer/msl/generator_impl_test.cc
index 05e7015..41b0cbb 100644
--- a/src/writer/msl/generator_impl_test.cc
+++ b/src/writer/msl/generator_impl_test.cc
@@ -19,9 +19,11 @@
 #include "gtest/gtest.h"
 #include "src/ast/entry_point.h"
 #include "src/ast/function.h"
+#include "src/ast/identifier_expression.h"
 #include "src/ast/module.h"
 #include "src/ast/pipeline_stage.h"
 #include "src/ast/type/void_type.h"
+#include "src/writer/msl/namer.h"
 
 namespace tint {
 namespace writer {
@@ -48,6 +50,36 @@
 )");
 }
 
+TEST_F(MslGeneratorImplTest, InputStructName) {
+  ast::EntryPoint ep(ast::PipelineStage::kVertex, "main", "func_main");
+
+  GeneratorImpl g;
+  ASSERT_EQ(g.generate_struct_name(&ep, "in"), "func_main_in");
+}
+
+TEST_F(MslGeneratorImplTest, InputStructName_ConflictWithExisting) {
+  ast::EntryPoint ep(ast::PipelineStage::kVertex, "main", "func_main");
+
+  GeneratorImpl g;
+
+  // Register the struct name as existing.
+  auto* namer = g.namer_for_testing();
+  namer->NameFor("func_main_out");
+
+  ASSERT_EQ(g.generate_struct_name(&ep, "out"), "func_main_out_0");
+}
+
+TEST_F(MslGeneratorImplTest, NameConflictWith_InputStructName) {
+  ast::EntryPoint ep(ast::PipelineStage::kVertex, "main", "func_main");
+
+  GeneratorImpl g;
+  ASSERT_EQ(g.generate_struct_name(&ep, "in"), "func_main_in");
+
+  ast::IdentifierExpression ident("func_main_in");
+  ASSERT_TRUE(g.EmitIdentifier(&ident));
+  EXPECT_EQ(g.result(), "func_main_in_0");
+}
+
 }  // namespace
 }  // namespace msl
 }  // namespace writer
diff --git a/src/writer/msl/namer.cc b/src/writer/msl/namer.cc
index cde3032..1a95f89 100644
--- a/src/writer/msl/namer.cc
+++ b/src/writer/msl/namer.cc
@@ -290,7 +290,7 @@
       }
       i++;
     }
-    remapped_names_.insert(ret_name);
+    RegisterRemappedName(ret_name);
   } else {
     uint32_t i = 0;
     // Make sure the ident name wasn't assigned by a remapping.
@@ -302,13 +302,22 @@
       ret_name = name + "_" + std::to_string(i);
       i++;
     }
-    remapped_names_.insert(ret_name);
+    RegisterRemappedName(ret_name);
   }
 
   name_map_[name] = ret_name;
   return ret_name;
 }
 
+bool Namer::IsMapped(const std::string& name) {
+  auto it = name_map_.find(name);
+  return it != name_map_.end();
+}
+
+void Namer::RegisterRemappedName(const std::string& name) {
+  remapped_names_.insert(name);
+}
+
 }  // namespace msl
 }  // namespace writer
 }  // namespace tint
diff --git a/src/writer/msl/namer.h b/src/writer/msl/namer.h
index 887e5be..1191dbb 100644
--- a/src/writer/msl/namer.h
+++ b/src/writer/msl/namer.h
@@ -35,6 +35,15 @@
   /// @returns the sanitized version of |name|
   std::string NameFor(const std::string& name);
 
+  /// Registers a remapped name.
+  /// @param name the name to register
+  void RegisterRemappedName(const std::string& name);
+
+  /// Returns if the given name has been mapped alread
+  /// @param name the name to check
+  /// @returns true if the name has been mapped
+  bool IsMapped(const std::string& name);
+
  private:
   /// Map of original name to new name. The two names may be the same.
   std::unordered_map<std::string, std::string> name_map_;