Add DeviceLostCallback to dawn.json and dawn_wire
Bug: dawn:68
Change-Id: I6d8dd071be4ec612c67245bfde218e31e7a998b8
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/14660
Commit-Queue: Natasha Lee <natlee@microsoft.com>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/dawn_native/Device.cpp b/src/dawn_native/Device.cpp
index 9c136be..3926a89 100644
--- a/src/dawn_native/Device.cpp
+++ b/src/dawn_native/Device.cpp
@@ -128,6 +128,11 @@
mRootErrorScope->SetCallback(callback, userdata);
}
+ void DeviceBase::SetDeviceLostCallback(wgpu::DeviceLostCallback callback, void* userdata) {
+ mDeviceLostCallback = callback;
+ mDeviceLostUserdata = userdata;
+ }
+
void DeviceBase::PushErrorScope(wgpu::ErrorFilter filter) {
if (ConsumedError(ValidateErrorFilter(filter))) {
return;
diff --git a/src/dawn_native/Device.h b/src/dawn_native/Device.h
index 7f22df5..aa6471f 100644
--- a/src/dawn_native/Device.h
+++ b/src/dawn_native/Device.h
@@ -163,6 +163,7 @@
void Tick();
+ void SetDeviceLostCallback(wgpu::DeviceLostCallback callback, void* userdata);
void SetUncapturedErrorCallback(wgpu::ErrorCallback callback, void* userdata);
void PushErrorScope(wgpu::ErrorFilter filter);
bool PopErrorScope(wgpu::ErrorCallback callback, void* userdata);
@@ -262,6 +263,9 @@
// resources.
virtual MaybeError WaitForIdleForDestruction() = 0;
+ wgpu::DeviceLostCallback mDeviceLostCallback = nullptr;
+ void* mDeviceLostUserdata;
+
AdapterBase* mAdapter = nullptr;
Ref<ErrorScope> mRootErrorScope;
diff --git a/src/dawn_wire/client/ApiProcs.cpp b/src/dawn_wire/client/ApiProcs.cpp
index 14f25e4..128bde7 100644
--- a/src/dawn_wire/client/ApiProcs.cpp
+++ b/src/dawn_wire/client/ApiProcs.cpp
@@ -431,5 +431,11 @@
Device* device = reinterpret_cast<Device*>(cSelf);
device->SetUncapturedErrorCallback(callback, userdata);
}
+ void ClientDeviceSetDeviceLostCallback(WGPUDevice cSelf,
+ WGPUDeviceLostCallback callback,
+ void* userdata) {
+ Device* device = reinterpret_cast<Device*>(cSelf);
+ device->SetDeviceLostCallback(callback, userdata);
+ }
}} // namespace dawn_wire::client
diff --git a/src/dawn_wire/client/ClientDoers.cpp b/src/dawn_wire/client/ClientDoers.cpp
index 1be0f1d..dd90406 100644
--- a/src/dawn_wire/client/ClientDoers.cpp
+++ b/src/dawn_wire/client/ClientDoers.cpp
@@ -35,6 +35,11 @@
return true;
}
+ bool Client::DoDeviceLostCallback(char const* message) {
+ mDevice->HandleDeviceLost(message);
+ return true;
+ }
+
bool Client::DoDevicePopErrorScopeCallback(uint64_t requestSerial,
WGPUErrorType errorType,
const char* message) {
diff --git a/src/dawn_wire/client/Device.cpp b/src/dawn_wire/client/Device.cpp
index 8577a4f..175617f 100644
--- a/src/dawn_wire/client/Device.cpp
+++ b/src/dawn_wire/client/Device.cpp
@@ -42,11 +42,22 @@
}
}
+ void Device::HandleDeviceLost(const char* message) {
+ if (mDeviceLostCallback) {
+ mDeviceLostCallback(message, mDeviceLostUserdata);
+ }
+ }
+
void Device::SetUncapturedErrorCallback(WGPUErrorCallback errorCallback, void* errorUserdata) {
mErrorCallback = errorCallback;
mErrorUserdata = errorUserdata;
}
+ void Device::SetDeviceLostCallback(WGPUDeviceLostCallback callback, void* userdata) {
+ mDeviceLostCallback = callback;
+ mDeviceLostUserdata = userdata;
+ }
+
void Device::PushErrorScope(WGPUErrorFilter filter) {
mErrorScopeStackSize++;
diff --git a/src/dawn_wire/client/Device.h b/src/dawn_wire/client/Device.h
index 9c1bb2f..af5934e 100644
--- a/src/dawn_wire/client/Device.h
+++ b/src/dawn_wire/client/Device.h
@@ -32,7 +32,9 @@
Client* GetClient();
void HandleError(WGPUErrorType errorType, const char* message);
+ void HandleDeviceLost(const char* message);
void SetUncapturedErrorCallback(WGPUErrorCallback errorCallback, void* errorUserdata);
+ void SetDeviceLostCallback(WGPUDeviceLostCallback errorCallback, void* errorUserdata);
void PushErrorScope(WGPUErrorFilter filter);
bool RequestPopErrorScope(WGPUErrorCallback callback, void* userdata);
@@ -49,7 +51,9 @@
Client* mClient = nullptr;
WGPUErrorCallback mErrorCallback = nullptr;
+ WGPUDeviceLostCallback mDeviceLostCallback = nullptr;
void* mErrorUserdata;
+ void* mDeviceLostUserdata;
};
}} // namespace dawn_wire::client
diff --git a/src/dawn_wire/server/Server.cpp b/src/dawn_wire/server/Server.cpp
index d03980a..d05285d 100644
--- a/src/dawn_wire/server/Server.cpp
+++ b/src/dawn_wire/server/Server.cpp
@@ -32,6 +32,7 @@
deviceData->handle = device;
mProcs.deviceSetUncapturedErrorCallback(device, ForwardUncapturedError, this);
+ mProcs.deviceSetDeviceLostCallback(device, ForwardDeviceLost, this);
}
Server::~Server() {
diff --git a/src/dawn_wire/server/Server.h b/src/dawn_wire/server/Server.h
index 28f4c81..effe69e 100644
--- a/src/dawn_wire/server/Server.h
+++ b/src/dawn_wire/server/Server.h
@@ -62,6 +62,7 @@
// Forwarding callbacks
static void ForwardUncapturedError(WGPUErrorType type, const char* message, void* userdata);
+ static void ForwardDeviceLost(const char* message, void* userdata);
static void ForwardPopErrorScope(WGPUErrorType type, const char* message, void* userdata);
static void ForwardBufferMapReadAsync(WGPUBufferMapAsyncStatus status,
const void* ptr,
@@ -75,6 +76,7 @@
// Error callbacks
void OnUncapturedError(WGPUErrorType type, const char* message);
+ void OnDeviceLost(const char* message);
void OnDevicePopErrorScope(WGPUErrorType type,
const char* message,
ErrorScopeUserdata* userdata);
diff --git a/src/dawn_wire/server/ServerDevice.cpp b/src/dawn_wire/server/ServerDevice.cpp
index 6f27867..66c0d70 100644
--- a/src/dawn_wire/server/ServerDevice.cpp
+++ b/src/dawn_wire/server/ServerDevice.cpp
@@ -21,6 +21,11 @@
server->OnUncapturedError(type, message);
}
+ void Server::ForwardDeviceLost(const char* message, void* userdata) {
+ auto server = static_cast<Server*>(userdata);
+ server->OnDeviceLost(message);
+ }
+
void Server::OnUncapturedError(WGPUErrorType type, const char* message) {
ReturnDeviceUncapturedErrorCallbackCmd cmd;
cmd.type = type;
@@ -31,6 +36,15 @@
cmd.Serialize(allocatedBuffer);
}
+ void Server::OnDeviceLost(const char* message) {
+ ReturnDeviceLostCallbackCmd cmd;
+ cmd.message = message;
+
+ size_t requiredSize = cmd.GetRequiredSize();
+ char* allocatedBuffer = static_cast<char*>(GetCmdSpace(requiredSize));
+ cmd.Serialize(allocatedBuffer);
+ }
+
bool Server::DoDevicePopErrorScope(WGPUDevice cDevice, uint64_t requestSerial) {
ErrorScopeUserdata* userdata = new ErrorScopeUserdata;
userdata->server = this;
diff --git a/src/tests/unittests/wire/WireErrorCallbackTests.cpp b/src/tests/unittests/wire/WireErrorCallbackTests.cpp
index 18436fc..93b5342 100644
--- a/src/tests/unittests/wire/WireErrorCallbackTests.cpp
+++ b/src/tests/unittests/wire/WireErrorCallbackTests.cpp
@@ -42,6 +42,16 @@
mockDevicePopErrorScopeCallback->Call(type, message, userdata);
}
+ class MockDeviceLostCallback {
+ public:
+ MOCK_METHOD2(Call, void(const char* message, void* userdata));
+ };
+
+ std::unique_ptr<StrictMock<MockDeviceLostCallback>> mockDeviceLostCallback;
+ void ToMockDeviceLostCallback(const char* message, void* userdata) {
+ mockDeviceLostCallback->Call(message, userdata);
+ }
+
} // anonymous namespace
class WireErrorCallbackTests : public WireTest {
@@ -55,6 +65,7 @@
mockDeviceErrorCallback = std::make_unique<StrictMock<MockDeviceErrorCallback>>();
mockDevicePopErrorScopeCallback = std::make_unique<StrictMock<MockDevicePopErrorScopeCallback>>();
+ mockDeviceLostCallback = std::make_unique<StrictMock<MockDeviceLostCallback>>();
}
void TearDown() override {
@@ -62,6 +73,7 @@
mockDeviceErrorCallback = nullptr;
mockDevicePopErrorScopeCallback = nullptr;
+ mockDeviceLostCallback = nullptr;
}
void FlushServer() {
@@ -232,3 +244,19 @@
FlushServer();
}
}
+
+// Test the return wire for device lost callback
+TEST_F(WireErrorCallbackTests, DeviceLostCallback) {
+ wgpuDeviceSetDeviceLostCallback(device, ToMockDeviceLostCallback, this);
+
+ // Setting the error callback should stay on the client side and do nothing
+ FlushClient();
+
+ // Calling the callback on the server side will result in the callback being called on the
+ // client side
+ api.CallDeviceLostCallback(apiDevice, "Some error message");
+
+ EXPECT_CALL(*mockDeviceLostCallback, Call(StrEq("Some error message"), this)).Times(1);
+
+ FlushServer();
+}
diff --git a/src/tests/unittests/wire/WireTest.cpp b/src/tests/unittests/wire/WireTest.cpp
index d48c085..9dce5e7 100644
--- a/src/tests/unittests/wire/WireTest.cpp
+++ b/src/tests/unittests/wire/WireTest.cpp
@@ -43,6 +43,7 @@
// This SetCallback call cannot be ignored because it is done as soon as we start the server
EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(_, _, _)).Times(Exactly(1));
+ EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(_, _, _)).Times(Exactly(1));
SetupIgnoredCallExpectations();
mS2cBuf = std::make_unique<utils::TerribleCommandBuffer>();