Introduces MutexProtected wrapper for thread-safety.
Bug: dawn:1662
Change-Id: Ie960665e0688112c5edddc53b428329b5f41a71b
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/144520
Commit-Queue: Loko Kung <lokokung@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Austin Eng <enga@chromium.org>
diff --git a/src/dawn/common/BUILD.gn b/src/dawn/common/BUILD.gn
index 6b72bb5..cacfd25 100644
--- a/src/dawn/common/BUILD.gn
+++ b/src/dawn/common/BUILD.gn
@@ -253,6 +253,7 @@
"Math.h",
"Mutex.cpp",
"Mutex.h",
+ "MutexProtected.h",
"NSRef.h",
"NonCopyable.h",
"Numeric.h",
diff --git a/src/dawn/common/CMakeLists.txt b/src/dawn/common/CMakeLists.txt
index ca78909..aa20870 100644
--- a/src/dawn/common/CMakeLists.txt
+++ b/src/dawn/common/CMakeLists.txt
@@ -56,6 +56,7 @@
"Math.h"
"Mutex.cpp"
"Mutex.h"
+ "MutexProtected.h"
"NSRef.h"
"NonCopyable.h"
"Numeric.h"
diff --git a/src/dawn/common/MutexProtected.h b/src/dawn/common/MutexProtected.h
new file mode 100644
index 0000000..a14daf1
--- /dev/null
+++ b/src/dawn/common/MutexProtected.h
@@ -0,0 +1,142 @@
+// Copyright 2023 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_DAWN_COMMON_MUTEXPROTECTED_H_
+#define SRC_DAWN_COMMON_MUTEXPROTECTED_H_
+
+#include <mutex>
+#include <utility>
+
+#include "dawn/common/Mutex.h"
+#include "dawn/common/Ref.h"
+
+namespace dawn {
+
+template <typename T>
+class MutexProtected;
+
+namespace detail {
+
+template <typename T>
+struct MutexProtectedTraits {
+ using MutexType = std::mutex;
+ using LockType = std::lock_guard<std::mutex>;
+ using ObjectType = T;
+
+ static MutexType CreateMutex() { return std::mutex(); }
+ static std::mutex& GetMutex(MutexType& m) { return m; }
+ static ObjectType* GetObj(T* const obj) { return obj; }
+ static const ObjectType* GetObj(const T* const obj) { return obj; }
+};
+
+template <typename T>
+struct MutexProtectedTraits<Ref<T>> {
+ using MutexType = Ref<Mutex>;
+ using LockType = Mutex::AutoLock;
+ using ObjectType = T;
+
+ static MutexType CreateMutex() { return AcquireRef(new Mutex()); }
+ static Mutex* GetMutex(MutexType& m) { return m.Get(); }
+ static ObjectType* GetObj(Ref<T>* const obj) { return obj->Get(); }
+ static const ObjectType* GetObj(const Ref<T>* const obj) { return obj->Get(); }
+};
+
+// Guard class is a wrapping class that gives access to a protected resource after acquiring the
+// lock related to it. For the lifetime of this class, the lock is held.
+template <typename T, typename Traits>
+class Guard {
+ public:
+ using ReturnType = typename UnwrapRef<T>::type;
+
+ // It's the programmer's burden to not save the pointer/reference and reuse it without the lock.
+ ReturnType* operator->() { return Traits::GetObj(mObj); }
+ ReturnType& operator*() { return *Traits::GetObj(mObj); }
+ const ReturnType* operator->() const { return Traits::GetObj(mObj); }
+ const ReturnType& operator*() const { return *Traits::GetObj(mObj); }
+
+ private:
+ friend class MutexProtected<T>;
+
+ Guard(T* obj, typename Traits::MutexType& mutex) : mLock(Traits::GetMutex(mutex)), mObj(obj) {}
+
+ typename Traits::LockType mLock;
+ T* const mObj;
+};
+
+} // namespace detail
+
+// Wrapping class used for object members to ensure usage of the resource is protected with a mutex.
+// Example usage:
+// class Allocator {
+// public:
+// Allocation Allocate();
+// void Deallocate(Allocation&);
+// };
+// class AllocatorUser {
+// public:
+// void OnlyAllocate() {
+// auto allocation = mAllocator->Allocate();
+// }
+// void AtomicAllocateDeallocate() {
+// // Operations:
+// // - acquire lock
+// // - Allocate, Deallocate
+// // - release lock
+// mAllocator.Use([](auto allocator) {
+// auto allocation = allocator->Allocate();
+// allocator->Deallocate(allocation);
+// });
+// }
+// void NonAtomicAllocateDeallocate() {
+// // Operations:
+// // - acquire lock, Allocate, release lock
+// // - acquire lock, Deallocate, release lock
+// auto allocation = mAllocator->Allocate();
+// mAllocator->Deallocate(allocation);
+// }
+// private:
+// MutexProtected<Allocator> mAllocator;
+// };
+template <typename T>
+class MutexProtected {
+ public:
+ using Traits = detail::MutexProtectedTraits<T>;
+ using Usage = detail::Guard<T, Traits>;
+ using ConstUsage = detail::Guard<const T, Traits>;
+
+ MutexProtected() : mMutex(Traits::CreateMutex()) {}
+
+ template <typename... Args>
+ explicit MutexProtected(Args&&... args)
+ : mMutex(Traits::CreateMutex()), mObj(std::forward<Args>(args)...) {}
+
+ Usage operator->() { return Use(); }
+ ConstUsage operator->() const { return Use(); }
+
+ template <typename Fn>
+ auto Use(Fn&& fn) {
+ return fn(Use());
+ }
+
+ private:
+ Usage Use() { return Usage(&mObj, mMutex); }
+ ConstUsage Use() const { return ConstUsage(&mObj, mMutex); }
+
+ typename Traits::MutexType mMutex;
+ T mObj;
+};
+
+} // namespace dawn
+
+#endif // SRC_DAWN_COMMON_MUTEXPROTECTED_H_
diff --git a/src/dawn/common/Ref.h b/src/dawn/common/Ref.h
index f0e947c..9dd7dfb 100644
--- a/src/dawn/common/Ref.h
+++ b/src/dawn/common/Ref.h
@@ -16,7 +16,6 @@
#define SRC_DAWN_COMMON_REF_H_
#include <mutex>
-#include <type_traits>
#include "dawn/common/RefBase.h"
#include "dawn/common/RefCounted.h"
@@ -40,6 +39,24 @@
} // namespace detail
template <typename T>
+struct UnwrapRef {
+ using type = T;
+};
+template <typename T>
+struct UnwrapRef<Ref<T>> {
+ using type = T;
+};
+
+template <typename T>
+struct IsRef {
+ static constexpr bool value = false;
+};
+template <typename T>
+struct IsRef<Ref<T>> {
+ static constexpr bool value = true;
+};
+
+template <typename T>
class Ref : public RefBase<T*, detail::RefCountedTraits<T>> {
public:
using RefBase<T*, detail::RefCountedTraits<T>>::RefBase;
diff --git a/src/dawn/tests/BUILD.gn b/src/dawn/tests/BUILD.gn
index 1aef11a..0004571 100644
--- a/src/dawn/tests/BUILD.gn
+++ b/src/dawn/tests/BUILD.gn
@@ -305,6 +305,7 @@
"unittests/LimitsTests.cpp",
"unittests/LinkedListTests.cpp",
"unittests/MathTests.cpp",
+ "unittests/MutexProtectedTests.cpp",
"unittests/MutexTests.cpp",
"unittests/NumericTests.cpp",
"unittests/ObjectBaseTests.cpp",
diff --git a/src/dawn/tests/unittests/MutexProtectedTests.cpp b/src/dawn/tests/unittests/MutexProtectedTests.cpp
new file mode 100644
index 0000000..9cf1bcc
--- /dev/null
+++ b/src/dawn/tests/unittests/MutexProtectedTests.cpp
@@ -0,0 +1,212 @@
+// Copyright 2023 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <string>
+#include <thread>
+#include <type_traits>
+#include <utility>
+
+#include "dawn/common/MutexProtected.h"
+#include "dawn/common/Ref.h"
+#include "dawn/common/RefCounted.h"
+#include "gtest/gtest.h"
+
+namespace dawn {
+namespace {
+
+using ::testing::Test;
+using ::testing::Types;
+
+// Simple thread-unsafe counter class.
+class CounterT : public RefCounted {
+ public:
+ CounterT() = default;
+ explicit CounterT(int count) : mCount(count) {}
+
+ int Get() const { return mCount; }
+
+ void Increment() { mCount++; }
+ void Decrement() { mCount--; }
+
+ private:
+ int mCount = 0;
+};
+
+template <typename T>
+MutexProtected<T> CreateDefault() {
+ if constexpr (IsRef<T>::value) {
+ return MutexProtected<T>(AcquireRef(new typename UnwrapRef<T>::type()));
+ } else {
+ return MutexProtected<T>();
+ }
+}
+
+template <typename T, typename... Args>
+MutexProtected<T> CreateCustom(Args&&... args) {
+ if constexpr (IsRef<T>::value) {
+ return MutexProtected<T>(
+ AcquireRef(new typename UnwrapRef<T>::type(std::forward<Args>(args)...)));
+ } else {
+ return MutexProtected<T>(std::forward<Args>(args)...);
+ }
+}
+
+template <typename T>
+class MutexProtectedTest : public Test {};
+
+class MutexProtectedTestTypeNames {
+ public:
+ template <typename T>
+ static std::string GetName(int) {
+ if (std::is_same<T, CounterT>()) {
+ return "CounterT";
+ }
+ if (std::is_same<T, Ref<CounterT>>()) {
+ return "Ref<CounterT>";
+ }
+ }
+};
+using MutexProtectedTestTypes = Types<CounterT, Ref<CounterT>>;
+TYPED_TEST_SUITE(MutexProtectedTest, MutexProtectedTestTypes, MutexProtectedTestTypeNames);
+
+TYPED_TEST(MutexProtectedTest, DefaultCtor) {
+ static constexpr int kIncrementCount = 100;
+ static constexpr int kDecrementCount = 50;
+
+ MutexProtected<TypeParam> counter = CreateDefault<TypeParam>();
+
+ auto increment = [&] {
+ for (uint32_t i = 0; i < kIncrementCount; i++) {
+ counter->Increment();
+ }
+ };
+ auto useIncrement = [&] {
+ for (uint32_t i = 0; i < kIncrementCount; i++) {
+ counter.Use([](auto c) { c->Increment(); });
+ }
+ };
+ auto decrement = [&] {
+ for (uint32_t i = 0; i < kDecrementCount; i++) {
+ counter->Decrement();
+ }
+ };
+ auto useDecrement = [&] {
+ for (uint32_t i = 0; i < kDecrementCount; i++) {
+ counter.Use([](auto c) { c->Decrement(); });
+ }
+ };
+
+ std::thread incrementThread(increment);
+ std::thread useIncrementThread(useIncrement);
+ std::thread decrementThread(decrement);
+ std::thread useDecrementThread(useDecrement);
+ incrementThread.join();
+ useIncrementThread.join();
+ decrementThread.join();
+ useDecrementThread.join();
+
+ EXPECT_EQ(counter->Get(), 2 * (kIncrementCount - kDecrementCount));
+}
+
+TYPED_TEST(MutexProtectedTest, CustomCtor) {
+ static constexpr int kIncrementCount = 100;
+ static constexpr int kDecrementCount = 50;
+ static constexpr int kStartingcount = -100;
+
+ MutexProtected<TypeParam> counter = CreateCustom<TypeParam>(kStartingcount);
+
+ auto increment = [&] {
+ for (uint32_t i = 0; i < kIncrementCount; i++) {
+ counter->Increment();
+ }
+ };
+ auto useIncrement = [&] {
+ for (uint32_t i = 0; i < kIncrementCount; i++) {
+ counter.Use([](auto c) { c->Increment(); });
+ }
+ };
+ auto decrement = [&] {
+ for (uint32_t i = 0; i < kDecrementCount; i++) {
+ counter->Decrement();
+ }
+ };
+ auto useDecrement = [&] {
+ for (uint32_t i = 0; i < kDecrementCount; i++) {
+ counter.Use([](auto c) { c->Decrement(); });
+ }
+ };
+
+ std::thread incrementThread(increment);
+ std::thread useIncrementThread(useIncrement);
+ std::thread decrementThread(decrement);
+ std::thread useDecrementThread(useDecrement);
+ incrementThread.join();
+ useIncrementThread.join();
+ decrementThread.join();
+ useDecrementThread.join();
+
+ EXPECT_EQ(counter->Get(), kStartingcount + 2 * (kIncrementCount - kDecrementCount));
+}
+
+TYPED_TEST(MutexProtectedTest, MultipleProtected) {
+ static constexpr int kIncrementCount = 100;
+
+ MutexProtected<TypeParam> c1 = CreateDefault<TypeParam>();
+ MutexProtected<TypeParam> c2 = CreateDefault<TypeParam>();
+
+ auto increment = [&] {
+ for (uint32_t i = 0; i < kIncrementCount; i++) {
+ c1.Use([&](auto x1) {
+ c2.Use([&](auto x2) {
+ x1->Increment();
+ x2->Increment();
+ });
+ });
+ }
+ };
+ auto validate = [&] {
+ for (uint32_t i = 0; i < kIncrementCount; i++) {
+ c1.Use([&](auto x1) { c2.Use([&](auto x2) { EXPECT_EQ(x1->Get(), x2->Get()); }); });
+ }
+ };
+ std::thread incrementThread(increment);
+ std::thread validateThread(validate);
+ incrementThread.join();
+ validateThread.join();
+}
+
+} // anonymous namespace
+} // namespace dawn
+
+// Special compilation tests that are only enabled when experimental headers are available.
+#if __has_include(<experimental/type_traits>)
+#include <experimental/type_traits>
+
+namespace dawn {
+namespace {
+
+// MutexProtected types are only copyable when they are wrapping a Ref type.
+template <typename T>
+using mutexprotected_copyable_t =
+ decltype(std::declval<MutexProtected<T>&>() = std::declval<const MutexProtected<T>&>());
+TYPED_TEST(MutexProtectedTest, Copyable) {
+ static_assert(IsRef<TypeParam>::value ==
+ std::experimental::is_detected_v<mutexprotected_copyable_t, TypeParam>,
+ "Copy assignment is only allowed when the wrapping type is a Ref.");
+}
+
+} // anonymous namespace
+} // namespace dawn
+
+#endif