dawn::wire::client: replace the "is alive" weak ptr with Ref<Device>

The weak_ptr mechanism was necessary to avoid the buffer referencing the
device, because the device couldn't make the difference between internal
and external references.

Bug: 344963953
Change-Id: If2cd1ba055bfc520a5cbbde71f60c906abeb9f95
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/197455
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Loko Kung <lokokung@google.com>
Reviewed-by: Austin Eng <enga@chromium.org>
diff --git a/src/dawn/dawn_wire.json b/src/dawn/dawn_wire.json
index 6fb0cbb..4d7f652 100644
--- a/src/dawn/dawn_wire.json
+++ b/src/dawn/dawn_wire.json
@@ -274,6 +274,7 @@
             "AdapterGetInstance",
             "BufferDestroy",
             "BufferUnmap",
+            "DeviceCreateErrorBuffer",
             "DeviceGetAdapter",
             "DeviceGetQueue",
             "DeviceGetSupportedSurfaceUsage",
diff --git a/src/dawn/tests/unittests/wire/WireBufferMappingTests.cpp b/src/dawn/tests/unittests/wire/WireBufferMappingTests.cpp
index 41d50bf..a3c541a 100644
--- a/src/dawn/tests/unittests/wire/WireBufferMappingTests.cpp
+++ b/src/dawn/tests/unittests/wire/WireBufferMappingTests.cpp
@@ -125,10 +125,11 @@
     }
 
     // Test to exercise client functions that should override server response for callbacks.
-    template <typename ExpFn>
-    void TestEarlyMapCancelled(void (*cancelFn)(WGPUBuffer),
-                               ExpFn cancelExp,
-                               WGPUBufferMapAsyncStatus expected) {
+    template <typename CancelFn, typename ExpFn>
+    void TestEarlyMapCancelled(CancelFn cancelMapping,
+                               ExpFn addExpectations,
+                               WGPUBufferMapAsyncStatus expected,
+                               bool calledInCancelFn) {
         WGPUMapMode mapMode = GetMapMode();
         SetupBuffer(mapMode);
         BufferMapAsync(buffer, mapMode, 0, kBufferSize, nullptr);
@@ -139,22 +140,22 @@
                 api.CallBufferMapAsyncCallback(apiBuffer, WGPUBufferMapAsyncStatus_Success);
             }));
         ExpectMappedRangeCall(kBufferSize, &bufferContent);
-        cancelExp();
+        addExpectations();
 
         // The callback should get called with the expected status, regardless if the server has
         // responded.
-        if (IsSpontaneous()) {
+        if (calledInCancelFn) {
             // In spontaneous mode, the callback gets called as a part of the cancel function.
             ExpectWireCallbacksWhen([&](auto& mockCb) {
                 EXPECT_CALL(mockCb, Call(expected, _)).Times(1);
 
-                cancelFn(buffer);
+                cancelMapping();
             });
             FlushClient();
             FlushCallbacks();
         } else {
             // Otherwise, the callback will fire when we flush them.
-            cancelFn(buffer);
+            cancelMapping();
             FlushClient();
             ExpectWireCallbacksWhen([&](auto& mockCb) {
                 EXPECT_CALL(mockCb, Call(expected, _)).Times(1);
@@ -165,10 +166,11 @@
     }
 
     // Test to exercise client functions that should override server error response for callbacks.
-    template <typename ExpFn>
-    void TestEarlyMapErrorCancelled(void (*cancelFn)(WGPUBuffer),
-                                    ExpFn cancelExp,
-                                    WGPUBufferMapAsyncStatus expected) {
+    template <typename CancelFn, typename ExpFn>
+    void TestEarlyMapErrorCancelled(CancelFn cancelMapping,
+                                    ExpFn addExpectations,
+                                    WGPUBufferMapAsyncStatus expected,
+                                    bool calledInCancelFn) {
         WGPUMapMode mapMode = GetMapMode();
         SetupBuffer(mapMode);
         BufferMapAsync(buffer, mapMode, 0, kBufferSize, nullptr);
@@ -182,22 +184,22 @@
         FlushClient();
         FlushFutures();
 
-        cancelExp();
+        addExpectations();
 
         // The callback should get called with the expected status status, not server-side error,
         // even if the request fails on the server side.
-        if (IsSpontaneous()) {
+        if (calledInCancelFn) {
             // In spontaneous mode, the callback gets called as a part of the cancel function.
             ExpectWireCallbacksWhen([&](auto& mockCb) {
                 EXPECT_CALL(mockCb, Call(expected, _)).Times(1);
 
-                cancelFn(buffer);
+                cancelMapping();
             });
             FlushClient();
             FlushCallbacks();
         } else {
             // Otherwise, the callback will fire when we flush them.
-            cancelFn(buffer);
+            cancelMapping();
             FlushClient();
             ExpectWireCallbacksWhen([&](auto& mockCb) {
                 EXPECT_CALL(mockCb, Call(expected, _)).Times(1);
@@ -276,31 +278,56 @@
 // Check the map callback is called with "UnmappedBeforeCallback" when the map request would have
 // worked, but Unmap() was called.
 TEST_P(WireBufferMappingTests, UnmapCalledTooEarly) {
-    TestEarlyMapCancelled(
-        &wgpuBufferUnmap, [&]() { EXPECT_CALL(api, BufferUnmap(apiBuffer)); },
-        WGPUBufferMapAsyncStatus_UnmappedBeforeCallback);
+    TestEarlyMapCancelled([&]() { wgpuBufferUnmap(buffer); },
+                          [&]() { EXPECT_CALL(api, BufferUnmap(apiBuffer)); },
+                          WGPUBufferMapAsyncStatus_UnmappedBeforeCallback, IsSpontaneous());
 }
 
 // Check that if Unmap() was called early client-side, we disregard server-side validation errors.
 TEST_P(WireBufferMappingTests, UnmapCalledTooEarlyServerSideError) {
-    TestEarlyMapErrorCancelled(
-        &wgpuBufferUnmap, [&]() { EXPECT_CALL(api, BufferUnmap(apiBuffer)); },
-        WGPUBufferMapAsyncStatus_UnmappedBeforeCallback);
+    TestEarlyMapErrorCancelled([&]() { wgpuBufferUnmap(buffer); },
+                               [&]() { EXPECT_CALL(api, BufferUnmap(apiBuffer)); },
+                               WGPUBufferMapAsyncStatus_UnmappedBeforeCallback, IsSpontaneous());
 }
 
 // Check the map callback is called with "DestroyedBeforeCallback" when the map request would have
 // worked, but Destroy() was called.
 TEST_P(WireBufferMappingTests, DestroyCalledTooEarly) {
-    TestEarlyMapCancelled(
-        &wgpuBufferDestroy, [&]() { EXPECT_CALL(api, BufferDestroy(apiBuffer)); },
-        WGPUBufferMapAsyncStatus_DestroyedBeforeCallback);
+    TestEarlyMapCancelled([&]() { wgpuBufferDestroy(buffer); },
+                          [&]() { EXPECT_CALL(api, BufferDestroy(apiBuffer)); },
+                          WGPUBufferMapAsyncStatus_DestroyedBeforeCallback, IsSpontaneous());
 }
 
 // Check that if Destroy() was called early client-side, we disregard server-side validation errors.
 TEST_P(WireBufferMappingTests, DestroyCalledTooEarlyServerSideError) {
+    TestEarlyMapErrorCancelled([&]() { wgpuBufferDestroy(buffer); },
+                               [&]() { EXPECT_CALL(api, BufferDestroy(apiBuffer)); },
+                               WGPUBufferMapAsyncStatus_DestroyedBeforeCallback, IsSpontaneous());
+}
+
+// Check the map callback is called with "DestroyedBeforeCallback" when the map request would have
+// worked, but the device was released.
+TEST_P(WireBufferMappingTests, DeviceReleasedTooEarly) {
+    TestEarlyMapCancelled(
+        [&]() { wgpuDeviceRelease(device); },
+        [&]() {
+            EXPECT_CALL(api, OnDeviceSetLoggingCallback(apiDevice, nullptr, nullptr)).Times(1);
+            EXPECT_CALL(api, DeviceRelease(apiDevice));
+        },
+        WGPUBufferMapAsyncStatus_DestroyedBeforeCallback, false);
+    DefaultApiDeviceWasReleased();
+}
+
+// Check that if device is released early client-side, we disregard server-side validation errors.
+TEST_P(WireBufferMappingTests, DeviceReleasedTooEarlyServerSideError) {
     TestEarlyMapErrorCancelled(
-        &wgpuBufferDestroy, [&]() { EXPECT_CALL(api, BufferDestroy(apiBuffer)); },
-        WGPUBufferMapAsyncStatus_DestroyedBeforeCallback);
+        [&]() { wgpuDeviceRelease(device); },
+        [&]() {
+            EXPECT_CALL(api, OnDeviceSetLoggingCallback(apiDevice, nullptr, nullptr)).Times(1);
+            EXPECT_CALL(api, DeviceRelease(apiDevice));
+        },
+        WGPUBufferMapAsyncStatus_DestroyedBeforeCallback, false);
+    DefaultApiDeviceWasReleased();
 }
 
 // Test that the callback isn't fired twice when Unmap() is called inside the callback.
diff --git a/src/dawn/wire/client/Buffer.cpp b/src/dawn/wire/client/Buffer.cpp
index 066e120..d08710a 100644
--- a/src/dawn/wire/client/Buffer.cpp
+++ b/src/dawn/wire/client/Buffer.cpp
@@ -50,7 +50,7 @@
     errorInfo.chain.sType = WGPUSType_DawnBufferDescriptorErrorInfoFromWireClient;
     errorInfo.outOfMemory = true;
     errorBufferDescriptor.nextInChain = &errorInfo.chain;
-    return GetProcs().deviceCreateErrorBuffer(ToAPI(device), &errorBufferDescriptor);
+    return device->CreateErrorBuffer(&errorBufferDescriptor);
 }
 }  // anonymous namespace
 
@@ -176,14 +176,12 @@
             return Callback();
         }
 
+        // Device destruction/loss implicitly makes the map requests aborted.
+        if (!mBuffer->mDevice->IsAlive()) {
+            status = WGPUBufferMapAsyncStatus_DestroyedBeforeCallback;
+        }
+
         if (status == WGPUBufferMapAsyncStatus_Success) {
-            if (mBuffer->mIsDeviceAlive.expired()) {
-                // If the device lost its last ref before this callback was resolved, we want to
-                // overwrite the status. This is necessary because otherwise dropping the last
-                // device reference could race w.r.t what this callback would see.
-                status = WGPUBufferMapAsyncStatus_DestroyedBeforeCallback;
-                return Callback();
-            }
             DAWN_ASSERT(mBuffer->mPendingMapRequest->type);
             switch (*mBuffer->mPendingMapRequest->type) {
                 case MapRequestType::Read:
@@ -297,20 +295,19 @@
             }
         };
 
+        // The request has been cancelled before completion, return that result.
         if (!IsPendingRequest(futureID)) {
             DAWN_ASSERT(mStatus != WGPUMapAsyncStatus_Success);
             return Callback();
         }
 
         if (mStatus == WGPUMapAsyncStatus_Success) {
-            if (mBuffer->mIsDeviceAlive.expired()) {
-                // If the device lost its last ref before this callback was resolved, we want to
-                // overwrite the status. This is necessary because otherwise dropping the last
-                // device reference could race w.r.t what this callback would see.
+            // Device destruction/loss implicitly makes the map requests aborted.
+            if (!mBuffer->mDevice->IsAlive()) {
                 mStatus = WGPUMapAsyncStatus_Aborted;
-                mMessage = "Buffer was destroyed before mapping was resolved.";
-                return Callback();
+                mMessage = "The Device was lost before mapping was resolved.";
             }
+
             DAWN_ASSERT(mBuffer->mPendingMapRequest->type);
             switch (*mBuffer->mPendingMapRequest->type) {
                 case MapRequestType::Read:
@@ -387,8 +384,8 @@
     // Create the buffer and send the creation command.
     // This must happen after any potential error buffer creation
     // as server expects allocating ids to be monotonically increasing
-    Ref<Buffer> buffer = wireClient->Make<Buffer>(device->GetEventManagerHandle(), descriptor);
-    buffer->mIsDeviceAlive = device->GetAliveWeakPtr();
+    Ref<Buffer> buffer =
+        wireClient->Make<Buffer>(device->GetEventManagerHandle(), device, descriptor);
 
     if (descriptor->mappedAtCreation) {
         // If the buffer is mapped at creation, a write handle is created and will be
@@ -428,8 +425,23 @@
     return ReturnToAPI(std::move(buffer));
 }
 
+// static
+WGPUBuffer Buffer::CreateError(Device* device, const WGPUBufferDescriptor* descriptor) {
+    Client* client = device->GetClient();
+    Ref<Buffer> buffer = client->Make<Buffer>(device->GetEventManagerHandle(), device, descriptor);
+
+    DeviceCreateErrorBufferCmd cmd;
+    cmd.self = ToAPI(device);
+    cmd.descriptor = descriptor;
+    cmd.result = buffer->GetWireHandle();
+    client->SerializeCommand(cmd);
+
+    return ReturnToAPI(std::move(buffer));
+}
+
 Buffer::Buffer(const ObjectBaseParams& params,
                const ObjectHandle& eventManagerHandle,
+               Device* device,
                const WGPUBufferDescriptor* descriptor)
     : ObjectWithEventsBase(params, eventManagerHandle),
       mSize(descriptor->size),
@@ -437,7 +449,8 @@
       // This flag is for the write handle created by mappedAtCreation
       // instead of MapWrite usage. We don't have such a case for read handle.
       mDestructWriteHandleOnUnmap(descriptor->mappedAtCreation &&
-                                  ((descriptor->usage & WGPUBufferUsage_MapWrite) == 0)) {}
+                                  ((descriptor->usage & WGPUBufferUsage_MapWrite) == 0)),
+      mDevice(device) {}
 
 void Buffer::DeleteThis() {
     FreeMappedData();
diff --git a/src/dawn/wire/client/Buffer.h b/src/dawn/wire/client/Buffer.h
index fa166b4..d9c880a 100644
--- a/src/dawn/wire/client/Buffer.h
+++ b/src/dawn/wire/client/Buffer.h
@@ -46,9 +46,11 @@
 class Buffer final : public ObjectWithEventsBase {
   public:
     static WGPUBuffer Create(Device* device, const WGPUBufferDescriptor* descriptor);
+    static WGPUBuffer CreateError(Device* device, const WGPUBufferDescriptor* descriptor);
 
     Buffer(const ObjectBaseParams& params,
            const ObjectHandle& eventManagerHandle,
+           Device* device,
            const WGPUBufferDescriptor* descriptor);
     void DeleteThis() override;
 
@@ -96,8 +98,7 @@
     const uint64_t mSize = 0;
     const WGPUBufferUsage mUsage;
     const bool mDestructWriteHandleOnUnmap;
-
-    std::weak_ptr<bool> mIsDeviceAlive;
+    Ref<Device> mDevice;
 
     // Mapping members are mutable depending on the current map state.
     enum class MapRequestType { Read, Write };
diff --git a/src/dawn/wire/client/Client.cpp b/src/dawn/wire/client/Client.cpp
index 639f5e9..ae371ae 100644
--- a/src/dawn/wire/client/Client.cpp
+++ b/src/dawn/wire/client/Client.cpp
@@ -78,7 +78,8 @@
 }
 
 ReservedBuffer Client::ReserveBuffer(WGPUDevice device, const WGPUBufferDescriptor* descriptor) {
-    Ref<Buffer> buffer = Make<Buffer>(FromAPI(device)->GetEventManagerHandle(), descriptor);
+    Ref<Buffer> buffer =
+        Make<Buffer>(FromAPI(device)->GetEventManagerHandle(), FromAPI(device), descriptor);
 
     ReservedBuffer result;
     result.handle = buffer->GetWireHandle();
diff --git a/src/dawn/wire/client/Device.cpp b/src/dawn/wire/client/Device.cpp
index 0996167..62cc0f4 100644
--- a/src/dawn/wire/client/Device.cpp
+++ b/src/dawn/wire/client/Device.cpp
@@ -253,8 +253,7 @@
 Device::Device(const ObjectBaseParams& params,
                const ObjectHandle& eventManagerHandle,
                const WGPUDeviceDescriptor* descriptor)
-    : RefCountedWithExternalCount<ObjectWithEventsBase>(params, eventManagerHandle),
-      mIsAlive(std::make_shared<bool>(true)) {
+    : RefCountedWithExternalCount<ObjectWithEventsBase>(params, eventManagerHandle) {
 #if defined(DAWN_ENABLE_ASSERTS)
     static constexpr WGPUDeviceLostCallbackInfo2 kDefaultDeviceLostCallbackInfo = {
         nullptr, WGPUCallbackMode_AllowSpontaneous,
@@ -322,6 +321,10 @@
     return ObjectType::Device;
 }
 
+bool Device::IsAlive() const {
+    return mIsAlive;
+}
+
 void Device::WillDropLastExternalRef() {
     HandleDeviceLost(WGPUDeviceLostReason_Destroyed, "Device was destroyed.");
     Unregister();
@@ -369,6 +372,7 @@
         DAWN_CHECK(GetEventManager().SetFutureReady<DeviceLostEvent>(futureID, reason, message) ==
                    WireResult::Success);
     }
+    mIsAlive = false;
 }
 
 WGPUFuture Device::GetDeviceLostFuture() {
@@ -383,10 +387,6 @@
     return {mDeviceLostInfo.futureID};
 }
 
-std::weak_ptr<bool> Device::GetAliveWeakPtr() {
-    return mIsAlive;
-}
-
 void Device::SetUncapturedErrorCallback(WGPUErrorCallback errorCallback, void* errorUserdata) {
     if (mDeviceLostInfo.futureID != kNullFutureID) {
         mUncapturedErrorCallbackInfo = {nullptr, &LegacyUncapturedErrorCallback,
@@ -474,6 +474,10 @@
     return Buffer::Create(this, descriptor);
 }
 
+WGPUBuffer Device::CreateErrorBuffer(const WGPUBufferDescriptor* descriptor) {
+    return Buffer::CreateError(this, descriptor);
+}
+
 WGPUQueue Device::GetQueue() {
     // The queue is lazily created because if a Device is created by
     // Reserve/Inject, we cannot send the GetQueue message until
diff --git a/src/dawn/wire/client/Device.h b/src/dawn/wire/client/Device.h
index 7a395ee..7a9e8c1 100644
--- a/src/dawn/wire/client/Device.h
+++ b/src/dawn/wire/client/Device.h
@@ -52,6 +52,18 @@
 
     ObjectType GetObjectType() const override;
 
+    void SetLimits(const WGPUSupportedLimits* limits);
+    void SetFeatures(const WGPUFeatureName* features, uint32_t featuresCount);
+
+    bool IsAlive() const;
+    WGPUFuture GetDeviceLostFuture();
+
+    void HandleError(WGPUErrorType errorType, const char* message);
+    void HandleLogging(WGPULoggingType loggingType, const char* message);
+    void HandleDeviceLost(WGPUDeviceLostReason reason, const char* message);
+    class DeviceLostEvent;
+
+    // WebGPU API
     void SetUncapturedErrorCallback(WGPUErrorCallback errorCallback, void* errorUserdata);
     void SetLoggingCallback(WGPULoggingCallback errorCallback, void* errorUserdata);
     void SetDeviceLostCallback(WGPUDeviceLostCallback errorCallback, void* errorUserdata);
@@ -59,7 +71,9 @@
     void PopErrorScope(WGPUErrorCallback callback, void* userdata);
     WGPUFuture PopErrorScopeF(const WGPUPopErrorScopeCallbackInfo& callbackInfo);
     WGPUFuture PopErrorScope2(const WGPUPopErrorScopeCallbackInfo2& callbackInfo);
+
     WGPUBuffer CreateBuffer(const WGPUBufferDescriptor* descriptor);
+    WGPUBuffer CreateErrorBuffer(const WGPUBufferDescriptor* descriptor);
     void CreateComputePipelineAsync(WGPUComputePipelineDescriptor const* descriptor,
                                     WGPUCreateComputePipelineAsyncCallback callback,
                                     void* userdata);
@@ -79,22 +93,10 @@
         WGPURenderPipelineDescriptor const* descriptor,
         const WGPUCreateRenderPipelineAsyncCallbackInfo2& callbackInfo);
 
-    void HandleError(WGPUErrorType errorType, const char* message);
-    void HandleLogging(WGPULoggingType loggingType, const char* message);
-    void HandleDeviceLost(WGPUDeviceLostReason reason, const char* message);
-
     WGPUStatus GetLimits(WGPUSupportedLimits* limits) const;
     bool HasFeature(WGPUFeatureName feature) const;
     size_t EnumerateFeatures(WGPUFeatureName* features) const;
-    void SetLimits(const WGPUSupportedLimits* limits);
-    void SetFeatures(const WGPUFeatureName* features, uint32_t featuresCount);
-
     WGPUQueue GetQueue();
-    WGPUFuture GetDeviceLostFuture();
-
-    std::weak_ptr<bool> GetAliveWeakPtr();
-
-    class DeviceLostEvent;
 
   private:
     void WillDropLastExternalRef() override;
@@ -122,8 +124,7 @@
     raw_ptr<void> mLoggingUserdata = nullptr;
 
     Ref<Queue> mQueue;
-
-    std::shared_ptr<bool> mIsAlive;
+    bool mIsAlive = true;
 };
 
 }  // namespace dawn::wire::client