Dawn: Support stream for types used in shader module reflection

This CL add stream in/out support for types used in the shader module
reflection, including std::wstring, std::variant, std::unique_ptr,
ityp::array, absl::flat_hash_map, andabsl::flat_hash_set. This CL also
add support for stream out std::reference_wrapper, which are used for
sorting the items in map/set without copying them. Related unittests are
also added.

Bug: 402772740
Change-Id: I0b93c3638879ef749428c4f37420b132b2a0849f
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/241474
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Loko Kung <lokokung@google.com>
Commit-Queue: Zhaoming Jiang <zhaoming.jiang@microsoft.com>
diff --git a/src/dawn/native/stream/Stream.cpp b/src/dawn/native/stream/Stream.cpp
index 0e66c3c..1682b63 100644
--- a/src/dawn/native/stream/Stream.cpp
+++ b/src/dawn/native/stream/Stream.cpp
@@ -54,6 +54,27 @@
 }
 
 template <>
+void Stream<std::wstring>::Write(Sink* s, const std::wstring& t) {
+    StreamIn(s, t.length());
+    size_t size = t.length() * sizeof(wchar_t);
+    if (size > 0) {
+        void* ptr = s->GetSpace(size);
+        memcpy(ptr, t.data(), size);
+    }
+}
+
+template <>
+MaybeError Stream<std::wstring>::Read(Source* s, std::wstring* t) {
+    size_t length;
+    DAWN_TRY(StreamOut(s, &length));
+    size_t size = length * sizeof(wchar_t);
+    const void* ptr;
+    DAWN_TRY(s->Read(&ptr, size));
+    *t = std::wstring(static_cast<const wchar_t*>(ptr), length);
+    return {};
+}
+
+template <>
 void Stream<std::string_view>::Write(Sink* s, const std::string_view& t) {
     StreamIn(s, t.length());
     size_t size = t.length() * sizeof(char);
diff --git a/src/dawn/native/stream/Stream.h b/src/dawn/native/stream/Stream.h
index 77106f9..d501925 100644
--- a/src/dawn/native/stream/Stream.h
+++ b/src/dawn/native/stream/Stream.h
@@ -32,6 +32,7 @@
 #include <bitset>
 #include <functional>
 #include <limits>
+#include <memory>
 #include <unordered_map>
 #include <unordered_set>
 #include <utility>
@@ -39,8 +40,11 @@
 
 #include <optional>
 
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
 #include "dawn/common/Platform.h"
 #include "dawn/common/TypedInteger.h"
+#include "dawn/common/ityp_array.h"
 #include "dawn/native/Error.h"
 #include "dawn/native/stream/Sink.h"
 #include "dawn/native/stream/Source.h"
@@ -225,6 +229,43 @@
     }
 };
 
+// Stream specialization for unique pointers. We always serialize/deserialize via value, not by
+// pointer. To handle nullptr scenarios, we always serialize whether the pointer was not nullptr,
+// followed by the contents if applicable.
+template <typename T>
+class Stream<std::unique_ptr<T>, std::enable_if_t<!std::is_pointer_v<T>>> {
+  public:
+    static void Write(stream::Sink* sink, const std::unique_ptr<T>& t) {
+        StreamIn(sink, t != nullptr);
+        if (t != nullptr) {
+            StreamIn(sink, *t);
+        }
+    }
+
+    static MaybeError Read(Source* source, std::unique_ptr<T>* t) {
+        bool notNullptr;
+        DAWN_TRY(StreamOut(source, &notNullptr));
+        if (notNullptr) {
+            T out;
+            DAWN_TRY(StreamOut(source, &out));
+            *t = std::make_unique<T>(std::move(out));
+        } else {
+            *t = nullptr;
+        }
+        return {};
+    }
+};
+
+// Stream specialization for reference_wrapper. For serialization, unwrap it to the reference const
+// T& and call Write(sink, const T&).
+template <typename T>
+class Stream<std::reference_wrapper<T>> {
+  public:
+    static void Write(stream::Sink* sink, const std::reference_wrapper<T>& t) {
+        StreamIn(sink, t.get());
+    }
+};
+
 // Stream specialization for std::optional
 template <typename T>
 class Stream<std::optional<T>> {
@@ -342,6 +383,26 @@
     }
 };
 
+// Stream specialization for ityp::array<Index, Value, Size>.
+template <typename Index, typename Value, size_t Size>
+class Stream<ityp::array<Index, Value, Size>> {
+  public:
+    using ArrayType = ityp::array<Index, Value, Size>;
+
+    static void Write(Sink* s, const ArrayType& v) {
+        for (const Value& it : v) {
+            StreamIn(s, it);
+        }
+    }
+
+    static MaybeError Read(Source* s, ArrayType* v) {
+        for (auto& el : *v) {
+            DAWN_TRY(StreamOut(s, el));
+        }
+        return {};
+    }
+};
+
 // Stream specialization for std::pair.
 template <typename A, typename B>
 class Stream<std::pair<A, B>> {
@@ -358,26 +419,44 @@
     }
 };
 
-// Stream specialization for std::unordered_map<K, V> which sorts the entries
-// to provide a stable ordering.
-template <typename K, typename V>
-class Stream<std::unordered_map<K, V>> {
+template <typename M>
+concept IsMapLike = std::is_same_v<M,
+                                   std::unordered_map<typename M::key_type,
+                                                      typename M::mapped_type,
+                                                      typename M::hasher,
+                                                      typename M::key_equal,
+                                                      typename M::allocator_type>> ||
+                    std::is_same_v<M,
+                                   absl::flat_hash_map<typename M::key_type,
+                                                       typename M::mapped_type,
+                                                       typename M::hasher,
+                                                       typename M::key_equal,
+                                                       typename M::allocator_type>>;
+
+// Stream specialization for std::unordered_map<K, V> and absl::flat_hash_map which sorts the
+// entries to provide a stable ordering.
+template <IsMapLike MapType>
+class Stream<MapType> {
   public:
-    static void Write(stream::Sink* sink, const std::unordered_map<K, V>& m) {
-        std::vector<std::pair<K, V>> ordered(m.begin(), m.end());
-        std::sort(
-            ordered.begin(), ordered.end(),
-            [](const std::pair<K, V>& a, const std::pair<K, V>& b) { return a.first < b.first; });
-        StreamIn(sink, ordered);
+    using ConstRefWrapper = std::reference_wrapper<const typename MapType::value_type>;
+    using RefVector = std::vector<ConstRefWrapper>;
+    static void Write(stream::Sink* sink, const MapType& m) {
+        // Use a vector of wrapped reference for sorting to avoid copying the elements.
+        RefVector refVector(m.cbegin(), m.cend());
+        std::sort(refVector.begin(), refVector.end(),
+                  [](const ConstRefWrapper& a, const ConstRefWrapper& b) {
+                      return a.get().first < b.get().first;
+                  });
+        StreamIn(sink, refVector);
     }
-    static MaybeError Read(Source* s, std::unordered_map<K, V>* m) {
-        using SizeT = decltype(std::declval<std::vector<std::pair<K, V>>>().size());
+    static MaybeError Read(Source* s, MapType* m) {
+        using SizeT = decltype(std::declval<RefVector>().size());
         SizeT size;
         DAWN_TRY(StreamOut(s, &size));
         *m = {};
         m->reserve(size);
         for (SizeT i = 0; i < size; ++i) {
-            std::pair<K, V> p;
+            std::pair<typename MapType::key_type, typename MapType::mapped_type> p;
             DAWN_TRY(StreamOut(s, &p));
             m->insert(std::move(p));
         }
@@ -385,24 +464,41 @@
     }
 };
 
-// Stream specialization for std::unordered_set<V> which sorts the entries
+template <typename S>
+concept IsSetLike = std::is_same_v<S,
+                                   std::unordered_set<typename S::key_type,
+                                                      typename S::hasher,
+                                                      typename S::key_equal,
+                                                      typename S::allocator_type>> ||
+                    std::is_same_v<S,
+                                   absl::flat_hash_set<typename S::key_type,
+                                                       typename S::hasher,
+                                                       typename S::key_equal,
+                                                       typename S::allocator_type>>;
+
+// Stream specialization for std::unordered_set<V> and absl::flat_hash_set which sorts the entries
 // to provide a stable ordering.
-template <typename V>
-class Stream<std::unordered_set<V>> {
+template <IsSetLike SetType>
+class Stream<SetType> {
   public:
-    static void Write(stream::Sink* sink, const std::unordered_set<V>& s) {
-        std::vector<V> ordered(s.begin(), s.end());
-        std::sort(ordered.begin(), ordered.end(), [](const V& a, const V& b) { return a < b; });
-        StreamIn(sink, ordered);
+    using ConstRefWrapper = std::reference_wrapper<const typename SetType::value_type>;
+    using RefVector = std::vector<ConstRefWrapper>;
+    static void Write(stream::Sink* sink, const SetType& s) {
+        // Use a vector of wrapped reference for sorting to avoid copying the elements.
+        RefVector refVector(s.cbegin(), s.cend());
+        std::sort(
+            refVector.begin(), refVector.end(),
+            [](const ConstRefWrapper& a, const ConstRefWrapper& b) { return a.get() < b.get(); });
+        StreamIn(sink, refVector);
     }
-    static MaybeError Read(Source* source, std::unordered_set<V>* s) {
-        using SizeT = decltype(std::declval<std::vector<V>>().size());
+    static MaybeError Read(Source* source, SetType* s) {
+        using SizeT = decltype(std::declval<RefVector>().size());
         SizeT size;
         DAWN_TRY(StreamOut(source, &size));
         *s = {};
         s->reserve(size);
         for (SizeT i = 0; i < size; ++i) {
-            V v;
+            typename SetType::key_type v;
             DAWN_TRY(StreamOut(source, &v));
             s->insert(std::move(v));
         }
@@ -410,6 +506,83 @@
     }
 };
 
+// Stream specialization for std::variant<Types...> which read/write the type id and the typed
+// value.
+template <typename... Types>
+class Stream<std::variant<Types...>> {
+  public:
+    using VariantType = std::variant<Types...>;
+
+    static void Write(stream::Sink* sink, const VariantType& t) { WriteImpl<0, Types...>(sink, t); }
+
+    static MaybeError Read(stream::Source* source, VariantType* t) {
+        size_t typeId;
+        DAWN_TRY(StreamOut(source, &typeId));
+        if (typeId >= sizeof...(Types)) {
+            return DAWN_VALIDATION_ERROR("Invalid variant type id");
+        } else {
+            return ReadImpl<0, Types...>(source, t, typeId);
+        }
+    }
+
+  private:
+    // WriteImpl template for trying multiple possible value types
+    template <size_t N,
+              typename TryType,
+              typename... RemainingTypes,
+              typename = std::enable_if_t<sizeof...(RemainingTypes) != 0>>
+    static inline void WriteImpl(stream::Sink* sink, const VariantType& t) {
+        if (std::holds_alternative<TryType>(t)) {
+            // Record the type index
+            StreamIn(sink, N);
+            // Record the value
+            StreamIn(sink, std::get<TryType>(t));
+        } else {
+            // Try the next type
+            WriteImpl<N + 1, RemainingTypes...>(sink, t);
+        }
+    }
+    // WriteImpl template for trying the last possible type
+    template <size_t N, typename LastType>
+    static inline void WriteImpl(stream::Sink* sink, const VariantType& t) {
+        // Variant must hold the last possible type if no previous match.
+        DAWN_ASSERT(std::holds_alternative<LastType>(t));
+        // Record the type index
+        StreamIn(sink, N);
+        // Record the value
+        StreamIn(sink, std::get<LastType>(t));
+    }
+    // ReadImpl template for trying multiple possible value types
+    template <size_t N,
+              typename TryType,
+              typename... RemainingTypes,
+              typename = std::enable_if_t<sizeof...(RemainingTypes) != 0>>
+    static inline MaybeError ReadImpl(stream::Source* source, VariantType* t, size_t typeId) {
+        if (typeId == N) {
+            // Read the value
+            TryType value;
+            DAWN_TRY(StreamOut(source, &value));
+            *t = VariantType(std::move(value));
+            return {};
+        } else {
+            // Try the next type
+            return ReadImpl<N + 1, RemainingTypes...>(source, t, typeId);
+        }
+    }
+    // ReadImpl template for trying the last possible type
+    template <size_t N, typename LastType>
+    static inline MaybeError ReadImpl(stream::Source* source, VariantType* t, size_t typeId) {
+        // typeId must be the id of last possible type N if not being 0..N-1, since it has been
+        // validated in range 0..N
+        DAWN_ASSERT(typeId == N);
+        // Read the value
+        LastType value;
+        DAWN_TRY(StreamOut(source, &value));
+        *t = VariantType(std::move(value));
+        return {};
+    }
+};
+
 // Helper class to contain the begin/end iterators of an iterable.
 namespace detail {
 template <typename Iterator>
diff --git a/src/dawn/tests/unittests/native/StreamTests.cpp b/src/dawn/tests/unittests/native/StreamTests.cpp
index 4bc3a82..e3b270f 100644
--- a/src/dawn/tests/unittests/native/StreamTests.cpp
+++ b/src/dawn/tests/unittests/native/StreamTests.cpp
@@ -29,16 +29,19 @@
 #include <cstring>
 #include <iomanip>
 #include <limits>
+#include <memory>
 #include <string>
 #include <tuple>
 #include <unordered_map>
 #include <unordered_set>
 #include <utility>
+#include <variant>
 #include <vector>
 
 #include "dawn/common/TypedInteger.h"
 #include "dawn/native/Blob.h"
 #include "dawn/native/Serializable.h"
+#include "dawn/native/ShaderModule.h"
 #include "dawn/native/TintUtils.h"
 #include "dawn/native/stream/BlobSource.h"
 #include "dawn/native/stream/ByteVectorSink.h"
@@ -185,6 +188,20 @@
     EXPECT_CACHE_KEY_EQ(str, expected);
 }
 
+// Test that ByteVectorSink serializes std::wstrings as expected.
+TEST(SerializeTests, StdWStrings) {
+    // Letter 𐩯 takes 4 bytes in UTF16
+    std::wstring str = L"∂y/∂x𐩯";
+
+    ByteVectorSink expected;
+
+    StreamIn(&expected, size_t(str.length()));
+    size_t bytes = str.length() * sizeof(wchar_t);
+    memcpy(expected.GetSpace(bytes), str.data(), bytes);
+
+    EXPECT_CACHE_KEY_EQ(str, expected);
+}
+
 // Test that ByteVectorSink serializes std::string_views as expected.
 TEST(SerializeTests, StdStringViews) {
     static constexpr std::string_view str("string");
@@ -283,6 +300,60 @@
     }
 }
 
+// Test that ByteVectorSink serializes std::variant as expected.
+TEST(SerializeTests, StdVariant) {
+    using VariantType = std::variant<std::string_view, uint32_t>;
+    std::string_view stringViewInput = "hello";
+    uint32_t u32Input = 42;
+    {
+        // Type id of std::string_view is 0 in VariantType
+        VariantType v1 = stringViewInput;
+        ByteVectorSink expected;
+        StreamIn(&expected, /* Type id */ size_t(0), stringViewInput);
+        EXPECT_CACHE_KEY_EQ(v1, expected);
+    }
+    {
+        // Type id of uint32_t is 1 in VariantType
+        VariantType v2 = u32Input;
+        ByteVectorSink expected;
+        StreamIn(&expected, /* Type id */ size_t(1), u32Input);
+        EXPECT_CACHE_KEY_EQ(v2, expected);
+    }
+}
+
+// Test that ByteVectorSink serializes std::unique_ptr as expected.
+TEST(SerializeTests, StdUniquePtr) {
+    // Test serializing a non-nullptr unique_ptr
+    {
+        std::unique_ptr<int> ptr = std::make_unique<int>(456);
+        ByteVectorSink expected;
+        StreamIn(&expected, true, 456);
+        EXPECT_CACHE_KEY_EQ(ptr, expected);
+    }
+    // Test serializing a nullptr unique_ptr ByteVectorSink expected;
+    {
+        ByteVectorSink expected;
+        StreamIn(&expected, false);
+        EXPECT_CACHE_KEY_EQ(std::unique_ptr<int>(), expected);
+    }
+}
+
+// Test that ByteVectorSink serializes std::reference_wrapper as expected.
+TEST(SerializeTests, StdReferenceWrapper) {
+    const int value1 = 123;
+    int value2 = 789;
+    std::reference_wrapper<const int> cref = std::cref(value1);
+    std::reference_wrapper<int> ref = std::ref(value2);
+    auto inputPair = std::make_pair(cref, ref);
+    auto inputPairRef = std::ref(inputPair);
+
+    // Expect all reference are serialized as the referenced value.
+    ByteVectorSink expected;
+    StreamIn(&expected, 123, 789);
+
+    EXPECT_CACHE_KEY_EQ(inputPairRef, expected);
+}
+
 // Test that ByteVectorSink serializes std::unordered_map as expected.
 TEST(SerializeTests, StdUnorderedMap) {
     std::unordered_map<uint32_t, std::string_view> m;
@@ -312,6 +383,49 @@
     EXPECT_CACHE_KEY_EQ(input, expected);
 }
 
+// Test that ByteVectorSink serializes ityp::array as expected.
+TEST(SerializeTests, ItypArray) {
+    const ityp::array<TypedIntegerForTest, TypedIntegerForTest, 4> input = {
+        TypedIntegerForTest(99), TypedIntegerForTest(4), TypedIntegerForTest(6),
+        TypedIntegerForTest(1)};
+
+    // Expect all values.
+    ByteVectorSink expected;
+    StreamIn(&expected, TypedIntegerForTest(99), TypedIntegerForTest(4), TypedIntegerForTest(6),
+             TypedIntegerForTest(1));
+
+    EXPECT_CACHE_KEY_EQ(input, expected);
+}
+
+// Test that ByteVectorSink serializes absl::flat_hash_map as expected.
+TEST(SerializeTests, AbslFlatHashMap) {
+    absl::flat_hash_map<uint32_t, std::string_view> m;
+
+    m[4] = "hello";
+    m[1] = "world";
+    m[7] = "test";
+    m[3] = "data";
+
+    // Expect the number of entries, followed by (K, V) pairs sorted in order of key.
+    ByteVectorSink expected;
+    StreamIn(&expected, size_t(4), std::make_pair(uint32_t(1), m[1]),
+             std::make_pair(uint32_t(3), m[3]), std::make_pair(uint32_t(4), m[4]),
+             std::make_pair(uint32_t(7), m[7]));
+
+    EXPECT_CACHE_KEY_EQ(m, expected);
+}
+
+// Test that ByteVectorSink serializes absl::flat_hash_set as expected.
+TEST(SerializeTests, AbslFlatHashSet) {
+    const absl::flat_hash_set<int> input = {99, 4, 6, 1};
+
+    // Expect the number of entries, followed by values sorted in order of key.
+    ByteVectorSink expected;
+    StreamIn(&expected, size_t(4), 1, 4, 6, 99);
+
+    EXPECT_CACHE_KEY_EQ(input, expected);
+}
+
 // Test that ByteVectorSink serializes tint::BindingPoint as expected.
 TEST(SerializeTests, TintSemBindingPoint) {
     tint::BindingPoint bp{3, 6};
@@ -412,6 +526,65 @@
     }
 }
 
+// Test that serializing then deserializing a std::unique_ptr yields the same data.
+// Tested here instead of in the type-parameterized tests since std::unique_ptr are not copyable.
+TEST(StreamTests, SerializeDeserializeUniquePtr) {
+    // Test a null unique_ptr
+    {
+        std::unique_ptr<int> in = nullptr;
+
+        ByteVectorSink sink;
+        StreamIn(&sink, in);
+
+        BlobSource src(CreateBlob(sink));
+        // Initialize the unique_ptr to a non-null value to check if it gets set to nullptr.
+        std::unique_ptr<int> out = std::make_unique<int>(123);
+        auto err = StreamOut(&src, &out);
+        EXPECT_FALSE(err.IsError());
+        EXPECT_EQ(out, nullptr);
+    }
+
+    // Test a unique_ptr holding  data
+    {
+        std::unique_ptr<int> in = std::make_unique<int>(456);
+
+        ByteVectorSink sink;
+        StreamIn(&sink, in);
+
+        BlobSource src(CreateBlob(sink));
+        // Initialize the unique_ptr to a nullptr. When deserializing, it should be pointed to a new
+        // allocated memory holding the expected data.
+        std::unique_ptr<int> out = nullptr;
+        auto err = StreamOut(&src, &out);
+        EXPECT_FALSE(err.IsError());
+        EXPECT_NE(out, nullptr);
+        // in and out should point to different memory locations, but the values should be the same.
+        EXPECT_NE(in, out);
+        EXPECT_EQ(*in, *out);
+    }
+}
+
+// Test that serializing then deserializing a ityp::array yields the same data.
+// Tested here instead of in the type-parameterized tests since ityp::array don't have operator==.
+TEST(StreamTests, SerializeDeserializeItypArray) {
+    using Array = ityp::array<TypedIntegerForTest, int, 5>;
+    constexpr Array in = {1, 5, 2, 4, 7};
+
+    ByteVectorSink sink;
+    StreamIn(&sink, in);
+
+    BlobSource src(CreateBlob(sink));
+    // Initialize the unique_ptr to a nullptr. When deserializing, it should be pointed to a new
+    // allocated memory holding the expected data.
+    Array out;
+    auto err = StreamOut(&src, &out);
+    EXPECT_FALSE(err.IsError());
+    // Check every element of the out array is the same as in.
+    for (TypedIntegerForTest i = TypedIntegerForTest(); i < in.size(); i++) {
+        EXPECT_EQ(in[i], out[i]);
+    }
+}
+
 template <size_t N>
 std::bitset<N - 1> BitsetFromBitString(const char (&str)[N]) {
     // N - 1 because the last character is the null terminator.
@@ -424,6 +597,8 @@
     std::vector<float>{6.50, 78.28, 92., 8.28},
     // Test various types of strings.
     std::vector<std::string>{"abcdefg", "9461849495", ""},
+    // Test various types of wstrings.
+    std::vector<std::wstring>{L"abcde54321", L"∂y/∂x𐩯" /* Letter 𐩯 takes 4 bytes in UTF16 */, L""},
     // Test pairs.
     std::vector<std::pair<int, float>>{{1, 3.}, {6, 4.}},
     // Test TypedIntegers
@@ -439,14 +614,28 @@
         BitsetFromBitString("100110010101011001100110101011001100101010110011001011011"),
         BitsetFromBitString("000110010101011000100110101011001100101010010011001010100"),
         BitsetFromBitString("111111111111111111111111111111111111111111111111111111111"), 0},
+    // Test unordered_maps.
+    std::vector<std::unordered_map<int, int>>{{},
+                                              {{4, 5}, {6, 8}, {99, 42}, {0, 0}},
+                                              {{100, 1}, {2, 300}, {300, 2}}},
     // Test unordered_sets.
     std::vector<std::unordered_set<int>>{{}, {4, 6, 99, 0}, {100, 300, 300}},
+    // Test absl::flat_hash_map.
+    std::vector<absl::flat_hash_map<int, int>>{{},
+                                               {{4, 5}, {6, 8}, {99, 42}, {0, 0}},
+                                               {{100, 1}, {2, 300}, {300, 2}}},
+    // Test absl::flat_hash_set.
+    std::vector<absl::flat_hash_set<int>>{{}, {4, 6, 99, 0}, {100, 300, 300}},
     // Test vectors.
     std::vector<std::vector<int>>{{}, {1, 5, 2, 7, 4}, {3, 3, 3, 3, 3, 3, 3}},
+    // Test variants.
+    std::vector<std::variant<uint32_t, std::string, std::vector<std::bitset<7>>>>{
+        uint32_t{123}, std::string{"1, 5, 2, 7, 4"},
+        std::vector<std::bitset<7>>{0b1001011, 0b0011010, 0b0000000, 0b1111111}},
     // Test different size of arrays.
     std::vector<std::array<int, 3>>{{1, 5, 2}, {-3, -3, -3}},
     std::vector<std::array<uint8_t, 5>>{{5, 2, 7, 9, 6}, {3, 3, 3, 3, 42}},
-    // array of non-fundamental type
+    // Test array of non-fundamental type.
     std::vector<std::array<std::string, 2>>{{"abcd", "efg"}, {"123hij", ""}});
 
 static auto kStreamValueInitListParams = std::make_tuple(