Add CacheRequest utilities and tests This CL adds a DAWN_MAKE_CACHE_REQUEST X macro which helps in building a CacheRequest struct. A CacheRequest struct may be passed to LoadOrRun which will generate a CacheKey from the struct and load a result if there is a cache hit, or it will call the provided cache miss function to compute a value. The request struct helps enforce that precisely the inputs that go into a computation are all also included inside the CacheKey for that computation. Bug: dawn:549 Change-Id: Id85eb95f1b944d5431f142162ffa9a384351be89 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/91063 Reviewed-by: Loko Kung <lokokung@google.com> Commit-Queue: Austin Eng <enga@chromium.org>
diff --git a/src/dawn/native/BUILD.gn b/src/dawn/native/BUILD.gn index 4294b26..137ce1b 100644 --- a/src/dawn/native/BUILD.gn +++ b/src/dawn/native/BUILD.gn
@@ -200,6 +200,9 @@ "Buffer.h", "CacheKey.cpp", "CacheKey.h", + "CacheRequest.cpp", + "CacheRequest.h", + "CacheResult.h", "CachedObject.cpp", "CachedObject.h", "CallbackTaskManager.cpp",
diff --git a/src/dawn/native/CMakeLists.txt b/src/dawn/native/CMakeLists.txt index bbb50a0..8542b33 100644 --- a/src/dawn/native/CMakeLists.txt +++ b/src/dawn/native/CMakeLists.txt
@@ -59,6 +59,9 @@ "CachedObject.h" "CacheKey.cpp" "CacheKey.h" + "CacheRequest.cpp" + "CacheRequest.h" + "CacheResult.h" "CallbackTaskManager.cpp" "CallbackTaskManager.h" "CommandAllocator.cpp"
diff --git a/src/dawn/native/CacheKey.h b/src/dawn/native/CacheKey.h index 357ce4b..c2901db 100644 --- a/src/dawn/native/CacheKey.h +++ b/src/dawn/native/CacheKey.h
@@ -18,8 +18,10 @@ #include <bitset> #include <iostream> #include <limits> +#include <memory> #include <string> #include <type_traits> +#include <utility> #include <vector> #include "dawn/common/TypedInteger.h" @@ -49,6 +51,19 @@ enum class Type { ComputePipeline, RenderPipeline, Shader }; template <typename T> + class UnsafeUnkeyedValue { + public: + UnsafeUnkeyedValue() = default; + // NOLINTNEXTLINE(runtime/explicit) allow implicit construction to decrease verbosity + UnsafeUnkeyedValue(T&& value) : mValue(std::forward<T>(value)) {} + + const T& UnsafeGetValue() const { return mValue; } + + private: + T mValue; + }; + + template <typename T> CacheKey& Record(const T& t) { CacheKeySerializer<T>::Serialize(this, t); return *this; @@ -89,6 +104,18 @@ } }; +template <typename T> +CacheKey::UnsafeUnkeyedValue<T> UnsafeUnkeyedValue(T&& value) { + return CacheKey::UnsafeUnkeyedValue<T>(std::forward<T>(value)); +} + +// Specialized overload for CacheKey::UnsafeIgnoredValue which does nothing. +template <typename T> +class CacheKeySerializer<CacheKey::UnsafeUnkeyedValue<T>> { + public: + constexpr static void Serialize(CacheKey* key, const CacheKey::UnsafeUnkeyedValue<T>&) {} +}; + // Specialized overload for fundamental types. template <typename T> class CacheKeySerializer<T, std::enable_if_t<std::is_fundamental_v<T>>> { @@ -197,6 +224,13 @@ static void Serialize(CacheKey* key, const T& t) { key->Record(t.GetCacheKey()); } }; +// Specialized overload for std::vector. +template <typename T> +class CacheKeySerializer<std::vector<T>> { + public: + static void Serialize(CacheKey* key, const std::vector<T>& t) { key->RecordIterable(t); } +}; + } // namespace dawn::native #endif // SRC_DAWN_NATIVE_CACHEKEY_H_
diff --git a/src/dawn/native/CacheRequest.cpp b/src/dawn/native/CacheRequest.cpp new file mode 100644 index 0000000..2b35b12 --- /dev/null +++ b/src/dawn/native/CacheRequest.cpp
@@ -0,0 +1,25 @@ +// Copyright 2022 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dawn/native/CacheRequest.h" + +#include "dawn/common/Log.h" + +namespace dawn::native::detail { + +void LogCacheHitError(std::unique_ptr<ErrorData> error) { + dawn::ErrorLog() << error->GetFormattedMessage(); +} + +} // namespace dawn::native::detail
diff --git a/src/dawn/native/CacheRequest.h b/src/dawn/native/CacheRequest.h new file mode 100644 index 0000000..0fecb0d --- /dev/null +++ b/src/dawn/native/CacheRequest.h
@@ -0,0 +1,186 @@ +// Copyright 2022 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_DAWN_NATIVE_CACHEREQUEST_H_ +#define SRC_DAWN_NATIVE_CACHEREQUEST_H_ + +#include <memory> +#include <utility> + +#include "dawn/common/Assert.h" +#include "dawn/common/Compiler.h" +#include "dawn/native/Blob.h" +#include "dawn/native/BlobCache.h" +#include "dawn/native/CacheKey.h" +#include "dawn/native/CacheResult.h" +#include "dawn/native/Device.h" +#include "dawn/native/Error.h" + +namespace dawn::native { + +namespace detail { + +template <typename T> +struct UnwrapResultOrError { + using type = T; +}; + +template <typename T> +struct UnwrapResultOrError<ResultOrError<T>> { + using type = T; +}; + +template <typename T> +struct IsResultOrError { + static constexpr bool value = false; +}; + +template <typename T> +struct IsResultOrError<ResultOrError<T>> { + static constexpr bool value = true; +}; + +void LogCacheHitError(std::unique_ptr<ErrorData> error); + +} // namespace detail + +// Implementation of a CacheRequest which provides a LoadOrRun friend function which can be found +// via argument-dependent lookup. So, it doesn't need to be called with a fully qualified function +// name. +// +// Example usage: +// Request r = { ... }; +// ResultOrError<CacheResult<T>> cacheResult = +// LoadOrRun(device, std::move(r), +// [](Blob blob) -> T { /* handle cache hit */ }, +// [](Request r) -> ResultOrError<T> { /* handle cache miss */ } +// ); +// Or with free functions: +/// T OnCacheHit(Blob blob) { ... } +// ResultOrError<T> OnCacheMiss(Request r) { ... } +// // ... +// Request r = { ... }; +// auto result = LoadOrRun(device, std::move(r), OnCacheHit, OnCacheMiss); +// +// LoadOrRun generates a CacheKey from the request and loads from the device's BlobCache. On cache +// hit, calls CacheHitFn and returns a CacheResult<T>. On cache miss or if CacheHitFn returned an +// Error, calls CacheMissFn -> ResultOrError<T> with the request data and returns a +// ResultOrError<CacheResult<T>>. CacheHitFn must return the same unwrapped type as CacheMissFn. +// i.e. it doesn't need to be wrapped in ResultOrError. +// +// CacheMissFn may not have any additional data bound to it. It may not be a lambda or std::function +// which captures additional information, so it can only operate on the request data. This is +// enforced with a compile-time static_assert, and ensures that the result created from the +// computation is exactly the data included in the CacheKey. +template <typename Request> +class CacheRequestImpl { + public: + CacheRequestImpl() = default; + + // Require CacheRequests to be move-only to avoid unnecessary copies. + CacheRequestImpl(CacheRequestImpl&&) = default; + CacheRequestImpl& operator=(CacheRequestImpl&&) = default; + CacheRequestImpl(const CacheRequestImpl&) = delete; + CacheRequestImpl& operator=(const CacheRequestImpl&) = delete; + + template <typename CacheHitFn, typename CacheMissFn> + friend auto LoadOrRun(DeviceBase* device, + Request&& r, + CacheHitFn cacheHitFn, + CacheMissFn cacheMissFn) { + // Get return types and check that CacheMissReturnType can be cast to a raw function + // pointer. This means it's not a std::function or lambda that captures additional data. + using CacheHitReturnType = decltype(cacheHitFn(std::declval<Blob>())); + using CacheMissReturnType = decltype(cacheMissFn(std::declval<Request>())); + static_assert( + std::is_convertible_v<CacheMissFn, CacheMissReturnType (*)(Request)>, + "CacheMissFn function signature does not match, or it is not a free function."); + + static_assert(detail::IsResultOrError<CacheMissReturnType>::value, + "CacheMissFn should return a ResultOrError."); + using UnwrappedReturnType = typename detail::UnwrapResultOrError<CacheMissReturnType>::type; + + static_assert(std::is_same_v<typename detail::UnwrapResultOrError<CacheHitReturnType>::type, + UnwrappedReturnType>, + "If CacheMissFn returns T, CacheHitFn must return T or ResultOrError<T>."); + + using CacheResultType = CacheResult<UnwrappedReturnType>; + using ReturnType = ResultOrError<CacheResultType>; + + CacheKey key = r.CreateCacheKey(device); + BlobCache* cache = device->GetBlobCache(); + Blob blob; + if (cache != nullptr) { + blob = cache->Load(key); + } + + if (!blob.Empty()) { + // Cache hit. Handle the cached blob. + auto result = cacheHitFn(std::move(blob)); + + if constexpr (!detail::IsResultOrError<CacheHitReturnType>::value) { + // If the result type is not a ResultOrError, return it. + return ReturnType(CacheResultType::CacheHit(std::move(key), std::move(result))); + } else { + // Otherwise, if the value is a success, also return it. + if (DAWN_LIKELY(result.IsSuccess())) { + return ReturnType( + CacheResultType::CacheHit(std::move(key), result.AcquireSuccess())); + } + // On error, continue to the cache miss path and log the error. + detail::LogCacheHitError(result.AcquireError()); + } + } + // Cache miss, or the CacheHitFn failed. + auto result = cacheMissFn(std::move(r)); + if (DAWN_LIKELY(result.IsSuccess())) { + return ReturnType(CacheResultType::CacheMiss(std::move(key), result.AcquireSuccess())); + } + return ReturnType(result.AcquireError()); + } +}; + +// Helper for X macro to declare a struct member. +#define DAWN_INTERNAL_CACHE_REQUEST_DECL_STRUCT_MEMBER(type, name) type name{}; + +// Helper for X macro for recording cache request fields into a CacheKey. +#define DAWN_INTERNAL_CACHE_REQUEST_RECORD_KEY(type, name) key.Record(name); + +// Helper X macro to define a CacheRequest struct. +// Example usage: +// #define REQUEST_MEMBERS(X) \ +// X(int, a) \ +// X(float, b) \ +// X(Foo, foo) \ +// X(Bar, bar) +// DAWN_MAKE_CACHE_REQUEST(MyCacheRequest, REQUEST_MEMBERS) +// #undef REQUEST_MEMBERS +#define DAWN_MAKE_CACHE_REQUEST(Request, MEMBERS) \ + class Request : public CacheRequestImpl<Request> { \ + public: \ + Request() = default; \ + MEMBERS(DAWN_INTERNAL_CACHE_REQUEST_DECL_STRUCT_MEMBER) \ + \ + /* Create a CacheKey from the request type and all members */ \ + CacheKey CreateCacheKey(const DeviceBase* device) const { \ + CacheKey key = device->GetCacheKey(); \ + key.Record(#Request); \ + MEMBERS(DAWN_INTERNAL_CACHE_REQUEST_RECORD_KEY) \ + return key; \ + } \ + }; + +} // namespace dawn::native + +#endif // SRC_DAWN_NATIVE_CACHEREQUEST_H_
diff --git a/src/dawn/native/CacheResult.h b/src/dawn/native/CacheResult.h new file mode 100644 index 0000000..a2750fe --- /dev/null +++ b/src/dawn/native/CacheResult.h
@@ -0,0 +1,76 @@ +// Copyright 2022 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_DAWN_NATIVE_CACHERESULT_H_ +#define SRC_DAWN_NATIVE_CACHERESULT_H_ + +#include <memory> +#include <utility> + +#include "dawn/common/Assert.h" + +namespace dawn::native { + +template <typename T> +class CacheResult { + public: + static CacheResult CacheHit(CacheKey key, T value) { + return CacheResult(std::move(key), std::move(value), true); + } + + static CacheResult CacheMiss(CacheKey key, T value) { + return CacheResult(std::move(key), std::move(value), false); + } + + CacheResult() : mKey(), mValue(), mIsCached(false), mIsValid(false) {} + + bool IsCached() const { + ASSERT(mIsValid); + return mIsCached; + } + const CacheKey& GetCacheKey() { + ASSERT(mIsValid); + return mKey; + } + + // Note: Getting mValue is always const, since mutating it would invalidate consistency with + // mKey. + const T* operator->() const { + ASSERT(mIsValid); + return &mValue; + } + const T& operator*() const { + ASSERT(mIsValid); + return mValue; + } + + T Acquire() { + ASSERT(mIsValid); + mIsValid = false; + return std::move(mValue); + } + + private: + CacheResult(CacheKey key, T value, bool isCached) + : mKey(std::move(key)), mValue(std::move(value)), mIsCached(isCached), mIsValid(true) {} + + CacheKey mKey; + T mValue; + bool mIsCached; + bool mIsValid; +}; + +} // namespace dawn::native + +#endif // SRC_DAWN_NATIVE_CACHERESULT_H_
diff --git a/src/dawn/tests/BUILD.gn b/src/dawn/tests/BUILD.gn index 69c41a2..57bb666 100644 --- a/src/dawn/tests/BUILD.gn +++ b/src/dawn/tests/BUILD.gn
@@ -192,6 +192,7 @@ ":gmock_and_gtest", ":mock_webgpu_gen", ":native_mocks_sources", + ":platform_mocks_sources", "${dawn_root}/src/dawn:cpp", "${dawn_root}/src/dawn:proc", "${dawn_root}/src/dawn/common", @@ -254,6 +255,7 @@ "unittests/VersionTests.cpp", "unittests/native/BlobTests.cpp", "unittests/native/CacheKeyTests.cpp", + "unittests/native/CacheRequestTests.cpp", "unittests/native/CommandBufferEncodingTests.cpp", "unittests/native/CreatePipelineAsyncTaskTests.cpp", "unittests/native/DestroyObjectTests.cpp", @@ -380,9 +382,9 @@ # Dawn end2end tests targets ############################################################################### -# Source code for mocks used for end2end testing are separated from the rest of +# Source code for mocks used for platform testing are separated from the rest of # sources so that they aren't included in non-test builds. -source_set("end2end_mocks_sources") { +source_set("platform_mocks_sources") { configs += [ "${dawn_root}/src/dawn/native:internal" ] testonly = true @@ -392,8 +394,8 @@ ] sources = [ - "end2end/mocks/CachingInterfaceMock.cpp", - "end2end/mocks/CachingInterfaceMock.h", + "mocks/platform/CachingInterfaceMock.cpp", + "mocks/platform/CachingInterfaceMock.h", ] } @@ -401,7 +403,7 @@ testonly = true deps = [ - ":end2end_mocks_sources", + ":platform_mocks_sources", ":test_infra_sources", "${dawn_root}/src/dawn:cpp", "${dawn_root}/src/dawn:proc",
diff --git a/src/dawn/tests/DawnNativeTest.cpp b/src/dawn/tests/DawnNativeTest.cpp index 163413d..fbf8030 100644 --- a/src/dawn/tests/DawnNativeTest.cpp +++ b/src/dawn/tests/DawnNativeTest.cpp
@@ -20,6 +20,9 @@ #include "dawn/common/Assert.h" #include "dawn/dawn_proc.h" #include "dawn/native/ErrorData.h" +#include "dawn/native/Instance.h" +#include "dawn/native/dawn_platform.h" +#include "dawn/platform/DawnPlatform.h" namespace dawn::native { @@ -43,6 +46,9 @@ void DawnNativeTest::SetUp() { instance = std::make_unique<dawn::native::Instance>(); + platform = CreateTestPlatform(); + dawn::native::FromAPI(instance->Get())->SetPlatformForTesting(platform.get()); + instance->DiscoverDefaultAdapters(); std::vector<dawn::native::Adapter> adapters = instance->GetAdapters(); @@ -66,7 +72,9 @@ device.SetUncapturedErrorCallback(DawnNativeTest::OnDeviceError, nullptr); } -void DawnNativeTest::TearDown() {} +std::unique_ptr<dawn::platform::Platform> DawnNativeTest::CreateTestPlatform() { + return nullptr; +} WGPUDevice DawnNativeTest::CreateTestDevice() { // Disabled disallowing unsafe APIs so we can test them.
diff --git a/src/dawn/tests/DawnNativeTest.h b/src/dawn/tests/DawnNativeTest.h index e92bf67..dd3532f 100644 --- a/src/dawn/tests/DawnNativeTest.h +++ b/src/dawn/tests/DawnNativeTest.h
@@ -38,12 +38,13 @@ ~DawnNativeTest() override; void SetUp() override; - void TearDown() override; + virtual std::unique_ptr<dawn::platform::Platform> CreateTestPlatform(); virtual WGPUDevice CreateTestDevice(); protected: std::unique_ptr<dawn::native::Instance> instance; + std::unique_ptr<dawn::platform::Platform> platform; dawn::native::Adapter adapter; wgpu::Device device;
diff --git a/src/dawn/tests/end2end/D3D12CachingTests.cpp b/src/dawn/tests/end2end/D3D12CachingTests.cpp index 9cd4042..0d3bcdf 100644 --- a/src/dawn/tests/end2end/D3D12CachingTests.cpp +++ b/src/dawn/tests/end2end/D3D12CachingTests.cpp
@@ -17,7 +17,7 @@ #include <utility> #include "dawn/tests/DawnTest.h" -#include "dawn/tests/end2end/mocks/CachingInterfaceMock.h" +#include "dawn/tests/mocks/platform/CachingInterfaceMock.h" #include "dawn/utils/ComboRenderPipelineDescriptor.h" #include "dawn/utils/WGPUHelpers.h"
diff --git a/src/dawn/tests/end2end/PipelineCachingTests.cpp b/src/dawn/tests/end2end/PipelineCachingTests.cpp index 23ee708..8318c48 100644 --- a/src/dawn/tests/end2end/PipelineCachingTests.cpp +++ b/src/dawn/tests/end2end/PipelineCachingTests.cpp
@@ -16,7 +16,7 @@ #include <string_view> #include "dawn/tests/DawnTest.h" -#include "dawn/tests/end2end/mocks/CachingInterfaceMock.h" +#include "dawn/tests/mocks/platform/CachingInterfaceMock.h" #include "dawn/utils/ComboRenderPipelineDescriptor.h" #include "dawn/utils/WGPUHelpers.h"
diff --git a/src/dawn/tests/end2end/mocks/CachingInterfaceMock.cpp b/src/dawn/tests/mocks/platform/CachingInterfaceMock.cpp similarity index 97% rename from src/dawn/tests/end2end/mocks/CachingInterfaceMock.cpp rename to src/dawn/tests/mocks/platform/CachingInterfaceMock.cpp index 8507e9c..a52d4c2 100644 --- a/src/dawn/tests/end2end/mocks/CachingInterfaceMock.cpp +++ b/src/dawn/tests/mocks/platform/CachingInterfaceMock.cpp
@@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "dawn/tests/end2end/mocks/CachingInterfaceMock.h" +#include "dawn/tests/mocks/platform/CachingInterfaceMock.h" using ::testing::Invoke;
diff --git a/src/dawn/tests/end2end/mocks/CachingInterfaceMock.h b/src/dawn/tests/mocks/platform/CachingInterfaceMock.h similarity index 92% rename from src/dawn/tests/end2end/mocks/CachingInterfaceMock.h rename to src/dawn/tests/mocks/platform/CachingInterfaceMock.h index cc61d80..0e9e6af 100644 --- a/src/dawn/tests/end2end/mocks/CachingInterfaceMock.h +++ b/src/dawn/tests/mocks/platform/CachingInterfaceMock.h
@@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef SRC_DAWN_TESTS_END2END_MOCKS_CACHINGINTERFACEMOCK_H_ -#define SRC_DAWN_TESTS_END2END_MOCKS_CACHINGINTERFACEMOCK_H_ +#ifndef SRC_DAWN_TESTS_MOCKS_PLATFORM_CACHINGINTERFACEMOCK_H_ +#define SRC_DAWN_TESTS_MOCKS_PLATFORM_CACHINGINTERFACEMOCK_H_ #include <dawn/platform/DawnPlatform.h> #include <gmock/gmock.h> @@ -70,4 +70,4 @@ dawn::platform::CachingInterface* mCachingInterface = nullptr; }; -#endif // SRC_DAWN_TESTS_END2END_MOCKS_CACHINGINTERFACEMOCK_H_ +#endif // SRC_DAWN_TESTS_MOCKS_PLATFORM_CACHINGINTERFACEMOCK_H_
diff --git a/src/dawn/tests/unittests/native/CacheRequestTests.cpp b/src/dawn/tests/unittests/native/CacheRequestTests.cpp new file mode 100644 index 0000000..995de7f --- /dev/null +++ b/src/dawn/tests/unittests/native/CacheRequestTests.cpp
@@ -0,0 +1,320 @@ +// Copyright 2022 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <memory> +#include <utility> +#include <vector> + +#include "dawn/native/Blob.h" +#include "dawn/native/CacheRequest.h" +#include "dawn/tests/DawnNativeTest.h" +#include "dawn/tests/mocks/platform/CachingInterfaceMock.h" + +namespace dawn::native { + +namespace { + +using ::testing::_; +using ::testing::ByMove; +using ::testing::Invoke; +using ::testing::MockFunction; +using ::testing::Return; +using ::testing::StrictMock; +using ::testing::WithArg; + +class CacheRequestTests : public DawnNativeTest { + protected: + std::unique_ptr<dawn::platform::Platform> CreateTestPlatform() override { + return std::make_unique<DawnCachingMockPlatform>(&mMockCache); + } + + WGPUDevice CreateTestDevice() override { + wgpu::DeviceDescriptor deviceDescriptor = {}; + wgpu::DawnTogglesDeviceDescriptor togglesDesc = {}; + deviceDescriptor.nextInChain = &togglesDesc; + + const char* toggle = "enable_blob_cache"; + togglesDesc.forceEnabledToggles = &toggle; + togglesDesc.forceEnabledTogglesCount = 1; + + return adapter.CreateDevice(&deviceDescriptor); + } + + DeviceBase* GetDevice() { return dawn::native::FromAPI(device.Get()); } + + StrictMock<CachingInterfaceMock> mMockCache; +}; + +struct Foo { + int value; +}; + +#define REQUEST_MEMBERS(X) \ + X(int, a) \ + X(float, b) \ + X(std::vector<unsigned int>, c) \ + X(CacheKey::UnsafeUnkeyedValue<int*>, d) \ + X(CacheKey::UnsafeUnkeyedValue<Foo>, e) + +DAWN_MAKE_CACHE_REQUEST(CacheRequestForTesting, REQUEST_MEMBERS) + +#undef REQUEST_MEMBERS + +// static_assert the expected types for various return types from the cache hit handler and cache +// miss handler. +TEST_F(CacheRequestTests, CacheResultTypes) { + EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillRepeatedly(Return(0)); + + // (int, ResultOrError<int>), should be ResultOrError<CacheResult<int>>. + auto v1 = LoadOrRun( + GetDevice(), CacheRequestForTesting{}, [](Blob) -> int { return 0; }, + [](CacheRequestForTesting) -> ResultOrError<int> { return 1; }); + v1.AcquireSuccess(); + static_assert(std::is_same_v<ResultOrError<CacheResult<int>>, decltype(v1)>); + + // (ResultOrError<float>, ResultOrError<float>), should be ResultOrError<CacheResult<float>>. + auto v2 = LoadOrRun( + GetDevice(), CacheRequestForTesting{}, [](Blob) -> ResultOrError<float> { return 0.0; }, + [](CacheRequestForTesting) -> ResultOrError<float> { return 1.0; }); + v2.AcquireSuccess(); + static_assert(std::is_same_v<ResultOrError<CacheResult<float>>, decltype(v2)>); +} + +// Test that using a CacheRequest builds a key from the device key, the request type enum, and all +// of the request members. +TEST_F(CacheRequestTests, MakesCacheKey) { + // Make a request. + CacheRequestForTesting req; + req.a = 1; + req.b = 0.2; + req.c = {3, 4, 5}; + + // Make the expected key. + CacheKey expectedKey; + expectedKey.Record(GetDevice()->GetCacheKey(), "CacheRequestForTesting", req.a, req.b, req.c); + + // Expect a call to LoadData with the expected key. + EXPECT_CALL(mMockCache, LoadData(_, expectedKey.size(), nullptr, 0)) + .WillOnce(WithArg<0>(Invoke([&](const void* actualKeyData) { + EXPECT_EQ(memcmp(actualKeyData, expectedKey.data(), expectedKey.size()), 0); + return 0; + }))); + + // Load the request. + auto result = LoadOrRun( + GetDevice(), std::move(req), [](Blob) -> int { return 0; }, + [](CacheRequestForTesting) -> ResultOrError<int> { return 0; }) + .AcquireSuccess(); + + // The created cache key should be saved on the result. + EXPECT_EQ(result.GetCacheKey().size(), expectedKey.size()); + EXPECT_EQ(memcmp(result.GetCacheKey().data(), expectedKey.data(), expectedKey.size()), 0); +} + +// Test that members that are wrapped in UnsafeUnkeyedValue do not impact the key. +TEST_F(CacheRequestTests, CacheKeyIgnoresUnsafeIgnoredValue) { + // Make two requests with different UnsafeUnkeyedValues (UnsafeUnkeyed is declared on the struct + // definition). + int v1, v2; + CacheRequestForTesting req1; + req1.d = &v1; + req1.e = Foo{42}; + + CacheRequestForTesting req2; + req2.d = &v2; + req2.e = Foo{24}; + + EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(0)).WillOnce(Return(0)); + + static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn; + + // Load the first request, and check that the unsafe unkeyed values were passed though + EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([&](CacheRequestForTesting req) { + EXPECT_EQ(req.d.UnsafeGetValue(), &v1); + EXPECT_FLOAT_EQ(req.e.UnsafeGetValue().value, 42); + return 0; + }))); + auto r1 = LoadOrRun( + GetDevice(), std::move(req1), [](Blob) { return 0; }, + [](CacheRequestForTesting req) -> ResultOrError<int> { + return cacheMissFn.Call(std::move(req)); + }) + .AcquireSuccess(); + + // Load the second request, and check that the unsafe unkeyed values were passed though + EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([&](CacheRequestForTesting req) { + EXPECT_EQ(req.d.UnsafeGetValue(), &v2); + EXPECT_FLOAT_EQ(req.e.UnsafeGetValue().value, 24); + return 0; + }))); + auto r2 = LoadOrRun( + GetDevice(), std::move(req2), [](Blob) { return 0; }, + [](CacheRequestForTesting req) -> ResultOrError<int> { + return cacheMissFn.Call(std::move(req)); + }) + .AcquireSuccess(); + + // Expect their keys to be the same. + EXPECT_EQ(r1.GetCacheKey().size(), r2.GetCacheKey().size()); + EXPECT_EQ(memcmp(r1.GetCacheKey().data(), r2.GetCacheKey().data(), r1.GetCacheKey().size()), 0); +} + +// Test the expected code path when there is a cache miss. +TEST_F(CacheRequestTests, CacheMiss) { + // Make a request. + CacheRequestForTesting req; + req.a = 1; + req.b = 0.2; + req.c = {3, 4, 5}; + + unsigned int* cPtr = req.c.data(); + + static StrictMock<MockFunction<int(Blob)>> cacheHitFn; + static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn; + + // Mock a cache miss. + EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(0)); + + // Expect the cache miss, and return some value. + int rv = 42; + EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](CacheRequestForTesting req) { + // Expect the request contents to be the same. The data pointer for |c| is also the same + // since it was moved. + EXPECT_EQ(req.a, 1); + EXPECT_FLOAT_EQ(req.b, 0.2); + EXPECT_EQ(req.c.data(), cPtr); + return rv; + }))); + + // Load the request. + auto result = LoadOrRun( + GetDevice(), std::move(req), + [](Blob blob) -> int { return cacheHitFn.Call(std::move(blob)); }, + [](CacheRequestForTesting req) -> ResultOrError<int> { + return cacheMissFn.Call(std::move(req)); + }) + .AcquireSuccess(); + + // Expect the result to store the value. + EXPECT_EQ(*result, rv); + EXPECT_FALSE(result.IsCached()); +} + +// Test the expected code path when there is a cache hit. +TEST_F(CacheRequestTests, CacheHit) { + // Make a request. + CacheRequestForTesting req; + req.a = 1; + req.b = 0.2; + req.c = {3, 4, 5}; + + static StrictMock<MockFunction<int(Blob)>> cacheHitFn; + static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn; + + static constexpr char kCachedData[] = "hello world!"; + + // Mock a cache hit, and load the cached data. + EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(sizeof(kCachedData))); + EXPECT_CALL(mMockCache, LoadData(_, _, _, sizeof(kCachedData))) + .WillOnce(WithArg<2>(Invoke([](void* dataOut) { + memcpy(dataOut, kCachedData, sizeof(kCachedData)); + return sizeof(kCachedData); + }))); + + // Expect the cache hit, and return some value. + int rv = 1337; + EXPECT_CALL(cacheHitFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](Blob blob) { + // Expect the cached blob contents to match the cached data. + EXPECT_EQ(blob.Size(), sizeof(kCachedData)); + EXPECT_EQ(memcmp(blob.Data(), kCachedData, sizeof(kCachedData)), 0); + + return rv; + }))); + + // Load the request. + auto result = LoadOrRun( + GetDevice(), std::move(req), + [](Blob blob) -> int { return cacheHitFn.Call(std::move(blob)); }, + [](CacheRequestForTesting req) -> ResultOrError<int> { + return cacheMissFn.Call(std::move(req)); + }) + .AcquireSuccess(); + + // Expect the result to store the value. + EXPECT_EQ(*result, rv); + EXPECT_TRUE(result.IsCached()); +} + +// Test the expected code path when there is a cache hit but the handler errors. +TEST_F(CacheRequestTests, CacheHitError) { + // Make a request. + CacheRequestForTesting req; + req.a = 1; + req.b = 0.2; + req.c = {3, 4, 5}; + + unsigned int* cPtr = req.c.data(); + + static StrictMock<MockFunction<ResultOrError<int>(Blob)>> cacheHitFn; + static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn; + + static constexpr char kCachedData[] = "hello world!"; + + // Mock a cache hit, and load the cached data. + EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(sizeof(kCachedData))); + EXPECT_CALL(mMockCache, LoadData(_, _, _, sizeof(kCachedData))) + .WillOnce(WithArg<2>(Invoke([](void* dataOut) { + memcpy(dataOut, kCachedData, sizeof(kCachedData)); + return sizeof(kCachedData); + }))); + + // Expect the cache hit. + EXPECT_CALL(cacheHitFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](Blob blob) { + // Expect the cached blob contents to match the cached data. + EXPECT_EQ(blob.Size(), sizeof(kCachedData)); + EXPECT_EQ(memcmp(blob.Data(), kCachedData, sizeof(kCachedData)), 0); + + // Return an error. + return DAWN_VALIDATION_ERROR("fake test error"); + }))); + + // Expect the cache miss handler since the cache hit errored. + int rv = 79; + EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](CacheRequestForTesting req) { + // Expect the request contents to be the same. The data pointer for |c| is also the same + // since it was moved. + EXPECT_EQ(req.a, 1); + EXPECT_FLOAT_EQ(req.b, 0.2); + EXPECT_EQ(req.c.data(), cPtr); + return rv; + }))); + + // Load the request. + auto result = + LoadOrRun( + GetDevice(), std::move(req), + [](Blob blob) -> ResultOrError<int> { return cacheHitFn.Call(std::move(blob)); }, + [](CacheRequestForTesting req) -> ResultOrError<int> { + return cacheMissFn.Call(std::move(req)); + }) + .AcquireSuccess(); + + // Expect the result to store the value. + EXPECT_EQ(*result, rv); + EXPECT_FALSE(result.IsCached()); +} + +} // namespace + +} // namespace dawn::native