Initial implementation of a multithreaded fronted cache.

- Adds unit tests suite to verify the basic usages of the cache.
- Note: This change does NOT remove the device lock that is acquired
        for object deletion at the moment. So as a result, the actual
	effects shouldn't be apparent with this change. Furthermore,
	this most likely will not be enabled on its own without a
	parallel effort to add a WeakRef concept so that the cache
	would hold on top WeakRefs to objects instead of raw pointers.

Bug: dawn:1769
Change-Id: I0f6a11f01c558875c2b120a55aa3c4232b501a3c
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/136582
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Loko Kung <lokokung@google.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/dawn/common/BUILD.gn b/src/dawn/common/BUILD.gn
index 7509d0d..762312d 100644
--- a/src/dawn/common/BUILD.gn
+++ b/src/dawn/common/BUILD.gn
@@ -233,6 +233,7 @@
       "Compiler.h",
       "ConcurrentCache.h",
       "Constants.h",
+      "ContentLessObjectCache.h",
       "CoreFoundationRef.h",
       "DynamicLib.cpp",
       "DynamicLib.h",
diff --git a/src/dawn/common/CMakeLists.txt b/src/dawn/common/CMakeLists.txt
index c78d6b0..302be93 100644
--- a/src/dawn/common/CMakeLists.txt
+++ b/src/dawn/common/CMakeLists.txt
@@ -40,6 +40,7 @@
     "Compiler.h"
     "ConcurrentCache.h"
     "Constants.h"
+    "ContentLessObjectCache.h"
     "CoreFoundationRef.h"
     "DynamicLib.cpp"
     "DynamicLib.h"
diff --git a/src/dawn/common/ContentLessObjectCache.h b/src/dawn/common/ContentLessObjectCache.h
new file mode 100644
index 0000000..e67c05a
--- /dev/null
+++ b/src/dawn/common/ContentLessObjectCache.h
@@ -0,0 +1,104 @@
+// Copyright 2023 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_DAWN_COMMON_CONTENTLESSOBJECTCACHE_H_
+#define SRC_DAWN_COMMON_CONTENTLESSOBJECTCACHE_H_
+
+#include <mutex>
+#include <tuple>
+#include <type_traits>
+#include <unordered_set>
+#include <utility>
+
+#include "dawn/common/RefCounted.h"
+
+namespace dawn {
+
+// The ContentLessObjectCache stores raw pointers to living Refs without adding to their refcounts.
+// This means that any RefCountedT that is inserted into the cache needs to make sure that their
+// DeleteThis function erases itself from the cache. Otherwise, the cache can grow indefinitely via
+// leaked pointers to deleted Refs.
+template <typename RefCountedT, typename BlueprintT = RefCountedT>
+class ContentLessObjectCache {
+    static_assert(std::is_convertible_v<RefCountedT*, BlueprintT*>,
+                  "RefCountedT* must be convertible to a BlueprintT*.");
+
+  public:
+    // The dtor asserts that the cache is empty to aid in finding pointer leaks that can be possible
+    // if the RefCountedT doesn't correctly implement the DeleteThis function to erase itself from
+    // the cache.
+    ~ContentLessObjectCache() { ASSERT(Empty()); }
+
+    // Inserts the object into the cache returning a pair where the first is a Ref to the inserted
+    // or existing object, and the second is a bool that is true if we inserted `object` and false
+    // otherwise.
+    std::pair<Ref<RefCountedT>, bool> Insert(RefCountedT* object) {
+        std::lock_guard<std::mutex> lock(mMutex);
+        auto [it, inserted] = mCache.insert(object);
+        if (inserted) {
+            return {object, inserted};
+        } else {
+            // We need to check that the found instance isn't about to be destroyed. If it is, we
+            // actually want to remove the old instance from the cache and insert this one. This can
+            // happen if the last ref of the current instance in the cache hit is already in the
+            // process of being removed but hasn't completed yet.
+            Ref<RefCountedT> ref = TryGetRef(static_cast<RefCountedT*>(*it));
+            if (ref != nullptr) {
+                return {ref, false};
+            } else {
+                mCache.erase(it);
+                auto result = mCache.insert(object);
+                ASSERT(result.second);
+                return {object, true};
+            }
+        }
+    }
+
+    // Returns a valid Ref<T> if the underlying RefCounted object's refcount has not reached 0.
+    // Otherwise, returns nullptr.
+    Ref<RefCountedT> Find(BlueprintT* blueprint) {
+        std::lock_guard<std::mutex> lock(mMutex);
+        auto it = mCache.find(blueprint);
+        if (it != mCache.end()) {
+            return TryGetRef(static_cast<RefCountedT*>(*it));
+        }
+        return nullptr;
+    }
+
+    // Erases the object from the cache if it exists and are pointer equal. Otherwise does not
+    // modify the cache.
+    void Erase(RefCountedT* object) {
+        std::lock_guard<std::mutex> lock(mMutex);
+        auto it = mCache.find(object);
+        if (*it == object) {
+            mCache.erase(it);
+        }
+    }
+
+    // Returns true iff the cache is empty.
+    bool Empty() {
+        std::lock_guard<std::mutex> lock(mMutex);
+        return mCache.empty();
+    }
+
+  private:
+    std::mutex mMutex;
+    std::
+        unordered_set<BlueprintT*, typename BlueprintT::HashFunc, typename BlueprintT::EqualityFunc>
+            mCache;
+};
+
+}  // namespace dawn
+
+#endif  // SRC_DAWN_COMMON_CONTENTLESSOBJECTCACHE_H_
diff --git a/src/dawn/common/RefCounted.cpp b/src/dawn/common/RefCounted.cpp
index aa91c85..0c5bd12 100644
--- a/src/dawn/common/RefCounted.cpp
+++ b/src/dawn/common/RefCounted.cpp
@@ -56,6 +56,26 @@
     mRefCount.fetch_add(kRefCountIncrement, std::memory_order_relaxed);
 }
 
+bool RefCount::TryIncrement() {
+    uint64_t current = mRefCount.load(std::memory_order_relaxed);
+    bool success = false;
+    do {
+        if ((current & ~kPayloadMask) == 0u) {
+            return false;
+        }
+        // The relaxed ordering guarantees only the atomicity of the update. This is fine because:
+        //   - If another thread's decrement happens before this increment, the increment should
+        //     fail.
+        //   - If another thread's decrement happens after this increment, the decrement shouldn't
+        //     delete the object, because the ref count > 0.
+        // See Boost library for reference:
+        //   https://github.com/boostorg/smart_ptr/blob/develop/include/boost/smart_ptr/detail/sp_counted_base_std_atomic.hpp#L62
+        success = mRefCount.compare_exchange_weak(current, current + kRefCountIncrement,
+                                                  std::memory_order_relaxed);
+    } while (!success);
+    return true;
+}
+
 bool RefCount::Decrement() {
     ASSERT((mRefCount & ~kPayloadMask) != 0);
 
diff --git a/src/dawn/common/RefCounted.h b/src/dawn/common/RefCounted.h
index 0ab7427..e6cdb90 100644
--- a/src/dawn/common/RefCounted.h
+++ b/src/dawn/common/RefCounted.h
@@ -17,6 +17,7 @@
 
 #include <atomic>
 #include <cstdint>
+#include <type_traits>
 
 #include "dawn/common/RefBase.h"
 
@@ -32,6 +33,9 @@
 
     // Add a reference.
     void Increment();
+    // Tries to add a reference. Returns false if the ref count is already at 0. This is used when
+    // operating on a raw pointer to a RefCounted instead of a valid Ref that may be soon deleted.
+    bool TryIncrement();
 
     // Remove a reference. Returns true if this was the last reference.
     bool Decrement();
@@ -40,6 +44,9 @@
     std::atomic<uint64_t> mRefCount;
 };
 
+template <typename T>
+class Ref;
+
 class RefCounted {
   public:
     explicit RefCounted(uint64_t payload = 0);
@@ -52,6 +59,19 @@
     // synchronization in place for destruction.
     void Release();
 
+    // Tries to return a valid Ref to `object` if it's internal refcount is not already 0. If the
+    // internal refcount has already reached 0, returns nullptr instead.
+    template <typename T, typename = typename std::is_convertible<T, RefCounted>>
+    friend Ref<T> TryGetRef(T* object) {
+        // Since this is called on the RefCounted class directly, and can race with destruction, we
+        // verify that we can safely increment the refcount first, create the Ref, then decrement
+        // the refcount in that order to ensure that the resultant Ref is a valid Ref.
+        if (!object->mRefCount.TryIncrement()) {
+            return nullptr;
+        }
+        return AcquireRef(object);
+    }
+
     void APIReference() { Reference(); }
     // APIRelease() can be called without any synchronization guarantees so we need to use a Release
     // method that will call LockAndDeleteThis() on destruction.
diff --git a/src/dawn/native/CacheRequest.h b/src/dawn/native/CacheRequest.h
index b861419..27da114 100644
--- a/src/dawn/native/CacheRequest.h
+++ b/src/dawn/native/CacheRequest.h
@@ -32,26 +32,6 @@
 
 namespace detail {
 
-template <typename T>
-struct UnwrapResultOrError {
-    using type = T;
-};
-
-template <typename T>
-struct UnwrapResultOrError<ResultOrError<T>> {
-    using type = T;
-};
-
-template <typename T>
-struct IsResultOrError {
-    static constexpr bool value = false;
-};
-
-template <typename T>
-struct IsResultOrError<ResultOrError<T>> {
-    static constexpr bool value = true;
-};
-
 void LogCacheHitError(std::unique_ptr<ErrorData> error);
 
 }  // namespace detail
diff --git a/src/dawn/native/CachedObject.h b/src/dawn/native/CachedObject.h
index 3fbba63..1040f54 100644
--- a/src/dawn/native/CachedObject.h
+++ b/src/dawn/native/CachedObject.h
@@ -23,13 +23,13 @@
 
 namespace dawn::native {
 
-// Some objects are cached so that instead of creating new duplicate objects,
-// we increase the refcount of an existing object.
-// When an object is successfully created, the device should call
+// Some objects are cached so that instead of creating new duplicate objects, we increase the
+// refcount of an existing object. When an object is successfully created, the device should call
 // SetIsCachedReference() and insert the object into the cache.
 class CachedObject {
   public:
     bool IsCachedReference() const;
+    void SetIsCachedReference();
 
     // Functor necessary for the unordered_set<CachedObject*>-based cache.
     struct HashFunc {
@@ -47,14 +47,11 @@
     CacheKey mCacheKey;
 
   private:
-    friend class DeviceBase;
-    void SetIsCachedReference();
-
-    bool mIsCachedReference = false;
-
     // Called by ObjectContentHasher upon creation to record the object.
     virtual size_t ComputeContentHash() = 0;
 
+    bool mIsCachedReference = false;
+
     size_t mContentHash = 0;
     bool mIsContentHashInitialized = false;
 };
diff --git a/src/dawn/native/Device.cpp b/src/dawn/native/Device.cpp
index b55a06b..5d0b4a7 100644
--- a/src/dawn/native/Device.cpp
+++ b/src/dawn/native/Device.cpp
@@ -17,7 +17,7 @@
 #include <algorithm>
 #include <array>
 #include <mutex>
-#include <unordered_set>
+#include <utility>
 
 #include "dawn/common/Log.h"
 #include "dawn/common/Version_autogen.h"
@@ -60,24 +60,8 @@
 
 // DeviceBase sub-structures
 
-// The caches are unordered_sets of pointers with special hash and compare functions
-// to compare the value of the objects, instead of the pointers.
-template <typename Object>
-using ContentLessObjectCache =
-    std::unordered_set<Object*, typename Object::HashFunc, typename Object::EqualityFunc>;
-
 struct DeviceBase::Caches {
-    ~Caches() {
-        ASSERT(attachmentStates.empty());
-        ASSERT(bindGroupLayouts.empty());
-        ASSERT(computePipelines.empty());
-        ASSERT(pipelineLayouts.empty());
-        ASSERT(renderPipelines.empty());
-        ASSERT(samplers.empty());
-        ASSERT(shaderModules.empty());
-    }
-
-    ContentLessObjectCache<AttachmentStateBlueprint> attachmentStates;
+    ContentLessObjectCache<AttachmentState, AttachmentStateBlueprint> attachmentStates;
     ContentLessObjectCache<BindGroupLayoutBase> bindGroupLayouts;
     ContentLessObjectCache<ComputePipelineBase> computePipelines;
     ContentLessObjectCache<PipelineLayoutBase> pipelineLayouts;
@@ -86,6 +70,44 @@
     ContentLessObjectCache<ShaderModuleBase> shaderModules;
 };
 
+// Tries to find the blueprint in the cache, creating and inserting into the cache if not found.
+template <typename RefCountedT, typename BlueprintT, typename CreateFn>
+auto GetOrCreate(ContentLessObjectCache<RefCountedT, BlueprintT>& cache,
+                 BlueprintT* blueprint,
+                 CreateFn createFn) {
+    using ReturnType = decltype(createFn());
+
+    // If we find the blueprint in the cache we can just return it.
+    Ref<RefCountedT> result = cache.Find(blueprint);
+    if (result != nullptr) {
+        return ReturnType(result);
+    }
+
+    using UnwrappedReturnType = typename detail::UnwrapResultOrError<ReturnType>::type;
+    static_assert(std::is_same_v<UnwrappedReturnType, Ref<RefCountedT>>,
+                  "CreateFn should return an unwrapped type that is the same as Ref<RefCountedT>.");
+
+    // Create the result and try inserting it. Note that inserts can race because the critical
+    // sections here is disjoint, hence the checks to verify whether this thread inserted.
+    if constexpr (!detail::IsResultOrError<ReturnType>::value) {
+        result = createFn();
+    } else {
+        auto resultOrError = createFn();
+        if (DAWN_UNLIKELY(resultOrError.IsError())) {
+            return ReturnType(std::move(resultOrError.AcquireError()));
+        }
+        result = resultOrError.AcquireSuccess();
+    }
+    ASSERT(result.Get() != nullptr);
+
+    bool inserted = false;
+    std::tie(result, inserted) = cache.Insert(result.Get());
+    if (inserted) {
+        result->SetIsCachedReference();
+    }
+    return ReturnType(result);
+}
+
 struct DeviceBase::DeprecationWarnings {
     std::unordered_set<std::string> emitted;
     size_t count = 0;
@@ -823,24 +845,19 @@
     const size_t blueprintHash = blueprint.ComputeContentHash();
     blueprint.SetContentHash(blueprintHash);
 
-    Ref<BindGroupLayoutBase> result;
-    auto iter = mCaches->bindGroupLayouts.find(&blueprint);
-    if (iter != mCaches->bindGroupLayouts.end()) {
-        result = *iter;
-    } else {
-        DAWN_TRY_ASSIGN(result, CreateBindGroupLayoutImpl(descriptor, pipelineCompatibilityToken));
-        result->SetIsCachedReference();
-        result->SetContentHash(blueprintHash);
-        mCaches->bindGroupLayouts.insert(result.Get());
-    }
-
-    return std::move(result);
+    return GetOrCreate(
+        mCaches->bindGroupLayouts, &blueprint, [&]() -> ResultOrError<Ref<BindGroupLayoutBase>> {
+            Ref<BindGroupLayoutBase> result;
+            DAWN_TRY_ASSIGN(result,
+                            CreateBindGroupLayoutImpl(descriptor, pipelineCompatibilityToken));
+            result->SetContentHash(blueprintHash);
+            return result;
+        });
 }
 
 void DeviceBase::UncacheBindGroupLayout(BindGroupLayoutBase* obj) {
     ASSERT(obj->IsCachedReference());
-    size_t removedCount = mCaches->bindGroupLayouts.erase(obj);
-    ASSERT(removedCount == 1);
+    mCaches->bindGroupLayouts.Erase(obj);
 }
 
 // Private function used at initialization
@@ -872,53 +889,41 @@
 
 Ref<ComputePipelineBase> DeviceBase::GetCachedComputePipeline(
     ComputePipelineBase* uninitializedComputePipeline) {
-    Ref<ComputePipelineBase> cachedPipeline;
-    auto iter = mCaches->computePipelines.find(uninitializedComputePipeline);
-    if (iter != mCaches->computePipelines.end()) {
-        cachedPipeline = *iter;
-    }
-
-    return cachedPipeline;
+    return mCaches->computePipelines.Find(uninitializedComputePipeline);
 }
 
 Ref<RenderPipelineBase> DeviceBase::GetCachedRenderPipeline(
     RenderPipelineBase* uninitializedRenderPipeline) {
-    Ref<RenderPipelineBase> cachedPipeline;
-    auto iter = mCaches->renderPipelines.find(uninitializedRenderPipeline);
-    if (iter != mCaches->renderPipelines.end()) {
-        cachedPipeline = *iter;
-    }
-    return cachedPipeline;
+    return mCaches->renderPipelines.Find(uninitializedRenderPipeline);
 }
 
 Ref<ComputePipelineBase> DeviceBase::AddOrGetCachedComputePipeline(
     Ref<ComputePipelineBase> computePipeline) {
     ASSERT(IsLockedByCurrentThreadIfNeeded());
-    auto [cachedPipeline, inserted] = mCaches->computePipelines.insert(computePipeline.Get());
+    auto [cachedPipeline, inserted] = mCaches->computePipelines.Insert(computePipeline.Get());
     if (inserted) {
         computePipeline->SetIsCachedReference();
         return computePipeline;
     } else {
-        return *cachedPipeline;
+        return std::move(cachedPipeline);
     }
 }
 
 Ref<RenderPipelineBase> DeviceBase::AddOrGetCachedRenderPipeline(
     Ref<RenderPipelineBase> renderPipeline) {
     ASSERT(IsLockedByCurrentThreadIfNeeded());
-    auto [cachedPipeline, inserted] = mCaches->renderPipelines.insert(renderPipeline.Get());
+    auto [cachedPipeline, inserted] = mCaches->renderPipelines.Insert(renderPipeline.Get());
     if (inserted) {
         renderPipeline->SetIsCachedReference();
         return renderPipeline;
     } else {
-        return *cachedPipeline;
+        return std::move(cachedPipeline);
     }
 }
 
 void DeviceBase::UncacheComputePipeline(ComputePipelineBase* obj) {
     ASSERT(obj->IsCachedReference());
-    size_t removedCount = mCaches->computePipelines.erase(obj);
-    ASSERT(removedCount == 1);
+    mCaches->computePipelines.Erase(obj);
 }
 
 ResultOrError<Ref<TextureViewBase>>
@@ -957,30 +962,23 @@
     const size_t blueprintHash = blueprint.ComputeContentHash();
     blueprint.SetContentHash(blueprintHash);
 
-    Ref<PipelineLayoutBase> result;
-    auto iter = mCaches->pipelineLayouts.find(&blueprint);
-    if (iter != mCaches->pipelineLayouts.end()) {
-        result = *iter;
-    } else {
-        DAWN_TRY_ASSIGN(result, CreatePipelineLayoutImpl(descriptor));
-        result->SetIsCachedReference();
-        result->SetContentHash(blueprintHash);
-        mCaches->pipelineLayouts.insert(result.Get());
-    }
-
-    return std::move(result);
+    return GetOrCreate(mCaches->pipelineLayouts, &blueprint,
+                       [&]() -> ResultOrError<Ref<PipelineLayoutBase>> {
+                           Ref<PipelineLayoutBase> result;
+                           DAWN_TRY_ASSIGN(result, CreatePipelineLayoutImpl(descriptor));
+                           result->SetContentHash(blueprintHash);
+                           return result;
+                       });
 }
 
 void DeviceBase::UncachePipelineLayout(PipelineLayoutBase* obj) {
     ASSERT(obj->IsCachedReference());
-    size_t removedCount = mCaches->pipelineLayouts.erase(obj);
-    ASSERT(removedCount == 1);
+    mCaches->pipelineLayouts.Erase(obj);
 }
 
 void DeviceBase::UncacheRenderPipeline(RenderPipelineBase* obj) {
     ASSERT(obj->IsCachedReference());
-    size_t removedCount = mCaches->renderPipelines.erase(obj);
-    ASSERT(removedCount == 1);
+    mCaches->renderPipelines.Erase(obj);
 }
 
 ResultOrError<Ref<SamplerBase>> DeviceBase::GetOrCreateSampler(
@@ -990,24 +988,17 @@
     const size_t blueprintHash = blueprint.ComputeContentHash();
     blueprint.SetContentHash(blueprintHash);
 
-    Ref<SamplerBase> result;
-    auto iter = mCaches->samplers.find(&blueprint);
-    if (iter != mCaches->samplers.end()) {
-        result = *iter;
-    } else {
+    return GetOrCreate(mCaches->samplers, &blueprint, [&]() -> ResultOrError<Ref<SamplerBase>> {
+        Ref<SamplerBase> result;
         DAWN_TRY_ASSIGN(result, CreateSamplerImpl(descriptor));
-        result->SetIsCachedReference();
         result->SetContentHash(blueprintHash);
-        mCaches->samplers.insert(result.Get());
-    }
-
-    return std::move(result);
+        return result;
+    });
 }
 
 void DeviceBase::UncacheSampler(SamplerBase* obj) {
     ASSERT(obj->IsCachedReference());
-    size_t removedCount = mCaches->samplers.erase(obj);
-    ASSERT(removedCount == 1);
+    mCaches->samplers.Erase(obj);
 }
 
 ResultOrError<Ref<ShaderModuleBase>> DeviceBase::GetOrCreateShaderModule(
@@ -1021,46 +1012,35 @@
     const size_t blueprintHash = blueprint.ComputeContentHash();
     blueprint.SetContentHash(blueprintHash);
 
-    Ref<ShaderModuleBase> result;
-    auto iter = mCaches->shaderModules.find(&blueprint);
-    if (iter != mCaches->shaderModules.end()) {
-        result = *iter;
-    } else {
-        if (!parseResult->HasParsedShader()) {
-            // We skip the parse on creation if validation isn't enabled which let's us quickly
-            // lookup in the cache without validating and parsing. We need the parsed module
-            // now.
-            ASSERT(!IsValidationEnabled());
-            DAWN_TRY(
-                ValidateAndParseShaderModule(this, descriptor, parseResult, compilationMessages));
-        }
-        DAWN_TRY_ASSIGN(result,
-                        CreateShaderModuleImpl(descriptor, parseResult, compilationMessages));
-        result->SetIsCachedReference();
-        result->SetContentHash(blueprintHash);
-        mCaches->shaderModules.insert(result.Get());
-    }
-
-    return std::move(result);
+    return GetOrCreate(
+        mCaches->shaderModules, &blueprint, [&]() -> ResultOrError<Ref<ShaderModuleBase>> {
+            Ref<ShaderModuleBase> result;
+            if (!parseResult->HasParsedShader()) {
+                // We skip the parse on creation if validation isn't enabled which let's us quickly
+                // lookup in the cache without validating and parsing. We need the parsed module
+                // now.
+                ASSERT(!IsValidationEnabled());
+                DAWN_TRY(ValidateAndParseShaderModule(this, descriptor, parseResult,
+                                                      compilationMessages));
+            }
+            DAWN_TRY_ASSIGN(result,
+                            CreateShaderModuleImpl(descriptor, parseResult, compilationMessages));
+            result->SetContentHash(blueprintHash);
+            return result;
+        });
 }
 
 void DeviceBase::UncacheShaderModule(ShaderModuleBase* obj) {
     ASSERT(obj->IsCachedReference());
-    size_t removedCount = mCaches->shaderModules.erase(obj);
-    ASSERT(removedCount == 1);
+    mCaches->shaderModules.Erase(obj);
 }
 
 Ref<AttachmentState> DeviceBase::GetOrCreateAttachmentState(AttachmentStateBlueprint* blueprint) {
-    auto iter = mCaches->attachmentStates.find(blueprint);
-    if (iter != mCaches->attachmentStates.end()) {
-        return static_cast<AttachmentState*>(*iter);
-    }
-
-    Ref<AttachmentState> attachmentState = AcquireRef(new AttachmentState(this, *blueprint));
-    attachmentState->SetIsCachedReference();
-    attachmentState->SetContentHash(attachmentState->ComputeContentHash());
-    mCaches->attachmentStates.insert(attachmentState.Get());
-    return attachmentState;
+    return GetOrCreate(mCaches->attachmentStates, blueprint, [&]() -> Ref<AttachmentState> {
+        Ref<AttachmentState> attachmentState = AcquireRef(new AttachmentState(this, *blueprint));
+        attachmentState->SetContentHash(attachmentState->ComputeContentHash());
+        return attachmentState;
+    });
 }
 
 Ref<AttachmentState> DeviceBase::GetOrCreateAttachmentState(
@@ -1083,8 +1063,7 @@
 
 void DeviceBase::UncacheAttachmentState(AttachmentState* obj) {
     ASSERT(obj->IsCachedReference());
-    size_t removedCount = mCaches->attachmentStates.erase(obj);
-    ASSERT(removedCount == 1);
+    mCaches->attachmentStates.Erase(obj);
 }
 
 Ref<PipelineCacheBase> DeviceBase::GetOrCreatePipelineCache(const CacheKey& key) {
diff --git a/src/dawn/native/Device.h b/src/dawn/native/Device.h
index d059fae..4193830 100644
--- a/src/dawn/native/Device.h
+++ b/src/dawn/native/Device.h
@@ -21,6 +21,7 @@
 #include <utility>
 #include <vector>
 
+#include "dawn/common/ContentLessObjectCache.h"
 #include "dawn/common/Mutex.h"
 #include "dawn/native/CacheKey.h"
 #include "dawn/native/Commands.h"
diff --git a/src/dawn/native/Error.h b/src/dawn/native/Error.h
index 7cfe284..058a0e5 100644
--- a/src/dawn/native/Error.h
+++ b/src/dawn/native/Error.h
@@ -41,6 +41,30 @@
 template <typename T>
 using ResultOrError = Result<T, ErrorData>;
 
+namespace detail {
+
+template <typename T>
+struct UnwrapResultOrError {
+    using type = T;
+};
+
+template <typename T>
+struct UnwrapResultOrError<ResultOrError<T>> {
+    using type = T;
+};
+
+template <typename T>
+struct IsResultOrError {
+    static constexpr bool value = false;
+};
+
+template <typename T>
+struct IsResultOrError<ResultOrError<T>> {
+    static constexpr bool value = true;
+};
+
+}  // namespace detail
+
 // Returning a success is done like so:
 //   return {}; // for Error
 //   return SomethingOfTypeT; // for ResultOrError<T>
diff --git a/src/dawn/tests/BUILD.gn b/src/dawn/tests/BUILD.gn
index 1f7b5aa..c291118 100644
--- a/src/dawn/tests/BUILD.gn
+++ b/src/dawn/tests/BUILD.gn
@@ -281,6 +281,7 @@
     "unittests/ChainUtilsTests.cpp",
     "unittests/CommandAllocatorTests.cpp",
     "unittests/ConcurrentCacheTests.cpp",
+    "unittests/ContentLessObjectCacheTests.cpp",
     "unittests/EnumClassBitmasksTests.cpp",
     "unittests/EnumMaskIteratorTests.cpp",
     "unittests/ErrorTests.cpp",
diff --git a/src/dawn/tests/unittests/ContentLessObjectCacheTests.cpp b/src/dawn/tests/unittests/ContentLessObjectCacheTests.cpp
new file mode 100644
index 0000000..b117814
--- /dev/null
+++ b/src/dawn/tests/unittests/ContentLessObjectCacheTests.cpp
@@ -0,0 +1,275 @@
+// Copyright 2023 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <condition_variable>
+#include <functional>
+#include <mutex>
+#include <string>
+#include <thread>
+#include <vector>
+
+#include "dawn/common/ContentLessObjectCache.h"
+#include "gtest/gtest.h"
+
+namespace dawn {
+namespace {
+
+using ::testing::Test;
+using ::testing::Types;
+
+class BlueprintT {
+  public:
+    explicit BlueprintT(size_t value) : mValue(value) {}
+
+    size_t GetValue() const { return mValue; }
+
+    struct HashFunc {
+        size_t operator()(const BlueprintT* x) const { return x->mValue; }
+    };
+
+    struct EqualityFunc {
+        bool operator()(const BlueprintT* l, const BlueprintT* r) const {
+            return l->mValue == r->mValue;
+        }
+    };
+
+  protected:
+    size_t mValue;
+};
+
+class RefCountedT : public BlueprintT, public RefCounted {
+  public:
+    explicit RefCountedT(size_t value) : BlueprintT(value) {}
+    RefCountedT(size_t value, std::function<void(RefCountedT*)> deleteFn)
+        : BlueprintT(value), mDeleteFn(deleteFn) {}
+
+    ~RefCountedT() override { mDeleteFn(this); }
+
+    struct HashFunc {
+        size_t operator()(const RefCountedT* x) const { return x->mValue; }
+    };
+
+    struct EqualityFunc {
+        bool operator()(const RefCountedT* l, const RefCountedT* r) const {
+            return l->mValue == r->mValue;
+        }
+    };
+
+  private:
+    std::function<void(RefCountedT*)> mDeleteFn = [](RefCountedT*) -> void {};
+};
+
+template <typename Blueprint>
+class ContentLessObjectCacheTest : public Test {};
+
+class BlueprintTypeNames {
+  public:
+    template <typename T>
+    static std::string GetName(int) {
+        if (std::is_same<T, RefCountedT>()) {
+            return "RefCountedT";
+        }
+        if (std::is_same<T, BlueprintT>()) {
+            return "BlueprintT";
+        }
+    }
+};
+using BlueprintTypes = Types<RefCountedT, BlueprintT>;
+TYPED_TEST_SUITE(ContentLessObjectCacheTest, BlueprintTypes, BlueprintTypeNames);
+
+// Empty cache returns true on Empty().
+TYPED_TEST(ContentLessObjectCacheTest, Empty) {
+    ContentLessObjectCache<RefCountedT, TypeParam> cache;
+    EXPECT_TRUE(cache.Empty());
+}
+
+// Non-empty cache returns false on Empty().
+TYPED_TEST(ContentLessObjectCacheTest, NonEmpty) {
+    ContentLessObjectCache<RefCountedT, TypeParam> cache;
+    Ref<RefCountedT> object =
+        AcquireRef(new RefCountedT(1, [&](RefCountedT* x) { cache.Erase(x); }));
+    EXPECT_TRUE(cache.Insert(object.Get()).second);
+    EXPECT_FALSE(cache.Empty());
+}
+
+// Object inserted into the cache are findable.
+TYPED_TEST(ContentLessObjectCacheTest, Insert) {
+    ContentLessObjectCache<RefCountedT, TypeParam> cache;
+    Ref<RefCountedT> object =
+        AcquireRef(new RefCountedT(1, [&](RefCountedT* x) { cache.Erase(x); }));
+    EXPECT_TRUE(cache.Insert(object.Get()).second);
+
+    TypeParam blueprint(1);
+    Ref<RefCountedT> cached = cache.Find(&blueprint);
+    EXPECT_TRUE(object.Get() == cached.Get());
+}
+
+// Duplicate insert calls on different objects with the same hash only inserts the first.
+TYPED_TEST(ContentLessObjectCacheTest, InsertDuplicate) {
+    ContentLessObjectCache<RefCountedT, TypeParam> cache;
+    Ref<RefCountedT> object1 =
+        AcquireRef(new RefCountedT(1, [&](RefCountedT* x) { cache.Erase(x); }));
+    EXPECT_TRUE(cache.Insert(object1.Get()).second);
+
+    Ref<RefCountedT> object2 = AcquireRef(new RefCountedT(1));
+    EXPECT_FALSE(cache.Insert(object2.Get()).second);
+
+    TypeParam blueprint(1);
+    Ref<RefCountedT> cached = cache.Find(&blueprint);
+    EXPECT_TRUE(object1.Get() == cached.Get());
+}
+
+// Erasing the only entry leaves the cache empty.
+TYPED_TEST(ContentLessObjectCacheTest, Erase) {
+    ContentLessObjectCache<RefCountedT, TypeParam> cache;
+    Ref<RefCountedT> object = AcquireRef(new RefCountedT(1));
+    EXPECT_TRUE(cache.Insert(object.Get()).second);
+    EXPECT_FALSE(cache.Empty());
+
+    cache.Erase(object.Get());
+    EXPECT_TRUE(cache.Empty());
+}
+
+// Erasing a hash equivalent but not pointer equivalent entry is a no-op.
+TYPED_TEST(ContentLessObjectCacheTest, EraseDuplicate) {
+    ContentLessObjectCache<RefCountedT, TypeParam> cache;
+    Ref<RefCountedT> object1 =
+        AcquireRef(new RefCountedT(1, [&](RefCountedT* x) { cache.Erase(x); }));
+    EXPECT_TRUE(cache.Insert(object1.Get()).second);
+    EXPECT_FALSE(cache.Empty());
+
+    Ref<RefCountedT> object2 = AcquireRef(new RefCountedT(1));
+    cache.Erase(object2.Get());
+    EXPECT_FALSE(cache.Empty());
+}
+
+// Helper struct that basically acts as a semaphore to allow for flow control in multiple threads.
+struct Signal {
+    std::mutex mutex;
+    std::condition_variable cv;
+    bool signaled = false;
+
+    void Fire() {
+        std::lock_guard<std::mutex> lock(mutex);
+        signaled = true;
+        cv.notify_one();
+    }
+    void Wait() {
+        std::unique_lock<std::mutex> lock(mutex);
+        while (!signaled) {
+            cv.wait(lock);
+        }
+        signaled = false;
+    }
+};
+
+// Inserting and finding elements should respect the results from the insert call.
+TYPED_TEST(ContentLessObjectCacheTest, InsertingAndFinding) {
+    constexpr size_t kNumObjects = 100;
+    constexpr size_t kNumThreads = 8;
+    ContentLessObjectCache<RefCountedT, TypeParam> cache;
+    std::vector<Ref<RefCountedT>> objects(kNumObjects);
+
+    auto f = [&]() {
+        for (size_t i = 0; i < kNumObjects; i++) {
+            Ref<RefCountedT> object =
+                AcquireRef(new RefCountedT(i, [&](RefCountedT* x) { cache.Erase(x); }));
+            if (cache.Insert(object.Get()).second) {
+                // This shouldn't race because exactly 1 thread should successfully insert.
+                objects[i] = object;
+            }
+        }
+        for (size_t i = 0; i < kNumObjects; i++) {
+            TypeParam blueprint(i);
+            Ref<RefCountedT> cached = cache.Find(&blueprint);
+            EXPECT_NE(cached.Get(), nullptr);
+            EXPECT_EQ(cached.Get(), objects[i].Get());
+        }
+    };
+
+    std::vector<std::thread> threads;
+    for (size_t t = 0; t < kNumThreads; t++) {
+        threads.emplace_back(f);
+    }
+    for (size_t t = 0; t < kNumThreads; t++) {
+        threads[t].join();
+    }
+}
+
+// Finding an element that is in the process of deletion should return nullptr.
+TYPED_TEST(ContentLessObjectCacheTest, FindDeleting) {
+    Signal signalA, signalB;
+
+    ContentLessObjectCache<RefCountedT, TypeParam> cache;
+    Ref<RefCountedT> object = AcquireRef(new RefCountedT(1, [&](RefCountedT* x) {
+        signalA.Fire();
+        signalB.Wait();
+        cache.Erase(x);
+    }));
+    EXPECT_TRUE(cache.Insert(object.Get()).second);
+
+    // Thread A will release the last reference of the original object.
+    auto threadA = [&]() { object = nullptr; };
+    // Thread B will try to Find the entry before it is completely destroyed.
+    auto threadB = [&]() {
+        signalA.Wait();
+        TypeParam blueprint(1);
+        EXPECT_TRUE(cache.Find(&blueprint) == nullptr);
+        signalB.Fire();
+    };
+
+    std::thread tA(threadA);
+    std::thread tB(threadB);
+    tA.join();
+    tB.join();
+}
+
+// Inserting an element that has an entry which is in process of deletion should insert the new
+// object.
+TYPED_TEST(ContentLessObjectCacheTest, InsertDeleting) {
+    Signal signalA, signalB;
+
+    ContentLessObjectCache<RefCountedT, TypeParam> cache;
+    Ref<RefCountedT> object1 = AcquireRef(new RefCountedT(1, [&](RefCountedT* x) {
+        signalA.Fire();
+        signalB.Wait();
+        cache.Erase(x);
+    }));
+    EXPECT_TRUE(cache.Insert(object1.Get()).second);
+
+    Ref<RefCountedT> object2 =
+        AcquireRef(new RefCountedT(1, [&](RefCountedT* x) { cache.Erase(x); }));
+
+    // Thread A will release the last reference of the original object.
+    auto threadA = [&]() { object1 = nullptr; };
+    // Thread B will try to Insert a hash equivalent entry before the original is completely
+    // destroyed.
+    auto threadB = [&]() {
+        signalA.Wait();
+        EXPECT_TRUE(cache.Insert(object2.Get()).second);
+        signalB.Fire();
+    };
+
+    std::thread tA(threadA);
+    std::thread tB(threadB);
+    tA.join();
+    tB.join();
+
+    TypeParam blueprint(1);
+    Ref<RefCountedT> cached = cache.Find(&blueprint);
+    EXPECT_TRUE(object2.Get() == cached.Get());
+}
+
+}  // anonymous namespace
+}  // namespace dawn