Reland "[dawn][common] Adds MutexProtectedSupport CRTP wrapper."
This is a reland of commit 1e23f1ca224b164294be79b02417174f39ecfb6b
- Uses raw_ptr instead of reference in Guard because d3d11 extends
the templates and with a reference, we lose the ability to null
out the object to prevent destructors from running on an object
that should otherwise have been moved.
Original change's description:
> [dawn][common] Adds MutexProtectedSupport CRTP wrapper.
>
> - Adds a CRTP wrapper that can be used to both declare a
> MutexProtected struct and its internals. This is particularly
> useful when we want to provide additional backdoors to read
> data without acquiring the lock.
> - Adds relevant tests for the new wrapper.
> - Also makes Guards hold a reference instead of a pointer since
> the pointer should never be null anyways.
>
> Bug: 40643114
> Change-Id: I53ce8ac0e4a16c3e3aa1cddb8860d9d1a5727b3f
> Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/244275
> Commit-Queue: Loko Kung <lokokung@google.com>
> Reviewed-by: Kai Ninomiya <kainino@chromium.org>
Bug: 40643114
Change-Id: If015a4b55c0d3be95c0136e2740ba8bd10ae8445
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/245314
Commit-Queue: Loko Kung <lokokung@google.com>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
diff --git a/src/dawn/common/MutexProtected.h b/src/dawn/common/MutexProtected.h
index ab2e6fc..1194d14 100644
--- a/src/dawn/common/MutexProtected.h
+++ b/src/dawn/common/MutexProtected.h
@@ -42,6 +42,9 @@
template <typename T, template <typename, typename> class Guard>
class MutexProtected;
+template <typename T, template <typename, typename> class Guard>
+class MutexProtectedSupport;
+
namespace detail {
template <typename T>
@@ -68,6 +71,17 @@
static const ObjectType* GetObj(const Ref<T>* const obj) { return obj->Get(); }
};
+template <typename T>
+struct MutexProtectedSupportTraits {
+ using MutexType = std::mutex;
+ using LockType = std::unique_lock<std::mutex>;
+
+ static MutexType CreateMutex() { return std::mutex(); }
+ static std::mutex& GetMutex(MutexType& m) { return m; }
+ static auto* GetObj(T* const obj) { return &obj->mImpl; }
+ static const auto* GetObj(const T* const obj) { return &obj->mImpl; }
+};
+
// Guard class is a wrapping class that gives access to a protected resource after acquiring the
// lock related to it. For the lifetime of this class, the lock is held.
template <typename T, typename Traits>
@@ -91,6 +105,7 @@
private:
using NonConstT = typename std::remove_const<T>::type;
+ friend class MutexProtectedSupport<NonConstT, Guard>;
friend class MutexProtected<NonConstT, Guard>;
typename Traits::LockType mLock;
@@ -100,6 +115,34 @@
RAW_PTR_EXCLUSION T* mObj = nullptr;
};
+template <typename T, typename Traits, template <typename, typename> class Guard = detail::Guard>
+class MutexProtectedBase {
+ public:
+ using Usage = Guard<T, Traits>;
+ using ConstUsage = Guard<const T, Traits>;
+
+ MutexProtectedBase() : mMutex(Traits::CreateMutex()) {}
+ virtual ~MutexProtectedBase() = default;
+
+ Usage operator->() { return Use(); }
+ ConstUsage operator->() const { return Use(); }
+
+ template <typename Fn>
+ auto Use(Fn&& fn) {
+ return fn(Use());
+ }
+ template <typename Fn>
+ auto Use(Fn&& fn) const {
+ return fn(Use());
+ }
+
+ protected:
+ virtual Usage Use() = 0;
+ virtual ConstUsage Use() const = 0;
+
+ mutable typename Traits::MutexType mMutex;
+};
+
} // namespace detail
// Wrapping class used for object members to ensure usage of the resource is protected with a mutex.
@@ -135,39 +178,66 @@
// MutexProtected<Allocator> mAllocator;
// };
template <typename T, template <typename, typename> class Guard = detail::Guard>
-class MutexProtected {
+class MutexProtected
+ : public detail::MutexProtectedBase<T, detail::MutexProtectedTraits<T>, Guard> {
public:
using Traits = detail::MutexProtectedTraits<T>;
- using Usage = Guard<T, Traits>;
- using ConstUsage = Guard<const T, Traits>;
-
- MutexProtected() : mMutex(Traits::CreateMutex()) {}
+ using Base = detail::MutexProtectedBase<T, Traits, Guard>;
+ using typename Base::ConstUsage;
+ using typename Base::Usage;
template <typename... Args>
// NOLINTNEXTLINE(runtime/explicit) allow implicit construction
- MutexProtected(Args&&... args)
- : mMutex(Traits::CreateMutex()), mObj(std::forward<Args>(args)...) {}
+ MutexProtected(Args&&... args) : mObj(std::forward<Args>(args)...) {}
- Usage operator->() { return Use(); }
- ConstUsage operator->() const { return Use(); }
-
- template <typename Fn>
- auto Use(Fn&& fn) {
- return fn(Use());
- }
- template <typename Fn>
- auto Use(Fn&& fn) const {
- return fn(Use());
- }
+ using Base::Use;
private:
- Usage Use() { return Usage(&mObj, mMutex); }
- ConstUsage Use() const { return ConstUsage(&mObj, mMutex); }
+ Usage Use() override { return Usage(&mObj, this->mMutex); }
+ ConstUsage Use() const override { return ConstUsage(&mObj, this->mMutex); }
- mutable typename Traits::MutexType mMutex;
T mObj;
};
+// CRTP wrapper to help create classes that are generally MutexProtected, but may wish to implement
+// specific workarounds to avoid taking the lock in certain scenarios. See the example below and the
+// unittests for more example usages of this wrapper. Example usage:
+// struct Counter : public MutexProtectedSupport<Counter> {
+// public:
+// // Reads the value stored in |mCounter| without acquiring the lock.
+// int UnsafeRead() {
+// return mImpl.mCounter;
+// }
+//
+// private:
+// // This friend declaration MUST be included in all classes using this wrapper.
+// friend typename MutexProtectedSupport<Counter>::Traits;
+//
+// // Internal struct that wraps all the actual data that we want to be protected. Note that
+// // this struct currently MUST be named |mImpl| to work.
+// struct {
+// int mCounter = 0;
+// } mImpl;
+// };
+// // Other uses of this struct look as if we are using a MutexProtected<mImpl>.
+template <typename T, template <typename, typename> class Guard = detail::Guard>
+class MutexProtectedSupport
+ : public detail::MutexProtectedBase<T, detail::MutexProtectedSupportTraits<T>, Guard> {
+ public:
+ using Traits = detail::MutexProtectedSupportTraits<T>;
+ using Base = detail::MutexProtectedBase<T, Traits, Guard>;
+ using typename Base::ConstUsage;
+ using typename Base::Usage;
+
+ using Base::Use;
+
+ private:
+ Usage Use() override { return Usage(static_cast<T*>(this), this->mMutex); }
+ ConstUsage Use() const override {
+ return ConstUsage(static_cast<const T*>(this), this->mMutex);
+ }
+};
+
} // namespace dawn
#endif // SRC_DAWN_COMMON_MUTEXPROTECTED_H_
diff --git a/src/dawn/tests/unittests/MutexProtectedTests.cpp b/src/dawn/tests/unittests/MutexProtectedTests.cpp
index de6f20a..74f6b87 100644
--- a/src/dawn/tests/unittests/MutexProtectedTests.cpp
+++ b/src/dawn/tests/unittests/MutexProtectedTests.cpp
@@ -41,6 +41,75 @@
using ::testing::Test;
using ::testing::Types;
+class MutexSupportedCounterT : public MutexProtectedSupport<MutexSupportedCounterT> {
+ public:
+ // This is an unsafe read of the count without acquiring the lock.
+ int ReadCount() { return mImpl.mCount; }
+
+ private:
+ friend typename MutexProtectedSupport<MutexSupportedCounterT>::Traits;
+
+ struct {
+ int mCount = 0;
+ } mImpl;
+};
+
+TEST(MutexProtectedSupportTests, Nominal) {
+ static constexpr int kIncrementCount = 100;
+ static constexpr int kDecrementCount = 50;
+
+ MutexSupportedCounterT counter;
+
+ auto increment = [&] {
+ for (uint32_t i = 0; i < kIncrementCount; i++) {
+ counter->mCount++;
+ }
+ };
+ auto useIncrement = [&] {
+ for (uint32_t i = 0; i < kIncrementCount; i++) {
+ counter.Use([](auto c) { c->mCount++; });
+ }
+ };
+ auto decrement = [&] {
+ for (uint32_t i = 0; i < kDecrementCount; i++) {
+ counter->mCount--;
+ }
+ };
+ auto useDecrement = [&] {
+ for (uint32_t i = 0; i < kDecrementCount; i++) {
+ counter.Use([](auto c) { c->mCount--; });
+ }
+ };
+
+ std::thread incrementThread(increment);
+ std::thread useIncrementThread(useIncrement);
+ std::thread decrementThread(decrement);
+ std::thread useDecrementThread(useDecrement);
+ incrementThread.join();
+ useIncrementThread.join();
+ decrementThread.join();
+ useDecrementThread.join();
+
+ EXPECT_EQ(counter->mCount, 2 * (kIncrementCount - kDecrementCount));
+}
+
+// Verifies that if we call additionally implemented functions when using the MutexProtectedSupport
+// wrapper, that they do not acquire the lock. If the lock was acquired, then this test would
+// deadlock.
+TEST(MutexProtectedSupportTests, UnsafeRead) {
+ MutexSupportedCounterT counter;
+
+ // Acquire the lock via the Use function.
+ counter.Use([&](auto c) {
+ // With the lock acquired, we should be able to call additionally implemented functions that
+ // do not acquire the lock.
+ c->mCount = 1;
+ EXPECT_EQ(counter.ReadCount(), 1);
+ c->mCount = 2;
+ EXPECT_EQ(counter.ReadCount(), 2);
+ });
+}
+
// Simple thread-unsafe counter class.
class CounterT : public RefCounted {
public: