CloneContext: Support inplace cloning

Add a new constructor that only takes a ProgramBuilder.
This allows cloning objects to and from the same ProgramBuilder.

Also clean up tests.

Change-Id: I7c7bbaced4956f9094d0a6231aa4d7f7b6f17d4c
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/49744
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/clone_context.cc b/src/clone_context.cc
index 1f6b09b..b0ea74e 100644
--- a/src/clone_context.cc
+++ b/src/clone_context.cc
@@ -39,9 +39,15 @@
   }
 }
 
+CloneContext::CloneContext(ProgramBuilder* builder)
+    : CloneContext(builder, nullptr, false) {}
+
 CloneContext::~CloneContext() = default;
 
 Symbol CloneContext::Clone(Symbol s) {
+  if (!src) {
+    return s;  // In-place clone
+  }
   return utils::GetOrCreate(cloned_symbols_, s, [&]() -> Symbol {
     if (symbol_transform_) {
       return symbol_transform_(s);
diff --git a/src/clone_context.h b/src/clone_context.h
index ad02692..2623f63 100644
--- a/src/clone_context.h
+++ b/src/clone_context.h
@@ -74,7 +74,7 @@
   /// symbol.
   using SymbolTransform = std::function<Symbol(Symbol)>;
 
-  /// Constructor
+  /// Constructor for cloning objects from `from` into `to`.
   /// @param to the target ProgramBuilder to clone into
   /// @param from the source Program to clone from
   /// @param auto_clone_symbols clone all symbols in `from` before returning
@@ -82,6 +82,10 @@
                Program const* from,
                bool auto_clone_symbols = true);
 
+  /// Constructor for cloning objects from and to the ProgramBuilder `builder`.
+  /// @param builder the ProgramBuilder
+  explicit CloneContext(ProgramBuilder* builder);
+
   /// Destructor
   ~CloneContext();
 
@@ -93,7 +97,8 @@
   /// Clone() may use a function registered with ReplaceAll() to create a
   /// transformed version of the object. See ReplaceAll() for more information.
   ///
-  /// The Node or sem::Type `a` must be owned by the Program #src.
+  /// If the CloneContext is cloning from a Program to a ProgramBuilder, then
+  /// the Node or sem::Type `a` must be owned by the Program #src.
   ///
   /// @param a the `Node` or `sem::Type` to clone
   /// @return the cloned node
@@ -104,7 +109,9 @@
       return nullptr;
     }
 
-    TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(src, a);
+    if (src) {
+      TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(src, a);
+    }
 
     // Have we cloned this object already, or was Replace() called for this
     // object?
@@ -133,7 +140,7 @@
 
     // Does the type derive from ShareableCloneable?
     if (Is<ShareableCloneable, kDontErrorOnImpossibleCast>(a)) {
-      // Yes. Record this src -> dst mapping so that future calls to Clone()
+      // Yes. Record this clone mapping so that future calls to Clone()
       // return the same cloned object.
       cloned_.emplace(a, cloned);
     }
@@ -153,7 +160,8 @@
   /// Unlike Clone(), this method does not invoke or use any transformations
   /// registered by ReplaceAll().
   ///
-  /// The Node or sem::Type `a` must be owned by the Program #src.
+  /// If the CloneContext is cloning from a Program to a ProgramBuilder, then
+  /// the Node or sem::Type `a` must be owned by the Program #src.
   ///
   /// @param a the `Node` or `sem::Type` to clone
   /// @return the cloned node
@@ -164,7 +172,9 @@
       return nullptr;
     }
 
-    TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(src, a);
+    if (src) {
+      TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(src, a);
+    }
 
     // Have we seen this object before? If so, return the previously cloned
     // version instead of making yet another copy.
@@ -371,7 +381,7 @@
   /// @returns this CloneContext so calls can be chained
   template <typename T, typename BEFORE, typename OBJECT>
   CloneContext& InsertBefore(const std::vector<T>& vector,
-                             BEFORE* before,
+                             const BEFORE* before,
                              OBJECT* object) {
     if (std::find(vector.begin(), vector.end(), before) == vector.end()) {
       TINT_ICE(Diagnostics())
@@ -393,7 +403,7 @@
   /// @returns this CloneContext so calls can be chained
   template <typename T, typename AFTER, typename OBJECT>
   CloneContext& InsertAfter(const std::vector<T>& vector,
-                            AFTER* after,
+                            const AFTER* after,
                             OBJECT* object) {
     if (std::find(vector.begin(), vector.end(), after) == vector.end()) {
       TINT_ICE(Diagnostics())
diff --git a/src/clone_context_test.cc b/src/clone_context_test.cc
index de8c9e4..d5086c5 100644
--- a/src/clone_context_test.cc
+++ b/src/clone_context_test.cc
@@ -126,17 +126,17 @@
 };
 
 template <typename T>
-struct CloneContextTest : public ::testing::Test {
+struct CloneContextNodeTest : public ::testing::Test {
   using Node = typename T::Node;
   using Replaceable = typename T::Replaceable;
   using Replacement = typename T::Replacement;
   static constexpr bool is_unique = std::is_same<Node, UniqueNode>::value;
 };
 
-using CloneContextTestTypes = ::testing::Types<UniqueTypes, ShareableTypes>;
-TYPED_TEST_SUITE(CloneContextTest, CloneContextTestTypes, /**/);
+using CloneContextTestNodeTypes = ::testing::Types<UniqueTypes, ShareableTypes>;
+TYPED_TEST_SUITE(CloneContextNodeTest, CloneContextTestNodeTypes, /**/);
 
-TYPED_TEST(CloneContextTest, Clone) {
+TYPED_TEST(CloneContextNodeTest, Clone) {
   using Node = typename TestFixture::Node;
   constexpr bool is_unique = TestFixture::is_unique;
 
@@ -199,7 +199,7 @@
   EXPECT_EQ(cloned_root->c->name, cloned_root->b->name);
 }
 
-TYPED_TEST(CloneContextTest, CloneWithReplaceAll_Cloneable) {
+TYPED_TEST(CloneContextNodeTest, CloneWithReplaceAll_Cloneable) {
   using Node = typename TestFixture::Node;
   using Replaceable = typename TestFixture::Replaceable;
   using Replacement = typename TestFixture::Replacement;
@@ -301,7 +301,7 @@
   EXPECT_FALSE(Is<Replacement>(cloned_root->b->b));
 }
 
-TYPED_TEST(CloneContextTest, CloneWithReplaceAll_Symbols) {
+TYPED_TEST(CloneContextNodeTest, CloneWithReplaceAll_Symbols) {
   using Node = typename TestFixture::Node;
 
   Allocator a;
@@ -342,7 +342,7 @@
   EXPECT_EQ(cloned_root->b->b->name, cloned.Symbols().Get("transformed<b->b>"));
 }
 
-TYPED_TEST(CloneContextTest, CloneWithoutTransform) {
+TYPED_TEST(CloneContextNodeTest, CloneWithoutTransform) {
   using Node = typename TestFixture::Node;
   using Replacement = typename TestFixture::Replacement;
 
@@ -363,7 +363,7 @@
   EXPECT_EQ(cloned_node->name, cloned.Symbols().Get("root"));
 }
 
-TYPED_TEST(CloneContextTest, CloneWithReplace) {
+TYPED_TEST(CloneContextNodeTest, CloneWithReplace) {
   using Node = typename TestFixture::Node;
 
   Allocator a;
@@ -397,7 +397,7 @@
   EXPECT_EQ(cloned_root->c->name, cloned.Symbols().Get("c"));
 }
 
-TYPED_TEST(CloneContextTest, CloneWithInsertBefore) {
+TYPED_TEST(CloneContextNodeTest, CloneWithInsertBefore) {
   using Node = typename TestFixture::Node;
   constexpr bool is_unique = TestFixture::is_unique;
 
@@ -437,7 +437,7 @@
   EXPECT_EQ(cloned_root->vec[3]->name, cloned.Symbols().Get("c"));
 }
 
-TYPED_TEST(CloneContextTest, CloneWithInsertAfter) {
+TYPED_TEST(CloneContextNodeTest, CloneWithInsertAfter) {
   using Node = typename TestFixture::Node;
   constexpr bool is_unique = TestFixture::is_unique;
 
@@ -477,7 +477,26 @@
   EXPECT_EQ(cloned_root->vec[3]->name, cloned.Symbols().Get("c"));
 }
 
-TYPED_TEST(CloneContextTest, CloneWithReplaceAll_SameTypeTwice) {
+TYPED_TEST(CloneContextNodeTest, CloneIntoSameBuilder) {
+  using Node = typename TestFixture::Node;
+  constexpr bool is_unique = TestFixture::is_unique;
+
+  ProgramBuilder builder;
+  CloneContext ctx(&builder);
+  Allocator allocator;
+  auto* original = allocator.Create<Node>(builder.Symbols().New());
+  auto* cloned_a = ctx.Clone(original);
+  auto* cloned_b = ctx.Clone(original);
+  EXPECT_NE(original, cloned_a);
+  EXPECT_NE(original, cloned_b);
+  if (is_unique) {
+    EXPECT_NE(cloned_a, cloned_b);
+  } else {
+    EXPECT_EQ(cloned_a, cloned_b);
+  }
+}
+
+TYPED_TEST(CloneContextNodeTest, CloneWithReplaceAll_SameTypeTwice) {
   std::string node_name = TypeInfo::Of<typename TestFixture::Node>().name;
 
   EXPECT_FATAL_FAILURE(
@@ -494,7 +513,7 @@
           node_name);
 }
 
-TYPED_TEST(CloneContextTest, CloneWithReplaceAll_BaseThenDerived) {
+TYPED_TEST(CloneContextNodeTest, CloneWithReplaceAll_BaseThenDerived) {
   std::string node_name = TypeInfo::Of<typename TestFixture::Node>().name;
   std::string replaceable_name =
       TypeInfo::Of<typename TestFixture::Replaceable>().name;
@@ -515,7 +534,7 @@
           node_name);
 }
 
-TYPED_TEST(CloneContextTest, CloneWithReplaceAll_DerivedThenBase) {
+TYPED_TEST(CloneContextNodeTest, CloneWithReplaceAll_DerivedThenBase) {
   std::string node_name = TypeInfo::Of<typename TestFixture::Node>().name;
   std::string replaceable_name =
       TypeInfo::Of<typename TestFixture::Replaceable>().name;
@@ -535,20 +554,7 @@
           replaceable_name);
 }
 
-TYPED_TEST(CloneContextTest, CloneWithReplaceAll_SymbolsTwice) {
-  EXPECT_FATAL_FAILURE(
-      {
-        ProgramBuilder cloned;
-        Program original;
-        CloneContext ctx(&cloned, &original);
-        ctx.ReplaceAll([](Symbol s) { return s; });
-        ctx.ReplaceAll([](Symbol s) { return s; });
-      },
-      "internal compiler error: ReplaceAll(const SymbolTransform&) called "
-      "multiple times on the same CloneContext");
-}
-
-TYPED_TEST(CloneContextTest, CloneWithReplace_WithNotANode) {
+TYPED_TEST(CloneContextNodeTest, CloneWithReplace_WithNotANode) {
   EXPECT_FATAL_FAILURE(
       {
         using Node = typename TestFixture::Node;
@@ -577,7 +583,22 @@
       "internal compiler error");
 }
 
-TYPED_TEST(CloneContextTest, CloneNewUnnamedSymbols) {
+using CloneContextTest = ::testing::Test;
+
+TEST_F(CloneContextTest, CloneWithReplaceAll_SymbolsTwice) {
+  EXPECT_FATAL_FAILURE(
+      {
+        ProgramBuilder cloned;
+        Program original;
+        CloneContext ctx(&cloned, &original);
+        ctx.ReplaceAll([](Symbol s) { return s; });
+        ctx.ReplaceAll([](Symbol s) { return s; });
+      },
+      "internal compiler error: ReplaceAll(const SymbolTransform&) called "
+      "multiple times on the same CloneContext");
+}
+
+TEST_F(CloneContextTest, CloneNewUnnamedSymbols) {
   ProgramBuilder builder;
   Symbol old_a = builder.Symbols().New();
   Symbol old_b = builder.Symbols().New();
@@ -605,7 +626,7 @@
   EXPECT_EQ(cloned.Symbols().NameFor(new_c), "tint_symbol_2_1");
 }
 
-TYPED_TEST(CloneContextTest, CloneNewSymbols) {
+TEST_F(CloneContextTest, CloneNewSymbols) {
   ProgramBuilder builder;
   Symbol old_a = builder.Symbols().New("a");
   Symbol old_b = builder.Symbols().New("b");
@@ -633,7 +654,7 @@
   EXPECT_EQ(cloned.Symbols().NameFor(new_c), "c_1");
 }
 
-TYPED_TEST(CloneContextTest, CloneNewSymbols_AfterCloneSymbols) {
+TEST_F(CloneContextTest, CloneNewSymbols_AfterCloneSymbols) {
   ProgramBuilder builder;
   Symbol old_a = builder.Symbols().New("a");
   Symbol old_b = builder.Symbols().New("b");
@@ -661,15 +682,16 @@
   EXPECT_EQ(cloned.Symbols().NameFor(new_c), "c");
 }
 
-TYPED_TEST(CloneContextTest, ProgramIDs) {
+TEST_F(CloneContextTest, ProgramIDs) {
   ProgramBuilder dst;
   Program src(ProgramBuilder{});
   CloneContext ctx(&dst, &src);
   Allocator allocator;
-  ctx.Clone(allocator.Create<ProgramNode>(src.ID(), dst.ID()));
+  auto* cloned = ctx.Clone(allocator.Create<ProgramNode>(src.ID(), dst.ID()));
+  EXPECT_EQ(cloned->program_id, dst.ID());
 }
 
-TYPED_TEST(CloneContextTest, ProgramIDs_ObjectNotOwnedBySrc) {
+TEST_F(CloneContextTest, ProgramIDs_ObjectNotOwnedBySrc) {
   EXPECT_FATAL_FAILURE(
       {
         ProgramBuilder dst;
@@ -681,7 +703,7 @@
       R"(internal compiler error: TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(src, a))");
 }
 
-TYPED_TEST(CloneContextTest, ProgramIDs_ObjectNotOwnedByDst) {
+TEST_F(CloneContextTest, ProgramIDs_ObjectNotOwnedByDst) {
   EXPECT_FATAL_FAILURE(
       {
         ProgramBuilder dst;