transform/VertexPulling: Use SymbolTable::New()

And clean up some code in the process.

Avoids potential symbol collisions. Simplifies the logic.

Bug: tint:712
Change-Id: Ibce5ccbd4c7fd45d5bf29906b5a83b3637b6cdcc
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/47633
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/transform/vertex_pulling.cc b/src/transform/vertex_pulling.cc
index 951f4fd..df00a17 100644
--- a/src/transform/vertex_pulling.cc
+++ b/src/transform/vertex_pulling.cc
@@ -22,6 +22,7 @@
 #include "src/ast/variable_decl_statement.h"
 #include "src/program_builder.h"
 #include "src/semantic/variable.h"
+#include "src/utils/get_or_create.h"
 
 TINT_INSTANTIATE_TYPEINFO(tint::transform::VertexPulling::Config);
 
@@ -29,12 +30,360 @@
 namespace transform {
 namespace {
 
-static const char kVertexBufferNamePrefix[] = "_tint_pulling_vertex_buffer_";
-static const char kStructBufferName[] = "_tint_vertex_data";
-static const char kStructName[] = "TintVertexData";
-static const char kPullingPosVarName[] = "_tint_pulling_pos";
-static const char kDefaultVertexIndexName[] = "_tint_pulling_vertex_index";
-static const char kDefaultInstanceIndexName[] = "_tint_pulling_instance_index";
+struct State {
+  State(CloneContext& context, const VertexPulling::Config& c)
+      : ctx(context), cfg(c) {}
+  State(const State&) = default;
+  ~State() = default;
+
+  /// LocationReplacement describes an ast::Variable replacement for a
+  /// location input.
+  struct LocationReplacement {
+    /// The variable to replace in the source Program
+    ast::Variable* from;
+    /// The replacement to use in the target ProgramBuilder
+    ast::Variable* to;
+  };
+
+  CloneContext& ctx;
+  VertexPulling::Config const cfg;
+  std::unordered_map<uint32_t, ast::Variable*> location_to_var;
+  std::vector<LocationReplacement> location_replacements;
+  Symbol vertex_index_name;
+  Symbol instance_index_name;
+  Symbol pulling_position_name;
+  Symbol struct_buffer_name;
+  std::unordered_map<uint32_t, Symbol> vertex_buffer_names;
+
+  /// Generate the vertex buffer binding name
+  /// @param index index to append to buffer name
+  Symbol GetVertexBufferName(uint32_t index) {
+    return utils::GetOrCreate(vertex_buffer_names, index, [&] {
+      static const char kVertexBufferNamePrefix[] =
+          "_tint_pulling_vertex_buffer_";
+      return ctx.dst->Symbols().New(kVertexBufferNamePrefix +
+                                    std::to_string(index));
+    });
+  }
+
+  /// Lazily generates the pulling position symbol
+  Symbol GetPullingPositionName() {
+    if (!pulling_position_name.IsValid()) {
+      static const char kPullingPosVarName[] = "_tint_pulling_pos";
+      pulling_position_name = ctx.dst->Symbols().New(kPullingPosVarName);
+    }
+    return pulling_position_name;
+  }
+
+  /// Lazily generates the structure buffer symbol
+  Symbol GetStructBufferName() {
+    if (!struct_buffer_name.IsValid()) {
+      static const char kStructBufferName[] = "_tint_vertex_data";
+      struct_buffer_name = ctx.dst->Symbols().New(kStructBufferName);
+    }
+    return struct_buffer_name;
+  }
+
+  /// Inserts vertex_index binding, or finds the existing one
+  void FindOrInsertVertexIndexIfUsed() {
+    bool uses_vertex_step_mode = false;
+    for (const VertexBufferLayoutDescriptor& buffer_layout : cfg.vertex_state) {
+      if (buffer_layout.step_mode == InputStepMode::kVertex) {
+        uses_vertex_step_mode = true;
+        break;
+      }
+    }
+    if (!uses_vertex_step_mode) {
+      return;
+    }
+
+    // Look for an existing vertex index builtin
+    for (auto* v : ctx.src->AST().GlobalVariables()) {
+      auto* sem = ctx.src->Sem().Get(v);
+      if (sem->StorageClass() != ast::StorageClass::kInput) {
+        continue;
+      }
+
+      for (auto* d : v->decorations()) {
+        if (auto* builtin = d->As<ast::BuiltinDecoration>()) {
+          if (builtin->value() == ast::Builtin::kVertexIndex) {
+            vertex_index_name = ctx.Clone(v->symbol());
+            return;
+          }
+        }
+      }
+    }
+
+    // We didn't find a vertex index builtin, so create one
+    static const char kDefaultVertexIndexName[] = "_tint_pulling_vertex_index";
+    vertex_index_name = ctx.dst->Symbols().New(kDefaultVertexIndexName);
+
+    ctx.dst->Global(
+        vertex_index_name, ctx.dst->ty.u32(), ast::StorageClass::kInput,
+        nullptr,
+        ast::DecorationList{
+            ctx.dst->create<ast::BuiltinDecoration>(ast::Builtin::kVertexIndex),
+        });
+  }
+
+  /// Inserts instance_index binding, or finds the existing one
+  void FindOrInsertInstanceIndexIfUsed() {
+    bool uses_instance_step_mode = false;
+    for (const VertexBufferLayoutDescriptor& buffer_layout : cfg.vertex_state) {
+      if (buffer_layout.step_mode == InputStepMode::kInstance) {
+        uses_instance_step_mode = true;
+        break;
+      }
+    }
+    if (!uses_instance_step_mode) {
+      return;
+    }
+
+    // Look for an existing instance index builtin
+    for (auto* v : ctx.src->AST().GlobalVariables()) {
+      auto* sem = ctx.src->Sem().Get(v);
+      if (sem->StorageClass() != ast::StorageClass::kInput) {
+        continue;
+      }
+
+      for (auto* d : v->decorations()) {
+        if (auto* builtin = d->As<ast::BuiltinDecoration>()) {
+          if (builtin->value() == ast::Builtin::kInstanceIndex) {
+            instance_index_name = ctx.Clone(v->symbol());
+            return;
+          }
+        }
+      }
+    }
+
+    // We didn't find an instance index builtin, so create one
+    static const char kDefaultInstanceIndexName[] =
+        "_tint_pulling_instance_index";
+    instance_index_name = ctx.dst->Symbols().New(kDefaultInstanceIndexName);
+
+    ctx.dst->Global(instance_index_name, ctx.dst->ty.u32(),
+                    ast::StorageClass::kInput, nullptr,
+                    ast::DecorationList{
+                        ctx.dst->create<ast::BuiltinDecoration>(
+                            ast::Builtin::kInstanceIndex),
+                    });
+  }
+
+  /// Converts var<in> with a location decoration to var<private>
+  void ConvertVertexInputVariablesToPrivate() {
+    for (auto* v : ctx.src->AST().GlobalVariables()) {
+      auto* sem = ctx.src->Sem().Get(v);
+      if (sem->StorageClass() != ast::StorageClass::kInput) {
+        continue;
+      }
+
+      for (auto* d : v->decorations()) {
+        if (auto* l = d->As<ast::LocationDecoration>()) {
+          uint32_t location = l->value();
+          // This is where the replacement is created. Expressions use
+          // identifier strings instead of pointers, so we don't need to update
+          // any other place in the AST.
+          auto* replacement = ctx.dst->Var(ctx.Clone(v->symbol()),
+                                           ctx.Clone(v->declared_type()),
+                                           ast::StorageClass::kPrivate);
+          location_to_var[location] = replacement;
+          location_replacements.emplace_back(
+              LocationReplacement{v, replacement});
+          break;
+        }
+      }
+    }
+  }
+
+  /// Adds storage buffer decorated variables for the vertex buffers
+  void AddVertexStorageBuffers() {
+    // TODO(idanr): Make this readonly
+    // https://github.com/gpuweb/gpuweb/issues/935
+
+    // Creating the struct type
+    static const char kStructName[] = "TintVertexData";
+    auto* struct_type = ctx.dst->Structure(
+        ctx.dst->Symbols().New(kStructName),
+        {
+            ctx.dst->Member(GetStructBufferName(),
+                            ctx.dst->ty.array<ProgramBuilder::u32, 0>(4)),
+        },
+        {
+            ctx.dst->create<ast::StructBlockDecoration>(),
+        });
+
+    for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) {
+      // The decorated variable with struct type
+      ctx.dst->Global(
+          GetVertexBufferName(i), struct_type, ast::StorageClass::kStorage,
+          nullptr,
+          ast::DecorationList{
+              ctx.dst->create<ast::BindingDecoration>(i),
+              ctx.dst->create<ast::GroupDecoration>(cfg.pulling_group),
+          });
+    }
+  }
+
+  /// Creates and returns the assignment to the variables from the buffers
+  ast::BlockStatement* CreateVertexPullingPreamble() {
+    // Assign by looking at the vertex descriptor to find attributes with
+    // matching location.
+
+    ast::StatementList stmts;
+
+    // Declare the pulling position variable in the shader
+    stmts.emplace_back(ctx.dst->create<ast::VariableDeclStatement>(
+        ctx.dst->Var(GetPullingPositionName(), ctx.dst->ty.u32(),
+                     ast::StorageClass::kFunction)));
+
+    for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) {
+      const VertexBufferLayoutDescriptor& buffer_layout = cfg.vertex_state[i];
+
+      for (const VertexAttributeDescriptor& attribute_desc :
+           buffer_layout.attributes) {
+        auto it = location_to_var.find(attribute_desc.shader_location);
+        if (it == location_to_var.end()) {
+          continue;
+        }
+        auto* v = it->second;
+
+        auto name = buffer_layout.step_mode == InputStepMode::kVertex
+                        ? vertex_index_name
+                        : instance_index_name;
+
+        // An expression for the start of the read in the buffer in bytes
+        auto* pos_value = ctx.dst->Add(
+            ctx.dst->Mul(name,
+                         static_cast<uint32_t>(buffer_layout.array_stride)),
+            static_cast<uint32_t>(attribute_desc.offset));
+
+        // Update position of the read
+        auto* set_pos_expr = ctx.dst->create<ast::AssignmentStatement>(
+            ctx.dst->Expr(GetPullingPositionName()), pos_value);
+        stmts.emplace_back(set_pos_expr);
+
+        stmts.emplace_back(ctx.dst->create<ast::AssignmentStatement>(
+            ctx.dst->create<ast::IdentifierExpression>(v->symbol()),
+            AccessByFormat(i, attribute_desc.format)));
+      }
+    }
+
+    return ctx.dst->create<ast::BlockStatement>(stmts);
+  }
+
+  /// Generates an expression reading from a buffer a specific format.
+  /// This reads the value wherever `kPullingPosVarName` points to at the time
+  /// of the read.
+  /// @param buffer the index of the vertex buffer
+  /// @param format the format to read
+  ast::Expression* AccessByFormat(uint32_t buffer, VertexFormat format) {
+    // TODO(idanr): this doesn't account for the format of the attribute in the
+    // shader. ex: vec<u32> in shader, and attribute claims VertexFormat::Float4
+    // right now, we would try to assign a vec4<f32> to this attribute, but we
+    // really need to assign a vec4<u32> by casting.
+    // We could split this function to first do memory accesses and unpacking
+    // into int/uint/float1-4/etc, then convert that variable to a var<in> with
+    // the conversion defined in the WebGPU spec.
+    switch (format) {
+      case VertexFormat::kU32:
+        return AccessU32(buffer, ctx.dst->Expr(GetPullingPositionName()));
+      case VertexFormat::kI32:
+        return AccessI32(buffer, ctx.dst->Expr(GetPullingPositionName()));
+      case VertexFormat::kF32:
+        return AccessF32(buffer, ctx.dst->Expr(GetPullingPositionName()));
+      case VertexFormat::kVec2F32:
+        return AccessVec(buffer, 4, ctx.dst->ty.f32(), VertexFormat::kF32, 2);
+      case VertexFormat::kVec3F32:
+        return AccessVec(buffer, 4, ctx.dst->ty.f32(), VertexFormat::kF32, 3);
+      case VertexFormat::kVec4F32:
+        return AccessVec(buffer, 4, ctx.dst->ty.f32(), VertexFormat::kF32, 4);
+      default:
+        return nullptr;
+    }
+  }
+
+  /// Generates an expression reading a uint32 from a vertex buffer
+  /// @param buffer the index of the vertex buffer
+  /// @param pos an expression for the position of the access, in bytes
+  ast::Expression* AccessU32(uint32_t buffer, ast::Expression* pos) {
+    // Here we divide by 4, since the buffer is uint32 not uint8. The input
+    // buffer has byte offsets for each attribute, and we will convert it to u32
+    // indexes by dividing. Then, that element is going to be read, and if
+    // needed, unpacked into an appropriate variable. All reads should end up
+    // here as a base case.
+    return ctx.dst->create<ast::ArrayAccessorExpression>(
+        ctx.dst->MemberAccessor(GetVertexBufferName(buffer),
+                                GetStructBufferName()),
+        ctx.dst->Div(pos, 4u));
+  }
+
+  /// Generates an expression reading an int32 from a vertex buffer
+  /// @param buffer the index of the vertex buffer
+  /// @param pos an expression for the position of the access, in bytes
+  ast::Expression* AccessI32(uint32_t buffer, ast::Expression* pos) {
+    // as<T> reinterprets bits
+    return ctx.dst->create<ast::BitcastExpression>(ctx.dst->ty.i32(),
+                                                   AccessU32(buffer, pos));
+  }
+
+  /// Generates an expression reading a float from a vertex buffer
+  /// @param buffer the index of the vertex buffer
+  /// @param pos an expression for the position of the access, in bytes
+  ast::Expression* AccessF32(uint32_t buffer, ast::Expression* pos) {
+    // as<T> reinterprets bits
+    return ctx.dst->create<ast::BitcastExpression>(ctx.dst->ty.f32(),
+                                                   AccessU32(buffer, pos));
+  }
+
+  /// Generates an expression reading a basic type (u32, i32, f32) from a
+  /// vertex buffer
+  /// @param buffer the index of the vertex buffer
+  /// @param pos an expression for the position of the access, in bytes
+  /// @param format the underlying vertex format
+  ast::Expression* AccessPrimitive(uint32_t buffer,
+                                   ast::Expression* pos,
+                                   VertexFormat format) {
+    // This function uses a position expression to read, rather than using the
+    // position variable. This allows us to read from offset positions relative
+    // to |kPullingPosVarName|. We can't call AccessByFormat because it reads
+    // only from the position variable.
+    switch (format) {
+      case VertexFormat::kU32:
+        return AccessU32(buffer, pos);
+      case VertexFormat::kI32:
+        return AccessI32(buffer, pos);
+      case VertexFormat::kF32:
+        return AccessF32(buffer, pos);
+      default:
+        return nullptr;
+    }
+  }
+
+  /// Generates an expression reading a vec2/3/4 from a vertex buffer.
+  /// This reads the value wherever `kPullingPosVarName` points to at the time
+  /// of the read.
+  /// @param buffer the index of the vertex buffer
+  /// @param element_stride stride between elements, in bytes
+  /// @param base_type underlying AST type
+  /// @param base_format underlying vertex format
+  /// @param count how many elements the vector has
+  ast::Expression* AccessVec(uint32_t buffer,
+                             uint32_t element_stride,
+                             type::Type* base_type,
+                             VertexFormat base_format,
+                             uint32_t count) {
+    ast::ExpressionList expr_list;
+    for (uint32_t i = 0; i < count; ++i) {
+      // Offset read position by element_stride for each component
+      auto* cur_pos =
+          ctx.dst->Add(GetPullingPositionName(), element_stride * i);
+      expr_list.push_back(AccessPrimitive(buffer, cur_pos, base_format));
+    }
+
+    return ctx.dst->create<ast::TypeConstructorExpression>(
+        ctx.dst->create<type::Vector>(base_type, count), std::move(expr_list));
+  }
+};
 
 }  // namespace
 
@@ -93,367 +442,6 @@
 VertexPulling::Config& VertexPulling::Config::operator=(const Config&) =
     default;
 
-VertexPulling::State::State(CloneContext& context, const Config& c)
-    : ctx(context), cfg(c) {}
-VertexPulling::State::State(const State&) = default;
-VertexPulling::State::~State() = default;
-
-std::string VertexPulling::State::GetVertexBufferName(uint32_t index) const {
-  return kVertexBufferNamePrefix + std::to_string(index);
-}
-
-void VertexPulling::State::FindOrInsertVertexIndexIfUsed() {
-  bool uses_vertex_step_mode = false;
-  for (const VertexBufferLayoutDescriptor& buffer_layout : cfg.vertex_state) {
-    if (buffer_layout.step_mode == InputStepMode::kVertex) {
-      uses_vertex_step_mode = true;
-      break;
-    }
-  }
-  if (!uses_vertex_step_mode) {
-    return;
-  }
-
-  // Look for an existing vertex index builtin
-  for (auto* v : ctx.src->AST().GlobalVariables()) {
-    auto* sem = ctx.src->Sem().Get(v);
-    if (sem->StorageClass() != ast::StorageClass::kInput) {
-      continue;
-    }
-
-    for (auto* d : v->decorations()) {
-      if (auto* builtin = d->As<ast::BuiltinDecoration>()) {
-        if (builtin->value() == ast::Builtin::kVertexIndex) {
-          vertex_index_name = ctx.src->Symbols().NameFor(v->symbol());
-          return;
-        }
-      }
-    }
-  }
-
-  // We didn't find a vertex index builtin, so create one
-  vertex_index_name = kDefaultVertexIndexName;
-
-  auto* var = ctx.dst->create<ast::Variable>(
-      Source{},                                        // source
-      ctx.dst->Symbols().Register(vertex_index_name),  // symbol
-      ast::StorageClass::kInput,                       // storage_class
-      GetU32Type(),                                    // type
-      false,                                           // is_const
-      nullptr,                                         // constructor
-      ast::DecorationList{
-          ctx.dst->create<ast::BuiltinDecoration>(Source{},
-                                                  ast::Builtin::kVertexIndex),
-      });
-
-  ctx.dst->AST().AddGlobalVariable(var);
-}
-
-void VertexPulling::State::FindOrInsertInstanceIndexIfUsed() {
-  bool uses_instance_step_mode = false;
-  for (const VertexBufferLayoutDescriptor& buffer_layout : cfg.vertex_state) {
-    if (buffer_layout.step_mode == InputStepMode::kInstance) {
-      uses_instance_step_mode = true;
-      break;
-    }
-  }
-  if (!uses_instance_step_mode) {
-    return;
-  }
-
-  // Look for an existing instance index builtin
-  for (auto* v : ctx.src->AST().GlobalVariables()) {
-    auto* sem = ctx.src->Sem().Get(v);
-    if (sem->StorageClass() != ast::StorageClass::kInput) {
-      continue;
-    }
-
-    for (auto* d : v->decorations()) {
-      if (auto* builtin = d->As<ast::BuiltinDecoration>()) {
-        if (builtin->value() == ast::Builtin::kInstanceIndex) {
-          instance_index_name = ctx.src->Symbols().NameFor(v->symbol());
-          return;
-        }
-      }
-    }
-  }
-
-  // We didn't find an instance index builtin, so create one
-  instance_index_name = kDefaultInstanceIndexName;
-
-  auto* var = ctx.dst->create<ast::Variable>(
-      Source{},                                          // source
-      ctx.dst->Symbols().Register(instance_index_name),  // symbol
-      ast::StorageClass::kInput,                         // storage_class
-      GetU32Type(),                                      // type
-      false,                                             // is_const
-      nullptr,                                           // constructor
-      ast::DecorationList{
-          ctx.dst->create<ast::BuiltinDecoration>(Source{},
-                                                  ast::Builtin::kInstanceIndex),
-      });
-  ctx.dst->AST().AddGlobalVariable(var);
-}
-
-void VertexPulling::State::ConvertVertexInputVariablesToPrivate() {
-  for (auto* v : ctx.src->AST().GlobalVariables()) {
-    auto* sem = ctx.src->Sem().Get(v);
-    if (sem->StorageClass() != ast::StorageClass::kInput) {
-      continue;
-    }
-
-    for (auto* d : v->decorations()) {
-      if (auto* l = d->As<ast::LocationDecoration>()) {
-        uint32_t location = l->value();
-        // This is where the replacement is created. Expressions use identifier
-        // strings instead of pointers, so we don't need to update any other
-        // place in the AST.
-        auto name = ctx.src->Symbols().NameFor(v->symbol());
-        auto* replacement = ctx.dst->create<ast::Variable>(
-            Source{},                           // source
-            ctx.dst->Symbols().Register(name),  // symbol
-            ast::StorageClass::kPrivate,        // storage_class
-            ctx.Clone(v->declared_type()),      // type
-            false,                              // is_const
-            nullptr,                            // constructor
-            ast::DecorationList{});             // decorations
-        location_to_var[location] = replacement;
-        location_replacements.emplace_back(LocationReplacement{v, replacement});
-        break;
-      }
-    }
-  }
-}
-
-void VertexPulling::State::AddVertexStorageBuffers() {
-  // TODO(idanr): Make this readonly https://github.com/gpuweb/gpuweb/issues/935
-  // The array inside the struct definition
-  auto* internal_array_type = ctx.dst->create<type::Array>(
-      GetU32Type(), 0,
-      ast::DecorationList{
-          ctx.dst->create<ast::StrideDecoration>(Source{}, 4u),
-      });
-
-  // Creating the struct type
-  ast::StructMemberList members;
-  members.push_back(ctx.dst->create<ast::StructMember>(
-      Source{}, ctx.dst->Symbols().Register(kStructBufferName),
-      internal_array_type, ast::DecorationList{}));
-
-  ast::DecorationList decos;
-  decos.push_back(ctx.dst->create<ast::StructBlockDecoration>(Source{}));
-
-  auto* struct_type = ctx.dst->create<type::Struct>(
-      ctx.dst->Symbols().Register(kStructName),
-      ctx.dst->create<ast::Struct>(Source{}, std::move(members),
-                                   std::move(decos)));
-
-  for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) {
-    // The decorated variable with struct type
-    std::string name = GetVertexBufferName(i);
-    auto* var = ctx.dst->create<ast::Variable>(
-        Source{},                           // source
-        ctx.dst->Symbols().Register(name),  // symbol
-        ast::StorageClass::kStorage,        // storage_class
-        struct_type,                        // type
-        false,                              // is_const
-        nullptr,                            // constructor
-        ast::DecorationList{
-            ctx.dst->create<ast::BindingDecoration>(Source{}, i),
-            ctx.dst->create<ast::GroupDecoration>(Source{}, cfg.pulling_group),
-        });
-    ctx.dst->AST().AddGlobalVariable(var);
-  }
-  ctx.dst->AST().AddConstructedType(struct_type);
-}
-
-ast::BlockStatement* VertexPulling::State::CreateVertexPullingPreamble() const {
-  // Assign by looking at the vertex descriptor to find attributes with matching
-  // location.
-
-  ast::StatementList stmts;
-
-  // Declare the |kPullingPosVarName| variable in the shader
-  auto* pos_declaration = ctx.dst->create<ast::VariableDeclStatement>(
-      Source{}, ctx.dst->create<ast::Variable>(
-                    Source{},                                         // source
-                    ctx.dst->Symbols().Register(kPullingPosVarName),  // symbol
-                    ast::StorageClass::kFunction,  // storage_class
-                    GetU32Type(),                  // type
-                    false,                         // is_const
-                    nullptr,                       // constructor
-                    ast::DecorationList{}));       // decorations
-
-  // |kPullingPosVarName| refers to the byte location of the current read. We
-  // declare a variable in the shader to avoid having to reuse Expression
-  // objects.
-  stmts.emplace_back(pos_declaration);
-
-  for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) {
-    const VertexBufferLayoutDescriptor& buffer_layout = cfg.vertex_state[i];
-
-    for (const VertexAttributeDescriptor& attribute_desc :
-         buffer_layout.attributes) {
-      auto it = location_to_var.find(attribute_desc.shader_location);
-      if (it == location_to_var.end()) {
-        continue;
-      }
-      auto* v = it->second;
-
-      auto name = buffer_layout.step_mode == InputStepMode::kVertex
-                      ? vertex_index_name
-                      : instance_index_name;
-      // Identifier to index by
-      auto* index_identifier = ctx.dst->create<ast::IdentifierExpression>(
-          Source{}, ctx.dst->Symbols().Register(name));
-
-      // An expression for the start of the read in the buffer in bytes
-      auto* pos_value = ctx.dst->create<ast::BinaryExpression>(
-          Source{}, ast::BinaryOp::kAdd,
-          ctx.dst->create<ast::BinaryExpression>(
-              Source{}, ast::BinaryOp::kMultiply, index_identifier,
-              GenUint(static_cast<uint32_t>(buffer_layout.array_stride))),
-          GenUint(static_cast<uint32_t>(attribute_desc.offset)));
-
-      // Update position of the read
-      auto* set_pos_expr = ctx.dst->create<ast::AssignmentStatement>(
-          Source{}, CreatePullingPositionIdent(), pos_value);
-      stmts.emplace_back(set_pos_expr);
-
-      stmts.emplace_back(ctx.dst->create<ast::AssignmentStatement>(
-          Source{},
-          ctx.dst->create<ast::IdentifierExpression>(Source{}, v->symbol()),
-          AccessByFormat(i, attribute_desc.format)));
-    }
-  }
-
-  return ctx.dst->create<ast::BlockStatement>(Source{}, stmts);
-}
-
-ast::Expression* VertexPulling::State::GenUint(uint32_t value) const {
-  return ctx.dst->create<ast::ScalarConstructorExpression>(
-      Source{},
-      ctx.dst->create<ast::UintLiteral>(Source{}, GetU32Type(), value));
-}
-
-ast::Expression* VertexPulling::State::CreatePullingPositionIdent() const {
-  return ctx.dst->create<ast::IdentifierExpression>(
-      Source{}, ctx.dst->Symbols().Register(kPullingPosVarName));
-}
-
-ast::Expression* VertexPulling::State::AccessByFormat(
-    uint32_t buffer,
-    VertexFormat format) const {
-  // TODO(idanr): this doesn't account for the format of the attribute in the
-  // shader. ex: vec<u32> in shader, and attribute claims VertexFormat::Float4
-  // right now, we would try to assign a vec4<f32> to this attribute, but we
-  // really need to assign a vec4<u32> by casting.
-  // We could split this function to first do memory accesses and unpacking into
-  // int/uint/float1-4/etc, then convert that variable to a var<in> with the
-  // conversion defined in the WebGPU spec.
-  switch (format) {
-    case VertexFormat::kU32:
-      return AccessU32(buffer, CreatePullingPositionIdent());
-    case VertexFormat::kI32:
-      return AccessI32(buffer, CreatePullingPositionIdent());
-    case VertexFormat::kF32:
-      return AccessF32(buffer, CreatePullingPositionIdent());
-    case VertexFormat::kVec2F32:
-      return AccessVec(buffer, 4, GetF32Type(), VertexFormat::kF32, 2);
-    case VertexFormat::kVec3F32:
-      return AccessVec(buffer, 4, GetF32Type(), VertexFormat::kF32, 3);
-    case VertexFormat::kVec4F32:
-      return AccessVec(buffer, 4, GetF32Type(), VertexFormat::kF32, 4);
-    default:
-      return nullptr;
-  }
-}
-
-ast::Expression* VertexPulling::State::AccessU32(uint32_t buffer,
-                                                 ast::Expression* pos) const {
-  // Here we divide by 4, since the buffer is uint32 not uint8. The input buffer
-  // has byte offsets for each attribute, and we will convert it to u32 indexes
-  // by dividing. Then, that element is going to be read, and if needed,
-  // unpacked into an appropriate variable. All reads should end up here as a
-  // base case.
-  auto vbuf_name = GetVertexBufferName(buffer);
-  return ctx.dst->create<ast::ArrayAccessorExpression>(
-      Source{},
-      ctx.dst->create<ast::MemberAccessorExpression>(
-          Source{},
-          ctx.dst->create<ast::IdentifierExpression>(
-              Source{}, ctx.dst->Symbols().Register(vbuf_name)),
-          ctx.dst->create<ast::IdentifierExpression>(
-              Source{}, ctx.dst->Symbols().Register(kStructBufferName))),
-      ctx.dst->create<ast::BinaryExpression>(Source{}, ast::BinaryOp::kDivide,
-                                             pos, GenUint(4)));
-}
-
-ast::Expression* VertexPulling::State::AccessI32(uint32_t buffer,
-                                                 ast::Expression* pos) const {
-  // as<T> reinterprets bits
-  return ctx.dst->create<ast::BitcastExpression>(Source{}, GetI32Type(),
-                                                 AccessU32(buffer, pos));
-}
-
-ast::Expression* VertexPulling::State::AccessF32(uint32_t buffer,
-                                                 ast::Expression* pos) const {
-  // as<T> reinterprets bits
-  return ctx.dst->create<ast::BitcastExpression>(Source{}, GetF32Type(),
-                                                 AccessU32(buffer, pos));
-}
-
-ast::Expression* VertexPulling::State::AccessPrimitive(
-    uint32_t buffer,
-    ast::Expression* pos,
-    VertexFormat format) const {
-  // This function uses a position expression to read, rather than using the
-  // position variable. This allows us to read from offset positions relative to
-  // |kPullingPosVarName|. We can't call AccessByFormat because it reads only
-  // from the position variable.
-  switch (format) {
-    case VertexFormat::kU32:
-      return AccessU32(buffer, pos);
-    case VertexFormat::kI32:
-      return AccessI32(buffer, pos);
-    case VertexFormat::kF32:
-      return AccessF32(buffer, pos);
-    default:
-      return nullptr;
-  }
-}
-
-ast::Expression* VertexPulling::State::AccessVec(uint32_t buffer,
-                                                 uint32_t element_stride,
-                                                 type::Type* base_type,
-                                                 VertexFormat base_format,
-                                                 uint32_t count) const {
-  ast::ExpressionList expr_list;
-  for (uint32_t i = 0; i < count; ++i) {
-    // Offset read position by element_stride for each component
-    auto* cur_pos = ctx.dst->create<ast::BinaryExpression>(
-        Source{}, ast::BinaryOp::kAdd, CreatePullingPositionIdent(),
-        GenUint(element_stride * i));
-    expr_list.push_back(AccessPrimitive(buffer, cur_pos, base_format));
-  }
-
-  return ctx.dst->create<ast::TypeConstructorExpression>(
-      Source{}, ctx.dst->create<type::Vector>(base_type, count),
-      std::move(expr_list));
-}
-
-type::Type* VertexPulling::State::GetU32Type() const {
-  return ctx.dst->create<type::U32>();
-}
-
-type::Type* VertexPulling::State::GetI32Type() const {
-  return ctx.dst->create<type::I32>();
-}
-
-type::Type* VertexPulling::State::GetF32Type() const {
-  return ctx.dst->create<type::F32>();
-}
-
 VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor() = default;
 
 VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor(
diff --git a/src/transform/vertex_pulling.h b/src/transform/vertex_pulling.h
index 8be1812..576bc09 100644
--- a/src/transform/vertex_pulling.h
+++ b/src/transform/vertex_pulling.h
@@ -176,105 +176,6 @@
 
  private:
   Config cfg_;
-
-  struct State {
-    State(CloneContext& ctx, const Config& c);
-    explicit State(const State&);
-    ~State();
-
-    /// Generate the vertex buffer binding name
-    /// @param index index to append to buffer name
-    std::string GetVertexBufferName(uint32_t index) const;
-
-    /// Inserts vertex_index binding, or finds the existing one
-    void FindOrInsertVertexIndexIfUsed();
-
-    /// Inserts instance_index binding, or finds the existing one
-    void FindOrInsertInstanceIndexIfUsed();
-
-    /// Converts var<in> with a location decoration to var<private>
-    void ConvertVertexInputVariablesToPrivate();
-
-    /// Adds storage buffer decorated variables for the vertex buffers
-    void AddVertexStorageBuffers();
-
-    /// Creates and returns the assignment to the variables from the buffers
-    ast::BlockStatement* CreateVertexPullingPreamble() const;
-
-    /// Generates an expression holding a constant uint
-    /// @param value uint value
-    ast::Expression* GenUint(uint32_t value) const;
-
-    /// Generates an expression to read the shader value `kPullingPosVarName`
-    ast::Expression* CreatePullingPositionIdent() const;
-
-    /// Generates an expression reading from a buffer a specific format.
-    /// This reads the value wherever `kPullingPosVarName` points to at the time
-    /// of the read.
-    /// @param buffer the index of the vertex buffer
-    /// @param format the format to read
-    ast::Expression* AccessByFormat(uint32_t buffer, VertexFormat format) const;
-
-    /// Generates an expression reading a uint32 from a vertex buffer
-    /// @param buffer the index of the vertex buffer
-    /// @param pos an expression for the position of the access, in bytes
-    ast::Expression* AccessU32(uint32_t buffer, ast::Expression* pos) const;
-
-    /// Generates an expression reading an int32 from a vertex buffer
-    /// @param buffer the index of the vertex buffer
-    /// @param pos an expression for the position of the access, in bytes
-    ast::Expression* AccessI32(uint32_t buffer, ast::Expression* pos) const;
-
-    /// Generates an expression reading a float from a vertex buffer
-    /// @param buffer the index of the vertex buffer
-    /// @param pos an expression for the position of the access, in bytes
-    ast::Expression* AccessF32(uint32_t buffer, ast::Expression* pos) const;
-
-    /// Generates an expression reading a basic type (u32, i32, f32) from a
-    /// vertex buffer
-    /// @param buffer the index of the vertex buffer
-    /// @param pos an expression for the position of the access, in bytes
-    /// @param format the underlying vertex format
-    ast::Expression* AccessPrimitive(uint32_t buffer,
-                                     ast::Expression* pos,
-                                     VertexFormat format) const;
-
-    /// Generates an expression reading a vec2/3/4 from a vertex buffer.
-    /// This reads the value wherever `kPullingPosVarName` points to at the time
-    /// of the read.
-    /// @param buffer the index of the vertex buffer
-    /// @param element_stride stride between elements, in bytes
-    /// @param base_type underlying AST type
-    /// @param base_format underlying vertex format
-    /// @param count how many elements the vector has
-    ast::Expression* AccessVec(uint32_t buffer,
-                               uint32_t element_stride,
-                               type::Type* base_type,
-                               VertexFormat base_format,
-                               uint32_t count) const;
-
-    // Used to grab corresponding types from the type manager
-    type::Type* GetU32Type() const;
-    type::Type* GetI32Type() const;
-    type::Type* GetF32Type() const;
-
-    CloneContext& ctx;
-    Config const cfg;
-
-    /// LocationReplacement describes an ast::Variable replacement for a
-    /// location input.
-    struct LocationReplacement {
-      /// The variable to replace in the source Program
-      ast::Variable* from;
-      /// The replacement to use in the target ProgramBuilder
-      ast::Variable* to;
-    };
-
-    std::unordered_map<uint32_t, ast::Variable*> location_to_var;
-    std::vector<LocationReplacement> location_replacements;
-    std::string vertex_index_name;
-    std::string instance_index_name;
-  };
 };
 
 }  // namespace transform
diff --git a/src/transform/vertex_pulling_test.cc b/src/transform/vertex_pulling_test.cc
index e1222e6..ce01f87 100644
--- a/src/transform/vertex_pulling_test.cc
+++ b/src/transform/vertex_pulling_test.cc
@@ -113,13 +113,13 @@
   auto* expect = R"(
 [[builtin(vertex_index)]] var<in> _tint_pulling_vertex_index : u32;
 
-[[binding(0), group(4)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
-
 [[block]]
 struct TintVertexData {
   _tint_vertex_data : [[stride(4)]] array<u32>;
 };
 
+[[binding(0), group(4)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
+
 var<private> var_a : f32;
 
 [[stage(vertex)]]
@@ -155,13 +155,13 @@
   auto* expect = R"(
 [[builtin(instance_index)]] var<in> _tint_pulling_instance_index : u32;
 
-[[binding(0), group(4)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
-
 [[block]]
 struct TintVertexData {
   _tint_vertex_data : [[stride(4)]] array<u32>;
 };
 
+[[binding(0), group(4)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
+
 var<private> var_a : f32;
 
 [[stage(vertex)]]
@@ -197,13 +197,13 @@
   auto* expect = R"(
 [[builtin(vertex_index)]] var<in> _tint_pulling_vertex_index : u32;
 
-[[binding(0), group(5)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
-
 [[block]]
 struct TintVertexData {
   _tint_vertex_data : [[stride(4)]] array<u32>;
 };
 
+[[binding(0), group(5)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
+
 var<private> var_a : f32;
 
 [[stage(vertex)]]
@@ -242,15 +242,15 @@
 )";
 
   auto* expect = R"(
-[[binding(0), group(4)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
-
-[[binding(1), group(4)]] var<storage> _tint_pulling_vertex_buffer_1 : TintVertexData;
-
 [[block]]
 struct TintVertexData {
   _tint_vertex_data : [[stride(4)]] array<u32>;
 };
 
+[[binding(0), group(4)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
+
+[[binding(1), group(4)]] var<storage> _tint_pulling_vertex_buffer_1 : TintVertexData;
+
 var<private> var_a : f32;
 
 var<private> var_b : f32;
@@ -305,13 +305,13 @@
   auto* expect = R"(
 [[builtin(vertex_index)]] var<in> _tint_pulling_vertex_index : u32;
 
-[[binding(0), group(4)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
-
 [[block]]
 struct TintVertexData {
   _tint_vertex_data : [[stride(4)]] array<u32>;
 };
 
+[[binding(0), group(4)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
+
 var<private> var_a : f32;
 
 var<private> var_b : vec4<f32>;
@@ -355,17 +355,17 @@
   auto* expect = R"(
 [[builtin(vertex_index)]] var<in> _tint_pulling_vertex_index : u32;
 
+[[block]]
+struct TintVertexData {
+  _tint_vertex_data : [[stride(4)]] array<u32>;
+};
+
 [[binding(0), group(4)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
 
 [[binding(1), group(4)]] var<storage> _tint_pulling_vertex_buffer_1 : TintVertexData;
 
 [[binding(2), group(4)]] var<storage> _tint_pulling_vertex_buffer_2 : TintVertexData;
 
-[[block]]
-struct TintVertexData {
-  _tint_vertex_data : [[stride(4)]] array<u32>;
-};
-
 var<private> var_a : vec2<f32>;
 
 var<private> var_b : vec3<f32>;