Introduces MutexProtected wrapper for thread-safety.

Bug: dawn:1662
Change-Id: Ie960665e0688112c5edddc53b428329b5f41a71b
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/144520
Commit-Queue: Loko Kung <lokokung@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Austin Eng <enga@chromium.org>
diff --git a/src/dawn/common/BUILD.gn b/src/dawn/common/BUILD.gn
index 6b72bb5..cacfd25 100644
--- a/src/dawn/common/BUILD.gn
+++ b/src/dawn/common/BUILD.gn
@@ -253,6 +253,7 @@
       "Math.h",
       "Mutex.cpp",
       "Mutex.h",
+      "MutexProtected.h",
       "NSRef.h",
       "NonCopyable.h",
       "Numeric.h",
diff --git a/src/dawn/common/CMakeLists.txt b/src/dawn/common/CMakeLists.txt
index ca78909..aa20870 100644
--- a/src/dawn/common/CMakeLists.txt
+++ b/src/dawn/common/CMakeLists.txt
@@ -56,6 +56,7 @@
     "Math.h"
     "Mutex.cpp"
     "Mutex.h"
+    "MutexProtected.h"
     "NSRef.h"
     "NonCopyable.h"
     "Numeric.h"
diff --git a/src/dawn/common/MutexProtected.h b/src/dawn/common/MutexProtected.h
new file mode 100644
index 0000000..a14daf1
--- /dev/null
+++ b/src/dawn/common/MutexProtected.h
@@ -0,0 +1,142 @@
+// Copyright 2023 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_DAWN_COMMON_MUTEXPROTECTED_H_
+#define SRC_DAWN_COMMON_MUTEXPROTECTED_H_
+
+#include <mutex>
+#include <utility>
+
+#include "dawn/common/Mutex.h"
+#include "dawn/common/Ref.h"
+
+namespace dawn {
+
+template <typename T>
+class MutexProtected;
+
+namespace detail {
+
+template <typename T>
+struct MutexProtectedTraits {
+    using MutexType = std::mutex;
+    using LockType = std::lock_guard<std::mutex>;
+    using ObjectType = T;
+
+    static MutexType CreateMutex() { return std::mutex(); }
+    static std::mutex& GetMutex(MutexType& m) { return m; }
+    static ObjectType* GetObj(T* const obj) { return obj; }
+    static const ObjectType* GetObj(const T* const obj) { return obj; }
+};
+
+template <typename T>
+struct MutexProtectedTraits<Ref<T>> {
+    using MutexType = Ref<Mutex>;
+    using LockType = Mutex::AutoLock;
+    using ObjectType = T;
+
+    static MutexType CreateMutex() { return AcquireRef(new Mutex()); }
+    static Mutex* GetMutex(MutexType& m) { return m.Get(); }
+    static ObjectType* GetObj(Ref<T>* const obj) { return obj->Get(); }
+    static const ObjectType* GetObj(const Ref<T>* const obj) { return obj->Get(); }
+};
+
+// Guard class is a wrapping class that gives access to a protected resource after acquiring the
+// lock related to it. For the lifetime of this class, the lock is held.
+template <typename T, typename Traits>
+class Guard {
+  public:
+    using ReturnType = typename UnwrapRef<T>::type;
+
+    // It's the programmer's burden to not save the pointer/reference and reuse it without the lock.
+    ReturnType* operator->() { return Traits::GetObj(mObj); }
+    ReturnType& operator*() { return *Traits::GetObj(mObj); }
+    const ReturnType* operator->() const { return Traits::GetObj(mObj); }
+    const ReturnType& operator*() const { return *Traits::GetObj(mObj); }
+
+  private:
+    friend class MutexProtected<T>;
+
+    Guard(T* obj, typename Traits::MutexType& mutex) : mLock(Traits::GetMutex(mutex)), mObj(obj) {}
+
+    typename Traits::LockType mLock;
+    T* const mObj;
+};
+
+}  // namespace detail
+
+// Wrapping class used for object members to ensure usage of the resource is protected with a mutex.
+// Example usage:
+//     class Allocator {
+//       public:
+//         Allocation Allocate();
+//         void Deallocate(Allocation&);
+//     };
+//     class AllocatorUser {
+//       public:
+//         void OnlyAllocate() {
+//             auto allocation = mAllocator->Allocate();
+//         }
+//         void AtomicAllocateDeallocate() {
+//             // Operations:
+//             //   - acquire lock
+//             //   - Allocate, Deallocate
+//             //   - release lock
+//             mAllocator.Use([](auto allocator) {
+//                 auto allocation = allocator->Allocate();
+//                 allocator->Deallocate(allocation);
+//             });
+//         }
+//         void NonAtomicAllocateDeallocate() {
+//             // Operations:
+//             //   - acquire lock, Allocate, release lock
+//             //   - acquire lock, Deallocate, release lock
+//             auto allocation = mAllocator->Allocate();
+//             mAllocator->Deallocate(allocation);
+//         }
+//       private:
+//         MutexProtected<Allocator> mAllocator;
+//     };
+template <typename T>
+class MutexProtected {
+  public:
+    using Traits = detail::MutexProtectedTraits<T>;
+    using Usage = detail::Guard<T, Traits>;
+    using ConstUsage = detail::Guard<const T, Traits>;
+
+    MutexProtected() : mMutex(Traits::CreateMutex()) {}
+
+    template <typename... Args>
+    explicit MutexProtected(Args&&... args)
+        : mMutex(Traits::CreateMutex()), mObj(std::forward<Args>(args)...) {}
+
+    Usage operator->() { return Use(); }
+    ConstUsage operator->() const { return Use(); }
+
+    template <typename Fn>
+    auto Use(Fn&& fn) {
+        return fn(Use());
+    }
+
+  private:
+    Usage Use() { return Usage(&mObj, mMutex); }
+    ConstUsage Use() const { return ConstUsage(&mObj, mMutex); }
+
+    typename Traits::MutexType mMutex;
+    T mObj;
+};
+
+}  // namespace dawn
+
+#endif  // SRC_DAWN_COMMON_MUTEXPROTECTED_H_
diff --git a/src/dawn/common/Ref.h b/src/dawn/common/Ref.h
index f0e947c..9dd7dfb 100644
--- a/src/dawn/common/Ref.h
+++ b/src/dawn/common/Ref.h
@@ -16,7 +16,6 @@
 #define SRC_DAWN_COMMON_REF_H_
 
 #include <mutex>
-#include <type_traits>
 
 #include "dawn/common/RefBase.h"
 #include "dawn/common/RefCounted.h"
@@ -40,6 +39,24 @@
 }  // namespace detail
 
 template <typename T>
+struct UnwrapRef {
+    using type = T;
+};
+template <typename T>
+struct UnwrapRef<Ref<T>> {
+    using type = T;
+};
+
+template <typename T>
+struct IsRef {
+    static constexpr bool value = false;
+};
+template <typename T>
+struct IsRef<Ref<T>> {
+    static constexpr bool value = true;
+};
+
+template <typename T>
 class Ref : public RefBase<T*, detail::RefCountedTraits<T>> {
   public:
     using RefBase<T*, detail::RefCountedTraits<T>>::RefBase;
diff --git a/src/dawn/tests/BUILD.gn b/src/dawn/tests/BUILD.gn
index 1aef11a..0004571 100644
--- a/src/dawn/tests/BUILD.gn
+++ b/src/dawn/tests/BUILD.gn
@@ -305,6 +305,7 @@
     "unittests/LimitsTests.cpp",
     "unittests/LinkedListTests.cpp",
     "unittests/MathTests.cpp",
+    "unittests/MutexProtectedTests.cpp",
     "unittests/MutexTests.cpp",
     "unittests/NumericTests.cpp",
     "unittests/ObjectBaseTests.cpp",
diff --git a/src/dawn/tests/unittests/MutexProtectedTests.cpp b/src/dawn/tests/unittests/MutexProtectedTests.cpp
new file mode 100644
index 0000000..9cf1bcc
--- /dev/null
+++ b/src/dawn/tests/unittests/MutexProtectedTests.cpp
@@ -0,0 +1,212 @@
+// Copyright 2023 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <string>
+#include <thread>
+#include <type_traits>
+#include <utility>
+
+#include "dawn/common/MutexProtected.h"
+#include "dawn/common/Ref.h"
+#include "dawn/common/RefCounted.h"
+#include "gtest/gtest.h"
+
+namespace dawn {
+namespace {
+
+using ::testing::Test;
+using ::testing::Types;
+
+// Simple thread-unsafe counter class.
+class CounterT : public RefCounted {
+  public:
+    CounterT() = default;
+    explicit CounterT(int count) : mCount(count) {}
+
+    int Get() const { return mCount; }
+
+    void Increment() { mCount++; }
+    void Decrement() { mCount--; }
+
+  private:
+    int mCount = 0;
+};
+
+template <typename T>
+MutexProtected<T> CreateDefault() {
+    if constexpr (IsRef<T>::value) {
+        return MutexProtected<T>(AcquireRef(new typename UnwrapRef<T>::type()));
+    } else {
+        return MutexProtected<T>();
+    }
+}
+
+template <typename T, typename... Args>
+MutexProtected<T> CreateCustom(Args&&... args) {
+    if constexpr (IsRef<T>::value) {
+        return MutexProtected<T>(
+            AcquireRef(new typename UnwrapRef<T>::type(std::forward<Args>(args)...)));
+    } else {
+        return MutexProtected<T>(std::forward<Args>(args)...);
+    }
+}
+
+template <typename T>
+class MutexProtectedTest : public Test {};
+
+class MutexProtectedTestTypeNames {
+  public:
+    template <typename T>
+    static std::string GetName(int) {
+        if (std::is_same<T, CounterT>()) {
+            return "CounterT";
+        }
+        if (std::is_same<T, Ref<CounterT>>()) {
+            return "Ref<CounterT>";
+        }
+    }
+};
+using MutexProtectedTestTypes = Types<CounterT, Ref<CounterT>>;
+TYPED_TEST_SUITE(MutexProtectedTest, MutexProtectedTestTypes, MutexProtectedTestTypeNames);
+
+TYPED_TEST(MutexProtectedTest, DefaultCtor) {
+    static constexpr int kIncrementCount = 100;
+    static constexpr int kDecrementCount = 50;
+
+    MutexProtected<TypeParam> counter = CreateDefault<TypeParam>();
+
+    auto increment = [&] {
+        for (uint32_t i = 0; i < kIncrementCount; i++) {
+            counter->Increment();
+        }
+    };
+    auto useIncrement = [&] {
+        for (uint32_t i = 0; i < kIncrementCount; i++) {
+            counter.Use([](auto c) { c->Increment(); });
+        }
+    };
+    auto decrement = [&] {
+        for (uint32_t i = 0; i < kDecrementCount; i++) {
+            counter->Decrement();
+        }
+    };
+    auto useDecrement = [&] {
+        for (uint32_t i = 0; i < kDecrementCount; i++) {
+            counter.Use([](auto c) { c->Decrement(); });
+        }
+    };
+
+    std::thread incrementThread(increment);
+    std::thread useIncrementThread(useIncrement);
+    std::thread decrementThread(decrement);
+    std::thread useDecrementThread(useDecrement);
+    incrementThread.join();
+    useIncrementThread.join();
+    decrementThread.join();
+    useDecrementThread.join();
+
+    EXPECT_EQ(counter->Get(), 2 * (kIncrementCount - kDecrementCount));
+}
+
+TYPED_TEST(MutexProtectedTest, CustomCtor) {
+    static constexpr int kIncrementCount = 100;
+    static constexpr int kDecrementCount = 50;
+    static constexpr int kStartingcount = -100;
+
+    MutexProtected<TypeParam> counter = CreateCustom<TypeParam>(kStartingcount);
+
+    auto increment = [&] {
+        for (uint32_t i = 0; i < kIncrementCount; i++) {
+            counter->Increment();
+        }
+    };
+    auto useIncrement = [&] {
+        for (uint32_t i = 0; i < kIncrementCount; i++) {
+            counter.Use([](auto c) { c->Increment(); });
+        }
+    };
+    auto decrement = [&] {
+        for (uint32_t i = 0; i < kDecrementCount; i++) {
+            counter->Decrement();
+        }
+    };
+    auto useDecrement = [&] {
+        for (uint32_t i = 0; i < kDecrementCount; i++) {
+            counter.Use([](auto c) { c->Decrement(); });
+        }
+    };
+
+    std::thread incrementThread(increment);
+    std::thread useIncrementThread(useIncrement);
+    std::thread decrementThread(decrement);
+    std::thread useDecrementThread(useDecrement);
+    incrementThread.join();
+    useIncrementThread.join();
+    decrementThread.join();
+    useDecrementThread.join();
+
+    EXPECT_EQ(counter->Get(), kStartingcount + 2 * (kIncrementCount - kDecrementCount));
+}
+
+TYPED_TEST(MutexProtectedTest, MultipleProtected) {
+    static constexpr int kIncrementCount = 100;
+
+    MutexProtected<TypeParam> c1 = CreateDefault<TypeParam>();
+    MutexProtected<TypeParam> c2 = CreateDefault<TypeParam>();
+
+    auto increment = [&] {
+        for (uint32_t i = 0; i < kIncrementCount; i++) {
+            c1.Use([&](auto x1) {
+                c2.Use([&](auto x2) {
+                    x1->Increment();
+                    x2->Increment();
+                });
+            });
+        }
+    };
+    auto validate = [&] {
+        for (uint32_t i = 0; i < kIncrementCount; i++) {
+            c1.Use([&](auto x1) { c2.Use([&](auto x2) { EXPECT_EQ(x1->Get(), x2->Get()); }); });
+        }
+    };
+    std::thread incrementThread(increment);
+    std::thread validateThread(validate);
+    incrementThread.join();
+    validateThread.join();
+}
+
+}  // anonymous namespace
+}  // namespace dawn
+
+// Special compilation tests that are only enabled when experimental headers are available.
+#if __has_include(<experimental/type_traits>)
+#include <experimental/type_traits>
+
+namespace dawn {
+namespace {
+
+// MutexProtected types are only copyable when they are wrapping a Ref type.
+template <typename T>
+using mutexprotected_copyable_t =
+    decltype(std::declval<MutexProtected<T>&>() = std::declval<const MutexProtected<T>&>());
+TYPED_TEST(MutexProtectedTest, Copyable) {
+    static_assert(IsRef<TypeParam>::value ==
+                      std::experimental::is_detected_v<mutexprotected_copyable_t, TypeParam>,
+                  "Copy assignment is only allowed when the wrapping type is a Ref.");
+}
+
+}  // anonymous namespace
+}  // namespace dawn
+
+#endif