Dawn: Support stream for types used in shader module reflection
This CL add stream in/out support for types used in the shader module
reflection, including std::wstring, std::variant, std::unique_ptr,
ityp::array, absl::flat_hash_map, andabsl::flat_hash_set. This CL also
add support for stream out std::reference_wrapper, which are used for
sorting the items in map/set without copying them. Related unittests are
also added.
Bug: 402772740
Change-Id: I0b93c3638879ef749428c4f37420b132b2a0849f
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/241474
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Loko Kung <lokokung@google.com>
Commit-Queue: Zhaoming Jiang <zhaoming.jiang@microsoft.com>
diff --git a/src/dawn/native/stream/Stream.cpp b/src/dawn/native/stream/Stream.cpp
index 0e66c3c..1682b63 100644
--- a/src/dawn/native/stream/Stream.cpp
+++ b/src/dawn/native/stream/Stream.cpp
@@ -54,6 +54,27 @@
}
template <>
+void Stream<std::wstring>::Write(Sink* s, const std::wstring& t) {
+ StreamIn(s, t.length());
+ size_t size = t.length() * sizeof(wchar_t);
+ if (size > 0) {
+ void* ptr = s->GetSpace(size);
+ memcpy(ptr, t.data(), size);
+ }
+}
+
+template <>
+MaybeError Stream<std::wstring>::Read(Source* s, std::wstring* t) {
+ size_t length;
+ DAWN_TRY(StreamOut(s, &length));
+ size_t size = length * sizeof(wchar_t);
+ const void* ptr;
+ DAWN_TRY(s->Read(&ptr, size));
+ *t = std::wstring(static_cast<const wchar_t*>(ptr), length);
+ return {};
+}
+
+template <>
void Stream<std::string_view>::Write(Sink* s, const std::string_view& t) {
StreamIn(s, t.length());
size_t size = t.length() * sizeof(char);
diff --git a/src/dawn/native/stream/Stream.h b/src/dawn/native/stream/Stream.h
index 77106f9..d501925 100644
--- a/src/dawn/native/stream/Stream.h
+++ b/src/dawn/native/stream/Stream.h
@@ -32,6 +32,7 @@
#include <bitset>
#include <functional>
#include <limits>
+#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <utility>
@@ -39,8 +40,11 @@
#include <optional>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "dawn/common/Platform.h"
#include "dawn/common/TypedInteger.h"
+#include "dawn/common/ityp_array.h"
#include "dawn/native/Error.h"
#include "dawn/native/stream/Sink.h"
#include "dawn/native/stream/Source.h"
@@ -225,6 +229,43 @@
}
};
+// Stream specialization for unique pointers. We always serialize/deserialize via value, not by
+// pointer. To handle nullptr scenarios, we always serialize whether the pointer was not nullptr,
+// followed by the contents if applicable.
+template <typename T>
+class Stream<std::unique_ptr<T>, std::enable_if_t<!std::is_pointer_v<T>>> {
+ public:
+ static void Write(stream::Sink* sink, const std::unique_ptr<T>& t) {
+ StreamIn(sink, t != nullptr);
+ if (t != nullptr) {
+ StreamIn(sink, *t);
+ }
+ }
+
+ static MaybeError Read(Source* source, std::unique_ptr<T>* t) {
+ bool notNullptr;
+ DAWN_TRY(StreamOut(source, ¬Nullptr));
+ if (notNullptr) {
+ T out;
+ DAWN_TRY(StreamOut(source, &out));
+ *t = std::make_unique<T>(std::move(out));
+ } else {
+ *t = nullptr;
+ }
+ return {};
+ }
+};
+
+// Stream specialization for reference_wrapper. For serialization, unwrap it to the reference const
+// T& and call Write(sink, const T&).
+template <typename T>
+class Stream<std::reference_wrapper<T>> {
+ public:
+ static void Write(stream::Sink* sink, const std::reference_wrapper<T>& t) {
+ StreamIn(sink, t.get());
+ }
+};
+
// Stream specialization for std::optional
template <typename T>
class Stream<std::optional<T>> {
@@ -342,6 +383,26 @@
}
};
+// Stream specialization for ityp::array<Index, Value, Size>.
+template <typename Index, typename Value, size_t Size>
+class Stream<ityp::array<Index, Value, Size>> {
+ public:
+ using ArrayType = ityp::array<Index, Value, Size>;
+
+ static void Write(Sink* s, const ArrayType& v) {
+ for (const Value& it : v) {
+ StreamIn(s, it);
+ }
+ }
+
+ static MaybeError Read(Source* s, ArrayType* v) {
+ for (auto& el : *v) {
+ DAWN_TRY(StreamOut(s, el));
+ }
+ return {};
+ }
+};
+
// Stream specialization for std::pair.
template <typename A, typename B>
class Stream<std::pair<A, B>> {
@@ -358,26 +419,44 @@
}
};
-// Stream specialization for std::unordered_map<K, V> which sorts the entries
-// to provide a stable ordering.
-template <typename K, typename V>
-class Stream<std::unordered_map<K, V>> {
+template <typename M>
+concept IsMapLike = std::is_same_v<M,
+ std::unordered_map<typename M::key_type,
+ typename M::mapped_type,
+ typename M::hasher,
+ typename M::key_equal,
+ typename M::allocator_type>> ||
+ std::is_same_v<M,
+ absl::flat_hash_map<typename M::key_type,
+ typename M::mapped_type,
+ typename M::hasher,
+ typename M::key_equal,
+ typename M::allocator_type>>;
+
+// Stream specialization for std::unordered_map<K, V> and absl::flat_hash_map which sorts the
+// entries to provide a stable ordering.
+template <IsMapLike MapType>
+class Stream<MapType> {
public:
- static void Write(stream::Sink* sink, const std::unordered_map<K, V>& m) {
- std::vector<std::pair<K, V>> ordered(m.begin(), m.end());
- std::sort(
- ordered.begin(), ordered.end(),
- [](const std::pair<K, V>& a, const std::pair<K, V>& b) { return a.first < b.first; });
- StreamIn(sink, ordered);
+ using ConstRefWrapper = std::reference_wrapper<const typename MapType::value_type>;
+ using RefVector = std::vector<ConstRefWrapper>;
+ static void Write(stream::Sink* sink, const MapType& m) {
+ // Use a vector of wrapped reference for sorting to avoid copying the elements.
+ RefVector refVector(m.cbegin(), m.cend());
+ std::sort(refVector.begin(), refVector.end(),
+ [](const ConstRefWrapper& a, const ConstRefWrapper& b) {
+ return a.get().first < b.get().first;
+ });
+ StreamIn(sink, refVector);
}
- static MaybeError Read(Source* s, std::unordered_map<K, V>* m) {
- using SizeT = decltype(std::declval<std::vector<std::pair<K, V>>>().size());
+ static MaybeError Read(Source* s, MapType* m) {
+ using SizeT = decltype(std::declval<RefVector>().size());
SizeT size;
DAWN_TRY(StreamOut(s, &size));
*m = {};
m->reserve(size);
for (SizeT i = 0; i < size; ++i) {
- std::pair<K, V> p;
+ std::pair<typename MapType::key_type, typename MapType::mapped_type> p;
DAWN_TRY(StreamOut(s, &p));
m->insert(std::move(p));
}
@@ -385,24 +464,41 @@
}
};
-// Stream specialization for std::unordered_set<V> which sorts the entries
+template <typename S>
+concept IsSetLike = std::is_same_v<S,
+ std::unordered_set<typename S::key_type,
+ typename S::hasher,
+ typename S::key_equal,
+ typename S::allocator_type>> ||
+ std::is_same_v<S,
+ absl::flat_hash_set<typename S::key_type,
+ typename S::hasher,
+ typename S::key_equal,
+ typename S::allocator_type>>;
+
+// Stream specialization for std::unordered_set<V> and absl::flat_hash_set which sorts the entries
// to provide a stable ordering.
-template <typename V>
-class Stream<std::unordered_set<V>> {
+template <IsSetLike SetType>
+class Stream<SetType> {
public:
- static void Write(stream::Sink* sink, const std::unordered_set<V>& s) {
- std::vector<V> ordered(s.begin(), s.end());
- std::sort(ordered.begin(), ordered.end(), [](const V& a, const V& b) { return a < b; });
- StreamIn(sink, ordered);
+ using ConstRefWrapper = std::reference_wrapper<const typename SetType::value_type>;
+ using RefVector = std::vector<ConstRefWrapper>;
+ static void Write(stream::Sink* sink, const SetType& s) {
+ // Use a vector of wrapped reference for sorting to avoid copying the elements.
+ RefVector refVector(s.cbegin(), s.cend());
+ std::sort(
+ refVector.begin(), refVector.end(),
+ [](const ConstRefWrapper& a, const ConstRefWrapper& b) { return a.get() < b.get(); });
+ StreamIn(sink, refVector);
}
- static MaybeError Read(Source* source, std::unordered_set<V>* s) {
- using SizeT = decltype(std::declval<std::vector<V>>().size());
+ static MaybeError Read(Source* source, SetType* s) {
+ using SizeT = decltype(std::declval<RefVector>().size());
SizeT size;
DAWN_TRY(StreamOut(source, &size));
*s = {};
s->reserve(size);
for (SizeT i = 0; i < size; ++i) {
- V v;
+ typename SetType::key_type v;
DAWN_TRY(StreamOut(source, &v));
s->insert(std::move(v));
}
@@ -410,6 +506,83 @@
}
};
+// Stream specialization for std::variant<Types...> which read/write the type id and the typed
+// value.
+template <typename... Types>
+class Stream<std::variant<Types...>> {
+ public:
+ using VariantType = std::variant<Types...>;
+
+ static void Write(stream::Sink* sink, const VariantType& t) { WriteImpl<0, Types...>(sink, t); }
+
+ static MaybeError Read(stream::Source* source, VariantType* t) {
+ size_t typeId;
+ DAWN_TRY(StreamOut(source, &typeId));
+ if (typeId >= sizeof...(Types)) {
+ return DAWN_VALIDATION_ERROR("Invalid variant type id");
+ } else {
+ return ReadImpl<0, Types...>(source, t, typeId);
+ }
+ }
+
+ private:
+ // WriteImpl template for trying multiple possible value types
+ template <size_t N,
+ typename TryType,
+ typename... RemainingTypes,
+ typename = std::enable_if_t<sizeof...(RemainingTypes) != 0>>
+ static inline void WriteImpl(stream::Sink* sink, const VariantType& t) {
+ if (std::holds_alternative<TryType>(t)) {
+ // Record the type index
+ StreamIn(sink, N);
+ // Record the value
+ StreamIn(sink, std::get<TryType>(t));
+ } else {
+ // Try the next type
+ WriteImpl<N + 1, RemainingTypes...>(sink, t);
+ }
+ }
+ // WriteImpl template for trying the last possible type
+ template <size_t N, typename LastType>
+ static inline void WriteImpl(stream::Sink* sink, const VariantType& t) {
+ // Variant must hold the last possible type if no previous match.
+ DAWN_ASSERT(std::holds_alternative<LastType>(t));
+ // Record the type index
+ StreamIn(sink, N);
+ // Record the value
+ StreamIn(sink, std::get<LastType>(t));
+ }
+ // ReadImpl template for trying multiple possible value types
+ template <size_t N,
+ typename TryType,
+ typename... RemainingTypes,
+ typename = std::enable_if_t<sizeof...(RemainingTypes) != 0>>
+ static inline MaybeError ReadImpl(stream::Source* source, VariantType* t, size_t typeId) {
+ if (typeId == N) {
+ // Read the value
+ TryType value;
+ DAWN_TRY(StreamOut(source, &value));
+ *t = VariantType(std::move(value));
+ return {};
+ } else {
+ // Try the next type
+ return ReadImpl<N + 1, RemainingTypes...>(source, t, typeId);
+ }
+ }
+ // ReadImpl template for trying the last possible type
+ template <size_t N, typename LastType>
+ static inline MaybeError ReadImpl(stream::Source* source, VariantType* t, size_t typeId) {
+ // typeId must be the id of last possible type N if not being 0..N-1, since it has been
+ // validated in range 0..N
+ DAWN_ASSERT(typeId == N);
+ // Read the value
+ LastType value;
+ DAWN_TRY(StreamOut(source, &value));
+ *t = VariantType(std::move(value));
+ return {};
+ }
+};
+
// Helper class to contain the begin/end iterators of an iterable.
namespace detail {
template <typename Iterator>
diff --git a/src/dawn/tests/unittests/native/StreamTests.cpp b/src/dawn/tests/unittests/native/StreamTests.cpp
index 4bc3a82..e3b270f 100644
--- a/src/dawn/tests/unittests/native/StreamTests.cpp
+++ b/src/dawn/tests/unittests/native/StreamTests.cpp
@@ -29,16 +29,19 @@
#include <cstring>
#include <iomanip>
#include <limits>
+#include <memory>
#include <string>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <utility>
+#include <variant>
#include <vector>
#include "dawn/common/TypedInteger.h"
#include "dawn/native/Blob.h"
#include "dawn/native/Serializable.h"
+#include "dawn/native/ShaderModule.h"
#include "dawn/native/TintUtils.h"
#include "dawn/native/stream/BlobSource.h"
#include "dawn/native/stream/ByteVectorSink.h"
@@ -185,6 +188,20 @@
EXPECT_CACHE_KEY_EQ(str, expected);
}
+// Test that ByteVectorSink serializes std::wstrings as expected.
+TEST(SerializeTests, StdWStrings) {
+ // Letter 𐩯 takes 4 bytes in UTF16
+ std::wstring str = L"∂y/∂x𐩯";
+
+ ByteVectorSink expected;
+
+ StreamIn(&expected, size_t(str.length()));
+ size_t bytes = str.length() * sizeof(wchar_t);
+ memcpy(expected.GetSpace(bytes), str.data(), bytes);
+
+ EXPECT_CACHE_KEY_EQ(str, expected);
+}
+
// Test that ByteVectorSink serializes std::string_views as expected.
TEST(SerializeTests, StdStringViews) {
static constexpr std::string_view str("string");
@@ -283,6 +300,60 @@
}
}
+// Test that ByteVectorSink serializes std::variant as expected.
+TEST(SerializeTests, StdVariant) {
+ using VariantType = std::variant<std::string_view, uint32_t>;
+ std::string_view stringViewInput = "hello";
+ uint32_t u32Input = 42;
+ {
+ // Type id of std::string_view is 0 in VariantType
+ VariantType v1 = stringViewInput;
+ ByteVectorSink expected;
+ StreamIn(&expected, /* Type id */ size_t(0), stringViewInput);
+ EXPECT_CACHE_KEY_EQ(v1, expected);
+ }
+ {
+ // Type id of uint32_t is 1 in VariantType
+ VariantType v2 = u32Input;
+ ByteVectorSink expected;
+ StreamIn(&expected, /* Type id */ size_t(1), u32Input);
+ EXPECT_CACHE_KEY_EQ(v2, expected);
+ }
+}
+
+// Test that ByteVectorSink serializes std::unique_ptr as expected.
+TEST(SerializeTests, StdUniquePtr) {
+ // Test serializing a non-nullptr unique_ptr
+ {
+ std::unique_ptr<int> ptr = std::make_unique<int>(456);
+ ByteVectorSink expected;
+ StreamIn(&expected, true, 456);
+ EXPECT_CACHE_KEY_EQ(ptr, expected);
+ }
+ // Test serializing a nullptr unique_ptr ByteVectorSink expected;
+ {
+ ByteVectorSink expected;
+ StreamIn(&expected, false);
+ EXPECT_CACHE_KEY_EQ(std::unique_ptr<int>(), expected);
+ }
+}
+
+// Test that ByteVectorSink serializes std::reference_wrapper as expected.
+TEST(SerializeTests, StdReferenceWrapper) {
+ const int value1 = 123;
+ int value2 = 789;
+ std::reference_wrapper<const int> cref = std::cref(value1);
+ std::reference_wrapper<int> ref = std::ref(value2);
+ auto inputPair = std::make_pair(cref, ref);
+ auto inputPairRef = std::ref(inputPair);
+
+ // Expect all reference are serialized as the referenced value.
+ ByteVectorSink expected;
+ StreamIn(&expected, 123, 789);
+
+ EXPECT_CACHE_KEY_EQ(inputPairRef, expected);
+}
+
// Test that ByteVectorSink serializes std::unordered_map as expected.
TEST(SerializeTests, StdUnorderedMap) {
std::unordered_map<uint32_t, std::string_view> m;
@@ -312,6 +383,49 @@
EXPECT_CACHE_KEY_EQ(input, expected);
}
+// Test that ByteVectorSink serializes ityp::array as expected.
+TEST(SerializeTests, ItypArray) {
+ const ityp::array<TypedIntegerForTest, TypedIntegerForTest, 4> input = {
+ TypedIntegerForTest(99), TypedIntegerForTest(4), TypedIntegerForTest(6),
+ TypedIntegerForTest(1)};
+
+ // Expect all values.
+ ByteVectorSink expected;
+ StreamIn(&expected, TypedIntegerForTest(99), TypedIntegerForTest(4), TypedIntegerForTest(6),
+ TypedIntegerForTest(1));
+
+ EXPECT_CACHE_KEY_EQ(input, expected);
+}
+
+// Test that ByteVectorSink serializes absl::flat_hash_map as expected.
+TEST(SerializeTests, AbslFlatHashMap) {
+ absl::flat_hash_map<uint32_t, std::string_view> m;
+
+ m[4] = "hello";
+ m[1] = "world";
+ m[7] = "test";
+ m[3] = "data";
+
+ // Expect the number of entries, followed by (K, V) pairs sorted in order of key.
+ ByteVectorSink expected;
+ StreamIn(&expected, size_t(4), std::make_pair(uint32_t(1), m[1]),
+ std::make_pair(uint32_t(3), m[3]), std::make_pair(uint32_t(4), m[4]),
+ std::make_pair(uint32_t(7), m[7]));
+
+ EXPECT_CACHE_KEY_EQ(m, expected);
+}
+
+// Test that ByteVectorSink serializes absl::flat_hash_set as expected.
+TEST(SerializeTests, AbslFlatHashSet) {
+ const absl::flat_hash_set<int> input = {99, 4, 6, 1};
+
+ // Expect the number of entries, followed by values sorted in order of key.
+ ByteVectorSink expected;
+ StreamIn(&expected, size_t(4), 1, 4, 6, 99);
+
+ EXPECT_CACHE_KEY_EQ(input, expected);
+}
+
// Test that ByteVectorSink serializes tint::BindingPoint as expected.
TEST(SerializeTests, TintSemBindingPoint) {
tint::BindingPoint bp{3, 6};
@@ -412,6 +526,65 @@
}
}
+// Test that serializing then deserializing a std::unique_ptr yields the same data.
+// Tested here instead of in the type-parameterized tests since std::unique_ptr are not copyable.
+TEST(StreamTests, SerializeDeserializeUniquePtr) {
+ // Test a null unique_ptr
+ {
+ std::unique_ptr<int> in = nullptr;
+
+ ByteVectorSink sink;
+ StreamIn(&sink, in);
+
+ BlobSource src(CreateBlob(sink));
+ // Initialize the unique_ptr to a non-null value to check if it gets set to nullptr.
+ std::unique_ptr<int> out = std::make_unique<int>(123);
+ auto err = StreamOut(&src, &out);
+ EXPECT_FALSE(err.IsError());
+ EXPECT_EQ(out, nullptr);
+ }
+
+ // Test a unique_ptr holding data
+ {
+ std::unique_ptr<int> in = std::make_unique<int>(456);
+
+ ByteVectorSink sink;
+ StreamIn(&sink, in);
+
+ BlobSource src(CreateBlob(sink));
+ // Initialize the unique_ptr to a nullptr. When deserializing, it should be pointed to a new
+ // allocated memory holding the expected data.
+ std::unique_ptr<int> out = nullptr;
+ auto err = StreamOut(&src, &out);
+ EXPECT_FALSE(err.IsError());
+ EXPECT_NE(out, nullptr);
+ // in and out should point to different memory locations, but the values should be the same.
+ EXPECT_NE(in, out);
+ EXPECT_EQ(*in, *out);
+ }
+}
+
+// Test that serializing then deserializing a ityp::array yields the same data.
+// Tested here instead of in the type-parameterized tests since ityp::array don't have operator==.
+TEST(StreamTests, SerializeDeserializeItypArray) {
+ using Array = ityp::array<TypedIntegerForTest, int, 5>;
+ constexpr Array in = {1, 5, 2, 4, 7};
+
+ ByteVectorSink sink;
+ StreamIn(&sink, in);
+
+ BlobSource src(CreateBlob(sink));
+ // Initialize the unique_ptr to a nullptr. When deserializing, it should be pointed to a new
+ // allocated memory holding the expected data.
+ Array out;
+ auto err = StreamOut(&src, &out);
+ EXPECT_FALSE(err.IsError());
+ // Check every element of the out array is the same as in.
+ for (TypedIntegerForTest i = TypedIntegerForTest(); i < in.size(); i++) {
+ EXPECT_EQ(in[i], out[i]);
+ }
+}
+
template <size_t N>
std::bitset<N - 1> BitsetFromBitString(const char (&str)[N]) {
// N - 1 because the last character is the null terminator.
@@ -424,6 +597,8 @@
std::vector<float>{6.50, 78.28, 92., 8.28},
// Test various types of strings.
std::vector<std::string>{"abcdefg", "9461849495", ""},
+ // Test various types of wstrings.
+ std::vector<std::wstring>{L"abcde54321", L"∂y/∂x𐩯" /* Letter 𐩯 takes 4 bytes in UTF16 */, L""},
// Test pairs.
std::vector<std::pair<int, float>>{{1, 3.}, {6, 4.}},
// Test TypedIntegers
@@ -439,14 +614,28 @@
BitsetFromBitString("100110010101011001100110101011001100101010110011001011011"),
BitsetFromBitString("000110010101011000100110101011001100101010010011001010100"),
BitsetFromBitString("111111111111111111111111111111111111111111111111111111111"), 0},
+ // Test unordered_maps.
+ std::vector<std::unordered_map<int, int>>{{},
+ {{4, 5}, {6, 8}, {99, 42}, {0, 0}},
+ {{100, 1}, {2, 300}, {300, 2}}},
// Test unordered_sets.
std::vector<std::unordered_set<int>>{{}, {4, 6, 99, 0}, {100, 300, 300}},
+ // Test absl::flat_hash_map.
+ std::vector<absl::flat_hash_map<int, int>>{{},
+ {{4, 5}, {6, 8}, {99, 42}, {0, 0}},
+ {{100, 1}, {2, 300}, {300, 2}}},
+ // Test absl::flat_hash_set.
+ std::vector<absl::flat_hash_set<int>>{{}, {4, 6, 99, 0}, {100, 300, 300}},
// Test vectors.
std::vector<std::vector<int>>{{}, {1, 5, 2, 7, 4}, {3, 3, 3, 3, 3, 3, 3}},
+ // Test variants.
+ std::vector<std::variant<uint32_t, std::string, std::vector<std::bitset<7>>>>{
+ uint32_t{123}, std::string{"1, 5, 2, 7, 4"},
+ std::vector<std::bitset<7>>{0b1001011, 0b0011010, 0b0000000, 0b1111111}},
// Test different size of arrays.
std::vector<std::array<int, 3>>{{1, 5, 2}, {-3, -3, -3}},
std::vector<std::array<uint8_t, 5>>{{5, 2, 7, 9, 6}, {3, 3, 3, 3, 42}},
- // array of non-fundamental type
+ // Test array of non-fundamental type.
std::vector<std::array<std::string, 2>>{{"abcd", "efg"}, {"123hij", ""}});
static auto kStreamValueInitListParams = std::make_tuple(