Remove special-casing of device reference/release in the wire

The wire's device is externally owned so reference/release were no-ops.
To unify the code paths, remove the special casing and instead
take an extra ref on the device the wire server is created with. This
is functionally equivalent and will allow both the current wire code,
and the incoming change to allow multiple device/adapter creation to
both work.

This CL also makes it possible for the client to destroy the device
before child objects.
A follow-up CL will mitigate this on the server side.

Bug: dawn:384
Change-Id: Ic5427074469012dccf8689ec95a848e6ba2c1fc2
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/37001
Commit-Queue: Austin Eng <enga@chromium.org>
Reviewed-by: Jiawei Shao <jiawei.shao@intel.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/generator/templates/dawn_wire/client/ApiProcs.cpp b/generator/templates/dawn_wire/client/ApiProcs.cpp
index 1b98d02..9273155 100644
--- a/generator/templates/dawn_wire/client/ApiProcs.cpp
+++ b/generator/templates/dawn_wire/client/ApiProcs.cpp
@@ -207,29 +207,27 @@
             }
         {% endfor %}
 
-        {% if not type.name.canonical_case() == "device" %}
-            //* When an object's refcount reaches 0, notify the server side of it and delete it.
-            void Client{{as_MethodSuffix(type.name, Name("release"))}}({{cType}} cObj) {
-                {{Type}}* obj = reinterpret_cast<{{Type}}*>(cObj);
-                obj->refcount --;
+        //* When an object's refcount reaches 0, notify the server side of it and delete it.
+        void Client{{as_MethodSuffix(type.name, Name("release"))}}({{cType}} cObj) {
+            {{Type}}* obj = reinterpret_cast<{{Type}}*>(cObj);
+            obj->refcount --;
 
-                if (obj->refcount > 0) {
-                    return;
-                }
-
-                DestroyObjectCmd cmd;
-                cmd.objectType = ObjectType::{{type.name.CamelCase()}};
-                cmd.objectId = obj->id;
-
-                obj->client->SerializeCommand(cmd);
-                obj->client->{{type.name.CamelCase()}}Allocator().Free(obj);
+            if (obj->refcount > 0) {
+                return;
             }
 
-            void Client{{as_MethodSuffix(type.name, Name("reference"))}}({{cType}} cObj) {
-                {{Type}}* obj = reinterpret_cast<{{Type}}*>(cObj);
-                obj->refcount ++;
-            }
-        {% endif %}
+            DestroyObjectCmd cmd;
+            cmd.objectType = ObjectType::{{type.name.CamelCase()}};
+            cmd.objectId = obj->id;
+
+            obj->client->SerializeCommand(cmd);
+            obj->client->{{type.name.CamelCase()}}Allocator().Free(obj);
+        }
+
+        void Client{{as_MethodSuffix(type.name, Name("reference"))}}({{cType}} cObj) {
+            {{Type}}* obj = reinterpret_cast<{{Type}}*>(cObj);
+            obj->refcount ++;
+        }
     {% endfor %}
 
     namespace {
@@ -238,12 +236,6 @@
             return nullptr;
         }
 
-        void ClientDeviceReference(WGPUDevice) {
-        }
-
-        void ClientDeviceRelease(WGPUDevice) {
-        }
-
         struct ProcEntry {
             WGPUProc proc;
             const char* name;
diff --git a/generator/templates/dawn_wire/server/ServerBase.h b/generator/templates/dawn_wire/server/ServerBase.h
index 19bf1aa..4c488ee 100644
--- a/generator/templates/dawn_wire/server/ServerBase.h
+++ b/generator/templates/dawn_wire/server/ServerBase.h
@@ -32,7 +32,7 @@
       protected:
         void DestroyAllObjects(const DawnProcTable& procs) {
             //* Free all objects when the server is destroyed
-            {% for type in by_category["object"] if type.name.canonical_case() != "device" %}
+            {% for type in by_category["object"] %}
                 {
                     std::vector<{{as_cType(type.name)}}> handles = mKnown{{type.name.CamelCase()}}.AcquireAllHandles();
                     for ({{as_cType(type.name)}} handle : handles) {
diff --git a/generator/templates/dawn_wire/server/ServerDoers.cpp b/generator/templates/dawn_wire/server/ServerDoers.cpp
index 3ebdedf..08b407a 100644
--- a/generator/templates/dawn_wire/server/ServerDoers.cpp
+++ b/generator/templates/dawn_wire/server/ServerDoers.cpp
@@ -73,23 +73,18 @@
         switch(objectType) {
             {% for type in by_category["object"] %}
                 case ObjectType::{{type.name.CamelCase()}}: {
-                    {% if type.name.CamelCase() == "Device" %}
-                        //* Freeing the device has to be done out of band.
+                    auto* data = {{type.name.CamelCase()}}Objects().Get(objectId);
+                    if (data == nullptr) {
                         return false;
-                    {% else %}
-                        auto* data = {{type.name.CamelCase()}}Objects().Get(objectId);
-                        if (data == nullptr) {
-                            return false;
-                        }
-                        {% if type.name.CamelCase() in server_reverse_lookup_objects %}
-                            {{type.name.CamelCase()}}ObjectIdTable().Remove(data->handle);
-                        {% endif %}
-                        if (data->handle != nullptr) {
-                            mProcs.{{as_varName(type.name, Name("release"))}}(data->handle);
-                        }
-                        {{type.name.CamelCase()}}Objects().Free(objectId);
-                        return true;
+                    }
+                    {% if type.name.CamelCase() in server_reverse_lookup_objects %}
+                        {{type.name.CamelCase()}}ObjectIdTable().Remove(data->handle);
                     {% endif %}
+                    if (data->handle != nullptr) {
+                        mProcs.{{as_varName(type.name, Name("release"))}}(data->handle);
+                    }
+                    {{type.name.CamelCase()}}Objects().Free(objectId);
+                    return true;
                 }
             {% endfor %}
             default:
diff --git a/src/dawn_wire/client/Client.cpp b/src/dawn_wire/client/Client.cpp
index d12bc57..51af5a6 100644
--- a/src/dawn_wire/client/Client.cpp
+++ b/src/dawn_wire/client/Client.cpp
@@ -61,12 +61,6 @@
             ObjectType objectType = static_cast<ObjectType>(&objectList - mObjects.data());
             while (!objectList.empty()) {
                 ObjectBase* object = objectList.head()->value();
-                if (object == mDevice) {
-                    // Note: We don't send a DestroyObject command for the device
-                    // since freeing a device object is done out of band.
-                    DeviceAllocator().Free(mDevice);
-                    continue;
-                }
 
                 DestroyObjectCmd cmd;
                 cmd.objectType = objectType;
diff --git a/src/dawn_wire/server/Server.cpp b/src/dawn_wire/server/Server.cpp
index a344581..bb5644b 100644
--- a/src/dawn_wire/server/Server.cpp
+++ b/src/dawn_wire/server/Server.cpp
@@ -23,6 +23,7 @@
                    MemoryTransferService* memoryTransferService)
         : mSerializer(serializer),
           mProcs(procs),
+          mDeviceOnCreation(device),
           mMemoryTransferService(memoryTransferService),
           mIsAlive(std::make_shared<bool>(true)) {
         if (mMemoryTransferService == nullptr) {
@@ -34,6 +35,10 @@
         auto* deviceData = DeviceObjects().Allocate(1);
         deviceData->handle = device;
 
+        // Take an extra ref. All objects may be freed by the client, but this
+        // one is externally owned.
+        mProcs.deviceReference(device);
+
         // Note: these callbacks are manually inlined here since they do not acquire and
         // free their userdata.
         mProcs.deviceSetUncapturedErrorCallback(
@@ -55,9 +60,8 @@
     Server::~Server() {
         // Un-set the error and lost callbacks since we cannot forward them
         // after the server has been destroyed.
-        WGPUDevice device = DeviceObjects().Get(1)->handle;
-        mProcs.deviceSetUncapturedErrorCallback(device, nullptr, nullptr);
-        mProcs.deviceSetDeviceLostCallback(device, nullptr, nullptr);
+        mProcs.deviceSetUncapturedErrorCallback(mDeviceOnCreation, nullptr, nullptr);
+        mProcs.deviceSetDeviceLostCallback(mDeviceOnCreation, nullptr, nullptr);
 
         DestroyAllObjects(mProcs);
     }
diff --git a/src/dawn_wire/server/Server.h b/src/dawn_wire/server/Server.h
index d503e81..867f29b 100644
--- a/src/dawn_wire/server/Server.h
+++ b/src/dawn_wire/server/Server.h
@@ -208,6 +208,7 @@
         WireDeserializeAllocator mAllocator;
         ChunkedCommandSerializer mSerializer;
         DawnProcTable mProcs;
+        WGPUDevice mDeviceOnCreation;
         std::unique_ptr<MemoryTransferService> mOwnedMemoryTransferService = nullptr;
         MemoryTransferService* mMemoryTransferService = nullptr;
 
diff --git a/src/tests/unittests/wire/WireDisconnectTests.cpp b/src/tests/unittests/wire/WireDisconnectTests.cpp
index e68fbab..4c99ced 100644
--- a/src/tests/unittests/wire/WireDisconnectTests.cpp
+++ b/src/tests/unittests/wire/WireDisconnectTests.cpp
@@ -145,6 +145,7 @@
     DeleteClient();
 
     // Expect release on all objects created by the client.
+    EXPECT_CALL(api, DeviceRelease(apiDevice)).Times(1);
     EXPECT_CALL(api, QueueRelease(apiQueue)).Times(1);
     EXPECT_CALL(api, CommandEncoderRelease(apiCommandEncoder)).Times(1);
     EXPECT_CALL(api, SamplerRelease(apiSampler)).Times(1);
diff --git a/src/tests/unittests/wire/WireMultipleDeviceTests.cpp b/src/tests/unittests/wire/WireMultipleDeviceTests.cpp
index f51354e..75cedcd 100644
--- a/src/tests/unittests/wire/WireMultipleDeviceTests.cpp
+++ b/src/tests/unittests/wire/WireMultipleDeviceTests.cpp
@@ -58,6 +58,7 @@
             serverDesc.procs = &mockProcs;
             serverDesc.serializer = mS2cBuf.get();
 
+            EXPECT_CALL(mApi, DeviceReference(mServerDevice));
             mWireServer.reset(new WireServer(serverDesc));
             mC2sBuf->SetHandler(mWireServer.get());
 
diff --git a/src/tests/unittests/wire/WireTest.cpp b/src/tests/unittests/wire/WireTest.cpp
index 2106b04..2609511 100644
--- a/src/tests/unittests/wire/WireTest.cpp
+++ b/src/tests/unittests/wire/WireTest.cpp
@@ -55,6 +55,7 @@
     serverDesc.serializer = mS2cBuf.get();
     serverDesc.memoryTransferService = GetServerMemoryTransferService();
 
+    EXPECT_CALL(api, DeviceReference(mockDevice));
     mWireServer.reset(new WireServer(serverDesc));
     mC2sBuf->SetHandler(mWireServer.get());
 
@@ -117,6 +118,7 @@
 
 void WireTest::DeleteServer() {
     EXPECT_CALL(api, QueueRelease(apiQueue)).Times(1);
+    EXPECT_CALL(api, DeviceRelease(apiDevice)).Times(1);
 
     if (mWireServer) {
         // These are called on server destruction to clear the callbacks. They must not be