[spirv-writer] Handle entry point IO struct types

Recursively hoist struct members out to module-scope variables, and
redeclare the structs without entry point IO decorations. Generate a
function for storing entry point outputs to the corresponding
module-scope variables and replace return statements with calls to
this function.

Fixed: tint:509
Change-Id: I8977f384b3c7425f844e9346dbbde33b750ea920
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/45821
Auto-Submit: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/transform/spirv.cc b/src/transform/spirv.cc
index d404d4a..2d0580a 100644
--- a/src/transform/spirv.cc
+++ b/src/transform/spirv.cc
@@ -17,6 +17,7 @@
 #include <string>
 #include <utility>
 
+#include "src/ast/call_statement.h"
 #include "src/ast/return_statement.h"
 #include "src/program_builder.h"
 #include "src/semantic/function.h"
@@ -48,7 +49,8 @@
 void Spirv::HandleEntryPointIOTypes(CloneContext& ctx) const {
   // Hoist entry point parameters, return values, and struct members out to
   // global variables. Declare and construct struct parameters in the function
-  // body. Replace entry point return statements with variable assignments.
+  // body. Replace entry point return statements with calls to a function that
+  // assigns the return value to the global output variables.
   //
   // Before:
   // ```
@@ -62,11 +64,13 @@
   // };
   //
   // [[stage(fragment)]]
-  // fn fs_main(
+  // fn frag_main(
   //   [[builtin(frag_coord)]] coord : vec4<f32>,
   //   samples : FragmentInput
   // ) -> FragmentOutput {
-  //   return FragmentOutput(1.0, samples.sample_mask_in);
+  //   var output : FragmentOutput = FragmentOutput(1.0,
+  //                                                samples.sample_mask_in);
+  //   return output;
   // }
   // ```
   //
@@ -87,71 +91,85 @@
   // [[builtin(frag_depth)]] var<out> depth: f32;
   // [[builtin(sample_mask_out)]] var<out> mask_out : u32;
   //
+  // fn frag_main_ret(retval : FragmentOutput) -> void {
+  //   depth = reval.depth;
+  //   mask_out = retval.mask_out;
+  // }
+  //
   // [[stage(fragment)]]
-  // fn fs_main() -> void {
+  // fn frag_main() -> void {
   //   const samples : FragmentInput(sample_index, sample_mask_in);
-  //   depth = 1.0;
-  //   mask_out = samples.sample_mask_in;
+  //   var output : FragmentOutput = FragmentOutput(1.0,
+  //                                                samples.sample_mask_in);
+  //   frag_main_ret(output);
   //   return;
   // }
   // ```
 
-  // TODO(jrprice): Hoist struct members decorated as entry point IO types out
-  // of struct declarations, and redeclare the structs without the decorations.
+  // Strip entry point IO decorations from struct declarations.
+  for (auto* ty : ctx.src->AST().ConstructedTypes()) {
+    if (auto* struct_ty = ty->As<type::Struct>()) {
+      // Build new list of struct members without entry point IO decorations.
+      ast::StructMemberList new_struct_members;
+      for (auto* member : struct_ty->impl()->members()) {
+        ast::DecorationList new_decorations = RemoveDecorations(
+            &ctx, member->decorations(), [](const ast::Decoration* deco) {
+              return deco
+                  ->IsAnyOf<ast::BuiltinDecoration, ast::LocationDecoration>();
+            });
+        new_struct_members.push_back(
+            ctx.dst->Member(ctx.src->Symbols().NameFor(member->symbol()),
+                            ctx.Clone(member->type()), new_decorations));
+      }
+
+      // Redeclare the struct.
+      auto* new_struct = ctx.dst->create<type::Struct>(
+          ctx.Clone(struct_ty->symbol()),
+          ctx.dst->create<ast::Struct>(
+              new_struct_members, ctx.Clone(struct_ty->impl()->decorations())));
+      ctx.Replace(struct_ty, new_struct);
+    }
+  }
 
   for (auto* func : ctx.src->AST().Functions()) {
     if (!func->IsEntryPoint()) {
       continue;
     }
 
-    auto* sem_func = ctx.src->Sem().Get(func);
-
     for (auto* param : func->params()) {
-      // TODO(jrprice): Handle structures by moving the declaration and
-      // construction to the function body.
-      if (param->type()->Is<type::Struct>()) {
-        TINT_UNIMPLEMENTED(ctx.dst->Diagnostics())
-            << "structures as entry point parameters are not yet supported";
-        continue;
-      }
+      Symbol new_var =
+          HoistToInputVariables(ctx, func, param->type(), param->decorations());
 
-      // Create a new symbol for the global variable.
-      auto var_symbol = ctx.dst->Symbols().New();
-      // Create the global variable.
-      auto* var = ctx.dst->Var(var_symbol, ctx.Clone(param->type()),
-                               ast::StorageClass::kInput, nullptr,
-                               ctx.Clone(param->decorations()));
-      ctx.InsertBefore(func, var);
-
-      // Replace all uses of the function parameter with the global variable.
+      // Replace all uses of the function parameter with the new variable.
       for (auto* user : ctx.src->Sem().Get(param)->Users()) {
         ctx.Replace<ast::Expression>(user->Declaration(),
-                                     ctx.dst->Expr(var_symbol));
+                                     ctx.dst->Expr(new_var));
       }
     }
 
     if (!func->return_type()->Is<type::Void>()) {
-      // TODO(jrprice): Handle structures by creating a variable for each member
-      // and replacing return statements with extracts+stores.
-      if (func->return_type()->UnwrapAll()->Is<type::Struct>()) {
-        TINT_UNIMPLEMENTED(ctx.dst->Diagnostics())
-            << "structures as entry point return values are not yet supported";
-        continue;
-      }
+      ast::StatementList stores;
+      auto store_value_symbol = ctx.dst->Symbols().New();
+      HoistToOutputVariables(ctx, func, func->return_type(),
+                             func->return_type_decorations(), {},
+                             store_value_symbol, stores);
 
-      // Create a new symbol for the global variable.
-      auto var_symbol = ctx.dst->Symbols().New();
-      // Create the global variable.
-      auto* var = ctx.dst->Var(var_symbol, ctx.Clone(func->return_type()),
-                               ast::StorageClass::kOutput, nullptr,
-                               ctx.Clone(func->return_type_decorations()));
-      ctx.InsertBefore(func, var);
+      // Create a function that writes a return value to all output variables.
+      auto* store_value =
+          ctx.dst->Var(store_value_symbol, ctx.Clone(func->return_type()),
+                       ast::StorageClass::kFunction, nullptr);
+      auto return_func_symbol = ctx.dst->Symbols().New();
+      auto* return_func = ctx.dst->create<ast::Function>(
+          return_func_symbol, ast::VariableList{store_value},
+          ctx.dst->ty.void_(), ctx.dst->create<ast::BlockStatement>(stores),
+          ast::DecorationList{}, ast::DecorationList{});
+      ctx.InsertBefore(func, return_func);
 
-      // Replace all return statements with stores to the global variable.
+      // Replace all return statements with calls to the output function.
+      auto* sem_func = ctx.src->Sem().Get(func);
       for (auto* ret : sem_func->ReturnStatements()) {
-        ctx.InsertBefore(
-            ret, ctx.dst->create<ast::AssignmentStatement>(
-                     ctx.dst->Expr(var_symbol), ctx.Clone(ret->value())));
+        auto* call = ctx.dst->Call(return_func_symbol, ctx.Clone(ret->value()));
+        ctx.InsertBefore(ret, ctx.dst->create<ast::CallStatement>(call));
         ctx.Replace(ret, ctx.dst->create<ast::ReturnStatement>());
       }
     }
@@ -214,5 +232,91 @@
   }
 }
 
+Symbol Spirv::HoistToInputVariables(
+    CloneContext& ctx,
+    const ast::Function* func,
+    type::Type* ty,
+    const ast::DecorationList& decorations) const {
+  if (!ty->UnwrapAliasIfNeeded()->Is<type::Struct>()) {
+    // Base case: create a global variable and return.
+    ast::DecorationList new_decorations =
+        RemoveDecorations(&ctx, decorations, [](const ast::Decoration* deco) {
+          return !deco->IsAnyOf<ast::BuiltinDecoration,
+                                ast::LocationDecoration>();
+        });
+    auto global_var_symbol = ctx.dst->Symbols().New();
+    auto* global_var =
+        ctx.dst->Var(global_var_symbol, ctx.Clone(ty),
+                     ast::StorageClass::kInput, nullptr, new_decorations);
+    ctx.InsertBefore(func, global_var);
+    return global_var_symbol;
+  }
+
+  // Recurse into struct members and build the initializer list.
+  ast::ExpressionList init_values;
+  auto* struct_ty = ty->UnwrapAliasIfNeeded()->As<type::Struct>();
+  for (auto* member : struct_ty->impl()->members()) {
+    auto member_var =
+        HoistToInputVariables(ctx, func, member->type(), member->decorations());
+    init_values.push_back(ctx.dst->Expr(member_var));
+  }
+
+  auto func_var_symbol = ctx.dst->Symbols().New();
+  if (func->body()->empty()) {
+    // The return value should never get used.
+    return func_var_symbol;
+  }
+
+  // Create a function-scope variable for the struct.
+  // TODO(jrprice): Use Const when crbug.com/tint/662 is fixed
+  auto* initializer = ctx.dst->Construct(ctx.Clone(ty), init_values);
+  auto* func_var =
+      ctx.dst->Var(func_var_symbol, ctx.Clone(ty), ast::StorageClass::kFunction,
+                   initializer, ast::DecorationList{});
+  ctx.InsertBefore(*func->body()->begin(), ctx.dst->WrapInStatement(func_var));
+  return func_var_symbol;
+}
+
+void Spirv::HoistToOutputVariables(CloneContext& ctx,
+                                   const ast::Function* func,
+                                   type::Type* ty,
+                                   const ast::DecorationList& decorations,
+                                   std::vector<Symbol> member_accesses,
+                                   Symbol store_value,
+                                   ast::StatementList& stores) const {
+  // Base case.
+  if (!ty->UnwrapAliasIfNeeded()->Is<type::Struct>()) {
+    // Create a global variable.
+    ast::DecorationList new_decorations =
+        RemoveDecorations(&ctx, decorations, [](const ast::Decoration* deco) {
+          return !deco->IsAnyOf<ast::BuiltinDecoration,
+                                ast::LocationDecoration>();
+        });
+    auto global_var_symbol = ctx.dst->Symbols().New();
+    auto* global_var =
+        ctx.dst->Var(global_var_symbol, ctx.Clone(ty),
+                     ast::StorageClass::kOutput, nullptr, new_decorations);
+    ctx.InsertBefore(func, global_var);
+
+    // Create the assignment instruction.
+    ast::Expression* rhs = ctx.dst->Expr(store_value);
+    for (auto member : member_accesses) {
+      rhs = ctx.dst->MemberAccessor(rhs, member);
+    }
+    stores.push_back(ctx.dst->Assign(ctx.dst->Expr(global_var_symbol), rhs));
+
+    return;
+  }
+
+  // Recurse into struct members.
+  auto* struct_ty = ty->UnwrapAliasIfNeeded()->As<type::Struct>();
+  for (auto* member : struct_ty->impl()->members()) {
+    member_accesses.push_back(member->symbol());
+    HoistToOutputVariables(ctx, func, member->type(), member->decorations(),
+                           member_accesses, store_value, stores);
+    member_accesses.pop_back();
+  }
+}
+
 }  // namespace transform
 }  // namespace tint
diff --git a/src/transform/spirv.h b/src/transform/spirv.h
index 5231f0e..e2d1f12 100644
--- a/src/transform/spirv.h
+++ b/src/transform/spirv.h
@@ -15,6 +15,8 @@
 #ifndef SRC_TRANSFORM_SPIRV_H_
 #define SRC_TRANSFORM_SPIRV_H_
 
+#include <vector>
+
 #include "src/transform/transform.h"
 
 namespace tint {
@@ -44,6 +46,34 @@
   void HandleEntryPointIOTypes(CloneContext& ctx) const;
   /// Change type of sample mask builtin variables to single element arrays.
   void HandleSampleMaskBuiltins(CloneContext& ctx) const;
+
+  /// Recursively create module-scope input variables for `ty` and add
+  /// function-scope variables for structs to `func`.
+  ///
+  /// For non-structures, create a module-scope input variable.
+  /// For structures, recurse into members and then create a function-scope
+  /// variable initialized using the variables created for its members.
+  /// Return the symbol for the variable that was created.
+  Symbol HoistToInputVariables(CloneContext& ctx,
+                               const ast::Function* func,
+                               type::Type* ty,
+                               const ast::DecorationList& decorations) const;
+
+  /// Recursively create module-scope output variables for `ty` and build a list
+  /// of assignment instructions to write to them from `store_value`.
+  ///
+  /// For non-structures, create a module-scope output variable and generate the
+  /// assignment instruction.
+  /// For structures, recurse into members, tracking the chain of member
+  /// accessors.
+  /// Returns the list of variable assignments in `stores`.
+  void HoistToOutputVariables(CloneContext& ctx,
+                              const ast::Function* func,
+                              type::Type* ty,
+                              const ast::DecorationList& decorations,
+                              std::vector<Symbol> member_accesses,
+                              Symbol store_value,
+                              ast::StatementList& stores) const;
 };
 
 }  // namespace transform
diff --git a/src/transform/spirv_test.cc b/src/transform/spirv_test.cc
index 2a22b67..d8d0d09 100644
--- a/src/transform/spirv_test.cc
+++ b/src/transform/spirv_test.cc
@@ -95,11 +95,15 @@
 )";
 
   auto* expect = R"(
-[[builtin(position)]] var<out> tint_symbol_1 : vec4<f32>;
+[[builtin(position)]] var<out> tint_symbol_2 : vec4<f32>;
+
+fn tint_symbol_3(tint_symbol_1 : vec4<f32>) -> void {
+  tint_symbol_2 = tint_symbol_1;
+}
 
 [[stage(vertex)]]
 fn vert_main() -> void {
-  tint_symbol_1 = vec4<f32>(1.0, 2.0, 3.0, 0.0);
+  tint_symbol_3(vec4<f32>(1.0, 2.0, 3.0, 0.0));
   return;
 }
 )";
@@ -123,15 +127,19 @@
   auto* expect = R"(
 [[location(0)]] var<in> tint_symbol_1 : u32;
 
-[[location(0)]] var<out> tint_symbol_2 : f32;
+[[location(0)]] var<out> tint_symbol_3 : f32;
+
+fn tint_symbol_4(tint_symbol_2 : f32) -> void {
+  tint_symbol_3 = tint_symbol_2;
+}
 
 [[stage(fragment)]]
 fn frag_main() -> void {
   if ((tint_symbol_1 > 10u)) {
-    tint_symbol_2 = 0.5;
+    tint_symbol_4(0.5);
     return;
   }
-  tint_symbol_2 = 1.0;
+  tint_symbol_4(1.0);
   return;
 }
 )";
@@ -159,15 +167,587 @@
 
 [[location(0)]] var<in> tint_symbol_1 : u32;
 
-[[location(0)]] var<out> tint_symbol_2 : myf32;
+[[location(0)]] var<out> tint_symbol_3 : myf32;
+
+fn tint_symbol_5(tint_symbol_2 : myf32) -> void {
+  tint_symbol_3 = tint_symbol_2;
+}
 
 [[stage(fragment)]]
 fn frag_main() -> void {
   if ((tint_symbol_1 > 10u)) {
-    tint_symbol_2 = 0.5;
+    tint_symbol_5(0.5);
     return;
   }
-  tint_symbol_2 = 1.0;
+  tint_symbol_5(1.0);
+  return;
+}
+)";
+
+  auto got = Run<Spirv>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvTest, HandleEntryPointIOTypes_StructParameters) {
+  auto* src = R"(
+struct FragmentInput {
+  [[builtin(frag_coord)]] coord : vec4<f32>;
+  [[location(1)]] value : f32;
+};
+
+[[stage(fragment)]]
+fn frag_main(inputs : FragmentInput) -> void {
+  var col : f32 = inputs.coord.x * inputs.value;
+}
+)";
+
+  auto* expect = R"(
+struct FragmentInput {
+  coord : vec4<f32>;
+  value : f32;
+};
+
+[[builtin(frag_coord)]] var<in> tint_symbol_4 : vec4<f32>;
+
+[[location(1)]] var<in> tint_symbol_5 : f32;
+
+[[stage(fragment)]]
+fn frag_main() -> void {
+  var tint_symbol_6 : FragmentInput = FragmentInput(tint_symbol_4, tint_symbol_5);
+  var col : f32 = (tint_symbol_6.coord.x * tint_symbol_6.value);
+}
+)";
+
+  auto got = Run<Spirv>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvTest, HandleEntryPointIOTypes_StructParameters_Nested) {
+  auto* src = R"(
+struct Builtins {
+  [[builtin(frag_coord)]] coord : vec4<f32>;
+};
+
+struct Locations {
+  [[location(2)]] l2 : f32;
+  [[location(3)]] l3 : f32;
+};
+
+struct Other {
+  l : Locations;
+};
+
+struct FragmentInput {
+  b : Builtins;
+  o : Other;
+  [[location(1)]] value : f32;
+};
+
+[[stage(fragment)]]
+fn frag_main(inputs : FragmentInput) -> void {
+  var col : f32 = inputs.b.coord.x * inputs.value;
+  var l : f32 = inputs.o.l.l2 + inputs.o.l.l3;
+}
+)";
+
+  auto* expect = R"(
+struct Builtins {
+  coord : vec4<f32>;
+};
+
+struct Locations {
+  l2 : f32;
+  l3 : f32;
+};
+
+struct Other {
+  l : Locations;
+};
+
+struct FragmentInput {
+  b : Builtins;
+  o : Other;
+  value : f32;
+};
+
+[[builtin(frag_coord)]] var<in> tint_symbol_12 : vec4<f32>;
+
+[[location(2)]] var<in> tint_symbol_14 : f32;
+
+[[location(3)]] var<in> tint_symbol_15 : f32;
+
+[[location(1)]] var<in> tint_symbol_18 : f32;
+
+[[stage(fragment)]]
+fn frag_main() -> void {
+  var tint_symbol_13 : Builtins = Builtins(tint_symbol_12);
+  var tint_symbol_16 : Locations = Locations(tint_symbol_14, tint_symbol_15);
+  var tint_symbol_17 : Other = Other(tint_symbol_16);
+  var tint_symbol_19 : FragmentInput = FragmentInput(tint_symbol_13, tint_symbol_17, tint_symbol_18);
+  var col : f32 = (tint_symbol_19.b.coord.x * tint_symbol_19.value);
+  var l : f32 = (tint_symbol_19.o.l.l2 + tint_symbol_19.o.l.l3);
+}
+)";
+
+  auto got = Run<Spirv>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvTest, HandleEntryPointIOTypes_StructParameters_EmptyBody) {
+  auto* src = R"(
+struct Locations {
+  [[location(1)]] value : f32;
+};
+
+struct FragmentInput {
+  locations : Locations;
+};
+
+[[stage(fragment)]]
+fn frag_main(inputs : FragmentInput) -> void {
+}
+)";
+
+  auto* expect = R"(
+struct Locations {
+  value : f32;
+};
+
+struct FragmentInput {
+  locations : Locations;
+};
+
+[[location(1)]] var<in> tint_symbol_5 : f32;
+
+[[stage(fragment)]]
+fn frag_main() -> void {
+}
+)";
+
+  auto got = Run<Spirv>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvTest, HandleEntryPointIOTypes_ReturnStruct) {
+  auto* src = R"(
+struct VertexOutput {
+  [[builtin(position)]] pos : vec4<f32>;
+  [[location(1)]] value : f32;
+};
+
+[[stage(vertex)]]
+fn vert_main() -> VertexOutput {
+  if (false) {
+    return VertexOutput();
+  }
+  var pos : vec4<f32> = vec4<f32>(1.0, 2.0, 3.0, 0.0);
+  return VertexOutput(pos, 2.0);
+}
+)";
+
+  auto* expect = R"(
+struct VertexOutput {
+  pos : vec4<f32>;
+  value : f32;
+};
+
+[[builtin(position)]] var<out> tint_symbol_5 : vec4<f32>;
+
+[[location(1)]] var<out> tint_symbol_6 : f32;
+
+fn tint_symbol_7(tint_symbol_4 : VertexOutput) -> void {
+  tint_symbol_5 = tint_symbol_4.pos;
+  tint_symbol_6 = tint_symbol_4.value;
+}
+
+[[stage(vertex)]]
+fn vert_main() -> void {
+  if (false) {
+    tint_symbol_7(VertexOutput());
+    return;
+  }
+  var pos : vec4<f32> = vec4<f32>(1.0, 2.0, 3.0, 0.0);
+  tint_symbol_7(VertexOutput(pos, 2.0));
+  return;
+}
+)";
+
+  auto got = Run<Spirv>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvTest, HandleEntryPointIOTypes_ReturnStruct_Nested) {
+  auto* src = R"(
+struct Builtins {
+  [[builtin(position)]] pos : vec4<f32>;
+};
+
+struct Locations {
+  [[location(2)]] l2 : f32;
+  [[location(3)]] l3 : f32;
+};
+
+struct Other {
+  l : Locations;
+};
+
+struct VertexOutput {
+  b : Builtins;
+  o : Other;
+  [[location(1)]] value : f32;
+};
+
+[[stage(vertex)]]
+fn vert_main() -> VertexOutput {
+  if (false) {
+    return VertexOutput();
+  }
+  var output : VertexOutput = VertexOutput();
+  output.b.pos = vec4<f32>(1.0, 2.0, 3.0, 0.0);
+  output.o.l.l2 = 4.0;
+  output.o.l.l3 = 5.0;
+  output.value = 6.0;
+  return output;
+}
+)";
+
+  auto* expect = R"(
+struct Builtins {
+  pos : vec4<f32>;
+};
+
+struct Locations {
+  l2 : f32;
+  l3 : f32;
+};
+
+struct Other {
+  l : Locations;
+};
+
+struct VertexOutput {
+  b : Builtins;
+  o : Other;
+  value : f32;
+};
+
+[[builtin(position)]] var<out> tint_symbol_13 : vec4<f32>;
+
+[[location(2)]] var<out> tint_symbol_14 : f32;
+
+[[location(3)]] var<out> tint_symbol_15 : f32;
+
+[[location(1)]] var<out> tint_symbol_16 : f32;
+
+fn tint_symbol_17(tint_symbol_12 : VertexOutput) -> void {
+  tint_symbol_13 = tint_symbol_12.b.pos;
+  tint_symbol_14 = tint_symbol_12.o.l.l2;
+  tint_symbol_15 = tint_symbol_12.o.l.l3;
+  tint_symbol_16 = tint_symbol_12.value;
+}
+
+[[stage(vertex)]]
+fn vert_main() -> void {
+  if (false) {
+    tint_symbol_17(VertexOutput());
+    return;
+  }
+  var output : VertexOutput = VertexOutput();
+  output.b.pos = vec4<f32>(1.0, 2.0, 3.0, 0.0);
+  output.o.l.l2 = 4.0;
+  output.o.l.l3 = 5.0;
+  output.value = 6.0;
+  tint_symbol_17(output);
+  return;
+}
+)";
+
+  auto got = Run<Spirv>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvTest, HandleEntryPointIOTypes_SharedStruct_SameShader) {
+  auto* src = R"(
+struct Interface {
+  [[location(1)]] value : f32;
+};
+
+[[stage(vertex)]]
+fn vert_main(inputs : Interface) -> Interface {
+  return inputs;
+}
+)";
+
+  auto* expect = R"(
+struct Interface {
+  value : f32;
+};
+
+[[location(1)]] var<in> tint_symbol_3 : f32;
+
+[[location(1)]] var<out> tint_symbol_6 : f32;
+
+fn tint_symbol_7(tint_symbol_5 : Interface) -> void {
+  tint_symbol_6 = tint_symbol_5.value;
+}
+
+[[stage(vertex)]]
+fn vert_main() -> void {
+  var tint_symbol_4 : Interface = Interface(tint_symbol_3);
+  tint_symbol_7(tint_symbol_4);
+  return;
+}
+)";
+
+  auto got = Run<Spirv>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvTest, HandleEntryPointIOTypes_SharedStruct_DifferentShaders) {
+  auto* src = R"(
+struct Interface {
+  [[location(1)]] value : f32;
+};
+
+[[stage(vertex)]]
+fn vert_main() -> Interface {
+  return Interface(42.0);
+}
+
+[[stage(fragment)]]
+fn frag_main(inputs : Interface) -> void {
+  var x : f32 = inputs.value;
+}
+)";
+
+  auto* expect = R"(
+struct Interface {
+  value : f32;
+};
+
+[[location(1)]] var<out> tint_symbol_4 : f32;
+
+fn tint_symbol_5(tint_symbol_3 : Interface) -> void {
+  tint_symbol_4 = tint_symbol_3.value;
+}
+
+[[stage(vertex)]]
+fn vert_main() -> void {
+  tint_symbol_5(Interface(42.0));
+  return;
+}
+
+[[location(1)]] var<in> tint_symbol_7 : f32;
+
+[[stage(fragment)]]
+fn frag_main() -> void {
+  var tint_symbol_8 : Interface = Interface(tint_symbol_7);
+  var x : f32 = tint_symbol_8.value;
+}
+)";
+
+  auto got = Run<Spirv>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvTest, HandleEntryPointIOTypes_SharedSubStruct) {
+  auto* src = R"(
+struct Interface {
+  [[location(1)]] value : f32;
+};
+
+struct VertexOutput {
+  [[builtin(position)]] pos : vec4<f32>;
+  interface : Interface;
+};
+
+struct FragmentInput {
+  [[builtin(sample_index)]] index : u32;
+  interface : Interface;
+};
+
+[[stage(vertex)]]
+fn vert_main() -> VertexOutput {
+  return VertexOutput(vec4<f32>(), Interface(42.0));
+}
+
+[[stage(fragment)]]
+fn frag_main(inputs : FragmentInput) -> void {
+  var x : f32 = inputs.interface.value;
+}
+)";
+
+  auto* expect = R"(
+struct Interface {
+  value : f32;
+};
+
+struct VertexOutput {
+  pos : vec4<f32>;
+  interface : Interface;
+};
+
+struct FragmentInput {
+  index : u32;
+  interface : Interface;
+};
+
+[[builtin(position)]] var<out> tint_symbol_9 : vec4<f32>;
+
+[[location(1)]] var<out> tint_symbol_10 : f32;
+
+fn tint_symbol_11(tint_symbol_8 : VertexOutput) -> void {
+  tint_symbol_9 = tint_symbol_8.pos;
+  tint_symbol_10 = tint_symbol_8.interface.value;
+}
+
+[[stage(vertex)]]
+fn vert_main() -> void {
+  tint_symbol_11(VertexOutput(vec4<f32>(), Interface(42.0)));
+  return;
+}
+
+[[builtin(sample_index)]] var<in> tint_symbol_13 : u32;
+
+[[location(1)]] var<in> tint_symbol_14 : f32;
+
+[[stage(fragment)]]
+fn frag_main() -> void {
+  var tint_symbol_15 : Interface = Interface(tint_symbol_14);
+  var tint_symbol_16 : FragmentInput = FragmentInput(tint_symbol_13, tint_symbol_15);
+  var x : f32 = tint_symbol_16.interface.value;
+}
+)";
+
+  auto got = Run<Spirv>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvTest, HandleEntryPointIOTypes_NestedStruct_TypeAlias) {
+  auto* src = R"(
+type myf32 = f32;
+
+struct Location {
+  [[location(2)]] l2 : myf32;
+};
+
+type MyLocation = Location;
+
+struct VertexIO {
+  l : MyLocation;
+  [[location(1)]] value : myf32;
+};
+
+type MyVertexInput = VertexIO;
+
+type MyVertexOutput = VertexIO;
+
+[[stage(vertex)]]
+fn vert_main(inputs : MyVertexInput) -> MyVertexOutput {
+  return inputs;
+}
+)";
+
+  auto* expect = R"(
+type myf32 = f32;
+
+struct Location {
+  l2 : myf32;
+};
+
+type MyLocation = Location;
+
+struct VertexIO {
+  l : MyLocation;
+  value : myf32;
+};
+
+type MyVertexInput = VertexIO;
+
+type MyVertexOutput = VertexIO;
+
+[[location(2)]] var<in> tint_symbol_8 : myf32;
+
+[[location(1)]] var<in> tint_symbol_10 : myf32;
+
+[[location(2)]] var<out> tint_symbol_14 : myf32;
+
+[[location(1)]] var<out> tint_symbol_15 : myf32;
+
+fn tint_symbol_17(tint_symbol_13 : MyVertexOutput) -> void {
+  tint_symbol_14 = tint_symbol_13.l.l2;
+  tint_symbol_15 = tint_symbol_13.value;
+}
+
+[[stage(vertex)]]
+fn vert_main() -> void {
+  var tint_symbol_9 : MyLocation = MyLocation(tint_symbol_8);
+  var tint_symbol_11 : MyVertexInput = MyVertexInput(tint_symbol_9, tint_symbol_10);
+  tint_symbol_17(tint_symbol_11);
+  return;
+}
+)";
+
+  auto got = Run<Spirv>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvTest, HandleEntryPointIOTypes_StructLayoutDecorations) {
+  auto* src = R"(
+[[block]]
+struct FragmentInput {
+  [[size(16), location(1)]] value : f32;
+  [[builtin(frag_coord)]] [[align(32)]] coord : vec4<f32>;
+};
+
+struct FragmentOutput {
+  [[size(16), location(1)]] value : f32;
+};
+
+[[stage(fragment)]]
+fn frag_main(inputs : FragmentInput) -> FragmentOutput {
+  return FragmentOutput(inputs.coord.x * inputs.value);
+}
+)";
+
+  auto* expect = R"(
+[[block]]
+struct FragmentInput {
+  [[size(16)]]
+  value : f32;
+  [[align(32)]]
+  coord : vec4<f32>;
+};
+
+struct FragmentOutput {
+  [[size(16)]]
+  value : f32;
+};
+
+[[location(1)]] var<in> tint_symbol_5 : f32;
+
+[[builtin(frag_coord)]] var<in> tint_symbol_6 : vec4<f32>;
+
+[[location(1)]] var<out> tint_symbol_9 : f32;
+
+fn tint_symbol_10(tint_symbol_8 : FragmentOutput) -> void {
+  tint_symbol_9 = tint_symbol_8.value;
+}
+
+[[stage(fragment)]]
+fn frag_main() -> void {
+  var tint_symbol_7 : FragmentInput = FragmentInput(tint_symbol_5, tint_symbol_6);
+  tint_symbol_10(FragmentOutput((tint_symbol_7.coord.x * tint_symbol_7.value)));
   return;
 }
 )";
@@ -269,11 +849,15 @@
 
 [[builtin(sample_mask_in)]] var<in> tint_symbol_2 : array<u32, 1>;
 
-[[builtin(sample_mask_out)]] var<out> tint_symbol_3 : array<u32, 1>;
+[[builtin(sample_mask_out)]] var<out> tint_symbol_4 : array<u32, 1>;
+
+fn tint_symbol_5(tint_symbol_3 : u32) -> void {
+  tint_symbol_4[0] = tint_symbol_3;
+}
 
 [[stage(fragment)]]
 fn main() -> void {
-  tint_symbol_3[0] = tint_symbol_2[0];
+  tint_symbol_5(tint_symbol_2[0]);
   return;
 }
 )";
diff --git a/src/transform/transform.cc b/src/transform/transform.cc
index fbf3e13..955f16e 100644
--- a/src/transform/transform.cc
+++ b/src/transform/transform.cc
@@ -81,5 +81,18 @@
   });
 }
 
+ast::DecorationList Transform::RemoveDecorations(
+    CloneContext* ctx,
+    const ast::DecorationList& in,
+    std::function<bool(const ast::Decoration*)> should_remove) {
+  ast::DecorationList new_decorations;
+  for (auto* deco : in) {
+    if (!should_remove(deco)) {
+      new_decorations.push_back(ctx->Clone(deco));
+    }
+  }
+  return new_decorations;
+}
+
 }  // namespace transform
 }  // namespace tint
diff --git a/src/transform/transform.h b/src/transform/transform.h
index 6e51c04..8b475bb 100644
--- a/src/transform/transform.h
+++ b/src/transform/transform.h
@@ -170,6 +170,16 @@
                                      const char* (&names)[N]) {
     RenameReservedKeywords(ctx, names, N);
   }
+
+  /// Clones the decoration list `in`, removing decorations based on a filter.
+  /// @param ctx the clone context
+  /// @param in the decorations to clone
+  /// @param should_remove the function to select which decorations to remove
+  /// @return the cloned decorations
+  static ast::DecorationList RemoveDecorations(
+      CloneContext* ctx,
+      const ast::DecorationList& in,
+      std::function<bool(const ast::Decoration*)> should_remove);
 };
 
 }  // namespace transform
diff --git a/src/writer/spirv/builder_entry_point_test.cc b/src/writer/spirv/builder_entry_point_test.cc
index a45e506..3650c3d 100644
--- a/src/writer/spirv/builder_entry_point_test.cc
+++ b/src/writer/spirv/builder_entry_point_test.cc
@@ -1,4 +1,4 @@
-// Copyright 2020 The Tint Authors.
+// Copyright 2021 The Tint Authors.
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -129,11 +129,13 @@
   // Output storage class, and the return statements are replaced with stores.
   EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
 OpMemoryModel Logical GLSL450
-OpEntryPoint Fragment %10 "frag_main" %1 %4
-OpExecutionMode %10 OriginUpperLeft
+OpEntryPoint Fragment %15 "frag_main" %1 %4
+OpExecutionMode %15 OriginUpperLeft
 OpName %1 "tint_symbol_1"
-OpName %4 "tint_symbol_2"
-OpName %10 "frag_main"
+OpName %4 "tint_symbol_3"
+OpName %10 "tint_symbol_4"
+OpName %11 "tint_symbol_2"
+OpName %15 "frag_main"
 OpDecorate %1 Location 0
 OpDecorate %4 Location 0
 %3 = OpTypeInt 32 0
@@ -144,22 +146,216 @@
 %7 = OpConstantNull %6
 %4 = OpVariable %5 Output %7
 %9 = OpTypeVoid
-%8 = OpTypeFunction %9
-%13 = OpConstant %3 10
-%15 = OpTypeBool
-%18 = OpConstant %6 0.5
-%19 = OpConstant %6 1
+%8 = OpTypeFunction %9 %6
+%14 = OpTypeFunction %9
+%18 = OpConstant %3 10
+%20 = OpTypeBool
+%24 = OpConstant %6 0.5
+%26 = OpConstant %6 1
 %10 = OpFunction %9 None %8
-%11 = OpLabel
-%12 = OpLoad %3 %1
-%14 = OpUGreaterThan %15 %12 %13
-OpSelectionMerge %16 None
-OpBranchConditional %14 %17 %16
-%17 = OpLabel
-OpStore %4 %18
+%11 = OpFunctionParameter %6
+%12 = OpLabel
+%13 = OpLoad %6 %11
+OpStore %4 %13
 OpReturn
+OpFunctionEnd
+%15 = OpFunction %9 None %14
 %16 = OpLabel
-OpStore %4 %19
+%17 = OpLoad %3 %1
+%19 = OpUGreaterThan %20 %17 %18
+OpSelectionMerge %21 None
+OpBranchConditional %19 %22 %21
+%22 = OpLabel
+%23 = OpFunctionCall %9 %10 %24
+OpReturn
+%21 = OpLabel
+%25 = OpFunctionCall %9 %10 %26
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(BuilderTest, EntryPoint_SharedSubStruct) {
+  // struct Interface {
+  //   [[location(1)]] value : f32;
+  // };
+  //
+  // struct VertexOutput {
+  //   [[builtin(position)]] pos : vec4<f32>;
+  //   interface : Interface;
+  // };
+  //
+  // struct FragmentInput {
+  //   [[location(0)]] mul : f32;
+  //   interface : Interface;
+  // };
+  //
+  // [[stage(vertex)]]
+  // fn vert_main() -> VertexOutput {
+  //   return VertexOutput(vec4<f32>(), Interface(42.0));
+  // }
+  //
+  // [[stage(fragment)]]
+  // fn frag_main(inputs : FragmentInput) -> [[builtin(frag_depth)]] f32 {
+  //   return inputs.mul * inputs.interface.value;
+  // }
+
+  auto* interface =
+      Structure("Interface",
+                ast::StructMemberList{Member(
+                    "value", ty.f32(),
+                    ast::DecorationList{create<ast::LocationDecoration>(1u)})});
+  auto* vertex_output = Structure(
+      "VertexOutput",
+      ast::StructMemberList{
+          Member("pos", ty.vec4<f32>(),
+                 ast::DecorationList{
+                     create<ast::BuiltinDecoration>(ast::Builtin::kPosition)}),
+          Member("interface", interface)});
+  auto* fragment_input = Structure(
+      "FragmentInput",
+      ast::StructMemberList{
+          Member("mul", ty.f32(),
+                 ast::DecorationList{create<ast::LocationDecoration>(0u)}),
+          Member("interface", interface)});
+
+  auto* vert_retval = Construct(vertex_output, Construct(ty.vec4<f32>()),
+                                Construct(interface, 42.f));
+  Func("vert_main", ast::VariableList{}, vertex_output,
+       ast::StatementList{
+           create<ast::ReturnStatement>(vert_retval),
+       },
+       ast::DecorationList{
+           create<ast::StageDecoration>(ast::PipelineStage::kVertex),
+       });
+
+  auto* frag_retval =
+      Mul(MemberAccessor(Expr("inputs"), "mul"),
+          MemberAccessor(MemberAccessor(Expr("inputs"), "interface"), "value"));
+  auto* frag_inputs =
+      Var("inputs", fragment_input, ast::StorageClass::kFunction, nullptr);
+  Func("frag_main", ast::VariableList{frag_inputs}, ty.f32(),
+       ast::StatementList{
+           create<ast::ReturnStatement>(frag_retval),
+       },
+       ast::DecorationList{
+           create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+       },
+       ast::DecorationList{
+           create<ast::BuiltinDecoration>(ast::Builtin::kFragDepth)});
+
+  spirv::Builder& b = SanitizeAndBuild();
+
+  ASSERT_TRUE(b.Build()) << b.error();
+
+  EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint Vertex %30 "vert_main" %1 %6
+OpEntryPoint Fragment %41 "frag_main" %11 %9 %12
+OpExecutionMode %41 OriginUpperLeft
+OpExecutionMode %41 DepthReplacing
+OpName %1 "tint_symbol_9"
+OpName %6 "tint_symbol_10"
+OpName %9 "tint_symbol_13"
+OpName %11 "tint_symbol_14"
+OpName %12 "tint_symbol_18"
+OpName %15 "VertexOutput"
+OpMemberName %15 0 "pos"
+OpMemberName %15 1 "interface"
+OpName %16 "Interface"
+OpMemberName %16 0 "value"
+OpName %17 "tint_symbol_11"
+OpName %18 "tint_symbol_8"
+OpName %30 "vert_main"
+OpName %37 "tint_symbol_19"
+OpName %38 "tint_symbol_17"
+OpName %41 "frag_main"
+OpName %45 "tint_symbol_15"
+OpName %48 "FragmentInput"
+OpMemberName %48 0 "mul"
+OpMemberName %48 1 "interface"
+OpName %52 "tint_symbol_16"
+OpDecorate %1 BuiltIn Position
+OpDecorate %6 Location 1
+OpDecorate %9 Location 0
+OpDecorate %11 Location 1
+OpDecorate %12 BuiltIn FragDepth
+OpMemberDecorate %15 0 Offset 0
+OpMemberDecorate %15 1 Offset 16
+OpMemberDecorate %16 0 Offset 0
+OpMemberDecorate %48 0 Offset 0
+OpMemberDecorate %48 1 Offset 4
+%4 = OpTypeFloat 32
+%3 = OpTypeVector %4 4
+%2 = OpTypePointer Output %3
+%5 = OpConstantNull %3
+%1 = OpVariable %2 Output %5
+%7 = OpTypePointer Output %4
+%8 = OpConstantNull %4
+%6 = OpVariable %7 Output %8
+%10 = OpTypePointer Input %4
+%9 = OpVariable %10 Input
+%11 = OpVariable %10 Input
+%12 = OpVariable %7 Output %8
+%14 = OpTypeVoid
+%16 = OpTypeStruct %4
+%15 = OpTypeStruct %3 %16
+%13 = OpTypeFunction %14 %15
+%20 = OpTypeInt 32 0
+%21 = OpConstant %20 0
+%22 = OpTypePointer Function %3
+%25 = OpConstant %20 1
+%26 = OpTypePointer Function %4
+%29 = OpTypeFunction %14
+%33 = OpConstant %4 42
+%34 = OpConstantComposite %16 %33
+%35 = OpConstantComposite %15 %5 %34
+%36 = OpTypeFunction %14 %4
+%46 = OpTypePointer Function %16
+%47 = OpConstantNull %16
+%48 = OpTypeStruct %4 %16
+%53 = OpTypePointer Function %48
+%54 = OpConstantNull %48
+%17 = OpFunction %14 None %13
+%18 = OpFunctionParameter %15
+%19 = OpLabel
+%23 = OpAccessChain %22 %18 %21
+%24 = OpLoad %3 %23
+OpStore %1 %24
+%27 = OpAccessChain %26 %18 %25 %21
+%28 = OpLoad %4 %27
+OpStore %6 %28
+OpReturn
+OpFunctionEnd
+%30 = OpFunction %14 None %29
+%31 = OpLabel
+%32 = OpFunctionCall %14 %17 %35
+OpReturn
+OpFunctionEnd
+%37 = OpFunction %14 None %36
+%38 = OpFunctionParameter %4
+%39 = OpLabel
+%40 = OpLoad %4 %38
+OpStore %12 %40
+OpReturn
+OpFunctionEnd
+%41 = OpFunction %14 None %29
+%42 = OpLabel
+%45 = OpVariable %46 Function %47
+%52 = OpVariable %53 Function %54
+%43 = OpLoad %4 %11
+%44 = OpCompositeConstruct %16 %43
+OpStore %45 %44
+%49 = OpLoad %4 %9
+%50 = OpLoad %16 %45
+%51 = OpCompositeConstruct %48 %49 %50
+OpStore %52 %51
+%56 = OpAccessChain %26 %52 %21
+%57 = OpLoad %4 %56
+%58 = OpAccessChain %26 %52 %25 %21
+%59 = OpLoad %4 %58
+%60 = OpFMul %4 %57 %59
+%55 = OpFunctionCall %14 %37 %60
 OpReturn
 OpFunctionEnd
 )");