[dawn][common] Adds different Notify modes for MutexCondVarProtected.

- Adds an optional template argument for MutexCondVarProtected users
  to explicitly specify notify mode. Sometimes, we may modify the
  inner object but may want to not notify, or notify only one thread.
- Additionally, this change has some cleanups to remove the unused
  MutexProtectedSupport class that I initially added a while back
  thinking it would be useful, but we have since found other ways
  to address the initial use case in MapAsync. It also greatly
  simplifies the code by removing the inheritance that made it quite
  hard to update.

Change-Id: I3c22ecd4ab3fab4149a999ed78474b175c52b1c3
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/290375
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Loko Kung <lokokung@google.com>
diff --git a/src/dawn/common/MutexProtected.h b/src/dawn/common/MutexProtected.h
index 13bc766..1c7d4d9 100644
--- a/src/dawn/common/MutexProtected.h
+++ b/src/dawn/common/MutexProtected.h
@@ -47,14 +47,19 @@
 
 namespace dawn {
 
-template <typename T, template <typename, typename> class Guard>
+template <typename T, template <typename, typename> class Guard, typename Traits>
 class MutexProtected;
 
-template <typename T, template <typename, typename> class Guard>
-class MutexCondVarProtected;
+// Used by MutexCondVarProtected below where sometimes, it's useful to be able to specify which type
+// of notify scope we want.
+enum class NotifyType {
+    All,
+    One,
+    None,
+};
 
-template <typename T, template <typename, typename> class Guard>
-class MutexProtectedSupport;
+template <typename T, template <typename, typename, NotifyType> class Guard, typename Traits>
+class MutexCondVarProtected;
 
 namespace detail {
 
@@ -94,30 +99,9 @@
     static const ObjectType* GetObj(const Ref<T>* const obj) { return obj->Get(); }
 };
 
-template <typename T>
-struct MutexProtectedSupportTraits {
-    using MutexType = std::mutex;
-    using LockType = std::unique_lock<std::mutex>;
-
-    static constexpr bool kSupportsTryLock = true;
-
-    static MutexType CreateMutex() { return std::mutex(); }
-    static std::mutex& GetMutex(MutexType& m) { return m; }
-    static auto* GetObj(T* const obj) { return &obj->mImpl; }
-    static const auto* GetObj(const T* const obj) { return &obj->mImpl; }
-
-    static std::optional<LockType> TryLock(MutexType& mutex) {
-        LockType lock(GetMutex(mutex), std::try_to_lock);
-        if (!lock.owns_lock()) {
-            return std::nullopt;
-        }
-        return lock;
-    }
-};
-
 template <typename T, typename Traits>
 class Guard;
-template <typename T, typename Traits>
+template <typename T, typename Traits, NotifyType NotifyT>
 class CondVarGuard;
 
 // Guard class is a wrapping class that gives access to a protected resource after acquiring the
@@ -165,9 +149,13 @@
 
   private:
     using NonConstT = typename std::remove_const<T>::type;
-    friend class CondVarGuard<T, Traits>;
-    friend class MutexProtectedSupport<NonConstT, Guard>;
-    friend class MutexProtected<NonConstT, Guard>;
+    friend class MutexProtected<NonConstT, Guard, Traits>;
+
+    // Currently need to explicitly list the notify types because we can't partially specialize
+    // friend classes.
+    friend class CondVarGuard<T, Traits, NotifyType::All>;
+    friend class CondVarGuard<T, Traits, NotifyType::One>;
+    friend class CondVarGuard<T, Traits, NotifyType::None>;
 
     typename Traits::LockType mLock;
     // RAW_PTR_EXCLUSION: This pointer is created/destroyed on each access to a MutexProtected.
@@ -179,9 +167,9 @@
 
 // CondVarGuard is a different guard class that internally holds a Guard, but provides additional
 // functionality w.r.t condition variables. Specifically, the non-const version of this Guard will
-// automatically call notify_all() on the underlying condition variable so that calls to |Wait*()|
-// will unblock when |Pred| is true.
-template <typename T, typename Traits>
+// automatically call a notify function on the underlying condition variable so that calls to
+// |Wait*()| will unblock when |Pred| is true.
+template <typename T, typename Traits, NotifyType NotifyT = NotifyType::All>
 class CondVarGuard : public NonMovable, StackAllocated {
   public:
     // It's the programmer's burden to not save the pointer/reference and reuse it without the lock.
@@ -207,86 +195,47 @@
     }
 
   protected:
-    CondVarGuard(T* obj,
-                 typename Traits::MutexType& mutex,
-                 class Defer* defer = nullptr,
-                 std::condition_variable* cv = nullptr)
-        : mNotifyScope(cv), mGuard(obj, mutex, defer) {}
-    CondVarGuard(T* obj,
-                 Traits::LockType&& lock,
-                 class Defer* defer = nullptr,
-                 std::condition_variable* cv = nullptr)
-        : mNotifyScope(cv), mGuard(obj, std::move(lock), defer) {}
+    CondVarGuard(T* obj, Traits::MutexType& mutex, std::condition_variable* cv)
+        : mNotifyScope(cv), mGuard(obj, mutex) {}
 
     auto* Get() const { return mGuard.Get(); }
 
   private:
     using NonConstT = typename std::remove_const<T>::type;
-    friend class MutexProtected<NonConstT, CondVarGuard>;
-    friend class MutexCondVarProtected<NonConstT, CondVarGuard>;
+    friend class MutexCondVarProtected<NonConstT, CondVarGuard, Traits>;
 
-    struct NotifyScope {
-        explicit NotifyScope(std::condition_variable* cv) : cv(cv) { DAWN_ASSERT(cv); }
-        ~NotifyScope() {
-            if constexpr (!std::is_const_v<T>) {
-                cv->notify_all();
-            }
-        }
-
+    struct NotifyScopeBase {
+        explicit NotifyScopeBase(std::condition_variable* cv) : cv(cv) { DAWN_ASSERT(cv); }
         raw_ptr<std::condition_variable> cv = nullptr;
     };
-    NotifyScope mNotifyScope;
-    Guard<T, Traits> mGuard;
-};
 
-template <typename T, typename Traits, template <typename, typename> class Guard = detail::Guard>
-class MutexProtectedBase {
-  public:
-    using Usage = Guard<T, Traits>;
-    using ConstUsage = Guard<const T, Traits>;
-
-    MutexProtectedBase() : mMutex(Traits::CreateMutex()) {}
-    virtual ~MutexProtectedBase() = default;
-
-    Usage operator->() { return Use(); }
-    ConstUsage operator->() const { return Use(); }
-
-    template <typename Fn>
-    auto Use(Fn&& fn) {
-        return fn(Use());
-    }
-    template <typename Fn>
-    auto Use(Fn&& fn) const {
-        return fn(ConstUse());
-    }
-    template <typename Fn>
-    auto ConstUse(Fn&& fn) const {
-        return fn(ConstUse());
-    }
-
-    template <typename Fn>
-    auto UseWithDefer(Fn&& fn) {
-        Defer defer;
-        return fn(UseWithDefer(defer));
-    }
-
-    std::optional<Usage> TryUse()
-        requires Traits::kSupportsTryLock
-    {
-        auto maybeLock = Traits::TryLock(this->mMutex);
-        if (!maybeLock.has_value()) {
-            return std::nullopt;
+    template <NotifyType U>
+    struct NotifyScope : NotifyScopeBase {
+        using NotifyScopeBase::NotifyScopeBase;
+    };
+    template <>
+    struct NotifyScope<NotifyType::All> : NotifyScopeBase {
+        using NotifyScopeBase::NotifyScopeBase;
+        ~NotifyScope() {
+            if constexpr (!std::is_const_v<T>) {
+                this->cv->notify_all();
+            }
         }
-        return Use(std::move(*maybeLock));
-    }
+    };
+    template <>
+    struct NotifyScope<NotifyType::One> : NotifyScopeBase {
+        using NotifyScopeBase::NotifyScopeBase;
+        ~NotifyScope() {
+            if constexpr (!std::is_const_v<T>) {
+                this->cv->notify_one();
+            }
+        }
+    };
 
-  protected:
-    virtual Usage Use() = 0;
-    virtual Usage Use(Traits::LockType&& lock) = 0;
-    virtual Usage UseWithDefer(Defer& defer) = 0;
-    virtual ConstUsage ConstUse() const = 0;
-
-    mutable typename Traits::MutexType mMutex;
+    NotifyScope<NotifyT> mNotifyScope;
+    // Note that this class needs to hold a Guard member instead of extending it because we want the
+    // lock to be released before we notify.
+    Guard<T, Traits> mGuard;
 };
 
 }  // namespace detail
@@ -323,31 +272,55 @@
 //       private:
 //         MutexProtected<Allocator> mAllocator;
 //     };
-template <typename T, template <typename, typename> class Guard = detail::Guard>
-class MutexProtected
-    : public detail::MutexProtectedBase<T, detail::MutexProtectedTraits<T>, Guard> {
+template <typename T,
+          template <typename, typename> class Guard = detail::Guard,
+          typename Traits = detail::MutexProtectedTraits<T>>
+class MutexProtected {
   public:
-    using Traits = detail::MutexProtectedTraits<T>;
-    using Base = detail::MutexProtectedBase<T, Traits, Guard>;
-    using typename Base::ConstUsage;
-    using typename Base::Usage;
+    using Usage = Guard<T, Traits>;
+    using ConstUsage = Guard<const T, Traits>;
 
     template <typename... Args>
     // NOLINTNEXTLINE(runtime/explicit) allow implicit construction
-    MutexProtected(Args&&... args) : mObj(std::forward<Args>(args)...) {}
+    MutexProtected(Args&&... args)
+        : mMutex(Traits::CreateMutex()), mObj(std::forward<Args>(args)...) {}
+    virtual ~MutexProtected() = default;
 
-    using Base::TryUse;
-    using Base::Use;
-    using Base::UseWithDefer;
+    Usage operator->() { return Usage(&mObj, mMutex); }
+    template <typename Fn>
+    auto Use(Fn&& fn) {
+        return fn(Usage(&mObj, mMutex));
+    }
 
-  protected:
-    T mObj;
+    ConstUsage operator->() const { return ConstUsage(&mObj, mMutex); }
+    template <typename Fn>
+    auto ConstUse(Fn&& fn) const {
+        return fn(ConstUsage(&mObj, mMutex));
+    }
+    template <typename Fn>
+    auto Use(Fn&& fn) const {
+        return ConstUse(fn);
+    }
+
+    std::optional<Usage> TryUse()
+        requires Traits::kSupportsTryLock
+    {
+        auto maybeLock = Traits::TryLock(mMutex);
+        if (!maybeLock.has_value()) {
+            return std::nullopt;
+        }
+        return Usage(&mObj, std::move(*maybeLock), nullptr);
+    }
+
+    template <typename Fn>
+    auto UseWithDefer(Fn&& fn) {
+        Defer defer;
+        return fn(Usage(&mObj, mMutex, &defer));
+    }
 
   private:
-    Usage Use() override { return Usage(&mObj, this->mMutex); }
-    Usage Use(Traits::LockType&& lock) override { return Usage(&mObj, std::move(lock), nullptr); }
-    Usage UseWithDefer(Defer& defer) override { return Usage(&mObj, this->mMutex, &defer); }
-    ConstUsage ConstUse() const override { return ConstUsage(&mObj, this->mMutex); }
+    mutable Traits::MutexType mMutex;
+    T mObj;
 };
 
 // Wrapping class for object members to provide the protections with a mutex of a MutexProtected
@@ -370,79 +343,43 @@
 //       private:
 //         MutexCondVarProtected<bool> mDone = false;
 //     };
-template <typename T, template <typename, typename> class Guard = detail::CondVarGuard>
-class MutexCondVarProtected : public MutexProtected<T, Guard> {
+template <typename T,
+          template <typename, typename, NotifyType> class Guard = detail::CondVarGuard,
+          typename Traits = detail::MutexProtectedTraits<T>>
+class MutexCondVarProtected {
   public:
-    using Base = MutexProtected<T, Guard>;
-    using typename Base::ConstUsage;
-    using typename Base::Usage;
+    using Usage = Guard<T, Traits, NotifyType::All>;
+    using ConstUsage = Guard<const T, Traits, NotifyType::None>;
 
-    using Base::Base;
+    template <typename... Args>
+    // NOLINTNEXTLINE(runtime/explicit) allow implicit construction
+    MutexCondVarProtected(Args&&... args)
+        : mMutex(Traits::CreateMutex()), mObj(std::forward<Args>(args)...) {}
+    virtual ~MutexCondVarProtected() = default;
+
+    Usage operator->() { return Usage(&mObj, mMutex, &mCv); }
+    template <NotifyType NotifyT = NotifyType::All, typename Fn>
+    auto Use(Fn&& fn) {
+        return fn(Guard<T, Traits, NotifyT>(&mObj, mMutex, &mCv));
+    }
 
     // Note that unlike in MutexProtected where |Use| and |ConstUse| guarantee the lock for the
     // entire critical section, if a user calls |Wait| within |Fn|, the lock may be released and
     // reacquired in order for another thread to update the condition.
-    using Base::Base::ConstUse;
-    using Base::Base::TryUse;
-    using Base::Base::Use;
+    ConstUsage operator->() const { return ConstUsage(&mObj, mMutex, &mCv); }
+    template <typename Fn>
+    auto ConstUse(Fn&& fn) const {
+        return fn(ConstUsage(&mObj, mMutex, &mCv));
+    }
+    template <typename Fn>
+    auto Use(Fn&& fn) const {
+        return ConstUse(fn);
+    }
 
   private:
-    Usage Use() override { return Usage(&this->mObj, this->mMutex, nullptr, &mCv); }
-    Usage Use(Base::Traits::LockType&& lock) override {
-        return Usage(&this->mObj, std::move(lock), nullptr, &mCv);
-    }
-    ConstUsage ConstUse() const override {
-        return ConstUsage(&this->mObj, this->mMutex, nullptr, &mCv);
-    }
-
+    mutable Traits::MutexType mMutex;
     mutable std::condition_variable mCv;
-};
-
-// CRTP wrapper to help create classes that are generally MutexProtected, but may wish to implement
-// specific workarounds to avoid taking the lock in certain scenarios. See the example below and the
-// unittests for more example usages of this wrapper. Example usage:
-//     struct Counter : public MutexProtectedSupport<Counter> {
-//       public:
-//         // Reads the value stored in |mCounter| without acquiring the lock.
-//         int UnsafeRead() {
-//             return mImpl.mCounter;
-//         }
-//
-//       private:
-//         // This friend declaration MUST be included in all classes using this wrapper.
-//         friend typename MutexProtectedSupport<Counter>::Traits;
-//
-//         // Internal struct that wraps all the actual data that we want to be protected. Note that
-//         // this struct currently MUST be named |mImpl| to work.
-//         struct {
-//             int mCounter = 0;
-//         } mImpl;
-//     };
-//     // Other uses of this struct look as if we are using a MutexProtected<mImpl>.
-template <typename T, template <typename, typename> class Guard = detail::Guard>
-class MutexProtectedSupport
-    : public detail::MutexProtectedBase<T, detail::MutexProtectedSupportTraits<T>, Guard> {
-  public:
-    using Traits = detail::MutexProtectedSupportTraits<T>;
-    using Base = detail::MutexProtectedBase<T, Traits, Guard>;
-    using typename Base::ConstUsage;
-    using typename Base::Usage;
-
-    using Base::TryUse;
-    using Base::Use;
-    using Base::UseWithDefer;
-
-  private:
-    Usage Use() override { return Usage(static_cast<T*>(this), this->mMutex); }
-    Usage Use(Traits::LockType&& lock) override {
-        return Usage(static_cast<T*>(this), std::move(lock), nullptr);
-    }
-    Usage UseWithDefer(Defer& defer) override {
-        return Usage(static_cast<T*>(this), this->mMutex, &defer);
-    }
-    ConstUsage ConstUse() const override {
-        return ConstUsage(static_cast<const T*>(this), this->mMutex);
-    }
+    T mObj;
 };
 
 }  // namespace dawn
diff --git a/src/dawn/tests/unittests/MutexProtectedTests.cpp b/src/dawn/tests/unittests/MutexProtectedTests.cpp
index b6ff2f3..47d56dc 100644
--- a/src/dawn/tests/unittests/MutexProtectedTests.cpp
+++ b/src/dawn/tests/unittests/MutexProtectedTests.cpp
@@ -29,11 +29,13 @@
 #include <thread>
 #include <type_traits>
 #include <utility>
+#include <vector>
 
 #include "dawn/common/MutexProtected.h"
 #include "dawn/common/Ref.h"
 #include "dawn/common/RefCounted.h"
 #include "dawn/common/Time.h"
+#include "dawn/utils/SystemUtils.h"
 #include "gtest/gtest.h"
 
 namespace dawn {
@@ -42,75 +44,6 @@
 using ::testing::Test;
 using ::testing::Types;
 
-class MutexSupportedCounterT : public MutexProtectedSupport<MutexSupportedCounterT> {
-  public:
-    // This is an unsafe read of the count without acquiring the lock.
-    int ReadCount() { return mImpl.mCount; }
-
-  private:
-    friend typename MutexProtectedSupport<MutexSupportedCounterT>::Traits;
-
-    struct {
-        int mCount = 0;
-    } mImpl;
-};
-
-TEST(MutexProtectedSupportTests, Nominal) {
-    static constexpr int kIncrementCount = 100;
-    static constexpr int kDecrementCount = 50;
-
-    MutexSupportedCounterT counter;
-
-    auto increment = [&] {
-        for (uint32_t i = 0; i < kIncrementCount; i++) {
-            counter->mCount++;
-        }
-    };
-    auto useIncrement = [&] {
-        for (uint32_t i = 0; i < kIncrementCount; i++) {
-            counter.Use([](auto c) { c->mCount++; });
-        }
-    };
-    auto decrement = [&] {
-        for (uint32_t i = 0; i < kDecrementCount; i++) {
-            counter->mCount--;
-        }
-    };
-    auto useDecrement = [&] {
-        for (uint32_t i = 0; i < kDecrementCount; i++) {
-            counter.Use([](auto c) { c->mCount--; });
-        }
-    };
-
-    std::thread incrementThread(increment);
-    std::thread useIncrementThread(useIncrement);
-    std::thread decrementThread(decrement);
-    std::thread useDecrementThread(useDecrement);
-    incrementThread.join();
-    useIncrementThread.join();
-    decrementThread.join();
-    useDecrementThread.join();
-
-    EXPECT_EQ(counter->mCount, 2 * (kIncrementCount - kDecrementCount));
-}
-
-// Verifies that if we call additionally implemented functions when using the MutexProtectedSupport
-// wrapper, that they do not acquire the lock. If the lock was acquired, then this test would
-// deadlock.
-TEST(MutexProtectedSupportTests, UnsafeRead) {
-    MutexSupportedCounterT counter;
-
-    // Acquire the lock via the Use function.
-    counter.Use([&](auto c) {
-        // With the lock acquired, we should be able to call additionally implemented functions that
-        // do not acquire the lock.
-        c->mCount = 1;
-        EXPECT_EQ(counter.ReadCount(), 1);
-        c->mCount = 2;
-        EXPECT_EQ(counter.ReadCount(), 2);
-    });
-}
-
 // Simple thread-unsafe counter class.
 class CounterT : public RefCounted {
   public:
@@ -329,6 +262,47 @@
     thread2.join();
 }
 
+// Test that if we specifically ask for only one thread to be notified, then only one thread should
+// wake up from waiting.
+TEST(MutexCondVarProtectedTest, NotifyTypes) {
+    auto counter = MutexCondVarProtected<CounterT>();
+    std::atomic<int> woken = 0;
+
+    // Multiple threads both waiting on the condition variable, only one of them should actually be
+    // woken up on the first increment.
+    static constexpr int kNumThreads = 5;
+    std::vector<std::thread> threads;
+    threads.reserve(kNumThreads);
+    for (auto i = 0; i < kNumThreads; i++) {
+        threads.emplace_back([&] {
+            counter.ConstUse([&](auto c) {
+                c.Wait([](auto& x) { return x.Get() >= 1; });
+                woken += 1;
+            });
+        });
+    }
+
+    // Don't notify any threads.
+    counter.Use<NotifyType::None>([](auto c) { EXPECT_EQ(c->Get(), 0); });
+    EXPECT_EQ(woken, 0);
+
+    // Notify one of the threads only. This is currently racy w.r.t to the increment below in that
+    // it's possible that the increment happens before the threads start waiting. As a result, we
+    // only verify that at least once thread was woken. In practice, it is very difficult to verify
+    // that exactly one thread is woken.
+    counter.Use<NotifyType::One>([](auto c) { c->Increment(); });
+    while (woken == 0) {
+        utils::USleep(1000);
+    }
+
+    // Notify the rest of the threads via a NotifyAll, then wait for all the threads to join.
+    counter.Use<NotifyType::All>([](auto c) { c->Increment(); });
+    for (auto& t : threads) {
+        t.join();
+    }
+    EXPECT_EQ(woken, kNumThreads);
+}
+
 }  // anonymous namespace
 }  // namespace dawn