CloneContext: Add Remove()

Omits an object from a vector when that vector is cloned

Bug: tint:183
Change-Id: I543c885609591dcd3b930ca00b8c1a78bc61f920
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/51301
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/clone_context.h b/src/clone_context.h
index 40017a5..b00e547 100644
--- a/src/clone_context.h
+++ b/src/clone_context.h
@@ -18,6 +18,7 @@
 #include <algorithm>
 #include <functional>
 #include <unordered_map>
+#include <unordered_set>
 #include <utility>
 #include <vector>
 
@@ -248,6 +249,9 @@
     if (list_transform_it != list_transforms_.end()) {
       const auto& transforms = list_transform_it->second;
       for (auto& el : v) {
+        if (transforms.remove_.count(el)) {
+          continue;
+        }
         auto insert_before_it = transforms.insert_before_.find(el);
         if (insert_before_it != transforms.insert_before_.end()) {
           for (auto insert : insert_before_it->second) {
@@ -369,10 +373,30 @@
   /// @returns this CloneContext so calls can be chained
   template <typename WHAT, typename WITH>
   CloneContext& Replace(WHAT* what, WITH* with) {
+    TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(src, what);
+    TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(dst, with);
     cloned_[what] = with;
     return *this;
   }
 
+  /// Removes `object` from the cloned copy of `vector`.
+  /// @param vector the vector in #src
+  /// @param object a pointer to the object in #src that will be omitted from
+  /// the cloned vector.
+  /// @returns this CloneContext so calls can be chained
+  template <typename T, typename OBJECT>
+  CloneContext& Remove(const std::vector<T>& vector, OBJECT* object) {
+    TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(src, object);
+    if (std::find(vector.begin(), vector.end(), object) == vector.end()) {
+      TINT_ICE(Diagnostics())
+          << "CloneContext::Remove() vector does not contain object";
+      return *this;
+    }
+
+    list_transforms_[&vector].remove_.emplace(object);
+    return *this;
+  }
+
   /// Inserts `object` before `before` whenever `vector` is cloned.
   /// @param vector the vector in #src
   /// @param before a pointer to the object in #src
@@ -383,6 +407,8 @@
   CloneContext& InsertBefore(const std::vector<T>& vector,
                              const BEFORE* before,
                              OBJECT* object) {
+    TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(src, before);
+    TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(dst, object);
     if (std::find(vector.begin(), vector.end(), before) == vector.end()) {
       TINT_ICE(Diagnostics())
           << "CloneContext::InsertBefore() vector does not contain before";
@@ -405,6 +431,8 @@
   CloneContext& InsertAfter(const std::vector<T>& vector,
                             const AFTER* after,
                             OBJECT* object) {
+    TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(src, after);
+    TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(dst, object);
     if (std::find(vector.begin(), vector.end(), after) == vector.end()) {
       TINT_ICE(Diagnostics())
           << "CloneContext::InsertAfter() vector does not contain after";
@@ -453,7 +481,10 @@
     if (TO* cast = As<TO>(obj)) {
       return cast;
     }
-    TINT_ICE(Diagnostics()) << "Cloned object was not of the expected type";
+    TINT_ICE(Diagnostics())
+        << "Cloned object was not of the expected type\n"
+        << "got:      " << (obj ? obj->TypeInfo().name : "<null>") << "\n"
+        << "expected: " << TypeInfo::Of<TO>().name;
     return nullptr;
   }
 
@@ -470,6 +501,9 @@
     /// Destructor
     ~ListTransforms();
 
+    /// A map of object in #src to omit when cloned into #dst.
+    std::unordered_set<const Cloneable*> remove_;
+
     /// A map of object in #src to the list of cloned objects in #dst.
     /// Clone(const std::vector<T*>& v) will use this to insert the map-value
     /// list into the target vector before cloning and inserting the map-key.
diff --git a/src/clone_context_test.cc b/src/clone_context_test.cc
index d5086c5..4ec9585 100644
--- a/src/clone_context_test.cc
+++ b/src/clone_context_test.cc
@@ -397,6 +397,39 @@
   EXPECT_EQ(cloned_root->c->name, cloned.Symbols().Get("c"));
 }
 
+TYPED_TEST(CloneContextNodeTest, CloneWithRemove) {
+  using Node = typename TestFixture::Node;
+  constexpr bool is_unique = TestFixture::is_unique;
+
+  Allocator a;
+
+  ProgramBuilder builder;
+  auto* original_root = a.Create<Node>(builder.Symbols().Register("root"));
+  original_root->a = a.Create<Node>(builder.Symbols().Register("a"));
+  original_root->b = a.Create<Node>(builder.Symbols().Register("b"));
+  original_root->c = a.Create<Node>(builder.Symbols().Register("c"));
+  original_root->vec = {original_root->a, original_root->b, original_root->c};
+  Program original(std::move(builder));
+
+  ProgramBuilder cloned;
+  auto* cloned_root = CloneContext(&cloned, &original)
+                          .Remove(original_root->vec, original_root->b)
+                          .Clone(original_root);
+
+  EXPECT_EQ(cloned_root->vec.size(), 2u);
+  if (is_unique) {
+    EXPECT_NE(cloned_root->vec[0], cloned_root->a);
+    EXPECT_NE(cloned_root->vec[1], cloned_root->c);
+  } else {
+    EXPECT_EQ(cloned_root->vec[0], cloned_root->a);
+    EXPECT_EQ(cloned_root->vec[1], cloned_root->c);
+  }
+
+  EXPECT_EQ(cloned_root->name, cloned.Symbols().Get("root"));
+  EXPECT_EQ(cloned_root->vec[0]->name, cloned.Symbols().Get("a"));
+  EXPECT_EQ(cloned_root->vec[1]->name, cloned.Symbols().Get("c"));
+}
+
 TYPED_TEST(CloneContextNodeTest, CloneWithInsertBefore) {
   using Node = typename TestFixture::Node;
   constexpr bool is_unique = TestFixture::is_unique;
@@ -691,7 +724,7 @@
   EXPECT_EQ(cloned->program_id, dst.ID());
 }
 
-TEST_F(CloneContextTest, ProgramIDs_ObjectNotOwnedBySrc) {
+TEST_F(CloneContextTest, ProgramIDs_Clone_ObjectNotOwnedBySrc) {
   EXPECT_FATAL_FAILURE(
       {
         ProgramBuilder dst;
@@ -703,7 +736,7 @@
       R"(internal compiler error: TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(src, a))");
 }
 
-TEST_F(CloneContextTest, ProgramIDs_ObjectNotOwnedByDst) {
+TEST_F(CloneContextTest, ProgramIDs_Clone_ObjectNotOwnedByDst) {
   EXPECT_FATAL_FAILURE(
       {
         ProgramBuilder dst;