Initial implementation of a multithreaded fronted cache.
- Adds unit tests suite to verify the basic usages of the cache.
- Note: This change does NOT remove the device lock that is acquired
for object deletion at the moment. So as a result, the actual
effects shouldn't be apparent with this change. Furthermore,
this most likely will not be enabled on its own without a
parallel effort to add a WeakRef concept so that the cache
would hold on top WeakRefs to objects instead of raw pointers.
Bug: dawn:1769
Change-Id: I0f6a11f01c558875c2b120a55aa3c4232b501a3c
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/136582
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Loko Kung <lokokung@google.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/dawn/common/BUILD.gn b/src/dawn/common/BUILD.gn
index 7509d0d..762312d 100644
--- a/src/dawn/common/BUILD.gn
+++ b/src/dawn/common/BUILD.gn
@@ -233,6 +233,7 @@
"Compiler.h",
"ConcurrentCache.h",
"Constants.h",
+ "ContentLessObjectCache.h",
"CoreFoundationRef.h",
"DynamicLib.cpp",
"DynamicLib.h",
diff --git a/src/dawn/common/CMakeLists.txt b/src/dawn/common/CMakeLists.txt
index c78d6b0..302be93 100644
--- a/src/dawn/common/CMakeLists.txt
+++ b/src/dawn/common/CMakeLists.txt
@@ -40,6 +40,7 @@
"Compiler.h"
"ConcurrentCache.h"
"Constants.h"
+ "ContentLessObjectCache.h"
"CoreFoundationRef.h"
"DynamicLib.cpp"
"DynamicLib.h"
diff --git a/src/dawn/common/ContentLessObjectCache.h b/src/dawn/common/ContentLessObjectCache.h
new file mode 100644
index 0000000..e67c05a
--- /dev/null
+++ b/src/dawn/common/ContentLessObjectCache.h
@@ -0,0 +1,104 @@
+// Copyright 2023 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_DAWN_COMMON_CONTENTLESSOBJECTCACHE_H_
+#define SRC_DAWN_COMMON_CONTENTLESSOBJECTCACHE_H_
+
+#include <mutex>
+#include <tuple>
+#include <type_traits>
+#include <unordered_set>
+#include <utility>
+
+#include "dawn/common/RefCounted.h"
+
+namespace dawn {
+
+// The ContentLessObjectCache stores raw pointers to living Refs without adding to their refcounts.
+// This means that any RefCountedT that is inserted into the cache needs to make sure that their
+// DeleteThis function erases itself from the cache. Otherwise, the cache can grow indefinitely via
+// leaked pointers to deleted Refs.
+template <typename RefCountedT, typename BlueprintT = RefCountedT>
+class ContentLessObjectCache {
+ static_assert(std::is_convertible_v<RefCountedT*, BlueprintT*>,
+ "RefCountedT* must be convertible to a BlueprintT*.");
+
+ public:
+ // The dtor asserts that the cache is empty to aid in finding pointer leaks that can be possible
+ // if the RefCountedT doesn't correctly implement the DeleteThis function to erase itself from
+ // the cache.
+ ~ContentLessObjectCache() { ASSERT(Empty()); }
+
+ // Inserts the object into the cache returning a pair where the first is a Ref to the inserted
+ // or existing object, and the second is a bool that is true if we inserted `object` and false
+ // otherwise.
+ std::pair<Ref<RefCountedT>, bool> Insert(RefCountedT* object) {
+ std::lock_guard<std::mutex> lock(mMutex);
+ auto [it, inserted] = mCache.insert(object);
+ if (inserted) {
+ return {object, inserted};
+ } else {
+ // We need to check that the found instance isn't about to be destroyed. If it is, we
+ // actually want to remove the old instance from the cache and insert this one. This can
+ // happen if the last ref of the current instance in the cache hit is already in the
+ // process of being removed but hasn't completed yet.
+ Ref<RefCountedT> ref = TryGetRef(static_cast<RefCountedT*>(*it));
+ if (ref != nullptr) {
+ return {ref, false};
+ } else {
+ mCache.erase(it);
+ auto result = mCache.insert(object);
+ ASSERT(result.second);
+ return {object, true};
+ }
+ }
+ }
+
+ // Returns a valid Ref<T> if the underlying RefCounted object's refcount has not reached 0.
+ // Otherwise, returns nullptr.
+ Ref<RefCountedT> Find(BlueprintT* blueprint) {
+ std::lock_guard<std::mutex> lock(mMutex);
+ auto it = mCache.find(blueprint);
+ if (it != mCache.end()) {
+ return TryGetRef(static_cast<RefCountedT*>(*it));
+ }
+ return nullptr;
+ }
+
+ // Erases the object from the cache if it exists and are pointer equal. Otherwise does not
+ // modify the cache.
+ void Erase(RefCountedT* object) {
+ std::lock_guard<std::mutex> lock(mMutex);
+ auto it = mCache.find(object);
+ if (*it == object) {
+ mCache.erase(it);
+ }
+ }
+
+ // Returns true iff the cache is empty.
+ bool Empty() {
+ std::lock_guard<std::mutex> lock(mMutex);
+ return mCache.empty();
+ }
+
+ private:
+ std::mutex mMutex;
+ std::
+ unordered_set<BlueprintT*, typename BlueprintT::HashFunc, typename BlueprintT::EqualityFunc>
+ mCache;
+};
+
+} // namespace dawn
+
+#endif // SRC_DAWN_COMMON_CONTENTLESSOBJECTCACHE_H_
diff --git a/src/dawn/common/RefCounted.cpp b/src/dawn/common/RefCounted.cpp
index aa91c85..0c5bd12 100644
--- a/src/dawn/common/RefCounted.cpp
+++ b/src/dawn/common/RefCounted.cpp
@@ -56,6 +56,26 @@
mRefCount.fetch_add(kRefCountIncrement, std::memory_order_relaxed);
}
+bool RefCount::TryIncrement() {
+ uint64_t current = mRefCount.load(std::memory_order_relaxed);
+ bool success = false;
+ do {
+ if ((current & ~kPayloadMask) == 0u) {
+ return false;
+ }
+ // The relaxed ordering guarantees only the atomicity of the update. This is fine because:
+ // - If another thread's decrement happens before this increment, the increment should
+ // fail.
+ // - If another thread's decrement happens after this increment, the decrement shouldn't
+ // delete the object, because the ref count > 0.
+ // See Boost library for reference:
+ // https://github.com/boostorg/smart_ptr/blob/develop/include/boost/smart_ptr/detail/sp_counted_base_std_atomic.hpp#L62
+ success = mRefCount.compare_exchange_weak(current, current + kRefCountIncrement,
+ std::memory_order_relaxed);
+ } while (!success);
+ return true;
+}
+
bool RefCount::Decrement() {
ASSERT((mRefCount & ~kPayloadMask) != 0);
diff --git a/src/dawn/common/RefCounted.h b/src/dawn/common/RefCounted.h
index 0ab7427..e6cdb90 100644
--- a/src/dawn/common/RefCounted.h
+++ b/src/dawn/common/RefCounted.h
@@ -17,6 +17,7 @@
#include <atomic>
#include <cstdint>
+#include <type_traits>
#include "dawn/common/RefBase.h"
@@ -32,6 +33,9 @@
// Add a reference.
void Increment();
+ // Tries to add a reference. Returns false if the ref count is already at 0. This is used when
+ // operating on a raw pointer to a RefCounted instead of a valid Ref that may be soon deleted.
+ bool TryIncrement();
// Remove a reference. Returns true if this was the last reference.
bool Decrement();
@@ -40,6 +44,9 @@
std::atomic<uint64_t> mRefCount;
};
+template <typename T>
+class Ref;
+
class RefCounted {
public:
explicit RefCounted(uint64_t payload = 0);
@@ -52,6 +59,19 @@
// synchronization in place for destruction.
void Release();
+ // Tries to return a valid Ref to `object` if it's internal refcount is not already 0. If the
+ // internal refcount has already reached 0, returns nullptr instead.
+ template <typename T, typename = typename std::is_convertible<T, RefCounted>>
+ friend Ref<T> TryGetRef(T* object) {
+ // Since this is called on the RefCounted class directly, and can race with destruction, we
+ // verify that we can safely increment the refcount first, create the Ref, then decrement
+ // the refcount in that order to ensure that the resultant Ref is a valid Ref.
+ if (!object->mRefCount.TryIncrement()) {
+ return nullptr;
+ }
+ return AcquireRef(object);
+ }
+
void APIReference() { Reference(); }
// APIRelease() can be called without any synchronization guarantees so we need to use a Release
// method that will call LockAndDeleteThis() on destruction.
diff --git a/src/dawn/native/CacheRequest.h b/src/dawn/native/CacheRequest.h
index b861419..27da114 100644
--- a/src/dawn/native/CacheRequest.h
+++ b/src/dawn/native/CacheRequest.h
@@ -32,26 +32,6 @@
namespace detail {
-template <typename T>
-struct UnwrapResultOrError {
- using type = T;
-};
-
-template <typename T>
-struct UnwrapResultOrError<ResultOrError<T>> {
- using type = T;
-};
-
-template <typename T>
-struct IsResultOrError {
- static constexpr bool value = false;
-};
-
-template <typename T>
-struct IsResultOrError<ResultOrError<T>> {
- static constexpr bool value = true;
-};
-
void LogCacheHitError(std::unique_ptr<ErrorData> error);
} // namespace detail
diff --git a/src/dawn/native/CachedObject.h b/src/dawn/native/CachedObject.h
index 3fbba63..1040f54 100644
--- a/src/dawn/native/CachedObject.h
+++ b/src/dawn/native/CachedObject.h
@@ -23,13 +23,13 @@
namespace dawn::native {
-// Some objects are cached so that instead of creating new duplicate objects,
-// we increase the refcount of an existing object.
-// When an object is successfully created, the device should call
+// Some objects are cached so that instead of creating new duplicate objects, we increase the
+// refcount of an existing object. When an object is successfully created, the device should call
// SetIsCachedReference() and insert the object into the cache.
class CachedObject {
public:
bool IsCachedReference() const;
+ void SetIsCachedReference();
// Functor necessary for the unordered_set<CachedObject*>-based cache.
struct HashFunc {
@@ -47,14 +47,11 @@
CacheKey mCacheKey;
private:
- friend class DeviceBase;
- void SetIsCachedReference();
-
- bool mIsCachedReference = false;
-
// Called by ObjectContentHasher upon creation to record the object.
virtual size_t ComputeContentHash() = 0;
+ bool mIsCachedReference = false;
+
size_t mContentHash = 0;
bool mIsContentHashInitialized = false;
};
diff --git a/src/dawn/native/Device.cpp b/src/dawn/native/Device.cpp
index b55a06b..5d0b4a7 100644
--- a/src/dawn/native/Device.cpp
+++ b/src/dawn/native/Device.cpp
@@ -17,7 +17,7 @@
#include <algorithm>
#include <array>
#include <mutex>
-#include <unordered_set>
+#include <utility>
#include "dawn/common/Log.h"
#include "dawn/common/Version_autogen.h"
@@ -60,24 +60,8 @@
// DeviceBase sub-structures
-// The caches are unordered_sets of pointers with special hash and compare functions
-// to compare the value of the objects, instead of the pointers.
-template <typename Object>
-using ContentLessObjectCache =
- std::unordered_set<Object*, typename Object::HashFunc, typename Object::EqualityFunc>;
-
struct DeviceBase::Caches {
- ~Caches() {
- ASSERT(attachmentStates.empty());
- ASSERT(bindGroupLayouts.empty());
- ASSERT(computePipelines.empty());
- ASSERT(pipelineLayouts.empty());
- ASSERT(renderPipelines.empty());
- ASSERT(samplers.empty());
- ASSERT(shaderModules.empty());
- }
-
- ContentLessObjectCache<AttachmentStateBlueprint> attachmentStates;
+ ContentLessObjectCache<AttachmentState, AttachmentStateBlueprint> attachmentStates;
ContentLessObjectCache<BindGroupLayoutBase> bindGroupLayouts;
ContentLessObjectCache<ComputePipelineBase> computePipelines;
ContentLessObjectCache<PipelineLayoutBase> pipelineLayouts;
@@ -86,6 +70,44 @@
ContentLessObjectCache<ShaderModuleBase> shaderModules;
};
+// Tries to find the blueprint in the cache, creating and inserting into the cache if not found.
+template <typename RefCountedT, typename BlueprintT, typename CreateFn>
+auto GetOrCreate(ContentLessObjectCache<RefCountedT, BlueprintT>& cache,
+ BlueprintT* blueprint,
+ CreateFn createFn) {
+ using ReturnType = decltype(createFn());
+
+ // If we find the blueprint in the cache we can just return it.
+ Ref<RefCountedT> result = cache.Find(blueprint);
+ if (result != nullptr) {
+ return ReturnType(result);
+ }
+
+ using UnwrappedReturnType = typename detail::UnwrapResultOrError<ReturnType>::type;
+ static_assert(std::is_same_v<UnwrappedReturnType, Ref<RefCountedT>>,
+ "CreateFn should return an unwrapped type that is the same as Ref<RefCountedT>.");
+
+ // Create the result and try inserting it. Note that inserts can race because the critical
+ // sections here is disjoint, hence the checks to verify whether this thread inserted.
+ if constexpr (!detail::IsResultOrError<ReturnType>::value) {
+ result = createFn();
+ } else {
+ auto resultOrError = createFn();
+ if (DAWN_UNLIKELY(resultOrError.IsError())) {
+ return ReturnType(std::move(resultOrError.AcquireError()));
+ }
+ result = resultOrError.AcquireSuccess();
+ }
+ ASSERT(result.Get() != nullptr);
+
+ bool inserted = false;
+ std::tie(result, inserted) = cache.Insert(result.Get());
+ if (inserted) {
+ result->SetIsCachedReference();
+ }
+ return ReturnType(result);
+}
+
struct DeviceBase::DeprecationWarnings {
std::unordered_set<std::string> emitted;
size_t count = 0;
@@ -823,24 +845,19 @@
const size_t blueprintHash = blueprint.ComputeContentHash();
blueprint.SetContentHash(blueprintHash);
- Ref<BindGroupLayoutBase> result;
- auto iter = mCaches->bindGroupLayouts.find(&blueprint);
- if (iter != mCaches->bindGroupLayouts.end()) {
- result = *iter;
- } else {
- DAWN_TRY_ASSIGN(result, CreateBindGroupLayoutImpl(descriptor, pipelineCompatibilityToken));
- result->SetIsCachedReference();
- result->SetContentHash(blueprintHash);
- mCaches->bindGroupLayouts.insert(result.Get());
- }
-
- return std::move(result);
+ return GetOrCreate(
+ mCaches->bindGroupLayouts, &blueprint, [&]() -> ResultOrError<Ref<BindGroupLayoutBase>> {
+ Ref<BindGroupLayoutBase> result;
+ DAWN_TRY_ASSIGN(result,
+ CreateBindGroupLayoutImpl(descriptor, pipelineCompatibilityToken));
+ result->SetContentHash(blueprintHash);
+ return result;
+ });
}
void DeviceBase::UncacheBindGroupLayout(BindGroupLayoutBase* obj) {
ASSERT(obj->IsCachedReference());
- size_t removedCount = mCaches->bindGroupLayouts.erase(obj);
- ASSERT(removedCount == 1);
+ mCaches->bindGroupLayouts.Erase(obj);
}
// Private function used at initialization
@@ -872,53 +889,41 @@
Ref<ComputePipelineBase> DeviceBase::GetCachedComputePipeline(
ComputePipelineBase* uninitializedComputePipeline) {
- Ref<ComputePipelineBase> cachedPipeline;
- auto iter = mCaches->computePipelines.find(uninitializedComputePipeline);
- if (iter != mCaches->computePipelines.end()) {
- cachedPipeline = *iter;
- }
-
- return cachedPipeline;
+ return mCaches->computePipelines.Find(uninitializedComputePipeline);
}
Ref<RenderPipelineBase> DeviceBase::GetCachedRenderPipeline(
RenderPipelineBase* uninitializedRenderPipeline) {
- Ref<RenderPipelineBase> cachedPipeline;
- auto iter = mCaches->renderPipelines.find(uninitializedRenderPipeline);
- if (iter != mCaches->renderPipelines.end()) {
- cachedPipeline = *iter;
- }
- return cachedPipeline;
+ return mCaches->renderPipelines.Find(uninitializedRenderPipeline);
}
Ref<ComputePipelineBase> DeviceBase::AddOrGetCachedComputePipeline(
Ref<ComputePipelineBase> computePipeline) {
ASSERT(IsLockedByCurrentThreadIfNeeded());
- auto [cachedPipeline, inserted] = mCaches->computePipelines.insert(computePipeline.Get());
+ auto [cachedPipeline, inserted] = mCaches->computePipelines.Insert(computePipeline.Get());
if (inserted) {
computePipeline->SetIsCachedReference();
return computePipeline;
} else {
- return *cachedPipeline;
+ return std::move(cachedPipeline);
}
}
Ref<RenderPipelineBase> DeviceBase::AddOrGetCachedRenderPipeline(
Ref<RenderPipelineBase> renderPipeline) {
ASSERT(IsLockedByCurrentThreadIfNeeded());
- auto [cachedPipeline, inserted] = mCaches->renderPipelines.insert(renderPipeline.Get());
+ auto [cachedPipeline, inserted] = mCaches->renderPipelines.Insert(renderPipeline.Get());
if (inserted) {
renderPipeline->SetIsCachedReference();
return renderPipeline;
} else {
- return *cachedPipeline;
+ return std::move(cachedPipeline);
}
}
void DeviceBase::UncacheComputePipeline(ComputePipelineBase* obj) {
ASSERT(obj->IsCachedReference());
- size_t removedCount = mCaches->computePipelines.erase(obj);
- ASSERT(removedCount == 1);
+ mCaches->computePipelines.Erase(obj);
}
ResultOrError<Ref<TextureViewBase>>
@@ -957,30 +962,23 @@
const size_t blueprintHash = blueprint.ComputeContentHash();
blueprint.SetContentHash(blueprintHash);
- Ref<PipelineLayoutBase> result;
- auto iter = mCaches->pipelineLayouts.find(&blueprint);
- if (iter != mCaches->pipelineLayouts.end()) {
- result = *iter;
- } else {
- DAWN_TRY_ASSIGN(result, CreatePipelineLayoutImpl(descriptor));
- result->SetIsCachedReference();
- result->SetContentHash(blueprintHash);
- mCaches->pipelineLayouts.insert(result.Get());
- }
-
- return std::move(result);
+ return GetOrCreate(mCaches->pipelineLayouts, &blueprint,
+ [&]() -> ResultOrError<Ref<PipelineLayoutBase>> {
+ Ref<PipelineLayoutBase> result;
+ DAWN_TRY_ASSIGN(result, CreatePipelineLayoutImpl(descriptor));
+ result->SetContentHash(blueprintHash);
+ return result;
+ });
}
void DeviceBase::UncachePipelineLayout(PipelineLayoutBase* obj) {
ASSERT(obj->IsCachedReference());
- size_t removedCount = mCaches->pipelineLayouts.erase(obj);
- ASSERT(removedCount == 1);
+ mCaches->pipelineLayouts.Erase(obj);
}
void DeviceBase::UncacheRenderPipeline(RenderPipelineBase* obj) {
ASSERT(obj->IsCachedReference());
- size_t removedCount = mCaches->renderPipelines.erase(obj);
- ASSERT(removedCount == 1);
+ mCaches->renderPipelines.Erase(obj);
}
ResultOrError<Ref<SamplerBase>> DeviceBase::GetOrCreateSampler(
@@ -990,24 +988,17 @@
const size_t blueprintHash = blueprint.ComputeContentHash();
blueprint.SetContentHash(blueprintHash);
- Ref<SamplerBase> result;
- auto iter = mCaches->samplers.find(&blueprint);
- if (iter != mCaches->samplers.end()) {
- result = *iter;
- } else {
+ return GetOrCreate(mCaches->samplers, &blueprint, [&]() -> ResultOrError<Ref<SamplerBase>> {
+ Ref<SamplerBase> result;
DAWN_TRY_ASSIGN(result, CreateSamplerImpl(descriptor));
- result->SetIsCachedReference();
result->SetContentHash(blueprintHash);
- mCaches->samplers.insert(result.Get());
- }
-
- return std::move(result);
+ return result;
+ });
}
void DeviceBase::UncacheSampler(SamplerBase* obj) {
ASSERT(obj->IsCachedReference());
- size_t removedCount = mCaches->samplers.erase(obj);
- ASSERT(removedCount == 1);
+ mCaches->samplers.Erase(obj);
}
ResultOrError<Ref<ShaderModuleBase>> DeviceBase::GetOrCreateShaderModule(
@@ -1021,46 +1012,35 @@
const size_t blueprintHash = blueprint.ComputeContentHash();
blueprint.SetContentHash(blueprintHash);
- Ref<ShaderModuleBase> result;
- auto iter = mCaches->shaderModules.find(&blueprint);
- if (iter != mCaches->shaderModules.end()) {
- result = *iter;
- } else {
- if (!parseResult->HasParsedShader()) {
- // We skip the parse on creation if validation isn't enabled which let's us quickly
- // lookup in the cache without validating and parsing. We need the parsed module
- // now.
- ASSERT(!IsValidationEnabled());
- DAWN_TRY(
- ValidateAndParseShaderModule(this, descriptor, parseResult, compilationMessages));
- }
- DAWN_TRY_ASSIGN(result,
- CreateShaderModuleImpl(descriptor, parseResult, compilationMessages));
- result->SetIsCachedReference();
- result->SetContentHash(blueprintHash);
- mCaches->shaderModules.insert(result.Get());
- }
-
- return std::move(result);
+ return GetOrCreate(
+ mCaches->shaderModules, &blueprint, [&]() -> ResultOrError<Ref<ShaderModuleBase>> {
+ Ref<ShaderModuleBase> result;
+ if (!parseResult->HasParsedShader()) {
+ // We skip the parse on creation if validation isn't enabled which let's us quickly
+ // lookup in the cache without validating and parsing. We need the parsed module
+ // now.
+ ASSERT(!IsValidationEnabled());
+ DAWN_TRY(ValidateAndParseShaderModule(this, descriptor, parseResult,
+ compilationMessages));
+ }
+ DAWN_TRY_ASSIGN(result,
+ CreateShaderModuleImpl(descriptor, parseResult, compilationMessages));
+ result->SetContentHash(blueprintHash);
+ return result;
+ });
}
void DeviceBase::UncacheShaderModule(ShaderModuleBase* obj) {
ASSERT(obj->IsCachedReference());
- size_t removedCount = mCaches->shaderModules.erase(obj);
- ASSERT(removedCount == 1);
+ mCaches->shaderModules.Erase(obj);
}
Ref<AttachmentState> DeviceBase::GetOrCreateAttachmentState(AttachmentStateBlueprint* blueprint) {
- auto iter = mCaches->attachmentStates.find(blueprint);
- if (iter != mCaches->attachmentStates.end()) {
- return static_cast<AttachmentState*>(*iter);
- }
-
- Ref<AttachmentState> attachmentState = AcquireRef(new AttachmentState(this, *blueprint));
- attachmentState->SetIsCachedReference();
- attachmentState->SetContentHash(attachmentState->ComputeContentHash());
- mCaches->attachmentStates.insert(attachmentState.Get());
- return attachmentState;
+ return GetOrCreate(mCaches->attachmentStates, blueprint, [&]() -> Ref<AttachmentState> {
+ Ref<AttachmentState> attachmentState = AcquireRef(new AttachmentState(this, *blueprint));
+ attachmentState->SetContentHash(attachmentState->ComputeContentHash());
+ return attachmentState;
+ });
}
Ref<AttachmentState> DeviceBase::GetOrCreateAttachmentState(
@@ -1083,8 +1063,7 @@
void DeviceBase::UncacheAttachmentState(AttachmentState* obj) {
ASSERT(obj->IsCachedReference());
- size_t removedCount = mCaches->attachmentStates.erase(obj);
- ASSERT(removedCount == 1);
+ mCaches->attachmentStates.Erase(obj);
}
Ref<PipelineCacheBase> DeviceBase::GetOrCreatePipelineCache(const CacheKey& key) {
diff --git a/src/dawn/native/Device.h b/src/dawn/native/Device.h
index d059fae..4193830 100644
--- a/src/dawn/native/Device.h
+++ b/src/dawn/native/Device.h
@@ -21,6 +21,7 @@
#include <utility>
#include <vector>
+#include "dawn/common/ContentLessObjectCache.h"
#include "dawn/common/Mutex.h"
#include "dawn/native/CacheKey.h"
#include "dawn/native/Commands.h"
diff --git a/src/dawn/native/Error.h b/src/dawn/native/Error.h
index 7cfe284..058a0e5 100644
--- a/src/dawn/native/Error.h
+++ b/src/dawn/native/Error.h
@@ -41,6 +41,30 @@
template <typename T>
using ResultOrError = Result<T, ErrorData>;
+namespace detail {
+
+template <typename T>
+struct UnwrapResultOrError {
+ using type = T;
+};
+
+template <typename T>
+struct UnwrapResultOrError<ResultOrError<T>> {
+ using type = T;
+};
+
+template <typename T>
+struct IsResultOrError {
+ static constexpr bool value = false;
+};
+
+template <typename T>
+struct IsResultOrError<ResultOrError<T>> {
+ static constexpr bool value = true;
+};
+
+} // namespace detail
+
// Returning a success is done like so:
// return {}; // for Error
// return SomethingOfTypeT; // for ResultOrError<T>
diff --git a/src/dawn/tests/BUILD.gn b/src/dawn/tests/BUILD.gn
index 1f7b5aa..c291118 100644
--- a/src/dawn/tests/BUILD.gn
+++ b/src/dawn/tests/BUILD.gn
@@ -281,6 +281,7 @@
"unittests/ChainUtilsTests.cpp",
"unittests/CommandAllocatorTests.cpp",
"unittests/ConcurrentCacheTests.cpp",
+ "unittests/ContentLessObjectCacheTests.cpp",
"unittests/EnumClassBitmasksTests.cpp",
"unittests/EnumMaskIteratorTests.cpp",
"unittests/ErrorTests.cpp",
diff --git a/src/dawn/tests/unittests/ContentLessObjectCacheTests.cpp b/src/dawn/tests/unittests/ContentLessObjectCacheTests.cpp
new file mode 100644
index 0000000..b117814
--- /dev/null
+++ b/src/dawn/tests/unittests/ContentLessObjectCacheTests.cpp
@@ -0,0 +1,275 @@
+// Copyright 2023 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <condition_variable>
+#include <functional>
+#include <mutex>
+#include <string>
+#include <thread>
+#include <vector>
+
+#include "dawn/common/ContentLessObjectCache.h"
+#include "gtest/gtest.h"
+
+namespace dawn {
+namespace {
+
+using ::testing::Test;
+using ::testing::Types;
+
+class BlueprintT {
+ public:
+ explicit BlueprintT(size_t value) : mValue(value) {}
+
+ size_t GetValue() const { return mValue; }
+
+ struct HashFunc {
+ size_t operator()(const BlueprintT* x) const { return x->mValue; }
+ };
+
+ struct EqualityFunc {
+ bool operator()(const BlueprintT* l, const BlueprintT* r) const {
+ return l->mValue == r->mValue;
+ }
+ };
+
+ protected:
+ size_t mValue;
+};
+
+class RefCountedT : public BlueprintT, public RefCounted {
+ public:
+ explicit RefCountedT(size_t value) : BlueprintT(value) {}
+ RefCountedT(size_t value, std::function<void(RefCountedT*)> deleteFn)
+ : BlueprintT(value), mDeleteFn(deleteFn) {}
+
+ ~RefCountedT() override { mDeleteFn(this); }
+
+ struct HashFunc {
+ size_t operator()(const RefCountedT* x) const { return x->mValue; }
+ };
+
+ struct EqualityFunc {
+ bool operator()(const RefCountedT* l, const RefCountedT* r) const {
+ return l->mValue == r->mValue;
+ }
+ };
+
+ private:
+ std::function<void(RefCountedT*)> mDeleteFn = [](RefCountedT*) -> void {};
+};
+
+template <typename Blueprint>
+class ContentLessObjectCacheTest : public Test {};
+
+class BlueprintTypeNames {
+ public:
+ template <typename T>
+ static std::string GetName(int) {
+ if (std::is_same<T, RefCountedT>()) {
+ return "RefCountedT";
+ }
+ if (std::is_same<T, BlueprintT>()) {
+ return "BlueprintT";
+ }
+ }
+};
+using BlueprintTypes = Types<RefCountedT, BlueprintT>;
+TYPED_TEST_SUITE(ContentLessObjectCacheTest, BlueprintTypes, BlueprintTypeNames);
+
+// Empty cache returns true on Empty().
+TYPED_TEST(ContentLessObjectCacheTest, Empty) {
+ ContentLessObjectCache<RefCountedT, TypeParam> cache;
+ EXPECT_TRUE(cache.Empty());
+}
+
+// Non-empty cache returns false on Empty().
+TYPED_TEST(ContentLessObjectCacheTest, NonEmpty) {
+ ContentLessObjectCache<RefCountedT, TypeParam> cache;
+ Ref<RefCountedT> object =
+ AcquireRef(new RefCountedT(1, [&](RefCountedT* x) { cache.Erase(x); }));
+ EXPECT_TRUE(cache.Insert(object.Get()).second);
+ EXPECT_FALSE(cache.Empty());
+}
+
+// Object inserted into the cache are findable.
+TYPED_TEST(ContentLessObjectCacheTest, Insert) {
+ ContentLessObjectCache<RefCountedT, TypeParam> cache;
+ Ref<RefCountedT> object =
+ AcquireRef(new RefCountedT(1, [&](RefCountedT* x) { cache.Erase(x); }));
+ EXPECT_TRUE(cache.Insert(object.Get()).second);
+
+ TypeParam blueprint(1);
+ Ref<RefCountedT> cached = cache.Find(&blueprint);
+ EXPECT_TRUE(object.Get() == cached.Get());
+}
+
+// Duplicate insert calls on different objects with the same hash only inserts the first.
+TYPED_TEST(ContentLessObjectCacheTest, InsertDuplicate) {
+ ContentLessObjectCache<RefCountedT, TypeParam> cache;
+ Ref<RefCountedT> object1 =
+ AcquireRef(new RefCountedT(1, [&](RefCountedT* x) { cache.Erase(x); }));
+ EXPECT_TRUE(cache.Insert(object1.Get()).second);
+
+ Ref<RefCountedT> object2 = AcquireRef(new RefCountedT(1));
+ EXPECT_FALSE(cache.Insert(object2.Get()).second);
+
+ TypeParam blueprint(1);
+ Ref<RefCountedT> cached = cache.Find(&blueprint);
+ EXPECT_TRUE(object1.Get() == cached.Get());
+}
+
+// Erasing the only entry leaves the cache empty.
+TYPED_TEST(ContentLessObjectCacheTest, Erase) {
+ ContentLessObjectCache<RefCountedT, TypeParam> cache;
+ Ref<RefCountedT> object = AcquireRef(new RefCountedT(1));
+ EXPECT_TRUE(cache.Insert(object.Get()).second);
+ EXPECT_FALSE(cache.Empty());
+
+ cache.Erase(object.Get());
+ EXPECT_TRUE(cache.Empty());
+}
+
+// Erasing a hash equivalent but not pointer equivalent entry is a no-op.
+TYPED_TEST(ContentLessObjectCacheTest, EraseDuplicate) {
+ ContentLessObjectCache<RefCountedT, TypeParam> cache;
+ Ref<RefCountedT> object1 =
+ AcquireRef(new RefCountedT(1, [&](RefCountedT* x) { cache.Erase(x); }));
+ EXPECT_TRUE(cache.Insert(object1.Get()).second);
+ EXPECT_FALSE(cache.Empty());
+
+ Ref<RefCountedT> object2 = AcquireRef(new RefCountedT(1));
+ cache.Erase(object2.Get());
+ EXPECT_FALSE(cache.Empty());
+}
+
+// Helper struct that basically acts as a semaphore to allow for flow control in multiple threads.
+struct Signal {
+ std::mutex mutex;
+ std::condition_variable cv;
+ bool signaled = false;
+
+ void Fire() {
+ std::lock_guard<std::mutex> lock(mutex);
+ signaled = true;
+ cv.notify_one();
+ }
+ void Wait() {
+ std::unique_lock<std::mutex> lock(mutex);
+ while (!signaled) {
+ cv.wait(lock);
+ }
+ signaled = false;
+ }
+};
+
+// Inserting and finding elements should respect the results from the insert call.
+TYPED_TEST(ContentLessObjectCacheTest, InsertingAndFinding) {
+ constexpr size_t kNumObjects = 100;
+ constexpr size_t kNumThreads = 8;
+ ContentLessObjectCache<RefCountedT, TypeParam> cache;
+ std::vector<Ref<RefCountedT>> objects(kNumObjects);
+
+ auto f = [&]() {
+ for (size_t i = 0; i < kNumObjects; i++) {
+ Ref<RefCountedT> object =
+ AcquireRef(new RefCountedT(i, [&](RefCountedT* x) { cache.Erase(x); }));
+ if (cache.Insert(object.Get()).second) {
+ // This shouldn't race because exactly 1 thread should successfully insert.
+ objects[i] = object;
+ }
+ }
+ for (size_t i = 0; i < kNumObjects; i++) {
+ TypeParam blueprint(i);
+ Ref<RefCountedT> cached = cache.Find(&blueprint);
+ EXPECT_NE(cached.Get(), nullptr);
+ EXPECT_EQ(cached.Get(), objects[i].Get());
+ }
+ };
+
+ std::vector<std::thread> threads;
+ for (size_t t = 0; t < kNumThreads; t++) {
+ threads.emplace_back(f);
+ }
+ for (size_t t = 0; t < kNumThreads; t++) {
+ threads[t].join();
+ }
+}
+
+// Finding an element that is in the process of deletion should return nullptr.
+TYPED_TEST(ContentLessObjectCacheTest, FindDeleting) {
+ Signal signalA, signalB;
+
+ ContentLessObjectCache<RefCountedT, TypeParam> cache;
+ Ref<RefCountedT> object = AcquireRef(new RefCountedT(1, [&](RefCountedT* x) {
+ signalA.Fire();
+ signalB.Wait();
+ cache.Erase(x);
+ }));
+ EXPECT_TRUE(cache.Insert(object.Get()).second);
+
+ // Thread A will release the last reference of the original object.
+ auto threadA = [&]() { object = nullptr; };
+ // Thread B will try to Find the entry before it is completely destroyed.
+ auto threadB = [&]() {
+ signalA.Wait();
+ TypeParam blueprint(1);
+ EXPECT_TRUE(cache.Find(&blueprint) == nullptr);
+ signalB.Fire();
+ };
+
+ std::thread tA(threadA);
+ std::thread tB(threadB);
+ tA.join();
+ tB.join();
+}
+
+// Inserting an element that has an entry which is in process of deletion should insert the new
+// object.
+TYPED_TEST(ContentLessObjectCacheTest, InsertDeleting) {
+ Signal signalA, signalB;
+
+ ContentLessObjectCache<RefCountedT, TypeParam> cache;
+ Ref<RefCountedT> object1 = AcquireRef(new RefCountedT(1, [&](RefCountedT* x) {
+ signalA.Fire();
+ signalB.Wait();
+ cache.Erase(x);
+ }));
+ EXPECT_TRUE(cache.Insert(object1.Get()).second);
+
+ Ref<RefCountedT> object2 =
+ AcquireRef(new RefCountedT(1, [&](RefCountedT* x) { cache.Erase(x); }));
+
+ // Thread A will release the last reference of the original object.
+ auto threadA = [&]() { object1 = nullptr; };
+ // Thread B will try to Insert a hash equivalent entry before the original is completely
+ // destroyed.
+ auto threadB = [&]() {
+ signalA.Wait();
+ EXPECT_TRUE(cache.Insert(object2.Get()).second);
+ signalB.Fire();
+ };
+
+ std::thread tA(threadA);
+ std::thread tB(threadB);
+ tA.join();
+ tB.join();
+
+ TypeParam blueprint(1);
+ Ref<RefCountedT> cached = cache.Find(&blueprint);
+ EXPECT_TRUE(object2.Get() == cached.Get());
+}
+
+} // anonymous namespace
+} // namespace dawn