Add ast::CloneContext::ReplaceAll()

ReplaceAll() registers `replacer` to be called whenever the Clone() method is called with a type that matches (or derives from) the type of the first parameter of `replacer`.

`replacer` must be function-like with the signature: `T* (T*)`, where `T` is a type deriving from CastableBase.

If `replacer` returns a nullptr then Clone() will attempt the next registered replacer function that matches the object type. If no replacers match the object type, or all returned nullptr then Clone() will call `T::Clone()` to clone the object.

Example:

```
  // Replace all ast::UintLiterals with the number 42
  CloneCtx ctx(mod);
  ctx.ReplaceAll([&] (ast::UintLiteral* in) {
    return ctx.mod->create<ast::UintLiteral>(ctx.Clone(in->type()), 42);
  });
  auto* out = ctx.Clone(tree);
```

This is to be used by Transforms that want to replace parts of the AST on clone.

Bug: tint:390
Change-Id: I80a0e58aa3711f309f58a504f6b6a06f6c546ea1
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/34568
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/BUILD.gn b/BUILD.gn
index 019da76..58a188d 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -758,6 +758,7 @@
     "src/ast/call_expression_test.cc",
     "src/ast/call_statement_test.cc",
     "src/ast/case_statement_test.cc",
+    "src/ast/clone_context_test.cc",
     "src/ast/constant_id_decoration_test.cc",
     "src/ast/continue_statement_test.cc",
     "src/ast/decorated_variable_test.cc",
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 870a44a..1c603f7 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -367,6 +367,7 @@
   ast/call_expression_test.cc
   ast/call_statement_test.cc
   ast/case_statement_test.cc
+  ast/clone_context_test.cc
   ast/constant_id_decoration_test.cc
   ast/continue_statement_test.cc
   ast/discard_statement_test.cc
diff --git a/src/ast/clone_context.h b/src/ast/clone_context.h
index 667882d..7733b85 100644
--- a/src/ast/clone_context.h
+++ b/src/ast/clone_context.h
@@ -15,9 +15,12 @@
 #ifndef SRC_AST_CLONE_CONTEXT_H_
 #define SRC_AST_CLONE_CONTEXT_H_
 
+#include <functional>
 #include <unordered_map>
 #include <vector>
 
+#include "src/ast/traits.h"
+#include "src/castable.h"
 #include "src/source.h"
 
 namespace tint {
@@ -31,26 +34,37 @@
   /// Constructor
   /// @param m the target module to clone into
   explicit CloneContext(Module* m);
+
   /// Destructor
   ~CloneContext();
 
   /// Clones the `Node` or `type::Type` `a` into the module #mod if `a` is not
   /// null. If `a` is null, then Clone() returns null. If `a` has been cloned
   /// already by this CloneContext then the same cloned pointer is returned.
+  ///
+  /// Clone() may use a function registered with ReplaceAll() to create a
+  /// transformed version of the object. See ReplaceAll() for more information.
+  ///
   /// @note Semantic information such as resolved expression type and intrinsic
   /// information is not cloned.
   /// @param a the `Node` or `type::Type` to clone
   /// @return the cloned node
   template <typename T>
   T* Clone(T* a) {
+    // If the input is nullptr, there's nothing to clone - just return nullptr.
     if (a == nullptr) {
       return nullptr;
     }
 
-    auto it = cloned_.find(a);
-    if (it != cloned_.end()) {
-      return static_cast<T*>(it->second);
+    // See if we've already cloned this object - if we have return the
+    // previously cloned pointer.
+    // If we haven't cloned this before, try cloning using a replacer transform.
+    if (auto* c = LookupOrTransform(a)) {
+      return static_cast<T*>(c);
     }
+
+    // First time clone and no replacer transforms matched.
+    // Clone with T::Clone().
     auto* c = a->Clone(this);
     cloned_.emplace(a, c);
     return static_cast<T*>(c);
@@ -77,11 +91,72 @@
     return out;
   }
 
+  /// ReplaceAll() registers `replacer` to be called whenever the Clone() method
+  /// is called with a type that matches (or derives from) the type of the first
+  /// parameter of `replacer`.
+  ///
+  /// `replacer` must be function-like with the signature: `T* (T*)`, where `T`
+  /// is a type deriving from CastableBase.
+  ///
+  /// If `replacer` returns a nullptr then Clone() will attempt the next
+  /// registered replacer function that matches the object type. If no replacers
+  /// match the object type, or all returned nullptr then Clone() will call
+  /// `T::Clone()` to clone the object.
+  ///
+  /// Example:
+  ///
+  /// ```
+  ///   // Replace all ast::UintLiterals with the number 42
+  ///   CloneCtx ctx(mod);
+  ///   ctx.ReplaceAll([&] (ast::UintLiteral* in) {
+  ///     return ctx.mod->create<ast::UintLiteral>(ctx.Clone(in->type()), 42);
+  ///   });
+  ///   auto* out = ctx.Clone(tree);
+  /// ```
+  ///
+  /// @param replacer a function or function-like object with the signature
+  ///        `T* (T*)`, where `T` derives from CastableBase
+  template <typename F>
+  void ReplaceAll(F replacer) {
+    using TPtr = traits::FirstParamTypeT<F>;
+    using T = typename std::remove_pointer<TPtr>::type;
+    transforms_.emplace_back([=](CastableBase* in) {
+      auto* in_as_t = in->As<T>();
+      return in_as_t != nullptr ? replacer(in_as_t) : nullptr;
+    });
+  }
+
   /// The target module to clone into.
   Module* const mod;
 
  private:
-  std::unordered_map<void*, void*> cloned_;
+  using Transform = std::function<CastableBase*(CastableBase*)>;
+
+  /// LookupOrTransform is the template-independent logic of Clone().
+  /// This is outside of Clone() to reduce the amount of template-instantiated
+  /// code.
+  CastableBase* LookupOrTransform(CastableBase* a) {
+    // Have we seen this object before? If so, return the previously cloned
+    // version instead of making yet another copy.
+    auto it = cloned_.find(a);
+    if (it != cloned_.end()) {
+      return it->second;
+    }
+
+    // Attempt to clone using the registered replacer functions.
+    for (auto& f : transforms_) {
+      if (CastableBase* c = f(a)) {
+        cloned_.emplace(a, c);
+        return c;
+      }
+    }
+
+    // No luck, Clone() will have to call T::Clone().
+    return nullptr;
+  }
+
+  std::unordered_map<CastableBase*, CastableBase*> cloned_;
+  std::vector<Transform> transforms_;
 };
 
 }  // namespace ast
diff --git a/src/ast/clone_context_test.cc b/src/ast/clone_context_test.cc
new file mode 100644
index 0000000..48c00d7
--- /dev/null
+++ b/src/ast/clone_context_test.cc
@@ -0,0 +1,170 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/ast/clone_context.h"
+
+#include "gtest/gtest.h"
+
+#include "src/ast/module.h"
+
+namespace tint {
+namespace ast {
+namespace {
+
+struct Cloneable : public Castable<Cloneable, Node> {
+  Cloneable* a = nullptr;
+  Cloneable* b = nullptr;
+  Cloneable* c = nullptr;
+
+  Cloneable* Clone(CloneContext* ctx) const override {
+    auto* out = ctx->mod->create<Cloneable>();
+    out->a = ctx->Clone(a);
+    out->b = ctx->Clone(b);
+    out->c = ctx->Clone(c);
+    return out;
+  }
+
+  bool IsValid() const override { return true; }
+  void to_str(std::ostream&, size_t) const override {}
+};
+
+struct Replaceable : public Castable<Replaceable, Cloneable> {};
+struct Replacement : public Castable<Replacement, Replaceable> {};
+
+TEST(CloneContext, Clone) {
+  ast::Module original;
+  auto* original_root = original.create<Cloneable>();
+  original_root->a = original.create<Cloneable>();
+  original_root->a->b = original.create<Cloneable>();
+  original_root->b = original.create<Cloneable>();
+  original_root->b->a = original_root->a;  // Aliased
+  original_root->b->b = original.create<Cloneable>();
+  original_root->c = original_root->b;  // Aliased
+
+  //                          root
+  //        ╭──────────────────┼──────────────────╮
+  //       (a)                (b)                (c)
+  //        C  <──────┐        C  <───────────────┘
+  //   ╭────┼────╮    │   ╭────┼────╮
+  //  (a)  (b)  (c)   │  (a)  (b)  (c)
+  //        C         └───┘    C
+  //
+  // C: Clonable
+
+  ast::Module cloned;
+  CloneContext ctx(&cloned);
+  auto* cloned_root = original_root->Clone(&ctx);
+
+  EXPECT_NE(cloned_root->a, nullptr);
+  EXPECT_EQ(cloned_root->a->a, nullptr);
+  EXPECT_NE(cloned_root->a->b, nullptr);
+  EXPECT_EQ(cloned_root->a->c, nullptr);
+  EXPECT_NE(cloned_root->b, nullptr);
+  EXPECT_NE(cloned_root->b->a, nullptr);
+  EXPECT_NE(cloned_root->b->b, nullptr);
+  EXPECT_EQ(cloned_root->b->c, nullptr);
+  EXPECT_NE(cloned_root->c, nullptr);
+
+  EXPECT_NE(cloned_root->a, original_root->a);
+  EXPECT_NE(cloned_root->a->b, original_root->a->b);
+  EXPECT_NE(cloned_root->b, original_root->b);
+  EXPECT_NE(cloned_root->b->a, original_root->b->a);
+  EXPECT_NE(cloned_root->b->b, original_root->b->b);
+  EXPECT_NE(cloned_root->c, original_root->c);
+
+  EXPECT_EQ(cloned_root->b->a, cloned_root->a);  // Aliased
+  EXPECT_EQ(cloned_root->c, cloned_root->b);     // Aliased
+}
+
+TEST(CloneContext, CloneWithReplacements) {
+  ast::Module original;
+  auto* original_root = original.create<Cloneable>();
+  original_root->a = original.create<Cloneable>();
+  original_root->a->b = original.create<Replaceable>();
+  original_root->b = original.create<Replaceable>();
+  original_root->b->a = original_root->a;  // Aliased
+  original_root->c = original_root->b;     // Aliased
+
+  //                          root
+  //        ╭──────────────────┼──────────────────╮
+  //       (a)                (b)                (c)
+  //        C  <──────┐        R  <───────────────┘
+  //   ╭────┼────╮    │   ╭────┼────╮
+  //  (a)  (b)  (c)   │  (a)  (b)  (c)
+  //        R         └───┘
+  //
+  // C: Clonable
+  // R: Replaceable
+
+  ast::Module cloned;
+  CloneContext ctx(&cloned);
+  ctx.ReplaceAll([&](Replaceable* in) {
+    auto* out = cloned.create<Replacement>();
+    out->b = cloned.create<Cloneable>();
+    out->c = ctx.Clone(in->a);
+    return out;
+  });
+  auto* cloned_root = original_root->Clone(&ctx);
+
+  //                         root
+  //        ╭─────────────────┼──────────────────╮
+  //       (a)               (b)                (c)
+  //        C  <──────┐       R  <───────────────┘
+  //   ╭────┼────╮    │  ╭────┼────╮
+  //  (a)  (b)  (c)   │ (a)  (b)  (c)
+  //        R         │       C    |
+  //   ╭────┼────╮    └────────────┘
+  //  (a)  (b)  (c)
+  //        C
+  //
+  // C: Clonable
+  // R: Replacement
+
+  EXPECT_NE(cloned_root->a, nullptr);
+  EXPECT_EQ(cloned_root->a->a, nullptr);
+  EXPECT_NE(cloned_root->a->b, nullptr);     // Replaced
+  EXPECT_EQ(cloned_root->a->b->a, nullptr);  // From replacement
+  EXPECT_NE(cloned_root->a->b->b, nullptr);  // From replacement
+  EXPECT_EQ(cloned_root->a->b->c, nullptr);  // From replacement
+  EXPECT_EQ(cloned_root->a->c, nullptr);
+  EXPECT_NE(cloned_root->b, nullptr);
+  EXPECT_EQ(cloned_root->b->a, nullptr);  // From replacement
+  EXPECT_NE(cloned_root->b->b, nullptr);  // From replacement
+  EXPECT_NE(cloned_root->b->c, nullptr);  // From replacement
+  EXPECT_NE(cloned_root->c, nullptr);
+
+  EXPECT_NE(cloned_root->a, original_root->a);
+  EXPECT_NE(cloned_root->a->b, original_root->a->b);
+  EXPECT_NE(cloned_root->b, original_root->b);
+  EXPECT_NE(cloned_root->b->a, original_root->b->a);
+  EXPECT_NE(cloned_root->c, original_root->c);
+
+  EXPECT_EQ(cloned_root->b->c, cloned_root->a);  // Aliased
+  EXPECT_EQ(cloned_root->c, cloned_root->b);     // Aliased
+
+  EXPECT_FALSE(cloned_root->a->Is<Replacement>());
+  EXPECT_TRUE(cloned_root->a->b->Is<Replacement>());
+  EXPECT_FALSE(cloned_root->a->b->b->Is<Replacement>());
+  EXPECT_TRUE(cloned_root->b->Is<Replacement>());
+  EXPECT_FALSE(cloned_root->b->b->Is<Replacement>());
+}
+
+}  // namespace
+}  // namespace ast
+
+TINT_INSTANTIATE_CLASS_ID(ast::Cloneable);
+TINT_INSTANTIATE_CLASS_ID(ast::Replaceable);
+TINT_INSTANTIATE_CLASS_ID(ast::Replacement);
+
+}  // namespace tint