blob: b861419ce886b27b0731891cf2409ae10e86c583 [file] [log] [blame]
// Copyright 2022 The Dawn Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <utility>
#include "dawn/common/Assert.h"
#include "dawn/common/Compiler.h"
#include "dawn/native/Blob.h"
#include "dawn/native/BlobCache.h"
#include "dawn/native/CacheKey.h"
#include "dawn/native/CacheResult.h"
#include "dawn/native/Device.h"
#include "dawn/native/Error.h"
#include "dawn/native/VisitableMembers.h"
namespace dawn::native {
namespace detail {
template <typename T>
struct UnwrapResultOrError {
using type = T;
template <typename T>
struct UnwrapResultOrError<ResultOrError<T>> {
using type = T;
template <typename T>
struct IsResultOrError {
static constexpr bool value = false;
template <typename T>
struct IsResultOrError<ResultOrError<T>> {
static constexpr bool value = true;
void LogCacheHitError(std::unique_ptr<ErrorData> error);
} // namespace detail
// Implementation of a CacheRequest which provides a LoadOrRun friend function which can be found
// via argument-dependent lookup. So, it doesn't need to be called with a fully qualified function
// name.
// Example usage:
// Request r = { ... };
// ResultOrError<CacheResult<T>> cacheResult =
// LoadOrRun(device, std::move(r),
// [](Blob blob) -> T { /* handle cache hit */ },
// [](Request r) -> ResultOrError<T> { /* handle cache miss */ }
// );
// Or with free functions:
/// T OnCacheHit(Blob blob) { ... }
// ResultOrError<T> OnCacheMiss(Request r) { ... }
// // ...
// Request r = { ... };
// auto result = LoadOrRun(device, std::move(r), OnCacheHit, OnCacheMiss);
// LoadOrRun generates a CacheKey from the request and loads from the device's BlobCache. On cache
// hit, calls CacheHitFn and returns a CacheResult<T>. On cache miss or if CacheHitFn returned an
// Error, calls CacheMissFn -> ResultOrError<T> with the request data and returns a
// ResultOrError<CacheResult<T>>. CacheHitFn must return the same unwrapped type as CacheMissFn.
// i.e. it doesn't need to be wrapped in ResultOrError.
// CacheMissFn may not have any additional data bound to it. It may not be a lambda or std::function
// which captures additional information, so it can only operate on the request data. This is
// enforced with a compile-time static_assert, and ensures that the result created from the
// computation is exactly the data included in the CacheKey.
template <typename Request>
class CacheRequestImpl {
CacheRequestImpl() = default;
// Require CacheRequests to be move-only to avoid unnecessary copies.
CacheRequestImpl(CacheRequestImpl&&) = default;
CacheRequestImpl& operator=(CacheRequestImpl&&) = default;
CacheRequestImpl(const CacheRequestImpl&) = delete;
CacheRequestImpl& operator=(const CacheRequestImpl&) = delete;
// Create a CacheKey from the request type and all members
CacheKey CreateCacheKey(const DeviceBase* device) const {
CacheKey key = device->GetCacheKey();
StreamIn(&key, Request::kName);
static_cast<const Request*>(this)->VisitAll(
[&](const auto&... members) { StreamIn(&key, members...); });
return key;
template <typename CacheHitFn, typename CacheMissFn>
friend auto LoadOrRun(DeviceBase* device,
Request&& r,
CacheHitFn cacheHitFn,
CacheMissFn cacheMissFn) {
// Get return types and check that CacheMissReturnType can be cast to a raw function
// pointer. This means it's not a std::function or lambda that captures additional data.
using CacheHitReturnType = decltype(cacheHitFn(std::declval<Blob>()));
using CacheMissReturnType = decltype(cacheMissFn(std::declval<Request>()));
std::is_convertible_v<CacheMissFn, CacheMissReturnType (*)(Request)>,
"CacheMissFn function signature does not match, or it is not a free function.");
"CacheMissFn should return a ResultOrError.");
using UnwrappedReturnType = typename detail::UnwrapResultOrError<CacheMissReturnType>::type;
static_assert(std::is_same_v<typename detail::UnwrapResultOrError<CacheHitReturnType>::type,
"If CacheMissFn returns T, CacheHitFn must return T or ResultOrError<T>.");
using CacheResultType = CacheResult<UnwrappedReturnType>;
using ReturnType = ResultOrError<CacheResultType>;
CacheKey key = r.CreateCacheKey(device);
Blob blob = device->GetBlobCache()->Load(key);
if (!blob.Empty()) {
// Cache hit. Handle the cached blob.
auto result = cacheHitFn(std::move(blob));
if constexpr (!detail::IsResultOrError<CacheHitReturnType>::value) {
// If the result type is not a ResultOrError, return it.
return ReturnType(CacheResultType::CacheHit(std::move(key), std::move(result)));
} else {
// Otherwise, if the value is a success, also return it.
if (DAWN_LIKELY(result.IsSuccess())) {
return ReturnType(
CacheResultType::CacheHit(std::move(key), result.AcquireSuccess()));
// On error, continue to the cache miss path and log the error.
// Cache miss, or the CacheHitFn failed.
auto result = cacheMissFn(std::move(r));
if (DAWN_LIKELY(result.IsSuccess())) {
return ReturnType(CacheResultType::CacheMiss(std::move(key), result.AcquireSuccess()));
return ReturnType(result.AcquireError());
} // namespace dawn::native
// Helper for X macro to declare a struct member.
// Helper for X macro for recording cache request fields into a CacheKey.
#define DAWN_INTERNAL_CACHE_REQUEST_RECORD_KEY(type, name) StreamIn(&key, name);
// Helper X macro to define a CacheRequest struct.
// Example usage:
// #define REQUEST_MEMBERS(X) \
// X(int, a) \
// X(float, b) \
// X(Foo, foo) \
// X(Bar, bar)
class Request : public ::dawn::native::CacheRequestImpl<Request> { \
public: \
static constexpr char kName[] = #Request; \
Request() = default; \
// Helper macro for the common pattern of DAWN_TRY_ASSIGN around LoadOrRun.
// Requires an #include of dawn/native/Error.h
#define DAWN_TRY_LOAD_OR_RUN(var, ...) DAWN_TRY_ASSIGN(var, LoadOrRun(__VA_ARGS__))