Add CacheRequest utilities and tests

This CL adds a DAWN_MAKE_CACHE_REQUEST X macro
which helps in building a CacheRequest struct.

A CacheRequest struct may be passed to LoadOrRun
which will generate a CacheKey from the struct and
load a result if there is a cache hit, or it will
call the provided cache miss function to compute a value.

The request struct helps enforce that precisely the
inputs that go into a computation are all also included
inside the CacheKey for that computation.

Bug: dawn:549
Change-Id: Id85eb95f1b944d5431f142162ffa9a384351be89
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/91063
Reviewed-by: Loko Kung <lokokung@google.com>
Commit-Queue: Austin Eng <enga@chromium.org>
diff --git a/src/dawn/native/BUILD.gn b/src/dawn/native/BUILD.gn
index 4294b26..137ce1b 100644
--- a/src/dawn/native/BUILD.gn
+++ b/src/dawn/native/BUILD.gn
@@ -200,6 +200,9 @@
     "Buffer.h",
     "CacheKey.cpp",
     "CacheKey.h",
+    "CacheRequest.cpp",
+    "CacheRequest.h",
+    "CacheResult.h",
     "CachedObject.cpp",
     "CachedObject.h",
     "CallbackTaskManager.cpp",
diff --git a/src/dawn/native/CMakeLists.txt b/src/dawn/native/CMakeLists.txt
index bbb50a0..8542b33 100644
--- a/src/dawn/native/CMakeLists.txt
+++ b/src/dawn/native/CMakeLists.txt
@@ -59,6 +59,9 @@
     "CachedObject.h"
     "CacheKey.cpp"
     "CacheKey.h"
+    "CacheRequest.cpp"
+    "CacheRequest.h"
+    "CacheResult.h"
     "CallbackTaskManager.cpp"
     "CallbackTaskManager.h"
     "CommandAllocator.cpp"
diff --git a/src/dawn/native/CacheKey.h b/src/dawn/native/CacheKey.h
index 357ce4b..c2901db 100644
--- a/src/dawn/native/CacheKey.h
+++ b/src/dawn/native/CacheKey.h
@@ -18,8 +18,10 @@
 #include <bitset>
 #include <iostream>
 #include <limits>
+#include <memory>
 #include <string>
 #include <type_traits>
+#include <utility>
 #include <vector>
 
 #include "dawn/common/TypedInteger.h"
@@ -49,6 +51,19 @@
     enum class Type { ComputePipeline, RenderPipeline, Shader };
 
     template <typename T>
+    class UnsafeUnkeyedValue {
+      public:
+        UnsafeUnkeyedValue() = default;
+        // NOLINTNEXTLINE(runtime/explicit) allow implicit construction to decrease verbosity
+        UnsafeUnkeyedValue(T&& value) : mValue(std::forward<T>(value)) {}
+
+        const T& UnsafeGetValue() const { return mValue; }
+
+      private:
+        T mValue;
+    };
+
+    template <typename T>
     CacheKey& Record(const T& t) {
         CacheKeySerializer<T>::Serialize(this, t);
         return *this;
@@ -89,6 +104,18 @@
     }
 };
 
+template <typename T>
+CacheKey::UnsafeUnkeyedValue<T> UnsafeUnkeyedValue(T&& value) {
+    return CacheKey::UnsafeUnkeyedValue<T>(std::forward<T>(value));
+}
+
+// Specialized overload for CacheKey::UnsafeIgnoredValue which does nothing.
+template <typename T>
+class CacheKeySerializer<CacheKey::UnsafeUnkeyedValue<T>> {
+  public:
+    constexpr static void Serialize(CacheKey* key, const CacheKey::UnsafeUnkeyedValue<T>&) {}
+};
+
 // Specialized overload for fundamental types.
 template <typename T>
 class CacheKeySerializer<T, std::enable_if_t<std::is_fundamental_v<T>>> {
@@ -197,6 +224,13 @@
     static void Serialize(CacheKey* key, const T& t) { key->Record(t.GetCacheKey()); }
 };
 
+// Specialized overload for std::vector.
+template <typename T>
+class CacheKeySerializer<std::vector<T>> {
+  public:
+    static void Serialize(CacheKey* key, const std::vector<T>& t) { key->RecordIterable(t); }
+};
+
 }  // namespace dawn::native
 
 #endif  // SRC_DAWN_NATIVE_CACHEKEY_H_
diff --git a/src/dawn/native/CacheRequest.cpp b/src/dawn/native/CacheRequest.cpp
new file mode 100644
index 0000000..2b35b12
--- /dev/null
+++ b/src/dawn/native/CacheRequest.cpp
@@ -0,0 +1,25 @@
+// Copyright 2022 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "dawn/native/CacheRequest.h"
+
+#include "dawn/common/Log.h"
+
+namespace dawn::native::detail {
+
+void LogCacheHitError(std::unique_ptr<ErrorData> error) {
+    dawn::ErrorLog() << error->GetFormattedMessage();
+}
+
+}  // namespace dawn::native::detail
diff --git a/src/dawn/native/CacheRequest.h b/src/dawn/native/CacheRequest.h
new file mode 100644
index 0000000..0fecb0d
--- /dev/null
+++ b/src/dawn/native/CacheRequest.h
@@ -0,0 +1,186 @@
+// Copyright 2022 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_DAWN_NATIVE_CACHEREQUEST_H_
+#define SRC_DAWN_NATIVE_CACHEREQUEST_H_
+
+#include <memory>
+#include <utility>
+
+#include "dawn/common/Assert.h"
+#include "dawn/common/Compiler.h"
+#include "dawn/native/Blob.h"
+#include "dawn/native/BlobCache.h"
+#include "dawn/native/CacheKey.h"
+#include "dawn/native/CacheResult.h"
+#include "dawn/native/Device.h"
+#include "dawn/native/Error.h"
+
+namespace dawn::native {
+
+namespace detail {
+
+template <typename T>
+struct UnwrapResultOrError {
+    using type = T;
+};
+
+template <typename T>
+struct UnwrapResultOrError<ResultOrError<T>> {
+    using type = T;
+};
+
+template <typename T>
+struct IsResultOrError {
+    static constexpr bool value = false;
+};
+
+template <typename T>
+struct IsResultOrError<ResultOrError<T>> {
+    static constexpr bool value = true;
+};
+
+void LogCacheHitError(std::unique_ptr<ErrorData> error);
+
+}  // namespace detail
+
+// Implementation of a CacheRequest which provides a LoadOrRun friend function which can be found
+// via argument-dependent lookup. So, it doesn't need to be called with a fully qualified function
+// name.
+//
+// Example usage:
+//   Request r = { ... };
+//   ResultOrError<CacheResult<T>> cacheResult =
+//       LoadOrRun(device, std::move(r),
+//                 [](Blob blob) -> T { /* handle cache hit */ },
+//                 [](Request r) -> ResultOrError<T> { /* handle cache miss */ }
+//       );
+// Or with free functions:
+///  T OnCacheHit(Blob blob) { ... }
+//   ResultOrError<T> OnCacheMiss(Request r) { ... }
+//   // ...
+//   Request r = { ... };
+//   auto result = LoadOrRun(device, std::move(r), OnCacheHit, OnCacheMiss);
+//
+// LoadOrRun generates a CacheKey from the request and loads from the device's BlobCache. On cache
+// hit, calls CacheHitFn and returns a CacheResult<T>. On cache miss or if CacheHitFn returned an
+// Error, calls CacheMissFn -> ResultOrError<T> with the request data and returns a
+// ResultOrError<CacheResult<T>>. CacheHitFn must return the same unwrapped type as CacheMissFn.
+// i.e. it doesn't need to be wrapped in ResultOrError.
+//
+// CacheMissFn may not have any additional data bound to it. It may not be a lambda or std::function
+// which captures additional information, so it can only operate on the request data. This is
+// enforced with a compile-time static_assert, and ensures that the result created from the
+// computation is exactly the data included in the CacheKey.
+template <typename Request>
+class CacheRequestImpl {
+  public:
+    CacheRequestImpl() = default;
+
+    // Require CacheRequests to be move-only to avoid unnecessary copies.
+    CacheRequestImpl(CacheRequestImpl&&) = default;
+    CacheRequestImpl& operator=(CacheRequestImpl&&) = default;
+    CacheRequestImpl(const CacheRequestImpl&) = delete;
+    CacheRequestImpl& operator=(const CacheRequestImpl&) = delete;
+
+    template <typename CacheHitFn, typename CacheMissFn>
+    friend auto LoadOrRun(DeviceBase* device,
+                          Request&& r,
+                          CacheHitFn cacheHitFn,
+                          CacheMissFn cacheMissFn) {
+        // Get return types and check that CacheMissReturnType can be cast to a raw function
+        // pointer. This means it's not a std::function or lambda that captures additional data.
+        using CacheHitReturnType = decltype(cacheHitFn(std::declval<Blob>()));
+        using CacheMissReturnType = decltype(cacheMissFn(std::declval<Request>()));
+        static_assert(
+            std::is_convertible_v<CacheMissFn, CacheMissReturnType (*)(Request)>,
+            "CacheMissFn function signature does not match, or it is not a free function.");
+
+        static_assert(detail::IsResultOrError<CacheMissReturnType>::value,
+                      "CacheMissFn should return a ResultOrError.");
+        using UnwrappedReturnType = typename detail::UnwrapResultOrError<CacheMissReturnType>::type;
+
+        static_assert(std::is_same_v<typename detail::UnwrapResultOrError<CacheHitReturnType>::type,
+                                     UnwrappedReturnType>,
+                      "If CacheMissFn returns T, CacheHitFn must return T or ResultOrError<T>.");
+
+        using CacheResultType = CacheResult<UnwrappedReturnType>;
+        using ReturnType = ResultOrError<CacheResultType>;
+
+        CacheKey key = r.CreateCacheKey(device);
+        BlobCache* cache = device->GetBlobCache();
+        Blob blob;
+        if (cache != nullptr) {
+            blob = cache->Load(key);
+        }
+
+        if (!blob.Empty()) {
+            // Cache hit. Handle the cached blob.
+            auto result = cacheHitFn(std::move(blob));
+
+            if constexpr (!detail::IsResultOrError<CacheHitReturnType>::value) {
+                // If the result type is not a ResultOrError, return it.
+                return ReturnType(CacheResultType::CacheHit(std::move(key), std::move(result)));
+            } else {
+                // Otherwise, if the value is a success, also return it.
+                if (DAWN_LIKELY(result.IsSuccess())) {
+                    return ReturnType(
+                        CacheResultType::CacheHit(std::move(key), result.AcquireSuccess()));
+                }
+                // On error, continue to the cache miss path and log the error.
+                detail::LogCacheHitError(result.AcquireError());
+            }
+        }
+        // Cache miss, or the CacheHitFn failed.
+        auto result = cacheMissFn(std::move(r));
+        if (DAWN_LIKELY(result.IsSuccess())) {
+            return ReturnType(CacheResultType::CacheMiss(std::move(key), result.AcquireSuccess()));
+        }
+        return ReturnType(result.AcquireError());
+    }
+};
+
+// Helper for X macro to declare a struct member.
+#define DAWN_INTERNAL_CACHE_REQUEST_DECL_STRUCT_MEMBER(type, name) type name{};
+
+// Helper for X macro for recording cache request fields into a CacheKey.
+#define DAWN_INTERNAL_CACHE_REQUEST_RECORD_KEY(type, name) key.Record(name);
+
+// Helper X macro to define a CacheRequest struct.
+// Example usage:
+//   #define REQUEST_MEMBERS(X) \
+//       X(int, a)              \
+//       X(float, b)            \
+//       X(Foo, foo)            \
+//       X(Bar, bar)
+//   DAWN_MAKE_CACHE_REQUEST(MyCacheRequest, REQUEST_MEMBERS)
+//   #undef REQUEST_MEMBERS
+#define DAWN_MAKE_CACHE_REQUEST(Request, MEMBERS)                     \
+    class Request : public CacheRequestImpl<Request> {                \
+      public:                                                         \
+        Request() = default;                                          \
+        MEMBERS(DAWN_INTERNAL_CACHE_REQUEST_DECL_STRUCT_MEMBER)       \
+                                                                      \
+        /* Create a CacheKey from the request type and all members */ \
+        CacheKey CreateCacheKey(const DeviceBase* device) const {     \
+            CacheKey key = device->GetCacheKey();                     \
+            key.Record(#Request);                                     \
+            MEMBERS(DAWN_INTERNAL_CACHE_REQUEST_RECORD_KEY)           \
+            return key;                                               \
+        }                                                             \
+    };
+
+}  // namespace dawn::native
+
+#endif  // SRC_DAWN_NATIVE_CACHEREQUEST_H_
diff --git a/src/dawn/native/CacheResult.h b/src/dawn/native/CacheResult.h
new file mode 100644
index 0000000..a2750fe
--- /dev/null
+++ b/src/dawn/native/CacheResult.h
@@ -0,0 +1,76 @@
+// Copyright 2022 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_DAWN_NATIVE_CACHERESULT_H_
+#define SRC_DAWN_NATIVE_CACHERESULT_H_
+
+#include <memory>
+#include <utility>
+
+#include "dawn/common/Assert.h"
+
+namespace dawn::native {
+
+template <typename T>
+class CacheResult {
+  public:
+    static CacheResult CacheHit(CacheKey key, T value) {
+        return CacheResult(std::move(key), std::move(value), true);
+    }
+
+    static CacheResult CacheMiss(CacheKey key, T value) {
+        return CacheResult(std::move(key), std::move(value), false);
+    }
+
+    CacheResult() : mKey(), mValue(), mIsCached(false), mIsValid(false) {}
+
+    bool IsCached() const {
+        ASSERT(mIsValid);
+        return mIsCached;
+    }
+    const CacheKey& GetCacheKey() {
+        ASSERT(mIsValid);
+        return mKey;
+    }
+
+    // Note: Getting mValue is always const, since mutating it would invalidate consistency with
+    // mKey.
+    const T* operator->() const {
+        ASSERT(mIsValid);
+        return &mValue;
+    }
+    const T& operator*() const {
+        ASSERT(mIsValid);
+        return mValue;
+    }
+
+    T Acquire() {
+        ASSERT(mIsValid);
+        mIsValid = false;
+        return std::move(mValue);
+    }
+
+  private:
+    CacheResult(CacheKey key, T value, bool isCached)
+        : mKey(std::move(key)), mValue(std::move(value)), mIsCached(isCached), mIsValid(true) {}
+
+    CacheKey mKey;
+    T mValue;
+    bool mIsCached;
+    bool mIsValid;
+};
+
+}  // namespace dawn::native
+
+#endif  // SRC_DAWN_NATIVE_CACHERESULT_H_
diff --git a/src/dawn/tests/BUILD.gn b/src/dawn/tests/BUILD.gn
index 69c41a2..57bb666 100644
--- a/src/dawn/tests/BUILD.gn
+++ b/src/dawn/tests/BUILD.gn
@@ -192,6 +192,7 @@
     ":gmock_and_gtest",
     ":mock_webgpu_gen",
     ":native_mocks_sources",
+    ":platform_mocks_sources",
     "${dawn_root}/src/dawn:cpp",
     "${dawn_root}/src/dawn:proc",
     "${dawn_root}/src/dawn/common",
@@ -254,6 +255,7 @@
     "unittests/VersionTests.cpp",
     "unittests/native/BlobTests.cpp",
     "unittests/native/CacheKeyTests.cpp",
+    "unittests/native/CacheRequestTests.cpp",
     "unittests/native/CommandBufferEncodingTests.cpp",
     "unittests/native/CreatePipelineAsyncTaskTests.cpp",
     "unittests/native/DestroyObjectTests.cpp",
@@ -380,9 +382,9 @@
 # Dawn end2end tests targets
 ###############################################################################
 
-# Source code for mocks used for end2end testing are separated from the rest of
+# Source code for mocks used for platform testing are separated from the rest of
 # sources so that they aren't included in non-test builds.
-source_set("end2end_mocks_sources") {
+source_set("platform_mocks_sources") {
   configs += [ "${dawn_root}/src/dawn/native:internal" ]
   testonly = true
 
@@ -392,8 +394,8 @@
   ]
 
   sources = [
-    "end2end/mocks/CachingInterfaceMock.cpp",
-    "end2end/mocks/CachingInterfaceMock.h",
+    "mocks/platform/CachingInterfaceMock.cpp",
+    "mocks/platform/CachingInterfaceMock.h",
   ]
 }
 
@@ -401,7 +403,7 @@
   testonly = true
 
   deps = [
-    ":end2end_mocks_sources",
+    ":platform_mocks_sources",
     ":test_infra_sources",
     "${dawn_root}/src/dawn:cpp",
     "${dawn_root}/src/dawn:proc",
diff --git a/src/dawn/tests/DawnNativeTest.cpp b/src/dawn/tests/DawnNativeTest.cpp
index 163413d..fbf8030 100644
--- a/src/dawn/tests/DawnNativeTest.cpp
+++ b/src/dawn/tests/DawnNativeTest.cpp
@@ -20,6 +20,9 @@
 #include "dawn/common/Assert.h"
 #include "dawn/dawn_proc.h"
 #include "dawn/native/ErrorData.h"
+#include "dawn/native/Instance.h"
+#include "dawn/native/dawn_platform.h"
+#include "dawn/platform/DawnPlatform.h"
 
 namespace dawn::native {
 
@@ -43,6 +46,9 @@
 
 void DawnNativeTest::SetUp() {
     instance = std::make_unique<dawn::native::Instance>();
+    platform = CreateTestPlatform();
+    dawn::native::FromAPI(instance->Get())->SetPlatformForTesting(platform.get());
+
     instance->DiscoverDefaultAdapters();
 
     std::vector<dawn::native::Adapter> adapters = instance->GetAdapters();
@@ -66,7 +72,9 @@
     device.SetUncapturedErrorCallback(DawnNativeTest::OnDeviceError, nullptr);
 }
 
-void DawnNativeTest::TearDown() {}
+std::unique_ptr<dawn::platform::Platform> DawnNativeTest::CreateTestPlatform() {
+    return nullptr;
+}
 
 WGPUDevice DawnNativeTest::CreateTestDevice() {
     // Disabled disallowing unsafe APIs so we can test them.
diff --git a/src/dawn/tests/DawnNativeTest.h b/src/dawn/tests/DawnNativeTest.h
index e92bf67..dd3532f 100644
--- a/src/dawn/tests/DawnNativeTest.h
+++ b/src/dawn/tests/DawnNativeTest.h
@@ -38,12 +38,13 @@
     ~DawnNativeTest() override;
 
     void SetUp() override;
-    void TearDown() override;
 
+    virtual std::unique_ptr<dawn::platform::Platform> CreateTestPlatform();
     virtual WGPUDevice CreateTestDevice();
 
   protected:
     std::unique_ptr<dawn::native::Instance> instance;
+    std::unique_ptr<dawn::platform::Platform> platform;
     dawn::native::Adapter adapter;
     wgpu::Device device;
 
diff --git a/src/dawn/tests/end2end/D3D12CachingTests.cpp b/src/dawn/tests/end2end/D3D12CachingTests.cpp
index 9cd4042..0d3bcdf 100644
--- a/src/dawn/tests/end2end/D3D12CachingTests.cpp
+++ b/src/dawn/tests/end2end/D3D12CachingTests.cpp
@@ -17,7 +17,7 @@
 #include <utility>
 
 #include "dawn/tests/DawnTest.h"
-#include "dawn/tests/end2end/mocks/CachingInterfaceMock.h"
+#include "dawn/tests/mocks/platform/CachingInterfaceMock.h"
 #include "dawn/utils/ComboRenderPipelineDescriptor.h"
 #include "dawn/utils/WGPUHelpers.h"
 
diff --git a/src/dawn/tests/end2end/PipelineCachingTests.cpp b/src/dawn/tests/end2end/PipelineCachingTests.cpp
index 23ee708..8318c48 100644
--- a/src/dawn/tests/end2end/PipelineCachingTests.cpp
+++ b/src/dawn/tests/end2end/PipelineCachingTests.cpp
@@ -16,7 +16,7 @@
 #include <string_view>
 
 #include "dawn/tests/DawnTest.h"
-#include "dawn/tests/end2end/mocks/CachingInterfaceMock.h"
+#include "dawn/tests/mocks/platform/CachingInterfaceMock.h"
 #include "dawn/utils/ComboRenderPipelineDescriptor.h"
 #include "dawn/utils/WGPUHelpers.h"
 
diff --git a/src/dawn/tests/end2end/mocks/CachingInterfaceMock.cpp b/src/dawn/tests/mocks/platform/CachingInterfaceMock.cpp
similarity index 97%
rename from src/dawn/tests/end2end/mocks/CachingInterfaceMock.cpp
rename to src/dawn/tests/mocks/platform/CachingInterfaceMock.cpp
index 8507e9c..a52d4c2 100644
--- a/src/dawn/tests/end2end/mocks/CachingInterfaceMock.cpp
+++ b/src/dawn/tests/mocks/platform/CachingInterfaceMock.cpp
@@ -12,7 +12,7 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "dawn/tests/end2end/mocks/CachingInterfaceMock.h"
+#include "dawn/tests/mocks/platform/CachingInterfaceMock.h"
 
 using ::testing::Invoke;
 
diff --git a/src/dawn/tests/end2end/mocks/CachingInterfaceMock.h b/src/dawn/tests/mocks/platform/CachingInterfaceMock.h
similarity index 92%
rename from src/dawn/tests/end2end/mocks/CachingInterfaceMock.h
rename to src/dawn/tests/mocks/platform/CachingInterfaceMock.h
index cc61d80..0e9e6af 100644
--- a/src/dawn/tests/end2end/mocks/CachingInterfaceMock.h
+++ b/src/dawn/tests/mocks/platform/CachingInterfaceMock.h
@@ -12,8 +12,8 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#ifndef SRC_DAWN_TESTS_END2END_MOCKS_CACHINGINTERFACEMOCK_H_
-#define SRC_DAWN_TESTS_END2END_MOCKS_CACHINGINTERFACEMOCK_H_
+#ifndef SRC_DAWN_TESTS_MOCKS_PLATFORM_CACHINGINTERFACEMOCK_H_
+#define SRC_DAWN_TESTS_MOCKS_PLATFORM_CACHINGINTERFACEMOCK_H_
 
 #include <dawn/platform/DawnPlatform.h>
 #include <gmock/gmock.h>
@@ -70,4 +70,4 @@
     dawn::platform::CachingInterface* mCachingInterface = nullptr;
 };
 
-#endif  // SRC_DAWN_TESTS_END2END_MOCKS_CACHINGINTERFACEMOCK_H_
+#endif  // SRC_DAWN_TESTS_MOCKS_PLATFORM_CACHINGINTERFACEMOCK_H_
diff --git a/src/dawn/tests/unittests/native/CacheRequestTests.cpp b/src/dawn/tests/unittests/native/CacheRequestTests.cpp
new file mode 100644
index 0000000..995de7f
--- /dev/null
+++ b/src/dawn/tests/unittests/native/CacheRequestTests.cpp
@@ -0,0 +1,320 @@
+// Copyright 2022 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "dawn/native/Blob.h"
+#include "dawn/native/CacheRequest.h"
+#include "dawn/tests/DawnNativeTest.h"
+#include "dawn/tests/mocks/platform/CachingInterfaceMock.h"
+
+namespace dawn::native {
+
+namespace {
+
+using ::testing::_;
+using ::testing::ByMove;
+using ::testing::Invoke;
+using ::testing::MockFunction;
+using ::testing::Return;
+using ::testing::StrictMock;
+using ::testing::WithArg;
+
+class CacheRequestTests : public DawnNativeTest {
+  protected:
+    std::unique_ptr<dawn::platform::Platform> CreateTestPlatform() override {
+        return std::make_unique<DawnCachingMockPlatform>(&mMockCache);
+    }
+
+    WGPUDevice CreateTestDevice() override {
+        wgpu::DeviceDescriptor deviceDescriptor = {};
+        wgpu::DawnTogglesDeviceDescriptor togglesDesc = {};
+        deviceDescriptor.nextInChain = &togglesDesc;
+
+        const char* toggle = "enable_blob_cache";
+        togglesDesc.forceEnabledToggles = &toggle;
+        togglesDesc.forceEnabledTogglesCount = 1;
+
+        return adapter.CreateDevice(&deviceDescriptor);
+    }
+
+    DeviceBase* GetDevice() { return dawn::native::FromAPI(device.Get()); }
+
+    StrictMock<CachingInterfaceMock> mMockCache;
+};
+
+struct Foo {
+    int value;
+};
+
+#define REQUEST_MEMBERS(X)                   \
+    X(int, a)                                \
+    X(float, b)                              \
+    X(std::vector<unsigned int>, c)          \
+    X(CacheKey::UnsafeUnkeyedValue<int*>, d) \
+    X(CacheKey::UnsafeUnkeyedValue<Foo>, e)
+
+DAWN_MAKE_CACHE_REQUEST(CacheRequestForTesting, REQUEST_MEMBERS)
+
+#undef REQUEST_MEMBERS
+
+// static_assert the expected types for various return types from the cache hit handler and cache
+// miss handler.
+TEST_F(CacheRequestTests, CacheResultTypes) {
+    EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillRepeatedly(Return(0));
+
+    // (int, ResultOrError<int>), should be ResultOrError<CacheResult<int>>.
+    auto v1 = LoadOrRun(
+        GetDevice(), CacheRequestForTesting{}, [](Blob) -> int { return 0; },
+        [](CacheRequestForTesting) -> ResultOrError<int> { return 1; });
+    v1.AcquireSuccess();
+    static_assert(std::is_same_v<ResultOrError<CacheResult<int>>, decltype(v1)>);
+
+    // (ResultOrError<float>, ResultOrError<float>), should be ResultOrError<CacheResult<float>>.
+    auto v2 = LoadOrRun(
+        GetDevice(), CacheRequestForTesting{}, [](Blob) -> ResultOrError<float> { return 0.0; },
+        [](CacheRequestForTesting) -> ResultOrError<float> { return 1.0; });
+    v2.AcquireSuccess();
+    static_assert(std::is_same_v<ResultOrError<CacheResult<float>>, decltype(v2)>);
+}
+
+// Test that using a CacheRequest builds a key from the device key, the request type enum, and all
+// of the request members.
+TEST_F(CacheRequestTests, MakesCacheKey) {
+    // Make a request.
+    CacheRequestForTesting req;
+    req.a = 1;
+    req.b = 0.2;
+    req.c = {3, 4, 5};
+
+    // Make the expected key.
+    CacheKey expectedKey;
+    expectedKey.Record(GetDevice()->GetCacheKey(), "CacheRequestForTesting", req.a, req.b, req.c);
+
+    // Expect a call to LoadData with the expected key.
+    EXPECT_CALL(mMockCache, LoadData(_, expectedKey.size(), nullptr, 0))
+        .WillOnce(WithArg<0>(Invoke([&](const void* actualKeyData) {
+            EXPECT_EQ(memcmp(actualKeyData, expectedKey.data(), expectedKey.size()), 0);
+            return 0;
+        })));
+
+    // Load the request.
+    auto result = LoadOrRun(
+                      GetDevice(), std::move(req), [](Blob) -> int { return 0; },
+                      [](CacheRequestForTesting) -> ResultOrError<int> { return 0; })
+                      .AcquireSuccess();
+
+    // The created cache key should be saved on the result.
+    EXPECT_EQ(result.GetCacheKey().size(), expectedKey.size());
+    EXPECT_EQ(memcmp(result.GetCacheKey().data(), expectedKey.data(), expectedKey.size()), 0);
+}
+
+// Test that members that are wrapped in UnsafeUnkeyedValue do not impact the key.
+TEST_F(CacheRequestTests, CacheKeyIgnoresUnsafeIgnoredValue) {
+    // Make two requests with different UnsafeUnkeyedValues (UnsafeUnkeyed is declared on the struct
+    // definition).
+    int v1, v2;
+    CacheRequestForTesting req1;
+    req1.d = &v1;
+    req1.e = Foo{42};
+
+    CacheRequestForTesting req2;
+    req2.d = &v2;
+    req2.e = Foo{24};
+
+    EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(0)).WillOnce(Return(0));
+
+    static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn;
+
+    // Load the first request, and check that the unsafe unkeyed values were passed though
+    EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([&](CacheRequestForTesting req) {
+        EXPECT_EQ(req.d.UnsafeGetValue(), &v1);
+        EXPECT_FLOAT_EQ(req.e.UnsafeGetValue().value, 42);
+        return 0;
+    })));
+    auto r1 = LoadOrRun(
+                  GetDevice(), std::move(req1), [](Blob) { return 0; },
+                  [](CacheRequestForTesting req) -> ResultOrError<int> {
+                      return cacheMissFn.Call(std::move(req));
+                  })
+                  .AcquireSuccess();
+
+    // Load the second request, and check that the unsafe unkeyed values were passed though
+    EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([&](CacheRequestForTesting req) {
+        EXPECT_EQ(req.d.UnsafeGetValue(), &v2);
+        EXPECT_FLOAT_EQ(req.e.UnsafeGetValue().value, 24);
+        return 0;
+    })));
+    auto r2 = LoadOrRun(
+                  GetDevice(), std::move(req2), [](Blob) { return 0; },
+                  [](CacheRequestForTesting req) -> ResultOrError<int> {
+                      return cacheMissFn.Call(std::move(req));
+                  })
+                  .AcquireSuccess();
+
+    // Expect their keys to be the same.
+    EXPECT_EQ(r1.GetCacheKey().size(), r2.GetCacheKey().size());
+    EXPECT_EQ(memcmp(r1.GetCacheKey().data(), r2.GetCacheKey().data(), r1.GetCacheKey().size()), 0);
+}
+
+// Test the expected code path when there is a cache miss.
+TEST_F(CacheRequestTests, CacheMiss) {
+    // Make a request.
+    CacheRequestForTesting req;
+    req.a = 1;
+    req.b = 0.2;
+    req.c = {3, 4, 5};
+
+    unsigned int* cPtr = req.c.data();
+
+    static StrictMock<MockFunction<int(Blob)>> cacheHitFn;
+    static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn;
+
+    // Mock a cache miss.
+    EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(0));
+
+    // Expect the cache miss, and return some value.
+    int rv = 42;
+    EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](CacheRequestForTesting req) {
+        // Expect the request contents to be the same. The data pointer for |c| is also the same
+        // since it was moved.
+        EXPECT_EQ(req.a, 1);
+        EXPECT_FLOAT_EQ(req.b, 0.2);
+        EXPECT_EQ(req.c.data(), cPtr);
+        return rv;
+    })));
+
+    // Load the request.
+    auto result = LoadOrRun(
+                      GetDevice(), std::move(req),
+                      [](Blob blob) -> int { return cacheHitFn.Call(std::move(blob)); },
+                      [](CacheRequestForTesting req) -> ResultOrError<int> {
+                          return cacheMissFn.Call(std::move(req));
+                      })
+                      .AcquireSuccess();
+
+    // Expect the result to store the value.
+    EXPECT_EQ(*result, rv);
+    EXPECT_FALSE(result.IsCached());
+}
+
+// Test the expected code path when there is a cache hit.
+TEST_F(CacheRequestTests, CacheHit) {
+    // Make a request.
+    CacheRequestForTesting req;
+    req.a = 1;
+    req.b = 0.2;
+    req.c = {3, 4, 5};
+
+    static StrictMock<MockFunction<int(Blob)>> cacheHitFn;
+    static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn;
+
+    static constexpr char kCachedData[] = "hello world!";
+
+    // Mock a cache hit, and load the cached data.
+    EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(sizeof(kCachedData)));
+    EXPECT_CALL(mMockCache, LoadData(_, _, _, sizeof(kCachedData)))
+        .WillOnce(WithArg<2>(Invoke([](void* dataOut) {
+            memcpy(dataOut, kCachedData, sizeof(kCachedData));
+            return sizeof(kCachedData);
+        })));
+
+    // Expect the cache hit, and return some value.
+    int rv = 1337;
+    EXPECT_CALL(cacheHitFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](Blob blob) {
+        // Expect the cached blob contents to match the cached data.
+        EXPECT_EQ(blob.Size(), sizeof(kCachedData));
+        EXPECT_EQ(memcmp(blob.Data(), kCachedData, sizeof(kCachedData)), 0);
+
+        return rv;
+    })));
+
+    // Load the request.
+    auto result = LoadOrRun(
+                      GetDevice(), std::move(req),
+                      [](Blob blob) -> int { return cacheHitFn.Call(std::move(blob)); },
+                      [](CacheRequestForTesting req) -> ResultOrError<int> {
+                          return cacheMissFn.Call(std::move(req));
+                      })
+                      .AcquireSuccess();
+
+    // Expect the result to store the value.
+    EXPECT_EQ(*result, rv);
+    EXPECT_TRUE(result.IsCached());
+}
+
+// Test the expected code path when there is a cache hit but the handler errors.
+TEST_F(CacheRequestTests, CacheHitError) {
+    // Make a request.
+    CacheRequestForTesting req;
+    req.a = 1;
+    req.b = 0.2;
+    req.c = {3, 4, 5};
+
+    unsigned int* cPtr = req.c.data();
+
+    static StrictMock<MockFunction<ResultOrError<int>(Blob)>> cacheHitFn;
+    static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn;
+
+    static constexpr char kCachedData[] = "hello world!";
+
+    // Mock a cache hit, and load the cached data.
+    EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(sizeof(kCachedData)));
+    EXPECT_CALL(mMockCache, LoadData(_, _, _, sizeof(kCachedData)))
+        .WillOnce(WithArg<2>(Invoke([](void* dataOut) {
+            memcpy(dataOut, kCachedData, sizeof(kCachedData));
+            return sizeof(kCachedData);
+        })));
+
+    // Expect the cache hit.
+    EXPECT_CALL(cacheHitFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](Blob blob) {
+        // Expect the cached blob contents to match the cached data.
+        EXPECT_EQ(blob.Size(), sizeof(kCachedData));
+        EXPECT_EQ(memcmp(blob.Data(), kCachedData, sizeof(kCachedData)), 0);
+
+        // Return an error.
+        return DAWN_VALIDATION_ERROR("fake test error");
+    })));
+
+    // Expect the cache miss handler since the cache hit errored.
+    int rv = 79;
+    EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](CacheRequestForTesting req) {
+        // Expect the request contents to be the same. The data pointer for |c| is also the same
+        // since it was moved.
+        EXPECT_EQ(req.a, 1);
+        EXPECT_FLOAT_EQ(req.b, 0.2);
+        EXPECT_EQ(req.c.data(), cPtr);
+        return rv;
+    })));
+
+    // Load the request.
+    auto result =
+        LoadOrRun(
+            GetDevice(), std::move(req),
+            [](Blob blob) -> ResultOrError<int> { return cacheHitFn.Call(std::move(blob)); },
+            [](CacheRequestForTesting req) -> ResultOrError<int> {
+                return cacheMissFn.Call(std::move(req));
+            })
+            .AcquireSuccess();
+
+    // Expect the result to store the value.
+    EXPECT_EQ(*result, rv);
+    EXPECT_FALSE(result.IsCached());
+}
+
+}  // namespace
+
+}  // namespace dawn::native