CloneContext: Pass the vector to InsertBefore()

There's usually only ever one vector we want to insert into.
Inserting into *all* vectors that happen to contain the reference object is likely unintended, and is a foot-gun waiting to go off.

Change-Id: I533084ccad102fc998b851fd238fd6bea9299450
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/46445
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/clone_context.cc b/src/clone_context.cc
index 06a28e1..f449b0b 100644
--- a/src/clone_context.cc
+++ b/src/clone_context.cc
@@ -20,6 +20,9 @@
 
 namespace tint {
 
+CloneContext::ListTransforms::ListTransforms() = default;
+CloneContext::ListTransforms::~ListTransforms() = default;
+
 CloneContext::CloneContext(ProgramBuilder* to, Program const* from)
     : dst(to), src(from) {}
 CloneContext::~CloneContext() = default;
diff --git a/src/clone_context.h b/src/clone_context.h
index 2853ad4..fd7a0ac 100644
--- a/src/clone_context.h
+++ b/src/clone_context.h
@@ -15,6 +15,7 @@
 #ifndef SRC_CLONE_CONTEXT_H_
 #define SRC_CLONE_CONTEXT_H_
 
+#include <algorithm>
 #include <functional>
 #include <unordered_map>
 #include <utility>
@@ -178,14 +179,29 @@
   std::vector<T*> Clone(const std::vector<T*>& v) {
     std::vector<T*> out;
     out.reserve(v.size());
-    for (auto& el : v) {
-      auto it = insert_before_.find(el);
-      if (it != insert_before_.end()) {
-        for (auto insert : it->second) {
-          out.emplace_back(CheckedCast<T>(insert));
+
+    auto list_transform_it = list_transforms_.find(&v);
+    if (list_transform_it != list_transforms_.end()) {
+      const auto& transforms = list_transform_it->second;
+      for (auto& el : v) {
+        auto insert_before_it = transforms.insert_before_.find(el);
+        if (insert_before_it != transforms.insert_before_.end()) {
+          for (auto insert : insert_before_it->second) {
+            out.emplace_back(CheckedCast<T>(insert));
+          }
+        }
+        out.emplace_back(Clone(el));
+        auto insert_after_it = transforms.insert_after_.find(el);
+        if (insert_after_it != transforms.insert_after_.end()) {
+          for (auto insert : insert_after_it->second) {
+            out.emplace_back(CheckedCast<T>(insert));
+          }
         }
       }
-      out.emplace_back(Clone(el));
+    } else {
+      for (auto& el : v) {
+        out.emplace_back(Clone(el));
+      }
     }
     return out;
   }
@@ -293,15 +309,46 @@
     return *this;
   }
 
-  /// Inserts `object` before `before` whenever a vector containing `object` is
-  /// cloned.
+  /// Inserts `object` before `before` whenever `vector` is cloned.
+  /// @param vector the vector in #src
   /// @param before a pointer to the object in #src
   /// @param object a pointer to the object in #dst that will be inserted before
   /// any occurrence of the clone of `before`
   /// @returns this CloneContext so calls can be chained
-  template <typename BEFORE, typename OBJECT>
-  CloneContext& InsertBefore(BEFORE* before, OBJECT* object) {
-    auto& list = insert_before_[before];
+  template <typename T, typename BEFORE, typename OBJECT>
+  CloneContext& InsertBefore(const std::vector<T>& vector,
+                             BEFORE* before,
+                             OBJECT* object) {
+    if (std::find(vector.begin(), vector.end(), before) == vector.end()) {
+      TINT_ICE(Diagnostics())
+          << "CloneContext::InsertBefore() vector does not contain before";
+      return *this;
+    }
+
+    auto& transforms = list_transforms_[&vector];
+    auto& list = transforms.insert_before_[before];
+    list.emplace_back(object);
+    return *this;
+  }
+
+  /// Inserts `object` after `after` whenever `vector` is cloned.
+  /// @param vector the vector in #src
+  /// @param after a pointer to the object in #src
+  /// @param object a pointer to the object in #dst that will be inserted after
+  /// any occurrence of the clone of `after`
+  /// @returns this CloneContext so calls can be chained
+  template <typename T, typename AFTER, typename OBJECT>
+  CloneContext& InsertAfter(const std::vector<T>& vector,
+                            AFTER* after,
+                            OBJECT* object) {
+    if (std::find(vector.begin(), vector.end(), after) == vector.end()) {
+      TINT_ICE(Diagnostics())
+          << "CloneContext::InsertAfter() vector does not contain after";
+      return *this;
+    }
+
+    auto& transforms = list_transforms_[&vector];
+    auto& list = transforms.insert_after_[after];
     list.emplace_back(object);
     return *this;
   }
@@ -380,17 +427,33 @@
   /// A vector of Cloneable*
   using CloneableList = std::vector<Cloneable*>;
 
+  // Transformations to be applied to a list (vector)
+  struct ListTransforms {
+    /// Constructor
+    ListTransforms();
+    /// Destructor
+    ~ListTransforms();
+
+    /// A map of object in #src to the list of cloned objects in #dst.
+    /// Clone(const std::vector<T*>& v) will use this to insert the map-value
+    /// list into the target vector before cloning and inserting the map-key.
+    std::unordered_map<const Cloneable*, CloneableList> insert_before_;
+
+    /// A map of object in #src to the list of cloned objects in #dst.
+    /// Clone(const std::vector<T*>& v) will use this to insert the map-value
+    /// list into the target vector after cloning and inserting the map-key.
+    std::unordered_map<const Cloneable*, CloneableList> insert_after_;
+  };
+
   /// A map of object in #src to their cloned equivalent in #dst
   std::unordered_map<const Cloneable*, Cloneable*> cloned_;
 
-  /// A map of object in #src to the list of cloned objects in #dst.
-  /// Clone(const std::vector<T*>& v) will use this to insert the map-value list
-  /// into the target vector/ before cloning and inserting the map-key.
-  std::unordered_map<const Cloneable*, CloneableList> insert_before_;
-
   /// Cloneable transform functions registered with ReplaceAll()
   std::vector<CloneableTransform> transforms_;
 
+  /// Map of std::vector pointer to transforms for that list
+  std::unordered_map<const void*, ListTransforms> list_transforms_;
+
   /// Symbol transform registered with ReplaceAll()
   SymbolTransform symbol_transform_;
 };
diff --git a/src/clone_context_test.cc b/src/clone_context_test.cc
index 372da30..b09ccdc 100644
--- a/src/clone_context_test.cc
+++ b/src/clone_context_test.cc
@@ -287,9 +287,10 @@
   ProgramBuilder cloned;
   auto* insertion = cloned.create<Node>(cloned.Symbols().Register("insertion"));
 
-  auto* cloned_root = CloneContext(&cloned, &original)
-                          .InsertBefore(original_root->b, insertion)
-                          .Clone(original_root);
+  auto* cloned_root =
+      CloneContext(&cloned, &original)
+          .InsertBefore(original_root->vec, original_root->b, insertion)
+          .Clone(original_root);
 
   EXPECT_EQ(cloned_root->vec.size(), 4u);
   EXPECT_EQ(cloned_root->vec[0], cloned_root->a);
@@ -303,6 +304,36 @@
   EXPECT_EQ(cloned_root->vec[3]->name, cloned.Symbols().Get("c"));
 }
 
+TEST(CloneContext, CloneWithInsertAfter) {
+  ProgramBuilder builder;
+  auto* original_root =
+      builder.create<Node>(builder.Symbols().Register("root"));
+  original_root->a = builder.create<Node>(builder.Symbols().Register("a"));
+  original_root->b = builder.create<Node>(builder.Symbols().Register("b"));
+  original_root->c = builder.create<Node>(builder.Symbols().Register("c"));
+  original_root->vec = {original_root->a, original_root->b, original_root->c};
+  Program original(std::move(builder));
+
+  ProgramBuilder cloned;
+  auto* insertion = cloned.create<Node>(cloned.Symbols().Register("insertion"));
+
+  auto* cloned_root =
+      CloneContext(&cloned, &original)
+          .InsertAfter(original_root->vec, original_root->b, insertion)
+          .Clone(original_root);
+
+  EXPECT_EQ(cloned_root->vec.size(), 4u);
+  EXPECT_EQ(cloned_root->vec[0], cloned_root->a);
+  EXPECT_EQ(cloned_root->vec[1], cloned_root->b);
+  EXPECT_EQ(cloned_root->vec[3], cloned_root->c);
+
+  EXPECT_EQ(cloned_root->name, cloned.Symbols().Get("root"));
+  EXPECT_EQ(cloned_root->vec[0]->name, cloned.Symbols().Get("a"));
+  EXPECT_EQ(cloned_root->vec[1]->name, cloned.Symbols().Get("b"));
+  EXPECT_EQ(cloned_root->vec[2]->name, cloned.Symbols().Get("insertion"));
+  EXPECT_EQ(cloned_root->vec[3]->name, cloned.Symbols().Get("c"));
+}
+
 TEST(CloneContext, CloneWithReplaceAll_SameTypeTwice) {
   EXPECT_FATAL_FAILURE(
       {
diff --git a/src/transform/canonicalize_entry_point_io.cc b/src/transform/canonicalize_entry_point_io.cc
index ca4af82..1c52066 100644
--- a/src/transform/canonicalize_entry_point_io.cc
+++ b/src/transform/canonicalize_entry_point_io.cc
@@ -18,6 +18,7 @@
 
 #include "src/program_builder.h"
 #include "src/semantic/function.h"
+#include "src/semantic/statement.h"
 #include "src/semantic/variable.h"
 
 namespace tint {
@@ -119,7 +120,7 @@
         // Initialize it with the value extracted from the new struct parameter.
         auto* func_const = ctx.dst->Const(
             func_const_symbol, ctx.Clone(param_ty), func_const_initializer);
-        ctx.InsertBefore(*func->body()->begin(),
+        ctx.InsertBefore(func->body()->statements(), *func->body()->begin(),
                          ctx.dst->WrapInStatement(func_const));
 
         // Replace all uses of the function parameter with the function const.
@@ -134,7 +135,7 @@
           ctx.dst->Symbols().New(),
           ctx.dst->create<ast::Struct>(new_struct_members,
                                        ast::DecorationList{}));
-      ctx.InsertBefore(func, in_struct);
+      ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, in_struct);
 
       // Create a new function parameter using this struct type.
       auto* struct_param = ctx.dst->Var(new_struct_param_symbol, in_struct,
@@ -177,12 +178,13 @@
           ctx.dst->Symbols().New(),
           ctx.dst->create<ast::Struct>(new_struct_members,
                                        ast::DecorationList{}));
-      ctx.InsertBefore(func, out_struct);
+      ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, out_struct);
       new_ret_type = out_struct;
 
       // Replace all return statements.
       auto* sem_func = ctx.src->Sem().Get(func);
       for (auto* ret : sem_func->ReturnStatements()) {
+        auto* ret_sem = ctx.src->Sem().Get(ret);
         // Reconstruct the return value using the newly created struct.
         auto* new_ret_value = ctx.Clone(ret->value());
         ast::ExpressionList ret_values;
@@ -193,7 +195,7 @@
             auto temp = ctx.dst->Symbols().New();
             auto* temp_var = ctx.dst->Decl(
                 ctx.dst->Const(temp, ctx.Clone(ret_type), new_ret_value));
-            ctx.InsertBefore(ret, temp_var);
+            ctx.InsertBefore(ret_sem->Block()->statements(), ret, temp_var);
             new_ret_value = ctx.dst->Expr(temp);
           }
 
diff --git a/src/transform/hlsl.cc b/src/transform/hlsl.cc
index 21b9bf7..0a08e24 100644
--- a/src/transform/hlsl.cc
+++ b/src/transform/hlsl.cc
@@ -96,7 +96,8 @@
         auto* dst_ident = ctx.dst->Expr(dst_symbol);
 
         // Insert the constant before the usage
-        ctx.InsertBefore(src_stmt, dst_var_decl);
+        ctx.InsertBefore(src_sem_stmt->Block()->statements(), src_stmt,
+                         dst_var_decl);
         // Replace the inlined array with a reference to the constant
         ctx.Replace(src_init, dst_ident);
       }
diff --git a/src/transform/spirv.cc b/src/transform/spirv.cc
index 1eb7a6a..d29cfa4 100644
--- a/src/transform/spirv.cc
+++ b/src/transform/spirv.cc
@@ -21,6 +21,7 @@
 #include "src/ast/return_statement.h"
 #include "src/program_builder.h"
 #include "src/semantic/function.h"
+#include "src/semantic/statement.h"
 #include "src/semantic/variable.h"
 
 namespace tint {
@@ -162,13 +163,15 @@
           return_func_symbol, ast::VariableList{store_value},
           ctx.dst->ty.void_(), ctx.dst->create<ast::BlockStatement>(stores),
           ast::DecorationList{}, ast::DecorationList{});
-      ctx.InsertBefore(func, return_func);
+      ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, return_func);
 
       // Replace all return statements with calls to the output function.
       auto* sem_func = ctx.src->Sem().Get(func);
       for (auto* ret : sem_func->ReturnStatements()) {
+        auto* ret_sem = ctx.src->Sem().Get(ret);
         auto* call = ctx.dst->Call(return_func_symbol, ctx.Clone(ret->value()));
-        ctx.InsertBefore(ret, ctx.dst->create<ast::CallStatement>(call));
+        ctx.InsertBefore(ret_sem->Block()->statements(), ret,
+                         ctx.dst->create<ast::CallStatement>(call));
         ctx.Replace(ret, ctx.dst->create<ast::ReturnStatement>());
       }
     }
@@ -247,7 +250,7 @@
     auto* global_var =
         ctx.dst->Var(global_var_symbol, ctx.Clone(ty),
                      ast::StorageClass::kInput, nullptr, new_decorations);
-    ctx.InsertBefore(func, global_var);
+    ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, global_var);
     return global_var_symbol;
   }
 
@@ -269,7 +272,8 @@
   // Create a function-scope variable for the struct.
   auto* initializer = ctx.dst->Construct(ctx.Clone(ty), init_values);
   auto* func_var = ctx.dst->Const(func_var_symbol, ctx.Clone(ty), initializer);
-  ctx.InsertBefore(*func->body()->begin(), ctx.dst->WrapInStatement(func_var));
+  ctx.InsertBefore(func->body()->statements(), *func->body()->begin(),
+                   ctx.dst->WrapInStatement(func_var));
   return func_var_symbol;
 }
 
@@ -292,7 +296,7 @@
     auto* global_var =
         ctx.dst->Var(global_var_symbol, ctx.Clone(ty),
                      ast::StorageClass::kOutput, nullptr, new_decorations);
-    ctx.InsertBefore(func, global_var);
+    ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, global_var);
 
     // Create the assignment instruction.
     ast::Expression* rhs = ctx.dst->Expr(store_value);