transform: Fix multiple mutability issues with VertexPulling

ConvertVertexInputVariablesToPrivate() mutated the source program global variables, and copied them into the destination program.

Symbols and types were assigned across the program boundary without cloning.

Bug: tint:390
Change-Id: I03c8924e6ba94b745e74de0ab57f8a489e85cc50
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/38554
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/transform/vertex_pulling.cc b/src/transform/vertex_pulling.cc
index 081c8c9..caca8c2 100644
--- a/src/transform/vertex_pulling.cc
+++ b/src/transform/vertex_pulling.cc
@@ -103,21 +103,24 @@
   // following stages will pass
   Output out;
 
-  State state{in, &out.program, cfg};
+  CloneContext ctx(&out.program, in);
+  State state{ctx, cfg};
   state.FindOrInsertVertexIndexIfUsed();
   state.FindOrInsertInstanceIndexIfUsed();
   state.ConvertVertexInputVariablesToPrivate();
   state.AddVertexStorageBuffers();
 
-  CloneContext(&out.program, in)
-      .ReplaceAll([&](CloneContext* ctx, ast::Function* f) -> ast::Function* {
-        if (f == func) {
-          return CloneWithStatementsAtStart(
-              ctx, f, {state.CreateVertexPullingPreamble()});
-        }
-        return nullptr;  // Just clone func
-      })
-      .Clone();
+  for (auto& replacement : state.location_replacements) {
+    ctx.Replace(replacement.from, replacement.to);
+  }
+  ctx.ReplaceAll([&](CloneContext*, ast::Function* f) -> ast::Function* {
+    if (f == func) {
+      return CloneWithStatementsAtStart(&ctx, f,
+                                        {state.CreateVertexPullingPreamble()});
+    }
+    return nullptr;  // Just clone func
+  });
+  ctx.Clone();
 
   return out;
 }
@@ -126,8 +129,8 @@
 VertexPulling::Config::Config(const Config&) = default;
 VertexPulling::Config::~Config() = default;
 
-VertexPulling::State::State(const Program* i, Program* o, const Config& c)
-    : in(i), out(o), cfg(c) {}
+VertexPulling::State::State(CloneContext& context, const Config& c)
+    : ctx(context), cfg(c) {}
 
 VertexPulling::State::State(const State&) = default;
 
@@ -150,7 +153,7 @@
   }
 
   // Look for an existing vertex index builtin
-  for (auto* v : in->AST().GlobalVariables()) {
+  for (auto* v : ctx.src->AST().GlobalVariables()) {
     if (v->storage_class() != ast::StorageClass::kInput) {
       continue;
     }
@@ -158,7 +161,7 @@
     for (auto* d : v->decorations()) {
       if (auto* builtin = d->As<ast::BuiltinDecoration>()) {
         if (builtin->value() == ast::Builtin::kVertexIndex) {
-          vertex_index_name = in->Symbols().NameFor(v->symbol());
+          vertex_index_name = ctx.src->Symbols().NameFor(v->symbol());
           return;
         }
       }
@@ -168,20 +171,19 @@
   // We didn't find a vertex index builtin, so create one
   vertex_index_name = kDefaultVertexIndexName;
 
-  auto* var = out->create<ast::Variable>(
-      Source{},                                    // source
-      out->Symbols().Register(vertex_index_name),  // symbol
-      ast::StorageClass::kInput,                   // storage_class
-      GetI32Type(),                                // type
-      false,                                       // is_const
-      nullptr,                                     // constructor
+  auto* var = ctx.dst->create<ast::Variable>(
+      Source{},                                        // source
+      ctx.dst->Symbols().Register(vertex_index_name),  // symbol
+      ast::StorageClass::kInput,                       // storage_class
+      GetI32Type(),                                    // type
+      false,                                           // is_const
+      nullptr,                                         // constructor
       ast::VariableDecorationList{
-          // decorations
-          out->create<ast::BuiltinDecoration>(Source{},
-                                              ast::Builtin::kVertexIndex),
+          ctx.dst->create<ast::BuiltinDecoration>(Source{},
+                                                  ast::Builtin::kVertexIndex),
       });
 
-  out->AST().AddGlobalVariable(var);
+  ctx.dst->AST().AddGlobalVariable(var);
 }
 
 void VertexPulling::State::FindOrInsertInstanceIndexIfUsed() {
@@ -197,7 +199,7 @@
   }
 
   // Look for an existing instance index builtin
-  for (auto* v : in->AST().GlobalVariables()) {
+  for (auto* v : ctx.src->AST().GlobalVariables()) {
     if (v->storage_class() != ast::StorageClass::kInput) {
       continue;
     }
@@ -205,7 +207,7 @@
     for (auto* d : v->decorations()) {
       if (auto* builtin = d->As<ast::BuiltinDecoration>()) {
         if (builtin->value() == ast::Builtin::kInstanceIndex) {
-          instance_index_name = in->Symbols().NameFor(v->symbol());
+          instance_index_name = ctx.src->Symbols().NameFor(v->symbol());
           return;
         }
       }
@@ -215,24 +217,22 @@
   // We didn't find an instance index builtin, so create one
   instance_index_name = kDefaultInstanceIndexName;
 
-  auto* var = out->create<ast::Variable>(
-      Source{},                                      // source
-      out->Symbols().Register(instance_index_name),  // symbol
-      ast::StorageClass::kInput,                     // storage_class
-      GetI32Type(),                                  // type
-      false,                                         // is_const
-      nullptr,                                       // constructor
+  auto* var = ctx.dst->create<ast::Variable>(
+      Source{},                                          // source
+      ctx.dst->Symbols().Register(instance_index_name),  // symbol
+      ast::StorageClass::kInput,                         // storage_class
+      GetI32Type(),                                      // type
+      false,                                             // is_const
+      nullptr,                                           // constructor
       ast::VariableDecorationList{
-          // decorations
-          out->create<ast::BuiltinDecoration>(Source{},
-                                              ast::Builtin::kInstanceIndex),
+          ctx.dst->create<ast::BuiltinDecoration>(Source{},
+                                                  ast::Builtin::kInstanceIndex),
       });
-  out->AST().AddGlobalVariable(var);
+  ctx.dst->AST().AddGlobalVariable(var);
 }
 
 void VertexPulling::State::ConvertVertexInputVariablesToPrivate() {
-  // TODO(https://crbug.com/tint/390): Remove this const_cast hack!
-  for (auto*& v : const_cast<Program*>(in)->AST().GlobalVariables()) {
+  for (auto* v : ctx.src->AST().GlobalVariables()) {
     if (v->storage_class() != ast::StorageClass::kInput) {
       continue;
     }
@@ -240,18 +240,20 @@
     for (auto* d : v->decorations()) {
       if (auto* l = d->As<ast::LocationDecoration>()) {
         uint32_t location = l->value();
-        // This is where the replacement happens. Expressions use identifier
+        // This is where the replacement is created. Expressions use identifier
         // strings instead of pointers, so we don't need to update any other
         // place in the AST.
-        v = out->create<ast::Variable>(
-            Source{},                        // source
-            v->symbol(),                     // symbol
-            ast::StorageClass::kPrivate,     // storage_class
-            v->type(),                       // type
-            false,                           // is_const
-            nullptr,                         // constructor
-            ast::VariableDecorationList{});  // decorations
-        location_to_var[location] = v;
+        auto name = ctx.src->Symbols().NameFor(v->symbol());
+        auto* replacement = ctx.dst->create<ast::Variable>(
+            Source{},                           // source
+            ctx.dst->Symbols().Register(name),  // symbol
+            ast::StorageClass::kPrivate,        // storage_class
+            ctx.Clone(v->type()),               // type
+            false,                              // is_const
+            nullptr,                            // constructor
+            ast::VariableDecorationList{});     // decorations
+        location_to_var[location] = replacement;
+        location_replacements.emplace_back(LocationReplacement{v, replacement});
         break;
       }
     }
@@ -261,47 +263,47 @@
 void VertexPulling::State::AddVertexStorageBuffers() {
   // TODO(idanr): Make this readonly https://github.com/gpuweb/gpuweb/issues/935
   // The array inside the struct definition
-  auto* internal_array_type = out->create<type::Array>(
+  auto* internal_array_type = ctx.dst->create<type::Array>(
       GetU32Type(), 0,
       ast::ArrayDecorationList{
-          out->create<ast::StrideDecoration>(Source{}, 4u),
+          ctx.dst->create<ast::StrideDecoration>(Source{}, 4u),
       });
 
   // Creating the struct type
   ast::StructMemberList members;
   ast::StructMemberDecorationList member_dec;
   member_dec.push_back(
-      out->create<ast::StructMemberOffsetDecoration>(Source{}, 0u));
+      ctx.dst->create<ast::StructMemberOffsetDecoration>(Source{}, 0u));
 
-  members.push_back(out->create<ast::StructMember>(
-      Source{}, out->Symbols().Register(kStructBufferName), internal_array_type,
-      std::move(member_dec)));
+  members.push_back(ctx.dst->create<ast::StructMember>(
+      Source{}, ctx.dst->Symbols().Register(kStructBufferName),
+      internal_array_type, std::move(member_dec)));
 
   ast::StructDecorationList decos;
-  decos.push_back(out->create<ast::StructBlockDecoration>(Source{}));
+  decos.push_back(ctx.dst->create<ast::StructBlockDecoration>(Source{}));
 
-  auto* struct_type = out->create<type::Struct>(
-      out->Symbols().Register(kStructName),
-      out->create<ast::Struct>(Source{}, std::move(members), std::move(decos)));
+  auto* struct_type = ctx.dst->create<type::Struct>(
+      ctx.dst->Symbols().Register(kStructName),
+      ctx.dst->create<ast::Struct>(Source{}, std::move(members),
+                                   std::move(decos)));
 
   for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) {
     // The decorated variable with struct type
     std::string name = GetVertexBufferName(i);
-    auto* var = out->create<ast::Variable>(
-        Source{},                       // source
-        out->Symbols().Register(name),  // symbol
-        ast::StorageClass::kStorage,    // storage_class
-        struct_type,                    // type
-        false,                          // is_const
-        nullptr,                        // constructor
+    auto* var = ctx.dst->create<ast::Variable>(
+        Source{},                           // source
+        ctx.dst->Symbols().Register(name),  // symbol
+        ast::StorageClass::kStorage,        // storage_class
+        struct_type,                        // type
+        false,                              // is_const
+        nullptr,                            // constructor
         ast::VariableDecorationList{
-            // decorations
-            out->create<ast::BindingDecoration>(Source{}, i),
-            out->create<ast::GroupDecoration>(Source{}, cfg.pulling_group),
+            ctx.dst->create<ast::BindingDecoration>(Source{}, i),
+            ctx.dst->create<ast::GroupDecoration>(Source{}, cfg.pulling_group),
         });
-    out->AST().AddGlobalVariable(var);
+    ctx.dst->AST().AddGlobalVariable(var);
   }
-  out->AST().AddConstructedType(struct_type);
+  ctx.dst->AST().AddConstructedType(struct_type);
 }
 
 ast::BlockStatement* VertexPulling::State::CreateVertexPullingPreamble() const {
@@ -311,10 +313,10 @@
   ast::StatementList stmts;
 
   // Declare the |kPullingPosVarName| variable in the shader
-  auto* pos_declaration = out->create<ast::VariableDeclStatement>(
-      Source{}, out->create<ast::Variable>(
-                    Source{},                                     // source
-                    out->Symbols().Register(kPullingPosVarName),  // symbol
+  auto* pos_declaration = ctx.dst->create<ast::VariableDeclStatement>(
+      Source{}, ctx.dst->create<ast::Variable>(
+                    Source{},                                         // source
+                    ctx.dst->Symbols().Register(kPullingPosVarName),  // symbol
                     ast::StorageClass::kFunction,     // storage_class
                     GetI32Type(),                     // type
                     false,                            // is_const
@@ -341,42 +343,41 @@
                       ? vertex_index_name
                       : instance_index_name;
       // Identifier to index by
-      auto* index_identifier = out->create<ast::IdentifierExpression>(
-          Source{}, out->Symbols().Register(name));
+      auto* index_identifier = ctx.dst->create<ast::IdentifierExpression>(
+          Source{}, ctx.dst->Symbols().Register(name));
 
       // An expression for the start of the read in the buffer in bytes
-      auto* pos_value = out->create<ast::BinaryExpression>(
+      auto* pos_value = ctx.dst->create<ast::BinaryExpression>(
           Source{}, ast::BinaryOp::kAdd,
-          out->create<ast::BinaryExpression>(
+          ctx.dst->create<ast::BinaryExpression>(
               Source{}, ast::BinaryOp::kMultiply, index_identifier,
               GenUint(static_cast<uint32_t>(buffer_layout.array_stride))),
           GenUint(static_cast<uint32_t>(attribute_desc.offset)));
 
       // Update position of the read
-      auto* set_pos_expr = out->create<ast::AssignmentStatement>(
+      auto* set_pos_expr = ctx.dst->create<ast::AssignmentStatement>(
           Source{}, CreatePullingPositionIdent(), pos_value);
       stmts.emplace_back(set_pos_expr);
 
-      auto ident_name = in->Symbols().NameFor(v->symbol());
-      stmts.emplace_back(out->create<ast::AssignmentStatement>(
+      stmts.emplace_back(ctx.dst->create<ast::AssignmentStatement>(
           Source{},
-          out->create<ast::IdentifierExpression>(
-              Source{}, out->Symbols().Register(ident_name)),
+          ctx.dst->create<ast::IdentifierExpression>(Source{}, v->symbol()),
           AccessByFormat(i, attribute_desc.format)));
     }
   }
 
-  return out->create<ast::BlockStatement>(Source{}, stmts);
+  return ctx.dst->create<ast::BlockStatement>(Source{}, stmts);
 }
 
 ast::Expression* VertexPulling::State::GenUint(uint32_t value) const {
-  return out->create<ast::ScalarConstructorExpression>(
-      Source{}, out->create<ast::UintLiteral>(Source{}, GetU32Type(), value));
+  return ctx.dst->create<ast::ScalarConstructorExpression>(
+      Source{},
+      ctx.dst->create<ast::UintLiteral>(Source{}, GetU32Type(), value));
 }
 
 ast::Expression* VertexPulling::State::CreatePullingPositionIdent() const {
-  return out->create<ast::IdentifierExpression>(
-      Source{}, out->Symbols().Register(kPullingPosVarName));
+  return ctx.dst->create<ast::IdentifierExpression>(
+      Source{}, ctx.dst->Symbols().Register(kPullingPosVarName));
 }
 
 ast::Expression* VertexPulling::State::AccessByFormat(
@@ -415,30 +416,30 @@
   // unpacked into an appropriate variable. All reads should end up here as a
   // base case.
   auto vbuf_name = GetVertexBufferName(buffer);
-  return out->create<ast::ArrayAccessorExpression>(
+  return ctx.dst->create<ast::ArrayAccessorExpression>(
       Source{},
-      out->create<ast::MemberAccessorExpression>(
+      ctx.dst->create<ast::MemberAccessorExpression>(
           Source{},
-          out->create<ast::IdentifierExpression>(
-              Source{}, out->Symbols().Register(vbuf_name)),
-          out->create<ast::IdentifierExpression>(
-              Source{}, out->Symbols().Register(kStructBufferName))),
-      out->create<ast::BinaryExpression>(Source{}, ast::BinaryOp::kDivide, pos,
-                                         GenUint(4)));
+          ctx.dst->create<ast::IdentifierExpression>(
+              Source{}, ctx.dst->Symbols().Register(vbuf_name)),
+          ctx.dst->create<ast::IdentifierExpression>(
+              Source{}, ctx.dst->Symbols().Register(kStructBufferName))),
+      ctx.dst->create<ast::BinaryExpression>(Source{}, ast::BinaryOp::kDivide,
+                                             pos, GenUint(4)));
 }
 
 ast::Expression* VertexPulling::State::AccessI32(uint32_t buffer,
                                                  ast::Expression* pos) const {
   // as<T> reinterprets bits
-  return out->create<ast::BitcastExpression>(Source{}, GetI32Type(),
-                                             AccessU32(buffer, pos));
+  return ctx.dst->create<ast::BitcastExpression>(Source{}, GetI32Type(),
+                                                 AccessU32(buffer, pos));
 }
 
 ast::Expression* VertexPulling::State::AccessF32(uint32_t buffer,
                                                  ast::Expression* pos) const {
   // as<T> reinterprets bits
-  return out->create<ast::BitcastExpression>(Source{}, GetF32Type(),
-                                             AccessU32(buffer, pos));
+  return ctx.dst->create<ast::BitcastExpression>(Source{}, GetF32Type(),
+                                                 AccessU32(buffer, pos));
 }
 
 ast::Expression* VertexPulling::State::AccessPrimitive(
@@ -469,27 +470,27 @@
   ast::ExpressionList expr_list;
   for (uint32_t i = 0; i < count; ++i) {
     // Offset read position by element_stride for each component
-    auto* cur_pos = out->create<ast::BinaryExpression>(
+    auto* cur_pos = ctx.dst->create<ast::BinaryExpression>(
         Source{}, ast::BinaryOp::kAdd, CreatePullingPositionIdent(),
         GenUint(element_stride * i));
     expr_list.push_back(AccessPrimitive(buffer, cur_pos, base_format));
   }
 
-  return out->create<ast::TypeConstructorExpression>(
-      Source{}, out->create<type::Vector>(base_type, count),
+  return ctx.dst->create<ast::TypeConstructorExpression>(
+      Source{}, ctx.dst->create<type::Vector>(base_type, count),
       std::move(expr_list));
 }
 
 type::Type* VertexPulling::State::GetU32Type() const {
-  return out->create<type::U32>();
+  return ctx.dst->create<type::U32>();
 }
 
 type::Type* VertexPulling::State::GetI32Type() const {
-  return out->create<type::I32>();
+  return ctx.dst->create<type::I32>();
 }
 
 type::Type* VertexPulling::State::GetF32Type() const {
-  return out->create<type::F32>();
+  return ctx.dst->create<type::F32>();
 }
 
 VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor() = default;
diff --git a/src/transform/vertex_pulling.h b/src/transform/vertex_pulling.h
index 7569cb7..3db054d 100644
--- a/src/transform/vertex_pulling.h
+++ b/src/transform/vertex_pulling.h
@@ -183,7 +183,7 @@
   Config cfg;
 
   struct State {
-    State(const Program* in, Program* out, const Config& c);
+    State(CloneContext& ctx, const Config& c);
     explicit State(const State&);
     ~State();
 
@@ -263,11 +263,20 @@
     type::Type* GetI32Type() const;
     type::Type* GetF32Type() const;
 
-    const Program* const in;
-    Program* const out;
+    CloneContext& ctx;
     Config const cfg;
 
+    /// LocationReplacement describes an ast::Variable replacement for a
+    /// location input.
+    struct LocationReplacement {
+      /// The variable to replace in the source Program
+      ast::Variable* from;
+      /// The replacement to use in the target ProgramBuilder
+      ast::Variable* to;
+    };
+
     std::unordered_map<uint32_t, ast::Variable*> location_to_var;
+    std::vector<LocationReplacement> location_replacements;
     std::string vertex_index_name;
     std::string instance_index_name;
   };