blob: 0fecb0da573db055eaff4af3bfdee9d0150b1b7b [file] [log] [blame]
// Copyright 2022 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_DAWN_NATIVE_CACHEREQUEST_H_
#define SRC_DAWN_NATIVE_CACHEREQUEST_H_
#include <memory>
#include <utility>
#include "dawn/common/Assert.h"
#include "dawn/common/Compiler.h"
#include "dawn/native/Blob.h"
#include "dawn/native/BlobCache.h"
#include "dawn/native/CacheKey.h"
#include "dawn/native/CacheResult.h"
#include "dawn/native/Device.h"
#include "dawn/native/Error.h"
namespace dawn::native {
namespace detail {
template <typename T>
struct UnwrapResultOrError {
using type = T;
};
template <typename T>
struct UnwrapResultOrError<ResultOrError<T>> {
using type = T;
};
template <typename T>
struct IsResultOrError {
static constexpr bool value = false;
};
template <typename T>
struct IsResultOrError<ResultOrError<T>> {
static constexpr bool value = true;
};
void LogCacheHitError(std::unique_ptr<ErrorData> error);
} // namespace detail
// Implementation of a CacheRequest which provides a LoadOrRun friend function which can be found
// via argument-dependent lookup. So, it doesn't need to be called with a fully qualified function
// name.
//
// Example usage:
// Request r = { ... };
// ResultOrError<CacheResult<T>> cacheResult =
// LoadOrRun(device, std::move(r),
// [](Blob blob) -> T { /* handle cache hit */ },
// [](Request r) -> ResultOrError<T> { /* handle cache miss */ }
// );
// Or with free functions:
/// T OnCacheHit(Blob blob) { ... }
// ResultOrError<T> OnCacheMiss(Request r) { ... }
// // ...
// Request r = { ... };
// auto result = LoadOrRun(device, std::move(r), OnCacheHit, OnCacheMiss);
//
// LoadOrRun generates a CacheKey from the request and loads from the device's BlobCache. On cache
// hit, calls CacheHitFn and returns a CacheResult<T>. On cache miss or if CacheHitFn returned an
// Error, calls CacheMissFn -> ResultOrError<T> with the request data and returns a
// ResultOrError<CacheResult<T>>. CacheHitFn must return the same unwrapped type as CacheMissFn.
// i.e. it doesn't need to be wrapped in ResultOrError.
//
// CacheMissFn may not have any additional data bound to it. It may not be a lambda or std::function
// which captures additional information, so it can only operate on the request data. This is
// enforced with a compile-time static_assert, and ensures that the result created from the
// computation is exactly the data included in the CacheKey.
template <typename Request>
class CacheRequestImpl {
public:
CacheRequestImpl() = default;
// Require CacheRequests to be move-only to avoid unnecessary copies.
CacheRequestImpl(CacheRequestImpl&&) = default;
CacheRequestImpl& operator=(CacheRequestImpl&&) = default;
CacheRequestImpl(const CacheRequestImpl&) = delete;
CacheRequestImpl& operator=(const CacheRequestImpl&) = delete;
template <typename CacheHitFn, typename CacheMissFn>
friend auto LoadOrRun(DeviceBase* device,
Request&& r,
CacheHitFn cacheHitFn,
CacheMissFn cacheMissFn) {
// Get return types and check that CacheMissReturnType can be cast to a raw function
// pointer. This means it's not a std::function or lambda that captures additional data.
using CacheHitReturnType = decltype(cacheHitFn(std::declval<Blob>()));
using CacheMissReturnType = decltype(cacheMissFn(std::declval<Request>()));
static_assert(
std::is_convertible_v<CacheMissFn, CacheMissReturnType (*)(Request)>,
"CacheMissFn function signature does not match, or it is not a free function.");
static_assert(detail::IsResultOrError<CacheMissReturnType>::value,
"CacheMissFn should return a ResultOrError.");
using UnwrappedReturnType = typename detail::UnwrapResultOrError<CacheMissReturnType>::type;
static_assert(std::is_same_v<typename detail::UnwrapResultOrError<CacheHitReturnType>::type,
UnwrappedReturnType>,
"If CacheMissFn returns T, CacheHitFn must return T or ResultOrError<T>.");
using CacheResultType = CacheResult<UnwrappedReturnType>;
using ReturnType = ResultOrError<CacheResultType>;
CacheKey key = r.CreateCacheKey(device);
BlobCache* cache = device->GetBlobCache();
Blob blob;
if (cache != nullptr) {
blob = cache->Load(key);
}
if (!blob.Empty()) {
// Cache hit. Handle the cached blob.
auto result = cacheHitFn(std::move(blob));
if constexpr (!detail::IsResultOrError<CacheHitReturnType>::value) {
// If the result type is not a ResultOrError, return it.
return ReturnType(CacheResultType::CacheHit(std::move(key), std::move(result)));
} else {
// Otherwise, if the value is a success, also return it.
if (DAWN_LIKELY(result.IsSuccess())) {
return ReturnType(
CacheResultType::CacheHit(std::move(key), result.AcquireSuccess()));
}
// On error, continue to the cache miss path and log the error.
detail::LogCacheHitError(result.AcquireError());
}
}
// Cache miss, or the CacheHitFn failed.
auto result = cacheMissFn(std::move(r));
if (DAWN_LIKELY(result.IsSuccess())) {
return ReturnType(CacheResultType::CacheMiss(std::move(key), result.AcquireSuccess()));
}
return ReturnType(result.AcquireError());
}
};
// Helper for X macro to declare a struct member.
#define DAWN_INTERNAL_CACHE_REQUEST_DECL_STRUCT_MEMBER(type, name) type name{};
// Helper for X macro for recording cache request fields into a CacheKey.
#define DAWN_INTERNAL_CACHE_REQUEST_RECORD_KEY(type, name) key.Record(name);
// Helper X macro to define a CacheRequest struct.
// Example usage:
// #define REQUEST_MEMBERS(X) \
// X(int, a) \
// X(float, b) \
// X(Foo, foo) \
// X(Bar, bar)
// DAWN_MAKE_CACHE_REQUEST(MyCacheRequest, REQUEST_MEMBERS)
// #undef REQUEST_MEMBERS
#define DAWN_MAKE_CACHE_REQUEST(Request, MEMBERS) \
class Request : public CacheRequestImpl<Request> { \
public: \
Request() = default; \
MEMBERS(DAWN_INTERNAL_CACHE_REQUEST_DECL_STRUCT_MEMBER) \
\
/* Create a CacheKey from the request type and all members */ \
CacheKey CreateCacheKey(const DeviceBase* device) const { \
CacheKey key = device->GetCacheKey(); \
key.Record(#Request); \
MEMBERS(DAWN_INTERNAL_CACHE_REQUEST_RECORD_KEY) \
return key; \
} \
};
} // namespace dawn::native
#endif // SRC_DAWN_NATIVE_CACHEREQUEST_H_