CloneContext: Add InsertBefore()

Inserts objects before others when cloning

Change-Id: Ibf247abae3aeb3d351048f1182db2a2b42b2c677
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/41547
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: David Neto <dneto@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/clone_context.h b/src/clone_context.h
index 8424b73..0442b3e 100644
--- a/src/clone_context.h
+++ b/src/clone_context.h
@@ -152,6 +152,30 @@
   }
 
   /// Clones each of the elements of the vector `v` into the ProgramBuilder
+  /// #dst, inserting any additional elements into the list that were registered
+  /// with calls to InsertBefore().
+  ///
+  /// All the elements of the vector `v` must be owned by the Program #src.
+  ///
+  /// @param v the vector to clone
+  /// @return the cloned vector
+  template <typename T>
+  std::vector<T*> Clone(const std::vector<T*>& v) {
+    std::vector<T*> out;
+    out.reserve(v.size());
+    for (auto& el : v) {
+      auto it = insert_before_.find(el);
+      if (it != insert_before_.end()) {
+        for (auto insert : it->second) {
+          out.emplace_back(CheckedCast<T>(insert));
+        }
+      }
+      out.emplace_back(Clone(el));
+    }
+    return out;
+  }
+
+  /// Clones each of the elements of the vector `v` into the ProgramBuilder
   /// #dst.
   ///
   /// All the elements of the vector `v` must be owned by the Program #src.
@@ -219,6 +243,19 @@
     return *this;
   }
 
+  /// Inserts `object` before `before` whenever a vector containing `object` is
+  /// cloned.
+  /// @param before a pointer to the object in #src
+  /// @param object a pointer to the object in #dst that will be inserted before
+  /// any occurrence of the clone of `before`
+  /// @returns this CloneContext so calls can be chained
+  template <typename BEFORE, typename OBJECT>
+  CloneContext& InsertBefore(BEFORE* before, OBJECT* object) {
+    auto& list = insert_before_[before];
+    list.emplace_back(object);
+    return *this;
+  }
+
   /// Clone performs the clone of the entire Program #src to #dst.
   void Clone();
 
@@ -266,7 +303,18 @@
     return cast;
   }
 
+  /// A vector of CastableBase*
+  using CastableList = std::vector<CastableBase*>;
+
+  /// A map of object in #src to their cloned equivalent in #dst
   std::unordered_map<CastableBase*, CastableBase*> cloned_;
+
+  /// A map of object in #src to the list of cloned objects in #dst.
+  /// Clone(const std::vector<T*>& v) will use this to insert the map-value list
+  /// into the target vector/ before cloning and inserting the map-key.
+  std::unordered_map<CastableBase*, CastableList> insert_before_;
+
+  /// Transform functions registered with ReplaceAll()
   std::vector<Transform> transforms_;
 };
 
diff --git a/src/clone_context_test.cc b/src/clone_context_test.cc
index 5e47cb2..edeed3a 100644
--- a/src/clone_context_test.cc
+++ b/src/clone_context_test.cc
@@ -14,7 +14,9 @@
 
 #include "src/clone_context.h"
 
+#include <string>
 #include <utility>
+#include <vector>
 
 #include "gtest/gtest.h"
 
@@ -24,17 +26,21 @@
 namespace {
 
 struct Cloneable : public Castable<Cloneable, ast::Node> {
-  explicit Cloneable(const Source& source) : Base(source) {}
+  explicit Cloneable(const Source& source, std::string n)
+      : Base(source), name(n) {}
 
+  std::string name;
   Cloneable* a = nullptr;
   Cloneable* b = nullptr;
   Cloneable* c = nullptr;
+  std::vector<Cloneable*> vec;
 
   Cloneable* Clone(CloneContext* ctx) const override {
-    auto* out = ctx->dst->create<Cloneable>();
+    auto* out = ctx->dst->create<Cloneable>(name);
     out->a = ctx->Clone(a);
     out->b = ctx->Clone(b);
     out->c = ctx->Clone(c);
+    out->vec = ctx->Clone(vec);
     return out;
   }
 
@@ -43,10 +49,10 @@
 };
 
 struct Replaceable : public Castable<Replaceable, Cloneable> {
-  explicit Replaceable(const Source& source) : Base(source) {}
+  explicit Replaceable(const Source& source, std::string n) : Base(source, n) {}
 };
 struct Replacement : public Castable<Replacement, Replaceable> {
-  explicit Replacement(const Source& source) : Base(source) {}
+  explicit Replacement(const Source& source, std::string n) : Base(source, n) {}
 };
 
 struct NotACloneable : public Castable<NotACloneable, ast::Node> {
@@ -62,12 +68,12 @@
 
 TEST(CloneContext, Clone) {
   ProgramBuilder builder;
-  auto* original_root = builder.create<Cloneable>();
-  original_root->a = builder.create<Cloneable>();
-  original_root->a->b = builder.create<Cloneable>();
-  original_root->b = builder.create<Cloneable>();
+  auto* original_root = builder.create<Cloneable>("root");
+  original_root->a = builder.create<Cloneable>("a");
+  original_root->a->b = builder.create<Cloneable>("a->b");
+  original_root->b = builder.create<Cloneable>("b");
   original_root->b->a = original_root->a;  // Aliased
-  original_root->b->b = builder.create<Cloneable>();
+  original_root->b->b = builder.create<Cloneable>("b->b");
   original_root->c = original_root->b;  // Aliased
   Program original(std::move(builder));
 
@@ -101,16 +107,22 @@
   EXPECT_NE(cloned_root->b->b, original_root->b->b);
   EXPECT_NE(cloned_root->c, original_root->c);
 
+  EXPECT_EQ(cloned_root->name, "root");
+  EXPECT_EQ(cloned_root->a->name, "a");
+  EXPECT_EQ(cloned_root->a->b->name, "a->b");
+  EXPECT_EQ(cloned_root->b->name, "b");
+  EXPECT_EQ(cloned_root->b->b->name, "b->b");
+
   EXPECT_EQ(cloned_root->b->a, cloned_root->a);  // Aliased
   EXPECT_EQ(cloned_root->c, cloned_root->b);     // Aliased
 }
 
 TEST(CloneContext, CloneWithReplacements) {
   ProgramBuilder builder;
-  auto* original_root = builder.create<Cloneable>();
-  original_root->a = builder.create<Cloneable>();
-  original_root->a->b = builder.create<Replaceable>();
-  original_root->b = builder.create<Replaceable>();
+  auto* original_root = builder.create<Cloneable>("root");
+  original_root->a = builder.create<Cloneable>("a");
+  original_root->a->b = builder.create<Replaceable>("a->b");
+  original_root->b = builder.create<Replaceable>("b");
   original_root->b->a = original_root->a;  // Aliased
   original_root->c = original_root->b;     // Aliased
   Program original(std::move(builder));
@@ -127,14 +139,15 @@
   // R: Replaceable
 
   ProgramBuilder cloned;
-  auto* cloned_root = CloneContext(&cloned, &original)
-                          .ReplaceAll([&](CloneContext* ctx, Replaceable* in) {
-                            auto* out = cloned.create<Replacement>();
-                            out->b = cloned.create<Cloneable>();
-                            out->c = ctx->Clone(in->a);
-                            return out;
-                          })
-                          .Clone(original_root);
+  auto* cloned_root =
+      CloneContext(&cloned, &original)
+          .ReplaceAll([&](CloneContext* ctx, Replaceable* in) {
+            auto* out = cloned.create<Replacement>("replacement:" + in->name);
+            out->b = cloned.create<Cloneable>("replacement-child:" + in->name);
+            out->c = ctx->Clone(in->a);
+            return out;
+          })
+          .Clone(original_root);
 
   //                         root
   //        ╭─────────────────┼──────────────────╮
@@ -169,6 +182,13 @@
   EXPECT_NE(cloned_root->b->a, original_root->b->a);
   EXPECT_NE(cloned_root->c, original_root->c);
 
+  EXPECT_EQ(cloned_root->name, "root");
+  EXPECT_EQ(cloned_root->a->name, "a");
+  EXPECT_EQ(cloned_root->a->b->name, "replacement:a->b");
+  EXPECT_EQ(cloned_root->a->b->b->name, "replacement-child:a->b");
+  EXPECT_EQ(cloned_root->b->name, "replacement:b");
+  EXPECT_EQ(cloned_root->b->b->name, "replacement-child:b");
+
   EXPECT_EQ(cloned_root->b->c, cloned_root->a);  // Aliased
   EXPECT_EQ(cloned_root->c, cloned_root->b);     // Aliased
 
@@ -181,10 +201,10 @@
 
 TEST(CloneContext, CloneWithReplace) {
   ProgramBuilder builder;
-  auto* original_root = builder.create<Cloneable>();
-  original_root->a = builder.create<Cloneable>();
-  original_root->b = builder.create<Cloneable>();
-  original_root->c = builder.create<Cloneable>();
+  auto* original_root = builder.create<Cloneable>("root");
+  original_root->a = builder.create<Cloneable>("a");
+  original_root->b = builder.create<Cloneable>("b");
+  original_root->c = builder.create<Cloneable>("c");
   Program original(std::move(builder));
 
   //                          root
@@ -193,7 +213,7 @@
   //                        Replaced
 
   ProgramBuilder cloned;
-  auto* replacement = cloned.create<Cloneable>();
+  auto* replacement = cloned.create<Cloneable>("replacement");
 
   auto* cloned_root = CloneContext(&cloned, &original)
                           .Replace(original_root->b, replacement)
@@ -202,14 +222,47 @@
   EXPECT_NE(cloned_root->a, replacement);
   EXPECT_EQ(cloned_root->b, replacement);
   EXPECT_NE(cloned_root->c, replacement);
+
+  EXPECT_EQ(cloned_root->name, "root");
+  EXPECT_EQ(cloned_root->a->name, "a");
+  EXPECT_EQ(cloned_root->b->name, "replacement");
+  EXPECT_EQ(cloned_root->c->name, "c");
+}
+
+TEST(CloneContext, CloneWithInsertBefore) {
+  ProgramBuilder builder;
+  auto* original_root = builder.create<Cloneable>("root");
+  original_root->a = builder.create<Cloneable>("a");
+  original_root->b = builder.create<Cloneable>("b");
+  original_root->c = builder.create<Cloneable>("c");
+  original_root->vec = {original_root->a, original_root->b, original_root->c};
+  Program original(std::move(builder));
+
+  ProgramBuilder cloned;
+  auto* insertion = cloned.create<Cloneable>("insertion");
+
+  auto* cloned_root = CloneContext(&cloned, &original)
+                          .InsertBefore(original_root->b, insertion)
+                          .Clone(original_root);
+
+  EXPECT_EQ(cloned_root->vec.size(), 4u);
+  EXPECT_EQ(cloned_root->vec[0], cloned_root->a);
+  EXPECT_EQ(cloned_root->vec[2], cloned_root->b);
+  EXPECT_EQ(cloned_root->vec[3], cloned_root->c);
+
+  EXPECT_EQ(cloned_root->name, "root");
+  EXPECT_EQ(cloned_root->vec[0]->name, "a");
+  EXPECT_EQ(cloned_root->vec[1]->name, "insertion");
+  EXPECT_EQ(cloned_root->vec[2]->name, "b");
+  EXPECT_EQ(cloned_root->vec[3]->name, "c");
 }
 
 TEST(CloneContext, CloneWithReplace_WithNotACloneable) {
   ProgramBuilder builder;
-  auto* original_root = builder.create<Cloneable>();
-  original_root->a = builder.create<Cloneable>();
-  original_root->b = builder.create<Cloneable>();
-  original_root->c = builder.create<Cloneable>();
+  auto* original_root = builder.create<Cloneable>("root");
+  original_root->a = builder.create<Cloneable>("a");
+  original_root->b = builder.create<Cloneable>("b");
+  original_root->c = builder.create<Cloneable>("c");
   Program original(std::move(builder));
 
   //                          root