dawn_wire: Return early in callbacks after the server is destroyed

After the server is destroyed, the server's can't do anything like
forward callbacks to the client. Track this with a weak_ptr and
return early if it has expired.

It also updates device destruction in dawn_native so the lost
callback is always called, even on graceful destruction. This
is consistent with the rest of WebGPU where all callbacks are
guaranteed to be called in finite time.

Bug: chromium:1147416, chromium:1161943
Change-Id: Ib80dea36517401a2b8eafb01ded255ebbe757aef
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/35840
Commit-Queue: Austin Eng <enga@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Jiawei Shao <jiawei.shao@intel.com>
diff --git a/src/dawn_wire/server/Server.cpp b/src/dawn_wire/server/Server.cpp
index c7e67aa..a344581 100644
--- a/src/dawn_wire/server/Server.cpp
+++ b/src/dawn_wire/server/Server.cpp
@@ -34,11 +34,31 @@
         auto* deviceData = DeviceObjects().Allocate(1);
         deviceData->handle = device;
 
-        mProcs.deviceSetUncapturedErrorCallback(device, ForwardUncapturedError, this);
-        mProcs.deviceSetDeviceLostCallback(device, ForwardDeviceLost, this);
+        // Note: these callbacks are manually inlined here since they do not acquire and
+        // free their userdata.
+        mProcs.deviceSetUncapturedErrorCallback(
+            device,
+            [](WGPUErrorType type, const char* message, void* userdata) {
+                Server* server = static_cast<Server*>(userdata);
+                server->OnUncapturedError(type, message);
+            },
+            this);
+        mProcs.deviceSetDeviceLostCallback(
+            device,
+            [](const char* message, void* userdata) {
+                Server* server = static_cast<Server*>(userdata);
+                server->OnDeviceLost(message);
+            },
+            this);
     }
 
     Server::~Server() {
+        // Un-set the error and lost callbacks since we cannot forward them
+        // after the server has been destroyed.
+        WGPUDevice device = DeviceObjects().Get(1)->handle;
+        mProcs.deviceSetUncapturedErrorCallback(device, nullptr, nullptr);
+        mProcs.deviceSetDeviceLostCallback(device, nullptr, nullptr);
+
         DestroyAllObjects(mProcs);
     }
 
diff --git a/src/dawn_wire/server/Server.h b/src/dawn_wire/server/Server.h
index 08b5452..7b673d8 100644
--- a/src/dawn_wire/server/Server.h
+++ b/src/dawn_wire/server/Server.h
@@ -23,8 +23,94 @@
     class Server;
     class MemoryTransferService;
 
-    struct MapUserdata {
-        Server* server;
+    // CallbackUserdata and its derived classes are intended to be created by
+    // Server::MakeUserdata<T> and then passed as the userdata argument for Dawn
+    // callbacks.
+    // It contains a pointer back to the Server so that the callback can call the
+    // Server to perform operations like serialization, and it contains a weak pointer
+    // |serverIsAlive|. If the weak pointer has expired, it means the server has
+    // been destroyed and the callback must not use the Server pointer.
+    // To assist with checking |serverIsAlive| and lifetime management of the userdata,
+    // |ForwardToServer| (defined later in this file) can be used to acquire the userdata,
+    // return early if |serverIsAlive| has expired, and then forward the arguments
+    // to userdata->server->MyCallbackHandler.
+    //
+    // Example Usage:
+    //
+    // struct MyUserdata : CallbackUserdata { uint32_t foo; };
+    //
+    // auto userdata = MakeUserdata<MyUserdata>();
+    // userdata->foo = 2;
+    //
+    // // TODO(enga): Make the template inference for ForwardToServer cleaner with C++17
+    // callMyCallbackHandler(
+    //      ForwardToServer<decltype(&Server::MyCallbackHandler)>::Func<
+    //                      &Server::MyCallbackHandler>(),
+    //      userdata.release());
+    //
+    // void Server::MyCallbackHandler(MyUserdata* userdata) { }
+    struct CallbackUserdata {
+        Server* const server;
+        std::weak_ptr<bool> const serverIsAlive;
+
+      private:
+        friend class Server;
+        CallbackUserdata() = delete;
+        CallbackUserdata(Server* server, const std::shared_ptr<bool>& serverIsAlive)
+            : server(server), serverIsAlive(serverIsAlive) {
+        }
+    };
+
+    template <typename F>
+    class ForwardToServer;
+
+    template <typename R, typename... Args>
+    class ForwardToServer<R (Server::*)(Args...)> {
+      private:
+        // Get the type T of the last argument. It has CallbackUserdata as its base.
+        using UserdataT = typename std::remove_pointer<typename std::decay<decltype(
+            std::get<sizeof...(Args) - 1>(std::declval<std::tuple<Args...>>()))>::type>::type;
+
+        static_assert(std::is_base_of<CallbackUserdata, UserdataT>::value,
+                      "Last argument of callback handler should derive from CallbackUserdata.");
+
+        template <class T, class... Ts>
+        struct UntypedCallbackImpl;
+
+        template <std::size_t... I, class... Ts>
+        struct UntypedCallbackImpl<std::index_sequence<I...>, Ts...> {
+            template <R (Server::*Func)(Args...)>
+            static auto ForwardToServer(
+                // Unpack and forward the types of the parameter pack.
+                // Append void* as the last argument.
+                typename std::tuple_element<I, std::tuple<Ts...>>::type... args,
+                void* userdata) {
+                // Acquire the userdata, and cast it to UserdataT.
+                std::unique_ptr<UserdataT> data(static_cast<UserdataT*>(userdata));
+                if (data->serverIsAlive.expired()) {
+                    // Do nothing if the server has already been destroyed.
+                    return;
+                }
+                // Forward the arguments and the typed userdata to the Server:: member function.
+                (data->server->*Func)(std::forward<decltype(args)>(args)..., data.get());
+            }
+        };
+
+        // Generate a free function which has all of the same arguments, except the last
+        // userdata argument is void* instead of UserdataT*. Dawn's userdata args are void*.
+        using UntypedCallback =
+            UntypedCallbackImpl<std::make_index_sequence<sizeof...(Args) - 1>, Args...>;
+
+      public:
+        template <R (Server::*F)(Args...)>
+        static auto Func() {
+            return UntypedCallback::template ForwardToServer<F>;
+        }
+    };
+
+    struct MapUserdata : CallbackUserdata {
+        using CallbackUserdata::CallbackUserdata;
+
         ObjectHandle buffer;
         WGPUBuffer bufferObj;
         uint32_t requestSerial;
@@ -36,28 +122,31 @@
         std::unique_ptr<MemoryTransferService::WriteHandle> writeHandle = nullptr;
     };
 
-    struct ErrorScopeUserdata {
-        Server* server;
+    struct ErrorScopeUserdata : CallbackUserdata {
+        using CallbackUserdata::CallbackUserdata;
+
         // TODO(enga): ObjectHandle device;
         // when the wire supports multiple devices.
         uint64_t requestSerial;
     };
 
-    struct FenceCompletionUserdata {
-        Server* server;
+    struct FenceCompletionUserdata : CallbackUserdata {
+        using CallbackUserdata::CallbackUserdata;
+
         ObjectHandle fence;
         uint64_t value;
     };
 
-    struct FenceOnCompletionUserdata {
-        Server* server;
+    struct FenceOnCompletionUserdata : CallbackUserdata {
+        using CallbackUserdata::CallbackUserdata;
+
         ObjectHandle fence;
         uint64_t requestSerial;
     };
 
-    struct CreateReadyPipelineUserData {
-        std::weak_ptr<bool> isServerAlive;
-        Server* server;
+    struct CreateReadyPipelineUserData : CallbackUserdata {
+        using CallbackUserdata::CallbackUserdata;
+
         uint64_t requestSerial;
         ObjectId pipelineObjectID;
     };
@@ -76,6 +165,12 @@
 
         bool InjectTexture(WGPUTexture texture, uint32_t id, uint32_t generation);
 
+        template <typename T,
+                  typename Enable = std::enable_if<std::is_base_of<CallbackUserdata, T>::value>>
+        std::unique_ptr<T> MakeUserdata() {
+            return std::unique_ptr<T>(new T(this, mIsAlive));
+        }
+
       private:
         template <typename Cmd>
         void SerializeCommand(const Cmd& cmd) {
@@ -89,21 +184,6 @@
             mSerializer.SerializeCommand(cmd, extraSize, SerializeExtraSize);
         }
 
-        // Forwarding callbacks
-        static void ForwardUncapturedError(WGPUErrorType type, const char* message, void* userdata);
-        static void ForwardDeviceLost(const char* message, void* userdata);
-        static void ForwardPopErrorScope(WGPUErrorType type, const char* message, void* userdata);
-        static void ForwardBufferMapAsync(WGPUBufferMapAsyncStatus status, void* userdata);
-        static void ForwardFenceCompletedValue(WGPUFenceCompletionStatus status, void* userdata);
-        static void ForwardFenceOnCompletion(WGPUFenceCompletionStatus status, void* userdata);
-        static void ForwardCreateReadyComputePipeline(WGPUCreateReadyPipelineStatus status,
-                                                      WGPUComputePipeline pipeline,
-                                                      const char* message,
-                                                      void* userdata);
-        static void ForwardCreateReadyRenderPipeline(WGPUCreateReadyPipelineStatus status,
-                                                     WGPURenderPipeline pipeline,
-                                                     const char* message,
-                                                     void* userdata);
 
         // Error callbacks
         void OnUncapturedError(WGPUErrorType type, const char* message);
diff --git a/src/dawn_wire/server/ServerBuffer.cpp b/src/dawn_wire/server/ServerBuffer.cpp
index 1a8037c..e842bb1 100644
--- a/src/dawn_wire/server/ServerBuffer.cpp
+++ b/src/dawn_wire/server/ServerBuffer.cpp
@@ -77,8 +77,7 @@
             return false;
         }
 
-        std::unique_ptr<MapUserdata> userdata = std::make_unique<MapUserdata>();
-        userdata->server = this;
+        std::unique_ptr<MapUserdata> userdata = MakeUserdata<MapUserdata>();
         userdata->buffer = ObjectHandle{bufferId, buffer->generation};
         userdata->bufferObj = buffer->handle;
         userdata->requestSerial = requestSerial;
@@ -112,8 +111,11 @@
             userdata->readHandle = std::unique_ptr<MemoryTransferService::ReadHandle>(readHandle);
         }
 
-        mProcs.bufferMapAsync(buffer->handle, mode, offset, size, ForwardBufferMapAsync,
-                              userdata.release());
+        mProcs.bufferMapAsync(
+            buffer->handle, mode, offset, size,
+            ForwardToServer<decltype(
+                &Server::OnBufferMapAsyncCallback)>::Func<&Server::OnBufferMapAsyncCallback>(),
+            userdata.release());
 
         return true;
     }
@@ -206,14 +208,7 @@
                                                      static_cast<size_t>(writeFlushInfoLength));
     }
 
-    void Server::ForwardBufferMapAsync(WGPUBufferMapAsyncStatus status, void* userdata) {
-        auto data = static_cast<MapUserdata*>(userdata);
-        data->server->OnBufferMapAsyncCallback(status, data);
-    }
-
-    void Server::OnBufferMapAsyncCallback(WGPUBufferMapAsyncStatus status, MapUserdata* userdata) {
-        std::unique_ptr<MapUserdata> data(userdata);
-
+    void Server::OnBufferMapAsyncCallback(WGPUBufferMapAsyncStatus status, MapUserdata* data) {
         // Skip sending the callback if the buffer has already been destroyed.
         auto* bufferData = BufferObjects().Get(data->buffer.id);
         if (bufferData == nullptr || bufferData->generation != data->buffer.generation) {
diff --git a/src/dawn_wire/server/ServerDevice.cpp b/src/dawn_wire/server/ServerDevice.cpp
index 2fd26a2..cd2eba7 100644
--- a/src/dawn_wire/server/ServerDevice.cpp
+++ b/src/dawn_wire/server/ServerDevice.cpp
@@ -16,50 +16,6 @@
 
 namespace dawn_wire { namespace server {
 
-    void Server::ForwardUncapturedError(WGPUErrorType type, const char* message, void* userdata) {
-        auto server = static_cast<Server*>(userdata);
-        server->OnUncapturedError(type, message);
-    }
-
-    void Server::ForwardDeviceLost(const char* message, void* userdata) {
-        auto server = static_cast<Server*>(userdata);
-        server->OnDeviceLost(message);
-    }
-
-    void Server::ForwardCreateReadyComputePipeline(WGPUCreateReadyPipelineStatus status,
-                                                   WGPUComputePipeline pipeline,
-                                                   const char* message,
-                                                   void* userdata) {
-        std::unique_ptr<CreateReadyPipelineUserData> createReadyPipelineUserData(
-            static_cast<CreateReadyPipelineUserData*>(userdata));
-
-        // We need to ensure createReadyPipelineUserData->server is still pointing to a valid
-        // object before doing any operations on it.
-        if (createReadyPipelineUserData->isServerAlive.expired()) {
-            return;
-        }
-
-        createReadyPipelineUserData->server->OnCreateReadyComputePipelineCallback(
-            status, pipeline, message, createReadyPipelineUserData.release());
-    }
-
-    void Server::ForwardCreateReadyRenderPipeline(WGPUCreateReadyPipelineStatus status,
-                                                  WGPURenderPipeline pipeline,
-                                                  const char* message,
-                                                  void* userdata) {
-        std::unique_ptr<CreateReadyPipelineUserData> createReadyPipelineUserData(
-            static_cast<CreateReadyPipelineUserData*>(userdata));
-
-        // We need to ensure createReadyPipelineUserData->server is still pointing to a valid
-        // object before doing any operations on it.
-        if (createReadyPipelineUserData->isServerAlive.expired()) {
-            return;
-        }
-
-        createReadyPipelineUserData->server->OnCreateReadyRenderPipelineCallback(
-            status, pipeline, message, createReadyPipelineUserData.release());
-    }
-
     void Server::OnUncapturedError(WGPUErrorType type, const char* message) {
         ReturnDeviceUncapturedErrorCallbackCmd cmd;
         cmd.type = type;
@@ -76,17 +32,32 @@
     }
 
     bool Server::DoDevicePopErrorScope(WGPUDevice cDevice, uint64_t requestSerial) {
-        ErrorScopeUserdata* userdata = new ErrorScopeUserdata;
-        userdata->server = this;
+        auto userdata = MakeUserdata<ErrorScopeUserdata>();
         userdata->requestSerial = requestSerial;
 
-        bool success = mProcs.devicePopErrorScope(cDevice, ForwardPopErrorScope, userdata);
+        ErrorScopeUserdata* unownedUserdata = userdata.release();
+        bool success = mProcs.devicePopErrorScope(
+            cDevice,
+            ForwardToServer<decltype(
+                &Server::OnDevicePopErrorScope)>::Func<&Server::OnDevicePopErrorScope>(),
+            unownedUserdata);
         if (!success) {
-            delete userdata;
+            delete unownedUserdata;
         }
         return success;
     }
 
+    void Server::OnDevicePopErrorScope(WGPUErrorType type,
+                                       const char* message,
+                                       ErrorScopeUserdata* userdata) {
+        ReturnDevicePopErrorScopeCallbackCmd cmd;
+        cmd.requestSerial = userdata->requestSerial;
+        cmd.type = type;
+        cmd.message = message;
+
+        SerializeCommand(cmd);
+    }
+
     bool Server::DoDeviceCreateReadyComputePipeline(
         WGPUDevice cDevice,
         uint64_t requestSerial,
@@ -99,24 +70,22 @@
 
         resultData->generation = pipelineObjectHandle.generation;
 
-        std::unique_ptr<CreateReadyPipelineUserData> userdata =
-            std::make_unique<CreateReadyPipelineUserData>();
-        userdata->isServerAlive = mIsAlive;
-        userdata->server = this;
+        auto userdata = MakeUserdata<CreateReadyPipelineUserData>();
         userdata->requestSerial = requestSerial;
         userdata->pipelineObjectID = pipelineObjectHandle.id;
 
         mProcs.deviceCreateReadyComputePipeline(
-            cDevice, descriptor, ForwardCreateReadyComputePipeline, userdata.release());
+            cDevice, descriptor,
+            ForwardToServer<decltype(&Server::OnCreateReadyComputePipelineCallback)>::Func<
+                &Server::OnCreateReadyComputePipelineCallback>(),
+            userdata.release());
         return true;
     }
 
     void Server::OnCreateReadyComputePipelineCallback(WGPUCreateReadyPipelineStatus status,
                                                       WGPUComputePipeline pipeline,
                                                       const char* message,
-                                                      CreateReadyPipelineUserData* userdata) {
-        std::unique_ptr<CreateReadyPipelineUserData> data(userdata);
-
+                                                      CreateReadyPipelineUserData* data) {
         auto* computePipelineObject = ComputePipelineObjects().Get(data->pipelineObjectID);
         ASSERT(computePipelineObject != nullptr);
 
@@ -158,24 +127,22 @@
 
         resultData->generation = pipelineObjectHandle.generation;
 
-        std::unique_ptr<CreateReadyPipelineUserData> userdata =
-            std::make_unique<CreateReadyPipelineUserData>();
-        userdata->isServerAlive = mIsAlive;
-        userdata->server = this;
+        auto userdata = MakeUserdata<CreateReadyPipelineUserData>();
         userdata->requestSerial = requestSerial;
         userdata->pipelineObjectID = pipelineObjectHandle.id;
 
         mProcs.deviceCreateReadyRenderPipeline(
-            cDevice, descriptor, ForwardCreateReadyRenderPipeline, userdata.release());
+            cDevice, descriptor,
+            ForwardToServer<decltype(&Server::OnCreateReadyRenderPipelineCallback)>::Func<
+                &Server::OnCreateReadyRenderPipelineCallback>(),
+            userdata.release());
         return true;
     }
 
     void Server::OnCreateReadyRenderPipelineCallback(WGPUCreateReadyPipelineStatus status,
                                                      WGPURenderPipeline pipeline,
                                                      const char* message,
-                                                     CreateReadyPipelineUserData* userdata) {
-        std::unique_ptr<CreateReadyPipelineUserData> data(userdata);
-
+                                                     CreateReadyPipelineUserData* data) {
         auto* renderPipelineObject = RenderPipelineObjects().Get(data->pipelineObjectID);
         ASSERT(renderPipelineObject != nullptr);
 
@@ -206,23 +173,4 @@
         SerializeCommand(cmd);
     }
 
-    // static
-    void Server::ForwardPopErrorScope(WGPUErrorType type, const char* message, void* userdata) {
-        auto* data = reinterpret_cast<ErrorScopeUserdata*>(userdata);
-        data->server->OnDevicePopErrorScope(type, message, data);
-    }
-
-    void Server::OnDevicePopErrorScope(WGPUErrorType type,
-                                       const char* message,
-                                       ErrorScopeUserdata* userdata) {
-        std::unique_ptr<ErrorScopeUserdata> data{userdata};
-
-        ReturnDevicePopErrorScopeCallbackCmd cmd;
-        cmd.requestSerial = data->requestSerial;
-        cmd.type = type;
-        cmd.message = message;
-
-        SerializeCommand(cmd);
-    }
-
 }}  // namespace dawn_wire::server
diff --git a/src/dawn_wire/server/ServerFence.cpp b/src/dawn_wire/server/ServerFence.cpp
index cea561e..7d78840 100644
--- a/src/dawn_wire/server/ServerFence.cpp
+++ b/src/dawn_wire/server/ServerFence.cpp
@@ -18,15 +18,8 @@
 
 namespace dawn_wire { namespace server {
 
-    void Server::ForwardFenceCompletedValue(WGPUFenceCompletionStatus status, void* userdata) {
-        auto data = static_cast<FenceCompletionUserdata*>(userdata);
-        data->server->OnFenceCompletedValueUpdated(status, data);
-    }
-
     void Server::OnFenceCompletedValueUpdated(WGPUFenceCompletionStatus status,
-                                              FenceCompletionUserdata* userdata) {
-        std::unique_ptr<FenceCompletionUserdata> data(userdata);
-
+                                              FenceCompletionUserdata* data) {
         if (status != WGPUFenceCompletionStatus_Success) {
             return;
         }
@@ -49,25 +42,20 @@
             return false;
         }
 
-        FenceOnCompletionUserdata* userdata = new FenceOnCompletionUserdata;
-        userdata->server = this;
+        auto userdata = MakeUserdata<FenceOnCompletionUserdata>();
         userdata->fence = ObjectHandle{fenceId, fence->generation};
         userdata->requestSerial = requestSerial;
 
-        mProcs.fenceOnCompletion(fence->handle, value, ForwardFenceOnCompletion, userdata);
+        mProcs.fenceOnCompletion(
+            fence->handle, value,
+            ForwardToServer<decltype(
+                &Server::OnFenceOnCompletion)>::Func<&Server::OnFenceOnCompletion>(),
+            userdata.release());
         return true;
     }
 
-    // static
-    void Server::ForwardFenceOnCompletion(WGPUFenceCompletionStatus status, void* userdata) {
-        auto* data = reinterpret_cast<FenceOnCompletionUserdata*>(userdata);
-        data->server->OnFenceOnCompletion(status, data);
-    }
-
     void Server::OnFenceOnCompletion(WGPUFenceCompletionStatus status,
-                                     FenceOnCompletionUserdata* userdata) {
-        std::unique_ptr<FenceOnCompletionUserdata> data{userdata};
-
+                                     FenceOnCompletionUserdata* data) {
         ReturnFenceOnCompletionCallbackCmd cmd;
         cmd.fence = data->fence;
         cmd.requestSerial = data->requestSerial;
diff --git a/src/dawn_wire/server/ServerQueue.cpp b/src/dawn_wire/server/ServerQueue.cpp
index b6d2903..67f439a 100644
--- a/src/dawn_wire/server/ServerQueue.cpp
+++ b/src/dawn_wire/server/ServerQueue.cpp
@@ -28,12 +28,15 @@
         auto* fence = FenceObjects().Get(fenceId);
         ASSERT(fence != nullptr);
 
-        FenceCompletionUserdata* userdata = new FenceCompletionUserdata;
-        userdata->server = this;
+        auto userdata = MakeUserdata<FenceCompletionUserdata>();
         userdata->fence = ObjectHandle{fenceId, fence->generation};
         userdata->value = signalValue;
 
-        mProcs.fenceOnCompletion(cFence, signalValue, ForwardFenceCompletedValue, userdata);
+        mProcs.fenceOnCompletion(
+            cFence, signalValue,
+            ForwardToServer<decltype(&Server::OnFenceCompletedValueUpdated)>::Func<
+                &Server::OnFenceCompletedValueUpdated>(),
+            userdata.release());
         return true;
     }
 
diff --git a/src/tests/unittests/wire/WireMultipleDeviceTests.cpp b/src/tests/unittests/wire/WireMultipleDeviceTests.cpp
index 3674c87..f51354e 100644
--- a/src/tests/unittests/wire/WireMultipleDeviceTests.cpp
+++ b/src/tests/unittests/wire/WireMultipleDeviceTests.cpp
@@ -79,6 +79,12 @@
         ~WireHolder() {
             mApi.IgnoreAllReleaseCalls();
             mWireClient = nullptr;
+
+            // These are called on server destruction to clear the callbacks. They must not be
+            // called after the server is destroyed.
+            EXPECT_CALL(mApi, OnDeviceSetUncapturedErrorCallback(_, nullptr, nullptr))
+                .Times(Exactly(1));
+            EXPECT_CALL(mApi, OnDeviceSetDeviceLostCallback(_, nullptr, nullptr)).Times(Exactly(1));
             mWireServer = nullptr;
         }
 
diff --git a/src/tests/unittests/wire/WireTest.cpp b/src/tests/unittests/wire/WireTest.cpp
index 7c8a4d5..2106b04 100644
--- a/src/tests/unittests/wire/WireTest.cpp
+++ b/src/tests/unittests/wire/WireTest.cpp
@@ -86,6 +86,13 @@
     // cannot be null.
     api.IgnoreAllReleaseCalls();
     mWireClient = nullptr;
+
+    if (mWireServer) {
+        // These are called on server destruction to clear the callbacks. They must not be
+        // called after the server is destroyed.
+        EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(_, nullptr, nullptr)).Times(Exactly(1));
+        EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(_, nullptr, nullptr)).Times(Exactly(1));
+    }
     mWireServer = nullptr;
 }
 
@@ -110,6 +117,13 @@
 
 void WireTest::DeleteServer() {
     EXPECT_CALL(api, QueueRelease(apiQueue)).Times(1);
+
+    if (mWireServer) {
+        // These are called on server destruction to clear the callbacks. They must not be
+        // called after the server is destroyed.
+        EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(_, nullptr, nullptr)).Times(Exactly(1));
+        EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(_, nullptr, nullptr)).Times(Exactly(1));
+    }
     mWireServer = nullptr;
 }