Add tint::Cloneable base class

The CloneContext was previously dealing with pointers to CastableBase, which has no guarantees that the object was actually cloneable.
Add a Cloneable base class that CloneContext can use instead.

Improves readability and produces cleaner compiler errors if you try to clone a non-cloneable object.

Change-Id: I4352fc5dab3da434e4ab160a54c4c82d50e427b4
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/41722
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/ast/node.h b/src/ast/node.h
index 0090c35..c29c5e5 100644
--- a/src/ast/node.h
+++ b/src/ast/node.h
@@ -19,7 +19,7 @@
 #include <string>
 #include <vector>
 
-#include "src/castable.h"
+#include "src/clone_context.h"
 #include "src/source.h"
 
 namespace tint {
@@ -36,16 +36,10 @@
 namespace ast {
 
 /// AST base class node
-class Node : public Castable<Node> {
+class Node : public Castable<Node, Cloneable> {
  public:
   ~Node() override;
 
-  /// Clones this node and all transitive child nodes using the `CloneContext`
-  /// `ctx`.
-  /// @param ctx the clone context
-  /// @return the newly cloned node
-  virtual Node* Clone(CloneContext* ctx) const = 0;
-
   /// @returns the node source data
   const Source& source() const { return source_; }
 
diff --git a/src/clone_context.cc b/src/clone_context.cc
index f56ec78..1cb1ca2 100644
--- a/src/clone_context.cc
+++ b/src/clone_context.cc
@@ -19,6 +19,8 @@
 #include "src/program.h"
 #include "src/program_builder.h"
 
+TINT_INSTANTIATE_CLASS_ID(tint::Cloneable);
+
 namespace tint {
 
 CloneContext::CloneContext(ProgramBuilder* to, Program const* from)
diff --git a/src/clone_context.h b/src/clone_context.h
index f528daf..5f9605d 100644
--- a/src/clone_context.h
+++ b/src/clone_context.h
@@ -28,15 +28,23 @@
 namespace tint {
 
 // Forward declarations
+class CloneContext;
 class Program;
 class ProgramBuilder;
 
 namespace ast {
-
 class FunctionList;
-
 }  // namespace ast
 
+/// Cloneable is the base class for all objects that can be cloned
+class Cloneable : public Castable<Cloneable> {
+ public:
+  /// Performs a deep clone of this object using the CloneContext `ctx`.
+  /// @param ctx the clone context
+  /// @return the newly cloned object
+  virtual Cloneable* Clone(CloneContext* ctx) const = 0;
+};
+
 /// CloneContext holds the state used while cloning AST nodes and types.
 class CloneContext {
  public:
@@ -186,7 +194,7 @@
   ///
   /// `replacer` must be function-like with the signature:
   ///   `T* (CloneContext*, T*)`
-  ///  where `T` is a type deriving from CastableBase.
+  ///  where `T` is a type deriving from Cloneable.
   ///
   /// If `replacer` returns a nullptr then Clone() will attempt the next
   /// registered replacer function that matches the object type. If no replacers
@@ -210,13 +218,13 @@
   /// references of the original object. A type mismatch will result in an
   /// assertion in debug builds, and undefined behavior in release builds.
   /// @param replacer a function or function-like object with the signature
-  ///        `T* (CloneContext*, T*)`, where `T` derives from CastableBase
+  ///        `T* (CloneContext*, T*)`, where `T` derives from Cloneable
   /// @returns this CloneContext so calls can be chained
   template <typename F>
   CloneContext& ReplaceAll(F replacer) {
     using TPtr = traits::ParamTypeT<F, 1>;
     using T = typename std::remove_pointer<TPtr>::type;
-    transforms_.emplace_back([=](CastableBase* in) {
+    transforms_.emplace_back([=](Cloneable* in) {
       auto* in_as_t = in->As<T>();
       return in_as_t != nullptr ? replacer(this, in_as_t) : nullptr;
     });
@@ -264,7 +272,7 @@
   Program const* const src;
 
  private:
-  using Transform = std::function<CastableBase*(CastableBase*)>;
+  using Transform = std::function<Cloneable*(Cloneable*)>;
 
   CloneContext(const CloneContext&) = delete;
   CloneContext& operator=(const CloneContext&) = delete;
@@ -272,7 +280,7 @@
   /// LookupOrTransform is the template-independent logic of Clone().
   /// This is outside of Clone() to reduce the amount of template-instantiated
   /// code.
-  CastableBase* LookupOrTransform(CastableBase* a) {
+  Cloneable* LookupOrTransform(Cloneable* a) {
     // Have we seen this object before? If so, return the previously cloned
     // version instead of making yet another copy.
     auto it = cloned_.find(a);
@@ -282,7 +290,7 @@
 
     // Attempt to clone using the registered replacer functions.
     for (auto& f : transforms_) {
-      if (CastableBase* c = f(a)) {
+      if (Cloneable* c = f(a)) {
         cloned_.emplace(a, c);
         return c;
       }
@@ -301,16 +309,16 @@
     return cast;
   }
 
-  /// A vector of CastableBase*
-  using CastableList = std::vector<CastableBase*>;
+  /// A vector of Cloneable*
+  using CloneableList = std::vector<Cloneable*>;
 
   /// A map of object in #src to their cloned equivalent in #dst
-  std::unordered_map<CastableBase*, CastableBase*> cloned_;
+  std::unordered_map<Cloneable*, Cloneable*> cloned_;
 
   /// A map of object in #src to the list of cloned objects in #dst.
   /// Clone(const std::vector<T*>& v) will use this to insert the map-value list
   /// into the target vector/ before cloning and inserting the map-key.
-  std::unordered_map<CastableBase*, CastableList> insert_before_;
+  std::unordered_map<Cloneable*, CloneableList> insert_before_;
 
   /// Transform functions registered with ReplaceAll()
   std::vector<Transform> transforms_;
diff --git a/src/clone_context_test.cc b/src/clone_context_test.cc
index edeed3a..bb32334 100644
--- a/src/clone_context_test.cc
+++ b/src/clone_context_test.cc
@@ -25,18 +25,17 @@
 namespace tint {
 namespace {
 
-struct Cloneable : public Castable<Cloneable, ast::Node> {
-  explicit Cloneable(const Source& source, std::string n)
-      : Base(source), name(n) {}
+struct Node : public Castable<Node, ast::Node> {
+  explicit Node(const Source& source, std::string n) : Base(source), name(n) {}
 
   std::string name;
-  Cloneable* a = nullptr;
-  Cloneable* b = nullptr;
-  Cloneable* c = nullptr;
-  std::vector<Cloneable*> vec;
+  Node* a = nullptr;
+  Node* b = nullptr;
+  Node* c = nullptr;
+  std::vector<Node*> vec;
 
-  Cloneable* Clone(CloneContext* ctx) const override {
-    auto* out = ctx->dst->create<Cloneable>(name);
+  Node* Clone(CloneContext* ctx) const override {
+    auto* out = ctx->dst->create<Node>(name);
     out->a = ctx->Clone(a);
     out->b = ctx->Clone(b);
     out->c = ctx->Clone(c);
@@ -48,18 +47,18 @@
   void to_str(const semantic::Info&, std::ostream&, size_t) const override {}
 };
 
-struct Replaceable : public Castable<Replaceable, Cloneable> {
+struct Replaceable : public Castable<Replaceable, Node> {
   explicit Replaceable(const Source& source, std::string n) : Base(source, n) {}
 };
 struct Replacement : public Castable<Replacement, Replaceable> {
   explicit Replacement(const Source& source, std::string n) : Base(source, n) {}
 };
 
-struct NotACloneable : public Castable<NotACloneable, ast::Node> {
-  explicit NotACloneable(const Source& source) : Base(source) {}
+struct NotANode : public Castable<NotANode, ast::Node> {
+  explicit NotANode(const Source& source) : Base(source) {}
 
-  NotACloneable* Clone(CloneContext* ctx) const override {
-    return ctx->dst->create<NotACloneable>();
+  NotANode* Clone(CloneContext* ctx) const override {
+    return ctx->dst->create<NotANode>();
   }
 
   bool IsValid() const override { return true; }
@@ -68,24 +67,24 @@
 
 TEST(CloneContext, Clone) {
   ProgramBuilder builder;
-  auto* original_root = builder.create<Cloneable>("root");
-  original_root->a = builder.create<Cloneable>("a");
-  original_root->a->b = builder.create<Cloneable>("a->b");
-  original_root->b = builder.create<Cloneable>("b");
+  auto* original_root = builder.create<Node>("root");
+  original_root->a = builder.create<Node>("a");
+  original_root->a->b = builder.create<Node>("a->b");
+  original_root->b = builder.create<Node>("b");
   original_root->b->a = original_root->a;  // Aliased
-  original_root->b->b = builder.create<Cloneable>("b->b");
+  original_root->b->b = builder.create<Node>("b->b");
   original_root->c = original_root->b;  // Aliased
   Program original(std::move(builder));
 
   //                          root
   //        ╭──────────────────┼──────────────────╮
   //       (a)                (b)                (c)
-  //        C  <──────┐        C  <───────────────┘
+  //        N  <──────┐        N  <───────────────┘
   //   ╭────┼────╮    │   ╭────┼────╮
   //  (a)  (b)  (c)   │  (a)  (b)  (c)
-  //        C         └───┘    C
+  //        N         └───┘    N
   //
-  // C: Clonable
+  // N: Node
 
   ProgramBuilder cloned;
   auto* cloned_root = CloneContext(&cloned, &original).Clone(original_root);
@@ -119,8 +118,8 @@
 
 TEST(CloneContext, CloneWithReplacements) {
   ProgramBuilder builder;
-  auto* original_root = builder.create<Cloneable>("root");
-  original_root->a = builder.create<Cloneable>("a");
+  auto* original_root = builder.create<Node>("root");
+  original_root->a = builder.create<Node>("a");
   original_root->a->b = builder.create<Replaceable>("a->b");
   original_root->b = builder.create<Replaceable>("b");
   original_root->b->a = original_root->a;  // Aliased
@@ -130,12 +129,12 @@
   //                          root
   //        ╭──────────────────┼──────────────────╮
   //       (a)                (b)                (c)
-  //        C  <──────┐        R  <───────────────┘
+  //        N  <──────┐        R  <───────────────┘
   //   ╭────┼────╮    │   ╭────┼────╮
   //  (a)  (b)  (c)   │  (a)  (b)  (c)
   //        R         └───┘
   //
-  // C: Clonable
+  // N: Node
   // R: Replaceable
 
   ProgramBuilder cloned;
@@ -143,7 +142,7 @@
       CloneContext(&cloned, &original)
           .ReplaceAll([&](CloneContext* ctx, Replaceable* in) {
             auto* out = cloned.create<Replacement>("replacement:" + in->name);
-            out->b = cloned.create<Cloneable>("replacement-child:" + in->name);
+            out->b = cloned.create<Node>("replacement-child:" + in->name);
             out->c = ctx->Clone(in->a);
             return out;
           })
@@ -152,15 +151,15 @@
   //                         root
   //        ╭─────────────────┼──────────────────╮
   //       (a)               (b)                (c)
-  //        C  <──────┐       R  <───────────────┘
+  //        N  <──────┐       R  <───────────────┘
   //   ╭────┼────╮    │  ╭────┼────╮
   //  (a)  (b)  (c)   │ (a)  (b)  (c)
-  //        R         │       C    |
+  //        R         │       N    |
   //   ╭────┼────╮    └────────────┘
   //  (a)  (b)  (c)
-  //        C
+  //        N
   //
-  // C: Clonable
+  // N: Node
   // R: Replacement
 
   EXPECT_NE(cloned_root->a, nullptr);
@@ -201,10 +200,10 @@
 
 TEST(CloneContext, CloneWithReplace) {
   ProgramBuilder builder;
-  auto* original_root = builder.create<Cloneable>("root");
-  original_root->a = builder.create<Cloneable>("a");
-  original_root->b = builder.create<Cloneable>("b");
-  original_root->c = builder.create<Cloneable>("c");
+  auto* original_root = builder.create<Node>("root");
+  original_root->a = builder.create<Node>("a");
+  original_root->b = builder.create<Node>("b");
+  original_root->c = builder.create<Node>("c");
   Program original(std::move(builder));
 
   //                          root
@@ -213,7 +212,7 @@
   //                        Replaced
 
   ProgramBuilder cloned;
-  auto* replacement = cloned.create<Cloneable>("replacement");
+  auto* replacement = cloned.create<Node>("replacement");
 
   auto* cloned_root = CloneContext(&cloned, &original)
                           .Replace(original_root->b, replacement)
@@ -231,15 +230,15 @@
 
 TEST(CloneContext, CloneWithInsertBefore) {
   ProgramBuilder builder;
-  auto* original_root = builder.create<Cloneable>("root");
-  original_root->a = builder.create<Cloneable>("a");
-  original_root->b = builder.create<Cloneable>("b");
-  original_root->c = builder.create<Cloneable>("c");
+  auto* original_root = builder.create<Node>("root");
+  original_root->a = builder.create<Node>("a");
+  original_root->b = builder.create<Node>("b");
+  original_root->c = builder.create<Node>("c");
   original_root->vec = {original_root->a, original_root->b, original_root->c};
   Program original(std::move(builder));
 
   ProgramBuilder cloned;
-  auto* insertion = cloned.create<Cloneable>("insertion");
+  auto* insertion = cloned.create<Node>("insertion");
 
   auto* cloned_root = CloneContext(&cloned, &original)
                           .InsertBefore(original_root->b, insertion)
@@ -257,12 +256,12 @@
   EXPECT_EQ(cloned_root->vec[3]->name, "c");
 }
 
-TEST(CloneContext, CloneWithReplace_WithNotACloneable) {
+TEST(CloneContext, CloneWithReplace_WithNotANode) {
   ProgramBuilder builder;
-  auto* original_root = builder.create<Cloneable>("root");
-  original_root->a = builder.create<Cloneable>("a");
-  original_root->b = builder.create<Cloneable>("b");
-  original_root->c = builder.create<Cloneable>("c");
+  auto* original_root = builder.create<Node>("root");
+  original_root->a = builder.create<Node>("a");
+  original_root->b = builder.create<Node>("b");
+  original_root->c = builder.create<Node>("c");
   Program original(std::move(builder));
 
   //                          root
@@ -271,7 +270,7 @@
   //                        Replaced
 
   ProgramBuilder cloned;
-  auto* replacement = cloned.create<NotACloneable>();
+  auto* replacement = cloned.create<NotANode>();
 
   CloneContext ctx(&cloned, &original);
   ctx.Replace(original_root->b, replacement);
@@ -289,9 +288,9 @@
 
 }  // namespace
 
-TINT_INSTANTIATE_CLASS_ID(Cloneable);
+TINT_INSTANTIATE_CLASS_ID(Node);
 TINT_INSTANTIATE_CLASS_ID(Replaceable);
 TINT_INSTANTIATE_CLASS_ID(Replacement);
-TINT_INSTANTIATE_CLASS_ID(NotACloneable);
+TINT_INSTANTIATE_CLASS_ID(NotANode);
 
 }  // namespace tint
diff --git a/src/type/type.h b/src/type/type.h
index 2cb4c21..c6ef50e 100644
--- a/src/type/type.h
+++ b/src/type/type.h
@@ -17,7 +17,6 @@
 
 #include <string>
 
-#include "src/castable.h"
 #include "src/clone_context.h"
 
 namespace tint {
@@ -32,17 +31,12 @@
 enum class MemoryLayout { kUniformBuffer, kStorageBuffer };
 
 /// Base class for a type in the system
-class Type : public Castable<Type> {
+class Type : public Castable<Type, Cloneable> {
  public:
   /// Move constructor
   Type(Type&&);
   ~Type() override;
 
-  /// Clones this type and all transitive types using the `CloneContext` `ctx`.
-  /// @param ctx the clone context
-  /// @return the newly cloned type
-  virtual Type* Clone(CloneContext* ctx) const = 0;
-
   /// @returns the name for this type. The type name is unique over all types.
   virtual std::string type_name() const = 0;
 
@@ -114,16 +108,6 @@
 
  protected:
   Type();
-
-  /// A helper method for cloning the `Type` `t` if it is not null.
-  /// If `t` is null, then `Clone()` returns null.
-  /// @param b the program builder to clone `n` into
-  /// @param t the `Type` to clone (if not null)
-  /// @return the cloned type
-  template <typename T>
-  static T* Clone(ProgramBuilder* b, const T* t) {
-    return (t != nullptr) ? static_cast<T*>(t->Clone(b)) : nullptr;
-  }
 };
 
 }  // namespace type