dawn_wire: Implement device-related callbacks for multiple devices

Bug: dawn:565
Change-Id: Ic80a3bc1bbfd479af04e77afa0eb3f4ca3387ecd
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/38282
Reviewed-by: Jiawei Shao <jiawei.shao@intel.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/dawn_wire/client/Client.cpp b/src/dawn_wire/client/Client.cpp
index 7665a70..4ceaaba 100644
--- a/src/dawn_wire/client/Client.cpp
+++ b/src/dawn_wire/client/Client.cpp
@@ -129,14 +129,18 @@
     void Client::Disconnect() {
         mDisconnected = true;
         mSerializer = ChunkedCommandSerializer(NoopCommandSerializer::GetInstance());
-        if (mDevice != nullptr) {
-            mDevice->HandleDeviceLost("GPU connection lost");
+
+        auto& deviceList = mObjects[ObjectType::Device];
+        {
+            for (LinkNode<ObjectBase>* device = deviceList.head(); device != deviceList.end();
+                 device = device->next()) {
+                static_cast<Device*>(device->value())->HandleDeviceLost("GPU connection lost");
+            }
         }
         for (auto& objectList : mObjects) {
-            LinkNode<ObjectBase>* object = objectList.head();
-            while (object != objectList.end()) {
+            for (LinkNode<ObjectBase>* object = objectList.head(); object != objectList.end();
+                 object = object->next()) {
                 object->value()->CancelCallbacksForDisconnect();
-                object = object->next();
             }
         }
     }
diff --git a/src/dawn_wire/client/ClientDoers.cpp b/src/dawn_wire/client/ClientDoers.cpp
index cd9b5ab..75688de 100644
--- a/src/dawn_wire/client/ClientDoers.cpp
+++ b/src/dawn_wire/client/ClientDoers.cpp
@@ -20,7 +20,9 @@
 
 namespace dawn_wire { namespace client {
 
-    bool Client::DoDeviceUncapturedErrorCallback(WGPUErrorType errorType, const char* message) {
+    bool Client::DoDeviceUncapturedErrorCallback(Device* device,
+                                                 WGPUErrorType errorType,
+                                                 const char* message) {
         switch (errorType) {
             case WGPUErrorType_NoError:
             case WGPUErrorType_Validation:
@@ -31,19 +33,20 @@
             default:
                 return false;
         }
-        mDevice->HandleError(errorType, message);
+        device->HandleError(errorType, message);
         return true;
     }
 
-    bool Client::DoDeviceLostCallback(char const* message) {
-        mDevice->HandleDeviceLost(message);
+    bool Client::DoDeviceLostCallback(Device* device, char const* message) {
+        device->HandleDeviceLost(message);
         return true;
     }
 
-    bool Client::DoDevicePopErrorScopeCallback(uint64_t requestSerial,
+    bool Client::DoDevicePopErrorScopeCallback(Device* device,
+                                               uint64_t requestSerial,
                                                WGPUErrorType errorType,
                                                const char* message) {
-        return mDevice->OnPopErrorScopeCallback(requestSerial, errorType, message);
+        return device->OnPopErrorScopeCallback(requestSerial, errorType, message);
     }
 
     bool Client::DoBufferMapAsyncCallback(Buffer* buffer,
@@ -82,16 +85,26 @@
         return true;
     }
 
-    bool Client::DoDeviceCreateReadyComputePipelineCallback(uint64_t requestSerial,
+    bool Client::DoDeviceCreateReadyComputePipelineCallback(Device* device,
+                                                            uint64_t requestSerial,
                                                             WGPUCreateReadyPipelineStatus status,
                                                             const char* message) {
-        return mDevice->OnCreateReadyComputePipelineCallback(requestSerial, status, message);
+        // The device might have been deleted or recreated so this isn't an error.
+        if (device == nullptr) {
+            return true;
+        }
+        return device->OnCreateReadyComputePipelineCallback(requestSerial, status, message);
     }
 
-    bool Client::DoDeviceCreateReadyRenderPipelineCallback(uint64_t requestSerial,
+    bool Client::DoDeviceCreateReadyRenderPipelineCallback(Device* device,
+                                                           uint64_t requestSerial,
                                                            WGPUCreateReadyPipelineStatus status,
                                                            const char* message) {
-        return mDevice->OnCreateReadyRenderPipelineCallback(requestSerial, status, message);
+        // The device might have been deleted or recreated so this isn't an error.
+        if (device == nullptr) {
+            return true;
+        }
+        return device->OnCreateReadyRenderPipelineCallback(requestSerial, status, message);
     }
 
 }}  // namespace dawn_wire::client
diff --git a/src/dawn_wire/client/Device.cpp b/src/dawn_wire/client/Device.cpp
index 2d643cb..9ce73d5 100644
--- a/src/dawn_wire/client/Device.cpp
+++ b/src/dawn_wire/client/Device.cpp
@@ -146,7 +146,7 @@
         mErrorScopes[serial] = {callback, userdata};
 
         DevicePopErrorScopeCmd cmd;
-        cmd.device = ToAPI(this);
+        cmd.deviceId = this->id;
         cmd.requestSerial = serial;
 
         client->SerializeCommand(cmd);
diff --git a/src/dawn_wire/server/ObjectStorage.h b/src/dawn_wire/server/ObjectStorage.h
index 1595bda..bffe8e8 100644
--- a/src/dawn_wire/server/ObjectStorage.h
+++ b/src/dawn_wire/server/ObjectStorage.h
@@ -26,6 +26,8 @@
 
     struct DeviceInfo {
         std::unordered_set<uint64_t> childObjectTypesAndIds;
+        Server* server;
+        ObjectHandle self;
     };
 
     template <typename T>
diff --git a/src/dawn_wire/server/Server.cpp b/src/dawn_wire/server/Server.cpp
index 25a22f8..8c97bfc 100644
--- a/src/dawn_wire/server/Server.cpp
+++ b/src/dawn_wire/server/Server.cpp
@@ -90,6 +90,8 @@
         data->handle = device;
         data->generation = generation;
         data->allocated = true;
+        data->info->server = this;
+        data->info->self = ObjectHandle{id, generation};
 
         // The device is externally owned so it shouldn't be destroyed when we receive a destroy
         // message from the client. Add a reference to counterbalance the eventual release.
@@ -97,21 +99,23 @@
 
         // Set callbacks to forward errors to the client.
         // Note: these callbacks are manually inlined here since they do not acquire and
-        // free their userdata.
+        // free their userdata. Also unlike other callbacks, these are cleared and unset when
+        // the server is destroyed, so we don't need to check if the server is still alive
+        // inside them.
         mProcs.deviceSetUncapturedErrorCallback(
             device,
             [](WGPUErrorType type, const char* message, void* userdata) {
-                Server* server = static_cast<Server*>(userdata);
-                server->OnUncapturedError(type, message);
+                DeviceInfo* info = static_cast<DeviceInfo*>(userdata);
+                info->server->OnUncapturedError(info->self, type, message);
             },
-            this);
+            data->info.get());
         mProcs.deviceSetDeviceLostCallback(
             device,
             [](const char* message, void* userdata) {
-                Server* server = static_cast<Server*>(userdata);
-                server->OnDeviceLost(message);
+                DeviceInfo* info = static_cast<DeviceInfo*>(userdata);
+                info->server->OnDeviceLost(info->self, message);
             },
-            this);
+            data->info.get());
 
         return true;
     }
diff --git a/src/dawn_wire/server/Server.h b/src/dawn_wire/server/Server.h
index ffd7d22..d073009 100644
--- a/src/dawn_wire/server/Server.h
+++ b/src/dawn_wire/server/Server.h
@@ -123,8 +123,7 @@
     struct ErrorScopeUserdata : CallbackUserdata {
         using CallbackUserdata::CallbackUserdata;
 
-        // TODO(enga): ObjectHandle device;
-        // when the wire supports multiple devices.
+        ObjectHandle device;
         uint64_t requestSerial;
     };
 
@@ -145,6 +144,7 @@
     struct CreateReadyPipelineUserData : CallbackUserdata {
         using CallbackUserdata::CallbackUserdata;
 
+        ObjectHandle device;
         uint64_t requestSerial;
         ObjectId pipelineObjectID;
     };
@@ -193,8 +193,8 @@
         void ClearDeviceCallbacks(WGPUDevice device);
 
         // Error callbacks
-        void OnUncapturedError(WGPUErrorType type, const char* message);
-        void OnDeviceLost(const char* message);
+        void OnUncapturedError(ObjectHandle device, WGPUErrorType type, const char* message);
+        void OnDeviceLost(ObjectHandle device, const char* message);
         void OnDevicePopErrorScope(WGPUErrorType type,
                                    const char* message,
                                    ErrorScopeUserdata* userdata);
diff --git a/src/dawn_wire/server/ServerDevice.cpp b/src/dawn_wire/server/ServerDevice.cpp
index c200725..dc0b86b 100644
--- a/src/dawn_wire/server/ServerDevice.cpp
+++ b/src/dawn_wire/server/ServerDevice.cpp
@@ -16,28 +16,62 @@
 
 namespace dawn_wire { namespace server {
 
-    void Server::OnUncapturedError(WGPUErrorType type, const char* message) {
+    namespace {
+
+        template <ObjectType objectType, typename Pipeline>
+        void HandleCreateReadyRenderPipelineCallbackResult(KnownObjects<Pipeline>* knownObjects,
+                                                           WGPUCreateReadyPipelineStatus status,
+                                                           Pipeline pipeline,
+                                                           const char* message,
+                                                           CreateReadyPipelineUserData* data) {
+            auto* pipelineObject = knownObjects->Get(data->pipelineObjectID);
+
+            if (status == WGPUCreateReadyPipelineStatus_Success) {
+                ASSERT(pipelineObject != nullptr);
+                pipelineObject->handle = pipeline;
+            } else if (pipelineObject != nullptr) {
+                // May be null if the device was destroyed. Device destruction destroys child
+                // objects on the wire.
+                if (!UntrackDeviceChild(pipelineObject->deviceInfo, objectType,
+                                        data->pipelineObjectID)) {
+                    UNREACHABLE();
+                }
+                knownObjects->Free(data->pipelineObjectID);
+            }
+        }
+
+    }  // anonymous namespace
+
+    void Server::OnUncapturedError(ObjectHandle device, WGPUErrorType type, const char* message) {
         ReturnDeviceUncapturedErrorCallbackCmd cmd;
+        cmd.device = device;
         cmd.type = type;
         cmd.message = message;
 
         SerializeCommand(cmd);
     }
 
-    void Server::OnDeviceLost(const char* message) {
+    void Server::OnDeviceLost(ObjectHandle device, const char* message) {
         ReturnDeviceLostCallbackCmd cmd;
+        cmd.device = device;
         cmd.message = message;
 
         SerializeCommand(cmd);
     }
 
-    bool Server::DoDevicePopErrorScope(WGPUDevice cDevice, uint64_t requestSerial) {
+    bool Server::DoDevicePopErrorScope(ObjectId deviceId, uint64_t requestSerial) {
+        auto* device = DeviceObjects().Get(deviceId);
+        if (device == nullptr) {
+            return false;
+        }
+
         auto userdata = MakeUserdata<ErrorScopeUserdata>();
         userdata->requestSerial = requestSerial;
+        userdata->device = ObjectHandle{deviceId, device->generation};
 
         ErrorScopeUserdata* unownedUserdata = userdata.release();
         bool success = mProcs.devicePopErrorScope(
-            cDevice,
+            device->handle,
             ForwardToServer<decltype(
                 &Server::OnDevicePopErrorScope)>::Func<&Server::OnDevicePopErrorScope>(),
             unownedUserdata);
@@ -51,6 +85,7 @@
                                        const char* message,
                                        ErrorScopeUserdata* userdata) {
         ReturnDevicePopErrorScopeCallbackCmd cmd;
+        cmd.device = userdata->device;
         cmd.requestSerial = userdata->requestSerial;
         cmd.type = type;
         cmd.message = message;
@@ -81,6 +116,7 @@
         }
 
         auto userdata = MakeUserdata<CreateReadyPipelineUserData>();
+        userdata->device = ObjectHandle{deviceId, device->generation};
         userdata->requestSerial = requestSerial;
         userdata->pipelineObjectID = pipelineObjectHandle.id;
 
@@ -96,29 +132,11 @@
                                                       WGPUComputePipeline pipeline,
                                                       const char* message,
                                                       CreateReadyPipelineUserData* data) {
-        auto* computePipelineObject = ComputePipelineObjects().Get(data->pipelineObjectID);
-        ASSERT(computePipelineObject != nullptr);
-
-        switch (status) {
-            case WGPUCreateReadyPipelineStatus_Success:
-                computePipelineObject->handle = pipeline;
-                break;
-
-            case WGPUCreateReadyPipelineStatus_Error:
-                ComputePipelineObjects().Free(data->pipelineObjectID);
-                break;
-
-            // Currently this code is unreachable because WireServer is always deleted before the
-            // removal of the device. In the future this logic may be changed when we decide to
-            // support sharing one pair of WireServer/WireClient to multiple devices.
-            case WGPUCreateReadyPipelineStatus_DeviceLost:
-            case WGPUCreateReadyPipelineStatus_DeviceDestroyed:
-            case WGPUCreateReadyPipelineStatus_Unknown:
-            default:
-                UNREACHABLE();
-        }
+        HandleCreateReadyRenderPipelineCallbackResult<ObjectType::ComputePipeline>(
+            &ComputePipelineObjects(), status, pipeline, message, data);
 
         ReturnDeviceCreateReadyComputePipelineCallbackCmd cmd;
+        cmd.device = data->device;
         cmd.status = status;
         cmd.requestSerial = data->requestSerial;
         cmd.message = message;
@@ -148,6 +166,7 @@
         }
 
         auto userdata = MakeUserdata<CreateReadyPipelineUserData>();
+        userdata->device = ObjectHandle{deviceId, device->generation};
         userdata->requestSerial = requestSerial;
         userdata->pipelineObjectID = pipelineObjectHandle.id;
 
@@ -163,29 +182,11 @@
                                                      WGPURenderPipeline pipeline,
                                                      const char* message,
                                                      CreateReadyPipelineUserData* data) {
-        auto* renderPipelineObject = RenderPipelineObjects().Get(data->pipelineObjectID);
-        ASSERT(renderPipelineObject != nullptr);
-
-        switch (status) {
-            case WGPUCreateReadyPipelineStatus_Success:
-                renderPipelineObject->handle = pipeline;
-                break;
-
-            case WGPUCreateReadyPipelineStatus_Error:
-                RenderPipelineObjects().Free(data->pipelineObjectID);
-                break;
-
-            // Currently this code is unreachable because WireServer is always deleted before the
-            // removal of the device. In the future this logic may be changed when we decide to
-            // support sharing one pair of WireServer/WireClient to multiple devices.
-            case WGPUCreateReadyPipelineStatus_DeviceLost:
-            case WGPUCreateReadyPipelineStatus_DeviceDestroyed:
-            case WGPUCreateReadyPipelineStatus_Unknown:
-            default:
-                UNREACHABLE();
-        }
+        HandleCreateReadyRenderPipelineCallbackResult<ObjectType::RenderPipeline>(
+            &RenderPipelineObjects(), status, pipeline, message, data);
 
         ReturnDeviceCreateReadyRenderPipelineCallbackCmd cmd;
+        cmd.device = data->device;
         cmd.status = status;
         cmd.requestSerial = data->requestSerial;
         cmd.message = message;
diff --git a/src/tests/unittests/wire/WireCreateReadyPipelineTests.cpp b/src/tests/unittests/wire/WireCreateReadyPipelineTests.cpp
index 9d90348..2e7ae0f 100644
--- a/src/tests/unittests/wire/WireCreateReadyPipelineTests.cpp
+++ b/src/tests/unittests/wire/WireCreateReadyPipelineTests.cpp
@@ -328,3 +328,46 @@
     wgpuDeviceCreateReadyComputePipeline(device, &descriptor,
                                          ToMockCreateReadyComputePipelineCallback, this);
 }
+
+TEST_F(WireCreateReadyPipelineTest, DeviceDeletedBeforeCallback) {
+    WGPUShaderModuleDescriptor vertexDescriptor = {};
+    WGPUShaderModule module = wgpuDeviceCreateShaderModule(device, &vertexDescriptor);
+    WGPUShaderModule apiModule = api.GetNewShaderModule();
+    EXPECT_CALL(api, DeviceCreateShaderModule(apiDevice, _)).WillOnce(Return(apiModule));
+
+    WGPURenderPipelineDescriptor pipelineDescriptor{};
+    pipelineDescriptor.vertexStage.module = module;
+    pipelineDescriptor.vertexStage.entryPoint = "main";
+
+    WGPUProgrammableStageDescriptor fragmentStage = {};
+    fragmentStage.module = module;
+    fragmentStage.entryPoint = "main";
+    pipelineDescriptor.fragmentStage = &fragmentStage;
+
+    wgpuDeviceCreateReadyRenderPipeline(device, &pipelineDescriptor,
+                                        ToMockCreateReadyRenderPipelineCallback, this);
+
+    EXPECT_CALL(api, OnDeviceCreateReadyRenderPipeline(apiDevice, _, _, _));
+    FlushClient();
+
+    EXPECT_CALL(*mockCreateReadyRenderPipelineCallback,
+                Call(WGPUCreateReadyPipelineStatus_DeviceDestroyed, nullptr, _, this))
+        .Times(1);
+
+    wgpuDeviceRelease(device);
+
+    // Expect release on all objects created by the client.
+    Sequence s1, s2;
+    EXPECT_CALL(api, QueueRelease(apiQueue)).Times(1).InSequence(s1);
+    EXPECT_CALL(api, ShaderModuleRelease(apiModule)).Times(1).InSequence(s2);
+    EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(apiDevice, nullptr, nullptr))
+        .Times(1)
+        .InSequence(s1, s2);
+    EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(apiDevice, nullptr, nullptr))
+        .Times(1)
+        .InSequence(s1, s2);
+    EXPECT_CALL(api, DeviceRelease(apiDevice)).Times(1).InSequence(s1, s2);
+
+    FlushClient();
+    DefaultApiDeviceWasReleased();
+}