[dawn][common] Adds MutexProtectedSupport CRTP wrapper.

- Adds a CRTP wrapper that can be used to both declare a
  MutexProtected struct and its internals. This is particularly
  useful when we want to provide additional backdoors to read
  data without acquiring the lock.
- Adds relevant tests for the new wrapper.
- Also makes Guards hold a reference instead of a pointer since
  the pointer should never be null anyways.

Bug: 40643114
Change-Id: I53ce8ac0e4a16c3e3aa1cddb8860d9d1a5727b3f
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/244275
Commit-Queue: Loko Kung <lokokung@google.com>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
diff --git a/src/dawn/common/MutexProtected.h b/src/dawn/common/MutexProtected.h
index ab2e6fc..3976d0d 100644
--- a/src/dawn/common/MutexProtected.h
+++ b/src/dawn/common/MutexProtected.h
@@ -35,13 +35,15 @@
 #include "dawn/common/NonMovable.h"
 #include "dawn/common/Ref.h"
 #include "dawn/common/StackAllocated.h"
-#include "partition_alloc/pointers/raw_ptr_exclusion.h"
 
 namespace dawn {
 
 template <typename T, template <typename, typename> class Guard>
 class MutexProtected;
 
+template <typename T, template <typename, typename> class Guard>
+class MutexProtectedSupport;
+
 namespace detail {
 
 template <typename T>
@@ -68,6 +70,17 @@
     static const ObjectType* GetObj(const Ref<T>* const obj) { return obj->Get(); }
 };
 
+template <typename T>
+struct MutexProtectedSupportTraits {
+    using MutexType = std::mutex;
+    using LockType = std::unique_lock<std::mutex>;
+
+    static MutexType CreateMutex() { return std::mutex(); }
+    static std::mutex& GetMutex(MutexType& m) { return m; }
+    static auto* GetObj(T* const obj) { return &obj->mImpl; }
+    static const auto* GetObj(const T* const obj) { return &obj->mImpl; }
+};
+
 // Guard class is a wrapping class that gives access to a protected resource after acquiring the
 // lock related to it. For the lifetime of this class, the lock is held.
 template <typename T, typename Traits>
@@ -78,26 +91,50 @@
     auto& operator*() const { return *Get(); }
 
   protected:
-    Guard(T* obj, typename Traits::MutexType& mutex) : mLock(Traits::GetMutex(mutex)), mObj(obj) {}
-    Guard(Guard&& other) : mLock(std::move(other.mLock)), mObj(std::move(other.mObj)) {
-        other.mObj = nullptr;
-    }
+    Guard(T* obj, typename Traits::MutexType& mutex) : mLock(Traits::GetMutex(mutex)), mObj(*obj) {}
+    Guard(Guard&& other) : mLock(std::move(other.mLock)), mObj(other.mObj) {}
 
     Guard(const Guard& other) = delete;
     Guard& operator=(const Guard& other) = delete;
     Guard& operator=(Guard&& other) = delete;
 
-    auto* Get() const { return Traits::GetObj(mObj); }
+    auto* Get() const { return Traits::GetObj(&mObj); }
 
   private:
     using NonConstT = typename std::remove_const<T>::type;
+    friend class MutexProtectedSupport<NonConstT, Guard>;
     friend class MutexProtected<NonConstT, Guard>;
 
     typename Traits::LockType mLock;
-    // RAW_PTR_EXCLUSION: This pointer is created/destroyed on each access to a MutexProtected.
-    // The pointer is always transiently used while the MutexProtected is in scope so it is
-    // unlikely to be used after it is freed.
-    RAW_PTR_EXCLUSION T* mObj = nullptr;
+    T& mObj;
+};
+
+template <typename T, typename Traits, template <typename, typename> class Guard = detail::Guard>
+class MutexProtectedBase {
+  public:
+    using Usage = Guard<T, Traits>;
+    using ConstUsage = Guard<const T, Traits>;
+
+    MutexProtectedBase() : mMutex(Traits::CreateMutex()) {}
+    virtual ~MutexProtectedBase() = default;
+
+    Usage operator->() { return Use(); }
+    ConstUsage operator->() const { return Use(); }
+
+    template <typename Fn>
+    auto Use(Fn&& fn) {
+        return fn(Use());
+    }
+    template <typename Fn>
+    auto Use(Fn&& fn) const {
+        return fn(Use());
+    }
+
+  protected:
+    virtual Usage Use() = 0;
+    virtual ConstUsage Use() const = 0;
+
+    mutable typename Traits::MutexType mMutex;
 };
 
 }  // namespace detail
@@ -135,39 +172,66 @@
 //         MutexProtected<Allocator> mAllocator;
 //     };
 template <typename T, template <typename, typename> class Guard = detail::Guard>
-class MutexProtected {
+class MutexProtected
+    : public detail::MutexProtectedBase<T, detail::MutexProtectedTraits<T>, Guard> {
   public:
     using Traits = detail::MutexProtectedTraits<T>;
-    using Usage = Guard<T, Traits>;
-    using ConstUsage = Guard<const T, Traits>;
-
-    MutexProtected() : mMutex(Traits::CreateMutex()) {}
+    using Base = detail::MutexProtectedBase<T, Traits, Guard>;
+    using typename Base::ConstUsage;
+    using typename Base::Usage;
 
     template <typename... Args>
     // NOLINTNEXTLINE(runtime/explicit) allow implicit construction
-    MutexProtected(Args&&... args)
-        : mMutex(Traits::CreateMutex()), mObj(std::forward<Args>(args)...) {}
+    MutexProtected(Args&&... args) : mObj(std::forward<Args>(args)...) {}
 
-    Usage operator->() { return Use(); }
-    ConstUsage operator->() const { return Use(); }
-
-    template <typename Fn>
-    auto Use(Fn&& fn) {
-        return fn(Use());
-    }
-    template <typename Fn>
-    auto Use(Fn&& fn) const {
-        return fn(Use());
-    }
+    using Base::Use;
 
   private:
-    Usage Use() { return Usage(&mObj, mMutex); }
-    ConstUsage Use() const { return ConstUsage(&mObj, mMutex); }
+    Usage Use() override { return Usage(&mObj, this->mMutex); }
+    ConstUsage Use() const override { return ConstUsage(&mObj, this->mMutex); }
 
-    mutable typename Traits::MutexType mMutex;
     T mObj;
 };
 
+// CRTP wrapper to help create classes that are generally MutexProtected, but may wish to implement
+// specific workarounds to avoid taking the lock in certain scenarios. See the example below and the
+// unittests for more example usages of this wrapper. Example usage:
+//     struct Counter : public MutexProtectedSupport<Counter> {
+//       public:
+//         // Reads the value stored in |mCounter| without acquiring the lock.
+//         int UnsafeRead() {
+//             return mImpl.mCounter;
+//         }
+//
+//       private:
+//         // This friend declaration MUST be included in all classes using this wrapper.
+//         friend typename MutexProtectedSupport<Counter>::Traits;
+//
+//         // Internal struct that wraps all the actual data that we want to be protected. Note that
+//         // this struct currently MUST be named |mImpl| to work.
+//         struct {
+//             int mCounter = 0;
+//         } mImpl;
+//     };
+//     // Other uses of this struct look as if we are using a MutexProtected<mImpl>.
+template <typename T, template <typename, typename> class Guard = detail::Guard>
+class MutexProtectedSupport
+    : public detail::MutexProtectedBase<T, detail::MutexProtectedSupportTraits<T>, Guard> {
+  public:
+    using Traits = detail::MutexProtectedSupportTraits<T>;
+    using Base = detail::MutexProtectedBase<T, Traits, Guard>;
+    using typename Base::ConstUsage;
+    using typename Base::Usage;
+
+    using Base::Use;
+
+  private:
+    Usage Use() override { return Usage(static_cast<T*>(this), this->mMutex); }
+    ConstUsage Use() const override {
+        return ConstUsage(static_cast<const T*>(this), this->mMutex);
+    }
+};
+
 }  // namespace dawn
 
 #endif  // SRC_DAWN_COMMON_MUTEXPROTECTED_H_
diff --git a/src/dawn/tests/unittests/MutexProtectedTests.cpp b/src/dawn/tests/unittests/MutexProtectedTests.cpp
index de6f20a..74f6b87 100644
--- a/src/dawn/tests/unittests/MutexProtectedTests.cpp
+++ b/src/dawn/tests/unittests/MutexProtectedTests.cpp
@@ -41,6 +41,75 @@
 using ::testing::Test;
 using ::testing::Types;
 
+class MutexSupportedCounterT : public MutexProtectedSupport<MutexSupportedCounterT> {
+  public:
+    // This is an unsafe read of the count without acquiring the lock.
+    int ReadCount() { return mImpl.mCount; }
+
+  private:
+    friend typename MutexProtectedSupport<MutexSupportedCounterT>::Traits;
+
+    struct {
+        int mCount = 0;
+    } mImpl;
+};
+
+TEST(MutexProtectedSupportTests, Nominal) {
+    static constexpr int kIncrementCount = 100;
+    static constexpr int kDecrementCount = 50;
+
+    MutexSupportedCounterT counter;
+
+    auto increment = [&] {
+        for (uint32_t i = 0; i < kIncrementCount; i++) {
+            counter->mCount++;
+        }
+    };
+    auto useIncrement = [&] {
+        for (uint32_t i = 0; i < kIncrementCount; i++) {
+            counter.Use([](auto c) { c->mCount++; });
+        }
+    };
+    auto decrement = [&] {
+        for (uint32_t i = 0; i < kDecrementCount; i++) {
+            counter->mCount--;
+        }
+    };
+    auto useDecrement = [&] {
+        for (uint32_t i = 0; i < kDecrementCount; i++) {
+            counter.Use([](auto c) { c->mCount--; });
+        }
+    };
+
+    std::thread incrementThread(increment);
+    std::thread useIncrementThread(useIncrement);
+    std::thread decrementThread(decrement);
+    std::thread useDecrementThread(useDecrement);
+    incrementThread.join();
+    useIncrementThread.join();
+    decrementThread.join();
+    useDecrementThread.join();
+
+    EXPECT_EQ(counter->mCount, 2 * (kIncrementCount - kDecrementCount));
+}
+
+// Verifies that if we call additionally implemented functions when using the MutexProtectedSupport
+// wrapper, that they do not acquire the lock. If the lock was acquired, then this test would
+// deadlock.
+TEST(MutexProtectedSupportTests, UnsafeRead) {
+    MutexSupportedCounterT counter;
+
+    // Acquire the lock via the Use function.
+    counter.Use([&](auto c) {
+        // With the lock acquired, we should be able to call additionally implemented functions that
+        // do not acquire the lock.
+        c->mCount = 1;
+        EXPECT_EQ(counter.ReadCount(), 1);
+        c->mCount = 2;
+        EXPECT_EQ(counter.ReadCount(), 2);
+    });
+}
+
 // Simple thread-unsafe counter class.
 class CounterT : public RefCounted {
   public: