CloneContext: Use As<T>() instead of static_cast<T>

And assert that the cast succeeded.

There is a danger with Replace() or ReplaceAll(), where you can end up replacing a node with another node of an incompatible type for some reference of that object. Previously this would silently cast to the incorrect type, and Bad Things would happen. Now we will assert in this situation.

I have not observed this issue happening (all current uses of Replace() and ReplaceAll() are believed to be safe). This is just an edge case I've spotted and wanted to add some safety belts for.

Change-Id: Icf4a4fe76f7bc14bcc6b274de68f7d0b3d85d71f
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/41546
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/clone_context.h b/src/clone_context.h
index 17109e2..8424b73 100644
--- a/src/clone_context.h
+++ b/src/clone_context.h
@@ -15,6 +15,7 @@
 #ifndef SRC_CLONE_CONTEXT_H_
 #define SRC_CLONE_CONTEXT_H_
 
+#include <cassert>
 #include <functional>
 #include <unordered_map>
 #include <vector>
@@ -72,14 +73,49 @@
     // previously cloned pointer.
     // If we haven't cloned this before, try cloning using a replacer transform.
     if (auto* c = LookupOrTransform(a)) {
-      return static_cast<T*>(c);
+      return CheckedCast<T>(c);
     }
 
     // First time clone and no replacer transforms matched.
     // Clone with T::Clone().
     auto* c = a->Clone(this);
     cloned_.emplace(a, c);
-    return static_cast<T*>(c);
+    return CheckedCast<T>(c);
+  }
+
+  /// Clones the Node or type::Type `a` into the ProgramBuilder #dst if `a` is
+  /// not null. If `a` is null, then Clone() returns null. If `a` has been
+  /// cloned already by this CloneContext then the same cloned pointer is
+  /// returned.
+  ///
+  /// Unlike Clone(), this method does not invoke or use any transformations
+  /// registered by ReplaceAll().
+  ///
+  /// The Node or type::Type `a` must be owned by the Program #src.
+  ///
+  /// @note Semantic information such as resolved expression type and intrinsic
+  /// information is not cloned.
+  /// @param a the `Node` or `type::Type` to clone
+  /// @return the cloned node
+  template <typename T>
+  T* CloneWithoutTransform(T* a) {
+    // If the input is nullptr, there's nothing to clone - just return nullptr.
+    if (a == nullptr) {
+      return nullptr;
+    }
+
+    // Have we seen this object before? If so, return the previously cloned
+    // version instead of making yet another copy.
+    auto it = cloned_.find(a);
+    if (it != cloned_.end()) {
+      return CheckedCast<T>(it->second);
+    }
+
+    // First time clone and no replacer transforms matched.
+    // Clone with T::Clone().
+    auto* c = a->Clone(this);
+    cloned_.emplace(a, c);
+    return CheckedCast<T>(c);
   }
 
   /// Clones the Source `s` into `dst`
@@ -150,6 +186,9 @@
   ///     }).Clone();
   /// ```
   ///
+  /// @warning The replacement object must be of the correct type for all
+  /// references of the original object. A type mismatch will result in an
+  /// assertion in debug builds, and undefined behavior in release builds.
   /// @param replacer a function or function-like object with the signature
   ///        `T* (CloneContext*, T*)`, where `T` derives from CastableBase
   /// @returns this CloneContext so calls can be chained
@@ -168,12 +207,15 @@
   /// when calling Clone().
   /// @param what a pointer to the object in #src that will be replaced with
   /// `with`
-  /// @param with a pointer to the replacement object that will be used when
-  /// cloning into #dst
+  /// @param with a pointer to the replacement object owned by #dst that will be
+  /// used as a replacement for `what`
+  /// @warning The replacement object must be of the correct type for all
+  /// references of the original object. A type mismatch will result in an
+  /// assertion in debug builds, and undefined behavior in release builds.
   /// @returns this CloneContext so calls can be chained
-  template <typename T>
-  CloneContext& Replace(T* what, T* with) {
-    cloned_.emplace(what, with);
+  template <typename WHAT, typename WITH>
+  CloneContext& Replace(WHAT* what, WITH* with) {
+    cloned_[what] = with;
     return *this;
   }
 
@@ -215,6 +257,15 @@
     return nullptr;
   }
 
+  /// Cast `obj` from type `FROM` to type `TO`, returning the cast object.
+  /// Asserts if the cast failed.
+  template <typename TO, typename FROM>
+  TO* CheckedCast(FROM* obj) {
+    TO* cast = obj->template As<TO>();
+    assert(cast /* cloned object was not of the expected type */);
+    return cast;
+  }
+
   std::unordered_map<CastableBase*, CastableBase*> cloned_;
   std::vector<Transform> transforms_;
 };
diff --git a/src/clone_context_test.cc b/src/clone_context_test.cc
index 55d0f52..5e47cb2 100644
--- a/src/clone_context_test.cc
+++ b/src/clone_context_test.cc
@@ -49,6 +49,17 @@
   explicit Replacement(const Source& source) : Base(source) {}
 };
 
+struct NotACloneable : public Castable<NotACloneable, ast::Node> {
+  explicit NotACloneable(const Source& source) : Base(source) {}
+
+  NotACloneable* Clone(CloneContext* ctx) const override {
+    return ctx->dst->create<NotACloneable>();
+  }
+
+  bool IsValid() const override { return true; }
+  void to_str(const semantic::Info&, std::ostream&, size_t) const override {}
+};
+
 TEST(CloneContext, Clone) {
   ProgramBuilder builder;
   auto* original_root = builder.create<Cloneable>();
@@ -193,10 +204,41 @@
   EXPECT_NE(cloned_root->c, replacement);
 }
 
+TEST(CloneContext, CloneWithReplace_WithNotACloneable) {
+  ProgramBuilder builder;
+  auto* original_root = builder.create<Cloneable>();
+  original_root->a = builder.create<Cloneable>();
+  original_root->b = builder.create<Cloneable>();
+  original_root->c = builder.create<Cloneable>();
+  Program original(std::move(builder));
+
+  //                          root
+  //        ╭──────────────────┼──────────────────╮
+  //       (a)                (b)                (c)
+  //                        Replaced
+
+  ProgramBuilder cloned;
+  auto* replacement = cloned.create<NotACloneable>();
+
+  CloneContext ctx(&cloned, &original);
+  ctx.Replace(original_root->b, replacement);
+
+#ifndef NDEBUG
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wused-but-marked-unused"
+#pragma clang diagnostic ignored "-Wcovered-switch-default"
+
+  EXPECT_DEATH_IF_SUPPORTED(ctx.Clone(original_root), "");
+
+#pragma clang diagnostic pop
+#endif  // NDEBUG
+}
+
 }  // namespace
 
 TINT_INSTANTIATE_CLASS_ID(Cloneable);
 TINT_INSTANTIATE_CLASS_ID(Replaceable);
 TINT_INSTANTIATE_CLASS_ID(Replacement);
+TINT_INSTANTIATE_CLASS_ID(NotACloneable);
 
 }  // namespace tint