Remove BlockStatement::insert()

Bug: tint:396
Bug: tint:390
Change-Id: I719b84804164fa801ded505ed56717948f06c7a7
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/35502
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/ast/block_statement.h b/src/ast/block_statement.h
index 9b5417a..d485bdf 100644
--- a/src/ast/block_statement.h
+++ b/src/ast/block_statement.h
@@ -35,14 +35,6 @@
   BlockStatement(BlockStatement&&);
   ~BlockStatement() override;
 
-  /// Insert a statement to the block
-  /// @param index the index to insert at
-  /// @param stmt the statement to insert
-  void insert(size_t index, Statement* stmt) {
-    auto offset = static_cast<decltype(statements_)::difference_type>(index);
-    statements_.insert(statements_.begin() + offset, stmt);
-  }
-
   /// @returns true if the block is empty
   bool empty() const { return statements_.empty(); }
   /// @returns the number of statements directly in the block
@@ -60,16 +52,12 @@
   /// Retrieves the statement at `idx`
   /// @param idx the index. The index is not bounds checked.
   /// @returns the statement at `idx`
-  const Statement* get(size_t idx) const { return statements_[idx]; }
+  Statement* get(size_t idx) const { return statements_[idx]; }
 
   /// Retrieves the statement at `idx`
   /// @param idx the index. The index is not bounds checked.
   /// @returns the statement at `idx`
-  Statement* operator[](size_t idx) { return statements_[idx]; }
-  /// Retrieves the statement at `idx`
-  /// @param idx the index. The index is not bounds checked.
-  /// @returns the statement at `idx`
-  const Statement* operator[](size_t idx) const { return statements_[idx]; }
+  Statement* operator[](size_t idx) const { return statements_[idx]; }
 
   /// @returns the beginning iterator
   StatementList::const_iterator begin() const { return statements_.begin(); }
diff --git a/src/ast/block_statement_test.cc b/src/ast/block_statement_test.cc
index 71b71d2..78134c6 100644
--- a/src/ast/block_statement_test.cc
+++ b/src/ast/block_statement_test.cc
@@ -37,24 +37,6 @@
   EXPECT_EQ(b[0], ptr);
 }
 
-TEST_F(BlockStatementTest, Creation_WithInsert) {
-  auto* s1 = create<DiscardStatement>(Source{});
-  auto* s2 = create<DiscardStatement>(Source{});
-  auto* s3 = create<DiscardStatement>(Source{});
-
-  BlockStatement b(Source{}, StatementList{});
-  b.insert(0, s1);
-  b.insert(0, s2);
-  b.insert(1, s3);
-
-  // |b| should contain s2, s3, s1
-
-  ASSERT_EQ(b.size(), 3u);
-  EXPECT_EQ(b[0], s2);
-  EXPECT_EQ(b[1], s3);
-  EXPECT_EQ(b[2], s1);
-}
-
 TEST_F(BlockStatementTest, Creation_WithSource) {
   BlockStatement b(Source{Source::Location{20, 2}}, ast::StatementList{});
   auto src = b.source();
diff --git a/src/ast/module.cc b/src/ast/module.cc
index 6d5cf58..a357554 100644
--- a/src/ast/module.cc
+++ b/src/ast/module.cc
@@ -33,15 +33,30 @@
 Module Module::Clone() {
   Module out;
   CloneContext ctx(&out);
-  Clone(&ctx);
+
+  // Symbol table must be cloned first so that the resulting module has the
+  // symbols before we start the tree mutations.
+  ctx.mod->symbol_table_ = symbol_table_;
+
+  CloneUsing(&ctx);
   return out;
 }
 
-void Module::Clone(CloneContext* ctx) {
+Module Module::Clone(const std::function<void(CloneContext* ctx)>& init) {
+  Module out;
+  CloneContext ctx(&out);
+
   // Symbol table must be cloned first so that the resulting module has the
   // symbols before we start the tree mutations.
-  ctx->mod->symbol_table_ = symbol_table_;
+  ctx.mod->symbol_table_ = symbol_table_;
 
+  init(&ctx);
+
+  CloneUsing(&ctx);
+  return out;
+}
+
+void Module::CloneUsing(CloneContext* ctx) {
   for (auto* ty : constructed_types_) {
     ctx->mod->constructed_types_.emplace_back(ctx->Clone(ty));
   }
diff --git a/src/ast/module.h b/src/ast/module.h
index 1facd79..e5d153d 100644
--- a/src/ast/module.h
+++ b/src/ast/module.h
@@ -15,6 +15,7 @@
 #ifndef SRC_AST_MODULE_H_
 #define SRC_AST_MODULE_H_
 
+#include <functional>
 #include <memory>
 #include <string>
 #include <type_traits>
@@ -55,13 +56,11 @@
   /// @return a deep copy of this module
   Module Clone();
 
-  /// Clone this module into `ctx->mod` using the provided CloneContext
-  /// The module will be cloned in this order:
-  /// * Constructed types
-  /// * Global variables
-  /// * Functions
-  /// @param ctx the clone context
-  void Clone(CloneContext* ctx);
+  /// @param init a callback function to configure the CloneContex before
+  /// cloning any of the module's state
+  /// @return a deep copy of this module, calling `init` to first initialize the
+  /// context.
+  Module Clone(const std::function<void(CloneContext* ctx)>& init);
 
   /// Add a global variable to the module
   /// @param var the variable to add
@@ -181,6 +180,14 @@
  private:
   Module(const Module&) = delete;
 
+  /// Clone this module into `ctx->mod` using the provided CloneContext
+  /// The module will be cloned in this order:
+  /// * Constructed types
+  /// * Global variables
+  /// * Functions
+  /// @param ctx the clone context
+  void CloneUsing(CloneContext* ctx);
+
   SymbolTable symbol_table_;
   VariableList global_variables_;
   // The constructed types are owned by the type manager
diff --git a/src/reader/wgsl/parser_impl.cc b/src/reader/wgsl/parser_impl.cc
index 6219b32..663bdc6 100644
--- a/src/reader/wgsl/parser_impl.cc
+++ b/src/reader/wgsl/parser_impl.cc
@@ -1418,7 +1418,12 @@
 // body_stmt
 //   : BRACKET_LEFT statements BRACKET_RIGHT
 Expect<ast::BlockStatement*> ParserImpl::expect_body_stmt() {
-  return expect_brace_block("", [&] { return expect_statements(); });
+  return expect_brace_block("", [&]() -> Expect<ast::BlockStatement*> {
+    auto stmts = expect_statements();
+    if (stmts.errored)
+      return Failure::kErrored;
+    return create<ast::BlockStatement>(Source{}, stmts.value);
+  });
 }
 
 // paren_rhs_stmt
@@ -1437,7 +1442,7 @@
 
 // statements
 //   : statement*
-Expect<ast::BlockStatement*> ParserImpl::expect_statements() {
+Expect<ast::StatementList> ParserImpl::expect_statements() {
   bool errored = false;
   ast::StatementList stmts;
 
@@ -1455,7 +1460,7 @@
   if (errored)
     return Failure::kErrored;
 
-  return create<ast::BlockStatement>(Source{}, stmts);
+  return stmts;
 }
 
 // statement
@@ -1859,15 +1864,16 @@
     return Failure::kNoMatch;
 
   return expect_brace_block("loop", [&]() -> Maybe<ast::LoopStatement*> {
-    auto body = expect_statements();
-    if (body.errored)
+    auto stmts = expect_statements();
+    if (stmts.errored)
       return Failure::kErrored;
 
     auto continuing = continuing_stmt();
     if (continuing.errored)
       return Failure::kErrored;
 
-    return create<ast::LoopStatement>(source, body.value, continuing.value);
+    auto* body = create<ast::BlockStatement>(source, stmts.value);
+    return create<ast::LoopStatement>(source, body, continuing.value);
   });
 }
 
@@ -1958,9 +1964,9 @@
   if (header.errored)
     return Failure::kErrored;
 
-  auto body =
+  auto stmts =
       expect_brace_block("for loop", [&] { return expect_statements(); });
-  if (body.errored)
+  if (stmts.errored)
     return Failure::kErrored;
 
   // The for statement is a syntactic sugar on top of the loop statement.
@@ -1980,7 +1986,7 @@
     auto* break_if_not_condition =
         create<ast::IfStatement>(not_condition->source(), not_condition,
                                  break_body, ast::ElseStatementList{});
-    body->insert(0, break_if_not_condition);
+    stmts.value.insert(stmts.value.begin(), break_if_not_condition);
   }
 
   ast::BlockStatement* continuing_body = nullptr;
@@ -1991,7 +1997,8 @@
                                                   });
   }
 
-  auto* loop = create<ast::LoopStatement>(source, body.value, continuing_body);
+  auto* body = create<ast::BlockStatement>(source, stmts.value);
+  auto* loop = create<ast::LoopStatement>(source, body, continuing_body);
 
   if (header->initializer != nullptr) {
     return create<ast::BlockStatement>(source, ast::StatementList{
diff --git a/src/reader/wgsl/parser_impl.h b/src/reader/wgsl/parser_impl.h
index 1bdd73b..c523d06 100644
--- a/src/reader/wgsl/parser_impl.h
+++ b/src/reader/wgsl/parser_impl.h
@@ -468,7 +468,7 @@
   Expect<ast::Expression*> expect_paren_rhs_stmt();
   /// Parses a `statements` grammar element
   /// @returns the statements parsed
-  Expect<ast::BlockStatement*> expect_statements();
+  Expect<ast::StatementList> expect_statements();
   /// Parses a `statement` grammar element
   /// @returns the parsed statement or nullptr
   Maybe<ast::Statement*> statement();
diff --git a/src/reader/wgsl/parser_impl_for_stmt_test.cc b/src/reader/wgsl/parser_impl_for_stmt_test.cc
index 2172e21..b112779 100644
--- a/src/reader/wgsl/parser_impl_for_stmt_test.cc
+++ b/src/reader/wgsl/parser_impl_for_stmt_test.cc
@@ -15,6 +15,7 @@
 #include <string>
 
 #include "gtest/gtest.h"
+#include "src/ast/block_statement.h"
 #include "src/reader/wgsl/parser_impl.h"
 #include "src/reader/wgsl/parser_impl_test_helper.h"
 
@@ -30,15 +31,15 @@
     auto e_loop = p_loop->expect_statements();
     EXPECT_FALSE(e_loop.errored);
     EXPECT_FALSE(p_loop->has_error()) << p_loop->error();
-    ASSERT_NE(e_loop.value, nullptr);
 
     auto p_for = parser(for_str);
     auto e_for = p_for->expect_statements();
     EXPECT_FALSE(e_for.errored);
     EXPECT_FALSE(p_for->has_error()) << p_for->error();
-    ASSERT_NE(e_for.value, nullptr);
 
-    EXPECT_EQ(e_loop->str(), e_for->str());
+    std::string loop = ast::BlockStatement({}, e_loop.value).str();
+    std::string for_ = ast::BlockStatement({}, e_for.value).str();
+    EXPECT_EQ(loop, for_);
   }
 };
 
diff --git a/src/reader/wgsl/parser_impl_statements_test.cc b/src/reader/wgsl/parser_impl_statements_test.cc
index 88b6012..d33176c 100644
--- a/src/reader/wgsl/parser_impl_statements_test.cc
+++ b/src/reader/wgsl/parser_impl_statements_test.cc
@@ -29,8 +29,8 @@
   EXPECT_FALSE(e.errored);
   EXPECT_FALSE(p->has_error()) << p->error();
   ASSERT_EQ(e->size(), 2u);
-  EXPECT_TRUE(e->get(0)->Is<ast::DiscardStatement>());
-  EXPECT_TRUE(e->get(1)->Is<ast::ReturnStatement>());
+  EXPECT_TRUE(e.value[0]->Is<ast::DiscardStatement>());
+  EXPECT_TRUE(e.value[1]->Is<ast::ReturnStatement>());
 }
 
 TEST_F(ParserImplTest, Statements_Empty) {
diff --git a/src/transform/bound_array_accessors.cc b/src/transform/bound_array_accessors.cc
index d68da7c..3adf6b7 100644
--- a/src/transform/bound_array_accessors.cc
+++ b/src/transform/bound_array_accessors.cc
@@ -55,11 +55,11 @@
 
 Transform::Output BoundArrayAccessors::Run(ast::Module* mod) {
   Output out;
-  ast::CloneContext ctx(&out.module);
-  ctx.ReplaceAll([&](ast::ArrayAccessorExpression* expr) {
-    return Transform(expr, &ctx, &out.diagnostics);
+  out.module = mod->Clone([&](ast::CloneContext* ctx) {
+    ctx->ReplaceAll([&, ctx](ast::ArrayAccessorExpression* expr) {
+      return Transform(expr, ctx, &out.diagnostics);
+    });
   });
-  mod->Clone(&ctx);
   return out;
 }
 
diff --git a/src/transform/emit_vertex_point_size.cc b/src/transform/emit_vertex_point_size.cc
index 1d22a26..d8640d8 100644
--- a/src/transform/emit_vertex_point_size.cc
+++ b/src/transform/emit_vertex_point_size.cc
@@ -19,6 +19,7 @@
 
 #include "src/ast/assignment_statement.h"
 #include "src/ast/block_statement.h"
+#include "src/ast/clone_context.h"
 #include "src/ast/float_literal.h"
 #include "src/ast/identifier_expression.h"
 #include "src/ast/scalar_constructor_expression.h"
@@ -39,45 +40,56 @@
 
 Transform::Output EmitVertexPointSize::Run(ast::Module* in) {
   Output out;
-  out.module = in->Clone();
-  auto* mod = &out.module;
 
-  if (!mod->HasStage(ast::PipelineStage::kVertex)) {
+  if (!in->HasStage(ast::PipelineStage::kVertex)) {
     // If the module doesn't have any vertex stages, then there's nothing to do.
+    out.module = in->Clone();
     return out;
   }
 
-  auto* f32 = mod->create<ast::type::F32>();
+  tint::ast::AssignmentStatement* pointsize_assign = nullptr;
+  auto get_pointsize_assign = [&pointsize_assign](ast::Module* mod) {
+    if (pointsize_assign != nullptr) {
+      return pointsize_assign;
+    }
 
-  // Declare the pointsize builtin output variable.
-  auto* pointsize_var =
-      mod->create<ast::Variable>(Source{},                    // source
-                                 kPointSizeVar,               // name
-                                 ast::StorageClass::kOutput,  // storage_class
-                                 f32,                         // type
-                                 false,                       // is_const
-                                 nullptr,                     // constructor
-                                 ast::VariableDecorationList{
-                                     // decorations
-                                     mod->create<ast::BuiltinDecoration>(
-                                         ast::Builtin::kPointSize, Source{}),
-                                 });
-  mod->AddGlobalVariable(pointsize_var);
+    auto* f32 = mod->create<ast::type::F32>();
 
-  // Build the AST expression & statement for assigning pointsize one.
-  auto* one = mod->create<ast::ScalarConstructorExpression>(
-      Source{}, mod->create<ast::FloatLiteral>(Source{}, f32, 1.0f));
-  auto* pointsize_ident = mod->create<ast::IdentifierExpression>(
-      Source{}, mod->RegisterSymbol(kPointSizeVar), kPointSizeVar);
-  auto* pointsize_assign =
-      mod->create<ast::AssignmentStatement>(Source{}, pointsize_ident, one);
+    // Declare the pointsize builtin output variable.
+    auto* pointsize_var =
+        mod->create<ast::Variable>(Source{},                    // source
+                                   kPointSizeVar,               // name
+                                   ast::StorageClass::kOutput,  // storage_class
+                                   f32,                         // type
+                                   false,                       // is_const
+                                   nullptr,                     // constructor
+                                   ast::VariableDecorationList{
+                                       // decorations
+                                       mod->create<ast::BuiltinDecoration>(
+                                           ast::Builtin::kPointSize, Source{}),
+                                   });
+    mod->AddGlobalVariable(pointsize_var);
+
+    // Build the AST expression & statement for assigning pointsize one.
+    auto* one = mod->create<ast::ScalarConstructorExpression>(
+        Source{}, mod->create<ast::FloatLiteral>(Source{}, f32, 1.0f));
+    auto* pointsize_ident = mod->create<ast::IdentifierExpression>(
+        Source{}, mod->RegisterSymbol(kPointSizeVar), kPointSizeVar);
+    pointsize_assign =
+        mod->create<ast::AssignmentStatement>(Source{}, pointsize_ident, one);
+    return pointsize_assign;
+  };
 
   // Add the pointsize assignment statement to the front of all vertex stages.
-  for (auto* func : mod->functions()) {
-    if (func->pipeline_stage() == ast::PipelineStage::kVertex) {
-      func->body()->insert(0, pointsize_assign);
-    }
-  }
+  out.module = in->Clone([&](ast::CloneContext* ctx) {
+    ctx->ReplaceAll([&, ctx](ast::Function* func) -> ast::Function* {
+      if (func->pipeline_stage() != ast::PipelineStage::kVertex) {
+        return nullptr;  // Just clone func
+      }
+      return CloneWithStatementsAtStart(ctx, func,
+                                        {get_pointsize_assign(ctx->mod)});
+    });
+  });
 
   return out;
 }
diff --git a/src/transform/first_index_offset.cc b/src/transform/first_index_offset.cc
index 97c7670..807a2c6 100644
--- a/src/transform/first_index_offset.cc
+++ b/src/transform/first_index_offset.cc
@@ -112,70 +112,63 @@
   std::string vertex_index_name;
   std::string instance_index_name;
 
-  Output out;
-
   // Lazilly construct the UniformBuffer on first call to
   // maybe_create_buffer_var()
   ast::Variable* buffer_var = nullptr;
-  auto maybe_create_buffer_var = [&] {
+  auto maybe_create_buffer_var = [&](ast::Module* mod) {
     if (buffer_var == nullptr) {
-      buffer_var = AddUniformBuffer(&out.module);
+      buffer_var = AddUniformBuffer(mod);
     }
   };
 
   // Clone the AST, renaming the kVertexIdx and kInstanceIdx builtins, and add
   // a CreateFirstIndexOffset() statement to each function that uses one of
   // these builtins.
-  ast::CloneContext ctx(&out.module);
-  ctx.ReplaceAll([&](ast::Variable* var) -> ast::Variable* {
-    for (ast::VariableDecoration* dec : var->decorations()) {
-      if (auto* blt_dec = dec->As<ast::BuiltinDecoration>()) {
-        ast::Builtin blt_type = blt_dec->value();
-        if (blt_type == ast::Builtin::kVertexIdx) {
-          vertex_index_name = var->name();
-          has_vertex_index_ = true;
-          return clone_variable_with_new_name(&ctx, var,
-                                              kIndexOffsetPrefix + var->name());
-        } else if (blt_type == ast::Builtin::kInstanceIdx) {
-          instance_index_name = var->name();
-          has_instance_index_ = true;
-          return clone_variable_with_new_name(&ctx, var,
-                                              kIndexOffsetPrefix + var->name());
-        }
-      }
-    }
-    return nullptr;  // Just clone var
-  });
-  ctx.ReplaceAll(  // Note: This happens in the same pass as the rename above
-                   // which determines the original builtin variable names,
-                   // but this should be fine, as variables are cloned first.
-      [&](ast::Function* func) -> ast::Function* {
-        maybe_create_buffer_var();
-        if (buffer_var == nullptr) {
-          return nullptr;  // no transform need, just clone func
-        }
-        ast::StatementList statements;
-        for (const auto& data : func->local_referenced_builtin_variables()) {
-          if (data.second->value() == ast::Builtin::kVertexIdx) {
-            statements.emplace_back(CreateFirstIndexOffset(
-                vertex_index_name, kFirstVertexName, buffer_var, ctx.mod));
-          } else if (data.second->value() == ast::Builtin::kInstanceIdx) {
-            statements.emplace_back(CreateFirstIndexOffset(
-                instance_index_name, kFirstInstanceName, buffer_var, ctx.mod));
+
+  Output out;
+  out.module = in->Clone([&](ast::CloneContext* ctx) {
+    ctx->ReplaceAll([&, ctx](ast::Variable* var) -> ast::Variable* {
+      for (ast::VariableDecoration* dec : var->decorations()) {
+        if (auto* blt_dec = dec->As<ast::BuiltinDecoration>()) {
+          ast::Builtin blt_type = blt_dec->value();
+          if (blt_type == ast::Builtin::kVertexIdx) {
+            vertex_index_name = var->name();
+            has_vertex_index_ = true;
+            return clone_variable_with_new_name(
+                ctx, var, kIndexOffsetPrefix + var->name());
+          } else if (blt_type == ast::Builtin::kInstanceIdx) {
+            instance_index_name = var->name();
+            has_instance_index_ = true;
+            return clone_variable_with_new_name(
+                ctx, var, kIndexOffsetPrefix + var->name());
           }
         }
-        for (auto* s : *func->body()) {
-          statements.emplace_back(ctx.Clone(s));
-        }
-        return ctx.mod->create<ast::Function>(
-            ctx.Clone(func->source()), func->symbol(), func->name(),
-            ctx.Clone(func->params()), ctx.Clone(func->return_type()),
-            ctx.mod->create<ast::BlockStatement>(
-                ctx.Clone(func->body()->source()), statements),
-            ctx.Clone(func->decorations()));
-      });
+      }
+      return nullptr;  // Just clone var
+    });
+    ctx->ReplaceAll(  // Note: This happens in the same pass as the rename above
+                      // which determines the original builtin variable names,
+                      // but this should be fine, as variables are cloned first.
+        [&, ctx](ast::Function* func) -> ast::Function* {
+          maybe_create_buffer_var(ctx->mod);
+          if (buffer_var == nullptr) {
+            return nullptr;  // no transform need, just clone func
+          }
+          ast::StatementList statements;
+          for (const auto& data : func->local_referenced_builtin_variables()) {
+            if (data.second->value() == ast::Builtin::kVertexIdx) {
+              statements.emplace_back(CreateFirstIndexOffset(
+                  vertex_index_name, kFirstVertexName, buffer_var, ctx->mod));
+            } else if (data.second->value() == ast::Builtin::kInstanceIdx) {
+              statements.emplace_back(CreateFirstIndexOffset(
+                  instance_index_name, kFirstInstanceName, buffer_var,
+                  ctx->mod));
+            }
+          }
+          return CloneWithStatementsAtStart(ctx, func, statements);
+        });
+  });
 
-  in->Clone(&ctx);
   return out;
 }
 
diff --git a/src/transform/transform.cc b/src/transform/transform.cc
index f6bdfd2..a03b943 100644
--- a/src/transform/transform.cc
+++ b/src/transform/transform.cc
@@ -14,11 +14,30 @@
 
 #include "src/transform/transform.h"
 
+#include "src/ast/block_statement.h"
+#include "src/ast/clone_context.h"
+#include "src/ast/function.h"
+
 namespace tint {
 namespace transform {
 
 Transform::Transform() = default;
 Transform::~Transform() = default;
 
+ast::Function* Transform::CloneWithStatementsAtStart(
+    ast::CloneContext* ctx,
+    ast::Function* in,
+    ast::StatementList statements) {
+  for (auto* s : *in->body()) {
+    statements.emplace_back(ctx->Clone(s));
+  }
+  return ctx->mod->create<ast::Function>(
+      ctx->Clone(in->source()), in->symbol(), in->name(),
+      ctx->Clone(in->params()), ctx->Clone(in->return_type()),
+      ctx->mod->create<ast::BlockStatement>(ctx->Clone(in->body()->source()),
+                                            statements),
+      ctx->Clone(in->decorations()));
+}
+
 }  // namespace transform
 }  // namespace tint
diff --git a/src/transform/transform.h b/src/transform/transform.h
index 2a68467..211be8a 100644
--- a/src/transform/transform.h
+++ b/src/transform/transform.h
@@ -48,6 +48,18 @@
   /// @param module the source module to transform
   /// @returns the transformation result
   virtual Output Run(ast::Module* module) = 0;
+
+ protected:
+  /// Clones the function `in` adding `statements` to the beginning of the
+  /// cloned function body.
+  /// @param ctx the clone context
+  /// @param in the function to clone
+  /// @param statements the statements to prepend to `in`'s body
+  /// @return the cloned function
+  static ast::Function* CloneWithStatementsAtStart(
+      ast::CloneContext* ctx,
+      ast::Function* in,
+      ast::StatementList statements);
 };
 
 }  // namespace transform
diff --git a/src/transform/vertex_pulling.cc b/src/transform/vertex_pulling.cc
index 163d875..74feab3 100644
--- a/src/transform/vertex_pulling.cc
+++ b/src/transform/vertex_pulling.cc
@@ -20,6 +20,7 @@
 #include "src/ast/assignment_statement.h"
 #include "src/ast/binary_expression.h"
 #include "src/ast/bitcast_expression.h"
+#include "src/ast/clone_context.h"
 #include "src/ast/member_accessor_expression.h"
 #include "src/ast/scalar_constructor_expression.h"
 #include "src/ast/stride_decoration.h"
@@ -69,27 +70,24 @@
 }
 
 Transform::Output VertexPulling::Run(ast::Module* in) {
-  Output out;
-  out.module = in->Clone();
-
-  ast::Module* mod = &out.module;
-
   // Check SetVertexState was called
   if (!cfg.vertex_state_set) {
     diag::Diagnostic err;
     err.severity = diag::Severity::Error;
     err.message = "SetVertexState not called";
+    Output out;
     out.diagnostics.add(std::move(err));
     return out;
   }
 
   // Find entry point
-  auto* func = mod->FindFunctionBySymbolAndStage(
-      mod->GetSymbol(cfg.entry_point_name), ast::PipelineStage::kVertex);
+  auto* func = in->FindFunctionBySymbolAndStage(
+      in->GetSymbol(cfg.entry_point_name), ast::PipelineStage::kVertex);
   if (func == nullptr) {
     diag::Diagnostic err;
     err.severity = diag::Severity::Error;
     err.message = "Vertex stage entry point not found";
+    Output out;
     out.diagnostics.add(std::move(err));
     return out;
   }
@@ -99,13 +97,22 @@
 
   // TODO(idanr): Make sure we covered all error cases, to guarantee the
   // following stages will pass
+  Output out;
+  out.module = in->Clone([&](ast::CloneContext* ctx) {
+    State state{in, ctx->mod, cfg};
+    state.FindOrInsertVertexIndexIfUsed();
+    state.FindOrInsertInstanceIndexIfUsed();
+    state.ConvertVertexInputVariablesToPrivate();
+    state.AddVertexStorageBuffers();
 
-  State state{mod, cfg};
-  state.FindOrInsertVertexIndexIfUsed();
-  state.FindOrInsertInstanceIndexIfUsed();
-  state.ConvertVertexInputVariablesToPrivate();
-  state.AddVertexStorageBuffers();
-  func->body()->insert(0, state.CreateVertexPullingPreamble());
+    ctx->ReplaceAll([func, ctx, state](ast::Function* f) -> ast::Function* {
+      if (f == func) {
+        return CloneWithStatementsAtStart(
+            ctx, f, {state.CreateVertexPullingPreamble()});
+      }
+      return nullptr;  // Just clone func
+    });
+  });
 
   return out;
 }
@@ -114,11 +121,14 @@
 VertexPulling::Config::Config(const Config&) = default;
 VertexPulling::Config::~Config() = default;
 
-VertexPulling::State::State(ast::Module* m, const Config& c) : mod(m), cfg(c) {}
+VertexPulling::State::State(ast::Module* i, ast::Module* o, const Config& c)
+    : in(i), out(o), cfg(c) {}
+
+VertexPulling::State::State(const State&) = default;
 
 VertexPulling::State::~State() = default;
 
-std::string VertexPulling::State::GetVertexBufferName(uint32_t index) {
+std::string VertexPulling::State::GetVertexBufferName(uint32_t index) const {
   return kVertexBufferNamePrefix + std::to_string(index);
 }
 
@@ -135,7 +145,7 @@
   }
 
   // Look for an existing vertex index builtin
-  for (auto* v : mod->global_variables()) {
+  for (auto* v : in->global_variables()) {
     if (v->storage_class() != ast::StorageClass::kInput) {
       continue;
     }
@@ -154,7 +164,7 @@
   vertex_index_name = kDefaultVertexIndexName;
 
   auto* var =
-      mod->create<ast::Variable>(Source{},                   // source
+      out->create<ast::Variable>(Source{},                   // source
                                  vertex_index_name,          // name
                                  ast::StorageClass::kInput,  // storage_class
                                  GetI32Type(),               // type
@@ -162,11 +172,11 @@
                                  nullptr,                    // constructor
                                  ast::VariableDecorationList{
                                      // decorations
-                                     mod->create<ast::BuiltinDecoration>(
+                                     out->create<ast::BuiltinDecoration>(
                                          ast::Builtin::kVertexIdx, Source{}),
                                  });
 
-  mod->AddGlobalVariable(var);
+  out->AddGlobalVariable(var);
 }
 
 void VertexPulling::State::FindOrInsertInstanceIndexIfUsed() {
@@ -182,7 +192,7 @@
   }
 
   // Look for an existing instance index builtin
-  for (auto* v : mod->global_variables()) {
+  for (auto* v : in->global_variables()) {
     if (v->storage_class() != ast::StorageClass::kInput) {
       continue;
     }
@@ -201,7 +211,7 @@
   instance_index_name = kDefaultInstanceIndexName;
 
   auto* var =
-      mod->create<ast::Variable>(Source{},                   // source
+      out->create<ast::Variable>(Source{},                   // source
                                  instance_index_name,        // name
                                  ast::StorageClass::kInput,  // storage_class
                                  GetI32Type(),               // type
@@ -209,14 +219,14 @@
                                  nullptr,                    // constructor
                                  ast::VariableDecorationList{
                                      // decorations
-                                     mod->create<ast::BuiltinDecoration>(
+                                     out->create<ast::BuiltinDecoration>(
                                          ast::Builtin::kInstanceIdx, Source{}),
                                  });
-  mod->AddGlobalVariable(var);
+  out->AddGlobalVariable(var);
 }
 
 void VertexPulling::State::ConvertVertexInputVariablesToPrivate() {
-  for (auto*& v : mod->global_variables()) {
+  for (auto*& v : in->global_variables()) {
     if (v->storage_class() != ast::StorageClass::kInput) {
       continue;
     }
@@ -227,7 +237,7 @@
         // This is where the replacement happens. Expressions use identifier
         // strings instead of pointers, so we don't need to update any other
         // place in the AST.
-        v = mod->create<ast::Variable>(
+        v = out->create<ast::Variable>(
             Source{},                        // source
             v->name(),                       // name
             ast::StorageClass::kPrivate,     // storage_class
@@ -245,31 +255,31 @@
 void VertexPulling::State::AddVertexStorageBuffers() {
   // TODO(idanr): Make this readonly https://github.com/gpuweb/gpuweb/issues/935
   // The array inside the struct definition
-  auto* internal_array_type = mod->create<ast::type::Array>(
+  auto* internal_array_type = out->create<ast::type::Array>(
       GetU32Type(), 0,
       ast::ArrayDecorationList{
-          mod->create<ast::StrideDecoration>(4u, Source{}),
+          out->create<ast::StrideDecoration>(4u, Source{}),
       });
 
   // Creating the struct type
   ast::StructMemberList members;
   ast::StructMemberDecorationList member_dec;
   member_dec.push_back(
-      mod->create<ast::StructMemberOffsetDecoration>(0u, Source{}));
+      out->create<ast::StructMemberOffsetDecoration>(0u, Source{}));
 
-  members.push_back(mod->create<ast::StructMember>(
+  members.push_back(out->create<ast::StructMember>(
       Source{}, kStructBufferName, internal_array_type, std::move(member_dec)));
 
   ast::StructDecorationList decos;
-  decos.push_back(mod->create<ast::StructBlockDecoration>(Source{}));
+  decos.push_back(out->create<ast::StructBlockDecoration>(Source{}));
 
-  auto* struct_type = mod->create<ast::type::Struct>(
-      mod->RegisterSymbol(kStructName), kStructName,
-      mod->create<ast::Struct>(Source{}, std::move(members), std::move(decos)));
+  auto* struct_type = out->create<ast::type::Struct>(
+      out->RegisterSymbol(kStructName), kStructName,
+      out->create<ast::Struct>(Source{}, std::move(members), std::move(decos)));
 
   for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) {
     // The decorated variable with struct type
-    auto* var = mod->create<ast::Variable>(
+    auto* var = out->create<ast::Variable>(
         Source{},                           // source
         GetVertexBufferName(i),             // name
         ast::StorageClass::kStorageBuffer,  // storage_class
@@ -278,23 +288,23 @@
         nullptr,                            // constructor
         ast::VariableDecorationList{
             // decorations
-            mod->create<ast::BindingDecoration>(i, Source{}),
-            mod->create<ast::SetDecoration>(cfg.pulling_set, Source{}),
+            out->create<ast::BindingDecoration>(i, Source{}),
+            out->create<ast::SetDecoration>(cfg.pulling_set, Source{}),
         });
-    mod->AddGlobalVariable(var);
+    out->AddGlobalVariable(var);
   }
-  mod->AddConstructedType(struct_type);
+  out->AddConstructedType(struct_type);
 }
 
-ast::BlockStatement* VertexPulling::State::CreateVertexPullingPreamble() {
+ast::BlockStatement* VertexPulling::State::CreateVertexPullingPreamble() const {
   // Assign by looking at the vertex descriptor to find attributes with matching
   // location.
 
   ast::StatementList stmts;
 
   // Declare the |kPullingPosVarName| variable in the shader
-  auto* pos_declaration = mod->create<ast::VariableDeclStatement>(
-      Source{}, mod->create<ast::Variable>(
+  auto* pos_declaration = out->create<ast::VariableDeclStatement>(
+      Source{}, out->create<ast::Variable>(
                     Source{},                         // source
                     kPullingPosVarName,               // name
                     ast::StorageClass::kFunction,     // storage_class
@@ -323,45 +333,46 @@
                       ? vertex_index_name
                       : instance_index_name;
       // Identifier to index by
-      auto* index_identifier = mod->create<ast::IdentifierExpression>(
-          Source{}, mod->RegisterSymbol(name), name);
+      auto* index_identifier = out->create<ast::IdentifierExpression>(
+          Source{}, out->RegisterSymbol(name), name);
 
       // An expression for the start of the read in the buffer in bytes
-      auto* pos_value = mod->create<ast::BinaryExpression>(
+      auto* pos_value = out->create<ast::BinaryExpression>(
           Source{}, ast::BinaryOp::kAdd,
-          mod->create<ast::BinaryExpression>(
+          out->create<ast::BinaryExpression>(
               Source{}, ast::BinaryOp::kMultiply, index_identifier,
               GenUint(static_cast<uint32_t>(buffer_layout.array_stride))),
           GenUint(static_cast<uint32_t>(attribute_desc.offset)));
 
       // Update position of the read
-      auto* set_pos_expr = mod->create<ast::AssignmentStatement>(
+      auto* set_pos_expr = out->create<ast::AssignmentStatement>(
           Source{}, CreatePullingPositionIdent(), pos_value);
       stmts.emplace_back(set_pos_expr);
 
-      stmts.emplace_back(mod->create<ast::AssignmentStatement>(
+      stmts.emplace_back(out->create<ast::AssignmentStatement>(
           Source{},
-          mod->create<ast::IdentifierExpression>(
-              Source{}, mod->RegisterSymbol(v->name()), v->name()),
+          out->create<ast::IdentifierExpression>(
+              Source{}, out->RegisterSymbol(v->name()), v->name()),
           AccessByFormat(i, attribute_desc.format)));
     }
   }
 
-  return mod->create<ast::BlockStatement>(Source{}, stmts);
+  return out->create<ast::BlockStatement>(Source{}, stmts);
 }
 
-ast::Expression* VertexPulling::State::GenUint(uint32_t value) {
-  return mod->create<ast::ScalarConstructorExpression>(
-      Source{}, mod->create<ast::UintLiteral>(Source{}, GetU32Type(), value));
+ast::Expression* VertexPulling::State::GenUint(uint32_t value) const {
+  return out->create<ast::ScalarConstructorExpression>(
+      Source{}, out->create<ast::UintLiteral>(Source{}, GetU32Type(), value));
 }
 
-ast::Expression* VertexPulling::State::CreatePullingPositionIdent() {
-  return mod->create<ast::IdentifierExpression>(
-      Source{}, mod->RegisterSymbol(kPullingPosVarName), kPullingPosVarName);
+ast::Expression* VertexPulling::State::CreatePullingPositionIdent() const {
+  return out->create<ast::IdentifierExpression>(
+      Source{}, out->RegisterSymbol(kPullingPosVarName), kPullingPosVarName);
 }
 
-ast::Expression* VertexPulling::State::AccessByFormat(uint32_t buffer,
-                                                      VertexFormat format) {
+ast::Expression* VertexPulling::State::AccessByFormat(
+    uint32_t buffer,
+    VertexFormat format) const {
   // TODO(idanr): this doesn't account for the format of the attribute in the
   // shader. ex: vec<u32> in shader, and attribute claims VertexFormat::Float4
   // right now, we would try to assign a vec4<f32> to this attribute, but we
@@ -388,43 +399,44 @@
 }
 
 ast::Expression* VertexPulling::State::AccessU32(uint32_t buffer,
-                                                 ast::Expression* pos) {
+                                                 ast::Expression* pos) const {
   // Here we divide by 4, since the buffer is uint32 not uint8. The input buffer
   // has byte offsets for each attribute, and we will convert it to u32 indexes
   // by dividing. Then, that element is going to be read, and if needed,
   // unpacked into an appropriate variable. All reads should end up here as a
   // base case.
   auto vbuf_name = GetVertexBufferName(buffer);
-  return mod->create<ast::ArrayAccessorExpression>(
+  return out->create<ast::ArrayAccessorExpression>(
       Source{},
-      mod->create<ast::MemberAccessorExpression>(
+      out->create<ast::MemberAccessorExpression>(
           Source{},
-          mod->create<ast::IdentifierExpression>(
-              Source{}, mod->RegisterSymbol(vbuf_name), vbuf_name),
-          mod->create<ast::IdentifierExpression>(
-              Source{}, mod->RegisterSymbol(kStructBufferName),
+          out->create<ast::IdentifierExpression>(
+              Source{}, out->RegisterSymbol(vbuf_name), vbuf_name),
+          out->create<ast::IdentifierExpression>(
+              Source{}, out->RegisterSymbol(kStructBufferName),
               kStructBufferName)),
-      mod->create<ast::BinaryExpression>(Source{}, ast::BinaryOp::kDivide, pos,
+      out->create<ast::BinaryExpression>(Source{}, ast::BinaryOp::kDivide, pos,
                                          GenUint(4)));
 }
 
 ast::Expression* VertexPulling::State::AccessI32(uint32_t buffer,
-                                                 ast::Expression* pos) {
+                                                 ast::Expression* pos) const {
   // as<T> reinterprets bits
-  return mod->create<ast::BitcastExpression>(Source{}, GetI32Type(),
+  return out->create<ast::BitcastExpression>(Source{}, GetI32Type(),
                                              AccessU32(buffer, pos));
 }
 
 ast::Expression* VertexPulling::State::AccessF32(uint32_t buffer,
-                                                 ast::Expression* pos) {
+                                                 ast::Expression* pos) const {
   // as<T> reinterprets bits
-  return mod->create<ast::BitcastExpression>(Source{}, GetF32Type(),
+  return out->create<ast::BitcastExpression>(Source{}, GetF32Type(),
                                              AccessU32(buffer, pos));
 }
 
-ast::Expression* VertexPulling::State::AccessPrimitive(uint32_t buffer,
-                                                       ast::Expression* pos,
-                                                       VertexFormat format) {
+ast::Expression* VertexPulling::State::AccessPrimitive(
+    uint32_t buffer,
+    ast::Expression* pos,
+    VertexFormat format) const {
   // This function uses a position expression to read, rather than using the
   // position variable. This allows us to read from offset positions relative to
   // |kPullingPosVarName|. We can't call AccessByFormat because it reads only
@@ -445,31 +457,31 @@
                                                  uint32_t element_stride,
                                                  ast::type::Type* base_type,
                                                  VertexFormat base_format,
-                                                 uint32_t count) {
+                                                 uint32_t count) const {
   ast::ExpressionList expr_list;
   for (uint32_t i = 0; i < count; ++i) {
     // Offset read position by element_stride for each component
-    auto* cur_pos = mod->create<ast::BinaryExpression>(
+    auto* cur_pos = out->create<ast::BinaryExpression>(
         Source{}, ast::BinaryOp::kAdd, CreatePullingPositionIdent(),
         GenUint(element_stride * i));
     expr_list.push_back(AccessPrimitive(buffer, cur_pos, base_format));
   }
 
-  return mod->create<ast::TypeConstructorExpression>(
-      Source{}, mod->create<ast::type::Vector>(base_type, count),
+  return out->create<ast::TypeConstructorExpression>(
+      Source{}, out->create<ast::type::Vector>(base_type, count),
       std::move(expr_list));
 }
 
-ast::type::Type* VertexPulling::State::GetU32Type() {
-  return mod->create<ast::type::U32>();
+ast::type::Type* VertexPulling::State::GetU32Type() const {
+  return out->create<ast::type::U32>();
 }
 
-ast::type::Type* VertexPulling::State::GetI32Type() {
-  return mod->create<ast::type::I32>();
+ast::type::Type* VertexPulling::State::GetI32Type() const {
+  return out->create<ast::type::I32>();
 }
 
-ast::type::Type* VertexPulling::State::GetF32Type() {
-  return mod->create<ast::type::F32>();
+ast::type::Type* VertexPulling::State::GetF32Type() const {
+  return out->create<ast::type::F32>();
 }
 
 VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor() = default;
diff --git a/src/transform/vertex_pulling.h b/src/transform/vertex_pulling.h
index 4f55b1f..6d4f23c 100644
--- a/src/transform/vertex_pulling.h
+++ b/src/transform/vertex_pulling.h
@@ -178,12 +178,13 @@
   Config cfg;
 
   struct State {
-    State(ast::Module* m, const Config& c);
+    State(ast::Module* in, ast::Module* out, const Config& c);
+    explicit State(const State&);
     ~State();
 
     /// Generate the vertex buffer binding name
     /// @param index index to append to buffer name
-    std::string GetVertexBufferName(uint32_t index);
+    std::string GetVertexBufferName(uint32_t index) const;
 
     /// Inserts vertex_idx binding, or finds the existing one
     void FindOrInsertVertexIndexIfUsed();
@@ -198,36 +199,36 @@
     void AddVertexStorageBuffers();
 
     /// Creates and returns the assignment to the variables from the buffers
-    ast::BlockStatement* CreateVertexPullingPreamble();
+    ast::BlockStatement* CreateVertexPullingPreamble() const;
 
     /// Generates an expression holding a constant uint
     /// @param value uint value
-    ast::Expression* GenUint(uint32_t value);
+    ast::Expression* GenUint(uint32_t value) const;
 
     /// Generates an expression to read the shader value `kPullingPosVarName`
-    ast::Expression* CreatePullingPositionIdent();
+    ast::Expression* CreatePullingPositionIdent() const;
 
     /// Generates an expression reading from a buffer a specific format.
     /// This reads the value wherever `kPullingPosVarName` points to at the time
     /// of the read.
     /// @param buffer the index of the vertex buffer
     /// @param format the format to read
-    ast::Expression* AccessByFormat(uint32_t buffer, VertexFormat format);
+    ast::Expression* AccessByFormat(uint32_t buffer, VertexFormat format) const;
 
     /// Generates an expression reading a uint32 from a vertex buffer
     /// @param buffer the index of the vertex buffer
     /// @param pos an expression for the position of the access, in bytes
-    ast::Expression* AccessU32(uint32_t buffer, ast::Expression* pos);
+    ast::Expression* AccessU32(uint32_t buffer, ast::Expression* pos) const;
 
     /// Generates an expression reading an int32 from a vertex buffer
     /// @param buffer the index of the vertex buffer
     /// @param pos an expression for the position of the access, in bytes
-    ast::Expression* AccessI32(uint32_t buffer, ast::Expression* pos);
+    ast::Expression* AccessI32(uint32_t buffer, ast::Expression* pos) const;
 
     /// Generates an expression reading a float from a vertex buffer
     /// @param buffer the index of the vertex buffer
     /// @param pos an expression for the position of the access, in bytes
-    ast::Expression* AccessF32(uint32_t buffer, ast::Expression* pos);
+    ast::Expression* AccessF32(uint32_t buffer, ast::Expression* pos) const;
 
     /// Generates an expression reading a basic type (u32, i32, f32) from a
     /// vertex buffer
@@ -236,7 +237,7 @@
     /// @param format the underlying vertex format
     ast::Expression* AccessPrimitive(uint32_t buffer,
                                      ast::Expression* pos,
-                                     VertexFormat format);
+                                     VertexFormat format) const;
 
     /// Generates an expression reading a vec2/3/4 from a vertex buffer.
     /// This reads the value wherever `kPullingPosVarName` points to at the time
@@ -250,14 +251,15 @@
                                uint32_t element_stride,
                                ast::type::Type* base_type,
                                VertexFormat base_format,
-                               uint32_t count);
+                               uint32_t count) const;
 
     // Used to grab corresponding types from the type manager
-    ast::type::Type* GetU32Type();
-    ast::type::Type* GetI32Type();
-    ast::type::Type* GetF32Type();
+    ast::type::Type* GetU32Type() const;
+    ast::type::Type* GetI32Type() const;
+    ast::type::Type* GetF32Type() const;
 
-    ast::Module* const mod;
+    ast::Module* const in;
+    ast::Module* const out;
     Config const cfg;
 
     std::unordered_map<uint32_t, ast::Variable*> location_to_var;
diff --git a/src/transform/vertex_pulling_test.cc b/src/transform/vertex_pulling_test.cc
index 30145ce..818853f 100644
--- a/src/transform/vertex_pulling_test.cc
+++ b/src/transform/vertex_pulling_test.cc
@@ -176,11 +176,6 @@
     StructMember{[[ offset 0 ]] _tint_vertex_data: __array__u32_stride_4}
   }
   Variable{
-    var_a
-    private
-    __f32
-  }
-  Variable{
     Decorations{
       BuiltinDecoration{vertex_idx}
     }
@@ -197,6 +192,11 @@
     storage_buffer
     __struct_TintVertexData
   }
+  Variable{
+    var_a
+    private
+    __f32
+  }
   Function main -> __void
   StageDecoration{vertex}
   ()
@@ -263,11 +263,6 @@
     StructMember{[[ offset 0 ]] _tint_vertex_data: __array__u32_stride_4}
   }
   Variable{
-    var_a
-    private
-    __f32
-  }
-  Variable{
     Decorations{
       BuiltinDecoration{instance_idx}
     }
@@ -284,6 +279,11 @@
     storage_buffer
     __struct_TintVertexData
   }
+  Variable{
+    var_a
+    private
+    __f32
+  }
   Function main -> __void
   StageDecoration{vertex}
   ()
@@ -350,11 +350,6 @@
     StructMember{[[ offset 0 ]] _tint_vertex_data: __array__u32_stride_4}
   }
   Variable{
-    var_a
-    private
-    __f32
-  }
-  Variable{
     Decorations{
       BuiltinDecoration{vertex_idx}
     }
@@ -371,6 +366,11 @@
     storage_buffer
     __struct_TintVertexData
   }
+  Variable{
+    var_a
+    private
+    __f32
+  }
   Function main -> __void
   StageDecoration{vertex}
   ()
@@ -466,6 +466,24 @@
     StructMember{[[ offset 0 ]] _tint_vertex_data: __array__u32_stride_4}
   }
   Variable{
+    Decorations{
+      BindingDecoration{0}
+      SetDecoration{4}
+    }
+    _tint_pulling_vertex_buffer_0
+    storage_buffer
+    __struct_TintVertexData
+  }
+  Variable{
+    Decorations{
+      BindingDecoration{1}
+      SetDecoration{4}
+    }
+    _tint_pulling_vertex_buffer_1
+    storage_buffer
+    __struct_TintVertexData
+  }
+  Variable{
     var_a
     private
     __f32
@@ -491,24 +509,6 @@
     in
     __i32
   }
-  Variable{
-    Decorations{
-      BindingDecoration{0}
-      SetDecoration{4}
-    }
-    _tint_pulling_vertex_buffer_0
-    storage_buffer
-    __struct_TintVertexData
-  }
-  Variable{
-    Decorations{
-      BindingDecoration{1}
-      SetDecoration{4}
-    }
-    _tint_pulling_vertex_buffer_1
-    storage_buffer
-    __struct_TintVertexData
-  }
   Function main -> __void
   StageDecoration{vertex}
   ()
@@ -608,16 +608,6 @@
     StructMember{[[ offset 0 ]] _tint_vertex_data: __array__u32_stride_4}
   }
   Variable{
-    var_a
-    private
-    __f32
-  }
-  Variable{
-    var_b
-    private
-    __array__f32_4
-  }
-  Variable{
     Decorations{
       BuiltinDecoration{vertex_idx}
     }
@@ -634,6 +624,16 @@
     storage_buffer
     __struct_TintVertexData
   }
+  Variable{
+    var_a
+    private
+    __f32
+  }
+  Variable{
+    var_b
+    private
+    __array__f32_4
+  }
   Function main -> __void
   StageDecoration{vertex}
   ()
@@ -795,21 +795,6 @@
     StructMember{[[ offset 0 ]] _tint_vertex_data: __array__u32_stride_4}
   }
   Variable{
-    var_a
-    private
-    __array__f32_2
-  }
-  Variable{
-    var_b
-    private
-    __array__f32_3
-  }
-  Variable{
-    var_c
-    private
-    __array__f32_4
-  }
-  Variable{
     Decorations{
       BuiltinDecoration{vertex_idx}
     }
@@ -844,6 +829,21 @@
     storage_buffer
     __struct_TintVertexData
   }
+  Variable{
+    var_a
+    private
+    __array__f32_2
+  }
+  Variable{
+    var_b
+    private
+    __array__f32_3
+  }
+  Variable{
+    var_c
+    private
+    __array__f32_4
+  }
   Function main -> __void
   StageDecoration{vertex}
   ()