Make GetWeakRef a friend function
This makes it so you can GetWeakRef on a T* as well as a Ref<T>
which can save an unecessary reference+release
Change-Id: I2473d3284b2d4a7ca83be0fa04b7563d643b29de
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/142400
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Loko Kung <lokokung@google.com>
Commit-Queue: Austin Eng <enga@chromium.org>
diff --git a/src/dawn/common/ContentLessObjectCache.h b/src/dawn/common/ContentLessObjectCache.h
index 6654777..7a990b1 100644
--- a/src/dawn/common/ContentLessObjectCache.h
+++ b/src/dawn/common/ContentLessObjectCache.h
@@ -156,10 +156,10 @@
// Inserts the object into the cache returning a pair where the first is a Ref to the
// inserted or existing object, and the second is a bool that is true if we inserted
// `object` and false otherwise.
- std::pair<Ref<RefCountedT>, bool> Insert(Ref<RefCountedT> obj) {
+ std::pair<Ref<RefCountedT>, bool> Insert(RefCountedT* obj) {
std::lock_guard<std::mutex> lock(mMutex);
detail::WeakRefAndHash<RefCountedT> weakref =
- std::make_pair(obj.GetWeakRef(), typename RefCountedT::HashFunc()(obj.Get()));
+ std::make_pair(GetWeakRef(obj), typename RefCountedT::HashFunc()(obj));
auto [it, inserted] = mCache.insert(weakref);
if (inserted) {
obj->mCache = this;
diff --git a/src/dawn/common/Ref.h b/src/dawn/common/Ref.h
index 693a8fac..f0e947c 100644
--- a/src/dawn/common/Ref.h
+++ b/src/dawn/common/Ref.h
@@ -43,13 +43,6 @@
class Ref : public RefBase<T*, detail::RefCountedTraits<T>> {
public:
using RefBase<T*, detail::RefCountedTraits<T>>::RefBase;
-
- template <
- typename U = T,
- typename = typename std::enable_if<std::is_base_of_v<detail::WeakRefSupportBase, U>>::type>
- WeakRef<T> GetWeakRef() {
- return WeakRef<T>(this->Get());
- }
};
} // namespace dawn
diff --git a/src/dawn/common/WeakRef.h b/src/dawn/common/WeakRef.h
index 26485f6..b14b489 100644
--- a/src/dawn/common/WeakRef.h
+++ b/src/dawn/common/WeakRef.h
@@ -23,6 +23,23 @@
namespace dawn {
template <typename T>
+class WeakRef;
+
+template <
+ typename T,
+ typename = typename std::enable_if<std::is_base_of_v<detail::WeakRefSupportBase, T>>::type>
+WeakRef<T> GetWeakRef(T* obj) {
+ return WeakRef<T>(obj);
+}
+
+template <
+ typename T,
+ typename = typename std::enable_if<std::is_base_of_v<detail::WeakRefSupportBase, T>>::type>
+WeakRef<T> GetWeakRef(const Ref<T>& obj) {
+ return GetWeakRef(obj.Get());
+}
+
+template <typename T>
class WeakRef {
public:
WeakRef() {}
@@ -72,6 +89,9 @@
return nullptr;
}
+ friend WeakRef GetWeakRef<>(T* obj);
+ friend WeakRef GetWeakRef<>(const Ref<T>& obj);
+
private:
// Friend is needed so that we can access the data ref in conversions.
template <typename U>
diff --git a/src/dawn/tests/unittests/ContentLessObjectCacheTests.cpp b/src/dawn/tests/unittests/ContentLessObjectCacheTests.cpp
index 7ecacd5..c23ebda 100644
--- a/src/dawn/tests/unittests/ContentLessObjectCacheTests.cpp
+++ b/src/dawn/tests/unittests/ContentLessObjectCacheTests.cpp
@@ -61,7 +61,7 @@
TEST(ContentLessObjectCacheTest, NonEmpty) {
ContentLessObjectCache<CacheableT> cache;
Ref<CacheableT> object = AcquireRef(new CacheableT(1, [&](CacheableT* x) { cache.Erase(x); }));
- EXPECT_TRUE(cache.Insert(object).second);
+ EXPECT_TRUE(cache.Insert(object.Get()).second);
EXPECT_FALSE(cache.Empty());
}
diff --git a/src/dawn/tests/unittests/WeakRefTests.cpp b/src/dawn/tests/unittests/WeakRefTests.cpp
index 5ae62f3..6056f22 100644
--- a/src/dawn/tests/unittests/WeakRefTests.cpp
+++ b/src/dawn/tests/unittests/WeakRefTests.cpp
@@ -52,7 +52,7 @@
// When the original refcounted object is destroyed, all WeakRefs are no longer able to Promote.
TEST(WeakRefTests, BasicPromote) {
Ref<WeakRefBaseA> base = AcquireRef(new WeakRefBaseA());
- WeakRef<WeakRefBaseA> weak = base.GetWeakRef();
+ WeakRef<WeakRefBaseA> weak = GetWeakRef(base);
EXPECT_EQ(weak.Promote().Get(), base.Get());
base = nullptr;
@@ -63,9 +63,9 @@
// longer able to Promote.
TEST(WeakRefTests, DerivedPromote) {
Ref<WeakRefDerivedA> base = AcquireRef(new WeakRefDerivedA());
- WeakRef<WeakRefDerivedA> weak1 = base.GetWeakRef();
+ WeakRef<WeakRefDerivedA> weak1 = GetWeakRef(base);
WeakRef<WeakRefBaseA> weak2 = weak1;
- WeakRef<WeakRefBaseA> weak3 = base.GetWeakRef();
+ WeakRef<WeakRefBaseA> weak3 = GetWeakRef(base);
EXPECT_EQ(weak1.Promote().Get(), base.Get());
EXPECT_EQ(weak2.Promote().Get(), base.Get());
EXPECT_EQ(weak3.Promote().Get(), base.Get());
@@ -85,7 +85,7 @@
}));
auto f = [&] {
- WeakRef<WeakRefBaseA> weak = base.GetWeakRef();
+ WeakRef<WeakRefBaseA> weak = GetWeakRef(base);
semA.Release();
semB.Acquire();
EXPECT_EQ(weak.Promote().Get(), nullptr);
@@ -152,8 +152,8 @@
// Helper detection utilty to verify whether GetWeakRef is enabled.
template <typename T>
-using can_get_weakref_t = decltype(std::declval<Ref<T>>().GetWeakRef());
-TEST(WeakRefTests, GetWeakRef) {
+using can_get_weakref_t = decltype(GetWeakRef(std::declval<T*>()));
+TEST(WeakRefTests, GetWeakRefFromPtr) {
// The GetWeakRef function is only available on types that extend WeakRefSupport.
static_assert(std::experimental::is_detected_v<can_get_weakref_t, WeakRefBaseA>,
"GetWeakRef is enabled on classes that directly extend WeakRefSupport.");
@@ -164,6 +164,20 @@
"GetWeakRef is disabled on classes that do not extend WeakRefSupport.");
}
+// Helper detection utilty to verify whether GetWeakRef is enabled.
+template <typename T>
+using can_get_weakref_from_ref_t = decltype(GetWeakRef(std::declval<Ref<T>>()));
+TEST(WeakRefTests, GetWeakRefFromRef) {
+ // The GetWeakRef function is only available on types that extend WeakRefSupport.
+ static_assert(std::experimental::is_detected_v<can_get_weakref_from_ref_t, WeakRefBaseA>,
+ "GetWeakRef is enabled on classes that directly extend WeakRefSupport.");
+ static_assert(std::experimental::is_detected_v<can_get_weakref_from_ref_t, WeakRefDerivedA>,
+ "GetWeakRef is enabled on classes that indirectly extend WeakRefSupport.");
+
+ static_assert(!std::experimental::is_detected_v<can_get_weakref_from_ref_t, RefCountedT>,
+ "GetWeakRef is disabled on classes that do not extend WeakRefSupport.");
+}
+
} // anonymous namespace
} // namespace dawn