Add CacheRequest utilities and tests This CL adds a DAWN_MAKE_CACHE_REQUEST X macro which helps in building a CacheRequest struct. A CacheRequest struct may be passed to LoadOrRun which will generate a CacheKey from the struct and load a result if there is a cache hit, or it will call the provided cache miss function to compute a value. The request struct helps enforce that precisely the inputs that go into a computation are all also included inside the CacheKey for that computation. Bug: dawn:549 Change-Id: Id85eb95f1b944d5431f142162ffa9a384351be89 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/91063 Reviewed-by: Loko Kung <lokokung@google.com> Commit-Queue: Austin Eng <enga@chromium.org>
diff --git a/src/dawn/tests/BUILD.gn b/src/dawn/tests/BUILD.gn index 69c41a2..57bb666 100644 --- a/src/dawn/tests/BUILD.gn +++ b/src/dawn/tests/BUILD.gn
@@ -192,6 +192,7 @@ ":gmock_and_gtest", ":mock_webgpu_gen", ":native_mocks_sources", + ":platform_mocks_sources", "${dawn_root}/src/dawn:cpp", "${dawn_root}/src/dawn:proc", "${dawn_root}/src/dawn/common", @@ -254,6 +255,7 @@ "unittests/VersionTests.cpp", "unittests/native/BlobTests.cpp", "unittests/native/CacheKeyTests.cpp", + "unittests/native/CacheRequestTests.cpp", "unittests/native/CommandBufferEncodingTests.cpp", "unittests/native/CreatePipelineAsyncTaskTests.cpp", "unittests/native/DestroyObjectTests.cpp", @@ -380,9 +382,9 @@ # Dawn end2end tests targets ############################################################################### -# Source code for mocks used for end2end testing are separated from the rest of +# Source code for mocks used for platform testing are separated from the rest of # sources so that they aren't included in non-test builds. -source_set("end2end_mocks_sources") { +source_set("platform_mocks_sources") { configs += [ "${dawn_root}/src/dawn/native:internal" ] testonly = true @@ -392,8 +394,8 @@ ] sources = [ - "end2end/mocks/CachingInterfaceMock.cpp", - "end2end/mocks/CachingInterfaceMock.h", + "mocks/platform/CachingInterfaceMock.cpp", + "mocks/platform/CachingInterfaceMock.h", ] } @@ -401,7 +403,7 @@ testonly = true deps = [ - ":end2end_mocks_sources", + ":platform_mocks_sources", ":test_infra_sources", "${dawn_root}/src/dawn:cpp", "${dawn_root}/src/dawn:proc",
diff --git a/src/dawn/tests/DawnNativeTest.cpp b/src/dawn/tests/DawnNativeTest.cpp index 163413d..fbf8030 100644 --- a/src/dawn/tests/DawnNativeTest.cpp +++ b/src/dawn/tests/DawnNativeTest.cpp
@@ -20,6 +20,9 @@ #include "dawn/common/Assert.h" #include "dawn/dawn_proc.h" #include "dawn/native/ErrorData.h" +#include "dawn/native/Instance.h" +#include "dawn/native/dawn_platform.h" +#include "dawn/platform/DawnPlatform.h" namespace dawn::native { @@ -43,6 +46,9 @@ void DawnNativeTest::SetUp() { instance = std::make_unique<dawn::native::Instance>(); + platform = CreateTestPlatform(); + dawn::native::FromAPI(instance->Get())->SetPlatformForTesting(platform.get()); + instance->DiscoverDefaultAdapters(); std::vector<dawn::native::Adapter> adapters = instance->GetAdapters(); @@ -66,7 +72,9 @@ device.SetUncapturedErrorCallback(DawnNativeTest::OnDeviceError, nullptr); } -void DawnNativeTest::TearDown() {} +std::unique_ptr<dawn::platform::Platform> DawnNativeTest::CreateTestPlatform() { + return nullptr; +} WGPUDevice DawnNativeTest::CreateTestDevice() { // Disabled disallowing unsafe APIs so we can test them.
diff --git a/src/dawn/tests/DawnNativeTest.h b/src/dawn/tests/DawnNativeTest.h index e92bf67..dd3532f 100644 --- a/src/dawn/tests/DawnNativeTest.h +++ b/src/dawn/tests/DawnNativeTest.h
@@ -38,12 +38,13 @@ ~DawnNativeTest() override; void SetUp() override; - void TearDown() override; + virtual std::unique_ptr<dawn::platform::Platform> CreateTestPlatform(); virtual WGPUDevice CreateTestDevice(); protected: std::unique_ptr<dawn::native::Instance> instance; + std::unique_ptr<dawn::platform::Platform> platform; dawn::native::Adapter adapter; wgpu::Device device;
diff --git a/src/dawn/tests/end2end/D3D12CachingTests.cpp b/src/dawn/tests/end2end/D3D12CachingTests.cpp index 9cd4042..0d3bcdf 100644 --- a/src/dawn/tests/end2end/D3D12CachingTests.cpp +++ b/src/dawn/tests/end2end/D3D12CachingTests.cpp
@@ -17,7 +17,7 @@ #include <utility> #include "dawn/tests/DawnTest.h" -#include "dawn/tests/end2end/mocks/CachingInterfaceMock.h" +#include "dawn/tests/mocks/platform/CachingInterfaceMock.h" #include "dawn/utils/ComboRenderPipelineDescriptor.h" #include "dawn/utils/WGPUHelpers.h"
diff --git a/src/dawn/tests/end2end/PipelineCachingTests.cpp b/src/dawn/tests/end2end/PipelineCachingTests.cpp index 23ee708..8318c48 100644 --- a/src/dawn/tests/end2end/PipelineCachingTests.cpp +++ b/src/dawn/tests/end2end/PipelineCachingTests.cpp
@@ -16,7 +16,7 @@ #include <string_view> #include "dawn/tests/DawnTest.h" -#include "dawn/tests/end2end/mocks/CachingInterfaceMock.h" +#include "dawn/tests/mocks/platform/CachingInterfaceMock.h" #include "dawn/utils/ComboRenderPipelineDescriptor.h" #include "dawn/utils/WGPUHelpers.h"
diff --git a/src/dawn/tests/end2end/mocks/CachingInterfaceMock.cpp b/src/dawn/tests/mocks/platform/CachingInterfaceMock.cpp similarity index 97% rename from src/dawn/tests/end2end/mocks/CachingInterfaceMock.cpp rename to src/dawn/tests/mocks/platform/CachingInterfaceMock.cpp index 8507e9c..a52d4c2 100644 --- a/src/dawn/tests/end2end/mocks/CachingInterfaceMock.cpp +++ b/src/dawn/tests/mocks/platform/CachingInterfaceMock.cpp
@@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "dawn/tests/end2end/mocks/CachingInterfaceMock.h" +#include "dawn/tests/mocks/platform/CachingInterfaceMock.h" using ::testing::Invoke;
diff --git a/src/dawn/tests/end2end/mocks/CachingInterfaceMock.h b/src/dawn/tests/mocks/platform/CachingInterfaceMock.h similarity index 92% rename from src/dawn/tests/end2end/mocks/CachingInterfaceMock.h rename to src/dawn/tests/mocks/platform/CachingInterfaceMock.h index cc61d80..0e9e6af 100644 --- a/src/dawn/tests/end2end/mocks/CachingInterfaceMock.h +++ b/src/dawn/tests/mocks/platform/CachingInterfaceMock.h
@@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef SRC_DAWN_TESTS_END2END_MOCKS_CACHINGINTERFACEMOCK_H_ -#define SRC_DAWN_TESTS_END2END_MOCKS_CACHINGINTERFACEMOCK_H_ +#ifndef SRC_DAWN_TESTS_MOCKS_PLATFORM_CACHINGINTERFACEMOCK_H_ +#define SRC_DAWN_TESTS_MOCKS_PLATFORM_CACHINGINTERFACEMOCK_H_ #include <dawn/platform/DawnPlatform.h> #include <gmock/gmock.h> @@ -70,4 +70,4 @@ dawn::platform::CachingInterface* mCachingInterface = nullptr; }; -#endif // SRC_DAWN_TESTS_END2END_MOCKS_CACHINGINTERFACEMOCK_H_ +#endif // SRC_DAWN_TESTS_MOCKS_PLATFORM_CACHINGINTERFACEMOCK_H_
diff --git a/src/dawn/tests/unittests/native/CacheRequestTests.cpp b/src/dawn/tests/unittests/native/CacheRequestTests.cpp new file mode 100644 index 0000000..995de7f --- /dev/null +++ b/src/dawn/tests/unittests/native/CacheRequestTests.cpp
@@ -0,0 +1,320 @@ +// Copyright 2022 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <memory> +#include <utility> +#include <vector> + +#include "dawn/native/Blob.h" +#include "dawn/native/CacheRequest.h" +#include "dawn/tests/DawnNativeTest.h" +#include "dawn/tests/mocks/platform/CachingInterfaceMock.h" + +namespace dawn::native { + +namespace { + +using ::testing::_; +using ::testing::ByMove; +using ::testing::Invoke; +using ::testing::MockFunction; +using ::testing::Return; +using ::testing::StrictMock; +using ::testing::WithArg; + +class CacheRequestTests : public DawnNativeTest { + protected: + std::unique_ptr<dawn::platform::Platform> CreateTestPlatform() override { + return std::make_unique<DawnCachingMockPlatform>(&mMockCache); + } + + WGPUDevice CreateTestDevice() override { + wgpu::DeviceDescriptor deviceDescriptor = {}; + wgpu::DawnTogglesDeviceDescriptor togglesDesc = {}; + deviceDescriptor.nextInChain = &togglesDesc; + + const char* toggle = "enable_blob_cache"; + togglesDesc.forceEnabledToggles = &toggle; + togglesDesc.forceEnabledTogglesCount = 1; + + return adapter.CreateDevice(&deviceDescriptor); + } + + DeviceBase* GetDevice() { return dawn::native::FromAPI(device.Get()); } + + StrictMock<CachingInterfaceMock> mMockCache; +}; + +struct Foo { + int value; +}; + +#define REQUEST_MEMBERS(X) \ + X(int, a) \ + X(float, b) \ + X(std::vector<unsigned int>, c) \ + X(CacheKey::UnsafeUnkeyedValue<int*>, d) \ + X(CacheKey::UnsafeUnkeyedValue<Foo>, e) + +DAWN_MAKE_CACHE_REQUEST(CacheRequestForTesting, REQUEST_MEMBERS) + +#undef REQUEST_MEMBERS + +// static_assert the expected types for various return types from the cache hit handler and cache +// miss handler. +TEST_F(CacheRequestTests, CacheResultTypes) { + EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillRepeatedly(Return(0)); + + // (int, ResultOrError<int>), should be ResultOrError<CacheResult<int>>. + auto v1 = LoadOrRun( + GetDevice(), CacheRequestForTesting{}, [](Blob) -> int { return 0; }, + [](CacheRequestForTesting) -> ResultOrError<int> { return 1; }); + v1.AcquireSuccess(); + static_assert(std::is_same_v<ResultOrError<CacheResult<int>>, decltype(v1)>); + + // (ResultOrError<float>, ResultOrError<float>), should be ResultOrError<CacheResult<float>>. + auto v2 = LoadOrRun( + GetDevice(), CacheRequestForTesting{}, [](Blob) -> ResultOrError<float> { return 0.0; }, + [](CacheRequestForTesting) -> ResultOrError<float> { return 1.0; }); + v2.AcquireSuccess(); + static_assert(std::is_same_v<ResultOrError<CacheResult<float>>, decltype(v2)>); +} + +// Test that using a CacheRequest builds a key from the device key, the request type enum, and all +// of the request members. +TEST_F(CacheRequestTests, MakesCacheKey) { + // Make a request. + CacheRequestForTesting req; + req.a = 1; + req.b = 0.2; + req.c = {3, 4, 5}; + + // Make the expected key. + CacheKey expectedKey; + expectedKey.Record(GetDevice()->GetCacheKey(), "CacheRequestForTesting", req.a, req.b, req.c); + + // Expect a call to LoadData with the expected key. + EXPECT_CALL(mMockCache, LoadData(_, expectedKey.size(), nullptr, 0)) + .WillOnce(WithArg<0>(Invoke([&](const void* actualKeyData) { + EXPECT_EQ(memcmp(actualKeyData, expectedKey.data(), expectedKey.size()), 0); + return 0; + }))); + + // Load the request. + auto result = LoadOrRun( + GetDevice(), std::move(req), [](Blob) -> int { return 0; }, + [](CacheRequestForTesting) -> ResultOrError<int> { return 0; }) + .AcquireSuccess(); + + // The created cache key should be saved on the result. + EXPECT_EQ(result.GetCacheKey().size(), expectedKey.size()); + EXPECT_EQ(memcmp(result.GetCacheKey().data(), expectedKey.data(), expectedKey.size()), 0); +} + +// Test that members that are wrapped in UnsafeUnkeyedValue do not impact the key. +TEST_F(CacheRequestTests, CacheKeyIgnoresUnsafeIgnoredValue) { + // Make two requests with different UnsafeUnkeyedValues (UnsafeUnkeyed is declared on the struct + // definition). + int v1, v2; + CacheRequestForTesting req1; + req1.d = &v1; + req1.e = Foo{42}; + + CacheRequestForTesting req2; + req2.d = &v2; + req2.e = Foo{24}; + + EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(0)).WillOnce(Return(0)); + + static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn; + + // Load the first request, and check that the unsafe unkeyed values were passed though + EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([&](CacheRequestForTesting req) { + EXPECT_EQ(req.d.UnsafeGetValue(), &v1); + EXPECT_FLOAT_EQ(req.e.UnsafeGetValue().value, 42); + return 0; + }))); + auto r1 = LoadOrRun( + GetDevice(), std::move(req1), [](Blob) { return 0; }, + [](CacheRequestForTesting req) -> ResultOrError<int> { + return cacheMissFn.Call(std::move(req)); + }) + .AcquireSuccess(); + + // Load the second request, and check that the unsafe unkeyed values were passed though + EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([&](CacheRequestForTesting req) { + EXPECT_EQ(req.d.UnsafeGetValue(), &v2); + EXPECT_FLOAT_EQ(req.e.UnsafeGetValue().value, 24); + return 0; + }))); + auto r2 = LoadOrRun( + GetDevice(), std::move(req2), [](Blob) { return 0; }, + [](CacheRequestForTesting req) -> ResultOrError<int> { + return cacheMissFn.Call(std::move(req)); + }) + .AcquireSuccess(); + + // Expect their keys to be the same. + EXPECT_EQ(r1.GetCacheKey().size(), r2.GetCacheKey().size()); + EXPECT_EQ(memcmp(r1.GetCacheKey().data(), r2.GetCacheKey().data(), r1.GetCacheKey().size()), 0); +} + +// Test the expected code path when there is a cache miss. +TEST_F(CacheRequestTests, CacheMiss) { + // Make a request. + CacheRequestForTesting req; + req.a = 1; + req.b = 0.2; + req.c = {3, 4, 5}; + + unsigned int* cPtr = req.c.data(); + + static StrictMock<MockFunction<int(Blob)>> cacheHitFn; + static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn; + + // Mock a cache miss. + EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(0)); + + // Expect the cache miss, and return some value. + int rv = 42; + EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](CacheRequestForTesting req) { + // Expect the request contents to be the same. The data pointer for |c| is also the same + // since it was moved. + EXPECT_EQ(req.a, 1); + EXPECT_FLOAT_EQ(req.b, 0.2); + EXPECT_EQ(req.c.data(), cPtr); + return rv; + }))); + + // Load the request. + auto result = LoadOrRun( + GetDevice(), std::move(req), + [](Blob blob) -> int { return cacheHitFn.Call(std::move(blob)); }, + [](CacheRequestForTesting req) -> ResultOrError<int> { + return cacheMissFn.Call(std::move(req)); + }) + .AcquireSuccess(); + + // Expect the result to store the value. + EXPECT_EQ(*result, rv); + EXPECT_FALSE(result.IsCached()); +} + +// Test the expected code path when there is a cache hit. +TEST_F(CacheRequestTests, CacheHit) { + // Make a request. + CacheRequestForTesting req; + req.a = 1; + req.b = 0.2; + req.c = {3, 4, 5}; + + static StrictMock<MockFunction<int(Blob)>> cacheHitFn; + static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn; + + static constexpr char kCachedData[] = "hello world!"; + + // Mock a cache hit, and load the cached data. + EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(sizeof(kCachedData))); + EXPECT_CALL(mMockCache, LoadData(_, _, _, sizeof(kCachedData))) + .WillOnce(WithArg<2>(Invoke([](void* dataOut) { + memcpy(dataOut, kCachedData, sizeof(kCachedData)); + return sizeof(kCachedData); + }))); + + // Expect the cache hit, and return some value. + int rv = 1337; + EXPECT_CALL(cacheHitFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](Blob blob) { + // Expect the cached blob contents to match the cached data. + EXPECT_EQ(blob.Size(), sizeof(kCachedData)); + EXPECT_EQ(memcmp(blob.Data(), kCachedData, sizeof(kCachedData)), 0); + + return rv; + }))); + + // Load the request. + auto result = LoadOrRun( + GetDevice(), std::move(req), + [](Blob blob) -> int { return cacheHitFn.Call(std::move(blob)); }, + [](CacheRequestForTesting req) -> ResultOrError<int> { + return cacheMissFn.Call(std::move(req)); + }) + .AcquireSuccess(); + + // Expect the result to store the value. + EXPECT_EQ(*result, rv); + EXPECT_TRUE(result.IsCached()); +} + +// Test the expected code path when there is a cache hit but the handler errors. +TEST_F(CacheRequestTests, CacheHitError) { + // Make a request. + CacheRequestForTesting req; + req.a = 1; + req.b = 0.2; + req.c = {3, 4, 5}; + + unsigned int* cPtr = req.c.data(); + + static StrictMock<MockFunction<ResultOrError<int>(Blob)>> cacheHitFn; + static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn; + + static constexpr char kCachedData[] = "hello world!"; + + // Mock a cache hit, and load the cached data. + EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(sizeof(kCachedData))); + EXPECT_CALL(mMockCache, LoadData(_, _, _, sizeof(kCachedData))) + .WillOnce(WithArg<2>(Invoke([](void* dataOut) { + memcpy(dataOut, kCachedData, sizeof(kCachedData)); + return sizeof(kCachedData); + }))); + + // Expect the cache hit. + EXPECT_CALL(cacheHitFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](Blob blob) { + // Expect the cached blob contents to match the cached data. + EXPECT_EQ(blob.Size(), sizeof(kCachedData)); + EXPECT_EQ(memcmp(blob.Data(), kCachedData, sizeof(kCachedData)), 0); + + // Return an error. + return DAWN_VALIDATION_ERROR("fake test error"); + }))); + + // Expect the cache miss handler since the cache hit errored. + int rv = 79; + EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](CacheRequestForTesting req) { + // Expect the request contents to be the same. The data pointer for |c| is also the same + // since it was moved. + EXPECT_EQ(req.a, 1); + EXPECT_FLOAT_EQ(req.b, 0.2); + EXPECT_EQ(req.c.data(), cPtr); + return rv; + }))); + + // Load the request. + auto result = + LoadOrRun( + GetDevice(), std::move(req), + [](Blob blob) -> ResultOrError<int> { return cacheHitFn.Call(std::move(blob)); }, + [](CacheRequestForTesting req) -> ResultOrError<int> { + return cacheMissFn.Call(std::move(req)); + }) + .AcquireSuccess(); + + // Expect the result to store the value. + EXPECT_EQ(*result, rv); + EXPECT_FALSE(result.IsCached()); +} + +} // namespace + +} // namespace dawn::native