[dawn][emscripten] Adds Ref<> helpers for refcounting.

- Based on the RefBase from dawn/common, adds Ref<> to
  help with refcounting.
- Adds RefCountedWithExternalCount to deal in preparations
  for internally and externally ref counted objects, i.e.
  the Device.
- Use Ref<> for existing places instead of explicitly
  dealing with the refcounting functions.

Change-Id: Id2871ff87cb0870198e1c74d7c969482d9376ab8
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/203674
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
Commit-Queue: Loko Kung <lokokung@google.com>
Reviewed-by: Shrek Shao <shrekshao@google.com>
diff --git a/third_party/emdawnwebgpu/webgpu.cpp b/third_party/emdawnwebgpu/webgpu.cpp
index 56b7953..d529881 100644
--- a/third_party/emdawnwebgpu/webgpu.cpp
+++ b/third_party/emdawnwebgpu/webgpu.cpp
@@ -74,8 +74,7 @@
 
 class RefCounted : NonMovable {
  public:
-  RefCounted() = default;
-  virtual ~RefCounted() = default;
+  static constexpr bool HasExternalRefCount = false;
 
   void AddRef() {
     assert(mRefCount.fetch_add(1u, std::memory_order_relaxed) >= 1);
@@ -93,6 +92,129 @@
   std::atomic<uint64_t> mRefCount = 1;
 };
 
+class RefCountedWithExternalCount : public RefCounted {
+ public:
+  static constexpr bool HasExternalRefCount = true;
+
+  virtual ~RefCountedWithExternalCount() = default;
+
+  void AddRef() {
+    AddExternalRef();
+    RefCounted::AddRef();
+  }
+
+  void Release() {
+    if (mExternalRefCount.fetch_sub(1u, std::memory_order_release) == 1u) {
+      std::atomic_thread_fence(std::memory_order_acquire);
+      WillDropLastExternalRef();
+    }
+    RefCounted::Release();
+  }
+
+  void AddExternalRef() {
+    mExternalRefCount.fetch_add(1u, std::memory_order_relaxed);
+  }
+
+ private:
+  virtual void WillDropLastExternalRef() = 0;
+
+  std::atomic<uint64_t> mExternalRefCount = 0;
+};
+
+template <typename T>
+class Ref {
+ public:
+  static_assert(std::is_convertible_v<T, RefCounted*>,
+                "Cannot make a Ref<T> when T is not a Refcounted type.");
+
+  Ref() : mValue(nullptr) {}
+  ~Ref() {
+    if (mValue) {
+      mValue->Release();
+    }
+  }
+
+  // Constructors from nullptr.
+  // NOLINTNEXTLINE(runtime/explicit)
+  constexpr Ref(std::nullptr_t) : Ref() {}
+
+  // Constructors from T.
+  // NOLINTNEXTLINE(runtime/explicit)
+  Ref(T value) : mValue(value) { AddRef(value); }
+  Ref<T>& operator=(const T& value) {
+    Set(value);
+    return *this;
+  }
+
+  // Constructors from a Ref<T>.
+  Ref(const Ref<T>& other) : mValue(other.mValue) { AddRef(other.mValue); }
+  Ref<T>& operator=(const Ref<T>& other) {
+    Set(other.mValue);
+    return *this;
+  }
+  Ref(Ref<T>&& other) { mValue = other.Detach(); }
+  Ref<T>& operator=(Ref<T>&& other) {
+    if (&other != this) {
+      Release(mValue);
+      mValue = other.Detach();
+    }
+    return *this;
+  }
+
+  explicit operator bool() const { return !!mValue; }
+
+  // Smart pointer methods.
+  const T& Get() const { return mValue; }
+  T& Get() { return mValue; }
+  const T operator->() const { return mValue; }
+  T operator->() { return mValue; }
+
+  [[nodiscard]] T Detach() {
+    T value = mValue;
+    mValue = nullptr;
+    return value;
+  }
+
+  void Acquire(T value) {
+    Release(mValue);
+    mValue = value;
+  }
+
+ private:
+  static void AddRef(T value) {
+    if (value != nullptr) {
+      value->RefCounted::AddRef();
+    }
+  }
+  static void Release(T value) {
+    if (value != nullptr) {
+      value->RefCounted::Release();
+    }
+  }
+
+  void Set(T value) {
+    if (mValue != value) {
+      // Ensure that the new value is referenced before the old is released to
+      // prevent any transitive frees that may affect the new value.
+      AddRef(value);
+      Release(mValue);
+      mValue = value;
+    }
+  }
+
+  T mValue;
+};
+
+template <typename T>
+auto ReturnToAPI(Ref<T*>&& object) {
+  if constexpr (T::HasExternalRefCount) {
+    // For an object which has external ref count, just need to increase the
+    // external ref count, and keep the total ref count unchanged.
+    object->AddExternalRef();
+  }
+  return object.Detach();
+}
+
 // clang-format off
 // X Macro to help generate boilerplate code for all refcounted object types.
 #define WGPU_REFCOUNTED_OBJECTS(X) \
@@ -412,6 +534,15 @@
 }
 
 // ----------------------------------------------------------------------------
+// WGPU struct declarations.
+// ----------------------------------------------------------------------------
+
+// Default struct declarations.
+#define DEFINE_WGPU_DEFAULT_STRUCT(Name) \
+  struct WGPU##Name##Impl final : public RefCounted {};
+WGPU_PASSTHROUGH_OBJECTS(DEFINE_WGPU_DEFAULT_STRUCT)
+
+// ----------------------------------------------------------------------------
 // Future events.
 // ----------------------------------------------------------------------------
 
@@ -432,7 +563,7 @@
                  WGPUAdapter adapter,
                  const char* message) {
     mStatus = status;
-    mAdapter = adapter;
+    mAdapter.Acquire(adapter);
     mMessage = message;
   }
 
@@ -442,10 +573,11 @@
       mMessage = "A valid external Instance reference no longer exists.";
     }
     if (mCallback) {
-      mCallback(
-          mStatus,
-          mStatus == WGPURequestAdapterStatus_Success ? mAdapter : nullptr,
-          mMessage ? mMessage->c_str() : nullptr, mUserdata1, mUserdata2);
+      mCallback(mStatus,
+                mStatus == WGPURequestAdapterStatus_Success
+                    ? ReturnToAPI(std::move(mAdapter))
+                    : nullptr,
+                mMessage ? mMessage->c_str() : nullptr, mUserdata1, mUserdata2);
     }
   }
 
@@ -455,7 +587,7 @@
   void* mUserdata2 = nullptr;
 
   WGPURequestAdapterStatus mStatus;
-  WGPUAdapter mAdapter = nullptr;
+  Ref<WGPUAdapter> mAdapter;
   std::optional<std::string> mMessage = std::nullopt;
 };
 
@@ -463,11 +595,6 @@
 // WGPU struct implementations.
 // ----------------------------------------------------------------------------
 
-// Default struct implementations.
-#define DEFINE_WGPU_DEFAULT_STRUCT(Name) \
-  struct WGPU##Name##Impl : public RefCounted {};
-WGPU_PASSTHROUGH_OBJECTS(DEFINE_WGPU_DEFAULT_STRUCT)
-
 // Instance is specially implemented in order to handle Futures implementation.
 struct WGPUInstanceImpl : public RefCounted {
  public:
@@ -475,7 +602,7 @@
     mId = GetNextInstanceId();
     GetEventManager().RegisterInstance(mId);
   }
-  ~WGPUInstanceImpl() override { GetEventManager().UnregisterInstance(mId); }
+  ~WGPUInstanceImpl() { GetEventManager().UnregisterInstance(mId); }
   InstanceID GetId() const { return mId; }
 
   void ProcessEvents() { GetEventManager().ProcessEvents(mId); }
@@ -498,17 +625,15 @@
 // Device is specially implemented in order to handle refcounting the Queue.
 struct WGPUDeviceImpl : public RefCounted {
  public:
-  WGPUDeviceImpl(WGPUQueue queue) : mQueue(queue) {
-    // TODO(lokokung) Currently we are manually doing the ref counting for
-    //                the Queue. We should probably have some RAII helpers.
-    mQueue->AddRef();
-  }
-  ~WGPUDeviceImpl() override { mQueue->Release(); }
+  WGPUDeviceImpl(WGPUQueue queue) { mQueue.Acquire(queue); }
 
-  WGPUQueue GetQueue() { return mQueue; }
+  WGPUQueue GetQueue() {
+    auto queue = mQueue;
+    return ReturnToAPI(std::move(queue));
+  }
 
  private:
-  WGPUQueue mQueue;
+  Ref<WGPUQueue> mQueue;
 };
 
 // ----------------------------------------------------------------------------
@@ -618,7 +743,6 @@
 // ----------------------------------------------------------------------------
 
 WGPUQueue wgpuDeviceGetQueue(WGPUDevice device) {
-  device->GetQueue()->AddRef();
   return device->GetQueue();
 }