blob: 995de7f5203f029b137bdaa6513b0b6d990e653c [file] [log] [blame]
// Copyright 2022 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <utility>
#include <vector>
#include "dawn/native/Blob.h"
#include "dawn/native/CacheRequest.h"
#include "dawn/tests/DawnNativeTest.h"
#include "dawn/tests/mocks/platform/CachingInterfaceMock.h"
namespace dawn::native {
namespace {
using ::testing::_;
using ::testing::ByMove;
using ::testing::Invoke;
using ::testing::MockFunction;
using ::testing::Return;
using ::testing::StrictMock;
using ::testing::WithArg;
class CacheRequestTests : public DawnNativeTest {
protected:
std::unique_ptr<dawn::platform::Platform> CreateTestPlatform() override {
return std::make_unique<DawnCachingMockPlatform>(&mMockCache);
}
WGPUDevice CreateTestDevice() override {
wgpu::DeviceDescriptor deviceDescriptor = {};
wgpu::DawnTogglesDeviceDescriptor togglesDesc = {};
deviceDescriptor.nextInChain = &togglesDesc;
const char* toggle = "enable_blob_cache";
togglesDesc.forceEnabledToggles = &toggle;
togglesDesc.forceEnabledTogglesCount = 1;
return adapter.CreateDevice(&deviceDescriptor);
}
DeviceBase* GetDevice() { return dawn::native::FromAPI(device.Get()); }
StrictMock<CachingInterfaceMock> mMockCache;
};
struct Foo {
int value;
};
#define REQUEST_MEMBERS(X) \
X(int, a) \
X(float, b) \
X(std::vector<unsigned int>, c) \
X(CacheKey::UnsafeUnkeyedValue<int*>, d) \
X(CacheKey::UnsafeUnkeyedValue<Foo>, e)
DAWN_MAKE_CACHE_REQUEST(CacheRequestForTesting, REQUEST_MEMBERS)
#undef REQUEST_MEMBERS
// static_assert the expected types for various return types from the cache hit handler and cache
// miss handler.
TEST_F(CacheRequestTests, CacheResultTypes) {
EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillRepeatedly(Return(0));
// (int, ResultOrError<int>), should be ResultOrError<CacheResult<int>>.
auto v1 = LoadOrRun(
GetDevice(), CacheRequestForTesting{}, [](Blob) -> int { return 0; },
[](CacheRequestForTesting) -> ResultOrError<int> { return 1; });
v1.AcquireSuccess();
static_assert(std::is_same_v<ResultOrError<CacheResult<int>>, decltype(v1)>);
// (ResultOrError<float>, ResultOrError<float>), should be ResultOrError<CacheResult<float>>.
auto v2 = LoadOrRun(
GetDevice(), CacheRequestForTesting{}, [](Blob) -> ResultOrError<float> { return 0.0; },
[](CacheRequestForTesting) -> ResultOrError<float> { return 1.0; });
v2.AcquireSuccess();
static_assert(std::is_same_v<ResultOrError<CacheResult<float>>, decltype(v2)>);
}
// Test that using a CacheRequest builds a key from the device key, the request type enum, and all
// of the request members.
TEST_F(CacheRequestTests, MakesCacheKey) {
// Make a request.
CacheRequestForTesting req;
req.a = 1;
req.b = 0.2;
req.c = {3, 4, 5};
// Make the expected key.
CacheKey expectedKey;
expectedKey.Record(GetDevice()->GetCacheKey(), "CacheRequestForTesting", req.a, req.b, req.c);
// Expect a call to LoadData with the expected key.
EXPECT_CALL(mMockCache, LoadData(_, expectedKey.size(), nullptr, 0))
.WillOnce(WithArg<0>(Invoke([&](const void* actualKeyData) {
EXPECT_EQ(memcmp(actualKeyData, expectedKey.data(), expectedKey.size()), 0);
return 0;
})));
// Load the request.
auto result = LoadOrRun(
GetDevice(), std::move(req), [](Blob) -> int { return 0; },
[](CacheRequestForTesting) -> ResultOrError<int> { return 0; })
.AcquireSuccess();
// The created cache key should be saved on the result.
EXPECT_EQ(result.GetCacheKey().size(), expectedKey.size());
EXPECT_EQ(memcmp(result.GetCacheKey().data(), expectedKey.data(), expectedKey.size()), 0);
}
// Test that members that are wrapped in UnsafeUnkeyedValue do not impact the key.
TEST_F(CacheRequestTests, CacheKeyIgnoresUnsafeIgnoredValue) {
// Make two requests with different UnsafeUnkeyedValues (UnsafeUnkeyed is declared on the struct
// definition).
int v1, v2;
CacheRequestForTesting req1;
req1.d = &v1;
req1.e = Foo{42};
CacheRequestForTesting req2;
req2.d = &v2;
req2.e = Foo{24};
EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(0)).WillOnce(Return(0));
static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn;
// Load the first request, and check that the unsafe unkeyed values were passed though
EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([&](CacheRequestForTesting req) {
EXPECT_EQ(req.d.UnsafeGetValue(), &v1);
EXPECT_FLOAT_EQ(req.e.UnsafeGetValue().value, 42);
return 0;
})));
auto r1 = LoadOrRun(
GetDevice(), std::move(req1), [](Blob) { return 0; },
[](CacheRequestForTesting req) -> ResultOrError<int> {
return cacheMissFn.Call(std::move(req));
})
.AcquireSuccess();
// Load the second request, and check that the unsafe unkeyed values were passed though
EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([&](CacheRequestForTesting req) {
EXPECT_EQ(req.d.UnsafeGetValue(), &v2);
EXPECT_FLOAT_EQ(req.e.UnsafeGetValue().value, 24);
return 0;
})));
auto r2 = LoadOrRun(
GetDevice(), std::move(req2), [](Blob) { return 0; },
[](CacheRequestForTesting req) -> ResultOrError<int> {
return cacheMissFn.Call(std::move(req));
})
.AcquireSuccess();
// Expect their keys to be the same.
EXPECT_EQ(r1.GetCacheKey().size(), r2.GetCacheKey().size());
EXPECT_EQ(memcmp(r1.GetCacheKey().data(), r2.GetCacheKey().data(), r1.GetCacheKey().size()), 0);
}
// Test the expected code path when there is a cache miss.
TEST_F(CacheRequestTests, CacheMiss) {
// Make a request.
CacheRequestForTesting req;
req.a = 1;
req.b = 0.2;
req.c = {3, 4, 5};
unsigned int* cPtr = req.c.data();
static StrictMock<MockFunction<int(Blob)>> cacheHitFn;
static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn;
// Mock a cache miss.
EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(0));
// Expect the cache miss, and return some value.
int rv = 42;
EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](CacheRequestForTesting req) {
// Expect the request contents to be the same. The data pointer for |c| is also the same
// since it was moved.
EXPECT_EQ(req.a, 1);
EXPECT_FLOAT_EQ(req.b, 0.2);
EXPECT_EQ(req.c.data(), cPtr);
return rv;
})));
// Load the request.
auto result = LoadOrRun(
GetDevice(), std::move(req),
[](Blob blob) -> int { return cacheHitFn.Call(std::move(blob)); },
[](CacheRequestForTesting req) -> ResultOrError<int> {
return cacheMissFn.Call(std::move(req));
})
.AcquireSuccess();
// Expect the result to store the value.
EXPECT_EQ(*result, rv);
EXPECT_FALSE(result.IsCached());
}
// Test the expected code path when there is a cache hit.
TEST_F(CacheRequestTests, CacheHit) {
// Make a request.
CacheRequestForTesting req;
req.a = 1;
req.b = 0.2;
req.c = {3, 4, 5};
static StrictMock<MockFunction<int(Blob)>> cacheHitFn;
static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn;
static constexpr char kCachedData[] = "hello world!";
// Mock a cache hit, and load the cached data.
EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(sizeof(kCachedData)));
EXPECT_CALL(mMockCache, LoadData(_, _, _, sizeof(kCachedData)))
.WillOnce(WithArg<2>(Invoke([](void* dataOut) {
memcpy(dataOut, kCachedData, sizeof(kCachedData));
return sizeof(kCachedData);
})));
// Expect the cache hit, and return some value.
int rv = 1337;
EXPECT_CALL(cacheHitFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](Blob blob) {
// Expect the cached blob contents to match the cached data.
EXPECT_EQ(blob.Size(), sizeof(kCachedData));
EXPECT_EQ(memcmp(blob.Data(), kCachedData, sizeof(kCachedData)), 0);
return rv;
})));
// Load the request.
auto result = LoadOrRun(
GetDevice(), std::move(req),
[](Blob blob) -> int { return cacheHitFn.Call(std::move(blob)); },
[](CacheRequestForTesting req) -> ResultOrError<int> {
return cacheMissFn.Call(std::move(req));
})
.AcquireSuccess();
// Expect the result to store the value.
EXPECT_EQ(*result, rv);
EXPECT_TRUE(result.IsCached());
}
// Test the expected code path when there is a cache hit but the handler errors.
TEST_F(CacheRequestTests, CacheHitError) {
// Make a request.
CacheRequestForTesting req;
req.a = 1;
req.b = 0.2;
req.c = {3, 4, 5};
unsigned int* cPtr = req.c.data();
static StrictMock<MockFunction<ResultOrError<int>(Blob)>> cacheHitFn;
static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn;
static constexpr char kCachedData[] = "hello world!";
// Mock a cache hit, and load the cached data.
EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(sizeof(kCachedData)));
EXPECT_CALL(mMockCache, LoadData(_, _, _, sizeof(kCachedData)))
.WillOnce(WithArg<2>(Invoke([](void* dataOut) {
memcpy(dataOut, kCachedData, sizeof(kCachedData));
return sizeof(kCachedData);
})));
// Expect the cache hit.
EXPECT_CALL(cacheHitFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](Blob blob) {
// Expect the cached blob contents to match the cached data.
EXPECT_EQ(blob.Size(), sizeof(kCachedData));
EXPECT_EQ(memcmp(blob.Data(), kCachedData, sizeof(kCachedData)), 0);
// Return an error.
return DAWN_VALIDATION_ERROR("fake test error");
})));
// Expect the cache miss handler since the cache hit errored.
int rv = 79;
EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](CacheRequestForTesting req) {
// Expect the request contents to be the same. The data pointer for |c| is also the same
// since it was moved.
EXPECT_EQ(req.a, 1);
EXPECT_FLOAT_EQ(req.b, 0.2);
EXPECT_EQ(req.c.data(), cPtr);
return rv;
})));
// Load the request.
auto result =
LoadOrRun(
GetDevice(), std::move(req),
[](Blob blob) -> ResultOrError<int> { return cacheHitFn.Call(std::move(blob)); },
[](CacheRequestForTesting req) -> ResultOrError<int> {
return cacheMissFn.Call(std::move(req));
})
.AcquireSuccess();
// Expect the result to store the value.
EXPECT_EQ(*result, rv);
EXPECT_FALSE(result.IsCached());
}
} // namespace
} // namespace dawn::native