CloneContext: Add support for transforming symbols

Will be used by a Renamer transform

Change-Id: Ic0e9b69874f51103f0beec7745d32a9f8419e93a
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/42841
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/clone_context.cc b/src/clone_context.cc
index 1ff9ff2..edc6e8b 100644
--- a/src/clone_context.cc
+++ b/src/clone_context.cc
@@ -28,6 +28,9 @@
 CloneContext::~CloneContext() = default;
 
 Symbol CloneContext::Clone(const Symbol& s) const {
+  if (symbol_transform_) {
+    return symbol_transform_(s);
+  }
   return dst->Symbols().Register(src->Symbols().NameFor(s));
 }
 
@@ -48,4 +51,9 @@
   return dst->Diagnostics();
 }
 
+CloneContext::CloneableTransform::CloneableTransform() = default;
+CloneContext::CloneableTransform::CloneableTransform(
+    const CloneableTransform&) = default;
+CloneContext::CloneableTransform::~CloneableTransform() = default;
+
 }  // namespace tint
diff --git a/src/clone_context.h b/src/clone_context.h
index 31c18c7..94851c7 100644
--- a/src/clone_context.h
+++ b/src/clone_context.h
@@ -18,6 +18,7 @@
 #include <cassert>
 #include <functional>
 #include <unordered_map>
+#include <utility>
 #include <vector>
 
 #include "src/castable.h"
@@ -48,7 +49,18 @@
 
 /// CloneContext holds the state used while cloning AST nodes and types.
 class CloneContext {
+  /// ParamTypeIsPtrOf<F, T>::value is true iff the first parameter of
+  /// F is a pointer of (or derives from) type T.
+  template <typename F, typename T>
+  using ParamTypeIsPtrOf = traits::IsTypeOrDerived<
+      typename std::remove_pointer<traits::ParamTypeT<F, 0>>::type,
+      T>;
+
  public:
+  /// SymbolTransform is a function that takes a symbol and returns a new
+  /// symbol.
+  using SymbolTransform = std::function<Symbol(Symbol)>;
+
   /// Constructor
   /// @param to the target ProgramBuilder to clone into
   /// @param from the source Program to clone from
@@ -199,10 +211,8 @@
   /// `replacer` must be function-like with the signature: `T* (T*)`
   ///  where `T` is a type deriving from Cloneable.
   ///
-  /// If `replacer` returns a nullptr then Clone() will attempt the next
-  /// registered replacer function that matches the object type. If no replacers
-  /// match the object type, or all returned nullptr then Clone() will call
-  /// `T::Clone()` to clone the object.
+  /// If `replacer` returns a nullptr then Clone() will call `T::Clone()` to
+  /// clone the object.
   ///
   /// Example:
   ///
@@ -218,6 +228,9 @@
   ///   ctx.Clone();
   /// ```
   ///
+  /// @warning a single handler can only be registered for any given type.
+  /// Attempting to register two handlers for the same type will result in an
+  /// ICE.
   /// @warning The replacement object must be of the correct type for all
   /// references of the original object. A type mismatch will result in an
   /// assertion in debug builds, and undefined behavior in release builds.
@@ -225,13 +238,44 @@
   ///        `T* (T*)`, where `T` derives from Cloneable
   /// @returns this CloneContext so calls can be chained
   template <typename F>
-  CloneContext& ReplaceAll(F&& replacer) {
+  traits::EnableIf<ParamTypeIsPtrOf<F, Cloneable>::value, CloneContext>&
+  ReplaceAll(F&& replacer) {
     using TPtr = traits::ParamTypeT<F, 0>;
     using T = typename std::remove_pointer<TPtr>::type;
-    transforms_.emplace_back([=](Cloneable* in) {
-      auto* in_as_t = in->As<T>();
-      return in_as_t != nullptr ? replacer(in_as_t) : nullptr;
-    });
+    for (auto& transform : transforms_) {
+      if (transform.typeinfo->Is(TypeInfo::Of<T>()) ||
+          TypeInfo::Of<T>().Is(*transform.typeinfo)) {
+        TINT_ICE(Diagnostics())
+            << "ReplaceAll() called with a handler for type "
+            << TypeInfo::Of<T>().name
+            << " that is already handled by a handler for type "
+            << transform.typeinfo->name;
+        return *this;
+      }
+    }
+    CloneableTransform transform;
+    transform.typeinfo = &TypeInfo::Of<T>();
+    transform.function = [=](Cloneable* in) { return replacer(in->As<T>()); };
+    transforms_.emplace_back(std::move(transform));
+    return *this;
+  }
+
+  /// ReplaceAll() registers `replacer` to be called whenever the Clone() method
+  /// is called with a Symbol.
+  /// The returned symbol of `replacer` will be used as the replacement for
+  /// all references to the symbol that's being cloned. This returned Symbol
+  /// must be owned by the Program #dst.
+  /// @param replacer a function the signature `Symbol(Symbol)`.
+  /// @warning a SymbolTransform can only be registered once. Attempting to
+  /// register a SymbolTransform more than once will result in an ICE.
+  /// @returns this CloneContext so calls can be chained
+  CloneContext& ReplaceAll(const SymbolTransform& replacer) {
+    if (symbol_transform_) {
+      TINT_ICE(Diagnostics()) << "ReplaceAll(const SymbolTransform&) called "
+                                 "multiple times on the same CloneContext";
+      return *this;
+    }
+    symbol_transform_ = replacer;
     return *this;
   }
 
@@ -276,7 +320,19 @@
   Program const* const src;
 
  private:
-  using Transform = std::function<Cloneable*(Cloneable*)>;
+  struct CloneableTransform {
+    /// Constructor
+    CloneableTransform();
+    /// Copy constructor
+    /// @param other the CloneableTransform to copy
+    CloneableTransform(const CloneableTransform& other);
+    /// Destructor
+    ~CloneableTransform();
+
+    // TypeInfo of the Cloneable that the transform operates on
+    const TypeInfo* typeinfo;
+    std::function<Cloneable*(Cloneable*)> function;
+  };
 
   CloneContext(const CloneContext&) = delete;
   CloneContext& operator=(const CloneContext&) = delete;
@@ -293,11 +349,16 @@
     }
 
     // Attempt to clone using the registered replacer functions.
-    for (auto& f : transforms_) {
-      if (Cloneable* c = f(a)) {
+    auto& typeinfo = a->TypeInfo();
+    for (auto& transform : transforms_) {
+      if (!typeinfo.Is(*transform.typeinfo)) {
+        continue;
+      }
+      if (Cloneable* c = transform.function(a)) {
         cloned_.emplace(a, c);
         return c;
       }
+      break;
     }
 
     // No luck, Clone() will have to call T::Clone().
@@ -329,8 +390,11 @@
   /// into the target vector/ before cloning and inserting the map-key.
   std::unordered_map<Cloneable*, CloneableList> insert_before_;
 
-  /// Transform functions registered with ReplaceAll()
-  std::vector<Transform> transforms_;
+  /// Cloneable transform functions registered with ReplaceAll()
+  std::vector<CloneableTransform> transforms_;
+
+  /// Symbol transform registered with ReplaceAll()
+  SymbolTransform symbol_transform_;
 };
 
 }  // namespace tint
diff --git a/src/clone_context_test.cc b/src/clone_context_test.cc
index 5f38c99..99e1137 100644
--- a/src/clone_context_test.cc
+++ b/src/clone_context_test.cc
@@ -26,16 +26,16 @@
 namespace {
 
 struct Node : public Castable<Node, ast::Node> {
-  explicit Node(const Source& source, std::string n) : Base(source), name(n) {}
+  explicit Node(const Source& source, Symbol n) : Base(source), name(n) {}
 
-  std::string name;
+  Symbol name;
   Node* a = nullptr;
   Node* b = nullptr;
   Node* c = nullptr;
   std::vector<Node*> vec;
 
   Node* Clone(CloneContext* ctx) const override {
-    auto* out = ctx->dst->create<Node>(name);
+    auto* out = ctx->dst->create<Node>(ctx->Clone(name));
     out->a = ctx->Clone(a);
     out->b = ctx->Clone(b);
     out->c = ctx->Clone(c);
@@ -48,10 +48,10 @@
 };
 
 struct Replaceable : public Castable<Replaceable, Node> {
-  explicit Replaceable(const Source& source, std::string n) : Base(source, n) {}
+  explicit Replaceable(const Source& source, Symbol n) : Base(source, n) {}
 };
 struct Replacement : public Castable<Replacement, Replaceable> {
-  explicit Replacement(const Source& source, std::string n) : Base(source, n) {}
+  explicit Replacement(const Source& source, Symbol n) : Base(source, n) {}
 };
 
 struct NotANode : public Castable<NotANode, ast::Node> {
@@ -67,12 +67,15 @@
 
 TEST(CloneContext, Clone) {
   ProgramBuilder builder;
-  auto* original_root = builder.create<Node>("root");
-  original_root->a = builder.create<Node>("a");
-  original_root->a->b = builder.create<Node>("a->b");
-  original_root->b = builder.create<Node>("b");
+  auto* original_root =
+      builder.create<Node>(builder.Symbols().Register("root"));
+  original_root->a = builder.create<Node>(builder.Symbols().Register("a"));
+  original_root->a->b =
+      builder.create<Node>(builder.Symbols().Register("a->b"));
+  original_root->b = builder.create<Node>(builder.Symbols().Register("b"));
   original_root->b->a = original_root->a;  // Aliased
-  original_root->b->b = builder.create<Node>("b->b");
+  original_root->b->b =
+      builder.create<Node>(builder.Symbols().Register("b->b"));
   original_root->c = original_root->b;  // Aliased
   Program original(std::move(builder));
 
@@ -106,22 +109,25 @@
   EXPECT_NE(cloned_root->b->b, original_root->b->b);
   EXPECT_NE(cloned_root->c, original_root->c);
 
-  EXPECT_EQ(cloned_root->name, "root");
-  EXPECT_EQ(cloned_root->a->name, "a");
-  EXPECT_EQ(cloned_root->a->b->name, "a->b");
-  EXPECT_EQ(cloned_root->b->name, "b");
-  EXPECT_EQ(cloned_root->b->b->name, "b->b");
+  EXPECT_EQ(cloned_root->name, cloned.Symbols().Get("root"));
+  EXPECT_EQ(cloned_root->a->name, cloned.Symbols().Get("a"));
+  EXPECT_EQ(cloned_root->a->b->name, cloned.Symbols().Get("a->b"));
+  EXPECT_EQ(cloned_root->b->name, cloned.Symbols().Get("b"));
+  EXPECT_EQ(cloned_root->b->b->name, cloned.Symbols().Get("b->b"));
 
   EXPECT_EQ(cloned_root->b->a, cloned_root->a);  // Aliased
   EXPECT_EQ(cloned_root->c, cloned_root->b);     // Aliased
 }
 
-TEST(CloneContext, CloneWithReplacements) {
+TEST(CloneContext, CloneWithReplaceAll_Cloneable) {
   ProgramBuilder builder;
-  auto* original_root = builder.create<Node>("root");
-  original_root->a = builder.create<Node>("a");
-  original_root->a->b = builder.create<Replaceable>("a->b");
-  original_root->b = builder.create<Replaceable>("b");
+  auto* original_root =
+      builder.create<Node>(builder.Symbols().Register("root"));
+  original_root->a = builder.create<Node>(builder.Symbols().Register("a"));
+  original_root->a->b =
+      builder.create<Replaceable>(builder.Symbols().Register("a->b"));
+  original_root->b =
+      builder.create<Replaceable>(builder.Symbols().Register("b"));
   original_root->b->a = original_root->a;  // Aliased
   original_root->c = original_root->b;     // Aliased
   Program original(std::move(builder));
@@ -141,8 +147,12 @@
 
   CloneContext ctx(&cloned, &original);
   ctx.ReplaceAll([&](Replaceable* in) {
-    auto* out = cloned.create<Replacement>("replacement:" + in->name);
-    out->b = cloned.create<Node>("replacement-child:" + in->name);
+    auto out_name = cloned.Symbols().Register(
+        "replacement:" + original.Symbols().NameFor(in->name));
+    auto b_name = cloned.Symbols().Register(
+        "replacement-child:" + original.Symbols().NameFor(in->name));
+    auto* out = cloned.create<Replacement>(out_name);
+    out->b = cloned.create<Node>(b_name);
     out->c = ctx.Clone(in->a);
     return out;
   });
@@ -181,12 +191,14 @@
   EXPECT_NE(cloned_root->b->a, original_root->b->a);
   EXPECT_NE(cloned_root->c, original_root->c);
 
-  EXPECT_EQ(cloned_root->name, "root");
-  EXPECT_EQ(cloned_root->a->name, "a");
-  EXPECT_EQ(cloned_root->a->b->name, "replacement:a->b");
-  EXPECT_EQ(cloned_root->a->b->b->name, "replacement-child:a->b");
-  EXPECT_EQ(cloned_root->b->name, "replacement:b");
-  EXPECT_EQ(cloned_root->b->b->name, "replacement-child:b");
+  EXPECT_EQ(cloned_root->name, cloned.Symbols().Get("root"));
+  EXPECT_EQ(cloned_root->a->name, cloned.Symbols().Get("a"));
+  EXPECT_EQ(cloned_root->a->b->name, cloned.Symbols().Get("replacement:a->b"));
+  EXPECT_EQ(cloned_root->a->b->b->name,
+            cloned.Symbols().Get("replacement-child:a->b"));
+  EXPECT_EQ(cloned_root->b->name, cloned.Symbols().Get("replacement:b"));
+  EXPECT_EQ(cloned_root->b->b->name,
+            cloned.Symbols().Get("replacement-child:b"));
 
   EXPECT_EQ(cloned_root->b->c, cloned_root->a);  // Aliased
   EXPECT_EQ(cloned_root->c, cloned_root->b);     // Aliased
@@ -198,12 +210,53 @@
   EXPECT_FALSE(cloned_root->b->b->Is<Replacement>());
 }
 
+TEST(CloneContext, CloneWithReplaceAll_Symbols) {
+  ProgramBuilder builder;
+  auto* original_root =
+      builder.create<Node>(builder.Symbols().Register("root"));
+  original_root->a = builder.create<Node>(builder.Symbols().Register("a"));
+  original_root->a->b =
+      builder.create<Node>(builder.Symbols().Register("a->b"));
+  original_root->b = builder.create<Node>(builder.Symbols().Register("b"));
+  original_root->b->a = original_root->a;  // Aliased
+  original_root->b->b =
+      builder.create<Node>(builder.Symbols().Register("b->b"));
+  original_root->c = original_root->b;  // Aliased
+  Program original(std::move(builder));
+
+  //                          root
+  //        ╭──────────────────┼──────────────────╮
+  //       (a)                (b)                (c)
+  //        N  <──────┐        N  <───────────────┘
+  //   ╭────┼────╮    │   ╭────┼────╮
+  //  (a)  (b)  (c)   │  (a)  (b)  (c)
+  //        N         └───┘    N
+  //
+  // N: Node
+
+  ProgramBuilder cloned;
+  auto* cloned_root = CloneContext(&cloned, &original)
+                          .ReplaceAll([&](Symbol sym) {
+                            auto in = original.Symbols().NameFor(sym);
+                            auto out = "transformed<" + in + ">";
+                            return cloned.Symbols().Register(out);
+                          })
+                          .Clone(original_root);
+
+  EXPECT_EQ(cloned_root->name, cloned.Symbols().Get("transformed<root>"));
+  EXPECT_EQ(cloned_root->a->name, cloned.Symbols().Get("transformed<a>"));
+  EXPECT_EQ(cloned_root->a->b->name, cloned.Symbols().Get("transformed<a->b>"));
+  EXPECT_EQ(cloned_root->b->name, cloned.Symbols().Get("transformed<b>"));
+  EXPECT_EQ(cloned_root->b->b->name, cloned.Symbols().Get("transformed<b->b>"));
+}
+
 TEST(CloneContext, CloneWithReplace) {
   ProgramBuilder builder;
-  auto* original_root = builder.create<Node>("root");
-  original_root->a = builder.create<Node>("a");
-  original_root->b = builder.create<Node>("b");
-  original_root->c = builder.create<Node>("c");
+  auto* original_root =
+      builder.create<Node>(builder.Symbols().Register("root"));
+  original_root->a = builder.create<Node>(builder.Symbols().Register("a"));
+  original_root->b = builder.create<Node>(builder.Symbols().Register("b"));
+  original_root->c = builder.create<Node>(builder.Symbols().Register("c"));
   Program original(std::move(builder));
 
   //                          root
@@ -212,7 +265,8 @@
   //                        Replaced
 
   ProgramBuilder cloned;
-  auto* replacement = cloned.create<Node>("replacement");
+  auto* replacement =
+      cloned.create<Node>(cloned.Symbols().Register("replacement"));
 
   auto* cloned_root = CloneContext(&cloned, &original)
                           .Replace(original_root->b, replacement)
@@ -222,23 +276,24 @@
   EXPECT_EQ(cloned_root->b, replacement);
   EXPECT_NE(cloned_root->c, replacement);
 
-  EXPECT_EQ(cloned_root->name, "root");
-  EXPECT_EQ(cloned_root->a->name, "a");
-  EXPECT_EQ(cloned_root->b->name, "replacement");
-  EXPECT_EQ(cloned_root->c->name, "c");
+  EXPECT_EQ(cloned_root->name, cloned.Symbols().Get("root"));
+  EXPECT_EQ(cloned_root->a->name, cloned.Symbols().Get("a"));
+  EXPECT_EQ(cloned_root->b->name, cloned.Symbols().Get("replacement"));
+  EXPECT_EQ(cloned_root->c->name, cloned.Symbols().Get("c"));
 }
 
 TEST(CloneContext, CloneWithInsertBefore) {
   ProgramBuilder builder;
-  auto* original_root = builder.create<Node>("root");
-  original_root->a = builder.create<Node>("a");
-  original_root->b = builder.create<Node>("b");
-  original_root->c = builder.create<Node>("c");
+  auto* original_root =
+      builder.create<Node>(builder.Symbols().Register("root"));
+  original_root->a = builder.create<Node>(builder.Symbols().Register("a"));
+  original_root->b = builder.create<Node>(builder.Symbols().Register("b"));
+  original_root->c = builder.create<Node>(builder.Symbols().Register("c"));
   original_root->vec = {original_root->a, original_root->b, original_root->c};
   Program original(std::move(builder));
 
   ProgramBuilder cloned;
-  auto* insertion = cloned.create<Node>("insertion");
+  auto* insertion = cloned.create<Node>(cloned.Symbols().Register("insertion"));
 
   auto* cloned_root = CloneContext(&cloned, &original)
                           .InsertBefore(original_root->b, insertion)
@@ -249,21 +304,77 @@
   EXPECT_EQ(cloned_root->vec[2], cloned_root->b);
   EXPECT_EQ(cloned_root->vec[3], cloned_root->c);
 
-  EXPECT_EQ(cloned_root->name, "root");
-  EXPECT_EQ(cloned_root->vec[0]->name, "a");
-  EXPECT_EQ(cloned_root->vec[1]->name, "insertion");
-  EXPECT_EQ(cloned_root->vec[2]->name, "b");
-  EXPECT_EQ(cloned_root->vec[3]->name, "c");
+  EXPECT_EQ(cloned_root->name, cloned.Symbols().Get("root"));
+  EXPECT_EQ(cloned_root->vec[0]->name, cloned.Symbols().Get("a"));
+  EXPECT_EQ(cloned_root->vec[1]->name, cloned.Symbols().Get("insertion"));
+  EXPECT_EQ(cloned_root->vec[2]->name, cloned.Symbols().Get("b"));
+  EXPECT_EQ(cloned_root->vec[3]->name, cloned.Symbols().Get("c"));
+}
+
+TEST(CloneContext, CloneWithReplaceAll_SameTypeTwice) {
+  EXPECT_FATAL_FAILURE(
+      {
+        ProgramBuilder cloned;
+        Program original;
+        CloneContext ctx(&cloned, &original);
+        ctx.ReplaceAll([](Node*) { return nullptr; });
+        ctx.ReplaceAll([](Node*) { return nullptr; });
+      },
+      "internal compiler error: ReplaceAll() called with a handler for type "
+      "Node that is already handled by a handler for type Node");
+}
+
+TEST(CloneContext, CloneWithReplaceAll_BaseThenDerived) {
+  EXPECT_FATAL_FAILURE(
+      {
+        ProgramBuilder cloned;
+        Program original;
+        CloneContext ctx(&cloned, &original);
+        ctx.ReplaceAll([](Node*) { return nullptr; });
+        ctx.ReplaceAll([](Replaceable*) { return nullptr; });
+      },
+      "internal compiler error: ReplaceAll() called with a handler for type "
+      "Replaceable that is already handled by a handler for type Node");
+}
+
+TEST(CloneContext, CloneWithReplaceAll_DerivedThenBase) {
+  EXPECT_FATAL_FAILURE(
+      {
+        ProgramBuilder cloned;
+        Program original;
+        CloneContext ctx(&cloned, &original);
+        ctx.ReplaceAll([](Replaceable*) { return nullptr; });
+        ctx.ReplaceAll([](Node*) { return nullptr; });
+      },
+      "internal compiler error: ReplaceAll() called with a handler for type "
+      "Node that is already handled by a handler for type Replaceable");
+}
+
+TEST(CloneContext, CloneWithReplaceAll_SymbolsTwice) {
+  EXPECT_FATAL_FAILURE(
+      {
+        ProgramBuilder cloned;
+        Program original;
+        CloneContext ctx(&cloned, &original);
+        ctx.ReplaceAll([](Symbol s) { return s; });
+        ctx.ReplaceAll([](Symbol s) { return s; });
+      },
+      "internal compiler error: ReplaceAll(const SymbolTransform&) called "
+      "multiple times on the same CloneContext");
 }
 
 TEST(CloneContext, CloneWithReplace_WithNotANode) {
   EXPECT_FATAL_FAILURE(
       {
         ProgramBuilder builder;
-        auto* original_root = builder.create<Node>("root");
-        original_root->a = builder.create<Node>("a");
-        original_root->b = builder.create<Node>("b");
-        original_root->c = builder.create<Node>("c");
+        auto* original_root =
+            builder.create<Node>(builder.Symbols().Register("root"));
+        original_root->a =
+            builder.create<Node>(builder.Symbols().Register("a"));
+        original_root->b =
+            builder.create<Node>(builder.Symbols().Register("b"));
+        original_root->c =
+            builder.create<Node>(builder.Symbols().Register("c"));
         Program original(std::move(builder));
 
         //                          root