Add CacheRequest utilities and tests
This CL adds a DAWN_MAKE_CACHE_REQUEST X macro
which helps in building a CacheRequest struct.
A CacheRequest struct may be passed to LoadOrRun
which will generate a CacheKey from the struct and
load a result if there is a cache hit, or it will
call the provided cache miss function to compute a value.
The request struct helps enforce that precisely the
inputs that go into a computation are all also included
inside the CacheKey for that computation.
Bug: dawn:549
Change-Id: Id85eb95f1b944d5431f142162ffa9a384351be89
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/91063
Reviewed-by: Loko Kung <lokokung@google.com>
Commit-Queue: Austin Eng <enga@chromium.org>
diff --git a/src/dawn/native/BUILD.gn b/src/dawn/native/BUILD.gn
index 4294b26..137ce1b 100644
--- a/src/dawn/native/BUILD.gn
+++ b/src/dawn/native/BUILD.gn
@@ -200,6 +200,9 @@
"Buffer.h",
"CacheKey.cpp",
"CacheKey.h",
+ "CacheRequest.cpp",
+ "CacheRequest.h",
+ "CacheResult.h",
"CachedObject.cpp",
"CachedObject.h",
"CallbackTaskManager.cpp",
diff --git a/src/dawn/native/CMakeLists.txt b/src/dawn/native/CMakeLists.txt
index bbb50a0..8542b33 100644
--- a/src/dawn/native/CMakeLists.txt
+++ b/src/dawn/native/CMakeLists.txt
@@ -59,6 +59,9 @@
"CachedObject.h"
"CacheKey.cpp"
"CacheKey.h"
+ "CacheRequest.cpp"
+ "CacheRequest.h"
+ "CacheResult.h"
"CallbackTaskManager.cpp"
"CallbackTaskManager.h"
"CommandAllocator.cpp"
diff --git a/src/dawn/native/CacheKey.h b/src/dawn/native/CacheKey.h
index 357ce4b..c2901db 100644
--- a/src/dawn/native/CacheKey.h
+++ b/src/dawn/native/CacheKey.h
@@ -18,8 +18,10 @@
#include <bitset>
#include <iostream>
#include <limits>
+#include <memory>
#include <string>
#include <type_traits>
+#include <utility>
#include <vector>
#include "dawn/common/TypedInteger.h"
@@ -49,6 +51,19 @@
enum class Type { ComputePipeline, RenderPipeline, Shader };
template <typename T>
+ class UnsafeUnkeyedValue {
+ public:
+ UnsafeUnkeyedValue() = default;
+ // NOLINTNEXTLINE(runtime/explicit) allow implicit construction to decrease verbosity
+ UnsafeUnkeyedValue(T&& value) : mValue(std::forward<T>(value)) {}
+
+ const T& UnsafeGetValue() const { return mValue; }
+
+ private:
+ T mValue;
+ };
+
+ template <typename T>
CacheKey& Record(const T& t) {
CacheKeySerializer<T>::Serialize(this, t);
return *this;
@@ -89,6 +104,18 @@
}
};
+template <typename T>
+CacheKey::UnsafeUnkeyedValue<T> UnsafeUnkeyedValue(T&& value) {
+ return CacheKey::UnsafeUnkeyedValue<T>(std::forward<T>(value));
+}
+
+// Specialized overload for CacheKey::UnsafeIgnoredValue which does nothing.
+template <typename T>
+class CacheKeySerializer<CacheKey::UnsafeUnkeyedValue<T>> {
+ public:
+ constexpr static void Serialize(CacheKey* key, const CacheKey::UnsafeUnkeyedValue<T>&) {}
+};
+
// Specialized overload for fundamental types.
template <typename T>
class CacheKeySerializer<T, std::enable_if_t<std::is_fundamental_v<T>>> {
@@ -197,6 +224,13 @@
static void Serialize(CacheKey* key, const T& t) { key->Record(t.GetCacheKey()); }
};
+// Specialized overload for std::vector.
+template <typename T>
+class CacheKeySerializer<std::vector<T>> {
+ public:
+ static void Serialize(CacheKey* key, const std::vector<T>& t) { key->RecordIterable(t); }
+};
+
} // namespace dawn::native
#endif // SRC_DAWN_NATIVE_CACHEKEY_H_
diff --git a/src/dawn/native/CacheRequest.cpp b/src/dawn/native/CacheRequest.cpp
new file mode 100644
index 0000000..2b35b12
--- /dev/null
+++ b/src/dawn/native/CacheRequest.cpp
@@ -0,0 +1,25 @@
+// Copyright 2022 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "dawn/native/CacheRequest.h"
+
+#include "dawn/common/Log.h"
+
+namespace dawn::native::detail {
+
+void LogCacheHitError(std::unique_ptr<ErrorData> error) {
+ dawn::ErrorLog() << error->GetFormattedMessage();
+}
+
+} // namespace dawn::native::detail
diff --git a/src/dawn/native/CacheRequest.h b/src/dawn/native/CacheRequest.h
new file mode 100644
index 0000000..0fecb0d
--- /dev/null
+++ b/src/dawn/native/CacheRequest.h
@@ -0,0 +1,186 @@
+// Copyright 2022 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_DAWN_NATIVE_CACHEREQUEST_H_
+#define SRC_DAWN_NATIVE_CACHEREQUEST_H_
+
+#include <memory>
+#include <utility>
+
+#include "dawn/common/Assert.h"
+#include "dawn/common/Compiler.h"
+#include "dawn/native/Blob.h"
+#include "dawn/native/BlobCache.h"
+#include "dawn/native/CacheKey.h"
+#include "dawn/native/CacheResult.h"
+#include "dawn/native/Device.h"
+#include "dawn/native/Error.h"
+
+namespace dawn::native {
+
+namespace detail {
+
+template <typename T>
+struct UnwrapResultOrError {
+ using type = T;
+};
+
+template <typename T>
+struct UnwrapResultOrError<ResultOrError<T>> {
+ using type = T;
+};
+
+template <typename T>
+struct IsResultOrError {
+ static constexpr bool value = false;
+};
+
+template <typename T>
+struct IsResultOrError<ResultOrError<T>> {
+ static constexpr bool value = true;
+};
+
+void LogCacheHitError(std::unique_ptr<ErrorData> error);
+
+} // namespace detail
+
+// Implementation of a CacheRequest which provides a LoadOrRun friend function which can be found
+// via argument-dependent lookup. So, it doesn't need to be called with a fully qualified function
+// name.
+//
+// Example usage:
+// Request r = { ... };
+// ResultOrError<CacheResult<T>> cacheResult =
+// LoadOrRun(device, std::move(r),
+// [](Blob blob) -> T { /* handle cache hit */ },
+// [](Request r) -> ResultOrError<T> { /* handle cache miss */ }
+// );
+// Or with free functions:
+/// T OnCacheHit(Blob blob) { ... }
+// ResultOrError<T> OnCacheMiss(Request r) { ... }
+// // ...
+// Request r = { ... };
+// auto result = LoadOrRun(device, std::move(r), OnCacheHit, OnCacheMiss);
+//
+// LoadOrRun generates a CacheKey from the request and loads from the device's BlobCache. On cache
+// hit, calls CacheHitFn and returns a CacheResult<T>. On cache miss or if CacheHitFn returned an
+// Error, calls CacheMissFn -> ResultOrError<T> with the request data and returns a
+// ResultOrError<CacheResult<T>>. CacheHitFn must return the same unwrapped type as CacheMissFn.
+// i.e. it doesn't need to be wrapped in ResultOrError.
+//
+// CacheMissFn may not have any additional data bound to it. It may not be a lambda or std::function
+// which captures additional information, so it can only operate on the request data. This is
+// enforced with a compile-time static_assert, and ensures that the result created from the
+// computation is exactly the data included in the CacheKey.
+template <typename Request>
+class CacheRequestImpl {
+ public:
+ CacheRequestImpl() = default;
+
+ // Require CacheRequests to be move-only to avoid unnecessary copies.
+ CacheRequestImpl(CacheRequestImpl&&) = default;
+ CacheRequestImpl& operator=(CacheRequestImpl&&) = default;
+ CacheRequestImpl(const CacheRequestImpl&) = delete;
+ CacheRequestImpl& operator=(const CacheRequestImpl&) = delete;
+
+ template <typename CacheHitFn, typename CacheMissFn>
+ friend auto LoadOrRun(DeviceBase* device,
+ Request&& r,
+ CacheHitFn cacheHitFn,
+ CacheMissFn cacheMissFn) {
+ // Get return types and check that CacheMissReturnType can be cast to a raw function
+ // pointer. This means it's not a std::function or lambda that captures additional data.
+ using CacheHitReturnType = decltype(cacheHitFn(std::declval<Blob>()));
+ using CacheMissReturnType = decltype(cacheMissFn(std::declval<Request>()));
+ static_assert(
+ std::is_convertible_v<CacheMissFn, CacheMissReturnType (*)(Request)>,
+ "CacheMissFn function signature does not match, or it is not a free function.");
+
+ static_assert(detail::IsResultOrError<CacheMissReturnType>::value,
+ "CacheMissFn should return a ResultOrError.");
+ using UnwrappedReturnType = typename detail::UnwrapResultOrError<CacheMissReturnType>::type;
+
+ static_assert(std::is_same_v<typename detail::UnwrapResultOrError<CacheHitReturnType>::type,
+ UnwrappedReturnType>,
+ "If CacheMissFn returns T, CacheHitFn must return T or ResultOrError<T>.");
+
+ using CacheResultType = CacheResult<UnwrappedReturnType>;
+ using ReturnType = ResultOrError<CacheResultType>;
+
+ CacheKey key = r.CreateCacheKey(device);
+ BlobCache* cache = device->GetBlobCache();
+ Blob blob;
+ if (cache != nullptr) {
+ blob = cache->Load(key);
+ }
+
+ if (!blob.Empty()) {
+ // Cache hit. Handle the cached blob.
+ auto result = cacheHitFn(std::move(blob));
+
+ if constexpr (!detail::IsResultOrError<CacheHitReturnType>::value) {
+ // If the result type is not a ResultOrError, return it.
+ return ReturnType(CacheResultType::CacheHit(std::move(key), std::move(result)));
+ } else {
+ // Otherwise, if the value is a success, also return it.
+ if (DAWN_LIKELY(result.IsSuccess())) {
+ return ReturnType(
+ CacheResultType::CacheHit(std::move(key), result.AcquireSuccess()));
+ }
+ // On error, continue to the cache miss path and log the error.
+ detail::LogCacheHitError(result.AcquireError());
+ }
+ }
+ // Cache miss, or the CacheHitFn failed.
+ auto result = cacheMissFn(std::move(r));
+ if (DAWN_LIKELY(result.IsSuccess())) {
+ return ReturnType(CacheResultType::CacheMiss(std::move(key), result.AcquireSuccess()));
+ }
+ return ReturnType(result.AcquireError());
+ }
+};
+
+// Helper for X macro to declare a struct member.
+#define DAWN_INTERNAL_CACHE_REQUEST_DECL_STRUCT_MEMBER(type, name) type name{};
+
+// Helper for X macro for recording cache request fields into a CacheKey.
+#define DAWN_INTERNAL_CACHE_REQUEST_RECORD_KEY(type, name) key.Record(name);
+
+// Helper X macro to define a CacheRequest struct.
+// Example usage:
+// #define REQUEST_MEMBERS(X) \
+// X(int, a) \
+// X(float, b) \
+// X(Foo, foo) \
+// X(Bar, bar)
+// DAWN_MAKE_CACHE_REQUEST(MyCacheRequest, REQUEST_MEMBERS)
+// #undef REQUEST_MEMBERS
+#define DAWN_MAKE_CACHE_REQUEST(Request, MEMBERS) \
+ class Request : public CacheRequestImpl<Request> { \
+ public: \
+ Request() = default; \
+ MEMBERS(DAWN_INTERNAL_CACHE_REQUEST_DECL_STRUCT_MEMBER) \
+ \
+ /* Create a CacheKey from the request type and all members */ \
+ CacheKey CreateCacheKey(const DeviceBase* device) const { \
+ CacheKey key = device->GetCacheKey(); \
+ key.Record(#Request); \
+ MEMBERS(DAWN_INTERNAL_CACHE_REQUEST_RECORD_KEY) \
+ return key; \
+ } \
+ };
+
+} // namespace dawn::native
+
+#endif // SRC_DAWN_NATIVE_CACHEREQUEST_H_
diff --git a/src/dawn/native/CacheResult.h b/src/dawn/native/CacheResult.h
new file mode 100644
index 0000000..a2750fe
--- /dev/null
+++ b/src/dawn/native/CacheResult.h
@@ -0,0 +1,76 @@
+// Copyright 2022 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_DAWN_NATIVE_CACHERESULT_H_
+#define SRC_DAWN_NATIVE_CACHERESULT_H_
+
+#include <memory>
+#include <utility>
+
+#include "dawn/common/Assert.h"
+
+namespace dawn::native {
+
+template <typename T>
+class CacheResult {
+ public:
+ static CacheResult CacheHit(CacheKey key, T value) {
+ return CacheResult(std::move(key), std::move(value), true);
+ }
+
+ static CacheResult CacheMiss(CacheKey key, T value) {
+ return CacheResult(std::move(key), std::move(value), false);
+ }
+
+ CacheResult() : mKey(), mValue(), mIsCached(false), mIsValid(false) {}
+
+ bool IsCached() const {
+ ASSERT(mIsValid);
+ return mIsCached;
+ }
+ const CacheKey& GetCacheKey() {
+ ASSERT(mIsValid);
+ return mKey;
+ }
+
+ // Note: Getting mValue is always const, since mutating it would invalidate consistency with
+ // mKey.
+ const T* operator->() const {
+ ASSERT(mIsValid);
+ return &mValue;
+ }
+ const T& operator*() const {
+ ASSERT(mIsValid);
+ return mValue;
+ }
+
+ T Acquire() {
+ ASSERT(mIsValid);
+ mIsValid = false;
+ return std::move(mValue);
+ }
+
+ private:
+ CacheResult(CacheKey key, T value, bool isCached)
+ : mKey(std::move(key)), mValue(std::move(value)), mIsCached(isCached), mIsValid(true) {}
+
+ CacheKey mKey;
+ T mValue;
+ bool mIsCached;
+ bool mIsValid;
+};
+
+} // namespace dawn::native
+
+#endif // SRC_DAWN_NATIVE_CACHERESULT_H_
diff --git a/src/dawn/tests/BUILD.gn b/src/dawn/tests/BUILD.gn
index 69c41a2..57bb666 100644
--- a/src/dawn/tests/BUILD.gn
+++ b/src/dawn/tests/BUILD.gn
@@ -192,6 +192,7 @@
":gmock_and_gtest",
":mock_webgpu_gen",
":native_mocks_sources",
+ ":platform_mocks_sources",
"${dawn_root}/src/dawn:cpp",
"${dawn_root}/src/dawn:proc",
"${dawn_root}/src/dawn/common",
@@ -254,6 +255,7 @@
"unittests/VersionTests.cpp",
"unittests/native/BlobTests.cpp",
"unittests/native/CacheKeyTests.cpp",
+ "unittests/native/CacheRequestTests.cpp",
"unittests/native/CommandBufferEncodingTests.cpp",
"unittests/native/CreatePipelineAsyncTaskTests.cpp",
"unittests/native/DestroyObjectTests.cpp",
@@ -380,9 +382,9 @@
# Dawn end2end tests targets
###############################################################################
-# Source code for mocks used for end2end testing are separated from the rest of
+# Source code for mocks used for platform testing are separated from the rest of
# sources so that they aren't included in non-test builds.
-source_set("end2end_mocks_sources") {
+source_set("platform_mocks_sources") {
configs += [ "${dawn_root}/src/dawn/native:internal" ]
testonly = true
@@ -392,8 +394,8 @@
]
sources = [
- "end2end/mocks/CachingInterfaceMock.cpp",
- "end2end/mocks/CachingInterfaceMock.h",
+ "mocks/platform/CachingInterfaceMock.cpp",
+ "mocks/platform/CachingInterfaceMock.h",
]
}
@@ -401,7 +403,7 @@
testonly = true
deps = [
- ":end2end_mocks_sources",
+ ":platform_mocks_sources",
":test_infra_sources",
"${dawn_root}/src/dawn:cpp",
"${dawn_root}/src/dawn:proc",
diff --git a/src/dawn/tests/DawnNativeTest.cpp b/src/dawn/tests/DawnNativeTest.cpp
index 163413d..fbf8030 100644
--- a/src/dawn/tests/DawnNativeTest.cpp
+++ b/src/dawn/tests/DawnNativeTest.cpp
@@ -20,6 +20,9 @@
#include "dawn/common/Assert.h"
#include "dawn/dawn_proc.h"
#include "dawn/native/ErrorData.h"
+#include "dawn/native/Instance.h"
+#include "dawn/native/dawn_platform.h"
+#include "dawn/platform/DawnPlatform.h"
namespace dawn::native {
@@ -43,6 +46,9 @@
void DawnNativeTest::SetUp() {
instance = std::make_unique<dawn::native::Instance>();
+ platform = CreateTestPlatform();
+ dawn::native::FromAPI(instance->Get())->SetPlatformForTesting(platform.get());
+
instance->DiscoverDefaultAdapters();
std::vector<dawn::native::Adapter> adapters = instance->GetAdapters();
@@ -66,7 +72,9 @@
device.SetUncapturedErrorCallback(DawnNativeTest::OnDeviceError, nullptr);
}
-void DawnNativeTest::TearDown() {}
+std::unique_ptr<dawn::platform::Platform> DawnNativeTest::CreateTestPlatform() {
+ return nullptr;
+}
WGPUDevice DawnNativeTest::CreateTestDevice() {
// Disabled disallowing unsafe APIs so we can test them.
diff --git a/src/dawn/tests/DawnNativeTest.h b/src/dawn/tests/DawnNativeTest.h
index e92bf67..dd3532f 100644
--- a/src/dawn/tests/DawnNativeTest.h
+++ b/src/dawn/tests/DawnNativeTest.h
@@ -38,12 +38,13 @@
~DawnNativeTest() override;
void SetUp() override;
- void TearDown() override;
+ virtual std::unique_ptr<dawn::platform::Platform> CreateTestPlatform();
virtual WGPUDevice CreateTestDevice();
protected:
std::unique_ptr<dawn::native::Instance> instance;
+ std::unique_ptr<dawn::platform::Platform> platform;
dawn::native::Adapter adapter;
wgpu::Device device;
diff --git a/src/dawn/tests/end2end/D3D12CachingTests.cpp b/src/dawn/tests/end2end/D3D12CachingTests.cpp
index 9cd4042..0d3bcdf 100644
--- a/src/dawn/tests/end2end/D3D12CachingTests.cpp
+++ b/src/dawn/tests/end2end/D3D12CachingTests.cpp
@@ -17,7 +17,7 @@
#include <utility>
#include "dawn/tests/DawnTest.h"
-#include "dawn/tests/end2end/mocks/CachingInterfaceMock.h"
+#include "dawn/tests/mocks/platform/CachingInterfaceMock.h"
#include "dawn/utils/ComboRenderPipelineDescriptor.h"
#include "dawn/utils/WGPUHelpers.h"
diff --git a/src/dawn/tests/end2end/PipelineCachingTests.cpp b/src/dawn/tests/end2end/PipelineCachingTests.cpp
index 23ee708..8318c48 100644
--- a/src/dawn/tests/end2end/PipelineCachingTests.cpp
+++ b/src/dawn/tests/end2end/PipelineCachingTests.cpp
@@ -16,7 +16,7 @@
#include <string_view>
#include "dawn/tests/DawnTest.h"
-#include "dawn/tests/end2end/mocks/CachingInterfaceMock.h"
+#include "dawn/tests/mocks/platform/CachingInterfaceMock.h"
#include "dawn/utils/ComboRenderPipelineDescriptor.h"
#include "dawn/utils/WGPUHelpers.h"
diff --git a/src/dawn/tests/end2end/mocks/CachingInterfaceMock.cpp b/src/dawn/tests/mocks/platform/CachingInterfaceMock.cpp
similarity index 97%
rename from src/dawn/tests/end2end/mocks/CachingInterfaceMock.cpp
rename to src/dawn/tests/mocks/platform/CachingInterfaceMock.cpp
index 8507e9c..a52d4c2 100644
--- a/src/dawn/tests/end2end/mocks/CachingInterfaceMock.cpp
+++ b/src/dawn/tests/mocks/platform/CachingInterfaceMock.cpp
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "dawn/tests/end2end/mocks/CachingInterfaceMock.h"
+#include "dawn/tests/mocks/platform/CachingInterfaceMock.h"
using ::testing::Invoke;
diff --git a/src/dawn/tests/end2end/mocks/CachingInterfaceMock.h b/src/dawn/tests/mocks/platform/CachingInterfaceMock.h
similarity index 92%
rename from src/dawn/tests/end2end/mocks/CachingInterfaceMock.h
rename to src/dawn/tests/mocks/platform/CachingInterfaceMock.h
index cc61d80..0e9e6af 100644
--- a/src/dawn/tests/end2end/mocks/CachingInterfaceMock.h
+++ b/src/dawn/tests/mocks/platform/CachingInterfaceMock.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#ifndef SRC_DAWN_TESTS_END2END_MOCKS_CACHINGINTERFACEMOCK_H_
-#define SRC_DAWN_TESTS_END2END_MOCKS_CACHINGINTERFACEMOCK_H_
+#ifndef SRC_DAWN_TESTS_MOCKS_PLATFORM_CACHINGINTERFACEMOCK_H_
+#define SRC_DAWN_TESTS_MOCKS_PLATFORM_CACHINGINTERFACEMOCK_H_
#include <dawn/platform/DawnPlatform.h>
#include <gmock/gmock.h>
@@ -70,4 +70,4 @@
dawn::platform::CachingInterface* mCachingInterface = nullptr;
};
-#endif // SRC_DAWN_TESTS_END2END_MOCKS_CACHINGINTERFACEMOCK_H_
+#endif // SRC_DAWN_TESTS_MOCKS_PLATFORM_CACHINGINTERFACEMOCK_H_
diff --git a/src/dawn/tests/unittests/native/CacheRequestTests.cpp b/src/dawn/tests/unittests/native/CacheRequestTests.cpp
new file mode 100644
index 0000000..995de7f
--- /dev/null
+++ b/src/dawn/tests/unittests/native/CacheRequestTests.cpp
@@ -0,0 +1,320 @@
+// Copyright 2022 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "dawn/native/Blob.h"
+#include "dawn/native/CacheRequest.h"
+#include "dawn/tests/DawnNativeTest.h"
+#include "dawn/tests/mocks/platform/CachingInterfaceMock.h"
+
+namespace dawn::native {
+
+namespace {
+
+using ::testing::_;
+using ::testing::ByMove;
+using ::testing::Invoke;
+using ::testing::MockFunction;
+using ::testing::Return;
+using ::testing::StrictMock;
+using ::testing::WithArg;
+
+class CacheRequestTests : public DawnNativeTest {
+ protected:
+ std::unique_ptr<dawn::platform::Platform> CreateTestPlatform() override {
+ return std::make_unique<DawnCachingMockPlatform>(&mMockCache);
+ }
+
+ WGPUDevice CreateTestDevice() override {
+ wgpu::DeviceDescriptor deviceDescriptor = {};
+ wgpu::DawnTogglesDeviceDescriptor togglesDesc = {};
+ deviceDescriptor.nextInChain = &togglesDesc;
+
+ const char* toggle = "enable_blob_cache";
+ togglesDesc.forceEnabledToggles = &toggle;
+ togglesDesc.forceEnabledTogglesCount = 1;
+
+ return adapter.CreateDevice(&deviceDescriptor);
+ }
+
+ DeviceBase* GetDevice() { return dawn::native::FromAPI(device.Get()); }
+
+ StrictMock<CachingInterfaceMock> mMockCache;
+};
+
+struct Foo {
+ int value;
+};
+
+#define REQUEST_MEMBERS(X) \
+ X(int, a) \
+ X(float, b) \
+ X(std::vector<unsigned int>, c) \
+ X(CacheKey::UnsafeUnkeyedValue<int*>, d) \
+ X(CacheKey::UnsafeUnkeyedValue<Foo>, e)
+
+DAWN_MAKE_CACHE_REQUEST(CacheRequestForTesting, REQUEST_MEMBERS)
+
+#undef REQUEST_MEMBERS
+
+// static_assert the expected types for various return types from the cache hit handler and cache
+// miss handler.
+TEST_F(CacheRequestTests, CacheResultTypes) {
+ EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillRepeatedly(Return(0));
+
+ // (int, ResultOrError<int>), should be ResultOrError<CacheResult<int>>.
+ auto v1 = LoadOrRun(
+ GetDevice(), CacheRequestForTesting{}, [](Blob) -> int { return 0; },
+ [](CacheRequestForTesting) -> ResultOrError<int> { return 1; });
+ v1.AcquireSuccess();
+ static_assert(std::is_same_v<ResultOrError<CacheResult<int>>, decltype(v1)>);
+
+ // (ResultOrError<float>, ResultOrError<float>), should be ResultOrError<CacheResult<float>>.
+ auto v2 = LoadOrRun(
+ GetDevice(), CacheRequestForTesting{}, [](Blob) -> ResultOrError<float> { return 0.0; },
+ [](CacheRequestForTesting) -> ResultOrError<float> { return 1.0; });
+ v2.AcquireSuccess();
+ static_assert(std::is_same_v<ResultOrError<CacheResult<float>>, decltype(v2)>);
+}
+
+// Test that using a CacheRequest builds a key from the device key, the request type enum, and all
+// of the request members.
+TEST_F(CacheRequestTests, MakesCacheKey) {
+ // Make a request.
+ CacheRequestForTesting req;
+ req.a = 1;
+ req.b = 0.2;
+ req.c = {3, 4, 5};
+
+ // Make the expected key.
+ CacheKey expectedKey;
+ expectedKey.Record(GetDevice()->GetCacheKey(), "CacheRequestForTesting", req.a, req.b, req.c);
+
+ // Expect a call to LoadData with the expected key.
+ EXPECT_CALL(mMockCache, LoadData(_, expectedKey.size(), nullptr, 0))
+ .WillOnce(WithArg<0>(Invoke([&](const void* actualKeyData) {
+ EXPECT_EQ(memcmp(actualKeyData, expectedKey.data(), expectedKey.size()), 0);
+ return 0;
+ })));
+
+ // Load the request.
+ auto result = LoadOrRun(
+ GetDevice(), std::move(req), [](Blob) -> int { return 0; },
+ [](CacheRequestForTesting) -> ResultOrError<int> { return 0; })
+ .AcquireSuccess();
+
+ // The created cache key should be saved on the result.
+ EXPECT_EQ(result.GetCacheKey().size(), expectedKey.size());
+ EXPECT_EQ(memcmp(result.GetCacheKey().data(), expectedKey.data(), expectedKey.size()), 0);
+}
+
+// Test that members that are wrapped in UnsafeUnkeyedValue do not impact the key.
+TEST_F(CacheRequestTests, CacheKeyIgnoresUnsafeIgnoredValue) {
+ // Make two requests with different UnsafeUnkeyedValues (UnsafeUnkeyed is declared on the struct
+ // definition).
+ int v1, v2;
+ CacheRequestForTesting req1;
+ req1.d = &v1;
+ req1.e = Foo{42};
+
+ CacheRequestForTesting req2;
+ req2.d = &v2;
+ req2.e = Foo{24};
+
+ EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(0)).WillOnce(Return(0));
+
+ static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn;
+
+ // Load the first request, and check that the unsafe unkeyed values were passed though
+ EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([&](CacheRequestForTesting req) {
+ EXPECT_EQ(req.d.UnsafeGetValue(), &v1);
+ EXPECT_FLOAT_EQ(req.e.UnsafeGetValue().value, 42);
+ return 0;
+ })));
+ auto r1 = LoadOrRun(
+ GetDevice(), std::move(req1), [](Blob) { return 0; },
+ [](CacheRequestForTesting req) -> ResultOrError<int> {
+ return cacheMissFn.Call(std::move(req));
+ })
+ .AcquireSuccess();
+
+ // Load the second request, and check that the unsafe unkeyed values were passed though
+ EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([&](CacheRequestForTesting req) {
+ EXPECT_EQ(req.d.UnsafeGetValue(), &v2);
+ EXPECT_FLOAT_EQ(req.e.UnsafeGetValue().value, 24);
+ return 0;
+ })));
+ auto r2 = LoadOrRun(
+ GetDevice(), std::move(req2), [](Blob) { return 0; },
+ [](CacheRequestForTesting req) -> ResultOrError<int> {
+ return cacheMissFn.Call(std::move(req));
+ })
+ .AcquireSuccess();
+
+ // Expect their keys to be the same.
+ EXPECT_EQ(r1.GetCacheKey().size(), r2.GetCacheKey().size());
+ EXPECT_EQ(memcmp(r1.GetCacheKey().data(), r2.GetCacheKey().data(), r1.GetCacheKey().size()), 0);
+}
+
+// Test the expected code path when there is a cache miss.
+TEST_F(CacheRequestTests, CacheMiss) {
+ // Make a request.
+ CacheRequestForTesting req;
+ req.a = 1;
+ req.b = 0.2;
+ req.c = {3, 4, 5};
+
+ unsigned int* cPtr = req.c.data();
+
+ static StrictMock<MockFunction<int(Blob)>> cacheHitFn;
+ static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn;
+
+ // Mock a cache miss.
+ EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(0));
+
+ // Expect the cache miss, and return some value.
+ int rv = 42;
+ EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](CacheRequestForTesting req) {
+ // Expect the request contents to be the same. The data pointer for |c| is also the same
+ // since it was moved.
+ EXPECT_EQ(req.a, 1);
+ EXPECT_FLOAT_EQ(req.b, 0.2);
+ EXPECT_EQ(req.c.data(), cPtr);
+ return rv;
+ })));
+
+ // Load the request.
+ auto result = LoadOrRun(
+ GetDevice(), std::move(req),
+ [](Blob blob) -> int { return cacheHitFn.Call(std::move(blob)); },
+ [](CacheRequestForTesting req) -> ResultOrError<int> {
+ return cacheMissFn.Call(std::move(req));
+ })
+ .AcquireSuccess();
+
+ // Expect the result to store the value.
+ EXPECT_EQ(*result, rv);
+ EXPECT_FALSE(result.IsCached());
+}
+
+// Test the expected code path when there is a cache hit.
+TEST_F(CacheRequestTests, CacheHit) {
+ // Make a request.
+ CacheRequestForTesting req;
+ req.a = 1;
+ req.b = 0.2;
+ req.c = {3, 4, 5};
+
+ static StrictMock<MockFunction<int(Blob)>> cacheHitFn;
+ static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn;
+
+ static constexpr char kCachedData[] = "hello world!";
+
+ // Mock a cache hit, and load the cached data.
+ EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(sizeof(kCachedData)));
+ EXPECT_CALL(mMockCache, LoadData(_, _, _, sizeof(kCachedData)))
+ .WillOnce(WithArg<2>(Invoke([](void* dataOut) {
+ memcpy(dataOut, kCachedData, sizeof(kCachedData));
+ return sizeof(kCachedData);
+ })));
+
+ // Expect the cache hit, and return some value.
+ int rv = 1337;
+ EXPECT_CALL(cacheHitFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](Blob blob) {
+ // Expect the cached blob contents to match the cached data.
+ EXPECT_EQ(blob.Size(), sizeof(kCachedData));
+ EXPECT_EQ(memcmp(blob.Data(), kCachedData, sizeof(kCachedData)), 0);
+
+ return rv;
+ })));
+
+ // Load the request.
+ auto result = LoadOrRun(
+ GetDevice(), std::move(req),
+ [](Blob blob) -> int { return cacheHitFn.Call(std::move(blob)); },
+ [](CacheRequestForTesting req) -> ResultOrError<int> {
+ return cacheMissFn.Call(std::move(req));
+ })
+ .AcquireSuccess();
+
+ // Expect the result to store the value.
+ EXPECT_EQ(*result, rv);
+ EXPECT_TRUE(result.IsCached());
+}
+
+// Test the expected code path when there is a cache hit but the handler errors.
+TEST_F(CacheRequestTests, CacheHitError) {
+ // Make a request.
+ CacheRequestForTesting req;
+ req.a = 1;
+ req.b = 0.2;
+ req.c = {3, 4, 5};
+
+ unsigned int* cPtr = req.c.data();
+
+ static StrictMock<MockFunction<ResultOrError<int>(Blob)>> cacheHitFn;
+ static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn;
+
+ static constexpr char kCachedData[] = "hello world!";
+
+ // Mock a cache hit, and load the cached data.
+ EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(sizeof(kCachedData)));
+ EXPECT_CALL(mMockCache, LoadData(_, _, _, sizeof(kCachedData)))
+ .WillOnce(WithArg<2>(Invoke([](void* dataOut) {
+ memcpy(dataOut, kCachedData, sizeof(kCachedData));
+ return sizeof(kCachedData);
+ })));
+
+ // Expect the cache hit.
+ EXPECT_CALL(cacheHitFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](Blob blob) {
+ // Expect the cached blob contents to match the cached data.
+ EXPECT_EQ(blob.Size(), sizeof(kCachedData));
+ EXPECT_EQ(memcmp(blob.Data(), kCachedData, sizeof(kCachedData)), 0);
+
+ // Return an error.
+ return DAWN_VALIDATION_ERROR("fake test error");
+ })));
+
+ // Expect the cache miss handler since the cache hit errored.
+ int rv = 79;
+ EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](CacheRequestForTesting req) {
+ // Expect the request contents to be the same. The data pointer for |c| is also the same
+ // since it was moved.
+ EXPECT_EQ(req.a, 1);
+ EXPECT_FLOAT_EQ(req.b, 0.2);
+ EXPECT_EQ(req.c.data(), cPtr);
+ return rv;
+ })));
+
+ // Load the request.
+ auto result =
+ LoadOrRun(
+ GetDevice(), std::move(req),
+ [](Blob blob) -> ResultOrError<int> { return cacheHitFn.Call(std::move(blob)); },
+ [](CacheRequestForTesting req) -> ResultOrError<int> {
+ return cacheMissFn.Call(std::move(req));
+ })
+ .AcquireSuccess();
+
+ // Expect the result to store the value.
+ EXPECT_EQ(*result, rv);
+ EXPECT_FALSE(result.IsCached());
+}
+
+} // namespace
+
+} // namespace dawn::native