blob: e06ec3d866cbf12b1cf8c7bd4a86ad3a2add1107 [file] [log] [blame]
// Copyright 2022 The Dawn & Tint Authors
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <memory>
#include <utility>
#include <vector>
#include "dawn/native/Blob.h"
#include "dawn/native/CacheRequest.h"
#include "dawn/tests/DawnNativeTest.h"
#include "dawn/tests/mocks/platform/CachingInterfaceMock.h"
#include "dawn/webgpu_cpp_print.h"
namespace dawn::native {
namespace {
using ::testing::_;
using ::testing::ByMove;
using ::testing::MockFunction;
using ::testing::Return;
using ::testing::StrictMock;
using ::testing::WithArg;
struct CacheRequestTestParam {
bool enableHashValidation;
};
class CacheRequestTests : public DawnNativeTest,
public ::testing::WithParamInterface<CacheRequestTestParam> {
public:
wgpu::DawnTogglesDescriptor DeviceToggles() override {
wgpu::DawnTogglesDescriptor toggles = {};
// Explicitly set the toggle for hash validation based on the test parameter.
if (GetParam().enableHashValidation) {
toggles.enabledToggles = &mHashValidationToggle;
toggles.enabledToggleCount = 1;
} else {
toggles.disabledToggles = &mHashValidationToggle;
toggles.disabledToggleCount = 1;
}
return toggles;
}
protected:
std::unique_ptr<dawn::platform::Platform> CreateTestPlatform() override {
return std::make_unique<DawnCachingMockPlatform>(&mMockCache);
}
DeviceBase* GetDevice() { return dawn::native::FromAPI(device.Get()); }
StrictMock<CachingInterfaceMock> mMockCache;
static constexpr const char* mHashValidationToggle = "blob_cache_hash_validation";
};
struct Foo {
int value;
};
#define REQUEST_MEMBERS(X) \
X(int, a) \
X(float, b) \
X(std::vector<unsigned int>, c) \
X(UnsafeUnserializedValue<int*>, d) \
X(UnsafeUnserializedValue<Foo>, e)
DAWN_MAKE_CACHE_REQUEST(CacheRequestForTesting, REQUEST_MEMBERS);
#undef REQUEST_MEMBERS
// static_assert the expected types for various return types from the cache hit handler and cache
// miss handler.
TEST_P(CacheRequestTests, CacheResultTypes) {
EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillRepeatedly(Return(0));
// (int, ResultOrError<int>), should be ResultOrError<CacheResult<int>>.
auto v1 = LoadOrRun(
GetDevice(), CacheRequestForTesting{}, [](Blob) -> int { return 0; },
[](CacheRequestForTesting) -> ResultOrError<int> { return 1; });
v1.AcquireSuccess();
static_assert(std::is_same_v<ResultOrError<CacheResult<int>>, decltype(v1)>);
// (ResultOrError<float>, ResultOrError<float>), should be ResultOrError<CacheResult<float>>.
auto v2 = LoadOrRun(
GetDevice(), CacheRequestForTesting{}, [](Blob) -> ResultOrError<float> { return 0.0; },
[](CacheRequestForTesting) -> ResultOrError<float> { return 1.0; });
v2.AcquireSuccess();
static_assert(std::is_same_v<ResultOrError<CacheResult<float>>, decltype(v2)>);
}
// Test that using a CacheRequest builds a key from the device key, the request type enum, and all
// of the request members.
TEST_P(CacheRequestTests, MakesCacheKey) {
// Make a request.
CacheRequestForTesting req;
req.a = 1;
req.b = 0.2;
req.c = {3, 4, 5};
// Make the expected key.
CacheKey expectedKey;
StreamIn(&expectedKey, GetDevice()->GetCacheKey(), "CacheRequestForTesting", req.a, req.b,
req.c);
// Expect a call to LoadData with the expected key.
EXPECT_CALL(mMockCache, LoadData(_, expectedKey.size(), nullptr, 0))
.WillOnce(WithArg<0>([&](const void* actualKeyData) {
EXPECT_EQ(memcmp(actualKeyData, expectedKey.data(), expectedKey.size()), 0);
return 0;
}));
// Load the request.
auto result = LoadOrRun(
GetDevice(), std::move(req), [](Blob) -> int { return 0; },
[](CacheRequestForTesting) -> ResultOrError<int> { return 0; })
.AcquireSuccess();
// The created cache key should be saved on the result.
EXPECT_EQ(result.GetCacheKey().size(), expectedKey.size());
EXPECT_EQ(memcmp(result.GetCacheKey().data(), expectedKey.data(), expectedKey.size()), 0);
}
// Test that members that are wrapped in UnsafeUnserializedValue do not impact the key.
TEST_P(CacheRequestTests, CacheKeyIgnoresUnsafeIgnoredValue) {
// Make two requests with different UnsafeUnserializedValue (UnsafeUnkeyed is declared on the
// struct definition).
int v1, v2;
CacheRequestForTesting req1;
req1.d = UnsafeUnserializedValue(&v1);
req1.e = UnsafeUnserializedValue(Foo{42});
CacheRequestForTesting req2;
req2.d = UnsafeUnserializedValue(&v2);
req2.e = UnsafeUnserializedValue(Foo{24});
EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(0)).WillOnce(Return(0));
static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn;
// Load the first request, and check that the unsafe unkeyed values were passed though
EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>([&](CacheRequestForTesting req) {
EXPECT_EQ(req.d.UnsafeGetValue(), &v1);
EXPECT_FLOAT_EQ(req.e.UnsafeGetValue().value, 42);
return 0;
}));
auto r1 = LoadOrRun(
GetDevice(), std::move(req1), [](Blob) { return 0; },
[](CacheRequestForTesting req) -> ResultOrError<int> {
return cacheMissFn.Call(std::move(req));
})
.AcquireSuccess();
// Load the second request, and check that the unsafe unkeyed values were passed though
EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>([&](CacheRequestForTesting req) {
EXPECT_EQ(req.d.UnsafeGetValue(), &v2);
EXPECT_FLOAT_EQ(req.e.UnsafeGetValue().value, 24);
return 0;
}));
auto r2 = LoadOrRun(
GetDevice(), std::move(req2), [](Blob) { return 0; },
[](CacheRequestForTesting req) -> ResultOrError<int> {
return cacheMissFn.Call(std::move(req));
})
.AcquireSuccess();
// Expect their keys to be the same.
EXPECT_EQ(r1.GetCacheKey().size(), r2.GetCacheKey().size());
EXPECT_EQ(memcmp(r1.GetCacheKey().data(), r2.GetCacheKey().data(), r1.GetCacheKey().size()), 0);
}
// Test the expected code path when there is a cache miss.
TEST_P(CacheRequestTests, CacheMiss) {
// Make a request.
CacheRequestForTesting req;
req.a = 1;
req.b = 0.2;
req.c = {3, 4, 5};
unsigned int* cPtr = req.c.data();
static StrictMock<MockFunction<int(Blob)>> cacheHitFn;
static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn;
// Mock a cache miss.
EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(0));
// Expect the cache miss, and return some value.
int rv = 42;
EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>([=](CacheRequestForTesting req) {
// Expect the request contents to be the same. The data pointer for |c| is also the same
// since it was moved.
EXPECT_EQ(req.a, 1);
EXPECT_FLOAT_EQ(req.b, 0.2);
EXPECT_EQ(req.c.data(), cPtr);
return rv;
}));
// Load the request.
auto result = LoadOrRun(
GetDevice(), std::move(req),
[](Blob blob) -> int { return cacheHitFn.Call(std::move(blob)); },
[](CacheRequestForTesting req) -> ResultOrError<int> {
return cacheMissFn.Call(std::move(req));
})
.AcquireSuccess();
// Expect the result to store the value.
EXPECT_EQ(*result, rv);
EXPECT_FALSE(result.IsCached());
}
// Test the expected code path when there is a cache hit.
TEST_P(CacheRequestTests, CacheHit) {
// Make a request.
CacheRequestForTesting req;
req.a = 1;
req.b = 0.2;
req.c = {3, 4, 5};
static StrictMock<MockFunction<int(Blob)>> cacheHitFn;
static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn;
static constexpr char kCachedData[] = "hello world!";
// Bytes actually stored into and loaded from Blob cache might be different from raw given data.
Blob actualStoredData = GetDevice()->GetBlobCache()->GenerateActualStoredBlobForTesting(
sizeof(kCachedData), kCachedData);
// Mock a cache hit, and load the cached data.
EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(actualStoredData.Size()));
EXPECT_CALL(mMockCache, LoadData(_, _, _, actualStoredData.Size()))
.WillOnce(WithArg<2>([&actualStoredData](void* dataOut) {
memcpy(dataOut, actualStoredData.Data(), actualStoredData.Size());
return actualStoredData.Size();
}));
// Expect the cache hit, and return some value.
int rv = 1337;
EXPECT_CALL(cacheHitFn, Call(_)).WillOnce(WithArg<0>([=](Blob blob) {
// Expect the cached blob contents to match the cached data.
EXPECT_EQ(blob.Size(), sizeof(kCachedData));
EXPECT_EQ(memcmp(blob.Data(), kCachedData, sizeof(kCachedData)), 0);
return rv;
}));
// Load the request.
auto result = LoadOrRun(
GetDevice(), std::move(req),
[](Blob blob) -> int { return cacheHitFn.Call(std::move(blob)); },
[](CacheRequestForTesting req) -> ResultOrError<int> {
return cacheMissFn.Call(std::move(req));
})
.AcquireSuccess();
// Expect the result to store the value.
EXPECT_EQ(*result, rv);
EXPECT_TRUE(result.IsCached());
}
// Test the expected code path when there is a cache hit but the handler errors.
TEST_P(CacheRequestTests, CacheHitError) {
// Make a request.
CacheRequestForTesting req;
req.a = 1;
req.b = 0.2;
req.c = {3, 4, 5};
unsigned int* cPtr = req.c.data();
static StrictMock<MockFunction<ResultOrError<int>(Blob)>> cacheHitFn;
static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn;
static constexpr char kCachedData[] = "hello world!";
// Bytes actually stored into and loaded from Blob cache might be different from raw given data.
Blob actualStoredData = GetDevice()->GetBlobCache()->GenerateActualStoredBlobForTesting(
sizeof(kCachedData), kCachedData);
// Mock a cache hit, and load the cached data.
EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(actualStoredData.Size()));
EXPECT_CALL(mMockCache, LoadData(_, _, _, actualStoredData.Size()))
.WillOnce(WithArg<2>([&actualStoredData](void* dataOut) {
memcpy(dataOut, actualStoredData.Data(), actualStoredData.Size());
return actualStoredData.Size();
}));
// Expect the cache hit.
EXPECT_CALL(cacheHitFn, Call(_)).WillOnce(WithArg<0>([=](Blob blob) {
// Expect the cached blob contents to match the cached data.
EXPECT_EQ(blob.Size(), sizeof(kCachedData));
EXPECT_EQ(memcmp(blob.Data(), kCachedData, sizeof(kCachedData)), 0);
// Return an error.
return DAWN_VALIDATION_ERROR("fake test error");
}));
// Expect the cache miss handler since the cache hit errored.
int rv = 79;
EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>([=](CacheRequestForTesting req) {
// Expect the request contents to be the same. The data pointer for |c| is also the same
// since it was moved.
EXPECT_EQ(req.a, 1);
EXPECT_FLOAT_EQ(req.b, 0.2);
EXPECT_EQ(req.c.data(), cPtr);
return rv;
}));
// Load the request.
auto result =
LoadOrRun(
GetDevice(), std::move(req),
[](Blob blob) -> ResultOrError<int> { return cacheHitFn.Call(std::move(blob)); },
[](CacheRequestForTesting req) -> ResultOrError<int> {
return cacheMissFn.Call(std::move(req));
})
.AcquireSuccess();
// Expect the result to store the value.
EXPECT_EQ(*result, rv);
EXPECT_FALSE(result.IsCached());
}
// Test that a cache miss occurs if the two LoadData calls return different sizes.
TEST_P(CacheRequestTests, CacheHitDifferentLoadSizes) {
// Make a request.
CacheRequestForTesting req;
req.a = 1;
req.b = 0.2;
req.c = {3, 4, 5};
unsigned int* cPtr = req.c.data();
static StrictMock<MockFunction<int(Blob)>> cacheHitFn;
static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn;
// Mock a cache hit, but with different sizes returned from LoadData.
const size_t kExpectedSize = 10;
const size_t kActualSize = 5;
EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(kExpectedSize));
EXPECT_CALL(mMockCache, LoadData(_, _, _, kExpectedSize)).WillOnce(Return(kActualSize));
// Expect the cache miss handler since the load sizes were different.
int rv = 79;
EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>([=](CacheRequestForTesting req) {
// Expect the request contents to be the same. The data pointer for |c| is also the same
// since it was moved.
EXPECT_EQ(req.a, 1);
EXPECT_FLOAT_EQ(req.b, 0.2);
EXPECT_EQ(req.c.data(), cPtr);
return rv;
}));
// Load the request.
auto result = LoadOrRun(
GetDevice(), std::move(req),
[](Blob blob) -> int { return cacheHitFn.Call(std::move(blob)); },
[](CacheRequestForTesting req) -> ResultOrError<int> {
return cacheMissFn.Call(std::move(req));
})
.AcquireSuccess();
// Expect the result to store the value from the miss handler.
EXPECT_EQ(*result, rv);
EXPECT_FALSE(result.IsCached());
}
// Test the expected code path when hash validation is enabled, there is a cache hit but the hash
// validation fails. This should be treated as a cache miss, and the cache miss handler should be
// called.
TEST_P(CacheRequestTests, CacheHitHashValidationFailed) {
// Only run this test if hash validation is enabled.
if (!GetParam().enableHashValidation) {
GTEST_SKIP();
}
static constexpr char kCachedData[] = "hello world!";
static constexpr size_t kCachedDataSize = sizeof(kCachedData);
// Bytes actually stored into and loaded from Blob cache might be different from raw given data.
Blob actualStoredData = GetDevice()->GetBlobCache()->GenerateActualStoredBlobForTesting(
kCachedDataSize, kCachedData);
const size_t sizeWithHash = actualStoredData.Size();
// With hash validation enabled, the actual stored data size is larger than kCachedData.
ASSERT_GT(sizeWithHash, kCachedDataSize);
const size_t addedSize = sizeWithHash - kCachedDataSize;
auto DoTest = [&](const void* loadBuffer, size_t loadSize, bool expectHashValidationSuccess) {
// LoadOrRun requires its cache miss handler being a free function (i.e. without any
// capture), so the mock function has to be static to get used in it. However, using the
// same mock function object for different test cases might cause their expectations mixed
// up, so we use a static unique_ptr to ensure that each test case constructs and destructs
// its own mock function object.
static std::unique_ptr<StrictMock<MockFunction<int(Blob)>>> cacheHitFn;
static std::unique_ptr<StrictMock<MockFunction<int(CacheRequestForTesting)>>> cacheMissFn;
constexpr int rvCacheHit = 21;
constexpr int rvCacheMiss = 42;
// Make a request.
CacheRequestForTesting req;
req.a = 1;
req.b = 0.2;
req.c = {3, 4, 5};
unsigned int* cPtr = req.c.data();
// Mock a cache hit with given data buffer.
EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(loadSize));
EXPECT_CALL(mMockCache, LoadData(_, _, _, loadSize))
.WillOnce(WithArg<2>([&](void* dataOut) {
memcpy(dataOut, loadBuffer, loadSize);
return loadSize;
}));
// Construct mock functions for current test.
ASSERT_FALSE(cacheHitFn);
ASSERT_FALSE(cacheMissFn);
cacheHitFn = std::make_unique<StrictMock<MockFunction<int(Blob)>>>();
cacheMissFn = std::make_unique<StrictMock<MockFunction<int(CacheRequestForTesting)>>>();
if (expectHashValidationSuccess) {
// Expect the cache hit handler to be called with the loaded blob.
EXPECT_CALL(*cacheHitFn, Call(_)).WillOnce(WithArg<0>([=](Blob blob) {
// Expect the loaded blob contents to match the cached data.
EXPECT_EQ(blob.Size(), sizeof(kCachedData));
EXPECT_EQ(memcmp(blob.Data(), kCachedData, sizeof(kCachedData)), 0);
return rvCacheHit;
}));
} else {
// Expect the cacheMissFn called and return some value, since hash validation failure
// are treated as miss.
EXPECT_CALL(*cacheMissFn, Call(_)).WillOnce(WithArg<0>([=](CacheRequestForTesting req) {
// Expect the request contents to be the same. The data pointer for |c| is also
// the same since it was moved.
EXPECT_EQ(req.a, 1);
EXPECT_FLOAT_EQ(req.b, 0.2);
EXPECT_EQ(req.c.data(), cPtr);
return rvCacheMiss;
}));
}
// Load the request.
auto result = LoadOrRun(
GetDevice(), std::move(req),
[](Blob blob) -> int { return cacheHitFn->Call(std::move(blob)); },
[](CacheRequestForTesting req) -> ResultOrError<int> {
return cacheMissFn->Call(std::move(req));
})
.AcquireSuccess();
if (expectHashValidationSuccess) {
// Expect the result to hold the return value from cache hit.
EXPECT_EQ(*result, rvCacheHit);
EXPECT_TRUE(result.IsCached());
} else {
// Expect the result to hold the return value from cache miss.
EXPECT_EQ(*result, rvCacheMiss);
EXPECT_FALSE(result.IsCached());
}
// Destruct the mock functions to ensure all expected behavior happened.
cacheHitFn.reset();
cacheMissFn.reset();
};
// Control case: hash validation success.
{
DoTest(actualStoredData.Data(), sizeWithHash, true);
}
// Hash validation failure case 1: loaded blob size too small.
{
static constexpr uint8_t tooSmallBuffer[] = "0";
static constexpr size_t tooSmallBufferSize = sizeof(tooSmallBuffer);
ASSERT_LT(tooSmallBufferSize, addedSize);
DoTest(tooSmallBuffer, tooSmallBufferSize, false);
}
// Hash validation failure case 2: loaded blob hash mismatched.
{
Blob modifiedStoredData = CreateBlob(sizeWithHash);
memcpy(modifiedStoredData.Data(), actualStoredData.Data(), sizeWithHash);
// Modify the last byte to make the hash mismatch.
modifiedStoredData.Data()[sizeWithHash - 1] = ~modifiedStoredData.Data()[sizeWithHash - 1];
DoTest(modifiedStoredData.Data(), sizeWithHash, false);
}
}
INSTANTIATE_TEST_SUITE_P(,
CacheRequestTests,
testing::Values(CacheRequestTestParam{.enableHashValidation = false},
CacheRequestTestParam{.enableHashValidation = true}));
} // anonymous namespace
} // namespace dawn::native