tint/utils: Add support for unsafe pointer downcasts

Bug: tint:1779
Change-Id: Icfd27680edf7dfaedbfb70f25641dc762d23f42a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/113020
Auto-Submit: Ben Clayton <bclayton@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/utils/vector.h b/src/tint/utils/vector.h
index cefd536..2371136 100644
--- a/src/tint/utils/vector.h
+++ b/src/tint/utils/vector.h
@@ -106,12 +106,20 @@
     auto rend() const { return std::reverse_iterator<const T*>(begin()); }
 };
 
+/// Mode enumerator for ReinterpretSlice
+enum class ReinterpretMode {
+    /// Only upcasts of pointers are permitted
+    kSafe,
+    /// Potentially unsafe downcasts of pointers are also permitted
+    kUnsafe,
+};
+
 namespace detail {
 
 /// Private implementation of tint::utils::CanReinterpretSlice.
 /// Specialized for the case of TO equal to FROM, which is the common case, and avoids inspection of
 /// the base classes, which can be troublesome if the slice is of an incomplete type.
-template <typename TO, typename FROM>
+template <ReinterpretMode MODE, typename TO, typename FROM>
 struct CanReinterpretSlice {
     /// True if a slice of FROM can be reinterpreted as a slice of TO
     static constexpr bool value =
@@ -122,13 +130,14 @@
          !std::is_const_v<std::remove_pointer_t<FROM>>)&&  //
         // TO and FROM are both Castable
         IsCastable<std::remove_pointer_t<FROM>, std::remove_pointer_t<TO>> &&  //
-        // FROM is of, or derives from TO
-        traits::IsTypeOrDerived<std::remove_pointer_t<FROM>, std::remove_pointer_t<TO>>;
+        // MODE is kUnsafe, or FROM is of, or derives from TO
+        (MODE == ReinterpretMode::kUnsafe ||
+         traits::IsTypeOrDerived<std::remove_pointer_t<FROM>, std::remove_pointer_t<TO>>);
 };
 
 /// Specialization of 'CanReinterpretSlice' for when TO and FROM are equal types.
-template <typename T>
-struct CanReinterpretSlice<T, T> {
+template <typename T, ReinterpretMode MODE>
+struct CanReinterpretSlice<MODE, T, T> {
     /// Always `true` as TO and FROM are the same type.
     static constexpr bool value = true;
 };
@@ -140,16 +149,16 @@
 /// CastableBase, and the pointee type of `TO` is of the same type as, or is an ancestor of the
 /// pointee type of `FROM`. Vectors of non-`const` Castable pointers can be converted to a vector of
 /// `const` Castable pointers.
-template <typename TO, typename FROM>
-static constexpr bool CanReinterpretSlice = detail::CanReinterpretSlice<TO, FROM>::value;
+template <ReinterpretMode MODE, typename TO, typename FROM>
+static constexpr bool CanReinterpretSlice = detail::CanReinterpretSlice<MODE, TO, FROM>::value;
 
 /// Reinterprets `const Slice<FROM>*` as `const Slice<TO>*`
 /// @param slice a pointer to the slice to reinterpret
 /// @returns the reinterpreted slice
 /// @see CanReinterpretSlice
-template <typename TO, typename FROM>
+template <ReinterpretMode MODE, typename TO, typename FROM>
 const Slice<TO>* ReinterpretSlice(const Slice<FROM>* slice) {
-    static_assert(CanReinterpretSlice<TO, FROM>);
+    static_assert(CanReinterpretSlice<MODE, TO, FROM>);
     return Bitcast<const Slice<TO>*>(slice);
 }
 
@@ -157,9 +166,9 @@
 /// @param slice a pointer to the slice to reinterpret
 /// @returns the reinterpreted slice
 /// @see CanReinterpretSlice
-template <typename TO, typename FROM>
+template <ReinterpretMode MODE, typename TO, typename FROM>
 Slice<TO>* ReinterpretSlice(Slice<FROM>* slice) {
-    static_assert(CanReinterpretSlice<TO, FROM>);
+    static_assert(CanReinterpretSlice<MODE, TO, FROM>);
     return Bitcast<Slice<TO>*>(slice);
 }
 
@@ -230,15 +239,21 @@
     /// Copy constructor with covariance / const conversion
     /// @param other the vector to copy
     /// @see CanReinterpretSlice for rules about conversion
-    template <typename U, size_t N2, typename = std::enable_if_t<CanReinterpretSlice<T, U>>>
+    template <typename U,
+              size_t N2,
+              ReinterpretMode MODE,
+              typename = std::enable_if_t<CanReinterpretSlice<MODE, T, U>>>
     Vector(const Vector<U, N2>& other) {  // NOLINT(runtime/explicit)
-        Copy(*ReinterpretSlice<T>(&other.impl_.slice));
+        Copy(*ReinterpretSlice<MODE, T>(&other.impl_.slice));
     }
 
     /// Move constructor with covariance / const conversion
     /// @param other the vector to move
     /// @see CanReinterpretSlice for rules about conversion
-    template <typename U, size_t N2, typename = std::enable_if_t<CanReinterpretSlice<T, U>>>
+    template <typename U,
+              size_t N2,
+              ReinterpretMode MODE,
+              typename = std::enable_if_t<CanReinterpretSlice<MODE, T, U>>>
     Vector(Vector<U, N2>&& other) {  // NOLINT(runtime/explicit)
         MoveOrCopy(VectorRef<T>(std::move(other)));
     }
@@ -701,6 +716,11 @@
     /// Constructor
     VectorRef(EmptyType) : slice_(EmptySlice()) {}  // NOLINT(runtime/explicit)
 
+    /// Constructor from a Slice
+    /// @param slice the slice
+    VectorRef(Slice& slice)  // NOLINT(runtime/explicit)
+        : slice_(slice) {}
+
     /// Constructor from a Vector
     /// @param vector the vector to create a reference of
     template <size_t N>
@@ -729,29 +749,37 @@
 
     /// Copy constructor with covariance / const conversion
     /// @param other the other vector reference
-    template <typename U, typename = std::enable_if_t<CanReinterpretSlice<T, U>>>
+    template <typename U,
+              typename = std::enable_if_t<CanReinterpretSlice<ReinterpretMode::kSafe, T, U>>>
     VectorRef(const VectorRef<U>& other)  // NOLINT(runtime/explicit)
-        : slice_(*ReinterpretSlice<T>(&other.slice_)) {}
+        : slice_(*ReinterpretSlice<ReinterpretMode::kSafe, T>(&other.slice_)) {}
 
     /// Move constructor with covariance / const conversion
     /// @param other the vector reference
-    template <typename U, typename = std::enable_if_t<CanReinterpretSlice<T, U>>>
+    template <typename U,
+              typename = std::enable_if_t<CanReinterpretSlice<ReinterpretMode::kSafe, T, U>>>
     VectorRef(VectorRef<U>&& other)  // NOLINT(runtime/explicit)
-        : slice_(*ReinterpretSlice<T>(&other.slice_)), can_move_(other.can_move_) {}
+        : slice_(*ReinterpretSlice<ReinterpretMode::kSafe, T>(&other.slice_)),
+          can_move_(other.can_move_) {}
 
     /// Constructor from a Vector with covariance / const conversion
     /// @param vector the vector to create a reference of
     /// @see CanReinterpretSlice for rules about conversion
-    template <typename U, size_t N, typename = std::enable_if_t<CanReinterpretSlice<T, U>>>
+    template <typename U,
+              size_t N,
+              typename = std::enable_if_t<CanReinterpretSlice<ReinterpretMode::kSafe, T, U>>>
     VectorRef(Vector<U, N>& vector)  // NOLINT(runtime/explicit)
-        : slice_(*ReinterpretSlice<T>(&vector.impl_.slice)) {}
+        : slice_(*ReinterpretSlice<ReinterpretMode::kSafe, T>(&vector.impl_.slice)) {}
 
     /// Constructor from a moved Vector with covariance / const conversion
     /// @param vector the vector to create a reference of
     /// @see CanReinterpretSlice for rules about conversion
-    template <typename U, size_t N, typename = std::enable_if_t<CanReinterpretSlice<T, U>>>
+    template <typename U,
+              size_t N,
+              typename = std::enable_if_t<CanReinterpretSlice<ReinterpretMode::kSafe, T, U>>>
     VectorRef(Vector<U, N>&& vector)  // NOLINT(runtime/explicit)
-        : slice_(*ReinterpretSlice<T>(&vector.impl_.slice)), can_move_(vector.impl_.CanMove()) {}
+        : slice_(*ReinterpretSlice<ReinterpretMode::kSafe, T>(&vector.impl_.slice)),
+          can_move_(vector.impl_.CanMove()) {}
 
     /// Index operator
     /// @param i the element index. Must be less than `len`.
@@ -765,6 +793,14 @@
     /// be made
     size_t Capacity() const { return slice_.cap; }
 
+    /// @return a reinterpretation of this VectorRef as elements of type U.
+    /// @note this is doing a reinterpret_cast of elements. It is up to the caller to ensure that
+    /// this is a safe operation.
+    template <typename U>
+    VectorRef<U> ReinterpretCast() const {
+        return {*ReinterpretSlice<ReinterpretMode::kUnsafe, U>(&slice_)};
+    }
+
     /// @returns true if the vector is empty.
     bool IsEmpty() const { return slice_.len == 0; }
 
diff --git a/src/tint/utils/vector_test.cc b/src/tint/utils/vector_test.cc
index 8e6fa7a..74d15b7 100644
--- a/src/tint/utils/vector_test.cc
+++ b/src/tint/utils/vector_test.cc
@@ -79,22 +79,30 @@
 static_assert(std::is_same_v<VectorCommonType<C2a*, const C2b*>, const C1*>);
 static_assert(std::is_same_v<VectorCommonType<const C2a*, const C2b*>, const C1*>);
 
-static_assert(CanReinterpretSlice<const C0*, C0*>, "apply const");
-static_assert(!CanReinterpretSlice<C0*, const C0*>, "remove const");
-static_assert(CanReinterpretSlice<C0*, C1*>, "up cast");
-static_assert(CanReinterpretSlice<const C0*, const C1*>, "up cast");
-static_assert(CanReinterpretSlice<const C0*, C1*>, "up cast, apply const");
-static_assert(!CanReinterpretSlice<C0*, const C1*>, "up cast, remove const");
-static_assert(!CanReinterpretSlice<C1*, C0*>, "down cast");
-static_assert(!CanReinterpretSlice<const C1*, const C0*>, "down cast");
-static_assert(!CanReinterpretSlice<const C1*, C0*>, "down cast, apply const");
-static_assert(!CanReinterpretSlice<C1*, const C0*>, "down cast, remove const");
-static_assert(!CanReinterpretSlice<const C1*, C0*>, "down cast, apply const");
-static_assert(!CanReinterpretSlice<C1*, const C0*>, "down cast, remove const");
-static_assert(!CanReinterpretSlice<C2a*, C2b*>, "sideways cast");
-static_assert(!CanReinterpretSlice<const C2a*, const C2b*>, "sideways cast");
-static_assert(!CanReinterpretSlice<const C2a*, C2b*>, "sideways cast, apply const");
-static_assert(!CanReinterpretSlice<C2a*, const C2b*>, "sideways cast, remove const");
+static_assert(CanReinterpretSlice<ReinterpretMode::kSafe, const C0*, C0*>, "apply const");
+static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, C0*, const C0*>, "remove const");
+static_assert(CanReinterpretSlice<ReinterpretMode::kSafe, C0*, C1*>, "up cast");
+static_assert(CanReinterpretSlice<ReinterpretMode::kSafe, const C0*, const C1*>, "up cast");
+static_assert(CanReinterpretSlice<ReinterpretMode::kSafe, const C0*, C1*>, "up cast, apply const");
+static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, C0*, const C1*>,
+              "up cast, remove const");
+static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, C1*, C0*>, "down cast");
+static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, const C1*, const C0*>, "down cast");
+static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, const C1*, C0*>,
+              "down cast, apply const");
+static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, C1*, const C0*>,
+              "down cast, remove const");
+static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, const C1*, C0*>,
+              "down cast, apply const");
+static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, C1*, const C0*>,
+              "down cast, remove const");
+static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, C2a*, C2b*>, "sideways cast");
+static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, const C2a*, const C2b*>,
+              "sideways cast");
+static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, const C2a*, C2b*>,
+              "sideways cast, apply const");
+static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, C2a*, const C2b*>,
+              "sideways cast, remove const");
 
 ////////////////////////////////////////////////////////////////////////////////
 // TintVectorTest
@@ -2001,6 +2009,18 @@
     EXPECT_TRUE(AllExternallyHeld(vec_b));  // Moved, not copied
 }
 
+TEST(TintVectorRefTest, MoveVector_ReinterpretCast) {
+    C2a c2a;
+    C2b c2b;
+    Vector<C0*, 1> vec_a{&c2a, &c2b};
+    VectorRef<const C0*> vec_ref(std::move(vec_a));  // Move
+    EXPECT_EQ(vec_ref[0], &c2a);
+    EXPECT_EQ(vec_ref[1], &c2b);
+    VectorRef<const C1*> reinterpret = vec_ref.ReinterpretCast<const C1*>();
+    EXPECT_EQ(reinterpret[0], &c2a);
+    EXPECT_EQ(reinterpret[1], &c2b);
+}
+
 TEST(TintVectorRefTest, Index) {
     Vector<std::string, 2> vec{"one", "two"};
     VectorRef<std::string> vec_ref(vec);