dawn_wire: Add Reserve/InjectDevice

Now that the wire does enough tracking to prevent a malicious client
from freeing a device before its child objects, and the device is no
longer a "special" object with regard to reference/release, it is
safe to support multiple devices on the wire. The simplest way to
use this in WebGPU (to fix createReadyRenderPipeline validation)
is to add a reserve/inject device API similar to the one we use for
swapchain textures.

Bug: dawn:565
Change-Id: Ie956aff528c5610c9ecc5c189dab2d22185cb572
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/37800
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Austin Eng <enga@chromium.org>
diff --git a/generator/templates/dawn_wire/server/ServerDoers.cpp b/generator/templates/dawn_wire/server/ServerDoers.cpp
index 0c6ce42..0252336 100644
--- a/generator/templates/dawn_wire/server/ServerDoers.cpp
+++ b/generator/templates/dawn_wire/server/ServerDoers.cpp
@@ -98,6 +98,11 @@
                                 *data->childObjectTypesAndIds.begin());
                             DoDestroyObject(childObjectType, childObjectId);
                         }
+                        if (data->handle != nullptr) {
+                            //* Deregisters uncaptured error and device lost callbacks since
+                            //* they should not be forwarded if the device no longer exists on the wire.
+                            ClearDeviceCallbacks(data->handle);
+                        }
                     {% endif %}
                     if (data->handle != nullptr) {
                         mProcs.{{as_varName(type.name, Name("release"))}}(data->handle);
diff --git a/src/dawn_wire/WireClient.cpp b/src/dawn_wire/WireClient.cpp
index de73a84..0dcea37 100644
--- a/src/dawn_wire/WireClient.cpp
+++ b/src/dawn_wire/WireClient.cpp
@@ -37,6 +37,10 @@
         return mImpl->ReserveTexture(device);
     }
 
+    ReservedDevice WireClient::ReserveDevice() {
+        return mImpl->ReserveDevice();
+    }
+
     void WireClient::Disconnect() {
         mImpl->Disconnect();
     }
diff --git a/src/dawn_wire/WireServer.cpp b/src/dawn_wire/WireServer.cpp
index 723f691..763b5fc 100644
--- a/src/dawn_wire/WireServer.cpp
+++ b/src/dawn_wire/WireServer.cpp
@@ -40,6 +40,14 @@
         return mImpl->InjectTexture(texture, id, generation, deviceId, deviceGeneration);
     }
 
+    bool WireServer::InjectDevice(WGPUDevice device, uint32_t id, uint32_t generation) {
+        return mImpl->InjectDevice(device, id, generation);
+    }
+
+    WGPUDevice WireServer::GetDevice(uint32_t id, uint32_t generation) {
+        return mImpl->GetDevice(id, generation);
+    }
+
     namespace server {
         MemoryTransferService::MemoryTransferService() = default;
 
diff --git a/src/dawn_wire/client/Client.cpp b/src/dawn_wire/client/Client.cpp
index 0ca5f61..b566529 100644
--- a/src/dawn_wire/client/Client.cpp
+++ b/src/dawn_wire/client/Client.cpp
@@ -85,8 +85,13 @@
     }
 
     WGPUDevice Client::GetDevice() {
+        // This function is deprecated. The concept of a "default" device on the wire
+        // will be removed in favor of ReserveDevice/InjectDevice.
         if (mDevice == nullptr) {
-            mDevice = DeviceAllocator().New(this)->object.get();
+            ReservedDevice reservation = ReserveDevice();
+            mDevice = FromAPI(reservation.device);
+            ASSERT(reservation.id == 1);
+            ASSERT(reservation.generation == 0);
         }
         return reinterpret_cast<WGPUDeviceImpl*>(mDevice);
     }
@@ -103,6 +108,16 @@
         return result;
     }
 
+    ReservedDevice Client::ReserveDevice() {
+        auto* allocation = DeviceAllocator().New(this);
+
+        ReservedDevice result;
+        result.device = ToAPI(allocation->object.get());
+        result.id = allocation->object->id;
+        result.generation = allocation->generation;
+        return result;
+    }
+
     void Client::Disconnect() {
         mDisconnected = true;
         mSerializer = ChunkedCommandSerializer(NoopCommandSerializer::GetInstance());
diff --git a/src/dawn_wire/client/Client.h b/src/dawn_wire/client/Client.h
index 4902df8..dd7ac76 100644
--- a/src/dawn_wire/client/Client.h
+++ b/src/dawn_wire/client/Client.h
@@ -46,6 +46,7 @@
         }
 
         ReservedTexture ReserveTexture(WGPUDevice device);
+        ReservedDevice ReserveDevice();
 
         template <typename Cmd>
         void SerializeCommand(const Cmd& cmd) {
diff --git a/src/dawn_wire/client/Device.cpp b/src/dawn_wire/client/Device.cpp
index b6ee4fc..2d643cb 100644
--- a/src/dawn_wire/client/Device.cpp
+++ b/src/dawn_wire/client/Device.cpp
@@ -45,15 +45,6 @@
             }
         };
 #endif  // DAWN_ENABLE_ASSERTS
-        // Get the default queue for this device.
-        auto* allocation = client->QueueAllocator().New(client);
-        mDefaultQueue = allocation->object.get();
-
-        DeviceGetDefaultQueueCmd cmd;
-        cmd.self = ToAPI(this);
-        cmd.result = ObjectHandle{allocation->object->id, allocation->generation};
-
-        client->SerializeCommand(cmd);
     }
 
     Device::~Device() {
@@ -206,6 +197,22 @@
     }
 
     WGPUQueue Device::GetDefaultQueue() {
+        // The queue is lazily created because if a Device is created by
+        // Reserve/Inject, we cannot send the getDefaultQueue message until
+        // it has been injected on the Server. It cannot happen immediately
+        // on construction.
+        if (mDefaultQueue == nullptr) {
+            // Get the default queue for this device.
+            auto* allocation = client->QueueAllocator().New(client);
+            mDefaultQueue = allocation->object.get();
+
+            DeviceGetDefaultQueueCmd cmd;
+            cmd.self = ToAPI(this);
+            cmd.result = ObjectHandle{allocation->object->id, allocation->generation};
+
+            client->SerializeCommand(cmd);
+        }
+
         mDefaultQueue->refcount++;
         return ToAPI(mDefaultQueue);
     }
diff --git a/src/dawn_wire/server/ObjectStorage.h b/src/dawn_wire/server/ObjectStorage.h
index 74cc5a7..c803f53 100644
--- a/src/dawn_wire/server/ObjectStorage.h
+++ b/src/dawn_wire/server/ObjectStorage.h
@@ -160,6 +160,17 @@
             return objects;
         }
 
+        std::vector<T> GetAllHandles() {
+            std::vector<T> objects;
+            for (Data& data : mKnown) {
+                if (data.allocated && data.handle != nullptr) {
+                    objects.push_back(data.handle);
+                }
+            }
+
+            return objects;
+        }
+
       private:
         std::vector<Data> mKnown;
     };
diff --git a/src/dawn_wire/server/Server.cpp b/src/dawn_wire/server/Server.cpp
index 67a9dd5..39e50ea 100644
--- a/src/dawn_wire/server/Server.cpp
+++ b/src/dawn_wire/server/Server.cpp
@@ -23,7 +23,6 @@
                    MemoryTransferService* memoryTransferService)
         : mSerializer(serializer),
           mProcs(procs),
-          mDeviceOnCreation(device),
           mMemoryTransferService(memoryTransferService),
           mIsAlive(std::make_shared<bool>(true)) {
         if (mMemoryTransferService == nullptr) {
@@ -31,38 +30,21 @@
             mOwnedMemoryTransferService = CreateInlineMemoryTransferService();
             mMemoryTransferService = mOwnedMemoryTransferService.get();
         }
-        // The client-server knowledge is bootstrapped with device 1.
-        auto* deviceData = DeviceObjects().Allocate(1);
-        deviceData->handle = device;
 
-        // Take an extra ref. All objects may be freed by the client, but this
-        // one is externally owned.
-        mProcs.deviceReference(device);
-
-        // Note: these callbacks are manually inlined here since they do not acquire and
-        // free their userdata.
-        mProcs.deviceSetUncapturedErrorCallback(
-            device,
-            [](WGPUErrorType type, const char* message, void* userdata) {
-                Server* server = static_cast<Server*>(userdata);
-                server->OnUncapturedError(type, message);
-            },
-            this);
-        mProcs.deviceSetDeviceLostCallback(
-            device,
-            [](const char* message, void* userdata) {
-                Server* server = static_cast<Server*>(userdata);
-                server->OnDeviceLost(message);
-            },
-            this);
+        // For the deprecated initialization path:
+        // The client-server knowledge is bootstrapped with device 1, generation 0.
+        if (device != nullptr) {
+            bool success = InjectDevice(device, 1, 0);
+            ASSERT(success);
+        }
     }
 
     Server::~Server() {
         // Un-set the error and lost callbacks since we cannot forward them
         // after the server has been destroyed.
-        mProcs.deviceSetUncapturedErrorCallback(mDeviceOnCreation, nullptr, nullptr);
-        mProcs.deviceSetDeviceLostCallback(mDeviceOnCreation, nullptr, nullptr);
-
+        for (WGPUDevice device : DeviceObjects().GetAllHandles()) {
+            ClearDeviceCallbacks(device);
+        }
         DestroyAllObjects(mProcs);
     }
 
@@ -71,6 +53,7 @@
                                uint32_t generation,
                                uint32_t deviceId,
                                uint32_t deviceGeneration) {
+        ASSERT(texture != nullptr);
         ObjectData<WGPUDevice>* device = DeviceObjects().Get(deviceId);
         if (device == nullptr || device->generation != deviceGeneration) {
             return false;
@@ -97,6 +80,57 @@
         return true;
     }
 
+    bool Server::InjectDevice(WGPUDevice device, uint32_t id, uint32_t generation) {
+        ASSERT(device != nullptr);
+        ObjectData<WGPUDevice>* data = DeviceObjects().Allocate(id);
+        if (data == nullptr) {
+            return false;
+        }
+
+        data->handle = device;
+        data->generation = generation;
+        data->allocated = true;
+
+        // The device is externally owned so it shouldn't be destroyed when we receive a destroy
+        // message from the client. Add a reference to counterbalance the eventual release.
+        mProcs.deviceReference(device);
+
+        // Set callbacks to forward errors to the client.
+        // Note: these callbacks are manually inlined here since they do not acquire and
+        // free their userdata.
+        mProcs.deviceSetUncapturedErrorCallback(
+            device,
+            [](WGPUErrorType type, const char* message, void* userdata) {
+                Server* server = static_cast<Server*>(userdata);
+                server->OnUncapturedError(type, message);
+            },
+            this);
+        mProcs.deviceSetDeviceLostCallback(
+            device,
+            [](const char* message, void* userdata) {
+                Server* server = static_cast<Server*>(userdata);
+                server->OnDeviceLost(message);
+            },
+            this);
+
+        return true;
+    }
+
+    WGPUDevice Server::GetDevice(uint32_t id, uint32_t generation) {
+        ObjectData<WGPUDevice>* data = DeviceObjects().Get(id);
+        if (data == nullptr || data->generation != generation) {
+            return nullptr;
+        }
+        return data->handle;
+    }
+
+    void Server::ClearDeviceCallbacks(WGPUDevice device) {
+        // Un-set the error and lost callbacks since we cannot forward them
+        // after the server has been destroyed.
+        mProcs.deviceSetUncapturedErrorCallback(device, nullptr, nullptr);
+        mProcs.deviceSetDeviceLostCallback(device, nullptr, nullptr);
+    }
+
     bool TrackDeviceChild(ObjectDataBase<WGPUDevice>* device, ObjectType type, ObjectId id) {
         auto it = static_cast<ObjectData<WGPUDevice>*>(device)->childObjectTypesAndIds.insert(
             PackObjectTypeAndId(type, id));
diff --git a/src/dawn_wire/server/Server.h b/src/dawn_wire/server/Server.h
index f45ed0d..4056896 100644
--- a/src/dawn_wire/server/Server.h
+++ b/src/dawn_wire/server/Server.h
@@ -167,6 +167,10 @@
                            uint32_t deviceId,
                            uint32_t deviceGeneration);
 
+        bool InjectDevice(WGPUDevice device, uint32_t id, uint32_t generation);
+
+        WGPUDevice GetDevice(uint32_t id, uint32_t generation);
+
         template <typename T,
                   typename Enable = std::enable_if<std::is_base_of<CallbackUserdata, T>::value>>
         std::unique_ptr<T> MakeUserdata() {
@@ -186,6 +190,7 @@
             mSerializer.SerializeCommand(cmd, extraSize, SerializeExtraSize);
         }
 
+        void ClearDeviceCallbacks(WGPUDevice device);
 
         // Error callbacks
         void OnUncapturedError(WGPUErrorType type, const char* message);
@@ -212,7 +217,6 @@
         WireDeserializeAllocator mAllocator;
         ChunkedCommandSerializer mSerializer;
         DawnProcTable mProcs;
-        WGPUDevice mDeviceOnCreation;
         std::unique_ptr<MemoryTransferService> mOwnedMemoryTransferService = nullptr;
         MemoryTransferService* mMemoryTransferService = nullptr;
 
diff --git a/src/include/dawn_wire/WireClient.h b/src/include/dawn_wire/WireClient.h
index 8af02a97..b8f1247 100644
--- a/src/include/dawn_wire/WireClient.h
+++ b/src/include/dawn_wire/WireClient.h
@@ -38,6 +38,12 @@
         uint32_t deviceGeneration;
     };
 
+    struct ReservedDevice {
+        WGPUDevice device;
+        uint32_t id;
+        uint32_t generation;
+    };
+
     struct DAWN_WIRE_EXPORT WireClientDescriptor {
         CommandSerializer* serializer;
         client::MemoryTransferService* memoryTransferService = nullptr;
@@ -53,6 +59,7 @@
                                             size_t size) override final;
 
         ReservedTexture ReserveTexture(WGPUDevice device);
+        ReservedDevice ReserveDevice();
 
         // Disconnects the client.
         // Commands allocated after this point will not be sent.
diff --git a/src/include/dawn_wire/WireServer.h b/src/include/dawn_wire/WireServer.h
index ad36f44..9ff6fed 100644
--- a/src/include/dawn_wire/WireServer.h
+++ b/src/include/dawn_wire/WireServer.h
@@ -50,6 +50,17 @@
                            uint32_t deviceId = 1,
                            uint32_t deviceGeneration = 0);
 
+        bool InjectDevice(WGPUDevice device, uint32_t id, uint32_t generation);
+
+        // Look up a device by (id, generation) pair. Returns nullptr if the generation
+        // has expired or the id is not found.
+        // The Wire does not have destroy hooks to allow an embedder to observe when an object
+        // has been destroyed, but in Chrome, we need to know the list of live devices so we
+        // can call device.Tick() on all of them periodically to ensure progress on asynchronous
+        // work is made. Getting this list can be done by tracking the (id, generation) of
+        // previously injected devices, and observing if GetDevice(id, generation) returns non-null.
+        WGPUDevice GetDevice(uint32_t id, uint32_t generation);
+
       private:
         std::unique_ptr<server::Server> mImpl;
     };
diff --git a/src/tests/BUILD.gn b/src/tests/BUILD.gn
index 5a3e310..e614db2 100644
--- a/src/tests/BUILD.gn
+++ b/src/tests/BUILD.gn
@@ -223,6 +223,7 @@
     "unittests/wire/WireErrorCallbackTests.cpp",
     "unittests/wire/WireExtensionTests.cpp",
     "unittests/wire/WireFenceTests.cpp",
+    "unittests/wire/WireInjectDeviceTests.cpp",
     "unittests/wire/WireInjectTextureTests.cpp",
     "unittests/wire/WireMemoryTransferServiceTests.cpp",
     "unittests/wire/WireMultipleDeviceTests.cpp",
diff --git a/src/tests/unittests/wire/WireDestroyObjectTests.cpp b/src/tests/unittests/wire/WireDestroyObjectTests.cpp
index 34b976d..2c7ddc2 100644
--- a/src/tests/unittests/wire/WireDestroyObjectTests.cpp
+++ b/src/tests/unittests/wire/WireDestroyObjectTests.cpp
@@ -36,10 +36,19 @@
     // The device and child objects should be released.
     EXPECT_CALL(api, CommandEncoderRelease(apiEncoder)).InSequence(s1);
     EXPECT_CALL(api, QueueRelease(apiQueue)).InSequence(s2);
+    EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(apiDevice, nullptr, nullptr))
+        .Times(1)
+        .InSequence(s1, s2);
+    EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(apiDevice, nullptr, nullptr))
+        .Times(1)
+        .InSequence(s1, s2);
     EXPECT_CALL(api, DeviceRelease(apiDevice)).InSequence(s1, s2);
 
     FlushClient();
 
+    // Signal that we already released and cleared callbacks for |apiDevice|
+    DefaultApiDeviceWasReleased();
+
     // Using the command encoder should be an error.
     wgpuCommandEncoderFinish(encoder, nullptr);
     FlushClient(false);
@@ -82,8 +91,17 @@
         // The device and child objects alre also released.
         EXPECT_CALL(api, BufferRelease(apiBuffer)).InSequence(s1);
         EXPECT_CALL(api, QueueRelease(apiQueue)).InSequence(s2);
+        EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(apiDevice, nullptr, nullptr))
+            .Times(1)
+            .InSequence(s1, s2);
+        EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(apiDevice, nullptr, nullptr))
+            .Times(1)
+            .InSequence(s1, s2);
         EXPECT_CALL(api, DeviceRelease(apiDevice)).InSequence(s1, s2);
 
         FlushClient();
+
+        // Signal that we already released and cleared callbacks for |apiDevice|
+        DefaultApiDeviceWasReleased();
     }
 }
diff --git a/src/tests/unittests/wire/WireDisconnectTests.cpp b/src/tests/unittests/wire/WireDisconnectTests.cpp
index f44df13..d3f65a9 100644
--- a/src/tests/unittests/wire/WireDisconnectTests.cpp
+++ b/src/tests/unittests/wire/WireDisconnectTests.cpp
@@ -149,6 +149,15 @@
     EXPECT_CALL(api, QueueRelease(apiQueue)).Times(1).InSequence(s1);
     EXPECT_CALL(api, CommandEncoderRelease(apiCommandEncoder)).Times(1).InSequence(s2);
     EXPECT_CALL(api, SamplerRelease(apiSampler)).Times(1).InSequence(s3);
+    EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(apiDevice, nullptr, nullptr))
+        .Times(1)
+        .InSequence(s1, s2);
+    EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(apiDevice, nullptr, nullptr))
+        .Times(1)
+        .InSequence(s1, s2);
     EXPECT_CALL(api, DeviceRelease(apiDevice)).Times(1).InSequence(s1, s2, s3);
     FlushClient();
+
+    // Signal that we already released and cleared callbacks for |apiDevice|
+    DefaultApiDeviceWasReleased();
 }
diff --git a/src/tests/unittests/wire/WireInjectDeviceTests.cpp b/src/tests/unittests/wire/WireInjectDeviceTests.cpp
new file mode 100644
index 0000000..8f1dda3
--- /dev/null
+++ b/src/tests/unittests/wire/WireInjectDeviceTests.cpp
@@ -0,0 +1,184 @@
+// Copyright 2021 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "tests/unittests/wire/WireTest.h"
+
+#include "dawn_wire/WireClient.h"
+#include "dawn_wire/WireServer.h"
+
+using namespace testing;
+using namespace dawn_wire;
+
+class WireInjectDeviceTests : public WireTest {
+  public:
+    WireInjectDeviceTests() {
+    }
+    ~WireInjectDeviceTests() override = default;
+};
+
+// Test that reserving and injecting a device makes calls on the client object forward to the
+// server object correctly.
+TEST_F(WireInjectDeviceTests, CallAfterReserveInject) {
+    ReservedDevice reservation = GetWireClient()->ReserveDevice();
+
+    WGPUDevice serverDevice = api.GetNewDevice();
+    EXPECT_CALL(api, DeviceReference(serverDevice));
+    EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice, _, _));
+    EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice, _, _));
+    ASSERT_TRUE(
+        GetWireServer()->InjectDevice(serverDevice, reservation.id, reservation.generation));
+
+    WGPUBufferDescriptor bufferDesc = {};
+    wgpuDeviceCreateBuffer(reservation.device, &bufferDesc);
+    WGPUBuffer serverBuffer = api.GetNewBuffer();
+    EXPECT_CALL(api, DeviceCreateBuffer(serverDevice, _)).WillOnce(Return(serverBuffer));
+    FlushClient();
+
+    // Called on shutdown.
+    EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice, nullptr, nullptr))
+        .Times(Exactly(1));
+    EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice, nullptr, nullptr))
+        .Times(Exactly(1));
+}
+
+// Test that reserve correctly returns different IDs each time.
+TEST_F(WireInjectDeviceTests, ReserveDifferentIDs) {
+    ReservedDevice reservation1 = GetWireClient()->ReserveDevice();
+    ReservedDevice reservation2 = GetWireClient()->ReserveDevice();
+
+    ASSERT_NE(reservation1.id, reservation2.id);
+    ASSERT_NE(reservation1.device, reservation2.device);
+}
+
+// Test that injecting the same id without a destroy first fails.
+TEST_F(WireInjectDeviceTests, InjectExistingID) {
+    ReservedDevice reservation = GetWireClient()->ReserveDevice();
+
+    WGPUDevice serverDevice = api.GetNewDevice();
+    EXPECT_CALL(api, DeviceReference(serverDevice));
+    EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice, _, _));
+    EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice, _, _));
+    ASSERT_TRUE(
+        GetWireServer()->InjectDevice(serverDevice, reservation.id, reservation.generation));
+
+    // ID already in use, call fails.
+    ASSERT_FALSE(
+        GetWireServer()->InjectDevice(serverDevice, reservation.id, reservation.generation));
+
+    // Called on shutdown.
+    EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice, nullptr, nullptr))
+        .Times(Exactly(1));
+    EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice, nullptr, nullptr))
+        .Times(Exactly(1));
+}
+
+// Test that the server only borrows the device and does a single reference-release
+TEST_F(WireInjectDeviceTests, InjectedDeviceLifetime) {
+    ReservedDevice reservation = GetWireClient()->ReserveDevice();
+
+    // Injecting the device adds a reference
+    WGPUDevice serverDevice = api.GetNewDevice();
+    EXPECT_CALL(api, DeviceReference(serverDevice));
+    EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice, _, _));
+    EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice, _, _));
+    ASSERT_TRUE(
+        GetWireServer()->InjectDevice(serverDevice, reservation.id, reservation.generation));
+
+    // Releasing the device removes a single reference and clears its error callbacks.
+    wgpuDeviceRelease(reservation.device);
+    EXPECT_CALL(api, DeviceRelease(serverDevice));
+    EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice, nullptr, nullptr)).Times(1);
+    EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice, nullptr, nullptr)).Times(1);
+    FlushClient();
+
+    // Deleting the server doesn't release a second reference.
+    DeleteServer();
+    Mock::VerifyAndClearExpectations(&api);
+}
+
+// Test that it is an error to get the default queue of a device before it has been
+// injected on the server.
+TEST_F(WireInjectDeviceTests, GetQueueBeforeInject) {
+    ReservedDevice reservation = GetWireClient()->ReserveDevice();
+
+    wgpuDeviceGetDefaultQueue(reservation.device);
+    FlushClient(false);
+}
+
+// Test that it is valid to get the default queue of a device after it has been
+// injected on the server.
+TEST_F(WireInjectDeviceTests, GetQueueAfterInject) {
+    ReservedDevice reservation = GetWireClient()->ReserveDevice();
+
+    WGPUDevice serverDevice = api.GetNewDevice();
+    EXPECT_CALL(api, DeviceReference(serverDevice));
+    EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice, _, _));
+    EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice, _, _));
+    ASSERT_TRUE(
+        GetWireServer()->InjectDevice(serverDevice, reservation.id, reservation.generation));
+
+    wgpuDeviceGetDefaultQueue(reservation.device);
+
+    WGPUQueue apiQueue = api.GetNewQueue();
+    EXPECT_CALL(api, DeviceGetDefaultQueue(serverDevice)).WillOnce(Return(apiQueue));
+    FlushClient();
+
+    // Called on shutdown.
+    EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice, nullptr, nullptr))
+        .Times(Exactly(1));
+    EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice, nullptr, nullptr))
+        .Times(Exactly(1));
+}
+
+// Test that the list of live devices can be reflected using GetDevice.
+TEST_F(WireInjectDeviceTests, ReflectLiveDevices) {
+    // Reserve two devices.
+    ReservedDevice reservation1 = GetWireClient()->ReserveDevice();
+    ReservedDevice reservation2 = GetWireClient()->ReserveDevice();
+
+    // Inject both devices.
+
+    WGPUDevice serverDevice1 = api.GetNewDevice();
+    EXPECT_CALL(api, DeviceReference(serverDevice1));
+    EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice1, _, _));
+    EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice1, _, _));
+    ASSERT_TRUE(
+        GetWireServer()->InjectDevice(serverDevice1, reservation1.id, reservation1.generation));
+
+    WGPUDevice serverDevice2 = api.GetNewDevice();
+    EXPECT_CALL(api, DeviceReference(serverDevice2));
+    EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice2, _, _));
+    EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice2, _, _));
+    ASSERT_TRUE(
+        GetWireServer()->InjectDevice(serverDevice2, reservation2.id, reservation2.generation));
+
+    // Test that both devices can be reflected.
+    ASSERT_EQ(serverDevice1, GetWireServer()->GetDevice(reservation1.id, reservation1.generation));
+    ASSERT_EQ(serverDevice2, GetWireServer()->GetDevice(reservation2.id, reservation2.generation));
+
+    // Release the first device
+    wgpuDeviceRelease(reservation1.device);
+    EXPECT_CALL(api, DeviceRelease(serverDevice1));
+    EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice1, nullptr, nullptr)).Times(1);
+    EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice1, nullptr, nullptr)).Times(1);
+    FlushClient();
+
+    // The first device should no longer reflect, but the second should
+    ASSERT_EQ(nullptr, GetWireServer()->GetDevice(reservation1.id, reservation1.generation));
+    ASSERT_EQ(serverDevice2, GetWireServer()->GetDevice(reservation2.id, reservation2.generation));
+
+    // Called on shutdown.
+    EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice2, nullptr, nullptr)).Times(1);
+    EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice2, nullptr, nullptr)).Times(1);
+}
diff --git a/src/tests/unittests/wire/WireMultipleDeviceTests.cpp b/src/tests/unittests/wire/WireMultipleDeviceTests.cpp
index 75cedcd..216c122 100644
--- a/src/tests/unittests/wire/WireMultipleDeviceTests.cpp
+++ b/src/tests/unittests/wire/WireMultipleDeviceTests.cpp
@@ -83,9 +83,10 @@
 
             // These are called on server destruction to clear the callbacks. They must not be
             // called after the server is destroyed.
-            EXPECT_CALL(mApi, OnDeviceSetUncapturedErrorCallback(_, nullptr, nullptr))
+            EXPECT_CALL(mApi, OnDeviceSetUncapturedErrorCallback(mServerDevice, nullptr, nullptr))
                 .Times(Exactly(1));
-            EXPECT_CALL(mApi, OnDeviceSetDeviceLostCallback(_, nullptr, nullptr)).Times(Exactly(1));
+            EXPECT_CALL(mApi, OnDeviceSetDeviceLostCallback(mServerDevice, nullptr, nullptr))
+                .Times(Exactly(1));
             mWireServer = nullptr;
         }
 
diff --git a/src/tests/unittests/wire/WireTest.cpp b/src/tests/unittests/wire/WireTest.cpp
index 2609511..d3f17c3 100644
--- a/src/tests/unittests/wire/WireTest.cpp
+++ b/src/tests/unittests/wire/WireTest.cpp
@@ -88,15 +88,23 @@
     api.IgnoreAllReleaseCalls();
     mWireClient = nullptr;
 
-    if (mWireServer) {
+    if (mWireServer && apiDevice) {
         // These are called on server destruction to clear the callbacks. They must not be
         // called after the server is destroyed.
-        EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(_, nullptr, nullptr)).Times(Exactly(1));
-        EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(_, nullptr, nullptr)).Times(Exactly(1));
+        EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(apiDevice, nullptr, nullptr))
+            .Times(Exactly(1));
+        EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(apiDevice, nullptr, nullptr))
+            .Times(Exactly(1));
     }
     mWireServer = nullptr;
 }
 
+// This should be called if |apiDevice| is no longer exists on the wire.
+// This signals that expectations in |TearDowb| shouldn't be added.
+void WireTest::DefaultApiDeviceWasReleased() {
+    apiDevice = nullptr;
+}
+
 void WireTest::FlushClient(bool success) {
     ASSERT_EQ(mC2sBuf->Flush(), success);
 
@@ -123,8 +131,10 @@
     if (mWireServer) {
         // These are called on server destruction to clear the callbacks. They must not be
         // called after the server is destroyed.
-        EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(_, nullptr, nullptr)).Times(Exactly(1));
-        EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(_, nullptr, nullptr)).Times(Exactly(1));
+        EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(apiDevice, nullptr, nullptr))
+            .Times(Exactly(1));
+        EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(apiDevice, nullptr, nullptr))
+            .Times(Exactly(1));
     }
     mWireServer = nullptr;
 }
diff --git a/src/tests/unittests/wire/WireTest.h b/src/tests/unittests/wire/WireTest.h
index 95fd307..03ac641 100644
--- a/src/tests/unittests/wire/WireTest.h
+++ b/src/tests/unittests/wire/WireTest.h
@@ -123,6 +123,8 @@
     void FlushClient(bool success = true);
     void FlushServer(bool success = true);
 
+    void DefaultApiDeviceWasReleased();
+
     testing::StrictMock<MockProcTable> api;
     WGPUDevice apiDevice;
     WGPUQueue apiQueue;