Refactor [de]serialization functions out of CacheKey
Storing values into the cache will need to serialize and deserialize
values in addition to keys. This patch factors the serialization
utilities out of CacheKey into a more general "Stream" utility that
supports both input and output for serialization and deserialization.
Multiple files are not renamed to make parsing the diff easier. They
will be renamed in Change If61f0466d79e7759ed32c4ddf541ad0c17247996.
Bug: dawn:1480, dawn:1481
Change-Id: If7594c4ff7117454c1ab3d0afaeee5653120add8
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/96480
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Loko Kung <lokokung@google.com>
Commit-Queue: Austin Eng <enga@chromium.org>
diff --git a/generator/templates/dawn/native/CacheKey.cpp b/generator/templates/dawn/native/CacheKey.cpp
index cf31449..82b6404 100644
--- a/generator/templates/dawn/native/CacheKey.cpp
+++ b/generator/templates/dawn/native/CacheKey.cpp
@@ -25,44 +25,44 @@
namespace {{native_namespace}} {
//
-// Cache key serializers for wgpu structures used in caching.
+// Cache key writers for wgpu structures used in caching.
//
-{% macro render_serializer(member) %}
+{% macro render_writer(member) %}
{%- set name = member.name.camelCase() -%}
{% if member.length == None %}
- key->Record(t.{{name}});
+ StreamIn(sink, t.{{name}});
{% elif member.length == "strlen" %}
- key->RecordIterable(t.{{name}}, strlen(t.{{name}}));
+ StreamIn(sink, Iterable(t.{{name}}, strlen(t.{{name}})));
{% else %}
- key->RecordIterable(t.{{name}}, t.{{member.length.name.camelCase()}});
+ StreamIn(sink, Iterable(t.{{name}}, t.{{member.length.name.camelCase()}}));
{% endif %}
{% endmacro %}
-{# Helper macro to render serializers. Should be used in a call block to provide additional custom
+{# Helper macro to render writers. Should be used in a call block to provide additional custom
handling when necessary. The optional `omit` field can be used to omit fields that are either
handled in the custom code, or unnecessary in the serialized output.
Example:
- {% call render_cache_key_serializer("struct name", omits=["omit field"]) %}
+ {% call render_cache_key_writer("struct name", omits=["omit field"]) %}
// Custom C++ code to handle special types/members that are hard to generate code for
{% endcall %}
#}
-{% macro render_cache_key_serializer(json_type, omits=[]) %}
+{% macro render_cache_key_writer(json_type, omits=[]) %}
{%- set cpp_type = types[json_type].name.CamelCase() -%}
template <>
- void CacheKeySerializer<{{cpp_type}}>::Serialize(CacheKey* key, const {{cpp_type}}& t) {
+ void stream::Stream<{{cpp_type}}>::Write(stream::Sink* sink, const {{cpp_type}}& t) {
{{ caller() }}
{% for member in types[json_type].members %}
{%- if not member.name.get() in omits %}
- {{render_serializer(member)}}
+ {{render_writer(member)}}
{%- endif %}
{% endfor %}
}
{% endmacro %}
-{% call render_cache_key_serializer("adapter properties") %}
+{% call render_cache_key_writer("adapter properties") %}
{% endcall %}
-{% call render_cache_key_serializer("dawn cache device descriptor") %}
+{% call render_cache_key_writer("dawn cache device descriptor") %}
{% endcall %}
} // namespace {{native_namespace}}
diff --git a/src/dawn/native/BUILD.gn b/src/dawn/native/BUILD.gn
index d8ddca2..31f1e84 100644
--- a/src/dawn/native/BUILD.gn
+++ b/src/dawn/native/BUILD.gn
@@ -334,6 +334,13 @@
"VertexFormat.cpp",
"VertexFormat.h",
"dawn_platform.h",
+ "stream/BlobSource.cpp",
+ "stream/BlobSource.h",
+ "stream/ByteVectorSink.cpp",
+ "stream/ByteVectorSink.h",
+ "stream/Sink.h",
+ "stream/Source.h",
+ "stream/Stream.cpp",
"stream/Stream.h",
"utils/WGPUHelpers.cpp",
"utils/WGPUHelpers.h",
diff --git a/src/dawn/native/CMakeLists.txt b/src/dawn/native/CMakeLists.txt
index 9306178..d665b8d 100644
--- a/src/dawn/native/CMakeLists.txt
+++ b/src/dawn/native/CMakeLists.txt
@@ -193,6 +193,13 @@
"dawn_platform.h"
"webgpu_absl_format.cpp"
"webgpu_absl_format.h"
+ "stream/BlobSource.cpp"
+ "stream/BlobSource.h"
+ "stream/ByteVectorSink.cpp"
+ "stream/ByteVectorSink.h"
+ "stream/Sink.h"
+ "stream/Source.h"
+ "stream/Stream.cpp"
"stream/Stream.h"
"utils/WGPUHelpers.cpp"
"utils/WGPUHelpers.h"
diff --git a/src/dawn/native/CacheKey.cpp b/src/dawn/native/CacheKey.cpp
index 414b915..4db9535 100644
--- a/src/dawn/native/CacheKey.cpp
+++ b/src/dawn/native/CacheKey.cpp
@@ -20,32 +20,9 @@
namespace dawn::native {
-std::ostream& operator<<(std::ostream& os, const CacheKey& key) {
- os << std::hex;
- for (const int b : key) {
- os << std::setfill('0') << std::setw(2) << b << " ";
- }
- os << std::dec;
- return os;
-}
-
template <>
-void CacheKeySerializer<std::string>::Serialize(CacheKey* key, const std::string& t) {
- key->Record(t.length());
- key->insert(key->end(), t.begin(), t.end());
-}
-
-template <>
-void CacheKeySerializer<std::string_view>::Serialize(CacheKey* key, const std::string_view& t) {
- key->Record(t.length());
- key->insert(key->end(), t.begin(), t.end());
-}
-
-template <>
-void CacheKeySerializer<CacheKey>::Serialize(CacheKey* key, const CacheKey& t) {
- // For nested cache keys, we do not record the length, and just copy the key so that it
- // appears we just flatten the keys into a single key.
- key->insert(key->end(), t.begin(), t.end());
+void stream::Stream<CacheKey>::Write(stream::Sink* sink, const CacheKey& t) {
+ StreamIn(sink, static_cast<const ByteVectorSink&>(t));
}
} // namespace dawn::native
diff --git a/src/dawn/native/CacheKey.h b/src/dawn/native/CacheKey.h
index 786ae86..6cec3b6 100644
--- a/src/dawn/native/CacheKey.h
+++ b/src/dawn/native/CacheKey.h
@@ -15,40 +15,16 @@
#ifndef SRC_DAWN_NATIVE_CACHEKEY_H_
#define SRC_DAWN_NATIVE_CACHEKEY_H_
-#include <algorithm>
-#include <bitset>
-#include <functional>
-#include <iostream>
-#include <limits>
-#include <memory>
-#include <type_traits>
-#include <unordered_map>
#include <utility>
-#include <vector>
-#include "dawn/common/TypedInteger.h"
-#include "dawn/common/ityp_array.h"
+#include "dawn/native/stream/ByteVectorSink.h"
+#include "dawn/native/stream/Stream.h"
namespace dawn::native {
-// Forward declare classes because of co-dependency.
-class CacheKey;
-class CachedObject;
-
-// Stream operator for CacheKey for debugging.
-std::ostream& operator<<(std::ostream& os, const CacheKey& key);
-
-// Overridable serializer struct that should be implemented for cache key serializable
-// types/classes.
-template <typename T, typename SFINAE = void>
-class CacheKeySerializer {
+class CacheKey : public stream::ByteVectorSink {
public:
- static void Serialize(CacheKey* key, const T& t);
-};
-
-class CacheKey : public std::vector<uint8_t> {
- public:
- using std::vector<uint8_t>::vector;
+ using stream::ByteVectorSink::ByteVectorSink;
enum class Type { ComputePipeline, RenderPipeline, Shader };
@@ -61,49 +37,13 @@
const T& UnsafeGetValue() const { return mValue; }
+ // Friend definition of StreamIn which can be found by ADL to override
+ // stream::StreamIn<T>.
+ friend constexpr void StreamIn(stream::Sink*, const UnsafeUnkeyedValue&) {}
+
private:
T mValue;
};
-
- template <typename T>
- CacheKey& Record(const T& t) {
- CacheKeySerializer<T>::Serialize(this, t);
- return *this;
- }
- template <typename T, typename... Args>
- CacheKey& Record(const T& t, const Args&... args) {
- CacheKeySerializer<T>::Serialize(this, t);
- return Record(args...);
- }
-
- // Records iterables by prepending the number of elements. Some common iterables are have a
- // CacheKeySerializer implemented to avoid needing to split them out when recording, i.e.
- // strings and CacheKeys, but they fundamentally do the same as this function.
- template <typename IterableT>
- CacheKey& RecordIterable(const IterableT& iterable) {
- // Always record the size of generic iterables as a size_t for now.
- Record(static_cast<size_t>(iterable.size()));
- for (auto it = iterable.begin(); it != iterable.end(); ++it) {
- Record(*it);
- }
- return *this;
- }
- template <typename Index, typename Value, size_t Size>
- CacheKey& RecordIterable(const ityp::array<Index, Value, Size>& iterable) {
- Record(static_cast<Index>(iterable.size()));
- for (auto it = iterable.begin(); it != iterable.end(); ++it) {
- Record(*it);
- }
- return *this;
- }
- template <typename Ptr>
- CacheKey& RecordIterable(const Ptr* ptr, size_t n) {
- Record(n);
- for (size_t i = 0; i < n; ++i) {
- Record(ptr[i]);
- }
- return *this;
- }
};
template <typename T>
@@ -113,8 +53,4 @@
} // namespace dawn::native
-// CacheKeySerializer implementation temporarily moved to stream/Stream.h to
-// simplify the diff in the refactor to stream::Stream.
-#include "dawn/native/stream/Stream.h"
-
#endif // SRC_DAWN_NATIVE_CACHEKEY_H_
diff --git a/src/dawn/native/CacheKeyImplTint.cpp b/src/dawn/native/CacheKeyImplTint.cpp
index 1a38486..e566954 100644
--- a/src/dawn/native/CacheKeyImplTint.cpp
+++ b/src/dawn/native/CacheKeyImplTint.cpp
@@ -20,10 +20,10 @@
// static
template <>
-void CacheKeySerializer<tint::Program>::Serialize(CacheKey* key, const tint::Program& p) {
+void stream::Stream<tint::Program>::Write(stream::Sink* sink, const tint::Program& p) {
#if TINT_BUILD_WGSL_WRITER
tint::writer::wgsl::Options options{};
- key->Record(tint::writer::wgsl::Generate(&p, options).wgsl);
+ StreamIn(sink, tint::writer::wgsl::Generate(&p, options).wgsl);
#else
// TODO(crbug.com/dawn/1481): We shouldn't need to write back to WGSL if we have a CacheKey
// built from the initial shader module input. Then, we would never need to parse the program
@@ -34,8 +34,21 @@
// static
template <>
-void CacheKeySerializer<tint::transform::BindingPoints>::Serialize(
- CacheKey* key,
+void stream::Stream<tint::sem::BindingPoint>::Write(stream::Sink* sink,
+ const tint::sem::BindingPoint& p) {
+ static_assert(offsetof(tint::sem::BindingPoint, group) == 0,
+ "Please update serialization for tint::sem::BindingPoint");
+ static_assert(offsetof(tint::sem::BindingPoint, binding) == 4,
+ "Please update serialization for tint::sem::BindingPoint");
+ static_assert(sizeof(tint::sem::BindingPoint) == 8,
+ "Please update serialization for tint::sem::BindingPoint");
+ StreamIn(sink, p.group, p.binding);
+}
+
+// static
+template <>
+void stream::Stream<tint::transform::BindingPoints>::Write(
+ stream::Sink* sink,
const tint::transform::BindingPoints& points) {
static_assert(offsetof(tint::transform::BindingPoints, plane_1) == 0,
"Please update serialization for tint::transform::BindingPoints");
@@ -43,20 +56,7 @@
"Please update serialization for tint::transform::BindingPoints");
static_assert(sizeof(tint::transform::BindingPoints) == 16,
"Please update serialization for tint::transform::BindingPoints");
- key->Record(points.plane_1, points.params);
-}
-
-// static
-template <>
-void CacheKeySerializer<tint::sem::BindingPoint>::Serialize(CacheKey* key,
- const tint::sem::BindingPoint& p) {
- static_assert(offsetof(tint::sem::BindingPoint, group) == 0,
- "Please update serialization for tint::sem::BindingPoint");
- static_assert(offsetof(tint::sem::BindingPoint, binding) == 4,
- "Please update serialization for tint::sem::BindingPoint");
- static_assert(sizeof(tint::sem::BindingPoint) == 8,
- "Please update serialization for tint::sem::BindingPoint");
- key->Record(p.group, p.binding);
+ StreamIn(sink, points.plane_1, points.params);
}
} // namespace dawn::native
diff --git a/src/dawn/native/CacheRequest.h b/src/dawn/native/CacheRequest.h
index 375e6eb..df8af4a 100644
--- a/src/dawn/native/CacheRequest.h
+++ b/src/dawn/native/CacheRequest.h
@@ -157,7 +157,7 @@
#define DAWN_INTERNAL_CACHE_REQUEST_DECL_STRUCT_MEMBER(type, name) type name{};
// Helper for X macro for recording cache request fields into a CacheKey.
-#define DAWN_INTERNAL_CACHE_REQUEST_RECORD_KEY(type, name) key.Record(name);
+#define DAWN_INTERNAL_CACHE_REQUEST_RECORD_KEY(type, name) StreamIn(&key, name);
// Helper X macro to define a CacheRequest struct.
// Example usage:
@@ -177,7 +177,7 @@
/* Create a CacheKey from the request type and all members */ \
::dawn::native::CacheKey CreateCacheKey(const ::dawn::native::DeviceBase* device) const { \
::dawn::native::CacheKey key = device->GetCacheKey(); \
- key.Record(#Request); \
+ StreamIn(&key, #Request); \
MEMBERS(DAWN_INTERNAL_CACHE_REQUEST_RECORD_KEY) \
return key; \
} \
diff --git a/src/dawn/native/CachedObject.cpp b/src/dawn/native/CachedObject.cpp
index 5fa6a0a..249b0d6 100644
--- a/src/dawn/native/CachedObject.cpp
+++ b/src/dawn/native/CachedObject.cpp
@@ -46,4 +46,10 @@
return mCacheKey;
}
+// static
+template <>
+void stream::Stream<CachedObject>::Write(stream::Sink* sink, const CachedObject& obj) {
+ StreamIn(sink, obj.GetCacheKey());
+}
+
} // namespace dawn::native
diff --git a/src/dawn/native/ComputePipeline.cpp b/src/dawn/native/ComputePipeline.cpp
index a1dcf15..b0b574d 100644
--- a/src/dawn/native/ComputePipeline.cpp
+++ b/src/dawn/native/ComputePipeline.cpp
@@ -50,7 +50,7 @@
TrackInDevice();
// Initialize the cache key to include the cache type and device information.
- mCacheKey.Record(CacheKey::Type::ComputePipeline, device->GetCacheKey());
+ StreamIn(&mCacheKey, CacheKey::Type::ComputePipeline, device->GetCacheKey());
}
ComputePipelineBase::ComputePipelineBase(DeviceBase* device) : PipelineBase(device) {
diff --git a/src/dawn/native/Device.cpp b/src/dawn/native/Device.cpp
index 060826b..5026cc2 100644
--- a/src/dawn/native/Device.cpp
+++ b/src/dawn/native/Device.cpp
@@ -209,8 +209,8 @@
// Record the cache key from the properties. Note that currently, if a new extension
// descriptor is added (and probably handled here), the cache key recording needs to be
// updated.
- mDeviceCacheKey.Record(kDawnVersion, adapterProperties, mEnabledFeatures.featuresBitSet,
- mEnabledToggles.toggleBitset, cacheDesc);
+ StreamIn(&mDeviceCacheKey, kDawnVersion, adapterProperties, mEnabledFeatures.featuresBitSet,
+ mEnabledToggles.toggleBitset, cacheDesc);
}
DeviceBase::DeviceBase() : mState(State::Alive) {
diff --git a/src/dawn/native/RenderPipeline.cpp b/src/dawn/native/RenderPipeline.cpp
index 256cb62..5c24d00 100644
--- a/src/dawn/native/RenderPipeline.cpp
+++ b/src/dawn/native/RenderPipeline.cpp
@@ -622,7 +622,7 @@
TrackInDevice();
// Initialize the cache key to include the cache type and device information.
- mCacheKey.Record(CacheKey::Type::RenderPipeline, device->GetCacheKey());
+ StreamIn(&mCacheKey, CacheKey::Type::RenderPipeline, device->GetCacheKey());
}
RenderPipelineBase::RenderPipelineBase(DeviceBase* device) : PipelineBase(device) {
diff --git a/src/dawn/native/d3d12/CacheKeyD3D12.cpp b/src/dawn/native/d3d12/CacheKeyD3D12.cpp
index 0daf526..9a19477 100644
--- a/src/dawn/native/d3d12/CacheKeyD3D12.cpp
+++ b/src/dawn/native/d3d12/CacheKeyD3D12.cpp
@@ -20,120 +20,103 @@
namespace dawn::native {
template <>
-void CacheKeySerializer<D3D12_COMPUTE_PIPELINE_STATE_DESC>::Serialize(
- CacheKey* key,
- const D3D12_COMPUTE_PIPELINE_STATE_DESC& t) {
- // Don't record pRootSignature as we already record the signature blob in pipline layout.
- key->Record(t.CS).Record(t.NodeMask).Record(t.Flags);
-}
-
-template <>
-void CacheKeySerializer<D3D12_RENDER_TARGET_BLEND_DESC>::Serialize(
- CacheKey* key,
+void stream::Stream<D3D12_RENDER_TARGET_BLEND_DESC>::Write(
+ stream::Sink* sink,
const D3D12_RENDER_TARGET_BLEND_DESC& t) {
- key->Record(t.BlendEnable, t.LogicOpEnable, t.SrcBlend, t.DestBlend, t.BlendOp, t.SrcBlendAlpha,
- t.DestBlendAlpha, t.BlendOpAlpha, t.LogicOp, t.RenderTargetWriteMask);
+ StreamIn(sink, t.BlendEnable, t.LogicOpEnable, t.SrcBlend, t.DestBlend, t.BlendOp,
+ t.SrcBlendAlpha, t.DestBlendAlpha, t.BlendOpAlpha, t.LogicOp, t.RenderTargetWriteMask);
}
template <>
-void CacheKeySerializer<D3D12_BLEND_DESC>::Serialize(CacheKey* key, const D3D12_BLEND_DESC& t) {
- key->Record(t.AlphaToCoverageEnable, t.IndependentBlendEnable).Record(t.RenderTarget);
+void stream::Stream<D3D12_BLEND_DESC>::Write(stream::Sink* sink, const D3D12_BLEND_DESC& t) {
+ StreamIn(sink, t.AlphaToCoverageEnable, t.IndependentBlendEnable, t.RenderTarget);
}
template <>
-void CacheKeySerializer<D3D12_DEPTH_STENCILOP_DESC>::Serialize(
- CacheKey* key,
- const D3D12_DEPTH_STENCILOP_DESC& t) {
- key->Record(t.StencilFailOp, t.StencilDepthFailOp, t.StencilPassOp, t.StencilFunc);
+void stream::Stream<D3D12_DEPTH_STENCILOP_DESC>::Write(stream::Sink* sink,
+ const D3D12_DEPTH_STENCILOP_DESC& t) {
+ StreamIn(sink, t.StencilFailOp, t.StencilDepthFailOp, t.StencilPassOp, t.StencilFunc);
}
template <>
-void CacheKeySerializer<D3D12_DEPTH_STENCIL_DESC>::Serialize(CacheKey* key,
- const D3D12_DEPTH_STENCIL_DESC& t) {
- key->Record(t.DepthEnable, t.DepthWriteMask, t.DepthFunc, t.StencilEnable, t.StencilReadMask,
- t.StencilWriteMask, t.FrontFace, t.BackFace);
+void stream::Stream<D3D12_DEPTH_STENCIL_DESC>::Write(stream::Sink* sink,
+ const D3D12_DEPTH_STENCIL_DESC& t) {
+ StreamIn(sink, t.DepthEnable, t.DepthWriteMask, t.DepthFunc, t.StencilEnable, t.StencilReadMask,
+ t.StencilWriteMask, t.FrontFace, t.BackFace);
}
template <>
-void CacheKeySerializer<D3D12_RASTERIZER_DESC>::Serialize(CacheKey* key,
- const D3D12_RASTERIZER_DESC& t) {
- key->Record(t.FillMode, t.CullMode, t.FrontCounterClockwise, t.DepthBias, t.DepthBiasClamp,
- t.SlopeScaledDepthBias, t.DepthClipEnable, t.MultisampleEnable,
- t.AntialiasedLineEnable, t.ForcedSampleCount, t.ConservativeRaster);
+void stream::Stream<D3D12_RASTERIZER_DESC>::Write(stream::Sink* sink,
+ const D3D12_RASTERIZER_DESC& t) {
+ StreamIn(sink, t.FillMode, t.CullMode, t.FrontCounterClockwise, t.DepthBias, t.DepthBiasClamp,
+ t.SlopeScaledDepthBias, t.DepthClipEnable, t.MultisampleEnable,
+ t.AntialiasedLineEnable, t.ForcedSampleCount, t.ConservativeRaster);
}
template <>
-void CacheKeySerializer<D3D12_INPUT_ELEMENT_DESC>::Serialize(CacheKey* key,
- const D3D12_INPUT_ELEMENT_DESC& t) {
- key->Record(t.SemanticName, t.SemanticIndex, t.Format, t.InputSlot, t.AlignedByteOffset,
- t.InputSlotClass, t.InstanceDataStepRate);
+void stream::Stream<D3D12_INPUT_ELEMENT_DESC>::Write(stream::Sink* sink,
+ const D3D12_INPUT_ELEMENT_DESC& t) {
+ StreamIn(sink, std::string_view(t.SemanticName), t.SemanticIndex, t.Format, t.InputSlot,
+ t.AlignedByteOffset, t.InputSlotClass, t.InstanceDataStepRate);
}
template <>
-void CacheKeySerializer<D3D12_INPUT_LAYOUT_DESC>::Serialize(CacheKey* key,
- const D3D12_INPUT_LAYOUT_DESC& t) {
- key->RecordIterable(t.pInputElementDescs, t.NumElements);
+void stream::Stream<D3D12_INPUT_LAYOUT_DESC>::Write(stream::Sink* sink,
+ const D3D12_INPUT_LAYOUT_DESC& t) {
+ StreamIn(sink, Iterable(t.pInputElementDescs, t.NumElements));
}
template <>
-void CacheKeySerializer<D3D12_SO_DECLARATION_ENTRY>::Serialize(
- CacheKey* key,
- const D3D12_SO_DECLARATION_ENTRY& t) {
- key->Record(t.Stream, t.SemanticName, t.SemanticIndex, t.StartComponent, t.ComponentCount,
- t.OutputSlot);
+void stream::Stream<D3D12_SO_DECLARATION_ENTRY>::Write(stream::Sink* sink,
+ const D3D12_SO_DECLARATION_ENTRY& t) {
+ StreamIn(sink, t.Stream, std::string_view(t.SemanticName), t.SemanticIndex, t.StartComponent,
+ t.ComponentCount, t.OutputSlot);
}
template <>
-void CacheKeySerializer<D3D12_STREAM_OUTPUT_DESC>::Serialize(CacheKey* key,
- const D3D12_STREAM_OUTPUT_DESC& t) {
- key->RecordIterable(t.pSODeclaration, t.NumEntries)
- .RecordIterable(t.pBufferStrides, t.NumStrides)
- .Record(t.RasterizedStream);
+void stream::Stream<D3D12_STREAM_OUTPUT_DESC>::Write(stream::Sink* sink,
+ const D3D12_STREAM_OUTPUT_DESC& t) {
+ StreamIn(sink, Iterable(t.pSODeclaration, t.NumEntries),
+ Iterable(t.pBufferStrides, t.NumStrides), t.RasterizedStream);
}
template <>
-void CacheKeySerializer<DXGI_SAMPLE_DESC>::Serialize(CacheKey* key, const DXGI_SAMPLE_DESC& t) {
- key->Record(t.Count, t.Quality);
+void stream::Stream<DXGI_SAMPLE_DESC>::Write(stream::Sink* sink, const DXGI_SAMPLE_DESC& t) {
+ StreamIn(sink, t.Count, t.Quality);
}
template <>
-void CacheKeySerializer<D3D12_SHADER_BYTECODE>::Serialize(CacheKey* key,
- const D3D12_SHADER_BYTECODE& t) {
- key->RecordIterable(reinterpret_cast<const uint8_t*>(t.pShaderBytecode), t.BytecodeLength);
+void stream::Stream<D3D12_SHADER_BYTECODE>::Write(stream::Sink* sink,
+ const D3D12_SHADER_BYTECODE& t) {
+ StreamIn(sink, Iterable(reinterpret_cast<const uint8_t*>(t.pShaderBytecode), t.BytecodeLength));
}
template <>
-void CacheKeySerializer<D3D12_GRAPHICS_PIPELINE_STATE_DESC>::Serialize(
- CacheKey* key,
+void stream::Stream<D3D12_GRAPHICS_PIPELINE_STATE_DESC>::Write(
+ stream::Sink* sink,
const D3D12_GRAPHICS_PIPELINE_STATE_DESC& t) {
- // Don't record pRootSignature as we already record the signature blob in pipline layout.
- // Don't record CachedPSO as it is in the cached blob.
- key->Record(t.VS)
- .Record(t.PS)
- .Record(t.DS)
- .Record(t.HS)
- .Record(t.GS)
- .Record(t.StreamOutput)
- .Record(t.BlendState)
- .Record(t.SampleMask)
- .Record(t.RasterizerState)
- .Record(t.DepthStencilState)
- .Record(t.InputLayout)
- .Record(t.IBStripCutValue)
- .Record(t.PrimitiveTopologyType)
- .RecordIterable(t.RTVFormats, t.NumRenderTargets)
- .Record(t.DSVFormat)
- .Record(t.SampleDesc)
- .Record(t.NodeMask)
- .Record(t.Flags);
+ // Don't Serialize pRootSignature as we already serialize the signature blob in pipline layout.
+ // Don't Serialize CachedPSO as it is in the cached blob.
+ StreamIn(sink, t.VS, t.PS, t.DS, t.HS, t.GS, t.StreamOutput, t.BlendState, t.SampleMask,
+ t.RasterizerState, t.DepthStencilState, t.InputLayout, t.IBStripCutValue,
+ t.PrimitiveTopologyType, Iterable(t.RTVFormats, t.NumRenderTargets), t.DSVFormat,
+ t.SampleDesc, t.NodeMask, t.Flags);
}
template <>
-void CacheKeySerializer<ID3DBlob>::Serialize(CacheKey* key, const ID3DBlob& t) {
+void stream::Stream<D3D12_COMPUTE_PIPELINE_STATE_DESC>::Write(
+ stream::Sink* sink,
+ const D3D12_COMPUTE_PIPELINE_STATE_DESC& t) {
+ // Don't Serialize pRootSignature as we already serialize the signature blob in pipline layout.
+ StreamIn(sink, t.CS, t.NodeMask, t.Flags);
+}
+
+template <>
+void stream::Stream<ID3DBlob>::Write(stream::Sink* sink, const ID3DBlob& t) {
// Workaround: GetBufferPointer and GetbufferSize are not marked as const
ID3DBlob* pBlob = const_cast<ID3DBlob*>(&t);
- key->RecordIterable(reinterpret_cast<uint8_t*>(pBlob->GetBufferPointer()),
- pBlob->GetBufferSize());
+ StreamIn(sink, Iterable(reinterpret_cast<uint8_t*>(pBlob->GetBufferPointer()),
+ pBlob->GetBufferSize()));
}
} // namespace dawn::native
diff --git a/src/dawn/native/d3d12/ComputePipelineD3D12.cpp b/src/dawn/native/d3d12/ComputePipelineD3D12.cpp
index 5ba8b76..8dc05fc 100644
--- a/src/dawn/native/d3d12/ComputePipelineD3D12.cpp
+++ b/src/dawn/native/d3d12/ComputePipelineD3D12.cpp
@@ -62,7 +62,7 @@
ToBackend(GetLayout()), compileFlags));
d3dDesc.CS = compiledShader.GetD3D12ShaderBytecode();
- mCacheKey.Record(d3dDesc, ToBackend(GetLayout())->GetRootSignatureBlob());
+ StreamIn(&mCacheKey, d3dDesc, ToBackend(GetLayout())->GetRootSignatureBlob());
// Try to see if we have anything in the blob cache.
Blob blob = device->LoadCachedBlob(GetCacheKey());
diff --git a/src/dawn/native/d3d12/PipelineLayoutD3D12.cpp b/src/dawn/native/d3d12/PipelineLayoutD3D12.cpp
index 636fae2..aca85e8 100644
--- a/src/dawn/native/d3d12/PipelineLayoutD3D12.cpp
+++ b/src/dawn/native/d3d12/PipelineLayoutD3D12.cpp
@@ -271,7 +271,7 @@
0, mRootSignatureBlob->GetBufferPointer(),
mRootSignatureBlob->GetBufferSize(), IID_PPV_ARGS(&mRootSignature)),
"D3D12 create root signature"));
- mCacheKey.Record(mRootSignatureBlob.Get());
+ StreamIn(&mCacheKey, mRootSignatureBlob.Get());
return {};
}
diff --git a/src/dawn/native/d3d12/RenderPipelineD3D12.cpp b/src/dawn/native/d3d12/RenderPipelineD3D12.cpp
index 1ec8e02..e33d19f 100644
--- a/src/dawn/native/d3d12/RenderPipelineD3D12.cpp
+++ b/src/dawn/native/d3d12/RenderPipelineD3D12.cpp
@@ -430,7 +430,7 @@
mD3d12PrimitiveTopology = D3D12PrimitiveTopology(GetPrimitiveTopology());
- mCacheKey.Record(descriptorD3D12, *layout->GetRootSignatureBlob());
+ StreamIn(&mCacheKey, descriptorD3D12, *layout->GetRootSignatureBlob());
// Try to see if we have anything in the blob cache.
Blob blob = device->LoadCachedBlob(GetCacheKey());
diff --git a/src/dawn/native/d3d12/ShaderModuleD3D12.cpp b/src/dawn/native/d3d12/ShaderModuleD3D12.cpp
index 320449a..f64dc91 100644
--- a/src/dawn/native/d3d12/ShaderModuleD3D12.cpp
+++ b/src/dawn/native/d3d12/ShaderModuleD3D12.cpp
@@ -62,11 +62,11 @@
}
};
-void Serialize(std::stringstream& output, const tint::ast::Access& access) {
+void StreamIn(std::stringstream& output, const tint::ast::Access& access) {
output << access;
}
-void Serialize(std::stringstream& output, const tint::transform::BindingPoint& binding_point) {
+void StreamIn(std::stringstream& output, const tint::transform::BindingPoint& binding_point) {
output << "(BindingPoint";
output << " group=" << binding_point.group;
output << " binding=" << binding_point.binding;
@@ -74,32 +74,32 @@
}
template <typename T, typename = typename std::enable_if<std::is_fundamental<T>::value>::type>
-void Serialize(std::stringstream& output, const T& val) {
+void StreamIn(std::stringstream& output, const T& val) {
output << val;
}
template <typename T>
-void Serialize(std::stringstream& output,
- const std::unordered_map<tint::transform::BindingPoint, T>& map) {
+void StreamIn(std::stringstream& output,
+ const std::unordered_map<tint::transform::BindingPoint, T>& map) {
output << "(map";
std::map<tint::transform::BindingPoint, T, CompareBindingPoint> sorted(map.begin(), map.end());
for (auto& [bindingPoint, value] : sorted) {
output << " ";
- Serialize(output, bindingPoint);
+ StreamIn(output, bindingPoint);
output << "=";
- Serialize(output, value);
+ StreamIn(output, value);
}
output << ")";
}
-void Serialize(std::stringstream& output,
- const tint::writer::ArrayLengthFromUniformOptions& arrayLengthFromUniform) {
+void StreamIn(std::stringstream& output,
+ const tint::writer::ArrayLengthFromUniformOptions& arrayLengthFromUniform) {
output << "(ArrayLengthFromUniformOptions";
output << " ubo_binding=";
- Serialize(output, arrayLengthFromUniform.ubo_binding);
+ StreamIn(output, arrayLengthFromUniform.ubo_binding);
output << " bindpoint_to_size_index=";
- Serialize(output, arrayLengthFromUniform.bindpoint_to_size_index);
+ StreamIn(output, arrayLengthFromUniform.bindpoint_to_size_index);
output << ")";
}
@@ -344,17 +344,17 @@
stream << " disableSymbolRenaming=" << disableSymbolRenaming;
stream << " remappedBindingPoints=";
- Serialize(stream, remappedBindingPoints);
+ StreamIn(stream, remappedBindingPoints);
stream << " remappedAccessControls=";
- Serialize(stream, remappedAccessControls);
+ StreamIn(stream, remappedAccessControls);
stream << " useNumWorkgroups=" << usesNumWorkgroups;
stream << " numWorkgroupsRegisterSpace=" << numWorkgroupsRegisterSpace;
stream << " numWorkgroupsShaderRegister=" << numWorkgroupsShaderRegister;
stream << " arrayLengthFromUniform=";
- Serialize(stream, arrayLengthFromUniform);
+ StreamIn(stream, arrayLengthFromUniform);
stream << " shaderModel=" << deviceInfo->shaderModel;
stream << " disableWorkgroupInit=" << disableWorkgroupInit;
diff --git a/src/dawn/native/stream/BlobSource.cpp b/src/dawn/native/stream/BlobSource.cpp
new file mode 100644
index 0000000..deeaf12
--- /dev/null
+++ b/src/dawn/native/stream/BlobSource.cpp
@@ -0,0 +1,30 @@
+// Copyright 2022 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "dawn/native/stream/BlobSource.h"
+
+#include <utility>
+
+namespace dawn::native::stream {
+
+BlobSource::BlobSource(Blob&& blob) : mBlob(std::move(blob)) {}
+
+MaybeError BlobSource::Read(const void** ptr, size_t bytes) {
+ DAWN_INVALID_IF(bytes > mBlob.Size() - mOffset, "Out of bounds.");
+ *ptr = mBlob.Data() + mOffset;
+ mOffset += bytes;
+ return {};
+}
+
+} // namespace dawn::native::stream
diff --git a/src/dawn/native/stream/BlobSource.h b/src/dawn/native/stream/BlobSource.h
new file mode 100644
index 0000000..d93478e
--- /dev/null
+++ b/src/dawn/native/stream/BlobSource.h
@@ -0,0 +1,38 @@
+// Copyright 2022 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_DAWN_NATIVE_STREAM_BLOBSOURCE_H_
+#define SRC_DAWN_NATIVE_STREAM_BLOBSOURCE_H_
+
+#include "dawn/native/Blob.h"
+#include "dawn/native/Error.h"
+#include "dawn/native/stream/Source.h"
+
+namespace dawn::native::stream {
+
+class BlobSource : public Source {
+ public:
+ explicit BlobSource(Blob&& blob);
+
+ // stream::Source implementation.
+ MaybeError Read(const void** ptr, size_t bytes) override;
+
+ private:
+ const Blob mBlob;
+ size_t mOffset = 0;
+};
+
+} // namespace dawn::native::stream
+
+#endif // SRC_DAWN_NATIVE_STREAM_BLOBSOURCE_H_
diff --git a/src/dawn/native/stream/ByteVectorSink.cpp b/src/dawn/native/stream/ByteVectorSink.cpp
new file mode 100644
index 0000000..20c27f2
--- /dev/null
+++ b/src/dawn/native/stream/ByteVectorSink.cpp
@@ -0,0 +1,47 @@
+// Copyright 2022 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "dawn/native/stream/ByteVectorSink.h"
+
+#include "dawn/native/stream/Stream.h"
+
+namespace dawn::native::stream {
+
+void* ByteVectorSink::GetSpace(size_t bytes) {
+ size_t currentSize = this->size();
+ this->resize(currentSize + bytes);
+ return &this->operator[](currentSize);
+}
+
+template <>
+void stream::Stream<ByteVectorSink>::Write(stream::Sink* sink, const ByteVectorSink& vec) {
+ // For nested sinks, we do not record the length, and just copy the data so that it
+ // appears flattened.
+ size_t size = vec.size();
+ if (size > 0) {
+ void* ptr = sink->GetSpace(size);
+ memcpy(ptr, vec.data(), size);
+ }
+}
+
+std::ostream& operator<<(std::ostream& os, const ByteVectorSink& vec) {
+ os << std::hex;
+ for (const int b : vec) {
+ os << std::setfill('0') << std::setw(2) << b << " ";
+ }
+ os << std::dec;
+ return os;
+}
+
+} // namespace dawn::native::stream
diff --git a/src/dawn/native/stream/ByteVectorSink.h b/src/dawn/native/stream/ByteVectorSink.h
new file mode 100644
index 0000000..3b6016c
--- /dev/null
+++ b/src/dawn/native/stream/ByteVectorSink.h
@@ -0,0 +1,39 @@
+// Copyright 2022 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_DAWN_NATIVE_STREAM_BYTEVECTORSINK_H_
+#define SRC_DAWN_NATIVE_STREAM_BYTEVECTORSINK_H_
+
+#include <ostream>
+#include <vector>
+
+#include "dawn/native/stream/Sink.h"
+
+namespace dawn::native::stream {
+
+// Implementation of stream::Sink backed by a byte vector.
+class ByteVectorSink : public std::vector<uint8_t>, public Sink {
+ public:
+ using std::vector<uint8_t>::vector;
+
+ // Implementation of stream::Sink
+ void* GetSpace(size_t bytes) override;
+};
+
+// Stream operator for ByteVectorSink for debugging.
+std::ostream& operator<<(std::ostream& os, const ByteVectorSink& key);
+
+} // namespace dawn::native::stream
+
+#endif // SRC_DAWN_NATIVE_STREAM_BYTEVECTORSINK_H_
diff --git a/src/dawn/native/stream/Sink.h b/src/dawn/native/stream/Sink.h
new file mode 100644
index 0000000..caf1ba4
--- /dev/null
+++ b/src/dawn/native/stream/Sink.h
@@ -0,0 +1,32 @@
+// Copyright 2022 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_DAWN_NATIVE_STREAM_SINK_H_
+#define SRC_DAWN_NATIVE_STREAM_SINK_H_
+
+#include <cstddef>
+
+namespace dawn::native::stream {
+
+// Interface for a serialization sink.
+class Sink {
+ public:
+ // Allocate `bytes` space in the sink. Returns the pointer to the start
+ // of the allocation.
+ virtual void* GetSpace(size_t bytes) = 0;
+};
+
+} // namespace dawn::native::stream
+
+#endif // SRC_DAWN_NATIVE_STREAM_SINK_H_
diff --git a/src/dawn/native/stream/Source.h b/src/dawn/native/stream/Source.h
new file mode 100644
index 0000000..c7b19d0
--- /dev/null
+++ b/src/dawn/native/stream/Source.h
@@ -0,0 +1,34 @@
+// Copyright 2022 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_DAWN_NATIVE_STREAM_SOURCE_H_
+#define SRC_DAWN_NATIVE_STREAM_SOURCE_H_
+
+#include <cstddef>
+
+namespace dawn::native::stream {
+
+// Interface for a deserialization source.
+class Source {
+ public:
+ // Try to read `bytes` space from the source. The data must live as long as `Source.
+ // Returns MaybeError and writes result to |ptr| because ResultOrError uses
+ // a tagged pointer that must be 4-byte aligned. This function writes out |ptr|
+ // which may not be aligned.
+ virtual MaybeError Read(const void** ptr, size_t bytes) = 0;
+};
+
+} // namespace dawn::native::stream
+
+#endif // SRC_DAWN_NATIVE_STREAM_SOURCE_H_
diff --git a/src/dawn/native/stream/Stream.cpp b/src/dawn/native/stream/Stream.cpp
new file mode 100644
index 0000000..1ca241c
--- /dev/null
+++ b/src/dawn/native/stream/Stream.cpp
@@ -0,0 +1,51 @@
+// Copyright 2022 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "dawn/native/stream/Stream.h"
+
+#include <string>
+
+namespace dawn::native::stream {
+
+template <>
+void Stream<std::string>::Write(Sink* s, const std::string& t) {
+ StreamIn(s, t.length());
+ size_t size = t.length() * sizeof(char);
+ if (size > 0) {
+ void* ptr = s->GetSpace(size);
+ memcpy(ptr, t.data(), size);
+ }
+}
+
+template <>
+MaybeError Stream<std::string>::Read(Source* s, std::string* t) {
+ size_t length;
+ DAWN_TRY(StreamOut(s, &length));
+ const void* ptr;
+ DAWN_TRY(s->Read(&ptr, length));
+ *t = std::string(static_cast<const char*>(ptr), length);
+ return {};
+}
+
+template <>
+void Stream<std::string_view>::Write(Sink* s, const std::string_view& t) {
+ StreamIn(s, t.length());
+ size_t size = t.length() * sizeof(char);
+ if (size > 0) {
+ void* ptr = s->GetSpace(size);
+ memcpy(ptr, t.data(), size);
+ }
+}
+
+} // namespace dawn::native::stream
diff --git a/src/dawn/native/stream/Stream.h b/src/dawn/native/stream/Stream.h
index 1797102..416aa02 100644
--- a/src/dawn/native/stream/Stream.h
+++ b/src/dawn/native/stream/Stream.h
@@ -18,164 +18,299 @@
#include <algorithm>
#include <bitset>
#include <functional>
+#include <limits>
#include <unordered_map>
#include <utility>
#include <vector>
#include "dawn/common/Platform.h"
#include "dawn/common/TypedInteger.h"
-#include "dawn/native/CacheKey.h"
#include "dawn/native/Error.h"
+#include "dawn/native/stream/Sink.h"
+#include "dawn/native/stream/Source.h"
-namespace dawn::native {
+namespace dawn::native::stream {
-class CacheKey;
-
-// Specialized overload for CacheKey::UnsafeIgnoredValue which does nothing.
-template <typename T>
-class CacheKeySerializer<CacheKey::UnsafeUnkeyedValue<T>> {
+// Base Stream template for specialization. Specializations may define static methods:
+// static void Write(Sink* s, const T& v);
+// static MaybeError Read(Source* s, T* v);
+template <typename T, typename SFINAE = void>
+class Stream {
public:
- constexpr static void Serialize(CacheKey* key, const CacheKey::UnsafeUnkeyedValue<T>&) {}
+ static void Write(Sink* s, const T& v);
+ static MaybeError Read(Source* s, T* v);
};
-// Specialized overload for fundamental types.
+// Helper to call Stream<T>::Write.
+// By default, calling StreamIn will call this overload inside the stream namespace.
+// Other definitons of StreamIn found by ADL may override this one.
template <typename T>
-class CacheKeySerializer<T, std::enable_if_t<std::is_fundamental_v<T>>> {
+constexpr void StreamIn(Sink* s, const T& v) {
+ Stream<T>::Write(s, v);
+}
+
+// Helper to call Stream<T>::Read
+// By default, calling StreamOut will call this overload inside the stream namespace.
+// Other definitons of StreamOut found by ADL may override this one.
+template <typename T>
+MaybeError StreamOut(Source* s, T* v) {
+ return Stream<T>::Read(s, v);
+}
+
+// Helper to call StreamIn on a parameter pack.
+template <typename T, typename... Ts>
+constexpr void StreamIn(Sink* s, const T& v, const Ts&... vs) {
+ StreamIn(s, v);
+ (StreamIn(s, vs), ...);
+}
+
+// Helper to call StreamOut on a parameter pack.
+template <typename T, typename... Ts>
+MaybeError StreamOut(Source* s, T* v, Ts*... vs) {
+ DAWN_TRY(StreamOut(s, v));
+ return StreamOut(s, vs...);
+}
+
+// Stream specialization for fundamental types.
+template <typename T>
+class Stream<T, std::enable_if_t<std::is_fundamental_v<T>>> {
public:
- static void Serialize(CacheKey* key, const T t) {
- const char* it = reinterpret_cast<const char*>(&t);
- key->insert(key->end(), it, (it + sizeof(T)));
+ static void Write(Sink* s, const T& v) { memcpy(s->GetSpace(sizeof(T)), &v, sizeof(T)); }
+ static MaybeError Read(Source* s, T* v) {
+ const void* ptr;
+ DAWN_TRY(s->Read(&ptr, sizeof(T)));
+ memcpy(v, ptr, sizeof(T));
+ return {};
}
};
-// Specialized overload for bitsets that are smaller than 64.
+namespace detail {
+// NOLINTNEXTLINE(runtime/int) Alias "unsigned long long" type to match std::bitset to_ullong
+using BitsetUllong = unsigned long long;
+constexpr size_t kBitsPerUllong = 8 * sizeof(BitsetUllong);
+constexpr bool BitsetSupportsToUllong(size_t N) {
+ return N <= kBitsPerUllong;
+}
+} // namespace detail
+
+// Stream specialization for bitsets that are smaller than BitsetUllong.
template <size_t N>
-class CacheKeySerializer<std::bitset<N>, std::enable_if_t<(N <= 64)>> {
+class Stream<std::bitset<N>, std::enable_if_t<detail::BitsetSupportsToUllong(N)>> {
public:
- static void Serialize(CacheKey* key, const std::bitset<N>& t) { key->Record(t.to_ullong()); }
+ static void Write(Sink* s, const std::bitset<N>& t) { StreamIn(s, t.to_ullong()); }
+ static MaybeError Read(Source* s, std::bitset<N>* v) {
+ detail::BitsetUllong value;
+ DAWN_TRY(StreamOut(s, &value));
+ *v = std::bitset<N>(value);
+ return {};
+ }
};
-// Specialized overload for bitsets since using the built-in to_ullong have a size limit.
+// Stream specialization for bitsets since using the built-in to_ullong has a size limit.
template <size_t N>
-class CacheKeySerializer<std::bitset<N>, std::enable_if_t<(N > 64)>> {
+class Stream<std::bitset<N>, std::enable_if_t<!detail::BitsetSupportsToUllong(N)>> {
public:
- static void Serialize(CacheKey* key, const std::bitset<N>& t) {
- // Serializes the bitset into series of uint8_t, along with recording the size.
+ static void Write(Sink* s, const std::bitset<N>& t) {
+ // Iterate in chunks of detail::BitsetUllong.
+ static std::bitset<N> mask(std::numeric_limits<detail::BitsetUllong>::max());
+
+ std::bitset<N> bits = t;
+ for (size_t offset = 0; offset < N;
+ offset += detail::kBitsPerUllong, bits >>= detail::kBitsPerUllong) {
+ StreamIn(s, (mask & bits).to_ullong());
+ }
+ }
+
+ static MaybeError Read(Source* s, std::bitset<N>* v) {
static_assert(N > 0);
- key->Record(static_cast<size_t>(N));
- uint8_t value = 0;
- for (size_t i = 0; i < N; i++) {
- value <<= 1;
- // Explicitly convert to numeric since MSVC doesn't like mixing of bools.
- value |= t[i] ? 1 : 0;
- if (i % 8 == 7) {
- // Whenever we fill an 8 bit value, record it and zero it out.
- key->Record(value);
- value = 0;
- }
+ *v = {};
+
+ // Iterate in chunks of detail::BitsetUllong.
+ for (size_t offset = 0; offset < N;
+ offset += detail::kBitsPerUllong, (*v) <<= detail::kBitsPerUllong) {
+ detail::BitsetUllong ullong;
+ DAWN_TRY(StreamOut(s, &ullong));
+ *v |= std::bitset<N>(ullong);
}
- // Serialize the last value if we are not a multiple of 8.
- if (N % 8 != 0) {
- key->Record(value);
- }
+ return {};
}
};
-// Specialized overload for enums.
+// Stream specialization for enums.
template <typename T>
-class CacheKeySerializer<T, std::enable_if_t<std::is_enum_v<T>>> {
+class Stream<T, std::enable_if_t<std::is_enum_v<T>>> {
+ using U = std::underlying_type_t<T>;
+
public:
- static void Serialize(CacheKey* key, const T t) {
- CacheKeySerializer<std::underlying_type_t<T>>::Serialize(
- key, static_cast<std::underlying_type_t<T>>(t));
+ static void Write(Sink* s, const T& v) { StreamIn(s, static_cast<U>(v)); }
+
+ static MaybeError Read(Source* s, T* v) {
+ U out;
+ DAWN_TRY(StreamOut(s, &out));
+ *v = static_cast<T>(out);
+ return {};
}
};
-// Specialized overload for TypedInteger.
+// Stream specialization for TypedInteger.
template <typename Tag, typename Integer>
-class CacheKeySerializer<::detail::TypedIntegerImpl<Tag, Integer>> {
+class Stream<::detail::TypedIntegerImpl<Tag, Integer>> {
+ using T = ::detail::TypedIntegerImpl<Tag, Integer>;
+
public:
- static void Serialize(CacheKey* key, const ::detail::TypedIntegerImpl<Tag, Integer> t) {
- CacheKeySerializer<Integer>::Serialize(key, static_cast<Integer>(t));
+ static void Write(Sink* s, const T& t) { StreamIn(s, static_cast<Integer>(t)); }
+
+ static MaybeError Read(Source* s, T* v) {
+ Integer out;
+ DAWN_TRY(StreamOut(s, &out));
+ *v = T(out);
+ return {};
}
};
-// Specialized overload for pointers. Since we are serializing for a cache key, we always
-// serialize via value, not by pointer. To handle nullptr scenarios, we always serialize whether
-// the pointer was nullptr followed by the contents if applicable.
+// Stream specialization for pointers. We always serialize via value, not by pointer.
+// To handle nullptr scenarios, we always serialize whether the pointer was not nullptr,
+// followed by the contents if applicable.
template <typename T>
-class CacheKeySerializer<T, std::enable_if_t<std::is_pointer_v<T>>> {
+class Stream<T, std::enable_if_t<std::is_pointer_v<T>>> {
public:
- static void Serialize(CacheKey* key, const T t) {
- key->Record(t == nullptr);
+ static void Write(stream::Sink* sink, const T& t) {
+ using Pointee = std::decay_t<std::remove_pointer_t<T>>;
+ static_assert(!std::is_same_v<char, Pointee> && !std::is_same_v<wchar_t, Pointee> &&
+ !std::is_same_v<char16_t, Pointee> && !std::is_same_v<char32_t, Pointee>,
+ "C-str like type likely has ambiguous serialization. For a string, wrap with "
+ "std::string_view instead.");
+ StreamIn(sink, t != nullptr);
if (t != nullptr) {
- CacheKeySerializer<std::remove_cv_t<std::remove_pointer_t<T>>>::Serialize(key, *t);
+ StreamIn(sink, *t);
}
}
};
-// Specialized overload for fixed arrays of primitives.
+// Stream specialization for fixed arrays of fundamental types.
template <typename T, size_t N>
-class CacheKeySerializer<T[N], std::enable_if_t<std::is_fundamental_v<T>>> {
+class Stream<T[N], std::enable_if_t<std::is_fundamental_v<T>>> {
public:
- static void Serialize(CacheKey* key, const T (&t)[N]) {
+ static void Write(Sink* s, const T (&t)[N]) {
static_assert(N > 0);
- key->Record(static_cast<size_t>(N));
- const char* it = reinterpret_cast<const char*>(t);
- key->insert(key->end(), it, it + sizeof(t));
+ memcpy(s->GetSpace(sizeof(t)), &t, sizeof(t));
+ }
+
+ static MaybeError Read(Source* s, T (*t)[N]) {
+ static_assert(N > 0);
+ const void* ptr;
+ DAWN_TRY(s->Read(&ptr, sizeof(*t)));
+ memcpy(*t, ptr, sizeof(*t));
+ return {};
}
};
-// Specialized overload for fixed arrays of non-primitives.
+// Specialization for fixed arrays of non-fundamental types.
template <typename T, size_t N>
-class CacheKeySerializer<T[N], std::enable_if_t<!std::is_fundamental_v<T>>> {
+class Stream<T[N], std::enable_if_t<!std::is_fundamental_v<T>>> {
public:
- static void Serialize(CacheKey* key, const T (&t)[N]) {
+ static void Write(Sink* s, const T (&t)[N]) {
static_assert(N > 0);
- key->Record(static_cast<size_t>(N));
for (size_t i = 0; i < N; i++) {
- key->Record(t[i]);
+ StreamIn(s, t[i]);
}
}
-};
-// Specialized overload for CachedObjects.
-template <typename T>
-class CacheKeySerializer<T, std::enable_if_t<std::is_base_of_v<CachedObject, T>>> {
- public:
- static void Serialize(CacheKey* key, const T& t) { key->Record(t.GetCacheKey()); }
-};
-
-// Specialized overload for std::vector.
-template <typename T>
-class CacheKeySerializer<std::vector<T>> {
- public:
- static void Serialize(CacheKey* key, const std::vector<T>& t) { key->RecordIterable(t); }
-};
-
-// Specialized overload for std::pair<A, B>
-template <typename A, typename B>
-class CacheKeySerializer<std::pair<A, B>> {
- public:
- static void Serialize(CacheKey* key, const std::pair<A, B>& p) {
- key->Record(p.first, p.second);
+ static MaybeError Read(Source* s, T (*t)[N]) {
+ static_assert(N > 0);
+ for (size_t i = 0; i < N; i++) {
+ DAWN_TRY(StreamOut(s, &(*t)[i]));
+ }
+ return {};
}
};
-// Specialized overload for std::unordered_map<K, V>
-template <typename K, typename V>
-class CacheKeySerializer<std::unordered_map<K, V>> {
+// Stream specialization for std::vector.
+template <typename T>
+class Stream<std::vector<T>> {
public:
- static void Serialize(CacheKey* key, const std::unordered_map<K, V>& m) {
+ static void Write(Sink* s, const std::vector<T>& v) {
+ StreamIn(s, v.size());
+ for (const T& it : v) {
+ StreamIn(s, it);
+ }
+ }
+
+ static MaybeError Read(Source* s, std::vector<T>* v) {
+ using SizeT = decltype(std::declval<std::vector<T>>().size());
+ SizeT size;
+ DAWN_TRY(StreamOut(s, &size));
+ *v = {};
+ v->reserve(size);
+ for (SizeT i = 0; i < size; ++i) {
+ T el;
+ DAWN_TRY(StreamOut(s, &el));
+ v->push_back(std::move(el));
+ }
+ return {};
+ }
+};
+
+// Stream specialization for std::pair.
+template <typename A, typename B>
+class Stream<std::pair<A, B>> {
+ public:
+ static void Write(Sink* s, const std::pair<A, B>& v) {
+ StreamIn(s, v.first);
+ StreamIn(s, v.second);
+ }
+
+ static MaybeError Read(Source* s, std::pair<A, B>* v) {
+ DAWN_TRY(StreamOut(s, &v->first));
+ DAWN_TRY(StreamOut(s, &v->second));
+ return {};
+ }
+};
+
+// Stream specialization for std::unordered_map<K, V> which sorts the entries
+// to provide a stable ordering.
+template <typename K, typename V>
+class Stream<std::unordered_map<K, V>> {
+ public:
+ static void Write(stream::Sink* sink, const std::unordered_map<K, V>& m) {
std::vector<std::pair<K, V>> ordered(m.begin(), m.end());
std::sort(ordered.begin(), ordered.end(),
[](const std::pair<K, V>& a, const std::pair<K, V>& b) {
return std::less<K>{}(a.first, b.first);
});
- key->RecordIterable(ordered);
+ StreamIn(sink, ordered);
}
};
-} // namespace dawn::native
+// Helper class to contain the begin/end iterators of an iterable.
+namespace detail {
+template <typename Iterator>
+struct Iterable {
+ Iterator begin;
+ Iterator end;
+};
+} // namespace detail
+
+// Helper for making detail::Iterable from a pointer and count.
+template <typename T>
+auto Iterable(const T* ptr, size_t count) {
+ using Iterator = const T*;
+ return detail::Iterable<Iterator>{ptr, ptr + count};
+}
+
+// Stream specialization for detail::Iterable which writes the number of elements,
+// followed by the elements.
+template <typename Iterator>
+class Stream<detail::Iterable<Iterator>> {
+ public:
+ static void Write(stream::Sink* sink, const detail::Iterable<Iterator>& iter) {
+ StreamIn(sink, std::distance(iter.begin, iter.end));
+ for (auto it = iter.begin; it != iter.end; ++it) {
+ StreamIn(sink, *it);
+ }
+ }
+};
+
+} // namespace dawn::native::stream
#endif // SRC_DAWN_NATIVE_STREAM_STREAM_H_
diff --git a/src/dawn/native/vulkan/BindGroupLayoutVk.cpp b/src/dawn/native/vulkan/BindGroupLayoutVk.cpp
index a87a91f..94b5a7d 100644
--- a/src/dawn/native/vulkan/BindGroupLayoutVk.cpp
+++ b/src/dawn/native/vulkan/BindGroupLayoutVk.cpp
@@ -117,7 +117,7 @@
createInfo.pBindings = bindings.data();
// Record cache key information now since the createInfo is not stored.
- mCacheKey.Record(createInfo);
+ StreamIn(&mCacheKey, createInfo);
Device* device = ToBackend(GetDevice());
DAWN_TRY(CheckVkSuccess(device->fn.CreateDescriptorSetLayout(device->GetVkDevice(), &createInfo,
diff --git a/src/dawn/native/vulkan/CacheKeyVk.cpp b/src/dawn/native/vulkan/CacheKeyVk.cpp
index 2b6d51b..dc7ea05 100644
--- a/src/dawn/native/vulkan/CacheKeyVk.cpp
+++ b/src/dawn/native/vulkan/CacheKeyVk.cpp
@@ -13,6 +13,7 @@
// limitations under the License.
#include <cstring>
+#include <map>
#include "dawn/common/Assert.h"
#include "dawn/common/vulkan_platform.h"
@@ -39,7 +40,7 @@
}
template <typename VK_STRUCT_TYPE>
-void SerializePnextImpl(CacheKey* key, const VkBaseOutStructure* root) {
+void SerializePnextImpl(stream::Sink* sink, const VkBaseOutStructure* root) {
const VkBaseOutStructure* next = reinterpret_cast<const VkBaseOutStructure*>(root->pNext);
const VK_STRUCT_TYPE* found = nullptr;
while (next != nullptr) {
@@ -55,16 +56,16 @@
next = reinterpret_cast<const VkBaseOutStructure*>(next->pNext);
}
if (found != nullptr) {
- key->Record(found);
+ StreamIn(sink, found);
}
}
template <typename VK_STRUCT_TYPE,
typename... VK_STRUCT_TYPES,
typename = std::enable_if_t<(sizeof...(VK_STRUCT_TYPES) > 0)>>
-void SerializePnextImpl(CacheKey* key, const VkBaseOutStructure* root) {
- SerializePnextImpl<VK_STRUCT_TYPE>(key, root);
- SerializePnextImpl<VK_STRUCT_TYPES...>(key, root);
+void SerializePnextImpl(stream::Sink* sink, const VkBaseOutStructure* root) {
+ SerializePnextImpl<VK_STRUCT_TYPE>(sink, root);
+ SerializePnextImpl<VK_STRUCT_TYPES...>(sink, root);
}
template <typename VK_STRUCT_TYPE>
@@ -81,16 +82,16 @@
template <typename... VK_STRUCT_TYPES,
typename VK_STRUCT_TYPE,
typename = std::enable_if_t<(sizeof...(VK_STRUCT_TYPES) > 0)>>
-void SerializePnext(CacheKey* key, const VK_STRUCT_TYPE* t) {
+void SerializePnext(stream::Sink* sink, const VK_STRUCT_TYPE* t) {
const VkBaseOutStructure* root = detail::ToVkBaseOutStructure(t);
detail::ValidatePnextImpl<VK_STRUCT_TYPES...>(root);
- detail::SerializePnextImpl<VK_STRUCT_TYPES...>(key, root);
+ detail::SerializePnextImpl<VK_STRUCT_TYPES...>(sink, root);
}
// Empty template specialization so that we can put this in to ensure failures occur if new
// extensions are added without updating serialization.
template <typename VK_STRUCT_TYPE>
-void SerializePnext(CacheKey* key, const VK_STRUCT_TYPE* t) {
+void SerializePnext(stream::Sink* sink, const VK_STRUCT_TYPE* t) {
const VkBaseOutStructure* root = detail::ToVkBaseOutStructure(t);
detail::ValidatePnextImpl<>(root);
}
@@ -98,253 +99,241 @@
} // namespace
template <>
-void CacheKeySerializer<VkDescriptorSetLayoutBinding>::Serialize(
- CacheKey* key,
- const VkDescriptorSetLayoutBinding& t) {
- key->Record(t.binding, t.descriptorType, t.descriptorCount, t.stageFlags);
+void stream::Stream<VkDescriptorSetLayoutBinding>::Write(stream::Sink* sink,
+ const VkDescriptorSetLayoutBinding& t) {
+ StreamIn(sink, t.binding, t.descriptorType, t.descriptorCount, t.stageFlags);
}
template <>
-void CacheKeySerializer<VkDescriptorSetLayoutCreateInfo>::Serialize(
- CacheKey* key,
+void stream::Stream<VkDescriptorSetLayoutCreateInfo>::Write(
+ stream::Sink* sink,
const VkDescriptorSetLayoutCreateInfo& t) {
- key->Record(t.flags).RecordIterable(t.pBindings, t.bindingCount);
- SerializePnext(key, &t);
+ StreamIn(sink, t.flags, Iterable(t.pBindings, t.bindingCount));
+ SerializePnext(sink, &t);
}
template <>
-void CacheKeySerializer<VkPushConstantRange>::Serialize(CacheKey* key,
- const VkPushConstantRange& t) {
- key->Record(t.stageFlags, t.offset, t.size);
+void stream::Stream<VkPushConstantRange>::Write(stream::Sink* sink, const VkPushConstantRange& t) {
+ StreamIn(sink, t.stageFlags, t.offset, t.size);
}
template <>
-void CacheKeySerializer<VkPipelineLayoutCreateInfo>::Serialize(
- CacheKey* key,
- const VkPipelineLayoutCreateInfo& t) {
+void stream::Stream<VkPipelineLayoutCreateInfo>::Write(stream::Sink* sink,
+ const VkPipelineLayoutCreateInfo& t) {
// The set layouts are not serialized here because they are pointers to backend objects.
// They need to be cross-referenced with the frontend objects and serialized from there.
- key->Record(t.flags).RecordIterable(t.pPushConstantRanges, t.pushConstantRangeCount);
- SerializePnext(key, &t);
+ StreamIn(sink, t.flags, Iterable(t.pPushConstantRanges, t.pushConstantRangeCount));
+ SerializePnext(sink, &t);
}
template <>
-void CacheKeySerializer<VkPipelineShaderStageRequiredSubgroupSizeCreateInfoEXT>::Serialize(
- CacheKey* key,
+void stream::Stream<VkPipelineShaderStageRequiredSubgroupSizeCreateInfoEXT>::Write(
+ stream::Sink* sink,
const VkPipelineShaderStageRequiredSubgroupSizeCreateInfoEXT& t) {
- key->Record(t.requiredSubgroupSize);
+ StreamIn(sink, t.requiredSubgroupSize);
}
template <>
-void CacheKeySerializer<VkPipelineRasterizationDepthClipStateCreateInfoEXT>::Serialize(
- CacheKey* key,
+void stream::Stream<VkPipelineRasterizationDepthClipStateCreateInfoEXT>::Write(
+ stream::Sink* sink,
const VkPipelineRasterizationDepthClipStateCreateInfoEXT& t) {
- key->Record(t.depthClipEnable, t.flags);
+ StreamIn(sink, t.depthClipEnable, t.flags);
}
template <>
-void CacheKeySerializer<VkSpecializationMapEntry>::Serialize(CacheKey* key,
- const VkSpecializationMapEntry& t) {
- key->Record(t.constantID, t.offset, t.size);
+void stream::Stream<VkSpecializationMapEntry>::Write(stream::Sink* sink,
+ const VkSpecializationMapEntry& t) {
+ StreamIn(sink, t.constantID, t.offset, t.size);
}
template <>
-void CacheKeySerializer<VkSpecializationInfo>::Serialize(CacheKey* key,
- const VkSpecializationInfo& t) {
- key->RecordIterable(t.pMapEntries, t.mapEntryCount)
- .RecordIterable(static_cast<const uint8_t*>(t.pData), t.dataSize);
+void stream::Stream<VkSpecializationInfo>::Write(stream::Sink* sink,
+ const VkSpecializationInfo& t) {
+ StreamIn(sink, Iterable(t.pMapEntries, t.mapEntryCount),
+ Iterable(static_cast<const uint8_t*>(t.pData), t.dataSize));
}
template <>
-void CacheKeySerializer<VkPipelineShaderStageCreateInfo>::Serialize(
- CacheKey* key,
+void stream::Stream<VkPipelineShaderStageCreateInfo>::Write(
+ stream::Sink* sink,
const VkPipelineShaderStageCreateInfo& t) {
// The shader module is not serialized here because it is a pointer to a backend object.
- key->Record(t.flags, t.stage)
- .RecordIterable(t.pName, strlen(t.pName))
- .Record(t.pSpecializationInfo);
- SerializePnext<VkPipelineShaderStageRequiredSubgroupSizeCreateInfoEXT>(key, &t);
+ StreamIn(sink, t.flags, t.stage, Iterable(t.pName, strlen(t.pName)), t.pSpecializationInfo);
+ SerializePnext<VkPipelineShaderStageRequiredSubgroupSizeCreateInfoEXT>(sink, &t);
}
template <>
-void CacheKeySerializer<VkComputePipelineCreateInfo>::Serialize(
- CacheKey* key,
- const VkComputePipelineCreateInfo& t) {
+void stream::Stream<VkComputePipelineCreateInfo>::Write(stream::Sink* sink,
+ const VkComputePipelineCreateInfo& t) {
// The pipeline layout is not serialized here because it is a pointer to a backend object.
// It needs to be cross-referenced with the frontend objects and serialized from there. The
- // base pipeline information is also currently not recorded since we do not use them in our
+ // base pipeline information is also currently not serialized since we do not use them in our
// backend implementation. If we decide to use them later on, they also need to be
// cross-referenced from the frontend.
- key->Record(t.flags, t.stage);
+ StreamIn(sink, t.flags, t.stage);
}
template <>
-void CacheKeySerializer<VkVertexInputBindingDescription>::Serialize(
- CacheKey* key,
+void stream::Stream<VkVertexInputBindingDescription>::Write(
+ stream::Sink* sink,
const VkVertexInputBindingDescription& t) {
- key->Record(t.binding, t.stride, t.inputRate);
+ StreamIn(sink, t.binding, t.stride, t.inputRate);
}
template <>
-void CacheKeySerializer<VkVertexInputAttributeDescription>::Serialize(
- CacheKey* key,
+void stream::Stream<VkVertexInputAttributeDescription>::Write(
+ stream::Sink* sink,
const VkVertexInputAttributeDescription& t) {
- key->Record(t.location, t.binding, t.format, t.offset);
+ StreamIn(sink, t.location, t.binding, t.format, t.offset);
}
template <>
-void CacheKeySerializer<VkPipelineVertexInputStateCreateInfo>::Serialize(
- CacheKey* key,
+void stream::Stream<VkPipelineVertexInputStateCreateInfo>::Write(
+ stream::Sink* sink,
const VkPipelineVertexInputStateCreateInfo& t) {
- key->Record(t.flags)
- .RecordIterable(t.pVertexBindingDescriptions, t.vertexBindingDescriptionCount)
- .RecordIterable(t.pVertexAttributeDescriptions, t.vertexAttributeDescriptionCount);
- SerializePnext(key, &t);
+ StreamIn(sink, t.flags, Iterable(t.pVertexBindingDescriptions, t.vertexBindingDescriptionCount),
+ Iterable(t.pVertexAttributeDescriptions, t.vertexAttributeDescriptionCount));
+ SerializePnext(sink, &t);
}
template <>
-void CacheKeySerializer<VkPipelineInputAssemblyStateCreateInfo>::Serialize(
- CacheKey* key,
+void stream::Stream<VkPipelineInputAssemblyStateCreateInfo>::Write(
+ stream::Sink* sink,
const VkPipelineInputAssemblyStateCreateInfo& t) {
- key->Record(t.flags, t.topology, t.primitiveRestartEnable);
- SerializePnext(key, &t);
+ StreamIn(sink, t.flags, t.topology, t.primitiveRestartEnable);
+ SerializePnext(sink, &t);
}
template <>
-void CacheKeySerializer<VkPipelineTessellationStateCreateInfo>::Serialize(
- CacheKey* key,
+void stream::Stream<VkPipelineTessellationStateCreateInfo>::Write(
+ stream::Sink* sink,
const VkPipelineTessellationStateCreateInfo& t) {
- key->Record(t.flags, t.patchControlPoints);
- SerializePnext(key, &t);
+ StreamIn(sink, t.flags, t.patchControlPoints);
+ SerializePnext(sink, &t);
}
template <>
-void CacheKeySerializer<VkViewport>::Serialize(CacheKey* key, const VkViewport& t) {
- key->Record(t.x, t.y, t.width, t.height, t.minDepth, t.maxDepth);
+void stream::Stream<VkViewport>::Write(stream::Sink* sink, const VkViewport& t) {
+ StreamIn(sink, t.x, t.y, t.width, t.height, t.minDepth, t.maxDepth);
}
template <>
-void CacheKeySerializer<VkOffset2D>::Serialize(CacheKey* key, const VkOffset2D& t) {
- key->Record(t.x, t.y);
+void stream::Stream<VkOffset2D>::Write(stream::Sink* sink, const VkOffset2D& t) {
+ StreamIn(sink, t.x, t.y);
}
template <>
-void CacheKeySerializer<VkExtent2D>::Serialize(CacheKey* key, const VkExtent2D& t) {
- key->Record(t.width, t.height);
+void stream::Stream<VkExtent2D>::Write(stream::Sink* sink, const VkExtent2D& t) {
+ StreamIn(sink, t.width, t.height);
}
template <>
-void CacheKeySerializer<VkRect2D>::Serialize(CacheKey* key, const VkRect2D& t) {
- key->Record(t.offset, t.extent);
+void stream::Stream<VkRect2D>::Write(stream::Sink* sink, const VkRect2D& t) {
+ StreamIn(sink, t.offset, t.extent);
}
template <>
-void CacheKeySerializer<VkPipelineViewportStateCreateInfo>::Serialize(
- CacheKey* key,
+void stream::Stream<VkPipelineViewportStateCreateInfo>::Write(
+ stream::Sink* sink,
const VkPipelineViewportStateCreateInfo& t) {
- key->Record(t.flags)
- .RecordIterable(t.pViewports, t.viewportCount)
- .RecordIterable(t.pScissors, t.scissorCount);
- SerializePnext(key, &t);
+ StreamIn(sink, t.flags, Iterable(t.pViewports, t.viewportCount),
+ Iterable(t.pScissors, t.scissorCount));
+ SerializePnext(sink, &t);
}
template <>
-void CacheKeySerializer<VkPipelineRasterizationStateCreateInfo>::Serialize(
- CacheKey* key,
+void stream::Stream<VkPipelineRasterizationStateCreateInfo>::Write(
+ stream::Sink* sink,
const VkPipelineRasterizationStateCreateInfo& t) {
- key->Record(t.flags, t.depthClampEnable, t.rasterizerDiscardEnable, t.polygonMode, t.cullMode,
- t.frontFace, t.depthBiasEnable, t.depthBiasConstantFactor, t.depthBiasClamp,
- t.depthBiasSlopeFactor, t.lineWidth);
- SerializePnext<VkPipelineRasterizationDepthClipStateCreateInfoEXT>(key, &t);
+ StreamIn(sink, t.flags, t.depthClampEnable, t.rasterizerDiscardEnable, t.polygonMode,
+ t.cullMode, t.frontFace, t.depthBiasEnable, t.depthBiasConstantFactor,
+ t.depthBiasClamp, t.depthBiasSlopeFactor, t.lineWidth);
+ SerializePnext<VkPipelineRasterizationDepthClipStateCreateInfoEXT>(sink, &t);
}
template <>
-void CacheKeySerializer<VkPipelineMultisampleStateCreateInfo>::Serialize(
- CacheKey* key,
+void stream::Stream<VkPipelineMultisampleStateCreateInfo>::Write(
+ stream::Sink* sink,
const VkPipelineMultisampleStateCreateInfo& t) {
- key->Record(t.flags, t.rasterizationSamples, t.sampleShadingEnable, t.minSampleShading,
- t.pSampleMask, t.alphaToCoverageEnable, t.alphaToOneEnable);
- SerializePnext(key, &t);
+ StreamIn(sink, t.flags, t.rasterizationSamples, t.sampleShadingEnable, t.minSampleShading,
+ t.pSampleMask, t.alphaToCoverageEnable, t.alphaToOneEnable);
+ SerializePnext(sink, &t);
}
template <>
-void CacheKeySerializer<VkStencilOpState>::Serialize(CacheKey* key, const VkStencilOpState& t) {
- key->Record(t.failOp, t.passOp, t.depthFailOp, t.compareOp, t.compareMask, t.writeMask,
- t.reference);
+void stream::Stream<VkStencilOpState>::Write(stream::Sink* sink, const VkStencilOpState& t) {
+ StreamIn(sink, t.failOp, t.passOp, t.depthFailOp, t.compareOp, t.compareMask, t.writeMask,
+ t.reference);
}
template <>
-void CacheKeySerializer<VkPipelineDepthStencilStateCreateInfo>::Serialize(
- CacheKey* key,
+void stream::Stream<VkPipelineDepthStencilStateCreateInfo>::Write(
+ stream::Sink* sink,
const VkPipelineDepthStencilStateCreateInfo& t) {
- key->Record(t.flags, t.depthTestEnable, t.depthWriteEnable, t.depthCompareOp,
- t.depthBoundsTestEnable, t.stencilTestEnable, t.front, t.back, t.minDepthBounds,
- t.maxDepthBounds);
- SerializePnext(key, &t);
+ StreamIn(sink, t.flags, t.depthTestEnable, t.depthWriteEnable, t.depthCompareOp,
+ t.depthBoundsTestEnable, t.stencilTestEnable, t.front, t.back, t.minDepthBounds,
+ t.maxDepthBounds);
+ SerializePnext(sink, &t);
}
template <>
-void CacheKeySerializer<VkPipelineColorBlendAttachmentState>::Serialize(
- CacheKey* key,
+void stream::Stream<VkPipelineColorBlendAttachmentState>::Write(
+ stream::Sink* sink,
const VkPipelineColorBlendAttachmentState& t) {
- key->Record(t.blendEnable, t.srcColorBlendFactor, t.dstColorBlendFactor, t.colorBlendOp,
- t.srcAlphaBlendFactor, t.dstAlphaBlendFactor, t.alphaBlendOp, t.colorWriteMask);
+ StreamIn(sink, t.blendEnable, t.srcColorBlendFactor, t.dstColorBlendFactor, t.colorBlendOp,
+ t.srcAlphaBlendFactor, t.dstAlphaBlendFactor, t.alphaBlendOp, t.colorWriteMask);
}
template <>
-void CacheKeySerializer<VkPipelineColorBlendStateCreateInfo>::Serialize(
- CacheKey* key,
+void stream::Stream<VkPipelineColorBlendStateCreateInfo>::Write(
+ stream::Sink* sink,
const VkPipelineColorBlendStateCreateInfo& t) {
- key->Record(t.flags, t.logicOpEnable, t.logicOp)
- .RecordIterable(t.pAttachments, t.attachmentCount)
- .Record(t.blendConstants);
- SerializePnext(key, &t);
+ StreamIn(sink, t.flags, t.logicOpEnable, t.logicOp, Iterable(t.pAttachments, t.attachmentCount),
+ t.blendConstants);
+ SerializePnext(sink, &t);
}
template <>
-void CacheKeySerializer<VkPipelineDynamicStateCreateInfo>::Serialize(
- CacheKey* key,
+void stream::Stream<VkPipelineDynamicStateCreateInfo>::Write(
+ stream::Sink* sink,
const VkPipelineDynamicStateCreateInfo& t) {
- key->Record(t.flags).RecordIterable(t.pDynamicStates, t.dynamicStateCount);
- SerializePnext(key, &t);
+ StreamIn(sink, t.flags, Iterable(t.pDynamicStates, t.dynamicStateCount));
+ SerializePnext(sink, &t);
}
template <>
-void CacheKeySerializer<vulkan::RenderPassCacheQuery>::Serialize(
- CacheKey* key,
- const vulkan::RenderPassCacheQuery& t) {
- key->Record(t.colorMask.to_ulong(), t.resolveTargetMask.to_ulong(), t.sampleCount);
+void stream::Stream<vulkan::RenderPassCacheQuery>::Write(stream::Sink* sink,
+ const vulkan::RenderPassCacheQuery& t) {
+ StreamIn(sink, t.colorMask.to_ulong(), t.resolveTargetMask.to_ulong(), t.sampleCount);
// Manually iterate the color attachment indices and their corresponding format/load/store
- // ops because the data is sparse and may be uninitialized. Since we record the colorMask
- // member above, recording sparse data should be fine here.
+ // ops because the data is sparse and may be uninitialized. Since we serialize the colorMask
+ // member above, serializing sparse data should be fine here.
for (ColorAttachmentIndex i : IterateBitSet(t.colorMask)) {
- key->Record(t.colorFormats[i], t.colorLoadOp[i], t.colorStoreOp[i]);
+ StreamIn(sink, t.colorFormats[i], t.colorLoadOp[i], t.colorStoreOp[i]);
}
// Serialize the depth-stencil toggle bit, and the parameters if applicable.
- key->Record(t.hasDepthStencil);
+ StreamIn(sink, t.hasDepthStencil);
if (t.hasDepthStencil) {
- key->Record(t.depthStencilFormat, t.depthLoadOp, t.depthStoreOp, t.stencilLoadOp,
- t.stencilStoreOp, t.readOnlyDepthStencil);
+ StreamIn(sink, t.depthStencilFormat, t.depthLoadOp, t.depthStoreOp, t.stencilLoadOp,
+ t.stencilStoreOp, t.readOnlyDepthStencil);
}
}
template <>
-void CacheKeySerializer<VkGraphicsPipelineCreateInfo>::Serialize(
- CacheKey* key,
- const VkGraphicsPipelineCreateInfo& t) {
+void stream::Stream<VkGraphicsPipelineCreateInfo>::Write(stream::Sink* sink,
+ const VkGraphicsPipelineCreateInfo& t) {
// The pipeline layout and render pass are not serialized here because they are pointers to
// backend objects. They need to be cross-referenced with the frontend objects and
- // serialized from there. The base pipeline information is also currently not recorded since
+ // serialized from there. The base pipeline information is also currently not serialized since
// we do not use them in our backend implementation. If we decide to use them later on, they
// also need to be cross-referenced from the frontend.
- key->Record(t.flags)
- .RecordIterable(t.pStages, t.stageCount)
- .Record(t.pVertexInputState, t.pInputAssemblyState, t.pTessellationState, t.pViewportState,
- t.pRasterizationState, t.pMultisampleState, t.pDepthStencilState,
- t.pColorBlendState, t.pDynamicState, t.subpass);
- SerializePnext(key, &t);
+ StreamIn(sink, t.flags, Iterable(t.pStages, t.stageCount), t.pVertexInputState,
+ t.pInputAssemblyState, t.pTessellationState, t.pViewportState, t.pRasterizationState,
+ t.pMultisampleState, t.pDepthStencilState, t.pColorBlendState, t.pDynamicState,
+ t.subpass);
+ SerializePnext(sink, &t);
}
} // namespace dawn::native
diff --git a/src/dawn/native/vulkan/ComputePipelineVk.cpp b/src/dawn/native/vulkan/ComputePipelineVk.cpp
index 25bb888..ad6fbf8 100644
--- a/src/dawn/native/vulkan/ComputePipelineVk.cpp
+++ b/src/dawn/native/vulkan/ComputePipelineVk.cpp
@@ -41,7 +41,7 @@
const PipelineLayout* layout = ToBackend(GetLayout());
// Vulkan devices need cache UUID field to be serialized into pipeline cache keys.
- mCacheKey.Record(device->GetDeviceInfo().properties.pipelineCacheUUID);
+ StreamIn(&mCacheKey, device->GetDeviceInfo().properties.pipelineCacheUUID);
VkComputePipelineCreateInfo createInfo;
createInfo.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
@@ -85,8 +85,8 @@
}
// Record cache key information now since the createInfo is not stored.
- mCacheKey.Record(createInfo, layout)
- .RecordIterable(moduleAndSpirv.spirv, moduleAndSpirv.wordCount);
+ StreamIn(&mCacheKey, createInfo, layout,
+ stream::Iterable(moduleAndSpirv.spirv, moduleAndSpirv.wordCount));
// Try to see if we have anything in the blob cache.
Ref<PipelineCache> cache = ToBackend(GetDevice()->GetOrCreatePipelineCache(GetCacheKey()));
diff --git a/src/dawn/native/vulkan/PipelineLayoutVk.cpp b/src/dawn/native/vulkan/PipelineLayoutVk.cpp
index 48ffc0f..a47c4ed 100644
--- a/src/dawn/native/vulkan/PipelineLayoutVk.cpp
+++ b/src/dawn/native/vulkan/PipelineLayoutVk.cpp
@@ -56,7 +56,7 @@
createInfo.pPushConstantRanges = nullptr;
// Record cache key information now since the createInfo is not stored.
- mCacheKey.RecordIterable(cachedObjects.data(), numSetLayouts).Record(createInfo);
+ StreamIn(&mCacheKey, stream::Iterable(cachedObjects.data(), numSetLayouts), createInfo);
Device* device = ToBackend(GetDevice());
DAWN_TRY(CheckVkSuccess(
diff --git a/src/dawn/native/vulkan/PipelineLayoutVk.h b/src/dawn/native/vulkan/PipelineLayoutVk.h
index ca157f8..2b8f5cf 100644
--- a/src/dawn/native/vulkan/PipelineLayoutVk.h
+++ b/src/dawn/native/vulkan/PipelineLayoutVk.h
@@ -31,6 +31,11 @@
VkPipelineLayout GetHandle() const;
+ // Friend definition of StreamIn which can be found by ADL to override stream::StreamIn<T>.
+ friend void StreamIn(stream::Sink* sink, const PipelineLayout& obj) {
+ StreamIn(sink, static_cast<const CachedObject&>(obj));
+ }
+
private:
~PipelineLayout() override;
void DestroyImpl() override;
diff --git a/src/dawn/native/vulkan/RenderPipelineVk.cpp b/src/dawn/native/vulkan/RenderPipelineVk.cpp
index 6f07ee3..372abb5 100644
--- a/src/dawn/native/vulkan/RenderPipelineVk.cpp
+++ b/src/dawn/native/vulkan/RenderPipelineVk.cpp
@@ -337,7 +337,7 @@
const PipelineLayout* layout = ToBackend(GetLayout());
// Vulkan devices need cache UUID field to be serialized into pipeline cache keys.
- mCacheKey.Record(device->GetDeviceInfo().properties.pipelineCacheUUID);
+ StreamIn(&mCacheKey, device->GetDeviceInfo().properties.pipelineCacheUUID);
// There are at most 2 shader stages in render pipeline, i.e. vertex and fragment
std::array<VkPipelineShaderStageCreateInfo, 2> shaderStages;
@@ -389,7 +389,7 @@
stageCount++;
// Record cache key for each shader since it will become inaccessible later on.
- mCacheKey.Record(stage).RecordIterable(moduleAndSpirv.spirv, moduleAndSpirv.wordCount);
+ StreamIn(&mCacheKey, stream::Iterable(moduleAndSpirv.spirv, moduleAndSpirv.wordCount));
}
PipelineVertexInputStateCreateInfoTemporaryAllocations tempAllocations;
@@ -548,7 +548,7 @@
query.SetSampleCount(GetSampleCount());
- mCacheKey.Record(query);
+ StreamIn(&mCacheKey, query);
DAWN_TRY_ASSIGN(renderPass, device->GetRenderPassCache()->GetRenderPass(query));
}
@@ -577,7 +577,7 @@
createInfo.basePipelineIndex = -1;
// Record cache key information now since createInfo is not stored.
- mCacheKey.Record(createInfo, layout->GetCacheKey());
+ StreamIn(&mCacheKey, createInfo, layout->GetCacheKey());
// Try to see if we have anything in the blob cache.
Ref<PipelineCache> cache = ToBackend(GetDevice()->GetOrCreatePipelineCache(GetCacheKey()));
diff --git a/src/dawn/tests/unittests/native/CacheKeyTests.cpp b/src/dawn/tests/unittests/native/CacheKeyTests.cpp
index 2fafd88..964823b 100644
--- a/src/dawn/tests/unittests/native/CacheKeyTests.cpp
+++ b/src/dawn/tests/unittests/native/CacheKeyTests.cpp
@@ -15,29 +15,44 @@
#include <cstring>
#include <iomanip>
#include <string>
+#include <tuple>
#include <unordered_map>
#include <utility>
#include <vector>
-#include "dawn/native/CacheKey.h"
+#include "dawn/common/TypedInteger.h"
+#include "dawn/native/Blob.h"
+#include "dawn/native/stream/BlobSource.h"
+#include "dawn/native/stream/ByteVectorSink.h"
+#include "dawn/native/stream/Stream.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "tint/tint.h"
-namespace dawn::native {
+namespace dawn::native::stream {
// Testing classes with mock serializing implemented for testing.
class A {
public:
- MOCK_METHOD(void, SerializeMock, (CacheKey*, const A&), (const));
+ MOCK_METHOD(void, WriteMock, (stream::Sink*, const A&), (const));
};
template <>
-void CacheKeySerializer<A>::Serialize(CacheKey* key, const A& t) {
- t.SerializeMock(key, t);
+void stream::Stream<A>::Write(stream::Sink* s, const A& t) {
+ t.WriteMock(s, t);
}
-// Custom printer for CacheKey for clearer debug testing messages.
-void PrintTo(const CacheKey& key, std::ostream* stream) {
+struct Nested {
+ A a1;
+ A a2;
+};
+template <>
+void stream::Stream<Nested>::Write(stream::Sink* s, const Nested& t) {
+ StreamIn(s, t.a1);
+ StreamIn(s, t.a2);
+}
+
+// Custom printer for ByteVectorSink for clearer debug testing messages.
+void PrintTo(const ByteVectorSink& key, std::ostream* stream) {
*stream << std::hex;
for (const int b : key) {
*stream << std::setfill('0') << std::setw(2) << b << " ";
@@ -52,168 +67,134 @@
using ::testing::PrintToString;
using ::testing::Ref;
-// Matcher to compare CacheKeys for easier testing.
-MATCHER_P(CacheKeyEq, key, PrintToString(key)) {
+using TypedIntegerForTest = TypedInteger<struct TypedIntegerForTestTag, uint32_t>;
+
+// Matcher to compare ByteVectorSinks for easier testing.
+MATCHER_P(VectorEq, key, PrintToString(key)) {
return arg.size() == key.size() && memcmp(arg.data(), key.data(), key.size()) == 0;
}
-// Test that CacheKey::Record calls serialize on the single member of a struct.
-TEST(CacheKeyTests, RecordSingleMember) {
- CacheKey key;
+#define EXPECT_CACHE_KEY_EQ(lhs, rhs) \
+ do { \
+ ByteVectorSink actual; \
+ StreamIn(&actual, lhs); \
+ EXPECT_THAT(actual, VectorEq(rhs)); \
+ } while (0)
+
+// Test that ByteVectorSink calls Write on a value.
+TEST(SerializeTests, CallsWrite) {
+ ByteVectorSink sink;
A a;
- EXPECT_CALL(a, SerializeMock(NotNull(), Ref(a))).Times(1);
- EXPECT_THAT(key.Record(a), CacheKeyEq(CacheKey()));
+ EXPECT_CALL(a, WriteMock(NotNull(), Ref(a))).Times(1);
+ StreamIn(&sink, a);
}
-// Test that CacheKey::Record calls serialize on all members of a struct.
-TEST(CacheKeyTests, RecordManyMembers) {
- constexpr size_t kNumMembers = 100;
-
- CacheKey key;
- for (size_t i = 0; i < kNumMembers; ++i) {
- A a;
- EXPECT_CALL(a, SerializeMock(NotNull(), Ref(a))).Times(1);
- key.Record(a);
- }
- EXPECT_THAT(key, CacheKeyEq(CacheKey()));
-}
-
-// Test that CacheKey::Record calls serialize on all elements of an iterable.
-TEST(CacheKeyTests, RecordIterable) {
+// Test that ByteVectorSink calls Write on all elements of an iterable.
+TEST(SerializeTests, StreamInIterable) {
constexpr size_t kIterableSize = 100;
+ std::vector<A> vec(kIterableSize);
+ auto iterable = stream::Iterable(vec.data(), kIterableSize);
+
+ // Expect write to be called for each element
+ for (const auto& a : vec) {
+ EXPECT_CALL(a, WriteMock(NotNull(), Ref(a))).Times(1);
+ }
+
+ ByteVectorSink sink;
+ StreamIn(&sink, iterable);
+
// Expecting the size of the container.
- CacheKey expected;
- expected.Record(kIterableSize);
-
- std::vector<A> iterable(kIterableSize);
- {
- InSequence seq;
- for (const auto& a : iterable) {
- EXPECT_CALL(a, SerializeMock(NotNull(), Ref(a))).Times(1);
- }
- for (const auto& a : iterable) {
- EXPECT_CALL(a, SerializeMock(NotNull(), Ref(a))).Times(1);
- }
- }
-
- EXPECT_THAT(CacheKey().RecordIterable(iterable), CacheKeyEq(expected));
- EXPECT_THAT(CacheKey().RecordIterable(iterable.data(), kIterableSize), CacheKeyEq(expected));
+ ByteVectorSink expected;
+ StreamIn(&expected, kIterableSize);
+ EXPECT_THAT(sink, VectorEq(expected));
}
-// Test that CacheKey::Record calls serialize on all members and nested struct members.
-TEST(CacheKeyTests, RecordNested) {
- CacheKey expected;
- CacheKey actual;
- {
- // Recording a single member.
- A a;
- EXPECT_CALL(a, SerializeMock(NotNull(), Ref(a))).Times(1);
- actual.Record(CacheKey().Record(a));
- }
- {
- // Recording multiple members.
- constexpr size_t kNumMembers = 2;
- CacheKey sub;
- for (size_t i = 0; i < kNumMembers; ++i) {
- A a;
- EXPECT_CALL(a, SerializeMock(NotNull(), Ref(a))).Times(1);
- sub.Record(a);
- }
- actual.Record(sub);
- }
- {
- // Record an iterable.
- constexpr size_t kIterableSize = 2;
- expected.Record(kIterableSize);
- std::vector<A> iterable(kIterableSize);
- {
- InSequence seq;
- for (const auto& a : iterable) {
- EXPECT_CALL(a, SerializeMock(NotNull(), Ref(a))).Times(1);
- }
- }
- actual.Record(CacheKey().RecordIterable(iterable));
- }
- EXPECT_THAT(actual, CacheKeyEq(expected));
+// Test that ByteVectorSink calls Write on all nested members of a struct.
+TEST(SerializeTests, StreamInNested) {
+ ByteVectorSink sink;
+
+ Nested n;
+ EXPECT_CALL(n.a1, WriteMock(NotNull(), Ref(n.a1))).Times(1);
+ EXPECT_CALL(n.a2, WriteMock(NotNull(), Ref(n.a2))).Times(1);
+ StreamIn(&sink, n);
}
-// Test that CacheKey::Record serializes integral data as expected.
-TEST(CacheKeySerializerTests, IntegralTypes) {
+// Test that ByteVectorSink serializes integral data as expected.
+TEST(SerializeTests, IntegralTypes) {
// Only testing explicitly sized types for simplicity, and using 0s for larger types to
// avoid dealing with endianess.
- EXPECT_THAT(CacheKey().Record('c'), CacheKeyEq(CacheKey({'c'})));
- EXPECT_THAT(CacheKey().Record(uint8_t(255)), CacheKeyEq(CacheKey({255})));
- EXPECT_THAT(CacheKey().Record(uint16_t(0)), CacheKeyEq(CacheKey({0, 0})));
- EXPECT_THAT(CacheKey().Record(uint32_t(0)), CacheKeyEq(CacheKey({0, 0, 0, 0})));
+ EXPECT_CACHE_KEY_EQ('c', ByteVectorSink({'c'}));
+ EXPECT_CACHE_KEY_EQ(uint8_t(255), ByteVectorSink({255}));
+ EXPECT_CACHE_KEY_EQ(uint16_t(0), ByteVectorSink({0, 0}));
+ EXPECT_CACHE_KEY_EQ(uint32_t(0), ByteVectorSink({0, 0, 0, 0}));
}
-// Test that CacheKey::Record serializes floating-point data as expected.
-TEST(CacheKeySerializerTests, FloatingTypes) {
+// Test that ByteVectorSink serializes floating-point data as expected.
+TEST(SerializeTests, FloatingTypes) {
// Using 0s to avoid dealing with implementation specific float details.
- EXPECT_THAT(CacheKey().Record(float{0}), CacheKeyEq(CacheKey(sizeof(float), 0)));
- EXPECT_THAT(CacheKey().Record(double{0}), CacheKeyEq(CacheKey(sizeof(double), 0)));
+ ByteVectorSink k1, k2;
+ EXPECT_CACHE_KEY_EQ(float{0}, ByteVectorSink(sizeof(float), 0));
+ EXPECT_CACHE_KEY_EQ(double{0}, ByteVectorSink(sizeof(double), 0));
}
-// Test that CacheKey::Record serializes literal strings as expected.
-TEST(CacheKeySerializerTests, LiteralStrings) {
+// Test that ByteVectorSink serializes literal strings as expected.
+TEST(SerializeTests, LiteralStrings) {
// Using a std::string here to help with creating the expected result.
std::string str = "string";
- CacheKey expected;
- expected.Record(size_t(7));
+ ByteVectorSink expected;
expected.insert(expected.end(), str.begin(), str.end());
expected.push_back('\0');
- EXPECT_THAT(CacheKey().Record("string"), CacheKeyEq(expected));
+ EXPECT_CACHE_KEY_EQ("string", expected);
}
-// Test that CacheKey::Record serializes std::strings as expected.
-TEST(CacheKeySerializerTests, StdStrings) {
+// Test that ByteVectorSink serializes std::strings as expected.
+TEST(SerializeTests, StdStrings) {
std::string str = "string";
- CacheKey expected;
- expected.Record(size_t(6));
+ ByteVectorSink expected;
+ StreamIn(&expected, size_t(6));
expected.insert(expected.end(), str.begin(), str.end());
- EXPECT_THAT(CacheKey().Record(str), CacheKeyEq(expected));
+ EXPECT_CACHE_KEY_EQ(str, expected);
}
-// Test that CacheKey::Record serializes std::string_views as expected.
-TEST(CacheKeySerializerTests, StdStringViews) {
+// Test that ByteVectorSink serializes std::string_views as expected.
+TEST(SerializeTests, StdStringViews) {
static constexpr std::string_view str("string");
- CacheKey expected;
- expected.Record(size_t(6));
+ ByteVectorSink expected;
+ StreamIn(&expected, size_t(6));
expected.insert(expected.end(), str.begin(), str.end());
- EXPECT_THAT(CacheKey().Record(str), CacheKeyEq(expected));
+ EXPECT_CACHE_KEY_EQ(str, expected);
}
-// Test that CacheKey::Record serializes other CacheKeys as expected.
-TEST(CacheKeySerializerTests, CacheKeys) {
- CacheKey data = {'d', 'a', 't', 'a'};
+// Test that ByteVectorSink serializes other ByteVectorSinks as expected.
+TEST(SerializeTests, ByteVectorSinks) {
+ ByteVectorSink data = {'d', 'a', 't', 'a'};
- CacheKey expected;
+ ByteVectorSink expected;
expected.insert(expected.end(), data.begin(), data.end());
- EXPECT_THAT(CacheKey().Record(data), CacheKeyEq(expected));
+ EXPECT_CACHE_KEY_EQ(data, expected);
}
-// Test that CacheKey::Record serializes std::pair as expected.
-TEST(CacheKeySerializerTests, StdPair) {
+// Test that ByteVectorSink serializes std::pair as expected.
+TEST(SerializeTests, StdPair) {
std::string_view s = "hi!";
- CacheKey expected;
- expected.Record(s);
- expected.Record(uint32_t(42));
+ ByteVectorSink expected;
+ StreamIn(&expected, s, uint32_t(42));
- EXPECT_THAT(CacheKey().Record(std::make_pair(s, uint32_t(42))), CacheKeyEq(expected));
+ EXPECT_CACHE_KEY_EQ(std::make_pair(s, uint32_t(42)), expected);
}
-// Test that CacheKey::Record serializes std::unordered_map as expected.
-TEST(CacheKeySerializerTests, StdUnorderedMap) {
+// Test that ByteVectorSink serializes std::unordered_map as expected.
+TEST(SerializeTests, StdUnorderedMap) {
std::unordered_map<uint32_t, std::string_view> m;
m[4] = "hello";
@@ -222,32 +203,175 @@
m[3] = "data";
// Expect the number of entries, followed by (K, V) pairs sorted in order of key.
- CacheKey expected;
- expected.Record(size_t(4));
- expected.Record(std::make_pair(uint32_t(1), m[1]));
- expected.Record(std::make_pair(uint32_t(3), m[3]));
- expected.Record(std::make_pair(uint32_t(4), m[4]));
- expected.Record(std::make_pair(uint32_t(7), m[7]));
+ ByteVectorSink expected;
+ StreamIn(&expected, size_t(4), std::make_pair(uint32_t(1), m[1]),
+ std::make_pair(uint32_t(3), m[3]), std::make_pair(uint32_t(4), m[4]),
+ std::make_pair(uint32_t(7), m[7]));
- EXPECT_THAT(CacheKey().Record(m), CacheKeyEq(expected));
+ EXPECT_CACHE_KEY_EQ(m, expected);
}
-// Test that CacheKey::Record serializes tint::sem::BindingPoint as expected.
-TEST(CacheKeySerializerTests, TintSemBindingPoint) {
+// Test that ByteVectorSink serializes tint::sem::BindingPoint as expected.
+TEST(SerializeTests, TintSemBindingPoint) {
tint::sem::BindingPoint bp{3, 6};
- EXPECT_THAT(CacheKey().Record(bp), CacheKeyEq(CacheKey().Record(uint32_t(3), uint32_t(6))));
+
+ ByteVectorSink expected;
+ StreamIn(&expected, uint32_t(3), uint32_t(6));
+
+ EXPECT_CACHE_KEY_EQ(bp, expected);
}
-// Test that CacheKey::Record serializes tint::transform::BindingPoints as expected.
-TEST(CacheKeySerializerTests, TintTransformBindingPoints) {
+// Test that ByteVectorSink serializes tint::transform::BindingPoints as expected.
+TEST(SerializeTests, TintTransformBindingPoints) {
tint::transform::BindingPoints points{
tint::sem::BindingPoint{1, 4},
tint::sem::BindingPoint{3, 7},
};
- EXPECT_THAT(CacheKey().Record(points),
- CacheKeyEq(CacheKey().Record(uint32_t(1), uint32_t(4), uint32_t(3), uint32_t(7))));
+
+ ByteVectorSink expected;
+ StreamIn(&expected, uint32_t(1), uint32_t(4), uint32_t(3), uint32_t(7));
+
+ EXPECT_CACHE_KEY_EQ(points, expected);
}
+// Test that serializing then deserializing a param pack yields the same values.
+TEST(StreamTests, SerializeDeserializeParamPack) {
+ int a = 1;
+ float b = 2.0;
+ std::pair<std::string, double> c = std::make_pair("dawn", 3.4);
+
+ ByteVectorSink sink;
+ StreamIn(&sink, a, b, c);
+
+ BlobSource source(CreateBlob(std::move(sink)));
+ int aOut;
+ float bOut;
+ std::pair<std::string, double> cOut;
+ auto err = StreamOut(&source, &aOut, &bOut, &cOut);
+ if (err.IsError()) {
+ FAIL() << err.AcquireError()->GetFormattedMessage();
+ }
+ EXPECT_EQ(a, aOut);
+ EXPECT_EQ(b, bOut);
+ EXPECT_EQ(c, cOut);
+}
+
+template <size_t N>
+std::bitset<N - 1> BitsetFromBitString(const char (&str)[N]) {
+ // N - 1 because the last character is the null terminator.
+ return std::bitset<N - 1>(str, N - 1);
+}
+
+static auto kStreamValueVectorParams = std::make_tuple(
+ // Test primitives.
+ std::vector<int>{4, 5, 6, 2},
+ std::vector<float>{6.50, 78.28, 92., 8.28},
+ // Test various types of strings.
+ std::vector<std::string>{"abcdefg", "9461849495", ""},
+ // Test pairs.
+ std::vector<std::pair<int, float>>{{1, 3.}, {6, 4.}},
+ // Test TypedIntegers
+ std::vector<TypedIntegerForTest>{TypedIntegerForTest(42), TypedIntegerForTest(13)},
+ // Test enums
+ std::vector<wgpu::TextureUsage>{wgpu::TextureUsage::CopyDst,
+ wgpu::TextureUsage::RenderAttachment},
+ // Test bitsets of various sizes.
+ std::vector<std::bitset<7>>{0b1001011, 0b0011010, 0b0000000, 0b1111111},
+ std::vector<std::bitset<17>>{0x0000, 0xFFFF1},
+ std::vector<std::bitset<32>>{0x0C0FFEE0, 0xDEADC0DE, 0x00000000, 0xFFFFFFFF},
+ std::vector<std::bitset<57>>{
+ BitsetFromBitString("100110010101011001100110101011001100101010110011001011011"),
+ BitsetFromBitString("000110010101011000100110101011001100101010010011001010100"),
+ BitsetFromBitString("111111111111111111111111111111111111111111111111111111111"), 0},
+ // Test vectors.
+ std::vector<std::vector<int>>{{}, {1, 5, 2, 7, 4}, {3, 3, 3, 3, 3, 3, 3}});
+
+static auto kStreamValueInitListParams = std::make_tuple(
+ std::initializer_list<char[12]>{"test string", "string test"},
+ std::initializer_list<double[3]>{{5.435, 32.3, 1.23}, {8.2345, 0.234532, 4.435}});
+
+template <typename, typename>
+struct StreamValueTestTypesImpl;
+
+template <typename... T, typename... T2>
+struct StreamValueTestTypesImpl<std::tuple<std::vector<T>...>,
+ std::tuple<std::initializer_list<T2>...>> {
+ using type = ::testing::Types<T..., T2...>;
+};
+
+using StreamValueTestTypes =
+ typename StreamValueTestTypesImpl<decltype(kStreamValueVectorParams),
+ decltype(kStreamValueInitListParams)>::type;
+
+template <typename T>
+class StreamParameterizedTests : public ::testing::Test {
+ protected:
+ static std::vector<T> GetParams() { return std::get<std::vector<T>>(kStreamValueVectorParams); }
+
+ void ExpectEq(const T& lhs, const T& rhs) { EXPECT_EQ(lhs, rhs); }
+};
+
+template <typename T, size_t N>
+class StreamParameterizedTests<T[N]> : public ::testing::Test {
+ protected:
+ static std::initializer_list<T[N]> GetParams() {
+ return std::get<std::initializer_list<T[N]>>(kStreamValueInitListParams);
+ }
+
+ void ExpectEq(const T lhs[N], const T rhs[N]) { EXPECT_EQ(memcmp(lhs, rhs, sizeof(T[N])), 0); }
+};
+
+TYPED_TEST_SUITE_P(StreamParameterizedTests);
+
+// Test that serializing a value, then deserializing it yields the same value.
+TYPED_TEST_P(StreamParameterizedTests, SerializeDeserialize) {
+ for (const auto& value : this->GetParams()) {
+ ByteVectorSink sink;
+ StreamIn(&sink, value);
+
+ BlobSource source(CreateBlob(std::move(sink)));
+ TypeParam deserialized;
+ auto err = StreamOut(&source, &deserialized);
+ if (err.IsError()) {
+ FAIL() << err.AcquireError()->GetFormattedMessage();
+ }
+ this->ExpectEq(deserialized, value);
+ }
+}
+
+// Test that serializing a value, then deserializing it with insufficient space, an error is raised.
+TYPED_TEST_P(StreamParameterizedTests, SerializeDeserializeOutOfBounds) {
+ for (const auto& value : this->GetParams()) {
+ ByteVectorSink sink;
+ StreamIn(&sink, value);
+
+ // Make the vector 1 byte too small.
+ std::vector<uint8_t> src = sink;
+ src.pop_back();
+
+ BlobSource source(CreateBlob(std::move(src)));
+ TypeParam deserialized;
+ auto err = StreamOut(&source, &deserialized);
+ EXPECT_TRUE(err.IsError());
+ err.AcquireError();
+ }
+}
+
+// Test that deserializing from an empty source raises an error.
+TYPED_TEST_P(StreamParameterizedTests, DeserializeEmpty) {
+ BlobSource source(CreateBlob(0));
+ TypeParam deserialized;
+ auto err = StreamOut(&source, &deserialized);
+ EXPECT_TRUE(err.IsError());
+ err.AcquireError();
+}
+
+REGISTER_TYPED_TEST_SUITE_P(StreamParameterizedTests,
+ SerializeDeserialize,
+ SerializeDeserializeOutOfBounds,
+ DeserializeEmpty);
+INSTANTIATE_TYPED_TEST_SUITE_P(DawnUnittests, StreamParameterizedTests, StreamValueTestTypes, );
+
} // namespace
-} // namespace dawn::native
+} // namespace dawn::native::stream
diff --git a/src/dawn/tests/unittests/native/CacheRequestTests.cpp b/src/dawn/tests/unittests/native/CacheRequestTests.cpp
index 995de7f..4a3e717 100644
--- a/src/dawn/tests/unittests/native/CacheRequestTests.cpp
+++ b/src/dawn/tests/unittests/native/CacheRequestTests.cpp
@@ -102,7 +102,8 @@
// Make the expected key.
CacheKey expectedKey;
- expectedKey.Record(GetDevice()->GetCacheKey(), "CacheRequestForTesting", req.a, req.b, req.c);
+ StreamIn(&expectedKey, GetDevice()->GetCacheKey(), "CacheRequestForTesting", req.a, req.b,
+ req.c);
// Expect a call to LoadData with the expected key.
EXPECT_CALL(mMockCache, LoadData(_, expectedKey.size(), nullptr, 0))