Add DeviceLostCallback to dawn.json and dawn_wire

Bug: dawn:68
Change-Id: I6d8dd071be4ec612c67245bfde218e31e7a998b8
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/14660
Commit-Queue: Natasha Lee <natlee@microsoft.com>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/dawn.json b/dawn.json
index 44fce28..30dbbef 100644
--- a/dawn.json
+++ b/dawn.json
@@ -590,6 +590,13 @@
                 ]
             },
             {
+                "name": "set device lost callback",
+                "args": [
+                    {"name": "callback", "type": "device lost callback"},
+                    {"name": "userdata", "type": "void", "annotation": "*"}
+                ]
+            },
+            {
                 "name": "push error scope",
                 "args": [
                     {"name": "filter", "type": "error filter"}
@@ -605,6 +612,13 @@
             }
         ]
     },
+    "device lost callback": {
+        "category": "callback",
+        "args" : [
+            {"name": "message", "type": "char", "annotation": "const*"},
+            {"name": "userdata", "type": "void", "annotation": "*"}
+        ]
+    },
     "device properties": {
         "category": "structure",
         "extensible": false,
diff --git a/dawn_wire.json b/dawn_wire.json
index cad1d41..a8cffcb 100644
--- a/dawn_wire.json
+++ b/dawn_wire.json
@@ -74,6 +74,9 @@
             { "name": "type", "type": "error type"},
             { "name": "message", "type": "char", "annotation": "const*", "length": "strlen" }
         ],
+        "device lost callback" : [
+            { "name": "message", "type": "char", "annotation": "const*", "length": "strlen" }
+        ],
         "device pop error scope callback": [
             { "name": "request serial", "type": "uint64_t" },
             { "name": "type", "type": "error type" },
@@ -94,6 +97,7 @@
             "BufferSetSubData",
             "DeviceCreateBufferMappedAsync",
             "DevicePopErrorScope",
+            "DeviceSetDeviceLostCallback",
             "DeviceSetUncapturedErrorCallback",
             "FenceGetCompletedValue",
             "FenceOnCompletion"
diff --git a/generator/templates/mock_webgpu.cpp b/generator/templates/mock_webgpu.cpp
index 788290b..edb9689 100644
--- a/generator/templates/mock_webgpu.cpp
+++ b/generator/templates/mock_webgpu.cpp
@@ -55,11 +55,21 @@
                                                         void* userdata) {
     auto object = reinterpret_cast<ProcTableAsClass::Object*>(self);
     object->deviceErrorCallback = callback;
-    object->userdata1 = userdata;
+    object->userdata = userdata;
 
     OnDeviceSetUncapturedErrorCallback(self, callback, userdata);
 }
 
+void ProcTableAsClass::DeviceSetDeviceLostCallback(WGPUDevice self,
+                                                   WGPUDeviceLostCallback callback,
+                                                   void* userdata) {
+    auto object = reinterpret_cast<ProcTableAsClass::Object*>(self);
+    object->deviceLostCallback = callback;
+    object->userdata = userdata;
+
+    OnDeviceSetDeviceLostCallback(self, callback, userdata);
+}
+
 bool ProcTableAsClass::DevicePopErrorScope(WGPUDevice self,
                                            WGPUErrorCallback callback,
                                            void* userdata) {
@@ -72,7 +82,7 @@
                                                      void* userdata) {
     auto object = reinterpret_cast<ProcTableAsClass::Object*>(self);
     object->createBufferMappedCallback = callback;
-    object->userdata1 = userdata;
+    object->userdata = userdata;
 
     OnDeviceCreateBufferMappedAsyncCallback(self, descriptor, callback, userdata);
 }
@@ -82,7 +92,7 @@
                                           void* userdata) {
     auto object = reinterpret_cast<ProcTableAsClass::Object*>(self);
     object->mapReadCallback = callback;
-    object->userdata1 = userdata;
+    object->userdata = userdata;
 
     OnBufferMapReadAsyncCallback(self, callback, userdata);
 }
@@ -92,7 +102,7 @@
                                            void* userdata) {
     auto object = reinterpret_cast<ProcTableAsClass::Object*>(self);
     object->mapWriteCallback = callback;
-    object->userdata1 = userdata;
+    object->userdata = userdata;
 
     OnBufferMapWriteAsyncCallback(self, callback, userdata);
 }
@@ -103,7 +113,7 @@
                                          void* userdata) {
     auto object = reinterpret_cast<ProcTableAsClass::Object*>(self);
     object->fenceOnCompletionCallback = callback;
-    object->userdata1 = userdata;
+    object->userdata = userdata;
 
     OnFenceOnCompletionCallback(self, value, callback, userdata);
 }
@@ -112,20 +122,26 @@
                                                WGPUErrorType type,
                                                const char* message) {
     auto object = reinterpret_cast<ProcTableAsClass::Object*>(device);
-    object->deviceErrorCallback(type, message, object->userdata1);
+    object->deviceErrorCallback(type, message, object->userdata);
 }
+
+void ProcTableAsClass::CallDeviceLostCallback(WGPUDevice device, const char* message) {
+    auto object = reinterpret_cast<ProcTableAsClass::Object*>(device);
+    object->deviceLostCallback(message, object->userdata);
+}
+
 void ProcTableAsClass::CallCreateBufferMappedCallback(WGPUDevice device,
                                                       WGPUBufferMapAsyncStatus status,
                                                       WGPUCreateBufferMappedResult result) {
     auto object = reinterpret_cast<ProcTableAsClass::Object*>(device);
-    object->createBufferMappedCallback(status, result, object->userdata1);
+    object->createBufferMappedCallback(status, result, object->userdata);
 }
 void ProcTableAsClass::CallMapReadCallback(WGPUBuffer buffer,
                                            WGPUBufferMapAsyncStatus status,
                                            const void* data,
                                            uint64_t dataLength) {
     auto object = reinterpret_cast<ProcTableAsClass::Object*>(buffer);
-    object->mapReadCallback(status, data, dataLength, object->userdata1);
+    object->mapReadCallback(status, data, dataLength, object->userdata);
 }
 
 void ProcTableAsClass::CallMapWriteCallback(WGPUBuffer buffer,
@@ -133,13 +149,13 @@
                                             void* data,
                                             uint64_t dataLength) {
     auto object = reinterpret_cast<ProcTableAsClass::Object*>(buffer);
-    object->mapWriteCallback(status, data, dataLength, object->userdata1);
+    object->mapWriteCallback(status, data, dataLength, object->userdata);
 }
 
 void ProcTableAsClass::CallFenceOnCompletionCallback(WGPUFence fence,
                                                      WGPUFenceCompletionStatus status) {
     auto object = reinterpret_cast<ProcTableAsClass::Object*>(fence);
-    object->fenceOnCompletionCallback(status, object->userdata1);
+    object->fenceOnCompletionCallback(status, object->userdata);
 }
 
 {% for type in by_category["object"] %}
diff --git a/generator/templates/mock_webgpu.h b/generator/templates/mock_webgpu.h
index 57d043d..f21e1dc 100644
--- a/generator/templates/mock_webgpu.h
+++ b/generator/templates/mock_webgpu.h
@@ -55,6 +55,9 @@
         void DeviceSetUncapturedErrorCallback(WGPUDevice self,
                                     WGPUErrorCallback callback,
                                     void* userdata);
+        void DeviceSetDeviceLostCallback(WGPUDevice self,
+                                         WGPUDeviceLostCallback callback,
+                                         void* userdata);
         bool DevicePopErrorScope(WGPUDevice self, WGPUErrorCallback callback, void* userdata);
         void DeviceCreateBufferMappedAsync(WGPUDevice self,
                                            const WGPUBufferDescriptor* descriptor,
@@ -75,6 +78,9 @@
         virtual void OnDeviceSetUncapturedErrorCallback(WGPUDevice device,
                                               WGPUErrorCallback callback,
                                               void* userdata) = 0;
+        virtual void OnDeviceSetDeviceLostCallback(WGPUDevice device,
+                                                   WGPUDeviceLostCallback callback,
+                                                   void* userdata) = 0;
         virtual bool OnDevicePopErrorScopeCallback(WGPUDevice device,
                                               WGPUErrorCallback callback,
                                               void* userdata) = 0;
@@ -95,6 +101,7 @@
 
         // Calls the stored callbacks
         void CallDeviceErrorCallback(WGPUDevice device, WGPUErrorType type, const char* message);
+        void CallDeviceLostCallback(WGPUDevice device, const char* message);
         void CallCreateBufferMappedCallback(WGPUDevice device, WGPUBufferMapAsyncStatus status, WGPUCreateBufferMappedResult result);
         void CallMapReadCallback(WGPUBuffer buffer, WGPUBufferMapAsyncStatus status, const void* data, uint64_t dataLength);
         void CallMapWriteCallback(WGPUBuffer buffer, WGPUBufferMapAsyncStatus status, void* data, uint64_t dataLength);
@@ -103,12 +110,12 @@
         struct Object {
             ProcTableAsClass* procs = nullptr;
             WGPUErrorCallback deviceErrorCallback = nullptr;
+            WGPUDeviceLostCallback deviceLostCallback = nullptr;
             WGPUBufferCreateMappedCallback createBufferMappedCallback = nullptr;
             WGPUBufferMapReadCallback mapReadCallback = nullptr;
             WGPUBufferMapWriteCallback mapWriteCallback = nullptr;
             WGPUFenceOnCompletionCallback fenceOnCompletionCallback = nullptr;
-            void* userdata1 = 0;
-            void* userdata2 = 0;
+            void* userdata = 0;
         };
 
     private:
@@ -139,6 +146,8 @@
         {% endfor %}
 
         MOCK_METHOD3(OnDeviceSetUncapturedErrorCallback, void(WGPUDevice device, WGPUErrorCallback callback, void* userdata));
+        MOCK_METHOD3(OnDeviceSetDeviceLostCallback,
+                     void(WGPUDevice device, WGPUDeviceLostCallback callback, void* userdata));
         MOCK_METHOD3(OnDevicePopErrorScopeCallback, bool(WGPUDevice device, WGPUErrorCallback callback, void* userdata));
         MOCK_METHOD4(OnDeviceCreateBufferMappedAsyncCallback, void(WGPUDevice device, const WGPUBufferDescriptor* descriptor, WGPUBufferCreateMappedCallback callback, void* userdata));
         MOCK_METHOD3(OnBufferMapReadAsyncCallback, void(WGPUBuffer buffer, WGPUBufferMapReadCallback callback, void* userdata));
diff --git a/src/dawn_native/Device.cpp b/src/dawn_native/Device.cpp
index 9c136be..3926a89 100644
--- a/src/dawn_native/Device.cpp
+++ b/src/dawn_native/Device.cpp
@@ -128,6 +128,11 @@
         mRootErrorScope->SetCallback(callback, userdata);
     }
 
+    void DeviceBase::SetDeviceLostCallback(wgpu::DeviceLostCallback callback, void* userdata) {
+        mDeviceLostCallback = callback;
+        mDeviceLostUserdata = userdata;
+    }
+
     void DeviceBase::PushErrorScope(wgpu::ErrorFilter filter) {
         if (ConsumedError(ValidateErrorFilter(filter))) {
             return;
diff --git a/src/dawn_native/Device.h b/src/dawn_native/Device.h
index 7f22df5..aa6471f 100644
--- a/src/dawn_native/Device.h
+++ b/src/dawn_native/Device.h
@@ -163,6 +163,7 @@
 
         void Tick();
 
+        void SetDeviceLostCallback(wgpu::DeviceLostCallback callback, void* userdata);
         void SetUncapturedErrorCallback(wgpu::ErrorCallback callback, void* userdata);
         void PushErrorScope(wgpu::ErrorFilter filter);
         bool PopErrorScope(wgpu::ErrorCallback callback, void* userdata);
@@ -262,6 +263,9 @@
         // resources.
         virtual MaybeError WaitForIdleForDestruction() = 0;
 
+        wgpu::DeviceLostCallback mDeviceLostCallback = nullptr;
+        void* mDeviceLostUserdata;
+
         AdapterBase* mAdapter = nullptr;
 
         Ref<ErrorScope> mRootErrorScope;
diff --git a/src/dawn_wire/client/ApiProcs.cpp b/src/dawn_wire/client/ApiProcs.cpp
index 14f25e4..128bde7 100644
--- a/src/dawn_wire/client/ApiProcs.cpp
+++ b/src/dawn_wire/client/ApiProcs.cpp
@@ -431,5 +431,11 @@
         Device* device = reinterpret_cast<Device*>(cSelf);
         device->SetUncapturedErrorCallback(callback, userdata);
     }
+    void ClientDeviceSetDeviceLostCallback(WGPUDevice cSelf,
+                                           WGPUDeviceLostCallback callback,
+                                           void* userdata) {
+        Device* device = reinterpret_cast<Device*>(cSelf);
+        device->SetDeviceLostCallback(callback, userdata);
+    }
 
 }}  // namespace dawn_wire::client
diff --git a/src/dawn_wire/client/ClientDoers.cpp b/src/dawn_wire/client/ClientDoers.cpp
index 1be0f1d..dd90406 100644
--- a/src/dawn_wire/client/ClientDoers.cpp
+++ b/src/dawn_wire/client/ClientDoers.cpp
@@ -35,6 +35,11 @@
         return true;
     }
 
+    bool Client::DoDeviceLostCallback(char const* message) {
+        mDevice->HandleDeviceLost(message);
+        return true;
+    }
+
     bool Client::DoDevicePopErrorScopeCallback(uint64_t requestSerial,
                                                WGPUErrorType errorType,
                                                const char* message) {
diff --git a/src/dawn_wire/client/Device.cpp b/src/dawn_wire/client/Device.cpp
index 8577a4f..175617f 100644
--- a/src/dawn_wire/client/Device.cpp
+++ b/src/dawn_wire/client/Device.cpp
@@ -42,11 +42,22 @@
         }
     }
 
+    void Device::HandleDeviceLost(const char* message) {
+        if (mDeviceLostCallback) {
+            mDeviceLostCallback(message, mDeviceLostUserdata);
+        }
+    }
+
     void Device::SetUncapturedErrorCallback(WGPUErrorCallback errorCallback, void* errorUserdata) {
         mErrorCallback = errorCallback;
         mErrorUserdata = errorUserdata;
     }
 
+    void Device::SetDeviceLostCallback(WGPUDeviceLostCallback callback, void* userdata) {
+        mDeviceLostCallback = callback;
+        mDeviceLostUserdata = userdata;
+    }
+
     void Device::PushErrorScope(WGPUErrorFilter filter) {
         mErrorScopeStackSize++;
 
diff --git a/src/dawn_wire/client/Device.h b/src/dawn_wire/client/Device.h
index 9c1bb2f..af5934e 100644
--- a/src/dawn_wire/client/Device.h
+++ b/src/dawn_wire/client/Device.h
@@ -32,7 +32,9 @@
 
         Client* GetClient();
         void HandleError(WGPUErrorType errorType, const char* message);
+        void HandleDeviceLost(const char* message);
         void SetUncapturedErrorCallback(WGPUErrorCallback errorCallback, void* errorUserdata);
+        void SetDeviceLostCallback(WGPUDeviceLostCallback errorCallback, void* errorUserdata);
 
         void PushErrorScope(WGPUErrorFilter filter);
         bool RequestPopErrorScope(WGPUErrorCallback callback, void* userdata);
@@ -49,7 +51,9 @@
 
         Client* mClient = nullptr;
         WGPUErrorCallback mErrorCallback = nullptr;
+        WGPUDeviceLostCallback mDeviceLostCallback = nullptr;
         void* mErrorUserdata;
+        void* mDeviceLostUserdata;
     };
 
 }}  // namespace dawn_wire::client
diff --git a/src/dawn_wire/server/Server.cpp b/src/dawn_wire/server/Server.cpp
index d03980a..d05285d 100644
--- a/src/dawn_wire/server/Server.cpp
+++ b/src/dawn_wire/server/Server.cpp
@@ -32,6 +32,7 @@
         deviceData->handle = device;
 
         mProcs.deviceSetUncapturedErrorCallback(device, ForwardUncapturedError, this);
+        mProcs.deviceSetDeviceLostCallback(device, ForwardDeviceLost, this);
     }
 
     Server::~Server() {
diff --git a/src/dawn_wire/server/Server.h b/src/dawn_wire/server/Server.h
index 28f4c81..effe69e 100644
--- a/src/dawn_wire/server/Server.h
+++ b/src/dawn_wire/server/Server.h
@@ -62,6 +62,7 @@
 
         // Forwarding callbacks
         static void ForwardUncapturedError(WGPUErrorType type, const char* message, void* userdata);
+        static void ForwardDeviceLost(const char* message, void* userdata);
         static void ForwardPopErrorScope(WGPUErrorType type, const char* message, void* userdata);
         static void ForwardBufferMapReadAsync(WGPUBufferMapAsyncStatus status,
                                               const void* ptr,
@@ -75,6 +76,7 @@
 
         // Error callbacks
         void OnUncapturedError(WGPUErrorType type, const char* message);
+        void OnDeviceLost(const char* message);
         void OnDevicePopErrorScope(WGPUErrorType type,
                                    const char* message,
                                    ErrorScopeUserdata* userdata);
diff --git a/src/dawn_wire/server/ServerDevice.cpp b/src/dawn_wire/server/ServerDevice.cpp
index 6f27867..66c0d70 100644
--- a/src/dawn_wire/server/ServerDevice.cpp
+++ b/src/dawn_wire/server/ServerDevice.cpp
@@ -21,6 +21,11 @@
         server->OnUncapturedError(type, message);
     }
 
+    void Server::ForwardDeviceLost(const char* message, void* userdata) {
+        auto server = static_cast<Server*>(userdata);
+        server->OnDeviceLost(message);
+    }
+
     void Server::OnUncapturedError(WGPUErrorType type, const char* message) {
         ReturnDeviceUncapturedErrorCallbackCmd cmd;
         cmd.type = type;
@@ -31,6 +36,15 @@
         cmd.Serialize(allocatedBuffer);
     }
 
+    void Server::OnDeviceLost(const char* message) {
+        ReturnDeviceLostCallbackCmd cmd;
+        cmd.message = message;
+
+        size_t requiredSize = cmd.GetRequiredSize();
+        char* allocatedBuffer = static_cast<char*>(GetCmdSpace(requiredSize));
+        cmd.Serialize(allocatedBuffer);
+    }
+
     bool Server::DoDevicePopErrorScope(WGPUDevice cDevice, uint64_t requestSerial) {
         ErrorScopeUserdata* userdata = new ErrorScopeUserdata;
         userdata->server = this;
diff --git a/src/tests/unittests/wire/WireErrorCallbackTests.cpp b/src/tests/unittests/wire/WireErrorCallbackTests.cpp
index 18436fc..93b5342 100644
--- a/src/tests/unittests/wire/WireErrorCallbackTests.cpp
+++ b/src/tests/unittests/wire/WireErrorCallbackTests.cpp
@@ -42,6 +42,16 @@
         mockDevicePopErrorScopeCallback->Call(type, message, userdata);
     }
 
+    class MockDeviceLostCallback {
+      public:
+        MOCK_METHOD2(Call, void(const char* message, void* userdata));
+    };
+
+    std::unique_ptr<StrictMock<MockDeviceLostCallback>> mockDeviceLostCallback;
+    void ToMockDeviceLostCallback(const char* message, void* userdata) {
+        mockDeviceLostCallback->Call(message, userdata);
+    }
+
 }  // anonymous namespace
 
 class WireErrorCallbackTests : public WireTest {
@@ -55,6 +65,7 @@
 
         mockDeviceErrorCallback = std::make_unique<StrictMock<MockDeviceErrorCallback>>();
         mockDevicePopErrorScopeCallback = std::make_unique<StrictMock<MockDevicePopErrorScopeCallback>>();
+        mockDeviceLostCallback = std::make_unique<StrictMock<MockDeviceLostCallback>>();
     }
 
     void TearDown() override {
@@ -62,6 +73,7 @@
 
         mockDeviceErrorCallback = nullptr;
         mockDevicePopErrorScopeCallback = nullptr;
+        mockDeviceLostCallback = nullptr;
     }
 
     void FlushServer() {
@@ -232,3 +244,19 @@
         FlushServer();
     }
 }
+
+// Test the return wire for device lost callback
+TEST_F(WireErrorCallbackTests, DeviceLostCallback) {
+    wgpuDeviceSetDeviceLostCallback(device, ToMockDeviceLostCallback, this);
+
+    // Setting the error callback should stay on the client side and do nothing
+    FlushClient();
+
+    // Calling the callback on the server side will result in the callback being called on the
+    // client side
+    api.CallDeviceLostCallback(apiDevice, "Some error message");
+
+    EXPECT_CALL(*mockDeviceLostCallback, Call(StrEq("Some error message"), this)).Times(1);
+
+    FlushServer();
+}
diff --git a/src/tests/unittests/wire/WireTest.cpp b/src/tests/unittests/wire/WireTest.cpp
index d48c085..9dce5e7 100644
--- a/src/tests/unittests/wire/WireTest.cpp
+++ b/src/tests/unittests/wire/WireTest.cpp
@@ -43,6 +43,7 @@
 
     // This SetCallback call cannot be ignored because it is done as soon as we start the server
     EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(_, _, _)).Times(Exactly(1));
+    EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(_, _, _)).Times(Exactly(1));
     SetupIgnoredCallExpectations();
 
     mS2cBuf = std::make_unique<utils::TerribleCommandBuffer>();