| // Copyright 2022 The Dawn Authors |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include <memory> |
| #include <utility> |
| #include <vector> |
| |
| #include "dawn/native/Blob.h" |
| #include "dawn/native/CacheRequest.h" |
| #include "dawn/tests/DawnNativeTest.h" |
| #include "dawn/tests/mocks/platform/CachingInterfaceMock.h" |
| |
| namespace dawn::native { |
| |
| namespace { |
| |
| using ::testing::_; |
| using ::testing::ByMove; |
| using ::testing::Invoke; |
| using ::testing::MockFunction; |
| using ::testing::Return; |
| using ::testing::StrictMock; |
| using ::testing::WithArg; |
| |
| class CacheRequestTests : public DawnNativeTest { |
| protected: |
| std::unique_ptr<dawn::platform::Platform> CreateTestPlatform() override { |
| return std::make_unique<DawnCachingMockPlatform>(&mMockCache); |
| } |
| |
| WGPUDevice CreateTestDevice() override { |
| wgpu::DeviceDescriptor deviceDescriptor = {}; |
| wgpu::DawnTogglesDeviceDescriptor togglesDesc = {}; |
| deviceDescriptor.nextInChain = &togglesDesc; |
| |
| const char* toggle = "enable_blob_cache"; |
| togglesDesc.forceEnabledToggles = &toggle; |
| togglesDesc.forceEnabledTogglesCount = 1; |
| |
| return adapter.CreateDevice(&deviceDescriptor); |
| } |
| |
| DeviceBase* GetDevice() { return dawn::native::FromAPI(device.Get()); } |
| |
| StrictMock<CachingInterfaceMock> mMockCache; |
| }; |
| |
| struct Foo { |
| int value; |
| }; |
| |
| #define REQUEST_MEMBERS(X) \ |
| X(int, a) \ |
| X(float, b) \ |
| X(std::vector<unsigned int>, c) \ |
| X(CacheKey::UnsafeUnkeyedValue<int*>, d) \ |
| X(CacheKey::UnsafeUnkeyedValue<Foo>, e) |
| |
| DAWN_MAKE_CACHE_REQUEST(CacheRequestForTesting, REQUEST_MEMBERS); |
| |
| #undef REQUEST_MEMBERS |
| |
| // static_assert the expected types for various return types from the cache hit handler and cache |
| // miss handler. |
| TEST_F(CacheRequestTests, CacheResultTypes) { |
| EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillRepeatedly(Return(0)); |
| |
| // (int, ResultOrError<int>), should be ResultOrError<CacheResult<int>>. |
| auto v1 = LoadOrRun( |
| GetDevice(), CacheRequestForTesting{}, [](Blob) -> int { return 0; }, |
| [](CacheRequestForTesting) -> ResultOrError<int> { return 1; }); |
| v1.AcquireSuccess(); |
| static_assert(std::is_same_v<ResultOrError<CacheResult<int>>, decltype(v1)>); |
| |
| // (ResultOrError<float>, ResultOrError<float>), should be ResultOrError<CacheResult<float>>. |
| auto v2 = LoadOrRun( |
| GetDevice(), CacheRequestForTesting{}, [](Blob) -> ResultOrError<float> { return 0.0; }, |
| [](CacheRequestForTesting) -> ResultOrError<float> { return 1.0; }); |
| v2.AcquireSuccess(); |
| static_assert(std::is_same_v<ResultOrError<CacheResult<float>>, decltype(v2)>); |
| } |
| |
| // Test that using a CacheRequest builds a key from the device key, the request type enum, and all |
| // of the request members. |
| TEST_F(CacheRequestTests, MakesCacheKey) { |
| // Make a request. |
| CacheRequestForTesting req; |
| req.a = 1; |
| req.b = 0.2; |
| req.c = {3, 4, 5}; |
| |
| // Make the expected key. |
| CacheKey expectedKey; |
| StreamIn(&expectedKey, GetDevice()->GetCacheKey(), "CacheRequestForTesting", req.a, req.b, |
| req.c); |
| |
| // Expect a call to LoadData with the expected key. |
| EXPECT_CALL(mMockCache, LoadData(_, expectedKey.size(), nullptr, 0)) |
| .WillOnce(WithArg<0>(Invoke([&](const void* actualKeyData) { |
| EXPECT_EQ(memcmp(actualKeyData, expectedKey.data(), expectedKey.size()), 0); |
| return 0; |
| }))); |
| |
| // Load the request. |
| auto result = LoadOrRun( |
| GetDevice(), std::move(req), [](Blob) -> int { return 0; }, |
| [](CacheRequestForTesting) -> ResultOrError<int> { return 0; }) |
| .AcquireSuccess(); |
| |
| // The created cache key should be saved on the result. |
| EXPECT_EQ(result.GetCacheKey().size(), expectedKey.size()); |
| EXPECT_EQ(memcmp(result.GetCacheKey().data(), expectedKey.data(), expectedKey.size()), 0); |
| } |
| |
| // Test that members that are wrapped in UnsafeUnkeyedValue do not impact the key. |
| TEST_F(CacheRequestTests, CacheKeyIgnoresUnsafeIgnoredValue) { |
| // Make two requests with different UnsafeUnkeyedValues (UnsafeUnkeyed is declared on the struct |
| // definition). |
| int v1, v2; |
| CacheRequestForTesting req1; |
| req1.d = &v1; |
| req1.e = Foo{42}; |
| |
| CacheRequestForTesting req2; |
| req2.d = &v2; |
| req2.e = Foo{24}; |
| |
| EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(0)).WillOnce(Return(0)); |
| |
| static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn; |
| |
| // Load the first request, and check that the unsafe unkeyed values were passed though |
| EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([&](CacheRequestForTesting req) { |
| EXPECT_EQ(req.d.UnsafeGetValue(), &v1); |
| EXPECT_FLOAT_EQ(req.e.UnsafeGetValue().value, 42); |
| return 0; |
| }))); |
| auto r1 = LoadOrRun( |
| GetDevice(), std::move(req1), [](Blob) { return 0; }, |
| [](CacheRequestForTesting req) -> ResultOrError<int> { |
| return cacheMissFn.Call(std::move(req)); |
| }) |
| .AcquireSuccess(); |
| |
| // Load the second request, and check that the unsafe unkeyed values were passed though |
| EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([&](CacheRequestForTesting req) { |
| EXPECT_EQ(req.d.UnsafeGetValue(), &v2); |
| EXPECT_FLOAT_EQ(req.e.UnsafeGetValue().value, 24); |
| return 0; |
| }))); |
| auto r2 = LoadOrRun( |
| GetDevice(), std::move(req2), [](Blob) { return 0; }, |
| [](CacheRequestForTesting req) -> ResultOrError<int> { |
| return cacheMissFn.Call(std::move(req)); |
| }) |
| .AcquireSuccess(); |
| |
| // Expect their keys to be the same. |
| EXPECT_EQ(r1.GetCacheKey().size(), r2.GetCacheKey().size()); |
| EXPECT_EQ(memcmp(r1.GetCacheKey().data(), r2.GetCacheKey().data(), r1.GetCacheKey().size()), 0); |
| } |
| |
| // Test the expected code path when there is a cache miss. |
| TEST_F(CacheRequestTests, CacheMiss) { |
| // Make a request. |
| CacheRequestForTesting req; |
| req.a = 1; |
| req.b = 0.2; |
| req.c = {3, 4, 5}; |
| |
| unsigned int* cPtr = req.c.data(); |
| |
| static StrictMock<MockFunction<int(Blob)>> cacheHitFn; |
| static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn; |
| |
| // Mock a cache miss. |
| EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(0)); |
| |
| // Expect the cache miss, and return some value. |
| int rv = 42; |
| EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](CacheRequestForTesting req) { |
| // Expect the request contents to be the same. The data pointer for |c| is also the same |
| // since it was moved. |
| EXPECT_EQ(req.a, 1); |
| EXPECT_FLOAT_EQ(req.b, 0.2); |
| EXPECT_EQ(req.c.data(), cPtr); |
| return rv; |
| }))); |
| |
| // Load the request. |
| auto result = LoadOrRun( |
| GetDevice(), std::move(req), |
| [](Blob blob) -> int { return cacheHitFn.Call(std::move(blob)); }, |
| [](CacheRequestForTesting req) -> ResultOrError<int> { |
| return cacheMissFn.Call(std::move(req)); |
| }) |
| .AcquireSuccess(); |
| |
| // Expect the result to store the value. |
| EXPECT_EQ(*result, rv); |
| EXPECT_FALSE(result.IsCached()); |
| } |
| |
| // Test the expected code path when there is a cache hit. |
| TEST_F(CacheRequestTests, CacheHit) { |
| // Make a request. |
| CacheRequestForTesting req; |
| req.a = 1; |
| req.b = 0.2; |
| req.c = {3, 4, 5}; |
| |
| static StrictMock<MockFunction<int(Blob)>> cacheHitFn; |
| static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn; |
| |
| static constexpr char kCachedData[] = "hello world!"; |
| |
| // Mock a cache hit, and load the cached data. |
| EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(sizeof(kCachedData))); |
| EXPECT_CALL(mMockCache, LoadData(_, _, _, sizeof(kCachedData))) |
| .WillOnce(WithArg<2>(Invoke([](void* dataOut) { |
| memcpy(dataOut, kCachedData, sizeof(kCachedData)); |
| return sizeof(kCachedData); |
| }))); |
| |
| // Expect the cache hit, and return some value. |
| int rv = 1337; |
| EXPECT_CALL(cacheHitFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](Blob blob) { |
| // Expect the cached blob contents to match the cached data. |
| EXPECT_EQ(blob.Size(), sizeof(kCachedData)); |
| EXPECT_EQ(memcmp(blob.Data(), kCachedData, sizeof(kCachedData)), 0); |
| |
| return rv; |
| }))); |
| |
| // Load the request. |
| auto result = LoadOrRun( |
| GetDevice(), std::move(req), |
| [](Blob blob) -> int { return cacheHitFn.Call(std::move(blob)); }, |
| [](CacheRequestForTesting req) -> ResultOrError<int> { |
| return cacheMissFn.Call(std::move(req)); |
| }) |
| .AcquireSuccess(); |
| |
| // Expect the result to store the value. |
| EXPECT_EQ(*result, rv); |
| EXPECT_TRUE(result.IsCached()); |
| } |
| |
| // Test the expected code path when there is a cache hit but the handler errors. |
| TEST_F(CacheRequestTests, CacheHitError) { |
| // Make a request. |
| CacheRequestForTesting req; |
| req.a = 1; |
| req.b = 0.2; |
| req.c = {3, 4, 5}; |
| |
| unsigned int* cPtr = req.c.data(); |
| |
| static StrictMock<MockFunction<ResultOrError<int>(Blob)>> cacheHitFn; |
| static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn; |
| |
| static constexpr char kCachedData[] = "hello world!"; |
| |
| // Mock a cache hit, and load the cached data. |
| EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(sizeof(kCachedData))); |
| EXPECT_CALL(mMockCache, LoadData(_, _, _, sizeof(kCachedData))) |
| .WillOnce(WithArg<2>(Invoke([](void* dataOut) { |
| memcpy(dataOut, kCachedData, sizeof(kCachedData)); |
| return sizeof(kCachedData); |
| }))); |
| |
| // Expect the cache hit. |
| EXPECT_CALL(cacheHitFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](Blob blob) { |
| // Expect the cached blob contents to match the cached data. |
| EXPECT_EQ(blob.Size(), sizeof(kCachedData)); |
| EXPECT_EQ(memcmp(blob.Data(), kCachedData, sizeof(kCachedData)), 0); |
| |
| // Return an error. |
| return DAWN_VALIDATION_ERROR("fake test error"); |
| }))); |
| |
| // Expect the cache miss handler since the cache hit errored. |
| int rv = 79; |
| EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](CacheRequestForTesting req) { |
| // Expect the request contents to be the same. The data pointer for |c| is also the same |
| // since it was moved. |
| EXPECT_EQ(req.a, 1); |
| EXPECT_FLOAT_EQ(req.b, 0.2); |
| EXPECT_EQ(req.c.data(), cPtr); |
| return rv; |
| }))); |
| |
| // Load the request. |
| auto result = |
| LoadOrRun( |
| GetDevice(), std::move(req), |
| [](Blob blob) -> ResultOrError<int> { return cacheHitFn.Call(std::move(blob)); }, |
| [](CacheRequestForTesting req) -> ResultOrError<int> { |
| return cacheMissFn.Call(std::move(req)); |
| }) |
| .AcquireSuccess(); |
| |
| // Expect the result to store the value. |
| EXPECT_EQ(*result, rv); |
| EXPECT_FALSE(result.IsCached()); |
| } |
| |
| } // namespace |
| |
| } // namespace dawn::native |