dawn_wire: Fix a bug with multiple injected devices

Device child objects were storing an *unstable* pointer to device
specific tracking information. Fix this by moving the tracking
information to a stable heap allocation.

Bug: dawn:565
Change-Id: I00ad72563ac66e29314603e77698718953fcbf15
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/38280
Commit-Queue: Austin Eng <enga@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/generator/templates/dawn_wire/server/ServerDoers.cpp b/generator/templates/dawn_wire/server/ServerDoers.cpp
index 0252336..a16ff9e 100644
--- a/generator/templates/dawn_wire/server/ServerDoers.cpp
+++ b/generator/templates/dawn_wire/server/ServerDoers.cpp
@@ -77,9 +77,8 @@
                     if (data == nullptr) {
                         return false;
                     }
-                    if (data->device != nullptr) {
-                        auto* device = static_cast<ObjectData<WGPUDevice>*>(data->device);
-                        if (!UntrackDeviceChild(device, objectType, objectId)) {
+                    if (data->deviceInfo != nullptr) {
+                        if (!UntrackDeviceChild(data->deviceInfo, objectType, objectId)) {
                             return false;
                         }
                     }
@@ -91,11 +90,11 @@
                         //* are destroyed before their device. We should have a solution in
                         //* Dawn native that makes all child objects internally null if their
                         //* Device is destroyed.
-                        while (data->childObjectTypesAndIds.size() > 0) {
+                        while (data->info->childObjectTypesAndIds.size() > 0) {
                             ObjectType childObjectType;
                             ObjectId childObjectId;
                             std::tie(childObjectType, childObjectId) = UnpackObjectTypeAndId(
-                                *data->childObjectTypesAndIds.begin());
+                                *data->info->childObjectTypesAndIds.begin());
                             DoDestroyObject(childObjectType, childObjectId);
                         }
                         if (data->handle != nullptr) {
diff --git a/generator/templates/dawn_wire/server/ServerHandlers.cpp b/generator/templates/dawn_wire/server/ServerHandlers.cpp
index 1341e41..f23a684 100644
--- a/generator/templates/dawn_wire/server/ServerHandlers.cpp
+++ b/generator/templates/dawn_wire/server/ServerHandlers.cpp
@@ -59,13 +59,13 @@
                 {% if command.derived_object %}
                     {% set type = command.derived_object %}
                     {% if type.name.get() == "device" %}
-                        {{name}}Data->device = DeviceObjects().Get(cmd.selfId);
+                        {{name}}Data->deviceInfo = DeviceObjects().Get(cmd.selfId)->info.get();
                     {% else %}
                         auto* selfData = {{type.name.CamelCase()}}Objects().Get(cmd.selfId);
-                        {{name}}Data->device = selfData->device;
+                        {{name}}Data->deviceInfo = selfData->deviceInfo;
                     {% endif %}
-                    if ({{name}}Data->device != nullptr) {
-                        if (!TrackDeviceChild({{name}}Data->device, ObjectType::{{Type}}, cmd.{{name}}.id)) {
+                    if ({{name}}Data->deviceInfo != nullptr) {
+                        if (!TrackDeviceChild({{name}}Data->deviceInfo, ObjectType::{{Type}}, cmd.{{name}}.id)) {
                             return false;
                         }
                     }
diff --git a/src/dawn_wire/server/ObjectStorage.h b/src/dawn_wire/server/ObjectStorage.h
index c803f53..1595bda 100644
--- a/src/dawn_wire/server/ObjectStorage.h
+++ b/src/dawn_wire/server/ObjectStorage.h
@@ -24,6 +24,10 @@
 
 namespace dawn_wire { namespace server {
 
+    struct DeviceInfo {
+        std::unordered_set<uint64_t> childObjectTypesAndIds;
+    };
+
     template <typename T>
     struct ObjectDataBase {
         // The backend-provided handle and generation to this object.
@@ -34,7 +38,8 @@
         // TODO(cwallez@chromium.org): make this an internal bit vector in KnownObjects.
         bool allocated;
 
-        ObjectDataBase<WGPUDevice>* device = nullptr;
+        // This points to an allocation that is owned by the device.
+        DeviceInfo* deviceInfo = nullptr;
     };
 
     // Stores what the backend knows about the type.
@@ -68,7 +73,9 @@
 
     template <>
     struct ObjectData<WGPUDevice> : public ObjectDataBase<WGPUDevice> {
-        std::unordered_set<uint64_t> childObjectTypesAndIds;
+        // Store |info| as a separate allocation so that its address does not move.
+        // The pointer to |info| is stored in device child objects.
+        std::unique_ptr<DeviceInfo> info = std::make_unique<DeviceInfo>();
     };
 
     // Keeps track of the mapping between client IDs and backend objects.
diff --git a/src/dawn_wire/server/Server.cpp b/src/dawn_wire/server/Server.cpp
index 39e50ea..25a22f8 100644
--- a/src/dawn_wire/server/Server.cpp
+++ b/src/dawn_wire/server/Server.cpp
@@ -64,14 +64,14 @@
             return false;
         }
 
-        if (!TrackDeviceChild(device, ObjectType::Texture, id)) {
-            return false;
-        }
-
         data->handle = texture;
         data->generation = generation;
         data->allocated = true;
-        data->device = device;
+        data->deviceInfo = device->info.get();
+
+        if (!TrackDeviceChild(data->deviceInfo, ObjectType::Texture, id)) {
+            return false;
+        }
 
         // The texture is externally owned so it shouldn't be destroyed when we receive a destroy
         // message from the client. Add a reference to counterbalance the eventual release.
@@ -131,9 +131,8 @@
         mProcs.deviceSetDeviceLostCallback(device, nullptr, nullptr);
     }
 
-    bool TrackDeviceChild(ObjectDataBase<WGPUDevice>* device, ObjectType type, ObjectId id) {
-        auto it = static_cast<ObjectData<WGPUDevice>*>(device)->childObjectTypesAndIds.insert(
-            PackObjectTypeAndId(type, id));
+    bool TrackDeviceChild(DeviceInfo* info, ObjectType type, ObjectId id) {
+        auto it = info->childObjectTypesAndIds.insert(PackObjectTypeAndId(type, id));
         if (!it.second) {
             // An object of this type and id already exists.
             return false;
@@ -141,8 +140,8 @@
         return true;
     }
 
-    bool UntrackDeviceChild(ObjectDataBase<WGPUDevice>* device, ObjectType type, ObjectId id) {
-        auto& children = static_cast<ObjectData<WGPUDevice>*>(device)->childObjectTypesAndIds;
+    bool UntrackDeviceChild(DeviceInfo* info, ObjectType type, ObjectId id) {
+        auto& children = info->childObjectTypesAndIds;
         auto it = children.find(PackObjectTypeAndId(type, id));
         if (it == children.end()) {
             // An object of this type and id was already deleted.
diff --git a/src/dawn_wire/server/Server.h b/src/dawn_wire/server/Server.h
index 4056896..ffd7d22 100644
--- a/src/dawn_wire/server/Server.h
+++ b/src/dawn_wire/server/Server.h
@@ -223,8 +223,8 @@
         std::shared_ptr<bool> mIsAlive;
     };
 
-    bool TrackDeviceChild(ObjectDataBase<WGPUDevice>* device, ObjectType type, ObjectId id);
-    bool UntrackDeviceChild(ObjectDataBase<WGPUDevice>* device, ObjectType type, ObjectId id);
+    bool TrackDeviceChild(DeviceInfo* device, ObjectType type, ObjectId id);
+    bool UntrackDeviceChild(DeviceInfo* device, ObjectType type, ObjectId id);
 
     std::unique_ptr<MemoryTransferService> CreateInlineMemoryTransferService();
 
diff --git a/src/dawn_wire/server/ServerBuffer.cpp b/src/dawn_wire/server/ServerBuffer.cpp
index fdc850b..7cc5c9b 100644
--- a/src/dawn_wire/server/ServerBuffer.cpp
+++ b/src/dawn_wire/server/ServerBuffer.cpp
@@ -137,8 +137,8 @@
         }
         resultData->generation = bufferResult.generation;
         resultData->handle = mProcs.deviceCreateBuffer(device->handle, descriptor);
-        resultData->device = device;
-        if (!TrackDeviceChild(device, ObjectType::Buffer, bufferResult.id)) {
+        resultData->deviceInfo = device->info.get();
+        if (!TrackDeviceChild(resultData->deviceInfo, ObjectType::Buffer, bufferResult.id)) {
             return false;
         }
 
diff --git a/src/dawn_wire/server/ServerDevice.cpp b/src/dawn_wire/server/ServerDevice.cpp
index 8884dd1..c200725 100644
--- a/src/dawn_wire/server/ServerDevice.cpp
+++ b/src/dawn_wire/server/ServerDevice.cpp
@@ -74,8 +74,9 @@
         }
 
         resultData->generation = pipelineObjectHandle.generation;
-        resultData->device = device;
-        if (!TrackDeviceChild(device, ObjectType::ComputePipeline, pipelineObjectHandle.id)) {
+        resultData->deviceInfo = device->info.get();
+        if (!TrackDeviceChild(resultData->deviceInfo, ObjectType::ComputePipeline,
+                              pipelineObjectHandle.id)) {
             return false;
         }
 
@@ -140,8 +141,9 @@
         }
 
         resultData->generation = pipelineObjectHandle.generation;
-        resultData->device = device;
-        if (!TrackDeviceChild(device, ObjectType::RenderPipeline, pipelineObjectHandle.id)) {
+        resultData->deviceInfo = device->info.get();
+        if (!TrackDeviceChild(resultData->deviceInfo, ObjectType::RenderPipeline,
+                              pipelineObjectHandle.id)) {
             return false;
         }
 
diff --git a/src/tests/unittests/wire/WireInjectDeviceTests.cpp b/src/tests/unittests/wire/WireInjectDeviceTests.cpp
index 8f1dda3..897af4d 100644
--- a/src/tests/unittests/wire/WireInjectDeviceTests.cpp
+++ b/src/tests/unittests/wire/WireInjectDeviceTests.cpp
@@ -182,3 +182,49 @@
     EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice2, nullptr, nullptr)).Times(1);
     EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice2, nullptr, nullptr)).Times(1);
 }
+
+// This is a regression test where a second device reservation invalidated pointers into the
+// KnownObjects std::vector of devices. The fix was to store pointers to heap allocated
+// objects instead.
+TEST_F(WireInjectDeviceTests, TrackChildObjectsWithTwoReservedDevices) {
+    // Reserve one device, inject it, and get the default queue.
+    ReservedDevice reservation1 = GetWireClient()->ReserveDevice();
+
+    WGPUDevice serverDevice1 = api.GetNewDevice();
+    EXPECT_CALL(api, DeviceReference(serverDevice1));
+    EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice1, _, _));
+    EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice1, _, _));
+    ASSERT_TRUE(
+        GetWireServer()->InjectDevice(serverDevice1, reservation1.id, reservation1.generation));
+
+    WGPUCommandEncoder commandEncoder =
+        wgpuDeviceCreateCommandEncoder(reservation1.device, nullptr);
+
+    WGPUCommandEncoder serverCommandEncoder = api.GetNewCommandEncoder();
+    EXPECT_CALL(api, DeviceCreateCommandEncoder(serverDevice1, _))
+        .WillOnce(Return(serverCommandEncoder));
+    FlushClient();
+
+    // Reserve a second device, and inject it.
+    ReservedDevice reservation2 = GetWireClient()->ReserveDevice();
+
+    WGPUDevice serverDevice2 = api.GetNewDevice();
+    EXPECT_CALL(api, DeviceReference(serverDevice2));
+    EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice2, _, _));
+    EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice2, _, _));
+    ASSERT_TRUE(
+        GetWireServer()->InjectDevice(serverDevice2, reservation2.id, reservation2.generation));
+
+    // Release the encoder. This should work without error because it stores a stable
+    // pointer to its device's list of child objects. On destruction, it removes itself from the
+    // list.
+    wgpuCommandEncoderRelease(commandEncoder);
+    EXPECT_CALL(api, CommandEncoderRelease(serverCommandEncoder));
+    FlushClient();
+
+    // Called on shutdown.
+    EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice1, nullptr, nullptr)).Times(1);
+    EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice1, nullptr, nullptr)).Times(1);
+    EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice2, nullptr, nullptr)).Times(1);
+    EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice2, nullptr, nullptr)).Times(1);
+}