[dawn][common] Adds MutexCondVarProtected for non-Ref<> types.

- Adds helper class for waiting with conditional variables.
- The helper is used to implement WaitAny on Dawn wire client.

Bug: 441981783
Change-Id: I119d0df4808b311f0a4cb7ea19ca6dafd3ad3b8b
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/260460
Commit-Queue: Loko Kung <lokokung@google.com>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
Auto-Submit: Loko Kung <lokokung@google.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/dawn/common/MutexProtected.h b/src/dawn/common/MutexProtected.h
index 82da884..48e22e6 100644
--- a/src/dawn/common/MutexProtected.h
+++ b/src/dawn/common/MutexProtected.h
@@ -28,6 +28,8 @@
 #ifndef SRC_DAWN_COMMON_MUTEXPROTECTED_H_
 #define SRC_DAWN_COMMON_MUTEXPROTECTED_H_
 
+#include <chrono>
+#include <condition_variable>
 #include <mutex>
 #include <utility>
 
@@ -37,6 +39,7 @@
 #include "dawn/common/NonMovable.h"
 #include "dawn/common/Ref.h"
 #include "dawn/common/StackAllocated.h"
+#include "dawn/common/Time.h"
 #include "partition_alloc/pointers/raw_ptr.h"
 #include "partition_alloc/pointers/raw_ptr_exclusion.h"
 
@@ -46,6 +49,9 @@
 class MutexProtected;
 
 template <typename T, template <typename, typename> class Guard>
+class MutexCondVarProtected;
+
+template <typename T, template <typename, typename> class Guard>
 class MutexProtectedSupport;
 
 namespace detail {
@@ -85,6 +91,11 @@
     static const auto* GetObj(const T* const obj) { return &obj->mImpl; }
 };
 
+template <typename T, typename Traits>
+class Guard;
+template <typename T, typename Traits>
+class CondVarGuard;
+
 // Guard class is a wrapping class that gives access to a protected resource after acquiring the
 // lock related to it. For the lifetime of this class, the lock is held.
 template <typename T, typename Traits>
@@ -117,6 +128,7 @@
 
   private:
     using NonConstT = typename std::remove_const<T>::type;
+    friend class CondVarGuard<T, Traits>;
     friend class MutexProtectedSupport<NonConstT, Guard>;
     friend class MutexProtected<NonConstT, Guard>;
 
@@ -128,6 +140,63 @@
     raw_ptr<class Defer> mDefer = nullptr;
 };
 
+// CondVarGuard is a different guard class that internally holds a Guard, but provides additional
+// functionality w.r.t condition variables. Specifically, the non-const version of this Guard will
+// automatically call notify_all() on the underlying condition variable so that calls to |Wait*()|
+// will unblock when |Pred| is true.
+template <typename T, typename Traits>
+class CondVarGuard : public NonMovable, StackAllocated {
+  public:
+    // It's the programmer's burden to not save the pointer/reference and reuse it without the lock.
+    auto* operator->() const { return mGuard.Get(); }
+    auto& operator*() const { return *mGuard.Get(); }
+
+    template <typename Predicate>
+    void Wait(Predicate pred) {
+        DAWN_ASSERT(mNotifyScope.cv);
+        mNotifyScope.cv->wait(mGuard.mLock, [&] { return pred((*Get())); });
+    }
+    template <typename Predicate>
+    bool WaitFor(Nanoseconds timeout, Predicate pred) {
+        DAWN_ASSERT(mNotifyScope.cv);
+        if (timeout < kMaxDurationNanos) {
+            return mNotifyScope.cv->wait_for(
+                mGuard.mLock, std::chrono::nanoseconds(static_cast<uint64_t>(timeout)),
+                [&] { return pred(*Get()); });
+        } else {
+            Wait(pred);
+            return true;
+        }
+    }
+
+  protected:
+    CondVarGuard(T* obj,
+                 typename Traits::MutexType& mutex,
+                 class Defer* defer = nullptr,
+                 std::condition_variable* cv = nullptr)
+        : mNotifyScope(cv), mGuard(obj, mutex, defer) {}
+
+    auto* Get() const { return mGuard.Get(); }
+
+  private:
+    using NonConstT = typename std::remove_const<T>::type;
+    friend class MutexProtected<NonConstT, CondVarGuard>;
+    friend class MutexCondVarProtected<NonConstT, CondVarGuard>;
+
+    struct NotifyScope {
+        explicit NotifyScope(std::condition_variable* cv) : cv(cv) { DAWN_ASSERT(cv); }
+        ~NotifyScope() {
+            if constexpr (!std::is_const_v<T>) {
+                cv->notify_all();
+            }
+        }
+
+        raw_ptr<std::condition_variable> cv = nullptr;
+    };
+    NotifyScope mNotifyScope;
+    Guard<T, Traits> mGuard;
+};
+
 template <typename T, typename Traits, template <typename, typename> class Guard = detail::Guard>
 class MutexProtectedBase {
   public:
@@ -146,7 +215,11 @@
     }
     template <typename Fn>
     auto Use(Fn&& fn) const {
-        return fn(Use());
+        return fn(ConstUse());
+    }
+    template <typename Fn>
+    auto ConstUse(Fn&& fn) const {
+        return fn(ConstUse());
     }
 
     template <typename Fn>
@@ -158,7 +231,7 @@
   protected:
     virtual Usage Use() = 0;
     virtual Usage UseWithDefer(Defer& defer) = 0;
-    virtual ConstUsage Use() const = 0;
+    virtual ConstUsage ConstUse() const = 0;
 
     mutable typename Traits::MutexType mMutex;
 };
@@ -213,12 +286,57 @@
     using Base::Use;
     using Base::UseWithDefer;
 
+  protected:
+    T mObj;
+
   private:
     Usage Use() override { return Usage(&mObj, this->mMutex); }
     Usage UseWithDefer(Defer& defer) override { return Usage(&mObj, this->mMutex, &defer); }
-    ConstUsage Use() const override { return ConstUsage(&mObj, this->mMutex); }
+    ConstUsage ConstUse() const override { return ConstUsage(&mObj, this->mMutex); }
+};
 
-    T mObj;
+// Wrapping class for object members to provide the protections with a mutex of a MutexProtected
+// with some additional helpers to allow waiting with a conditional variable as well. The general
+// usage should look the same as MutexProtected above, with additional usages like the following
+// example:
+//     class Example {
+//       public:
+//         void Complete() {
+//             mDone.Use([](auto done) {
+//                 // Do something
+//                 mDone = true;
+//             });
+//         }
+//         void WaitUntilDone() {
+//             mDone.Use([](auto done) {
+//                 done.Wait([](auto& done) { return done; });
+//             });
+//         }
+//       private:
+//         MutexCondVarProtected<bool> mDone = false;
+//     };
+template <typename T, template <typename, typename> class Guard = detail::CondVarGuard>
+class MutexCondVarProtected : public MutexProtected<T, Guard> {
+  public:
+    using Base = MutexProtected<T, Guard>;
+    using typename Base::ConstUsage;
+    using typename Base::Usage;
+
+    using Base::Base;
+
+    // Note that unlike in MutexProtected where |Use| and |ConstUse| guarantee the lock for the
+    // entire critical section, if a user calls |Wait| within |Fn|, the lock may be released and
+    // reacquired in order for another thread to update the condition.
+    using Base::Base::ConstUse;
+    using Base::Base::Use;
+
+  private:
+    Usage Use() override { return Usage(&this->mObj, this->mMutex, nullptr, &mCv); }
+    ConstUsage ConstUse() const override {
+        return ConstUsage(&this->mObj, this->mMutex, nullptr, &mCv);
+    }
+
+    mutable std::condition_variable mCv;
 };
 
 // CRTP wrapper to help create classes that are generally MutexProtected, but may wish to implement
@@ -259,7 +377,7 @@
     Usage UseWithDefer(Defer& defer) override {
         return Usage(static_cast<T*>(this), this->mMutex, &defer);
     }
-    ConstUsage Use() const override {
+    ConstUsage ConstUse() const override {
         return ConstUsage(static_cast<const T*>(this), this->mMutex);
     }
 };
diff --git a/src/dawn/tests/unittests/MutexProtectedTests.cpp b/src/dawn/tests/unittests/MutexProtectedTests.cpp
index 74f6b87..b6ff2f3 100644
--- a/src/dawn/tests/unittests/MutexProtectedTests.cpp
+++ b/src/dawn/tests/unittests/MutexProtectedTests.cpp
@@ -33,6 +33,7 @@
 #include "dawn/common/MutexProtected.h"
 #include "dawn/common/Ref.h"
 #include "dawn/common/RefCounted.h"
+#include "dawn/common/Time.h"
 #include "gtest/gtest.h"
 
 namespace dawn {
@@ -126,26 +127,26 @@
 };
 
 template <typename T>
-MutexProtected<T> CreateDefault() {
-    if constexpr (IsRef<T>::value) {
-        return MutexProtected<T>(AcquireRef(new typename UnwrapRef<T>::type()));
-    } else {
-        return MutexProtected<T>();
+class MutexProtectedTest : public Test {
+  protected:
+    MutexProtected<T> CreateDefault() {
+        if constexpr (IsRef<T>::value) {
+            return MutexProtected<T>(AcquireRef(new typename UnwrapRef<T>::type()));
+        } else {
+            return MutexProtected<T>();
+        }
     }
-}
 
-template <typename T, typename... Args>
-MutexProtected<T> CreateCustom(Args&&... args) {
-    if constexpr (IsRef<T>::value) {
-        return MutexProtected<T>(
-            AcquireRef(new typename UnwrapRef<T>::type(std::forward<Args>(args)...)));
-    } else {
-        return MutexProtected<T>(std::forward<Args>(args)...);
+    template <typename... Args>
+    MutexProtected<T> CreateCustom(Args&&... args) {
+        if constexpr (IsRef<T>::value) {
+            return MutexProtected<T>(
+                AcquireRef(new typename UnwrapRef<T>::type(std::forward<Args>(args)...)));
+        } else {
+            return MutexProtected<T>(std::forward<Args>(args)...);
+        }
     }
-}
-
-template <typename T>
-class MutexProtectedTest : public Test {};
+};
 
 class MutexProtectedTestTypeNames {
   public:
@@ -166,7 +167,7 @@
     static constexpr int kIncrementCount = 100;
     static constexpr int kDecrementCount = 50;
 
-    MutexProtected<TypeParam> counter = CreateDefault<TypeParam>();
+    auto counter = this->CreateDefault();
 
     auto increment = [&] {
         for (uint32_t i = 0; i < kIncrementCount; i++) {
@@ -206,7 +207,7 @@
     static constexpr int kDecrementCount = 50;
     static constexpr int kStartingcount = -100;
 
-    MutexProtected<TypeParam> counter = CreateCustom<TypeParam>(kStartingcount);
+    auto counter = this->CreateCustom(kStartingcount);
 
     auto increment = [&] {
         for (uint32_t i = 0; i < kIncrementCount; i++) {
@@ -244,8 +245,8 @@
 TYPED_TEST(MutexProtectedTest, MultipleProtected) {
     static constexpr int kIncrementCount = 100;
 
-    MutexProtected<TypeParam> c1 = CreateDefault<TypeParam>();
-    MutexProtected<TypeParam> c2 = CreateDefault<TypeParam>();
+    auto c1 = this->CreateDefault();
+    auto c2 = this->CreateDefault();
 
     auto increment = [&] {
         for (uint32_t i = 0; i < kIncrementCount; i++) {
@@ -268,6 +269,66 @@
     validateThread.join();
 }
 
+TEST(MutexCondVarProtectedTest, Nominal) {
+    static constexpr int kIncrementCount = 100;
+    auto counter = MutexCondVarProtected<CounterT>();
+
+    auto increment = [&] {
+        for (uint32_t i = 0; i < kIncrementCount; i++) {
+            counter->Increment();
+        }
+    };
+    auto useIncrement = [&] {
+        for (uint32_t i = 0; i < kIncrementCount; i++) {
+            counter.Use([](auto c) { c->Increment(); });
+        }
+    };
+    std::thread incrementThread(increment);
+    std::thread useIncrementThread(useIncrement);
+
+    auto expected = 2 * kIncrementCount;
+    counter.Use([&](auto c) {
+        c.WaitFor(kMaxDurationNanos, [&](auto& count) { return count.Get() == expected; });
+        EXPECT_EQ(c->Get(), expected);
+    });
+    EXPECT_EQ(counter->Get(), expected);
+
+    incrementThread.join();
+    useIncrementThread.join();
+}
+
+// WaitFor should timeout and fail if the condition is never met.
+TEST(MutexCondVarProtectedTest, WaitForTimeout) {
+    auto counter = MutexCondVarProtected<CounterT>();
+    counter.Use([](auto c) {
+        EXPECT_FALSE(c.WaitFor(Nanoseconds(5), [](auto& x) { return x.Get() == 1; }));
+    });
+}
+
+// Test that Wait releases the lock, otherwise this test would deadlock.
+TEST(MutexCondVarProtectedTest, WaitDeadlock) {
+    auto c1 = MutexCondVarProtected<CounterT>();
+    auto c2 = MutexCondVarProtected<CounterT>();
+
+    auto t1 = [&] {
+        c1.Use([&](auto x1) {
+            x1.Wait([](auto& x) { return x.Get() == 1; });
+            c2->Increment();
+        });
+    };
+    auto t2 = [&] {
+        c2.Use([&](auto x2) {
+            c1->Increment();
+            x2.Wait([](auto& x) { return x.Get() == 1; });
+        });
+    };
+
+    std::thread thread1(t1);
+    std::thread thread2(t2);
+    thread1.join();
+    thread2.join();
+}
+
 }  // anonymous namespace
 }  // namespace dawn