Use transparent comparators in the frontend cache

Transparent comparators are a cleaner way to do Find and Erase
without WeakRef promotion, and cleaner than the existing variant
hack.

Bug: dawn:1513
Change-Id: Id01247bccc93f1e5399a83d5fa6a06207630165c
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/185260
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Loko Kung <lokokung@google.com>
Commit-Queue: Austin Eng <enga@chromium.org>
diff --git a/src/dawn/common/ContentLessObjectCache.h b/src/dawn/common/ContentLessObjectCache.h
index ffaf447..91633bc 100644
--- a/src/dawn/common/ContentLessObjectCache.h
+++ b/src/dawn/common/ContentLessObjectCache.h
@@ -29,10 +29,8 @@
 #define SRC_DAWN_COMMON_CONTENTLESSOBJECTCACHE_H_
 
 #include <mutex>
-#include <tuple>
 #include <type_traits>
 #include <utility>
-#include <variant>
 
 #include "absl/container/flat_hash_set.h"
 #include "dawn/common/ContentLessObjectCacheable.h"
@@ -55,8 +53,8 @@
 // dropped already.
 template <typename RefCountedT>
 struct ForErase {
-    explicit ForErase(RefCountedT* value) : mValue(value) {}
-    raw_ptr<RefCountedT> mValue;
+    explicit ForErase(RefCountedT* value) : value(value) {}
+    raw_ptr<RefCountedT> value;
 };
 
 // All cached WeakRefs must have an immutable hash value determined at insertion. This ensures that
@@ -65,105 +63,61 @@
 template <typename RefCountedT>
 using WeakRefAndHash = std::pair<WeakRef<RefCountedT>, size_t>;
 
-// The cache always holds WeakRefs internally, however, to enable lookups using pointers and special
-// Erase equality, we use a variant type to branch.
-template <typename RefCountedT>
-using ContentLessObjectCacheKey =
-    std::variant<RefCountedT*, WeakRefAndHash<RefCountedT>, ForErase<RefCountedT>>;
-
-enum class KeyType : size_t { Pointer = 0, WeakRef = 1, ForErase = 2 };
-
-template <typename RefCountedT>
-struct ContentLessObjectCacheHashVisitor {
-    using BaseHashFunc = typename RefCountedT::HashFunc;
-
-    size_t operator()(const RefCountedT* ptr) const { return BaseHashFunc()(ptr); }
-    size_t operator()(const WeakRefAndHash<RefCountedT>& weakref) const { return weakref.second; }
-    size_t operator()(const ForErase<RefCountedT>& forErase) const {
-        return BaseHashFunc()(forErase.mValue);
-    }
-};
-
 template <typename RefCountedT>
 struct ContentLessObjectCacheKeyFuncs {
-    static_assert(
-        std::is_same_v<RefCountedT*,
-                       std::variant_alternative_t<static_cast<size_t>(KeyType::Pointer),
-                                                  ContentLessObjectCacheKey<RefCountedT>>>);
-    static_assert(
-        std::is_same_v<WeakRefAndHash<RefCountedT>,
-                       std::variant_alternative_t<static_cast<size_t>(KeyType::WeakRef),
-                                                  ContentLessObjectCacheKey<RefCountedT>>>);
-    static_assert(
-        std::is_same_v<ForErase<RefCountedT>,
-                       std::variant_alternative_t<static_cast<size_t>(KeyType::ForErase),
-                                                  ContentLessObjectCacheKey<RefCountedT>>>);
+    using BaseHashFunc = typename RefCountedT::HashFunc;
+    using BaseEqualityFunc = typename RefCountedT::EqualityFunc;
 
     struct HashFunc {
-        size_t operator()(const ContentLessObjectCacheKey<RefCountedT>& key) const {
-            return std::visit(ContentLessObjectCacheHashVisitor<RefCountedT>(), key);
+        using is_transparent = void;
+
+        size_t operator()(const RefCountedT* ptr) const { return BaseHashFunc()(ptr); }
+        size_t operator()(const WeakRefAndHash<RefCountedT>& weakref) const {
+            return weakref.second;
+        }
+        size_t operator()(const ForErase<RefCountedT>& forErase) const {
+            return BaseHashFunc()(forErase.value);
         }
     };
 
     struct EqualityFunc {
+        using is_transparent = void;
+
         explicit EqualityFunc(ContentLessObjectCache<RefCountedT>* cache) : mCache(cache) {}
 
-        bool operator()(const ContentLessObjectCacheKey<RefCountedT>& a,
-                        const ContentLessObjectCacheKey<RefCountedT>& b) const {
-            // First check if we are in the erasing scenario. We need to determine this early
-            // because we handle the actual equality differently. Note that if either a or b is
-            // a ForErase, it is safe to use UnsafeGet for both a and b because either:
+        bool operator()(const WeakRefAndHash<RefCountedT>& a,
+                        const WeakRefAndHash<RefCountedT>& b) const {
+            Ref<RefCountedT> aRef = a.first.Promote();
+            Ref<RefCountedT> bRef = b.first.Promote();
+
+            bool equal = (aRef && bRef && BaseEqualityFunc()(aRef.Get(), bRef.Get()));
+            if (aRef) {
+                mCache->TrackTemporaryRef(std::move(aRef));
+            }
+            if (bRef) {
+                mCache->TrackTemporaryRef(std::move(bRef));
+            }
+            return equal;
+        }
+
+        bool operator()(const WeakRefAndHash<RefCountedT>& a,
+                        const ForErase<RefCountedT>& b) const {
+            // An object is being erased. In this scenario, UnsafeGet is OK because either:
             //   (1) a == b, in which case that means we are destroying the last copy and must be
             //       valid because cached objects must uncache themselves before being completely
             //       destroyed.
             //   (2) a != b, in which case the lock on the cache guarantees that the element in the
             //       cache has not been erased yet and hence cannot have been destroyed.
-            bool erasing = std::holds_alternative<ForErase<RefCountedT>>(a) ||
-                           std::holds_alternative<ForErase<RefCountedT>>(b);
+            return a.first.UnsafeGet() == b.value;
+        }
 
-            auto ExtractKey = [](bool erasing, const ContentLessObjectCacheKey<RefCountedT>& x)
-                -> std::pair<RefCountedT*, Ref<RefCountedT>> {
-                RefCountedT* xPtr = nullptr;
-                Ref<RefCountedT> xRef;
-                switch (static_cast<KeyType>(x.index())) {
-                    case KeyType::Pointer:
-                        xPtr = std::get<RefCountedT*>(x);
-                        break;
-                    case KeyType::WeakRef:
-                        if (erasing) {
-                            xPtr = std::get<WeakRefAndHash<RefCountedT>>(x).first.UnsafeGet();
-                        } else {
-                            xRef = std::get<WeakRefAndHash<RefCountedT>>(x).first.Promote();
-                            xPtr = xRef.Get();
-                        }
-                        break;
-                    case KeyType::ForErase:
-                        xPtr = std::get<ForErase<RefCountedT>>(x).mValue;
-                        break;
-                    default:
-                        DAWN_UNREACHABLE();
-                }
-                return {xPtr, xRef};
-            };
-            auto [aPtr, aRef] = ExtractKey(erasing, a);
-            auto [bPtr, bRef] = ExtractKey(erasing, b);
-
-            bool result = false;
-            if (aPtr == nullptr || bPtr == nullptr) {
-                result = false;
-            } else if (erasing) {
-                result = aPtr == bPtr;
-            } else {
-                result = typename RefCountedT::EqualityFunc()(aPtr, bPtr);
-            }
-
-            if (aRef != nullptr) {
+        bool operator()(const WeakRefAndHash<RefCountedT>& a, const RefCountedT* b) const {
+            Ref<RefCountedT> aRef = a.first.Promote();
+            bool equal = aRef && BaseEqualityFunc()(aRef.Get(), b);
+            if (aRef) {
                 mCache->TrackTemporaryRef(std::move(aRef));
             }
-            if (bRef != nullptr) {
-                mCache->TrackTemporaryRef(std::move(bRef));
-            }
-            return result;
+            return equal;
         }
 
         raw_ptr<ContentLessObjectCache<RefCountedT>> mCache = nullptr;
@@ -205,10 +159,9 @@
             } else {
                 // Try to promote the found WeakRef to a Ref. If promotion fails, remove the old Key
                 // and insert this one.
-                Ref<RefCountedT> ref =
-                    std::get<detail::WeakRefAndHash<RefCountedT>>(*it).first.Promote();
+                Ref<RefCountedT> ref = it->first.Promote();
                 if (ref != nullptr) {
-                    return {ref, false};
+                    return {std::move(ref), false};
                 } else {
                     mCache.erase(it);
                     auto result = mCache.insert(weakref);
@@ -225,7 +178,7 @@
         return WithLockAndCleanup([&]() -> Ref<RefCountedT> {
             auto it = mCache.find(blueprint);
             if (it != mCache.end()) {
-                return std::get<detail::WeakRefAndHash<RefCountedT>>(*it).first.Promote();
+                return it->first.Promote();
             }
             return nullptr;
         });
@@ -235,13 +188,15 @@
     // modify the cache. Since Erase never Promotes any WeakRefs, it does not need to be wrapped by
     // a WithLockAndCleanup, and a simple lock is enough.
     void Erase(RefCountedT* obj) {
-        std::lock_guard<std::mutex> lock(mMutex);
-        auto it = mCache.find(detail::ForErase<RefCountedT>(obj));
-        if (it == mCache.end()) {
+        size_t count;
+        {
+            std::lock_guard<std::mutex> lock(mMutex);
+            count = mCache.erase(detail::ForErase<RefCountedT>(obj));
+        }
+        if (count == 0) {
             return;
         }
         obj->mCache = nullptr;
-        mCache.erase(it);
     }
 
     // Returns true iff the cache is empty.
@@ -271,7 +226,7 @@
     }
 
     std::mutex mMutex;
-    absl::flat_hash_set<detail::ContentLessObjectCacheKey<RefCountedT>,
+    absl::flat_hash_set<detail::WeakRefAndHash<RefCountedT>,
                         typename CacheKeyFuncs::HashFunc,
                         typename CacheKeyFuncs::EqualityFunc>
         mCache;