transform/VertexPulling: Handle entry point parameters

Adds support for the new shader IO syntax by processing entry
parameters and pushing them to function-scope variables as necessary.

Module-scope variables are still supported for now.

Fixed: tint:731
Change-Id: I36d7ce4e3a990b6323292cb7c685af37187d6fda
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/48960
Auto-Submit: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/transform/vertex_pulling.cc b/src/transform/vertex_pulling.cc
index 98e4000..f84bb9b 100644
--- a/src/transform/vertex_pulling.cc
+++ b/src/transform/vertex_pulling.cc
@@ -47,12 +47,14 @@
 
   CloneContext& ctx;
   VertexPulling::Config const cfg;
-  std::unordered_map<uint32_t, ast::Variable*> location_to_var;
-  Symbol vertex_index_name;
-  Symbol instance_index_name;
+  std::unordered_map<uint32_t, std::function<ast::Expression*()>>
+      location_to_expr;
+  std::function<ast::Expression*()> vertex_index_expr = nullptr;
+  std::function<ast::Expression*()> instance_index_expr = nullptr;
   Symbol pulling_position_name;
   Symbol struct_buffer_name;
   std::unordered_map<uint32_t, Symbol> vertex_buffer_names;
+  ast::VariableList new_function_parameters;
 
   /// Generate the vertex buffer binding name
   /// @param index index to append to buffer name
@@ -106,7 +108,9 @@
       for (auto* d : v->decorations()) {
         if (auto* builtin = d->As<ast::BuiltinDecoration>()) {
           if (builtin->value() == ast::Builtin::kVertexIndex) {
-            vertex_index_name = ctx.Clone(v->symbol());
+            vertex_index_expr = [this, v]() {
+              return ctx.dst->Expr(ctx.Clone(v->symbol()));
+            };
             return;
           }
         }
@@ -114,11 +118,10 @@
     }
 
     // We didn't find a vertex index builtin, so create one
-    static const char kDefaultVertexIndexName[] = "tint_pulling_vertex_index";
-    vertex_index_name = ctx.dst->Symbols().New(kDefaultVertexIndexName);
+    auto name = ctx.dst->Symbols().New("tint_pulling_vertex_index");
+    vertex_index_expr = [this, name]() { return ctx.dst->Expr(name); };
 
-    ctx.dst->Global(vertex_index_name, ctx.dst->ty.u32(),
-                    ast::StorageClass::kInput, nullptr,
+    ctx.dst->Global(name, ctx.dst->ty.u32(), ast::StorageClass::kInput, nullptr,
                     ast::DecorationList{
                         ctx.dst->Builtin(ast::Builtin::kVertexIndex),
                     });
@@ -147,7 +150,9 @@
       for (auto* d : v->decorations()) {
         if (auto* builtin = d->As<ast::BuiltinDecoration>()) {
           if (builtin->value() == ast::Builtin::kInstanceIndex) {
-            instance_index_name = ctx.Clone(v->symbol());
+            instance_index_expr = [this, v]() {
+              return ctx.dst->Expr(ctx.Clone(v->symbol()));
+            };
             return;
           }
         }
@@ -155,12 +160,10 @@
     }
 
     // We didn't find an instance index builtin, so create one
-    static const char kDefaultInstanceIndexName[] =
-        "tint_pulling_instance_index";
-    instance_index_name = ctx.dst->Symbols().New(kDefaultInstanceIndexName);
+    auto name = ctx.dst->Symbols().New("tint_pulling_instance_index");
+    instance_index_expr = [this, name]() { return ctx.dst->Expr(name); };
 
-    ctx.dst->Global(instance_index_name, ctx.dst->ty.u32(),
-                    ast::StorageClass::kInput, nullptr,
+    ctx.dst->Global(name, ctx.dst->ty.u32(), ast::StorageClass::kInput, nullptr,
                     ast::DecorationList{
                         ctx.dst->Builtin(ast::Builtin::kInstanceIndex),
                     });
@@ -180,10 +183,12 @@
           // This is where the replacement is created. Expressions use
           // identifier strings instead of pointers, so we don't need to update
           // any other place in the AST.
-          auto* replacement = ctx.dst->Var(ctx.Clone(v->symbol()),
-                                           ctx.Clone(v->declared_type()),
+          auto name = ctx.Clone(v->symbol());
+          auto* replacement = ctx.dst->Var(name, ctx.Clone(v->declared_type()),
                                            ast::StorageClass::kPrivate);
-          location_to_var[location] = replacement;
+          location_to_expr[location] = [this, name]() {
+            return ctx.dst->Expr(name);
+          };
           ctx.Replace(v, replacement);
           break;
         }
@@ -237,30 +242,29 @@
 
       for (const VertexAttributeDescriptor& attribute_desc :
            buffer_layout.attributes) {
-        auto it = location_to_var.find(attribute_desc.shader_location);
-        if (it == location_to_var.end()) {
+        auto it = location_to_expr.find(attribute_desc.shader_location);
+        if (it == location_to_expr.end()) {
           continue;
         }
-        auto* v = it->second;
+        auto* ident = it->second();
 
-        auto name = buffer_layout.step_mode == InputStepMode::kVertex
-                        ? vertex_index_name
-                        : instance_index_name;
+        auto* index_expr = buffer_layout.step_mode == InputStepMode::kVertex
+                               ? vertex_index_expr()
+                               : instance_index_expr();
 
         // An expression for the start of the read in the buffer in bytes
         auto* pos_value = ctx.dst->Add(
-            ctx.dst->Mul(name,
+            ctx.dst->Mul(index_expr,
                          static_cast<uint32_t>(buffer_layout.array_stride)),
             static_cast<uint32_t>(attribute_desc.offset));
 
         // Update position of the read
-        auto* set_pos_expr = ctx.dst->create<ast::AssignmentStatement>(
-            ctx.dst->Expr(GetPullingPositionName()), pos_value);
+        auto* set_pos_expr =
+            ctx.dst->Assign(ctx.dst->Expr(GetPullingPositionName()), pos_value);
         stmts.emplace_back(set_pos_expr);
 
-        stmts.emplace_back(ctx.dst->create<ast::AssignmentStatement>(
-            ctx.dst->create<ast::IdentifierExpression>(v->symbol()),
-            AccessByFormat(i, attribute_desc.format)));
+        stmts.emplace_back(
+            ctx.dst->Assign(ident, AccessByFormat(i, attribute_desc.format)));
       }
     }
 
@@ -379,6 +383,186 @@
     return ctx.dst->create<ast::TypeConstructorExpression>(
         ctx.dst->create<sem::Vector>(base_type, count), std::move(expr_list));
   }
+
+  /// Process a non-struct entry point parameter.
+  /// Generate function-scope variables for location parameters, and record
+  /// vertex_index and instance_index builtins if present.
+  /// @param func the entry point function
+  /// @param param the parameter to process
+  void ProcessNonStructParameter(ast::Function* func, ast::Variable* param) {
+    if (auto* location =
+            ast::GetDecoration<ast::LocationDecoration>(param->decorations())) {
+      // Create a function-scope variable to replace the parameter.
+      auto func_var_sym = ctx.Clone(param->symbol());
+      auto* func_var_type = ctx.Clone(param->declared_type());
+      auto* func_var = ctx.dst->Var(func_var_sym, func_var_type,
+                                    ast::StorageClass::kFunction);
+      ctx.InsertBefore(func->body()->statements(), *func->body()->begin(),
+                       ctx.dst->Decl(func_var));
+      // Capture mapping from location to the new variable.
+      location_to_expr[location->value()] = [this, func_var]() {
+        return ctx.dst->Expr(func_var);
+      };
+    } else if (auto* builtin = ast::GetDecoration<ast::BuiltinDecoration>(
+                   param->decorations())) {
+      // Check for existing vertex_index and instance_index builtins.
+      if (builtin->value() == ast::Builtin::kVertexIndex) {
+        vertex_index_expr = [this, param]() {
+          return ctx.dst->Expr(ctx.Clone(param->symbol()));
+        };
+      } else if (builtin->value() == ast::Builtin::kInstanceIndex) {
+        instance_index_expr = [this, param]() {
+          return ctx.dst->Expr(ctx.Clone(param->symbol()));
+        };
+      }
+      new_function_parameters.push_back(ctx.Clone(param));
+    } else {
+      TINT_ICE(ctx.dst->Diagnostics()) << "Invalid entry point parameter";
+    }
+  }
+
+  /// Process a struct entry point parameter.
+  /// If the struct has members with location attributes, push the parameter to
+  /// a function-scope variable and create a new struct parameter without those
+  /// attributes. Record expressions for members that are vertex_index and
+  /// instance_index builtins.
+  /// @param func the entry point function
+  /// @param param the parameter to process
+  void ProcessStructParameter(ast::Function* func, ast::Variable* param) {
+    auto* struct_ty = param->declared_type()->As<sem::StructType>();
+    if (!struct_ty) {
+      TINT_ICE(ctx.dst->Diagnostics()) << "Invalid struct parameter";
+    }
+
+    auto param_sym = ctx.Clone(param->symbol());
+
+    // Process the struct members.
+    bool has_locations = false;
+    ast::StructMemberList members_to_clone;
+    for (auto* member : struct_ty->impl()->members()) {
+      auto member_sym = ctx.Clone(member->symbol());
+      std::function<ast::Expression*()> member_expr = [this, param_sym,
+                                                       member_sym]() {
+        return ctx.dst->MemberAccessor(param_sym, member_sym);
+      };
+
+      if (auto* location = ast::GetDecoration<ast::LocationDecoration>(
+              member->decorations())) {
+        // Capture mapping from location to struct member.
+        location_to_expr[location->value()] = member_expr;
+        has_locations = true;
+      } else if (auto* builtin = ast::GetDecoration<ast::BuiltinDecoration>(
+                     member->decorations())) {
+        // Check for existing vertex_index and instance_index builtins.
+        if (builtin->value() == ast::Builtin::kVertexIndex) {
+          vertex_index_expr = member_expr;
+        } else if (builtin->value() == ast::Builtin::kInstanceIndex) {
+          instance_index_expr = member_expr;
+        }
+        members_to_clone.push_back(member);
+      } else {
+        TINT_ICE(ctx.dst->Diagnostics()) << "Invalid entry point parameter";
+      }
+    }
+
+    if (!has_locations) {
+      // Nothing to do.
+      new_function_parameters.push_back(ctx.Clone(param));
+      return;
+    }
+
+    // Create a function-scope variable to replace the parameter.
+    auto* func_var = ctx.dst->Var(param_sym, ctx.Clone(param->declared_type()),
+                                  ast::StorageClass::kFunction);
+    ctx.InsertBefore(func->body()->statements(), *func->body()->begin(),
+                     ctx.dst->Decl(func_var));
+
+    if (!members_to_clone.empty()) {
+      // Create a new struct without the location attributes.
+      ast::StructMemberList new_members;
+      for (auto* member : members_to_clone) {
+        auto member_sym = ctx.Clone(member->symbol());
+        auto member_type = ctx.Clone(member->type());
+        auto member_decos = ctx.Clone(member->decorations());
+        new_members.push_back(
+            ctx.dst->Member(member_sym, member_type, std::move(member_decos)));
+      }
+      auto new_struct =
+          ctx.dst->Structure(ctx.dst->Symbols().New(), new_members);
+
+      // Create a new function parameter with this struct.
+      auto* new_param = ctx.dst->Param(ctx.dst->Symbols().New(), new_struct);
+      new_function_parameters.push_back(new_param);
+
+      // Copy values from the new parameter to the function-scope variable.
+      for (auto* member : members_to_clone) {
+        auto member_name = ctx.Clone(member->symbol());
+        ctx.InsertBefore(
+            func->body()->statements(), *func->body()->begin(),
+            ctx.dst->Assign(ctx.dst->MemberAccessor(func_var, member_name),
+                            ctx.dst->MemberAccessor(new_param, member_name)));
+      }
+    }
+  }
+
+  /// Process an entry point function.
+  /// @param func the entry point function
+  void Process(ast::Function* func) {
+    if (func->body()->empty()) {
+      return;
+    }
+
+    // Process entry point parameters.
+    for (auto* param : func->params()) {
+      auto* sem = ctx.src->Sem().Get(param);
+      if (sem->Type()->Is<sem::StructType>()) {
+        ProcessStructParameter(func, param);
+      } else {
+        ProcessNonStructParameter(func, param);
+      }
+    }
+
+    // Insert new parameters for vertex_index and instance_index if needed.
+    if (!vertex_index_expr) {
+      for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) {
+        if (layout.step_mode == InputStepMode::kVertex) {
+          auto name = ctx.dst->Symbols().New("tint_pulling_vertex_index");
+          new_function_parameters.push_back(
+              ctx.dst->Param(name, ctx.dst->ty.u32(),
+                             {ctx.dst->Builtin(ast::Builtin::kVertexIndex)}));
+          vertex_index_expr = [this, name]() { return ctx.dst->Expr(name); };
+          break;
+        }
+      }
+    }
+    if (!instance_index_expr) {
+      for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) {
+        if (layout.step_mode == InputStepMode::kInstance) {
+          auto name = ctx.dst->Symbols().New("tint_pulling_instance_index");
+          new_function_parameters.push_back(
+              ctx.dst->Param(name, ctx.dst->ty.u32(),
+                             {ctx.dst->Builtin(ast::Builtin::kInstanceIndex)}));
+          instance_index_expr = [this, name]() { return ctx.dst->Expr(name); };
+          break;
+        }
+      }
+    }
+
+    // Generate vertex pulling preamble.
+    ctx.InsertBefore(func->body()->statements(), *func->body()->begin(),
+                     CreateVertexPullingPreamble());
+
+    // Rewrite the function header with the new parameters.
+    auto func_sym = ctx.Clone(func->symbol());
+    auto ret_type = ctx.Clone(func->return_type());
+    auto* body = ctx.Clone(func->body());
+    auto decos = ctx.Clone(func->decorations());
+    auto ret_decos = ctx.Clone(func->return_type_decorations());
+    auto* new_func = ctx.dst->create<ast::Function>(
+        func->source(), func_sym, new_function_parameters, ret_type, body,
+        std::move(decos), std::move(ret_decos));
+    ctx.Replace(func, new_func);
+  }
 };
 
 }  // namespace
@@ -413,18 +597,26 @@
   CloneContext ctx(&out, in);
 
   State state{ctx, cfg};
-  state.FindOrInsertVertexIndexIfUsed();
-  state.FindOrInsertInstanceIndexIfUsed();
-  state.ConvertVertexInputVariablesToPrivate();
-  state.AddVertexStorageBuffers();
 
-  ctx.ReplaceAll([&](ast::Function* f) -> ast::Function* {
-    if (f == func) {
-      return CloneWithStatementsAtStart(&ctx, f,
-                                        {state.CreateVertexPullingPreamble()});
-    }
-    return nullptr;  // Just clone func
-  });
+  if (func->params().empty()) {
+    // TODO(crbug.com/tint/697): Remove this path for the old shader IO syntax.
+    state.FindOrInsertVertexIndexIfUsed();
+    state.FindOrInsertInstanceIndexIfUsed();
+    state.ConvertVertexInputVariablesToPrivate();
+    state.AddVertexStorageBuffers();
+
+    ctx.ReplaceAll([&](ast::Function* f) -> ast::Function* {
+      if (f == func) {
+        return CloneWithStatementsAtStart(
+            &ctx, f, {state.CreateVertexPullingPreamble()});
+      }
+      return nullptr;  // Just clone func
+    });
+  } else {
+    state.AddVertexStorageBuffers();
+    state.Process(func);
+  }
+
   ctx.Clone();
 
   return Output(Program(std::move(out)));
diff --git a/src/transform/vertex_pulling_test.cc b/src/transform/vertex_pulling_test.cc
index 8effc23..1e6ff9e 100644
--- a/src/transform/vertex_pulling_test.cc
+++ b/src/transform/vertex_pulling_test.cc
@@ -109,17 +109,13 @@
 
 TEST_F(VertexPullingTest, OneAttribute) {
   auto* src = R"(
-[[location(0)]] var<in> var_a : f32;
-
 [[stage(vertex)]]
-fn main() -> [[builtin(position)]] vec4<f32> {
-  return vec4<f32>();
+fn main([[location(0)]] var_a : f32) -> [[builtin(position)]] vec4<f32> {
+  return vec4<f32>(var_a, 0.0, 0.0, 1.0);
 }
 )";
 
   auto* expect = R"(
-[[builtin(vertex_index)]] var<in> tint_pulling_vertex_index : u32;
-
 [[block]]
 struct TintVertexData {
   tint_vertex_data : [[stride(4)]] array<u32>;
@@ -127,16 +123,15 @@
 
 [[binding(0), group(4)]] var<storage> tint_pulling_vertex_buffer_0 : [[access(read)]] TintVertexData;
 
-var<private> var_a : f32;
-
 [[stage(vertex)]]
-fn main() -> [[builtin(position)]] vec4<f32> {
+fn main([[builtin(vertex_index)]] tint_pulling_vertex_index : u32) -> [[builtin(position)]] vec4<f32> {
+  var var_a : f32;
   {
     var tint_pulling_pos : u32;
     tint_pulling_pos = ((tint_pulling_vertex_index * 4u) + 0u);
     var_a = bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(tint_pulling_pos / 4u)]);
   }
-  return vec4<f32>();
+  return vec4<f32>(var_a, 0.0, 0.0, 1.0);
 }
 )";
 
@@ -154,17 +149,13 @@
 
 TEST_F(VertexPullingTest, OneInstancedAttribute) {
   auto* src = R"(
-[[location(0)]] var<in> var_a : f32;
-
 [[stage(vertex)]]
-fn main() -> [[builtin(position)]] vec4<f32> {
-  return vec4<f32>();
+fn main([[location(0)]] var_a : f32) -> [[builtin(position)]] vec4<f32> {
+  return vec4<f32>(var_a, 0.0, 0.0, 1.0);
 }
 )";
 
   auto* expect = R"(
-[[builtin(instance_index)]] var<in> tint_pulling_instance_index : u32;
-
 [[block]]
 struct TintVertexData {
   tint_vertex_data : [[stride(4)]] array<u32>;
@@ -172,16 +163,15 @@
 
 [[binding(0), group(4)]] var<storage> tint_pulling_vertex_buffer_0 : [[access(read)]] TintVertexData;
 
-var<private> var_a : f32;
-
 [[stage(vertex)]]
-fn main() -> [[builtin(position)]] vec4<f32> {
+fn main([[builtin(instance_index)]] tint_pulling_instance_index : u32) -> [[builtin(position)]] vec4<f32> {
+  var var_a : f32;
   {
     var tint_pulling_pos : u32;
     tint_pulling_pos = ((tint_pulling_instance_index * 4u) + 0u);
     var_a = bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(tint_pulling_pos / 4u)]);
   }
-  return vec4<f32>();
+  return vec4<f32>(var_a, 0.0, 0.0, 1.0);
 }
 )";
 
@@ -199,6 +189,472 @@
 
 TEST_F(VertexPullingTest, OneAttributeDifferentOutputSet) {
   auto* src = R"(
+[[stage(vertex)]]
+fn main([[location(0)]] var_a : f32) -> [[builtin(position)]] vec4<f32> {
+  return vec4<f32>(var_a, 0.0, 0.0, 1.0);
+}
+)";
+
+  auto* expect = R"(
+[[block]]
+struct TintVertexData {
+  tint_vertex_data : [[stride(4)]] array<u32>;
+};
+
+[[binding(0), group(5)]] var<storage> tint_pulling_vertex_buffer_0 : [[access(read)]] TintVertexData;
+
+[[stage(vertex)]]
+fn main([[builtin(vertex_index)]] tint_pulling_vertex_index : u32) -> [[builtin(position)]] vec4<f32> {
+  var var_a : f32;
+  {
+    var tint_pulling_pos : u32;
+    tint_pulling_pos = ((tint_pulling_vertex_index * 4u) + 0u);
+    var_a = bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(tint_pulling_pos / 4u)]);
+  }
+  return vec4<f32>(var_a, 0.0, 0.0, 1.0);
+}
+)";
+
+  VertexPulling::Config cfg;
+  cfg.vertex_state = {
+      {{4, InputStepMode::kVertex, {{VertexFormat::kF32, 0, 0}}}}};
+  cfg.pulling_group = 5;
+  cfg.entry_point_name = "main";
+
+  DataMap data;
+  data.Add<VertexPulling::Config>(cfg);
+  auto got = Run<VertexPulling>(src, data);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, OneAttribute_Struct) {
+  auto* src = R"(
+struct Inputs {
+  [[location(0)]] var_a : f32;
+};
+
+[[stage(vertex)]]
+fn main(inputs : Inputs) -> [[builtin(position)]] vec4<f32> {
+  return vec4<f32>(inputs.var_a, 0.0, 0.0, 1.0);
+}
+)";
+
+  auto* expect = R"(
+[[block]]
+struct TintVertexData {
+  tint_vertex_data : [[stride(4)]] array<u32>;
+};
+
+[[binding(0), group(4)]] var<storage> tint_pulling_vertex_buffer_0 : [[access(read)]] TintVertexData;
+
+struct Inputs {
+  [[location(0)]]
+  var_a : f32;
+};
+
+[[stage(vertex)]]
+fn main([[builtin(vertex_index)]] tint_pulling_vertex_index : u32) -> [[builtin(position)]] vec4<f32> {
+  var inputs : Inputs;
+  {
+    var tint_pulling_pos : u32;
+    tint_pulling_pos = ((tint_pulling_vertex_index * 4u) + 0u);
+    inputs.var_a = bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(tint_pulling_pos / 4u)]);
+  }
+  return vec4<f32>(inputs.var_a, 0.0, 0.0, 1.0);
+}
+)";
+
+  VertexPulling::Config cfg;
+  cfg.vertex_state = {
+      {{4, InputStepMode::kVertex, {{VertexFormat::kF32, 0, 0}}}}};
+  cfg.entry_point_name = "main";
+
+  DataMap data;
+  data.Add<VertexPulling::Config>(cfg);
+  auto got = Run<VertexPulling>(src, data);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+// We expect the transform to use an existing builtin variables if it finds them
+TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex) {
+  auto* src = R"(
+[[stage(vertex)]]
+fn main([[location(0)]] var_a : f32,
+        [[location(1)]] var_b : f32,
+        [[builtin(vertex_index)]] custom_vertex_index : u32,
+        [[builtin(instance_index)]] custom_instance_index : u32
+        ) -> [[builtin(position)]] vec4<f32> {
+  return vec4<f32>(var_a, var_b, 0.0, 1.0);
+}
+)";
+
+  auto* expect = R"(
+[[block]]
+struct TintVertexData {
+  tint_vertex_data : [[stride(4)]] array<u32>;
+};
+
+[[binding(0), group(4)]] var<storage> tint_pulling_vertex_buffer_0 : [[access(read)]] TintVertexData;
+
+[[binding(1), group(4)]] var<storage> tint_pulling_vertex_buffer_1 : [[access(read)]] TintVertexData;
+
+[[stage(vertex)]]
+fn main([[builtin(vertex_index)]] custom_vertex_index : u32, [[builtin(instance_index)]] custom_instance_index : u32) -> [[builtin(position)]] vec4<f32> {
+  var var_a : f32;
+  var var_b : f32;
+  {
+    var tint_pulling_pos : u32;
+    tint_pulling_pos = ((custom_vertex_index * 4u) + 0u);
+    var_a = bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(tint_pulling_pos / 4u)]);
+    tint_pulling_pos = ((custom_instance_index * 4u) + 0u);
+    var_b = bitcast<f32>(tint_pulling_vertex_buffer_1.tint_vertex_data[(tint_pulling_pos / 4u)]);
+  }
+  return vec4<f32>(var_a, var_b, 0.0, 1.0);
+}
+)";
+
+  VertexPulling::Config cfg;
+  cfg.vertex_state = {{
+      {
+          4,
+          InputStepMode::kVertex,
+          {{VertexFormat::kF32, 0, 0}},
+      },
+      {
+          4,
+          InputStepMode::kInstance,
+          {{VertexFormat::kF32, 0, 1}},
+      },
+  }};
+  cfg.entry_point_name = "main";
+
+  DataMap data;
+  data.Add<VertexPulling::Config>(cfg);
+  auto got = Run<VertexPulling>(src, data);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex_Struct) {
+  auto* src = R"(
+struct Inputs {
+  [[location(0)]] var_a : f32;
+  [[location(1)]] var_b : f32;
+  [[builtin(vertex_index)]] custom_vertex_index : u32;
+  [[builtin(instance_index)]] custom_instance_index : u32;
+};
+
+[[stage(vertex)]]
+fn main(inputs : Inputs) -> [[builtin(position)]] vec4<f32> {
+  return vec4<f32>(inputs.var_a, inputs.var_b, 0.0, 1.0);
+}
+)";
+
+  auto* expect = R"(
+[[block]]
+struct TintVertexData {
+  tint_vertex_data : [[stride(4)]] array<u32>;
+};
+
+[[binding(0), group(4)]] var<storage> tint_pulling_vertex_buffer_0 : [[access(read)]] TintVertexData;
+
+[[binding(1), group(4)]] var<storage> tint_pulling_vertex_buffer_1 : [[access(read)]] TintVertexData;
+
+struct tint_symbol {
+  [[builtin(vertex_index)]]
+  custom_vertex_index : u32;
+  [[builtin(instance_index)]]
+  custom_instance_index : u32;
+};
+
+struct Inputs {
+  [[location(0)]]
+  var_a : f32;
+  [[location(1)]]
+  var_b : f32;
+  [[builtin(vertex_index)]]
+  custom_vertex_index : u32;
+  [[builtin(instance_index)]]
+  custom_instance_index : u32;
+};
+
+[[stage(vertex)]]
+fn main(tint_symbol_1 : tint_symbol) -> [[builtin(position)]] vec4<f32> {
+  var inputs : Inputs;
+  inputs.custom_vertex_index = tint_symbol_1.custom_vertex_index;
+  inputs.custom_instance_index = tint_symbol_1.custom_instance_index;
+  {
+    var tint_pulling_pos : u32;
+    tint_pulling_pos = ((inputs.custom_vertex_index * 4u) + 0u);
+    inputs.var_a = bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(tint_pulling_pos / 4u)]);
+    tint_pulling_pos = ((inputs.custom_instance_index * 4u) + 0u);
+    inputs.var_b = bitcast<f32>(tint_pulling_vertex_buffer_1.tint_vertex_data[(tint_pulling_pos / 4u)]);
+  }
+  return vec4<f32>(inputs.var_a, inputs.var_b, 0.0, 1.0);
+}
+)";
+
+  VertexPulling::Config cfg;
+  cfg.vertex_state = {{
+      {
+          4,
+          InputStepMode::kVertex,
+          {{VertexFormat::kF32, 0, 0}},
+      },
+      {
+          4,
+          InputStepMode::kInstance,
+          {{VertexFormat::kF32, 0, 1}},
+      },
+  }};
+  cfg.entry_point_name = "main";
+
+  DataMap data;
+  data.Add<VertexPulling::Config>(cfg);
+  auto got = Run<VertexPulling>(src, data);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex_SeparateStruct) {
+  auto* src = R"(
+struct Inputs {
+  [[location(0)]] var_a : f32;
+  [[location(1)]] var_b : f32;
+};
+
+struct Indices {
+  [[builtin(vertex_index)]] custom_vertex_index : u32;
+  [[builtin(instance_index)]] custom_instance_index : u32;
+};
+
+[[stage(vertex)]]
+fn main(inputs : Inputs, indices : Indices) -> [[builtin(position)]] vec4<f32> {
+  return vec4<f32>(inputs.var_a, inputs.var_b, 0.0, 1.0);
+}
+)";
+
+  auto* expect = R"(
+[[block]]
+struct TintVertexData {
+  tint_vertex_data : [[stride(4)]] array<u32>;
+};
+
+[[binding(0), group(4)]] var<storage> tint_pulling_vertex_buffer_0 : [[access(read)]] TintVertexData;
+
+[[binding(1), group(4)]] var<storage> tint_pulling_vertex_buffer_1 : [[access(read)]] TintVertexData;
+
+struct Inputs {
+  [[location(0)]]
+  var_a : f32;
+  [[location(1)]]
+  var_b : f32;
+};
+
+struct Indices {
+  [[builtin(vertex_index)]]
+  custom_vertex_index : u32;
+  [[builtin(instance_index)]]
+  custom_instance_index : u32;
+};
+
+[[stage(vertex)]]
+fn main(indices : Indices) -> [[builtin(position)]] vec4<f32> {
+  var inputs : Inputs;
+  {
+    var tint_pulling_pos : u32;
+    tint_pulling_pos = ((indices.custom_vertex_index * 4u) + 0u);
+    inputs.var_a = bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(tint_pulling_pos / 4u)]);
+    tint_pulling_pos = ((indices.custom_instance_index * 4u) + 0u);
+    inputs.var_b = bitcast<f32>(tint_pulling_vertex_buffer_1.tint_vertex_data[(tint_pulling_pos / 4u)]);
+  }
+  return vec4<f32>(inputs.var_a, inputs.var_b, 0.0, 1.0);
+}
+)";
+
+  VertexPulling::Config cfg;
+  cfg.vertex_state = {{
+      {
+          4,
+          InputStepMode::kVertex,
+          {{VertexFormat::kF32, 0, 0}},
+      },
+      {
+          4,
+          InputStepMode::kInstance,
+          {{VertexFormat::kF32, 0, 1}},
+      },
+  }};
+  cfg.entry_point_name = "main";
+
+  DataMap data;
+  data.Add<VertexPulling::Config>(cfg);
+  auto got = Run<VertexPulling>(src, data);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, TwoAttributesSameBuffer) {
+  auto* src = R"(
+[[stage(vertex)]]
+fn main([[location(0)]] var_a : f32,
+        [[location(1)]] var_b : vec4<f32>) -> [[builtin(position)]] vec4<f32> {
+  return vec4<f32>();
+}
+)";
+
+  auto* expect = R"(
+[[block]]
+struct TintVertexData {
+  tint_vertex_data : [[stride(4)]] array<u32>;
+};
+
+[[binding(0), group(4)]] var<storage> tint_pulling_vertex_buffer_0 : [[access(read)]] TintVertexData;
+
+[[stage(vertex)]]
+fn main([[builtin(vertex_index)]] tint_pulling_vertex_index : u32) -> [[builtin(position)]] vec4<f32> {
+  var var_a : f32;
+  var var_b : vec4<f32>;
+  {
+    var tint_pulling_pos : u32;
+    tint_pulling_pos = ((tint_pulling_vertex_index * 16u) + 0u);
+    var_a = bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[(tint_pulling_pos / 4u)]);
+    tint_pulling_pos = ((tint_pulling_vertex_index * 16u) + 0u);
+    var_b = vec4<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[((tint_pulling_pos + 0u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[((tint_pulling_pos + 4u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[((tint_pulling_pos + 8u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[((tint_pulling_pos + 12u) / 4u)]));
+  }
+  return vec4<f32>();
+}
+)";
+
+  VertexPulling::Config cfg;
+  cfg.vertex_state = {
+      {{16,
+        InputStepMode::kVertex,
+        {{VertexFormat::kF32, 0, 0}, {VertexFormat::kVec4F32, 0, 1}}}}};
+  cfg.entry_point_name = "main";
+
+  DataMap data;
+  data.Add<VertexPulling::Config>(cfg);
+  auto got = Run<VertexPulling>(src, data);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, FloatVectorAttributes) {
+  auto* src = R"(
+[[stage(vertex)]]
+fn main([[location(0)]] var_a : vec2<f32>,
+        [[location(1)]] var_b : vec3<f32>,
+        [[location(2)]] var_c : vec4<f32>
+        ) -> [[builtin(position)]] vec4<f32> {
+  return vec4<f32>();
+}
+)";
+
+  auto* expect = R"(
+[[block]]
+struct TintVertexData {
+  tint_vertex_data : [[stride(4)]] array<u32>;
+};
+
+[[binding(0), group(4)]] var<storage> tint_pulling_vertex_buffer_0 : [[access(read)]] TintVertexData;
+
+[[binding(1), group(4)]] var<storage> tint_pulling_vertex_buffer_1 : [[access(read)]] TintVertexData;
+
+[[binding(2), group(4)]] var<storage> tint_pulling_vertex_buffer_2 : [[access(read)]] TintVertexData;
+
+[[stage(vertex)]]
+fn main([[builtin(vertex_index)]] tint_pulling_vertex_index : u32) -> [[builtin(position)]] vec4<f32> {
+  var var_a : vec2<f32>;
+  var var_b : vec3<f32>;
+  var var_c : vec4<f32>;
+  {
+    var tint_pulling_pos : u32;
+    tint_pulling_pos = ((tint_pulling_vertex_index * 8u) + 0u);
+    var_a = vec2<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[((tint_pulling_pos + 0u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[((tint_pulling_pos + 4u) / 4u)]));
+    tint_pulling_pos = ((tint_pulling_vertex_index * 12u) + 0u);
+    var_b = vec3<f32>(bitcast<f32>(tint_pulling_vertex_buffer_1.tint_vertex_data[((tint_pulling_pos + 0u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_1.tint_vertex_data[((tint_pulling_pos + 4u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_1.tint_vertex_data[((tint_pulling_pos + 8u) / 4u)]));
+    tint_pulling_pos = ((tint_pulling_vertex_index * 16u) + 0u);
+    var_c = vec4<f32>(bitcast<f32>(tint_pulling_vertex_buffer_2.tint_vertex_data[((tint_pulling_pos + 0u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_2.tint_vertex_data[((tint_pulling_pos + 4u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_2.tint_vertex_data[((tint_pulling_pos + 8u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_2.tint_vertex_data[((tint_pulling_pos + 12u) / 4u)]));
+  }
+  return vec4<f32>();
+}
+)";
+
+  VertexPulling::Config cfg;
+  cfg.vertex_state = {{
+      {8, InputStepMode::kVertex, {{VertexFormat::kVec2F32, 0, 0}}},
+      {12, InputStepMode::kVertex, {{VertexFormat::kVec3F32, 0, 1}}},
+      {16, InputStepMode::kVertex, {{VertexFormat::kVec4F32, 0, 2}}},
+  }};
+  cfg.entry_point_name = "main";
+
+  DataMap data;
+  data.Add<VertexPulling::Config>(cfg);
+  auto got = Run<VertexPulling>(src, data);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(VertexPullingTest, AttemptSymbolCollision) {
+  auto* src = R"(
+[[stage(vertex)]]
+fn main([[location(0)]] var_a : f32,
+        [[location(1)]] var_b : vec4<f32>) -> [[builtin(position)]] vec4<f32> {
+  var tint_pulling_vertex_index : i32;
+  var tint_pulling_vertex_buffer_0 : i32;
+  var tint_vertex_data : i32;
+  var tint_pulling_pos : i32;
+  return vec4<f32>();
+}
+)";
+
+  auto* expect = R"(
+[[block]]
+struct TintVertexData {
+  tint_vertex_data_1 : [[stride(4)]] array<u32>;
+};
+
+[[binding(0), group(4)]] var<storage> tint_pulling_vertex_buffer_0_1 : [[access(read)]] TintVertexData;
+
+[[stage(vertex)]]
+fn main([[builtin(vertex_index)]] tint_pulling_vertex_index_1 : u32) -> [[builtin(position)]] vec4<f32> {
+  var var_a : f32;
+  var var_b : vec4<f32>;
+  {
+    var tint_pulling_pos_1 : u32;
+    tint_pulling_pos_1 = ((tint_pulling_vertex_index_1 * 16u) + 0u);
+    var_a = bitcast<f32>(tint_pulling_vertex_buffer_0_1.tint_vertex_data_1[(tint_pulling_pos_1 / 4u)]);
+    tint_pulling_pos_1 = ((tint_pulling_vertex_index_1 * 16u) + 0u);
+    var_b = vec4<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0_1.tint_vertex_data_1[((tint_pulling_pos_1 + 0u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_0_1.tint_vertex_data_1[((tint_pulling_pos_1 + 4u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_0_1.tint_vertex_data_1[((tint_pulling_pos_1 + 8u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_0_1.tint_vertex_data_1[((tint_pulling_pos_1 + 12u) / 4u)]));
+  }
+  var tint_pulling_vertex_index : i32;
+  var tint_pulling_vertex_buffer_0 : i32;
+  var tint_vertex_data : i32;
+  var tint_pulling_pos : i32;
+  return vec4<f32>();
+}
+)";
+
+  VertexPulling::Config cfg;
+  cfg.vertex_state = {
+      {{16,
+        InputStepMode::kVertex,
+        {{VertexFormat::kF32, 0, 0}, {VertexFormat::kVec4F32, 0, 1}}}}};
+  cfg.entry_point_name = "main";
+
+  DataMap data;
+  data.Add<VertexPulling::Config>(cfg);
+  auto got = Run<VertexPulling>(src, std::move(data));
+
+  EXPECT_EQ(expect, str(got));
+}
+
+// TODO(crbug.com/tint/697): Remove this.
+TEST_F(VertexPullingTest, OneAttributeDifferentOutputSet_Legacy) {
+  auto* src = R"(
 [[location(0)]] var<in> var_a : f32;
 
 [[stage(vertex)]]
@@ -243,8 +699,9 @@
   EXPECT_EQ(expect, str(got));
 }
 
+// TODO(crbug.com/tint/697): Remove this.
 // We expect the transform to use an existing builtin variables if it finds them
-TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex) {
+TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex_Legacy) {
   auto* src = R"(
 [[location(0)]] var<in> var_a : f32;
 [[location(1)]] var<in> var_b : f32;
@@ -310,7 +767,8 @@
   EXPECT_EQ(expect, str(got));
 }
 
-TEST_F(VertexPullingTest, TwoAttributesSameBuffer) {
+// TODO(crbug.com/tint/697): Remove this.
+TEST_F(VertexPullingTest, TwoAttributesSameBuffer_Legacy) {
   auto* src = R"(
 [[location(0)]] var<in> var_a : f32;
 [[location(1)]] var<in> var_b : vec4<f32>;
@@ -362,127 +820,6 @@
   EXPECT_EQ(expect, str(got));
 }
 
-TEST_F(VertexPullingTest, FloatVectorAttributes) {
-  auto* src = R"(
-[[location(0)]] var<in> var_a : vec2<f32>;
-[[location(1)]] var<in> var_b : vec3<f32>;
-[[location(2)]] var<in> var_c : vec4<f32>;
-
-[[stage(vertex)]]
-fn main() -> [[builtin(position)]] vec4<f32> {
-  return vec4<f32>();
-}
-)";
-
-  auto* expect = R"(
-[[builtin(vertex_index)]] var<in> tint_pulling_vertex_index : u32;
-
-[[block]]
-struct TintVertexData {
-  tint_vertex_data : [[stride(4)]] array<u32>;
-};
-
-[[binding(0), group(4)]] var<storage> tint_pulling_vertex_buffer_0 : [[access(read)]] TintVertexData;
-
-[[binding(1), group(4)]] var<storage> tint_pulling_vertex_buffer_1 : [[access(read)]] TintVertexData;
-
-[[binding(2), group(4)]] var<storage> tint_pulling_vertex_buffer_2 : [[access(read)]] TintVertexData;
-
-var<private> var_a : vec2<f32>;
-
-var<private> var_b : vec3<f32>;
-
-var<private> var_c : vec4<f32>;
-
-[[stage(vertex)]]
-fn main() -> [[builtin(position)]] vec4<f32> {
-  {
-    var tint_pulling_pos : u32;
-    tint_pulling_pos = ((tint_pulling_vertex_index * 8u) + 0u);
-    var_a = vec2<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[((tint_pulling_pos + 0u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_0.tint_vertex_data[((tint_pulling_pos + 4u) / 4u)]));
-    tint_pulling_pos = ((tint_pulling_vertex_index * 12u) + 0u);
-    var_b = vec3<f32>(bitcast<f32>(tint_pulling_vertex_buffer_1.tint_vertex_data[((tint_pulling_pos + 0u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_1.tint_vertex_data[((tint_pulling_pos + 4u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_1.tint_vertex_data[((tint_pulling_pos + 8u) / 4u)]));
-    tint_pulling_pos = ((tint_pulling_vertex_index * 16u) + 0u);
-    var_c = vec4<f32>(bitcast<f32>(tint_pulling_vertex_buffer_2.tint_vertex_data[((tint_pulling_pos + 0u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_2.tint_vertex_data[((tint_pulling_pos + 4u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_2.tint_vertex_data[((tint_pulling_pos + 8u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_2.tint_vertex_data[((tint_pulling_pos + 12u) / 4u)]));
-  }
-  return vec4<f32>();
-}
-)";
-
-  VertexPulling::Config cfg;
-  cfg.vertex_state = {{
-      {8, InputStepMode::kVertex, {{VertexFormat::kVec2F32, 0, 0}}},
-      {12, InputStepMode::kVertex, {{VertexFormat::kVec3F32, 0, 1}}},
-      {16, InputStepMode::kVertex, {{VertexFormat::kVec4F32, 0, 2}}},
-  }};
-  cfg.entry_point_name = "main";
-
-  DataMap data;
-  data.Add<VertexPulling::Config>(cfg);
-  auto got = Run<VertexPulling>(src, data);
-
-  EXPECT_EQ(expect, str(got));
-}
-
-TEST_F(VertexPullingTest, AttemptSymbolCollision) {
-  auto* src = R"(
-[[location(0)]] var<in> var_a : f32;
-[[location(1)]] var<in> var_b : vec4<f32>;
-
-[[stage(vertex)]]
-fn main() -> [[builtin(position)]] vec4<f32> {
-  var tint_pulling_vertex_index : i32;
-  var tint_pulling_vertex_buffer_0 : i32;
-  var tint_vertex_data : i32;
-  var tint_pulling_pos : i32;
-  return vec4<f32>();
-}
-)";
-
-  auto* expect = R"(
-[[builtin(vertex_index)]] var<in> tint_pulling_vertex_index_1 : u32;
-
-[[block]]
-struct TintVertexData {
-  tint_vertex_data_1 : [[stride(4)]] array<u32>;
-};
-
-[[binding(0), group(4)]] var<storage> tint_pulling_vertex_buffer_0_1 : [[access(read)]] TintVertexData;
-
-var<private> var_a : f32;
-
-var<private> var_b : vec4<f32>;
-
-[[stage(vertex)]]
-fn main() -> [[builtin(position)]] vec4<f32> {
-  {
-    var tint_pulling_pos_1 : u32;
-    tint_pulling_pos_1 = ((tint_pulling_vertex_index_1 * 16u) + 0u);
-    var_a = bitcast<f32>(tint_pulling_vertex_buffer_0_1.tint_vertex_data_1[(tint_pulling_pos_1 / 4u)]);
-    tint_pulling_pos_1 = ((tint_pulling_vertex_index_1 * 16u) + 0u);
-    var_b = vec4<f32>(bitcast<f32>(tint_pulling_vertex_buffer_0_1.tint_vertex_data_1[((tint_pulling_pos_1 + 0u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_0_1.tint_vertex_data_1[((tint_pulling_pos_1 + 4u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_0_1.tint_vertex_data_1[((tint_pulling_pos_1 + 8u) / 4u)]), bitcast<f32>(tint_pulling_vertex_buffer_0_1.tint_vertex_data_1[((tint_pulling_pos_1 + 12u) / 4u)]));
-  }
-  var tint_pulling_vertex_index : i32;
-  var tint_pulling_vertex_buffer_0 : i32;
-  var tint_vertex_data : i32;
-  var tint_pulling_pos : i32;
-  return vec4<f32>();
-}
-)";
-
-  VertexPulling::Config cfg;
-  cfg.vertex_state = {
-      {{16,
-        InputStepMode::kVertex,
-        {{VertexFormat::kF32, 0, 0}, {VertexFormat::kVec4F32, 0, 1}}}}};
-  cfg.entry_point_name = "main";
-
-  DataMap data;
-  data.Add<VertexPulling::Config>(cfg);
-  auto got = Run<VertexPulling>(src, std::move(data));
-
-  EXPECT_EQ(expect, str(got));
-}
 }  // namespace
 }  // namespace transform
 }  // namespace tint