[dawn][common] Adds MutexCondVarProtected for non-Ref<> types.
- Adds helper class for waiting with conditional variables.
- The helper is used to implement WaitAny on Dawn wire client.
Bug: 441981783
Change-Id: I119d0df4808b311f0a4cb7ea19ca6dafd3ad3b8b
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/260460
Commit-Queue: Loko Kung <lokokung@google.com>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
Auto-Submit: Loko Kung <lokokung@google.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/dawn/common/MutexProtected.h b/src/dawn/common/MutexProtected.h
index 82da884..48e22e6 100644
--- a/src/dawn/common/MutexProtected.h
+++ b/src/dawn/common/MutexProtected.h
@@ -28,6 +28,8 @@
#ifndef SRC_DAWN_COMMON_MUTEXPROTECTED_H_
#define SRC_DAWN_COMMON_MUTEXPROTECTED_H_
+#include <chrono>
+#include <condition_variable>
#include <mutex>
#include <utility>
@@ -37,6 +39,7 @@
#include "dawn/common/NonMovable.h"
#include "dawn/common/Ref.h"
#include "dawn/common/StackAllocated.h"
+#include "dawn/common/Time.h"
#include "partition_alloc/pointers/raw_ptr.h"
#include "partition_alloc/pointers/raw_ptr_exclusion.h"
@@ -46,6 +49,9 @@
class MutexProtected;
template <typename T, template <typename, typename> class Guard>
+class MutexCondVarProtected;
+
+template <typename T, template <typename, typename> class Guard>
class MutexProtectedSupport;
namespace detail {
@@ -85,6 +91,11 @@
static const auto* GetObj(const T* const obj) { return &obj->mImpl; }
};
+template <typename T, typename Traits>
+class Guard;
+template <typename T, typename Traits>
+class CondVarGuard;
+
// Guard class is a wrapping class that gives access to a protected resource after acquiring the
// lock related to it. For the lifetime of this class, the lock is held.
template <typename T, typename Traits>
@@ -117,6 +128,7 @@
private:
using NonConstT = typename std::remove_const<T>::type;
+ friend class CondVarGuard<T, Traits>;
friend class MutexProtectedSupport<NonConstT, Guard>;
friend class MutexProtected<NonConstT, Guard>;
@@ -128,6 +140,63 @@
raw_ptr<class Defer> mDefer = nullptr;
};
+// CondVarGuard is a different guard class that internally holds a Guard, but provides additional
+// functionality w.r.t condition variables. Specifically, the non-const version of this Guard will
+// automatically call notify_all() on the underlying condition variable so that calls to |Wait*()|
+// will unblock when |Pred| is true.
+template <typename T, typename Traits>
+class CondVarGuard : public NonMovable, StackAllocated {
+ public:
+ // It's the programmer's burden to not save the pointer/reference and reuse it without the lock.
+ auto* operator->() const { return mGuard.Get(); }
+ auto& operator*() const { return *mGuard.Get(); }
+
+ template <typename Predicate>
+ void Wait(Predicate pred) {
+ DAWN_ASSERT(mNotifyScope.cv);
+ mNotifyScope.cv->wait(mGuard.mLock, [&] { return pred((*Get())); });
+ }
+ template <typename Predicate>
+ bool WaitFor(Nanoseconds timeout, Predicate pred) {
+ DAWN_ASSERT(mNotifyScope.cv);
+ if (timeout < kMaxDurationNanos) {
+ return mNotifyScope.cv->wait_for(
+ mGuard.mLock, std::chrono::nanoseconds(static_cast<uint64_t>(timeout)),
+ [&] { return pred(*Get()); });
+ } else {
+ Wait(pred);
+ return true;
+ }
+ }
+
+ protected:
+ CondVarGuard(T* obj,
+ typename Traits::MutexType& mutex,
+ class Defer* defer = nullptr,
+ std::condition_variable* cv = nullptr)
+ : mNotifyScope(cv), mGuard(obj, mutex, defer) {}
+
+ auto* Get() const { return mGuard.Get(); }
+
+ private:
+ using NonConstT = typename std::remove_const<T>::type;
+ friend class MutexProtected<NonConstT, CondVarGuard>;
+ friend class MutexCondVarProtected<NonConstT, CondVarGuard>;
+
+ struct NotifyScope {
+ explicit NotifyScope(std::condition_variable* cv) : cv(cv) { DAWN_ASSERT(cv); }
+ ~NotifyScope() {
+ if constexpr (!std::is_const_v<T>) {
+ cv->notify_all();
+ }
+ }
+
+ raw_ptr<std::condition_variable> cv = nullptr;
+ };
+ NotifyScope mNotifyScope;
+ Guard<T, Traits> mGuard;
+};
+
template <typename T, typename Traits, template <typename, typename> class Guard = detail::Guard>
class MutexProtectedBase {
public:
@@ -146,7 +215,11 @@
}
template <typename Fn>
auto Use(Fn&& fn) const {
- return fn(Use());
+ return fn(ConstUse());
+ }
+ template <typename Fn>
+ auto ConstUse(Fn&& fn) const {
+ return fn(ConstUse());
}
template <typename Fn>
@@ -158,7 +231,7 @@
protected:
virtual Usage Use() = 0;
virtual Usage UseWithDefer(Defer& defer) = 0;
- virtual ConstUsage Use() const = 0;
+ virtual ConstUsage ConstUse() const = 0;
mutable typename Traits::MutexType mMutex;
};
@@ -213,12 +286,57 @@
using Base::Use;
using Base::UseWithDefer;
+ protected:
+ T mObj;
+
private:
Usage Use() override { return Usage(&mObj, this->mMutex); }
Usage UseWithDefer(Defer& defer) override { return Usage(&mObj, this->mMutex, &defer); }
- ConstUsage Use() const override { return ConstUsage(&mObj, this->mMutex); }
+ ConstUsage ConstUse() const override { return ConstUsage(&mObj, this->mMutex); }
+};
- T mObj;
+// Wrapping class for object members to provide the protections with a mutex of a MutexProtected
+// with some additional helpers to allow waiting with a conditional variable as well. The general
+// usage should look the same as MutexProtected above, with additional usages like the following
+// example:
+// class Example {
+// public:
+// void Complete() {
+// mDone.Use([](auto done) {
+// // Do something
+// mDone = true;
+// });
+// }
+// void WaitUntilDone() {
+// mDone.Use([](auto done) {
+// done.Wait([](auto& done) { return done; });
+// });
+// }
+// private:
+// MutexCondVarProtected<bool> mDone = false;
+// };
+template <typename T, template <typename, typename> class Guard = detail::CondVarGuard>
+class MutexCondVarProtected : public MutexProtected<T, Guard> {
+ public:
+ using Base = MutexProtected<T, Guard>;
+ using typename Base::ConstUsage;
+ using typename Base::Usage;
+
+ using Base::Base;
+
+ // Note that unlike in MutexProtected where |Use| and |ConstUse| guarantee the lock for the
+ // entire critical section, if a user calls |Wait| within |Fn|, the lock may be released and
+ // reacquired in order for another thread to update the condition.
+ using Base::Base::ConstUse;
+ using Base::Base::Use;
+
+ private:
+ Usage Use() override { return Usage(&this->mObj, this->mMutex, nullptr, &mCv); }
+ ConstUsage ConstUse() const override {
+ return ConstUsage(&this->mObj, this->mMutex, nullptr, &mCv);
+ }
+
+ mutable std::condition_variable mCv;
};
// CRTP wrapper to help create classes that are generally MutexProtected, but may wish to implement
@@ -259,7 +377,7 @@
Usage UseWithDefer(Defer& defer) override {
return Usage(static_cast<T*>(this), this->mMutex, &defer);
}
- ConstUsage Use() const override {
+ ConstUsage ConstUse() const override {
return ConstUsage(static_cast<const T*>(this), this->mMutex);
}
};
diff --git a/src/dawn/tests/unittests/MutexProtectedTests.cpp b/src/dawn/tests/unittests/MutexProtectedTests.cpp
index 74f6b87..b6ff2f3 100644
--- a/src/dawn/tests/unittests/MutexProtectedTests.cpp
+++ b/src/dawn/tests/unittests/MutexProtectedTests.cpp
@@ -33,6 +33,7 @@
#include "dawn/common/MutexProtected.h"
#include "dawn/common/Ref.h"
#include "dawn/common/RefCounted.h"
+#include "dawn/common/Time.h"
#include "gtest/gtest.h"
namespace dawn {
@@ -126,26 +127,26 @@
};
template <typename T>
-MutexProtected<T> CreateDefault() {
- if constexpr (IsRef<T>::value) {
- return MutexProtected<T>(AcquireRef(new typename UnwrapRef<T>::type()));
- } else {
- return MutexProtected<T>();
+class MutexProtectedTest : public Test {
+ protected:
+ MutexProtected<T> CreateDefault() {
+ if constexpr (IsRef<T>::value) {
+ return MutexProtected<T>(AcquireRef(new typename UnwrapRef<T>::type()));
+ } else {
+ return MutexProtected<T>();
+ }
}
-}
-template <typename T, typename... Args>
-MutexProtected<T> CreateCustom(Args&&... args) {
- if constexpr (IsRef<T>::value) {
- return MutexProtected<T>(
- AcquireRef(new typename UnwrapRef<T>::type(std::forward<Args>(args)...)));
- } else {
- return MutexProtected<T>(std::forward<Args>(args)...);
+ template <typename... Args>
+ MutexProtected<T> CreateCustom(Args&&... args) {
+ if constexpr (IsRef<T>::value) {
+ return MutexProtected<T>(
+ AcquireRef(new typename UnwrapRef<T>::type(std::forward<Args>(args)...)));
+ } else {
+ return MutexProtected<T>(std::forward<Args>(args)...);
+ }
}
-}
-
-template <typename T>
-class MutexProtectedTest : public Test {};
+};
class MutexProtectedTestTypeNames {
public:
@@ -166,7 +167,7 @@
static constexpr int kIncrementCount = 100;
static constexpr int kDecrementCount = 50;
- MutexProtected<TypeParam> counter = CreateDefault<TypeParam>();
+ auto counter = this->CreateDefault();
auto increment = [&] {
for (uint32_t i = 0; i < kIncrementCount; i++) {
@@ -206,7 +207,7 @@
static constexpr int kDecrementCount = 50;
static constexpr int kStartingcount = -100;
- MutexProtected<TypeParam> counter = CreateCustom<TypeParam>(kStartingcount);
+ auto counter = this->CreateCustom(kStartingcount);
auto increment = [&] {
for (uint32_t i = 0; i < kIncrementCount; i++) {
@@ -244,8 +245,8 @@
TYPED_TEST(MutexProtectedTest, MultipleProtected) {
static constexpr int kIncrementCount = 100;
- MutexProtected<TypeParam> c1 = CreateDefault<TypeParam>();
- MutexProtected<TypeParam> c2 = CreateDefault<TypeParam>();
+ auto c1 = this->CreateDefault();
+ auto c2 = this->CreateDefault();
auto increment = [&] {
for (uint32_t i = 0; i < kIncrementCount; i++) {
@@ -268,6 +269,66 @@
validateThread.join();
}
+TEST(MutexCondVarProtectedTest, Nominal) {
+ static constexpr int kIncrementCount = 100;
+ auto counter = MutexCondVarProtected<CounterT>();
+
+ auto increment = [&] {
+ for (uint32_t i = 0; i < kIncrementCount; i++) {
+ counter->Increment();
+ }
+ };
+ auto useIncrement = [&] {
+ for (uint32_t i = 0; i < kIncrementCount; i++) {
+ counter.Use([](auto c) { c->Increment(); });
+ }
+ };
+ std::thread incrementThread(increment);
+ std::thread useIncrementThread(useIncrement);
+
+ auto expected = 2 * kIncrementCount;
+ counter.Use([&](auto c) {
+ c.WaitFor(kMaxDurationNanos, [&](auto& count) { return count.Get() == expected; });
+ EXPECT_EQ(c->Get(), expected);
+ });
+ EXPECT_EQ(counter->Get(), expected);
+
+ incrementThread.join();
+ useIncrementThread.join();
+}
+
+// WaitFor should timeout and fail if the condition is never met.
+TEST(MutexCondVarProtectedTest, WaitForTimeout) {
+ auto counter = MutexCondVarProtected<CounterT>();
+ counter.Use([](auto c) {
+ EXPECT_FALSE(c.WaitFor(Nanoseconds(5), [](auto& x) { return x.Get() == 1; }));
+ });
+}
+
+// Test that Wait releases the lock, otherwise this test would deadlock.
+TEST(MutexCondVarProtectedTest, WaitDeadlock) {
+ auto c1 = MutexCondVarProtected<CounterT>();
+ auto c2 = MutexCondVarProtected<CounterT>();
+
+ auto t1 = [&] {
+ c1.Use([&](auto x1) {
+ x1.Wait([](auto& x) { return x.Get() == 1; });
+ c2->Increment();
+ });
+ };
+ auto t2 = [&] {
+ c2.Use([&](auto x2) {
+ c1->Increment();
+ x2.Wait([](auto& x) { return x.Get() == 1; });
+ });
+ };
+
+ std::thread thread1(t1);
+ std::thread thread2(t2);
+ thread1.join();
+ thread2.join();
+}
+
} // anonymous namespace
} // namespace dawn