[dawn][emscripten] Adds Ref<> helpers for refcounting.
- Based on the RefBase from dawn/common, adds Ref<> to
help with refcounting.
- Adds RefCountedWithExternalCount to deal in preparations
for internally and externally ref counted objects, i.e.
the Device.
- Use Ref<> for existing places instead of explicitly
dealing with the refcounting functions.
Change-Id: Id2871ff87cb0870198e1c74d7c969482d9376ab8
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/203674
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
Commit-Queue: Loko Kung <lokokung@google.com>
Reviewed-by: Shrek Shao <shrekshao@google.com>
diff --git a/third_party/emdawnwebgpu/webgpu.cpp b/third_party/emdawnwebgpu/webgpu.cpp
index 56b7953..d529881 100644
--- a/third_party/emdawnwebgpu/webgpu.cpp
+++ b/third_party/emdawnwebgpu/webgpu.cpp
@@ -74,8 +74,7 @@
class RefCounted : NonMovable {
public:
- RefCounted() = default;
- virtual ~RefCounted() = default;
+ static constexpr bool HasExternalRefCount = false;
void AddRef() {
assert(mRefCount.fetch_add(1u, std::memory_order_relaxed) >= 1);
@@ -93,6 +92,129 @@
std::atomic<uint64_t> mRefCount = 1;
};
+class RefCountedWithExternalCount : public RefCounted {
+ public:
+ static constexpr bool HasExternalRefCount = true;
+
+ virtual ~RefCountedWithExternalCount() = default;
+
+ void AddRef() {
+ AddExternalRef();
+ RefCounted::AddRef();
+ }
+
+ void Release() {
+ if (mExternalRefCount.fetch_sub(1u, std::memory_order_release) == 1u) {
+ std::atomic_thread_fence(std::memory_order_acquire);
+ WillDropLastExternalRef();
+ }
+ RefCounted::Release();
+ }
+
+ void AddExternalRef() {
+ mExternalRefCount.fetch_add(1u, std::memory_order_relaxed);
+ }
+
+ private:
+ virtual void WillDropLastExternalRef() = 0;
+
+ std::atomic<uint64_t> mExternalRefCount = 0;
+};
+
+template <typename T>
+class Ref {
+ public:
+ static_assert(std::is_convertible_v<T, RefCounted*>,
+ "Cannot make a Ref<T> when T is not a Refcounted type.");
+
+ Ref() : mValue(nullptr) {}
+ ~Ref() {
+ if (mValue) {
+ mValue->Release();
+ }
+ }
+
+ // Constructors from nullptr.
+ // NOLINTNEXTLINE(runtime/explicit)
+ constexpr Ref(std::nullptr_t) : Ref() {}
+
+ // Constructors from T.
+ // NOLINTNEXTLINE(runtime/explicit)
+ Ref(T value) : mValue(value) { AddRef(value); }
+ Ref<T>& operator=(const T& value) {
+ Set(value);
+ return *this;
+ }
+
+ // Constructors from a Ref<T>.
+ Ref(const Ref<T>& other) : mValue(other.mValue) { AddRef(other.mValue); }
+ Ref<T>& operator=(const Ref<T>& other) {
+ Set(other.mValue);
+ return *this;
+ }
+ Ref(Ref<T>&& other) { mValue = other.Detach(); }
+ Ref<T>& operator=(Ref<T>&& other) {
+ if (&other != this) {
+ Release(mValue);
+ mValue = other.Detach();
+ }
+ return *this;
+ }
+
+ explicit operator bool() const { return !!mValue; }
+
+ // Smart pointer methods.
+ const T& Get() const { return mValue; }
+ T& Get() { return mValue; }
+ const T operator->() const { return mValue; }
+ T operator->() { return mValue; }
+
+ [[nodiscard]] T Detach() {
+ T value = mValue;
+ mValue = nullptr;
+ return value;
+ }
+
+ void Acquire(T value) {
+ Release(mValue);
+ mValue = value;
+ }
+
+ private:
+ static void AddRef(T value) {
+ if (value != nullptr) {
+ value->RefCounted::AddRef();
+ }
+ }
+ static void Release(T value) {
+ if (value != nullptr) {
+ value->RefCounted::Release();
+ }
+ }
+
+ void Set(T value) {
+ if (mValue != value) {
+ // Ensure that the new value is referenced before the old is released to
+ // prevent any transitive frees that may affect the new value.
+ AddRef(value);
+ Release(mValue);
+ mValue = value;
+ }
+ }
+
+ T mValue;
+};
+
+template <typename T>
+auto ReturnToAPI(Ref<T*>&& object) {
+ if constexpr (T::HasExternalRefCount) {
+ // For an object which has external ref count, just need to increase the
+ // external ref count, and keep the total ref count unchanged.
+ object->AddExternalRef();
+ }
+ return object.Detach();
+}
+
// clang-format off
// X Macro to help generate boilerplate code for all refcounted object types.
#define WGPU_REFCOUNTED_OBJECTS(X) \
@@ -412,6 +534,15 @@
}
// ----------------------------------------------------------------------------
+// WGPU struct declarations.
+// ----------------------------------------------------------------------------
+
+// Default struct declarations.
+#define DEFINE_WGPU_DEFAULT_STRUCT(Name) \
+ struct WGPU##Name##Impl final : public RefCounted {};
+WGPU_PASSTHROUGH_OBJECTS(DEFINE_WGPU_DEFAULT_STRUCT)
+
+// ----------------------------------------------------------------------------
// Future events.
// ----------------------------------------------------------------------------
@@ -432,7 +563,7 @@
WGPUAdapter adapter,
const char* message) {
mStatus = status;
- mAdapter = adapter;
+ mAdapter.Acquire(adapter);
mMessage = message;
}
@@ -442,10 +573,11 @@
mMessage = "A valid external Instance reference no longer exists.";
}
if (mCallback) {
- mCallback(
- mStatus,
- mStatus == WGPURequestAdapterStatus_Success ? mAdapter : nullptr,
- mMessage ? mMessage->c_str() : nullptr, mUserdata1, mUserdata2);
+ mCallback(mStatus,
+ mStatus == WGPURequestAdapterStatus_Success
+ ? ReturnToAPI(std::move(mAdapter))
+ : nullptr,
+ mMessage ? mMessage->c_str() : nullptr, mUserdata1, mUserdata2);
}
}
@@ -455,7 +587,7 @@
void* mUserdata2 = nullptr;
WGPURequestAdapterStatus mStatus;
- WGPUAdapter mAdapter = nullptr;
+ Ref<WGPUAdapter> mAdapter;
std::optional<std::string> mMessage = std::nullopt;
};
@@ -463,11 +595,6 @@
// WGPU struct implementations.
// ----------------------------------------------------------------------------
-// Default struct implementations.
-#define DEFINE_WGPU_DEFAULT_STRUCT(Name) \
- struct WGPU##Name##Impl : public RefCounted {};
-WGPU_PASSTHROUGH_OBJECTS(DEFINE_WGPU_DEFAULT_STRUCT)
-
// Instance is specially implemented in order to handle Futures implementation.
struct WGPUInstanceImpl : public RefCounted {
public:
@@ -475,7 +602,7 @@
mId = GetNextInstanceId();
GetEventManager().RegisterInstance(mId);
}
- ~WGPUInstanceImpl() override { GetEventManager().UnregisterInstance(mId); }
+ ~WGPUInstanceImpl() { GetEventManager().UnregisterInstance(mId); }
InstanceID GetId() const { return mId; }
void ProcessEvents() { GetEventManager().ProcessEvents(mId); }
@@ -498,17 +625,15 @@
// Device is specially implemented in order to handle refcounting the Queue.
struct WGPUDeviceImpl : public RefCounted {
public:
- WGPUDeviceImpl(WGPUQueue queue) : mQueue(queue) {
- // TODO(lokokung) Currently we are manually doing the ref counting for
- // the Queue. We should probably have some RAII helpers.
- mQueue->AddRef();
- }
- ~WGPUDeviceImpl() override { mQueue->Release(); }
+ WGPUDeviceImpl(WGPUQueue queue) { mQueue.Acquire(queue); }
- WGPUQueue GetQueue() { return mQueue; }
+ WGPUQueue GetQueue() {
+ auto queue = mQueue;
+ return ReturnToAPI(std::move(queue));
+ }
private:
- WGPUQueue mQueue;
+ Ref<WGPUQueue> mQueue;
};
// ----------------------------------------------------------------------------
@@ -618,7 +743,6 @@
// ----------------------------------------------------------------------------
WGPUQueue wgpuDeviceGetQueue(WGPUDevice device) {
- device->GetQueue()->AddRef();
return device->GetQueue();
}