CloneContext: Add an overload of Replace() that takes a function

Replace(T* what, T* with) is bug-prone, as more complex transforms may want to clone `what` multiple times, or not at all. In both cases, this will likely result in an ICE as either the replacement will be reachable multiple times, or not at all.

This is the cause of some of the CTS failures reported in crbug.com/tint/993.

Bug: tint:993
Change-Id: I880ece45faab0e7f07230a1b4436f4e9846edc84
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/58221
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/clone_context.h b/src/clone_context.h
index 6c89ef7..8be0b9f 100644
--- a/src/clone_context.h
+++ b/src/clone_context.h
@@ -109,7 +109,9 @@
     // Was Replace() called for this object?
     auto it = replacements_.find(a);
     if (it != replacements_.end()) {
-      return CheckedCast<T>(it->second);
+      auto* replacement = it->second();
+      TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, dst, replacement);
+      return CheckedCast<T>(replacement);
     }
 
     Cloneable* cloned = nullptr;
@@ -342,8 +344,10 @@
     return *this;
   }
 
-  /// Replace replaces all occurrences of `what` in #src with `with` in #dst
-  /// when calling Clone().
+  /// Replace replaces all occurrences of `what` in #src with the pointer `with`
+  /// in #dst when calling Clone().
+  /// [DEPRECATED]: This function cannot handle nested replacements. Use the
+  /// overload of Replace() that take a function for the `WITH` argument.
   /// @param what a pointer to the object in #src that will be replaced with
   /// `with`
   /// @param with a pointer to the replacement object owned by #dst that will be
@@ -352,10 +356,32 @@
   /// references of the original object. A type mismatch will result in an
   /// assertion in debug builds, and undefined behavior in release builds.
   /// @returns this CloneContext so calls can be chained
-  template <typename WHAT, typename WITH>
+  template <typename WHAT,
+            typename WITH,
+            typename = traits::EnableIfIsType<WITH, Cloneable>>
   CloneContext& Replace(WHAT* what, WITH* with) {
     TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, src, what);
     TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, dst, with);
+    replacements_[what] = [with]() -> Cloneable* { return with; };
+    return *this;
+  }
+
+  /// Replace replaces all occurrences of `what` in #src with the result of the
+  /// function `with` in #dst when calling Clone(). `with` will be called each
+  /// time `what` is cloned by this context. If `what` is not cloned, then
+  /// `with` may never be called.
+  /// @param what a pointer to the object in #src that will be replaced with
+  /// `with`
+  /// @param with a function that takes no arguments and returns a pointer to
+  /// the replacement object owned by #dst. The returned pointer will be used as
+  /// a replacement for `what`.
+  /// @warning The replacement object must be of the correct type for all
+  /// references of the original object. A type mismatch will result in an
+  /// assertion in debug builds, and undefined behavior in release builds.
+  /// @returns this CloneContext so calls can be chained
+  template <typename WHAT, typename WITH, typename = std::result_of_t<WITH()>>
+  CloneContext& Replace(WHAT* what, WITH&& with) {
+    TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, src, what);
     replacements_[what] = with;
     return *this;
   }
@@ -532,8 +558,10 @@
     std::unordered_map<const Cloneable*, CloneableList> insert_after_;
   };
 
-  /// A map of object in #src to their replacement in #dst
-  std::unordered_map<const Cloneable*, Cloneable*> replacements_;
+  /// A map of object in #src to functions that create their replacement in
+  /// #dst
+  std::unordered_map<const Cloneable*, std::function<Cloneable*()>>
+      replacements_;
 
   /// A map of symbol in #src to their cloned equivalent in #dst
   std::unordered_map<Symbol, Symbol> cloned_symbols_;
diff --git a/src/clone_context_test.cc b/src/clone_context_test.cc
index a9865a2..391c7c6 100644
--- a/src/clone_context_test.cc
+++ b/src/clone_context_test.cc
@@ -291,7 +291,7 @@
   EXPECT_EQ(cloned_node->name, cloned.Symbols().Get("root"));
 }
 
-TEST_F(CloneContextNodeTest, CloneWithReplace) {
+TEST_F(CloneContextNodeTest, CloneWithReplacePointer) {
   Allocator a;
 
   ProgramBuilder builder;
@@ -323,6 +323,39 @@
   EXPECT_EQ(cloned_root->c->name, cloned.Symbols().Get("c"));
 }
 
+TEST_F(CloneContextNodeTest, CloneWithReplaceFunction) {
+  Allocator a;
+
+  ProgramBuilder builder;
+  auto* original_root = a.Create<Node>(builder.Symbols().New("root"));
+  original_root->a = a.Create<Node>(builder.Symbols().New("a"));
+  original_root->b = a.Create<Node>(builder.Symbols().New("b"));
+  original_root->c = a.Create<Node>(builder.Symbols().New("c"));
+  Program original(std::move(builder));
+
+  //                          root
+  //        ╭──────────────────┼──────────────────╮
+  //       (a)                (b)                (c)
+  //                        Replaced
+
+  ProgramBuilder cloned;
+  auto* replacement = a.Create<Node>(cloned.Symbols().New("replacement"));
+
+  auto* cloned_root =
+      CloneContext(&cloned, &original)
+          .Replace(original_root->b, [=] { return replacement; })
+          .Clone(original_root);
+
+  EXPECT_NE(cloned_root->a, replacement);
+  EXPECT_EQ(cloned_root->b, replacement);
+  EXPECT_NE(cloned_root->c, replacement);
+
+  EXPECT_EQ(cloned_root->name, cloned.Symbols().Get("root"));
+  EXPECT_EQ(cloned_root->a->name, cloned.Symbols().Get("a"));
+  EXPECT_EQ(cloned_root->b->name, cloned.Symbols().Get("replacement"));
+  EXPECT_EQ(cloned_root->c->name, cloned.Symbols().Get("c"));
+}
+
 TEST_F(CloneContextNodeTest, CloneWithRemove) {
   Allocator a;
 
@@ -638,7 +671,7 @@
           replaceable_name);
 }
 
-TEST_F(CloneContextNodeTest, CloneWithReplace_WithNotANode) {
+TEST_F(CloneContextNodeTest, CloneWithReplacePointer_WithNotANode) {
   EXPECT_FATAL_FAILURE(
       {
         Allocator allocator;
@@ -666,6 +699,34 @@
       "internal compiler error");
 }
 
+TEST_F(CloneContextNodeTest, CloneWithReplaceFunction_WithNotANode) {
+  EXPECT_FATAL_FAILURE(
+      {
+        Allocator allocator;
+        ProgramBuilder builder;
+        auto* original_root =
+            allocator.Create<Node>(builder.Symbols().New("root"));
+        original_root->a = allocator.Create<Node>(builder.Symbols().New("a"));
+        original_root->b = allocator.Create<Node>(builder.Symbols().New("b"));
+        original_root->c = allocator.Create<Node>(builder.Symbols().New("c"));
+        Program original(std::move(builder));
+
+        //                          root
+        //        ╭──────────────────┼──────────────────╮
+        //       (a)                (b)                (c)
+        //                        Replaced
+
+        ProgramBuilder cloned;
+        auto* replacement = allocator.Create<NotANode>();
+
+        CloneContext ctx(&cloned, &original);
+        ctx.Replace(original_root->b, [=] { return replacement; });
+
+        ctx.Clone(original_root);
+      },
+      "internal compiler error");
+}
+
 using CloneContextTest = ::testing::Test;
 
 TEST_F(CloneContextTest, CloneWithReplaceAll_SymbolsTwice) {