Add CacheRequest utilities and tests

This CL adds a DAWN_MAKE_CACHE_REQUEST X macro
which helps in building a CacheRequest struct.

A CacheRequest struct may be passed to LoadOrRun
which will generate a CacheKey from the struct and
load a result if there is a cache hit, or it will
call the provided cache miss function to compute a value.

The request struct helps enforce that precisely the
inputs that go into a computation are all also included
inside the CacheKey for that computation.

Bug: dawn:549
Change-Id: Id85eb95f1b944d5431f142162ffa9a384351be89
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/91063
Reviewed-by: Loko Kung <lokokung@google.com>
Commit-Queue: Austin Eng <enga@chromium.org>
diff --git a/src/dawn/tests/BUILD.gn b/src/dawn/tests/BUILD.gn
index 69c41a2..57bb666 100644
--- a/src/dawn/tests/BUILD.gn
+++ b/src/dawn/tests/BUILD.gn
@@ -192,6 +192,7 @@
     ":gmock_and_gtest",
     ":mock_webgpu_gen",
     ":native_mocks_sources",
+    ":platform_mocks_sources",
     "${dawn_root}/src/dawn:cpp",
     "${dawn_root}/src/dawn:proc",
     "${dawn_root}/src/dawn/common",
@@ -254,6 +255,7 @@
     "unittests/VersionTests.cpp",
     "unittests/native/BlobTests.cpp",
     "unittests/native/CacheKeyTests.cpp",
+    "unittests/native/CacheRequestTests.cpp",
     "unittests/native/CommandBufferEncodingTests.cpp",
     "unittests/native/CreatePipelineAsyncTaskTests.cpp",
     "unittests/native/DestroyObjectTests.cpp",
@@ -380,9 +382,9 @@
 # Dawn end2end tests targets
 ###############################################################################
 
-# Source code for mocks used for end2end testing are separated from the rest of
+# Source code for mocks used for platform testing are separated from the rest of
 # sources so that they aren't included in non-test builds.
-source_set("end2end_mocks_sources") {
+source_set("platform_mocks_sources") {
   configs += [ "${dawn_root}/src/dawn/native:internal" ]
   testonly = true
 
@@ -392,8 +394,8 @@
   ]
 
   sources = [
-    "end2end/mocks/CachingInterfaceMock.cpp",
-    "end2end/mocks/CachingInterfaceMock.h",
+    "mocks/platform/CachingInterfaceMock.cpp",
+    "mocks/platform/CachingInterfaceMock.h",
   ]
 }
 
@@ -401,7 +403,7 @@
   testonly = true
 
   deps = [
-    ":end2end_mocks_sources",
+    ":platform_mocks_sources",
     ":test_infra_sources",
     "${dawn_root}/src/dawn:cpp",
     "${dawn_root}/src/dawn:proc",
diff --git a/src/dawn/tests/DawnNativeTest.cpp b/src/dawn/tests/DawnNativeTest.cpp
index 163413d..fbf8030 100644
--- a/src/dawn/tests/DawnNativeTest.cpp
+++ b/src/dawn/tests/DawnNativeTest.cpp
@@ -20,6 +20,9 @@
 #include "dawn/common/Assert.h"
 #include "dawn/dawn_proc.h"
 #include "dawn/native/ErrorData.h"
+#include "dawn/native/Instance.h"
+#include "dawn/native/dawn_platform.h"
+#include "dawn/platform/DawnPlatform.h"
 
 namespace dawn::native {
 
@@ -43,6 +46,9 @@
 
 void DawnNativeTest::SetUp() {
     instance = std::make_unique<dawn::native::Instance>();
+    platform = CreateTestPlatform();
+    dawn::native::FromAPI(instance->Get())->SetPlatformForTesting(platform.get());
+
     instance->DiscoverDefaultAdapters();
 
     std::vector<dawn::native::Adapter> adapters = instance->GetAdapters();
@@ -66,7 +72,9 @@
     device.SetUncapturedErrorCallback(DawnNativeTest::OnDeviceError, nullptr);
 }
 
-void DawnNativeTest::TearDown() {}
+std::unique_ptr<dawn::platform::Platform> DawnNativeTest::CreateTestPlatform() {
+    return nullptr;
+}
 
 WGPUDevice DawnNativeTest::CreateTestDevice() {
     // Disabled disallowing unsafe APIs so we can test them.
diff --git a/src/dawn/tests/DawnNativeTest.h b/src/dawn/tests/DawnNativeTest.h
index e92bf67..dd3532f 100644
--- a/src/dawn/tests/DawnNativeTest.h
+++ b/src/dawn/tests/DawnNativeTest.h
@@ -38,12 +38,13 @@
     ~DawnNativeTest() override;
 
     void SetUp() override;
-    void TearDown() override;
 
+    virtual std::unique_ptr<dawn::platform::Platform> CreateTestPlatform();
     virtual WGPUDevice CreateTestDevice();
 
   protected:
     std::unique_ptr<dawn::native::Instance> instance;
+    std::unique_ptr<dawn::platform::Platform> platform;
     dawn::native::Adapter adapter;
     wgpu::Device device;
 
diff --git a/src/dawn/tests/end2end/D3D12CachingTests.cpp b/src/dawn/tests/end2end/D3D12CachingTests.cpp
index 9cd4042..0d3bcdf 100644
--- a/src/dawn/tests/end2end/D3D12CachingTests.cpp
+++ b/src/dawn/tests/end2end/D3D12CachingTests.cpp
@@ -17,7 +17,7 @@
 #include <utility>
 
 #include "dawn/tests/DawnTest.h"
-#include "dawn/tests/end2end/mocks/CachingInterfaceMock.h"
+#include "dawn/tests/mocks/platform/CachingInterfaceMock.h"
 #include "dawn/utils/ComboRenderPipelineDescriptor.h"
 #include "dawn/utils/WGPUHelpers.h"
 
diff --git a/src/dawn/tests/end2end/PipelineCachingTests.cpp b/src/dawn/tests/end2end/PipelineCachingTests.cpp
index 23ee708..8318c48 100644
--- a/src/dawn/tests/end2end/PipelineCachingTests.cpp
+++ b/src/dawn/tests/end2end/PipelineCachingTests.cpp
@@ -16,7 +16,7 @@
 #include <string_view>
 
 #include "dawn/tests/DawnTest.h"
-#include "dawn/tests/end2end/mocks/CachingInterfaceMock.h"
+#include "dawn/tests/mocks/platform/CachingInterfaceMock.h"
 #include "dawn/utils/ComboRenderPipelineDescriptor.h"
 #include "dawn/utils/WGPUHelpers.h"
 
diff --git a/src/dawn/tests/end2end/mocks/CachingInterfaceMock.cpp b/src/dawn/tests/mocks/platform/CachingInterfaceMock.cpp
similarity index 97%
rename from src/dawn/tests/end2end/mocks/CachingInterfaceMock.cpp
rename to src/dawn/tests/mocks/platform/CachingInterfaceMock.cpp
index 8507e9c..a52d4c2 100644
--- a/src/dawn/tests/end2end/mocks/CachingInterfaceMock.cpp
+++ b/src/dawn/tests/mocks/platform/CachingInterfaceMock.cpp
@@ -12,7 +12,7 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "dawn/tests/end2end/mocks/CachingInterfaceMock.h"
+#include "dawn/tests/mocks/platform/CachingInterfaceMock.h"
 
 using ::testing::Invoke;
 
diff --git a/src/dawn/tests/end2end/mocks/CachingInterfaceMock.h b/src/dawn/tests/mocks/platform/CachingInterfaceMock.h
similarity index 92%
rename from src/dawn/tests/end2end/mocks/CachingInterfaceMock.h
rename to src/dawn/tests/mocks/platform/CachingInterfaceMock.h
index cc61d80..0e9e6af 100644
--- a/src/dawn/tests/end2end/mocks/CachingInterfaceMock.h
+++ b/src/dawn/tests/mocks/platform/CachingInterfaceMock.h
@@ -12,8 +12,8 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#ifndef SRC_DAWN_TESTS_END2END_MOCKS_CACHINGINTERFACEMOCK_H_
-#define SRC_DAWN_TESTS_END2END_MOCKS_CACHINGINTERFACEMOCK_H_
+#ifndef SRC_DAWN_TESTS_MOCKS_PLATFORM_CACHINGINTERFACEMOCK_H_
+#define SRC_DAWN_TESTS_MOCKS_PLATFORM_CACHINGINTERFACEMOCK_H_
 
 #include <dawn/platform/DawnPlatform.h>
 #include <gmock/gmock.h>
@@ -70,4 +70,4 @@
     dawn::platform::CachingInterface* mCachingInterface = nullptr;
 };
 
-#endif  // SRC_DAWN_TESTS_END2END_MOCKS_CACHINGINTERFACEMOCK_H_
+#endif  // SRC_DAWN_TESTS_MOCKS_PLATFORM_CACHINGINTERFACEMOCK_H_
diff --git a/src/dawn/tests/unittests/native/CacheRequestTests.cpp b/src/dawn/tests/unittests/native/CacheRequestTests.cpp
new file mode 100644
index 0000000..995de7f
--- /dev/null
+++ b/src/dawn/tests/unittests/native/CacheRequestTests.cpp
@@ -0,0 +1,320 @@
+// Copyright 2022 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "dawn/native/Blob.h"
+#include "dawn/native/CacheRequest.h"
+#include "dawn/tests/DawnNativeTest.h"
+#include "dawn/tests/mocks/platform/CachingInterfaceMock.h"
+
+namespace dawn::native {
+
+namespace {
+
+using ::testing::_;
+using ::testing::ByMove;
+using ::testing::Invoke;
+using ::testing::MockFunction;
+using ::testing::Return;
+using ::testing::StrictMock;
+using ::testing::WithArg;
+
+class CacheRequestTests : public DawnNativeTest {
+  protected:
+    std::unique_ptr<dawn::platform::Platform> CreateTestPlatform() override {
+        return std::make_unique<DawnCachingMockPlatform>(&mMockCache);
+    }
+
+    WGPUDevice CreateTestDevice() override {
+        wgpu::DeviceDescriptor deviceDescriptor = {};
+        wgpu::DawnTogglesDeviceDescriptor togglesDesc = {};
+        deviceDescriptor.nextInChain = &togglesDesc;
+
+        const char* toggle = "enable_blob_cache";
+        togglesDesc.forceEnabledToggles = &toggle;
+        togglesDesc.forceEnabledTogglesCount = 1;
+
+        return adapter.CreateDevice(&deviceDescriptor);
+    }
+
+    DeviceBase* GetDevice() { return dawn::native::FromAPI(device.Get()); }
+
+    StrictMock<CachingInterfaceMock> mMockCache;
+};
+
+struct Foo {
+    int value;
+};
+
+#define REQUEST_MEMBERS(X)                   \
+    X(int, a)                                \
+    X(float, b)                              \
+    X(std::vector<unsigned int>, c)          \
+    X(CacheKey::UnsafeUnkeyedValue<int*>, d) \
+    X(CacheKey::UnsafeUnkeyedValue<Foo>, e)
+
+DAWN_MAKE_CACHE_REQUEST(CacheRequestForTesting, REQUEST_MEMBERS)
+
+#undef REQUEST_MEMBERS
+
+// static_assert the expected types for various return types from the cache hit handler and cache
+// miss handler.
+TEST_F(CacheRequestTests, CacheResultTypes) {
+    EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillRepeatedly(Return(0));
+
+    // (int, ResultOrError<int>), should be ResultOrError<CacheResult<int>>.
+    auto v1 = LoadOrRun(
+        GetDevice(), CacheRequestForTesting{}, [](Blob) -> int { return 0; },
+        [](CacheRequestForTesting) -> ResultOrError<int> { return 1; });
+    v1.AcquireSuccess();
+    static_assert(std::is_same_v<ResultOrError<CacheResult<int>>, decltype(v1)>);
+
+    // (ResultOrError<float>, ResultOrError<float>), should be ResultOrError<CacheResult<float>>.
+    auto v2 = LoadOrRun(
+        GetDevice(), CacheRequestForTesting{}, [](Blob) -> ResultOrError<float> { return 0.0; },
+        [](CacheRequestForTesting) -> ResultOrError<float> { return 1.0; });
+    v2.AcquireSuccess();
+    static_assert(std::is_same_v<ResultOrError<CacheResult<float>>, decltype(v2)>);
+}
+
+// Test that using a CacheRequest builds a key from the device key, the request type enum, and all
+// of the request members.
+TEST_F(CacheRequestTests, MakesCacheKey) {
+    // Make a request.
+    CacheRequestForTesting req;
+    req.a = 1;
+    req.b = 0.2;
+    req.c = {3, 4, 5};
+
+    // Make the expected key.
+    CacheKey expectedKey;
+    expectedKey.Record(GetDevice()->GetCacheKey(), "CacheRequestForTesting", req.a, req.b, req.c);
+
+    // Expect a call to LoadData with the expected key.
+    EXPECT_CALL(mMockCache, LoadData(_, expectedKey.size(), nullptr, 0))
+        .WillOnce(WithArg<0>(Invoke([&](const void* actualKeyData) {
+            EXPECT_EQ(memcmp(actualKeyData, expectedKey.data(), expectedKey.size()), 0);
+            return 0;
+        })));
+
+    // Load the request.
+    auto result = LoadOrRun(
+                      GetDevice(), std::move(req), [](Blob) -> int { return 0; },
+                      [](CacheRequestForTesting) -> ResultOrError<int> { return 0; })
+                      .AcquireSuccess();
+
+    // The created cache key should be saved on the result.
+    EXPECT_EQ(result.GetCacheKey().size(), expectedKey.size());
+    EXPECT_EQ(memcmp(result.GetCacheKey().data(), expectedKey.data(), expectedKey.size()), 0);
+}
+
+// Test that members that are wrapped in UnsafeUnkeyedValue do not impact the key.
+TEST_F(CacheRequestTests, CacheKeyIgnoresUnsafeIgnoredValue) {
+    // Make two requests with different UnsafeUnkeyedValues (UnsafeUnkeyed is declared on the struct
+    // definition).
+    int v1, v2;
+    CacheRequestForTesting req1;
+    req1.d = &v1;
+    req1.e = Foo{42};
+
+    CacheRequestForTesting req2;
+    req2.d = &v2;
+    req2.e = Foo{24};
+
+    EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(0)).WillOnce(Return(0));
+
+    static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn;
+
+    // Load the first request, and check that the unsafe unkeyed values were passed though
+    EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([&](CacheRequestForTesting req) {
+        EXPECT_EQ(req.d.UnsafeGetValue(), &v1);
+        EXPECT_FLOAT_EQ(req.e.UnsafeGetValue().value, 42);
+        return 0;
+    })));
+    auto r1 = LoadOrRun(
+                  GetDevice(), std::move(req1), [](Blob) { return 0; },
+                  [](CacheRequestForTesting req) -> ResultOrError<int> {
+                      return cacheMissFn.Call(std::move(req));
+                  })
+                  .AcquireSuccess();
+
+    // Load the second request, and check that the unsafe unkeyed values were passed though
+    EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([&](CacheRequestForTesting req) {
+        EXPECT_EQ(req.d.UnsafeGetValue(), &v2);
+        EXPECT_FLOAT_EQ(req.e.UnsafeGetValue().value, 24);
+        return 0;
+    })));
+    auto r2 = LoadOrRun(
+                  GetDevice(), std::move(req2), [](Blob) { return 0; },
+                  [](CacheRequestForTesting req) -> ResultOrError<int> {
+                      return cacheMissFn.Call(std::move(req));
+                  })
+                  .AcquireSuccess();
+
+    // Expect their keys to be the same.
+    EXPECT_EQ(r1.GetCacheKey().size(), r2.GetCacheKey().size());
+    EXPECT_EQ(memcmp(r1.GetCacheKey().data(), r2.GetCacheKey().data(), r1.GetCacheKey().size()), 0);
+}
+
+// Test the expected code path when there is a cache miss.
+TEST_F(CacheRequestTests, CacheMiss) {
+    // Make a request.
+    CacheRequestForTesting req;
+    req.a = 1;
+    req.b = 0.2;
+    req.c = {3, 4, 5};
+
+    unsigned int* cPtr = req.c.data();
+
+    static StrictMock<MockFunction<int(Blob)>> cacheHitFn;
+    static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn;
+
+    // Mock a cache miss.
+    EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(0));
+
+    // Expect the cache miss, and return some value.
+    int rv = 42;
+    EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](CacheRequestForTesting req) {
+        // Expect the request contents to be the same. The data pointer for |c| is also the same
+        // since it was moved.
+        EXPECT_EQ(req.a, 1);
+        EXPECT_FLOAT_EQ(req.b, 0.2);
+        EXPECT_EQ(req.c.data(), cPtr);
+        return rv;
+    })));
+
+    // Load the request.
+    auto result = LoadOrRun(
+                      GetDevice(), std::move(req),
+                      [](Blob blob) -> int { return cacheHitFn.Call(std::move(blob)); },
+                      [](CacheRequestForTesting req) -> ResultOrError<int> {
+                          return cacheMissFn.Call(std::move(req));
+                      })
+                      .AcquireSuccess();
+
+    // Expect the result to store the value.
+    EXPECT_EQ(*result, rv);
+    EXPECT_FALSE(result.IsCached());
+}
+
+// Test the expected code path when there is a cache hit.
+TEST_F(CacheRequestTests, CacheHit) {
+    // Make a request.
+    CacheRequestForTesting req;
+    req.a = 1;
+    req.b = 0.2;
+    req.c = {3, 4, 5};
+
+    static StrictMock<MockFunction<int(Blob)>> cacheHitFn;
+    static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn;
+
+    static constexpr char kCachedData[] = "hello world!";
+
+    // Mock a cache hit, and load the cached data.
+    EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(sizeof(kCachedData)));
+    EXPECT_CALL(mMockCache, LoadData(_, _, _, sizeof(kCachedData)))
+        .WillOnce(WithArg<2>(Invoke([](void* dataOut) {
+            memcpy(dataOut, kCachedData, sizeof(kCachedData));
+            return sizeof(kCachedData);
+        })));
+
+    // Expect the cache hit, and return some value.
+    int rv = 1337;
+    EXPECT_CALL(cacheHitFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](Blob blob) {
+        // Expect the cached blob contents to match the cached data.
+        EXPECT_EQ(blob.Size(), sizeof(kCachedData));
+        EXPECT_EQ(memcmp(blob.Data(), kCachedData, sizeof(kCachedData)), 0);
+
+        return rv;
+    })));
+
+    // Load the request.
+    auto result = LoadOrRun(
+                      GetDevice(), std::move(req),
+                      [](Blob blob) -> int { return cacheHitFn.Call(std::move(blob)); },
+                      [](CacheRequestForTesting req) -> ResultOrError<int> {
+                          return cacheMissFn.Call(std::move(req));
+                      })
+                      .AcquireSuccess();
+
+    // Expect the result to store the value.
+    EXPECT_EQ(*result, rv);
+    EXPECT_TRUE(result.IsCached());
+}
+
+// Test the expected code path when there is a cache hit but the handler errors.
+TEST_F(CacheRequestTests, CacheHitError) {
+    // Make a request.
+    CacheRequestForTesting req;
+    req.a = 1;
+    req.b = 0.2;
+    req.c = {3, 4, 5};
+
+    unsigned int* cPtr = req.c.data();
+
+    static StrictMock<MockFunction<ResultOrError<int>(Blob)>> cacheHitFn;
+    static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn;
+
+    static constexpr char kCachedData[] = "hello world!";
+
+    // Mock a cache hit, and load the cached data.
+    EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(sizeof(kCachedData)));
+    EXPECT_CALL(mMockCache, LoadData(_, _, _, sizeof(kCachedData)))
+        .WillOnce(WithArg<2>(Invoke([](void* dataOut) {
+            memcpy(dataOut, kCachedData, sizeof(kCachedData));
+            return sizeof(kCachedData);
+        })));
+
+    // Expect the cache hit.
+    EXPECT_CALL(cacheHitFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](Blob blob) {
+        // Expect the cached blob contents to match the cached data.
+        EXPECT_EQ(blob.Size(), sizeof(kCachedData));
+        EXPECT_EQ(memcmp(blob.Data(), kCachedData, sizeof(kCachedData)), 0);
+
+        // Return an error.
+        return DAWN_VALIDATION_ERROR("fake test error");
+    })));
+
+    // Expect the cache miss handler since the cache hit errored.
+    int rv = 79;
+    EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](CacheRequestForTesting req) {
+        // Expect the request contents to be the same. The data pointer for |c| is also the same
+        // since it was moved.
+        EXPECT_EQ(req.a, 1);
+        EXPECT_FLOAT_EQ(req.b, 0.2);
+        EXPECT_EQ(req.c.data(), cPtr);
+        return rv;
+    })));
+
+    // Load the request.
+    auto result =
+        LoadOrRun(
+            GetDevice(), std::move(req),
+            [](Blob blob) -> ResultOrError<int> { return cacheHitFn.Call(std::move(blob)); },
+            [](CacheRequestForTesting req) -> ResultOrError<int> {
+                return cacheMissFn.Call(std::move(req));
+            })
+            .AcquireSuccess();
+
+    // Expect the result to store the value.
+    EXPECT_EQ(*result, rv);
+    EXPECT_FALSE(result.IsCached());
+}
+
+}  // namespace
+
+}  // namespace dawn::native