Reland "[wgpu-headers] Align device lost callbacks with webgpu.h and futures."
Relanding after the following fixes to users:
- Chrome: https://chromium-review.googlesource.com/c/chromium/src/+/5441439
- Skia/Graphite: https://skia-review.googlesource.com/c/skia/+/838480
- MLDrift: https://critique.corp.google.com/cl/623262112
Original change's description:
> [wgpu-headers] Align device lost callbacks with webgpu.h and futures.
>
> - Overrides wire::ObjectBase::Release logic for Device. This was
> necessary because the DeviceLostEvent holds a ref to the Device.
> - Introduces non-progressing SystemEvents. This is needed because we
> want to make sure that polling doesn't happen as long as a device
> lost event exists.
>
> Bug: dawn:2021
> Change-Id: I120ca3c1e4b2bfd00b43bb9f26d91cb5c3f2e4d6
> Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/165527
> Reviewed-by: Corentin Wallez <cwallez@chromium.org>
> Reviewed-by: Austin Eng <enga@chromium.org>
> Reviewed-by: Kai Ninomiya <kainino@chromium.org>
> Commit-Queue: Loko Kung <lokokung@google.com>
Bug: dawn:2450, dawn:2021
Change-Id: I72d6105d719942d1b9745d476e0d07498e659a51
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/183162
Commit-Queue: Loko Kung <lokokung@google.com>
Reviewed-by: Austin Eng <enga@chromium.org>
diff --git a/generator/templates/dawn/wire/server/ServerDoers.cpp b/generator/templates/dawn/wire/server/ServerDoers.cpp
index 3bc147a..4431f14 100644
--- a/generator/templates/dawn/wire/server/ServerDoers.cpp
+++ b/generator/templates/dawn/wire/server/ServerDoers.cpp
@@ -85,16 +85,17 @@
WIRE_TRY({{type.name.CamelCase()}}Objects().Get(objectId, &obj));
if (obj->state == AllocationState::Allocated) {
- DAWN_ASSERT(obj->handle != nullptr);
{% if type.name.get() == "device" %}
if (obj->handle != nullptr) {
//* Deregisters uncaptured error and device lost callbacks since
//* they should not be forwarded if the device no longer exists on the wire.
ClearDeviceCallbacks(obj->handle);
+ mProcs.{{as_varName(type.name, Name("release"))}}(obj->handle);
}
+ {% else %}
+ DAWN_ASSERT(obj->handle != nullptr);
+ mProcs.{{as_varName(type.name, Name("release"))}}(obj->handle);
{% endif %}
-
- mProcs.{{as_varName(type.name, Name("release"))}}(obj->handle);
}
{{type.name.CamelCase()}}Objects().Free(objectId);
return WireResult::Success;
diff --git a/generator/templates/mock_api.cpp b/generator/templates/mock_api.cpp
index ec792af..c7eeb8c 100644
--- a/generator/templates/mock_api.cpp
+++ b/generator/templates/mock_api.cpp
@@ -76,13 +76,15 @@
//* Generate the older Call*Callback if there is no Future call equivalent.
//* Includes:
-//* - setDeviceLostCallback
//* - setUncapturedErrorCallback
//* - setLoggingCallback
-{% set LegacyCallbackFunctions = ['set device lost callback', 'set uncaptured error callback', 'set logging callback'] %}
+{% set LegacyCallbackFunctions = ['set uncaptured error callback', 'set logging callback'] %}
+
+//* Manually implemented mock functions due to incompatibility.
+{% set ManuallyMockedFunctions = ['set device lost callback'] %}
{% for type in by_category["object"] %}
- {% for method in type.methods %}
+ {% for method in type.methods if method.name.get() not in ManuallyMockedFunctions %}
{% set Suffix = as_CppMethodSuffix(type.name, method.name) %}
{% if has_callback_arguments(method) %}
{{as_cType(method.return_type.name)}} ProcTableAsClass::{{Suffix}}(
@@ -186,6 +188,28 @@
{% endfor %}
{% endfor %}
+// Manually implement device lost related callback helpers for testing.
+void ProcTableAsClass::DeviceSetDeviceLostCallback(WGPUDevice device,
+ WGPUDeviceLostCallback callback,
+ void* userdata) {
+ ProcTableAsClass::Object* object = reinterpret_cast<ProcTableAsClass::Object*>(device);
+ object->mDeviceLostOldCallback = callback;
+ object->mDeviceLostUserdata = userdata;
+
+ OnDeviceSetDeviceLostCallback(device, callback, userdata);
+}
+void ProcTableAsClass::CallDeviceSetDeviceLostCallbackCallback(WGPUDevice device,
+ WGPUDeviceLostReason reason,
+ char const* message) {
+ ProcTableAsClass::Object* object = reinterpret_cast<ProcTableAsClass::Object*>(device);
+ // If we have an old callback set, call that one, otherwise call the new one.
+ if (object->mDeviceLostOldCallback != nullptr) {
+ object->mDeviceLostOldCallback(reason, message, object->mDeviceLostUserdata);
+ } else {
+ object->mDeviceLostCallback(&device, reason, message, object->mDeviceLostUserdata);
+ }
+}
+
{% for type in by_category["object"] %}
{{as_cType(type.name)}} ProcTableAsClass::GetNew{{type.name.CamelCase()}}() {
mObjects.emplace_back(new Object);
diff --git a/generator/templates/mock_api.h b/generator/templates/mock_api.h
index e30a33c..77769b0 100644
--- a/generator/templates/mock_api.h
+++ b/generator/templates/mock_api.h
@@ -63,13 +63,16 @@
//* - setDeviceLostCallback
//* - setUncapturedErrorCallback
//* - setLoggingCallback
- {%- set LegacyCallbackFunctions = ['set device lost callback', 'set uncaptured error callback', 'set logging callback'] %}
+ {%- set LegacyCallbackFunctions = ['set uncaptured error callback', 'set logging callback'] %}
+
+ //* Manually implemented mock functions due to incompatibility.
+ {% set ManuallyMockedFunctions = ['set device lost callback'] %}
{%- for type in by_category["object"] %}
virtual void {{as_MethodSuffix(type.name, Name("reference"))}}({{as_cType(type.name)}} self) = 0;
virtual void {{as_MethodSuffix(type.name, Name("release"))}}({{as_cType(type.name)}} self) = 0;
- {% for method in type.methods %}
+ {% for method in type.methods if method.name.get() not in ManuallyMockedFunctions %}
{% set Suffix = as_CppMethodSuffix(type.name, method.name) %}
{% if not has_callback_arguments(method) and not has_callback_info(method) %}
virtual {{as_cType(method.return_type.name)}} {{Suffix}}(
@@ -122,6 +125,17 @@
{% endfor %}
{% endfor %}
+ // Manually implement device lost related callback helpers for testing.
+ void DeviceSetDeviceLostCallback(WGPUDevice device,
+ WGPUDeviceLostCallback callback,
+ void* userdata);
+ virtual void OnDeviceSetDeviceLostCallback(WGPUDevice device,
+ WGPUDeviceLostCallback callback,
+ void* userdata) = 0;
+ void CallDeviceSetDeviceLostCallbackCallback(WGPUDevice device,
+ WGPUDeviceLostReason reason,
+ char const* message);
+
struct Object {
ProcTableAsClass* procs = nullptr;
{% for type in by_category["object"] %}
@@ -138,6 +152,10 @@
{% endfor %}
{% endfor %}
{% endfor %}
+ // Manually implement device lost related callback helpers for testing.
+ WGPUDeviceLostCallback mDeviceLostOldCallback = nullptr;
+ WGPUDeviceLostCallbackNew mDeviceLostCallback = nullptr;
+ void* mDeviceLostUserdata = 0;
};
private:
@@ -178,6 +196,12 @@
), (override));
{% endfor %}
{% endfor %}
+
+ // Manually implement device lost related callback helpers for testing.
+ MOCK_METHOD(void,
+ OnDeviceSetDeviceLostCallback,
+ (WGPUDevice device, WGPUDeviceLostCallback callback, void* userdata),
+ (override));
};
#endif // MOCK_{{API}}_H
diff --git a/src/dawn/dawn.json b/src/dawn/dawn.json
index 5a634bf..a50d79c 100644
--- a/src/dawn/dawn.json
+++ b/src/dawn/dawn.json
@@ -208,8 +208,10 @@
{"name": "required features", "type": "feature name", "annotation": "const*", "length": "required feature count", "default": "nullptr"},
{"name": "required limits", "type": "required limits", "annotation": "const*", "optional": true},
{"name": "default queue", "type": "queue descriptor"},
- {"name": "device lost callback", "type": "device lost callback", "default": "nullptr"},
- {"name": "device lost userdata", "type": "void *", "default": "nullptr"}
+ {"name": "device lost callback", "type": "device lost callback", "default": "nullptr", "tags": ["deprecated"]},
+ {"name": "device lost userdata", "type": "void *", "default": "nullptr", "tags": ["deprecated"]},
+ {"name": "device lost callback info", "type": "device lost callback info"},
+ {"name": "uncaptured error callback info", "type": "uncaptured error callback info"}
]
},
"dawn toggles descriptor": {
@@ -1509,12 +1511,32 @@
{"name": "userdata", "type": "void *"}
]
},
+ "device lost callback new": {
+ "category": "function pointer",
+ "args": [
+ {"name": "device", "type": "device", "annotation": "const*", "length": 1},
+ {"name": "reason", "type": "device lost reason"},
+ {"name": "message", "type": "char", "annotation": "const*", "length": "strlen"},
+ {"name": "userdata", "type": "void *"}
+ ]
+ },
+ "device lost callback info": {
+ "category": "structure",
+ "extensible": "in",
+ "members": [
+ {"name": "mode", "type": "callback mode", "default": "wait any only", "_comment": "TODO(crbug.com/dawn/2458) Default should be removed."},
+ {"name": "callback", "type": "device lost callback new", "default": "nullptr"},
+ {"name": "userdata", "type": "void *", "default": "nullptr"}
+ ]
+ },
"device lost reason": {
"category": "enum",
"emscripten_no_enum_table": true,
"values": [
{"value": 0, "name": "undefined", "jsrepr": "undefined"},
- {"value": 1, "name": "destroyed"}
+ {"value": 1, "name": "destroyed"},
+ {"value": 2, "name": "instance dropped"},
+ {"value": 3, "name": "failed creation"}
]
},
"double": {
@@ -1528,6 +1550,14 @@
{"name": "userdata", "type": "void *"}
]
},
+ "uncaptured error callback info": {
+ "category": "structure",
+ "extensible": "in",
+ "members": [
+ {"name": "callback", "type": "error callback", "default": "nullptr"},
+ {"name": "userdata", "type": "void *", "default": "nullptr"}
+ ]
+ },
"pop error scope status": {
"category": "enum",
"emscripten_no_enum_table": true,
@@ -2373,7 +2403,7 @@
"callback mode": {
"category": "enum",
"emscripten_no_enum_table": true,
- "_comment": "TODO(crbug.com/dawn/2224): Should this be renumbered to reserve 0?",
+ "_comment": "TODO(crbug.com/dawn/2458): Should this be renumbered to reserve 0?",
"values": [
{"value": 0, "name": "wait any only"},
{"value": 1, "name": "allow process events"},
diff --git a/src/dawn/dawn_wire.json b/src/dawn/dawn_wire.json
index 4b230a3..f87a712 100644
--- a/src/dawn/dawn_wire.json
+++ b/src/dawn/dawn_wire.json
@@ -115,6 +115,7 @@
{ "name": "event manager handle", "type": "ObjectHandle" },
{ "name": "future", "type": "future" },
{ "name": "device object handle", "type": "ObjectHandle", "handle_type": "device"},
+ { "name": "device lost future", "type": "future" },
{ "name": "descriptor", "type": "device descriptor", "annotation": "const*" }
]
},
@@ -149,7 +150,8 @@
{ "name": "message", "type": "char", "annotation": "const*", "length": "strlen" }
],
"device lost callback" : [
- { "name": "device", "type": "ObjectHandle", "handle_type": "device" },
+ { "name": "event manager", "type": "ObjectHandle" },
+ { "name": "future", "type": "future" },
{ "name": "reason", "type": "device lost reason" },
{ "name": "message", "type": "char", "annotation": "const*", "length": "strlen" }
],
diff --git a/src/dawn/native/Adapter.cpp b/src/dawn/native/Adapter.cpp
index 7d19538..6441adc 100644
--- a/src/dawn/native/Adapter.cpp
+++ b/src/dawn/native/Adapter.cpp
@@ -29,10 +29,12 @@
#include <algorithm>
#include <memory>
+#include <string>
#include <tuple>
#include <utility>
#include <vector>
+#include "dawn/common/Log.h"
#include "dawn/native/ChainUtils.h"
#include "dawn/native/Device.h"
#include "dawn/native/Instance.h"
@@ -40,6 +42,9 @@
#include "partition_alloc/pointers/raw_ptr.h"
namespace dawn::native {
+namespace {
+static constexpr DeviceDescriptor kDefaultDeviceDesc = {};
+} // anonymous namespace
AdapterBase::AdapterBase(Ref<PhysicalDeviceBase> physicalDevice,
FeatureLevel featureLevel,
@@ -191,20 +196,24 @@
return mSupportedFeatures.EnumerateFeatures(features);
}
+// TODO(https://crbug.com/dawn/2465) Could potentially re-implement via AllowSpontaneous async mode.
DeviceBase* AdapterBase::APICreateDevice(const DeviceDescriptor* descriptor) {
- constexpr DeviceDescriptor kDefaultDesc = {};
if (descriptor == nullptr) {
- descriptor = &kDefaultDesc;
+ descriptor = &kDefaultDeviceDesc;
}
+ auto [lostEvent, result] = CreateDevice(descriptor);
+ mPhysicalDevice->GetInstance()->GetEventManager()->TrackEvent(lostEvent);
Ref<DeviceBase> device;
- if (mPhysicalDevice->GetInstance()->ConsumedError(CreateDevice(descriptor), &device)) {
+ if (mPhysicalDevice->GetInstance()->ConsumedError(std::move(result), &device)) {
return nullptr;
}
return ReturnToAPI(std::move(device));
}
-ResultOrError<Ref<DeviceBase>> AdapterBase::CreateDevice(const DeviceDescriptor* rawDescriptor) {
+ResultOrError<Ref<DeviceBase>> AdapterBase::CreateDeviceInternal(
+ const DeviceDescriptor* rawDescriptor,
+ Ref<DeviceBase::DeviceLostEvent> lostEvent) {
DAWN_ASSERT(rawDescriptor != nullptr);
// Create device toggles state from required toggles descriptor and inherited adapter toggles
@@ -255,29 +264,30 @@
"validating required limits");
}
- return mPhysicalDevice->CreateDevice(this, descriptor, deviceToggles);
+ return mPhysicalDevice->CreateDevice(this, descriptor, deviceToggles, std::move(lostEvent));
+}
+
+std::pair<Ref<DeviceBase::DeviceLostEvent>, ResultOrError<Ref<DeviceBase>>>
+AdapterBase::CreateDevice(const DeviceDescriptor* descriptor) {
+ DAWN_ASSERT(descriptor != nullptr);
+
+ Ref<DeviceBase::DeviceLostEvent> lostEvent = DeviceBase::DeviceLostEvent::Create(descriptor);
+
+ auto result = CreateDeviceInternal(descriptor, lostEvent);
+ if (result.IsError()) {
+ lostEvent->mReason = wgpu::DeviceLostReason::FailedCreation;
+ lostEvent->mMessage = "Failed to create device.";
+ mPhysicalDevice->GetInstance()->GetEventManager()->SetFutureReady(lostEvent.Get());
+ }
+ return {lostEvent, std::move(result)};
}
void AdapterBase::APIRequestDevice(const DeviceDescriptor* descriptor,
WGPURequestDeviceCallback callback,
void* userdata) {
- constexpr DeviceDescriptor kDefaultDescriptor = {};
- if (descriptor == nullptr) {
- descriptor = &kDefaultDescriptor;
- }
- auto result = CreateDevice(descriptor);
- if (result.IsError()) {
- std::unique_ptr<ErrorData> errorData = result.AcquireError();
- // TODO(crbug.com/dawn/1122): Call callbacks only on wgpuInstanceProcessEvents
- callback(WGPURequestDeviceStatus_Error, nullptr, errorData->GetFormattedMessage().c_str(),
- userdata);
- return;
- }
- Ref<DeviceBase> device = result.AcquireSuccess();
- WGPURequestDeviceStatus status =
- device == nullptr ? WGPURequestDeviceStatus_Unknown : WGPURequestDeviceStatus_Success;
- // TODO(crbug.com/dawn/1122): Call callbacks only on wgpuInstanceProcessEvents
- callback(status, ToAPI(ReturnToAPI(std::move(device))), nullptr, userdata);
+ // Default legacy callback mode for RequestDevice is spontaneous.
+ APIRequestDeviceF(descriptor,
+ {nullptr, wgpu::CallbackMode::AllowSpontaneous, callback, userdata});
}
Future AdapterBase::APIRequestDeviceF(const DeviceDescriptor* descriptor,
@@ -285,47 +295,53 @@
struct RequestDeviceEvent final : public EventManager::TrackedEvent {
WGPURequestDeviceCallback mCallback;
raw_ptr<void> mUserdata;
- ResultOrError<Ref<DeviceBase>> mDeviceOrError;
- RequestDeviceEvent(const RequestDeviceCallbackInfo& callbackInfo,
- ResultOrError<Ref<DeviceBase>> deviceOrError)
+ WGPURequestDeviceStatus mStatus;
+ Ref<DeviceBase> mDevice = nullptr;
+ std::string mMessage;
+
+ RequestDeviceEvent(const RequestDeviceCallbackInfo& callbackInfo, Ref<DeviceBase> device)
: TrackedEvent(callbackInfo.mode, TrackedEvent::Completed{}),
mCallback(callbackInfo.callback),
mUserdata(callbackInfo.userdata),
- mDeviceOrError(std::move(deviceOrError)) {}
+ mStatus(WGPURequestDeviceStatus_Success),
+ mDevice(std::move(device)) {}
+
+ RequestDeviceEvent(const RequestDeviceCallbackInfo& callbackInfo,
+ const std::string& message)
+ : TrackedEvent(callbackInfo.mode, TrackedEvent::Completed{}),
+ mCallback(callbackInfo.callback),
+ mUserdata(callbackInfo.userdata),
+ mStatus(WGPURequestDeviceStatus_Error),
+ mMessage(message) {}
~RequestDeviceEvent() override { EnsureComplete(EventCompletionType::Shutdown); }
void Complete(EventCompletionType completionType) override {
- WGPURequestDeviceStatus status;
- Ref<DeviceBase> device;
-
if (completionType == EventCompletionType::Shutdown) {
- status = WGPURequestDeviceStatus_InstanceDropped;
- } else {
- if (mDeviceOrError.IsError()) {
- std::unique_ptr<ErrorData> errorData = mDeviceOrError.AcquireError();
- mCallback(WGPURequestDeviceStatus_Error, nullptr,
- errorData->GetFormattedMessage().c_str(),
- mUserdata.ExtractAsDangling());
- return;
- }
- device = mDeviceOrError.AcquireSuccess();
- status = device == nullptr ? WGPURequestDeviceStatus_Unknown
- : WGPURequestDeviceStatus_Success;
+ mStatus = WGPURequestDeviceStatus_InstanceDropped;
+ mDevice = nullptr;
+ mMessage = "A valid external Instance reference no longer exists.";
}
- mCallback(status, ToAPI(ReturnToAPI(std::move(device))), nullptr,
- mUserdata.ExtractAsDangling());
+ mCallback(mStatus, ToAPI(ReturnToAPI(std::move(mDevice))),
+ mMessage.empty() ? nullptr : mMessage.c_str(), mUserdata.ExtractAsDangling());
}
};
- constexpr DeviceDescriptor kDefaultDescriptor = {};
if (descriptor == nullptr) {
- descriptor = &kDefaultDescriptor;
+ descriptor = &kDefaultDeviceDesc;
}
- FutureID futureID = mPhysicalDevice->GetInstance()->GetEventManager()->TrackEvent(
- AcquireRef(new RequestDeviceEvent(callbackInfo, CreateDevice(descriptor))));
+ FutureID futureID = kNullFutureID;
+ auto [lostEvent, result] = CreateDevice(descriptor);
+ if (result.IsSuccess()) {
+ futureID = mPhysicalDevice->GetInstance()->GetEventManager()->TrackEvent(
+ AcquireRef(new RequestDeviceEvent(callbackInfo, result.AcquireSuccess())));
+ } else {
+ futureID = mPhysicalDevice->GetInstance()->GetEventManager()->TrackEvent(AcquireRef(
+ new RequestDeviceEvent(callbackInfo, result.AcquireError()->GetFormattedMessage())));
+ }
+ mPhysicalDevice->GetInstance()->GetEventManager()->TrackEvent(std::move(lostEvent));
return {futureID};
}
diff --git a/src/dawn/native/Adapter.h b/src/dawn/native/Adapter.h
index 5b99069f..038de4b 100644
--- a/src/dawn/native/Adapter.h
+++ b/src/dawn/native/Adapter.h
@@ -28,11 +28,13 @@
#ifndef SRC_DAWN_NATIVE_ADAPTER_H_
#define SRC_DAWN_NATIVE_ADAPTER_H_
+#include <utility>
#include <vector>
#include "dawn/common/Ref.h"
#include "dawn/common/RefCounted.h"
#include "dawn/native/DawnNative.h"
+#include "dawn/native/Device.h"
#include "dawn/native/PhysicalDevice.h"
#include "dawn/native/dawn_platform.h"
@@ -62,7 +64,6 @@
Future APIRequestDeviceF(const DeviceDescriptor* descriptor,
const RequestDeviceCallbackInfo& callbackInfo);
DeviceBase* APICreateDevice(const DeviceDescriptor* descriptor = nullptr);
- ResultOrError<Ref<DeviceBase>> CreateDevice(const DeviceDescriptor* rawDescriptor);
bool APIGetFormatCapabilities(wgpu::TextureFormat format, FormatCapabilities* capabilities);
void SetUseTieredLimits(bool useTieredLimits);
@@ -78,6 +79,11 @@
FeatureLevel GetFeatureLevel() const;
private:
+ std::pair<Ref<DeviceBase::DeviceLostEvent>, ResultOrError<Ref<DeviceBase>>> CreateDevice(
+ const DeviceDescriptor* rawDescriptor);
+ ResultOrError<Ref<DeviceBase>> CreateDeviceInternal(const DeviceDescriptor* rawDescriptor,
+ Ref<DeviceBase::DeviceLostEvent> lostEvent);
+
Ref<PhysicalDeviceBase> mPhysicalDevice;
FeatureLevel mFeatureLevel;
bool mUseTieredLimits = false;
diff --git a/src/dawn/native/Buffer.cpp b/src/dawn/native/Buffer.cpp
index b1084d4..3efce35 100644
--- a/src/dawn/native/Buffer.cpp
+++ b/src/dawn/native/Buffer.cpp
@@ -235,7 +235,6 @@
(*buffer)->mState = BufferState::Mapped;
pendingMapEvent = std::move((*buffer)->mPendingMapEvent);
- (*buffer)->mPendingMapFutureID = kNullFutureID;
}
});
mCallback(ToAPI(status), mUserdata);
@@ -636,9 +635,6 @@
}
FutureID futureID = GetInstance()->GetEventManager()->TrackEvent(std::move(event));
- if (!earlyStatus) {
- mPendingMapFutureID = futureID;
- }
return {futureID};
}
@@ -713,10 +709,8 @@
// state and pending map event needs to be atomic w.r.t. MapAsyncEvent::Complete.
Ref<MapAsyncEvent> pendingMapEvent = std::move(mPendingMapEvent);
if (pendingMapEvent != nullptr) {
- DAWN_ASSERT(mPendingMapFutureID != kNullFutureID);
pendingMapEvent->UnmapEarly(static_cast<wgpu::BufferMapAsyncStatus>(callbackStatus));
- GetInstance()->GetEventManager()->SetFutureReady(mPendingMapFutureID);
- mPendingMapFutureID = kNullFutureID;
+ GetInstance()->GetEventManager()->SetFutureReady(pendingMapEvent.Get());
} else {
GetDevice()->GetCallbackTaskManager()->AddCallbackTask(
PrepareMappingCallback(mLastMapID, callbackStatus));
diff --git a/src/dawn/native/Buffer.h b/src/dawn/native/Buffer.h
index 401ff66..6204fff 100644
--- a/src/dawn/native/Buffer.h
+++ b/src/dawn/native/Buffer.h
@@ -186,7 +186,6 @@
size_t mMapSize = 0;
struct MapAsyncEvent;
- FutureID mPendingMapFutureID = kNullFutureID;
Ref<MapAsyncEvent> mPendingMapEvent;
};
diff --git a/src/dawn/native/Device.cpp b/src/dawn/native/Device.cpp
index c014fa0..46957d0 100644
--- a/src/dawn/native/Device.cpp
+++ b/src/dawn/native/Device.cpp
@@ -130,6 +130,7 @@
};
namespace {
+
struct LoggingCallbackTask : CallbackTask {
public:
LoggingCallbackTask() = delete;
@@ -163,8 +164,95 @@
std::string mMessage;
raw_ptr<void> mUserdata;
};
+
+static constexpr UncapturedErrorCallbackInfo kEmptyUncapturedErrorCallbackInfo = {nullptr, nullptr,
+ nullptr};
+
} // anonymous namespace
+DeviceBase::DeviceLostEvent::DeviceLostEvent(const DeviceLostCallbackInfo& callbackInfo)
+ : TrackedEvent(callbackInfo.mode, SystemEvent::CreateNonProgressingEvent()),
+ mCallback(callbackInfo.callback),
+ mUserdata(callbackInfo.userdata) {}
+
+DeviceBase::DeviceLostEvent::DeviceLostEvent(wgpu::DeviceLostCallback oldCallback, void* userdata)
+ : TrackedEvent(wgpu::CallbackMode::AllowProcessEvents,
+ SystemEvent::CreateNonProgressingEvent()),
+ mOldCallback(oldCallback),
+ mUserdata(userdata) {}
+
+DeviceBase::DeviceLostEvent::~DeviceLostEvent() {
+ EnsureComplete(EventCompletionType::Shutdown);
+}
+
+// static
+Ref<DeviceBase::DeviceLostEvent> DeviceBase::DeviceLostEvent::Create(
+ const DeviceDescriptor* descriptor) {
+ DAWN_ASSERT(descriptor != nullptr);
+
+#if defined(DAWN_ENABLE_ASSERTS)
+ // TODO(crbug.com/dawn/2465) Make default AllowSpontaneous once SetDeviceLostCallback is gone.
+ static constexpr DeviceLostCallbackInfo kDefaultDeviceLostCallbackInfo = {
+ nullptr, wgpu::CallbackMode::AllowProcessEvents,
+ [](WGPUDevice const*, WGPUDeviceLostReason, char const*, void*) {
+ static bool calledOnce = false;
+ if (!calledOnce) {
+ calledOnce = true;
+ dawn::WarningLog() << "No Dawn device lost callback was set. This is probably not "
+ "intended. If you really want to ignore device lost and "
+ "suppress this message, set the callback explicitly.";
+ }
+ },
+ nullptr};
+#else
+ static constexpr DeviceLostCallbackInfo kDefaultDeviceLostCallbackInfo = {
+ nullptr, wgpu::CallbackMode::AllowProcessEvents, nullptr, nullptr};
+#endif // DAWN_ENABLE_ASSERTS
+
+ Ref<DeviceBase::DeviceLostEvent> lostEvent;
+ if (descriptor->deviceLostCallback != nullptr) {
+ dawn::WarningLog()
+ << "DeviceDescriptor.deviceLostCallback and DeviceDescriptor.deviceLostUserdata are "
+ "deprecated. Use DeviceDescriptor.deviceLostCallbackInfo instead.";
+ lostEvent = AcquireRef(new DeviceBase::DeviceLostEvent(descriptor->deviceLostCallback,
+ descriptor->deviceLostUserdata));
+ } else {
+ DeviceLostCallbackInfo deviceLostCallbackInfo = kDefaultDeviceLostCallbackInfo;
+ if (descriptor->deviceLostCallbackInfo.callback != nullptr ||
+ descriptor->deviceLostCallbackInfo.mode != wgpu::CallbackMode::WaitAnyOnly) {
+ deviceLostCallbackInfo = descriptor->deviceLostCallbackInfo;
+ if (deviceLostCallbackInfo.mode != wgpu::CallbackMode::AllowSpontaneous) {
+ // TODO(dawn:2458) Currently we default the callback mode to ProcessEvents if not
+ // passed for backwards compatibility. We should add warning logging for it though
+ // when available.
+ deviceLostCallbackInfo.mode = wgpu::CallbackMode::AllowProcessEvents;
+ }
+ }
+ lostEvent = AcquireRef(new DeviceBase::DeviceLostEvent(deviceLostCallbackInfo));
+ }
+
+ return lostEvent;
+}
+
+void DeviceBase::DeviceLostEvent::Complete(EventCompletionType completionType) {
+ if (completionType == EventCompletionType::Shutdown) {
+ mReason = wgpu::DeviceLostReason::InstanceDropped;
+ mMessage = "A valid external Instance reference no longer exists.";
+ }
+ if (mReason == wgpu::DeviceLostReason::InstanceDropped ||
+ mReason == wgpu::DeviceLostReason::FailedCreation) {
+ mDevice = nullptr;
+ }
+
+ if (mOldCallback) {
+ mOldCallback(ToAPI(mReason), mMessage.c_str(), mUserdata.ExtractAsDangling());
+ } else if (mCallback) {
+ auto device = ToAPI(mDevice.Get());
+ mCallback(&device, ToAPI(mReason), mMessage.c_str(), mUserdata.ExtractAsDangling());
+ }
+ mDevice = nullptr;
+}
+
ResultOrError<Ref<PipelineLayoutBase>> ValidateLayoutAndGetComputePipelineDescriptorWithDefaults(
DeviceBase* device,
const ComputePipelineDescriptor& descriptor,
@@ -210,12 +298,37 @@
DeviceBase::DeviceBase(AdapterBase* adapter,
const UnpackedPtr<DeviceDescriptor>& descriptor,
- const TogglesState& deviceToggles)
- : mAdapter(adapter), mToggles(deviceToggles), mNextPipelineCompatibilityToken(1) {
+ const TogglesState& deviceToggles,
+ Ref<DeviceLostEvent>&& lostEvent)
+ : mLostEvent(std::move(lostEvent)),
+ mAdapter(adapter),
+ mToggles(deviceToggles),
+ mNextPipelineCompatibilityToken(1) {
DAWN_ASSERT(descriptor);
- mDeviceLostCallback = descriptor->deviceLostCallback;
- mDeviceLostUserdata = descriptor->deviceLostUserdata;
+ mLostEvent->mDevice = this;
+
+#if defined(DAWN_ENABLE_ASSERTS)
+ static constexpr UncapturedErrorCallbackInfo kDefaultUncapturedErrorCallbackInfo = {
+ nullptr,
+ [](WGPUErrorType, char const*, void*) {
+ static bool calledOnce = false;
+ if (!calledOnce) {
+ calledOnce = true;
+ dawn::WarningLog() << "No Dawn device uncaptured error callback was set. This is "
+ "probably not intended. If you really want to ignore errors "
+ "and suppress this message, set the callback explicitly.";
+ }
+ },
+ nullptr};
+#else
+ static constexpr UncapturedErrorCallbackInfo kDefaultUncapturedErrorCallbackInfo =
+ kEmptyUncapturedErrorCallbackInfo;
+#endif // DAWN_ENABLE_ASSERTS
+ mUncapturedErrorCallbackInfo = kDefaultUncapturedErrorCallbackInfo;
+ if (descriptor->uncapturedErrorCallbackInfo.callback != nullptr) {
+ mUncapturedErrorCallbackInfo = descriptor->uncapturedErrorCallbackInfo;
+ }
AdapterProperties adapterProperties;
adapter->APIGetProperties(&adapterProperties);
@@ -287,12 +400,18 @@
DeviceBase::DeviceBase() : mState(State::Alive), mToggles(ToggleStage::Device) {
GetDefaultLimits(&mLimits.v1, FeatureLevel::Core);
mFormatTable = BuildFormatTable(this);
+
+ DeviceDescriptor desc = {};
+ desc.deviceLostCallbackInfo = {nullptr, wgpu::CallbackMode::AllowSpontaneous};
+ mLostEvent = DeviceLostEvent::Create(&desc);
+ mLostEvent->mDevice = this;
}
DeviceBase::~DeviceBase() {
// We need to explicitly release the Queue before we complete the destructor so that the
// Queue does not get destroyed after the Device.
mQueue = nullptr;
+ mLostEvent = nullptr;
}
MaybeError DeviceBase::Initialize(Ref<QueueBase> defaultQueue) {
@@ -300,30 +419,6 @@
SetWGSLExtensionAllowList();
-#if defined(DAWN_ENABLE_ASSERTS)
- mUncapturedErrorCallback = [](WGPUErrorType, char const*, void*) {
- static bool calledOnce = false;
- if (!calledOnce) {
- calledOnce = true;
- dawn::WarningLog() << "No Dawn device uncaptured error callback was set. This is "
- "probably not intended. If you really want to ignore errors "
- "and suppress this message, set the callback to null.";
- }
- };
-
- if (!mDeviceLostCallback) {
- mDeviceLostCallback = [](WGPUDeviceLostReason, char const*, void*) {
- static bool calledOnce = false;
- if (!calledOnce) {
- calledOnce = true;
- dawn::WarningLog() << "No Dawn device lost callback was set. This is probably not "
- "intended. If you really want to ignore device lost "
- "and suppress this message, set the callback to null.";
- }
- };
- }
-#endif // DAWN_ENABLE_ASSERTS
-
mCaches = std::make_unique<DeviceBase::Caches>();
mErrorScopeStack = std::make_unique<ErrorScopeStack>();
mDynamicUploader = std::make_unique<DynamicUploader>(this);
@@ -411,17 +506,7 @@
// Reset callbacks since after dropping the last external reference, the application may have
// freed any device-scope memory needed to run the callback.
- mUncapturedErrorCallback = [](WGPUErrorType, char const* message, void*) {
- dawn::WarningLog() << "Uncaptured error after last external device reference dropped.\n"
- << message;
- };
- mUncapturedErrorUserdata = nullptr;
-
- mDeviceLostCallback = [](WGPUDeviceLostReason, char const* message, void*) {
- dawn::WarningLog() << "Device lost after last external device reference dropped.\n"
- << message;
- };
- mDeviceLostUserdata = nullptr;
+ mUncapturedErrorCallbackInfo = kEmptyUncapturedErrorCallbackInfo;
// mAdapter is not set for mock test devices.
// TODO(crbug.com/dawn/1702): using a mock adapter could avoid the null checking.
@@ -495,11 +580,11 @@
// Skip handling device facilities if they haven't even been created (or failed doing so)
if (mState != State::BeingCreated) {
// The device is being destroyed so it will be lost, call the application callback.
- if (mDeviceLostCallback != nullptr) {
- mCallbackTaskManager->AddCallbackTask(
- std::bind(mDeviceLostCallback, WGPUDeviceLostReason_Destroyed,
- "Device was destroyed.", mDeviceLostUserdata));
- mDeviceLostCallback = nullptr;
+ if (mLostEvent != nullptr) {
+ mLostEvent->mReason = wgpu::DeviceLostReason::Destroyed;
+ mLostEvent->mMessage = "Device was destroyed.";
+ GetInstance()->GetEventManager()->SetFutureReady(mLostEvent.Get());
+ mLostEvent = nullptr;
}
// Call all the callbacks immediately as the device is about to shut down.
@@ -639,12 +724,11 @@
// The device was lost, schedule the application callback's execution.
// Note: we don't invoke the callbacks directly here because it could cause re-entrances ->
// possible deadlock.
- if (mDeviceLostCallback != nullptr) {
- mCallbackTaskManager->AddCallbackTask([callback = mDeviceLostCallback, lostReason,
- messageStr, userdata = mDeviceLostUserdata] {
- callback(lostReason, messageStr.c_str(), userdata);
- });
- mDeviceLostCallback = nullptr;
+ if (mLostEvent != nullptr) {
+ mLostEvent->mReason = FromAPI(lostReason);
+ mLostEvent->mMessage = messageStr;
+ GetInstance()->GetEventManager()->SetFutureReady(mLostEvent.Get());
+ mLostEvent = nullptr;
}
mQueue->HandleDeviceLoss();
@@ -663,9 +747,10 @@
if (!captured) {
// Only call the uncaptured error callback if the device is alive. After the
// device is lost, the uncaptured error callback should cease firing.
- if (mUncapturedErrorCallback != nullptr && mState == State::Alive) {
- mUncapturedErrorCallback(static_cast<WGPUErrorType>(ToWGPUErrorType(type)),
- messageStr.c_str(), mUncapturedErrorUserdata);
+ if (mUncapturedErrorCallbackInfo.callback != nullptr && mState == State::Alive) {
+ mUncapturedErrorCallbackInfo.callback(
+ static_cast<WGPUErrorType>(ToWGPUErrorType(type)), messageStr.c_str(),
+ mUncapturedErrorCallbackInfo.userdata);
}
}
}
@@ -694,19 +779,18 @@
// Clearing the callback and userdata is allowed because in Chromium they should be cleared
// after Dawn device is destroyed and before Dawn wire server is destroyed.
if (callback == nullptr) {
- mUncapturedErrorCallback = nullptr;
- mUncapturedErrorUserdata = nullptr;
+ mUncapturedErrorCallbackInfo = kEmptyUncapturedErrorCallbackInfo;
return;
}
if (IsLost()) {
return;
}
- mUncapturedErrorCallback = callback;
- mUncapturedErrorUserdata = userdata;
+ mUncapturedErrorCallbackInfo = {nullptr, callback, userdata};
}
void DeviceBase::APISetDeviceLostCallback(wgpu::DeviceLostCallback callback, void* userdata) {
- // TODO(chromium:1234617): Add a deprecation warning.
+ EmitDeprecationWarning(
+ "SetDeviceLostCallback is deprecated. Pass the callback in the device descriptor instead.");
// The registered callback function and userdata pointer are stored and used by deferred
// callback tasks, and after setting a different callback (especially in the case of
@@ -718,15 +802,16 @@
// Clearing the callback and userdata is allowed because in Chromium they should be cleared
// after Dawn device is destroyed and before Dawn wire server is destroyed.
if (callback == nullptr) {
- mDeviceLostCallback = nullptr;
- mDeviceLostUserdata = nullptr;
+ mLostEvent->mCallback = nullptr;
+ mLostEvent->mOldCallback = nullptr;
+ mLostEvent->mUserdata = nullptr;
return;
}
if (IsLost()) {
return;
}
- mDeviceLostCallback = callback;
- mDeviceLostUserdata = userdata;
+ mLostEvent->mOldCallback = callback;
+ mLostEvent->mUserdata = userdata;
}
void DeviceBase::APIPushErrorScope(wgpu::ErrorFilter filter) {
diff --git a/src/dawn/native/Device.h b/src/dawn/native/Device.h
index 48301a4..715767e 100644
--- a/src/dawn/native/Device.h
+++ b/src/dawn/native/Device.h
@@ -79,9 +79,40 @@
class DeviceBase : public ErrorSink, public RefCountedWithExternalCount {
public:
+ struct DeviceLostEvent final : public EventManager::TrackedEvent {
+ // TODO(https://crbug.com/dawn/2465): Pass just the DeviceLostCallbackInfo when setters are
+ // deprecated. Creates and sets the device lost event for the given device if applicable. If
+ // the device is nullptr, an event is still created, but the caller owns the last ref of the
+ // event. When passing a device, note that device construction can be successful but fail
+ // later at initialization, and this should only be called with the device if initialization
+ // was successful.
+ static Ref<DeviceLostEvent> Create(const DeviceDescriptor* descriptor);
+
+ // Event result fields need to be public so that they can easily be updated prior to
+ // completing the event.
+ wgpu::DeviceLostReason mReason;
+ std::string mMessage;
+
+ wgpu::DeviceLostCallbackNew mCallback = nullptr;
+ // TODO(https://crbug.com/dawn/2465): Remove old callback when setters are deprecated, and
+ // move userdata into private.
+ wgpu::DeviceLostCallback mOldCallback = nullptr;
+ raw_ptr<void> mUserdata;
+ // Note that the device is set when the event is passed to construct a device.
+ Ref<DeviceBase> mDevice = nullptr;
+
+ private:
+ explicit DeviceLostEvent(const DeviceLostCallbackInfo& callbackInfo);
+ DeviceLostEvent(wgpu::DeviceLostCallback oldCallback, void* userdata);
+ ~DeviceLostEvent() override;
+
+ void Complete(EventCompletionType completionType) override;
+ };
+
DeviceBase(AdapterBase* adapter,
const UnpackedPtr<DeviceDescriptor>& descriptor,
- const TogglesState& deviceToggles);
+ const TogglesState& deviceToggles,
+ Ref<DeviceLostEvent>&& lostEvent);
~DeviceBase() override;
// Handles the error, causing a device loss if applicable. Almost always when a device loss
@@ -96,7 +127,9 @@
MaybeError ValidateObject(const ApiObjectBase* object) const;
- InstanceBase* GetInstance() const;
+ // TODO(dawn:1702) Remove virtual when we mock the adapter.
+ virtual InstanceBase* GetInstance() const;
+
AdapterBase* GetAdapter() const;
PhysicalDeviceBase* GetPhysicalDevice() const;
virtual dawn::platform::Platform* GetPlatform() const;
@@ -417,6 +450,11 @@
void DestroyObjects();
void Destroy();
+ // Device lost event needs to be protected for now because mock device needs it.
+ // TODO(dawn:1702) Make this private and move the class in the implementation file when we mock
+ // the adapter.
+ Ref<DeviceLostEvent> mLostEvent = nullptr;
+
private:
void WillDropLastExternalRef() override;
@@ -509,16 +547,12 @@
const TextureCopy& dst,
const Extent3D& copySizePixels) = 0;
- wgpu::ErrorCallback mUncapturedErrorCallback = nullptr;
- raw_ptr<void> mUncapturedErrorUserdata = nullptr;
+ UncapturedErrorCallbackInfo mUncapturedErrorCallbackInfo;
std::shared_mutex mLoggingMutex;
wgpu::LoggingCallback mLoggingCallback = nullptr;
raw_ptr<void> mLoggingUserdata = nullptr;
- wgpu::DeviceLostCallback mDeviceLostCallback = nullptr;
- raw_ptr<void> mDeviceLostUserdata = nullptr;
-
std::unique_ptr<ErrorScopeStack> mErrorScopeStack;
Ref<AdapterBase> mAdapter;
diff --git a/src/dawn/native/EventManager.cpp b/src/dawn/native/EventManager.cpp
index 7037557..62cfb50 100644
--- a/src/dawn/native/EventManager.cpp
+++ b/src/dawn/native/EventManager.cpp
@@ -345,62 +345,61 @@
FutureID EventManager::TrackEvent(Ref<TrackedEvent>&& event) {
FutureID futureID = mNextFutureID++;
- return mEvents.Use([&](auto events) {
- if (!events->has_value()) {
+ event->mFutureID = futureID;
+
+ // Handle the event now if it's spontaneous and ready.
+ if (event->mCallbackMode == wgpu::CallbackMode::AllowSpontaneous) {
+ bool isReady = false;
+ auto completionData = event->GetCompletionData();
+ if (std::holds_alternative<Ref<SystemEvent>>(completionData)) {
+ isReady = std::get<Ref<SystemEvent>>(completionData)->IsSignaled();
+ }
+ if (std::holds_alternative<QueueAndSerial>(completionData)) {
+ auto& queueAndSerial = std::get<QueueAndSerial>(completionData);
+ isReady = queueAndSerial.completionSerial <=
+ queueAndSerial.queue->GetCompletedCommandSerial();
+ }
+ if (isReady) {
+ event->EnsureComplete(EventCompletionType::Ready);
return futureID;
}
+ }
- if (event->mCallbackMode == wgpu::CallbackMode::AllowSpontaneous) {
- bool isReady = false;
- auto completionData = event->GetCompletionData();
- if (std::holds_alternative<Ref<SystemEvent>>(completionData)) {
- isReady = std::get<Ref<SystemEvent>>(completionData)->IsSignaled();
- }
- if (std::holds_alternative<QueueAndSerial>(completionData)) {
- auto& queueAndSerial = std::get<QueueAndSerial>(completionData);
- isReady = queueAndSerial.completionSerial <=
- queueAndSerial.queue->GetCompletedCommandSerial();
- }
- if (isReady) {
- event->EnsureComplete(EventCompletionType::Ready);
- return futureID;
- }
- }
-
- (*events)->emplace(futureID, std::move(event));
- return futureID;
- });
-}
-
-void EventManager::SetFutureReady(FutureID futureID) {
- Ref<TrackedEvent> spontaneousEvent;
mEvents.Use([&](auto events) {
if (!events->has_value()) {
return;
}
-
- if (auto it = (*events)->find(futureID); it != (*events)->end()) {
- auto& event = it->second;
-
- auto completionData = event->GetCompletionData();
- if (std::holds_alternative<Ref<SystemEvent>>(completionData)) {
- std::get<Ref<SystemEvent>>(event->GetCompletionData())->Signal();
- }
- if (std::holds_alternative<QueueAndSerial>(completionData)) {
- auto& queueAndSerial = std::get<QueueAndSerial>(completionData);
- queueAndSerial.completionSerial = queueAndSerial.queue->GetCompletedCommandSerial();
- }
-
- if (event->mCallbackMode == wgpu::CallbackMode::AllowSpontaneous) {
- spontaneousEvent = std::move(event);
- (*events)->erase(futureID);
- }
- }
+ (*events)->emplace(futureID, std::move(event));
});
+ return futureID;
+}
+
+void EventManager::SetFutureReady(TrackedEvent* event) {
+ auto completionData = event->GetCompletionData();
+ if (std::holds_alternative<Ref<SystemEvent>>(completionData)) {
+ std::get<Ref<SystemEvent>>(completionData)->Signal();
+ }
+ if (std::holds_alternative<QueueAndSerial>(completionData)) {
+ auto& queueAndSerial = std::get<QueueAndSerial>(completionData);
+ queueAndSerial.completionSerial = queueAndSerial.queue->GetCompletedCommandSerial();
+ }
+
+ // Sometimes, events might become ready before they are even tracked. This can happen because
+ // tracking is ordered to uphold callback ordering, but events may become ready in any order. If
+ // the event is spontaneous, it will be completed when it is tracked.
+ if (event->mFutureID == kNullFutureID) {
+ return;
+ }
// Handle spontaneous completion now.
- if (spontaneousEvent) {
- spontaneousEvent->EnsureComplete(EventCompletionType::Ready);
+ if (event->mCallbackMode == wgpu::CallbackMode::AllowSpontaneous) {
+ mEvents.Use([&](auto events) {
+ if (!events->has_value()) {
+ return;
+ }
+ (*events)->erase(event->mFutureID);
+ });
+ event->EnsureComplete(EventCompletionType::Ready);
}
}
@@ -409,7 +408,8 @@
std::vector<TrackedEvent::WaitRef> completable;
wgpu::WaitStatus waitStatus;
- auto needFutureProcessEvents = mEvents.Use([&](auto events) {
+ bool hasProgressingEvents = false;
+ auto hasIncompleteEvents = mEvents.Use([&](auto events) {
// Iterate all events and record poll events and spontaneous events since they are both
// allowed to be completed in the ProcessPoll call. Note that spontaneous events are allowed
// to trigger anywhere which is why we include them in the call.
@@ -417,6 +417,17 @@
futures.reserve((*events)->size());
for (auto& [futureID, event] : **events) {
if (event->mCallbackMode != wgpu::CallbackMode::WaitAnyOnly) {
+ // Figure out if there are any progressing events. If we only have non-progressing
+ // events, we need to return false to indicate that there isn't any polling work to
+ // be done.
+ auto completionData = event->GetCompletionData();
+ if (std::holds_alternative<Ref<SystemEvent>>(completionData)) {
+ hasProgressingEvents |=
+ std::get<Ref<SystemEvent>>(completionData)->IsProgressing();
+ } else {
+ hasProgressingEvents = true;
+ }
+
futures.push_back(
TrackedFutureWaitInfo{futureID, TrackedEvent::WaitRef{event.Get()}, 0, false});
}
@@ -429,7 +440,6 @@
waitStatus = WaitImpl(futures, Nanoseconds(0));
if (waitStatus == wgpu::WaitStatus::TimedOut) {
- // Return the beginning to indicate that nothing completed.
return true;
}
DAWN_ASSERT(waitStatus == wgpu::WaitStatus::Success);
@@ -449,7 +459,9 @@
for (auto& event : completable) {
event->EnsureComplete(EventCompletionType::Ready);
}
- return needFutureProcessEvents;
+ // Note that in the event of all progressing events completing, but there exists non-progressing
+ // events, we will return true one extra time.
+ return hasIncompleteEvents && hasProgressingEvents;
}
wgpu::WaitStatus EventManager::WaitAny(size_t count, FutureWaitInfo* infos, Nanoseconds timeout) {
@@ -548,6 +560,7 @@
: TrackedEvent(callbackMode, SystemEvent::CreateSignaled()) {}
EventManager::TrackedEvent::~TrackedEvent() {
+ DAWN_ASSERT(mFutureID != kNullFutureID);
DAWN_ASSERT(mCompleted);
}
diff --git a/src/dawn/native/EventManager.h b/src/dawn/native/EventManager.h
index 024dc37..b824610 100644
--- a/src/dawn/native/EventManager.h
+++ b/src/dawn/native/EventManager.h
@@ -72,8 +72,8 @@
class TrackedEvent;
// Track a TrackedEvent and give it a FutureID.
- [[nodiscard]] FutureID TrackEvent(Ref<TrackedEvent>&&);
- void SetFutureReady(FutureID futureID);
+ FutureID TrackEvent(Ref<TrackedEvent>&&);
+ void SetFutureReady(TrackedEvent* event);
// Returns true if future ProcessEvents is needed.
bool ProcessPollEvents();
@@ -152,6 +152,7 @@
virtual void Complete(EventCompletionType) = 0;
wgpu::CallbackMode mCallbackMode;
+ FutureID mFutureID = kNullFutureID;
#if DAWN_ENABLE_ASSERTS
std::atomic<bool> mCurrentlyBeingWaited = false;
diff --git a/src/dawn/native/ObjectBase.h b/src/dawn/native/ObjectBase.h
index 9fd8f51..6ac6ac9 100644
--- a/src/dawn/native/ObjectBase.h
+++ b/src/dawn/native/ObjectBase.h
@@ -172,6 +172,9 @@
template <class T>
T* ReturnToAPI(Ref<T>&& object) {
+ if (object == nullptr) {
+ return nullptr;
+ }
if constexpr (T::HasExternalRefCount) {
// For an object which has external ref count, just need to increase the external ref count,
// and keep the total ref count unchanged.
diff --git a/src/dawn/native/PhysicalDevice.cpp b/src/dawn/native/PhysicalDevice.cpp
index b90c77c..fa91398 100644
--- a/src/dawn/native/PhysicalDevice.cpp
+++ b/src/dawn/native/PhysicalDevice.cpp
@@ -29,6 +29,7 @@
#include <algorithm>
#include <memory>
+#include <utility>
#include "dawn/common/Constants.h"
#include "dawn/common/GPUInfo.h"
@@ -72,8 +73,9 @@
ResultOrError<Ref<DeviceBase>> PhysicalDeviceBase::CreateDevice(
AdapterBase* adapter,
const UnpackedPtr<DeviceDescriptor>& descriptor,
- const TogglesState& deviceToggles) {
- return CreateDeviceImpl(adapter, descriptor, deviceToggles);
+ const TogglesState& deviceToggles,
+ Ref<DeviceBase::DeviceLostEvent>&& lostEvent) {
+ return CreateDeviceImpl(adapter, descriptor, deviceToggles, std::move(lostEvent));
}
void PhysicalDeviceBase::InitializeVendorArchitectureImpl() {
diff --git a/src/dawn/native/PhysicalDevice.h b/src/dawn/native/PhysicalDevice.h
index c2ac2a6..2c1861f 100644
--- a/src/dawn/native/PhysicalDevice.h
+++ b/src/dawn/native/PhysicalDevice.h
@@ -37,6 +37,7 @@
#include "dawn/common/Ref.h"
#include "dawn/common/RefCounted.h"
#include "dawn/common/ityp_span.h"
+#include "dawn/native/Device.h"
#include "dawn/native/Error.h"
#include "dawn/native/Features.h"
#include "dawn/native/Forward.h"
@@ -67,7 +68,8 @@
ResultOrError<Ref<DeviceBase>> CreateDevice(AdapterBase* adapter,
const UnpackedPtr<DeviceDescriptor>& descriptor,
- const TogglesState& deviceToggles);
+ const TogglesState& deviceToggles,
+ Ref<DeviceBase::DeviceLostEvent>&& lostEvent);
uint32_t GetVendorId() const;
uint32_t GetDeviceId() const;
@@ -141,7 +143,8 @@
virtual ResultOrError<Ref<DeviceBase>> CreateDeviceImpl(
AdapterBase* adapter,
const UnpackedPtr<DeviceDescriptor>& descriptor,
- const TogglesState& deviceToggles) = 0;
+ const TogglesState& deviceToggles,
+ Ref<DeviceBase::DeviceLostEvent>&& lostEvent) = 0;
virtual MaybeError InitializeImpl() = 0;
diff --git a/src/dawn/native/SystemEvent.cpp b/src/dawn/native/SystemEvent.cpp
index 6d2a3a6..c22787f 100644
--- a/src/dawn/native/SystemEvent.cpp
+++ b/src/dawn/native/SystemEvent.cpp
@@ -130,6 +130,15 @@
return ev;
}
+// static
+Ref<SystemEvent> SystemEvent::CreateNonProgressingEvent() {
+ return AcquireRef(new SystemEvent(kNonProgressingPayload));
+}
+
+bool SystemEvent::IsProgressing() const {
+ return GetRefCountPayload() != kNonProgressingPayload;
+}
+
bool SystemEvent::IsSignaled() const {
return mSignaled.load(std::memory_order_acquire);
}
diff --git a/src/dawn/native/SystemEvent.h b/src/dawn/native/SystemEvent.h
index 9e75f64..f641cf7 100644
--- a/src/dawn/native/SystemEvent.h
+++ b/src/dawn/native/SystemEvent.h
@@ -103,8 +103,12 @@
class SystemEvent : public RefCounted {
public:
- static Ref<SystemEvent> CreateSignaled();
+ using RefCounted::RefCounted;
+ static Ref<SystemEvent> CreateSignaled();
+ static Ref<SystemEvent> CreateNonProgressingEvent();
+
+ bool IsProgressing() const;
bool IsSignaled() const;
void Signal();
@@ -113,6 +117,10 @@
const SystemEventReceiver& GetOrCreateSystemEventReceiver();
private:
+ // Some SystemEvents may be non-progressing, i.e. DeviceLost. We tag these events so that we can
+ // correctly return whether there is progressing work when users are polling.
+ static constexpr uint64_t kNonProgressingPayload = 1;
+
// mSignaled indicates whether the event has already been signaled.
// It is stored outside the mPipe mutex so its status can quickly be checked without
// acquiring a lock.
diff --git a/src/dawn/native/d3d/DeviceD3D.cpp b/src/dawn/native/d3d/DeviceD3D.cpp
index 312acb4..02d0621 100644
--- a/src/dawn/native/d3d/DeviceD3D.cpp
+++ b/src/dawn/native/d3d/DeviceD3D.cpp
@@ -27,6 +27,8 @@
#include "dawn/native/d3d/DeviceD3D.h"
+#include <utility>
+
#include "dawn/native/d3d/BackendD3D.h"
#include "dawn/native/d3d/ExternalImageDXGIImpl.h"
#include "dawn/native/d3d/Forward.h"
@@ -37,8 +39,9 @@
Device::Device(AdapterBase* adapter,
const UnpackedPtr<DeviceDescriptor>& descriptor,
- const TogglesState& deviceToggles)
- : DeviceBase(adapter, descriptor, deviceToggles) {}
+ const TogglesState& deviceToggles,
+ Ref<DeviceBase::DeviceLostEvent>&& lostEvent)
+ : DeviceBase(adapter, descriptor, deviceToggles, std::move(lostEvent)) {}
Device::~Device() {
Destroy();
diff --git a/src/dawn/native/d3d/DeviceD3D.h b/src/dawn/native/d3d/DeviceD3D.h
index cae7682..272abd9 100644
--- a/src/dawn/native/d3d/DeviceD3D.h
+++ b/src/dawn/native/d3d/DeviceD3D.h
@@ -47,7 +47,8 @@
public:
Device(AdapterBase* adapter,
const UnpackedPtr<DeviceDescriptor>& descriptor,
- const TogglesState& deviceToggles);
+ const TogglesState& deviceToggles,
+ Ref<DeviceBase::DeviceLostEvent>&& lostEvent);
~Device() override;
ResultOrError<wgpu::TextureUsage> GetSupportedSurfaceUsageImpl(
diff --git a/src/dawn/native/d3d11/DeviceD3D11.cpp b/src/dawn/native/d3d11/DeviceD3D11.cpp
index dc5b66d..c673c0d 100644
--- a/src/dawn/native/d3d11/DeviceD3D11.cpp
+++ b/src/dawn/native/d3d11/DeviceD3D11.cpp
@@ -112,8 +112,10 @@
// static
ResultOrError<Ref<Device>> Device::Create(AdapterBase* adapter,
const UnpackedPtr<DeviceDescriptor>& descriptor,
- const TogglesState& deviceToggles) {
- Ref<Device> device = AcquireRef(new Device(adapter, descriptor, deviceToggles));
+ const TogglesState& deviceToggles,
+ Ref<DeviceBase::DeviceLostEvent>&& lostEvent) {
+ Ref<Device> device =
+ AcquireRef(new Device(adapter, descriptor, deviceToggles, std::move(lostEvent)));
DAWN_TRY(device->Initialize(descriptor));
return device;
}
diff --git a/src/dawn/native/d3d11/DeviceD3D11.h b/src/dawn/native/d3d11/DeviceD3D11.h
index 21aca83..b9bc161 100644
--- a/src/dawn/native/d3d11/DeviceD3D11.h
+++ b/src/dawn/native/d3d11/DeviceD3D11.h
@@ -44,7 +44,8 @@
public:
static ResultOrError<Ref<Device>> Create(AdapterBase* adapter,
const UnpackedPtr<DeviceDescriptor>& descriptor,
- const TogglesState& deviceToggles);
+ const TogglesState& deviceToggles,
+ Ref<DeviceBase::DeviceLostEvent>&& lostEvent);
~Device() override;
MaybeError Initialize(const UnpackedPtr<DeviceDescriptor>& descriptor);
diff --git a/src/dawn/native/d3d11/PhysicalDeviceD3D11.cpp b/src/dawn/native/d3d11/PhysicalDeviceD3D11.cpp
index 938bbf5..1c33d0b 100644
--- a/src/dawn/native/d3d11/PhysicalDeviceD3D11.cpp
+++ b/src/dawn/native/d3d11/PhysicalDeviceD3D11.cpp
@@ -320,8 +320,9 @@
ResultOrError<Ref<DeviceBase>> PhysicalDevice::CreateDeviceImpl(
AdapterBase* adapter,
const UnpackedPtr<DeviceDescriptor>& descriptor,
- const TogglesState& deviceToggles) {
- return Device::Create(adapter, descriptor, deviceToggles);
+ const TogglesState& deviceToggles,
+ Ref<DeviceBase::DeviceLostEvent>&& lostEvent) {
+ return Device::Create(adapter, descriptor, deviceToggles, std::move(lostEvent));
}
// Resets the backend device and creates a new one. If any D3D11 objects belonging to the
diff --git a/src/dawn/native/d3d11/PhysicalDeviceD3D11.h b/src/dawn/native/d3d11/PhysicalDeviceD3D11.h
index 01e28fc..7f0f18b 100644
--- a/src/dawn/native/d3d11/PhysicalDeviceD3D11.h
+++ b/src/dawn/native/d3d11/PhysicalDeviceD3D11.h
@@ -61,9 +61,11 @@
void SetupBackendAdapterToggles(TogglesState* adapterToggles) const override;
void SetupBackendDeviceToggles(TogglesState* deviceToggles) const override;
- ResultOrError<Ref<DeviceBase>> CreateDeviceImpl(AdapterBase* adapter,
- const UnpackedPtr<DeviceDescriptor>& descriptor,
- const TogglesState& deviceToggles) override;
+ ResultOrError<Ref<DeviceBase>> CreateDeviceImpl(
+ AdapterBase* adapter,
+ const UnpackedPtr<DeviceDescriptor>& descriptor,
+ const TogglesState& deviceToggles,
+ Ref<DeviceBase::DeviceLostEvent>&& lostEvent) override;
MaybeError ResetInternalDeviceForTestingImpl() override;
diff --git a/src/dawn/native/d3d12/DeviceD3D12.cpp b/src/dawn/native/d3d12/DeviceD3D12.cpp
index 2bec840..5aea624 100644
--- a/src/dawn/native/d3d12/DeviceD3D12.cpp
+++ b/src/dawn/native/d3d12/DeviceD3D12.cpp
@@ -81,8 +81,10 @@
// static
ResultOrError<Ref<Device>> Device::Create(AdapterBase* adapter,
const UnpackedPtr<DeviceDescriptor>& descriptor,
- const TogglesState& deviceToggles) {
- Ref<Device> device = AcquireRef(new Device(adapter, descriptor, deviceToggles));
+ const TogglesState& deviceToggles,
+ Ref<DeviceBase::DeviceLostEvent>&& lostEvent) {
+ Ref<Device> device =
+ AcquireRef(new Device(adapter, descriptor, deviceToggles, std::move(lostEvent)));
DAWN_TRY(device->Initialize(descriptor));
return device;
}
@@ -200,8 +202,9 @@
Device::Device(AdapterBase* adapter,
const UnpackedPtr<DeviceDescriptor>& descriptor,
- const TogglesState& deviceToggles)
- : Base(adapter, descriptor, deviceToggles) {}
+ const TogglesState& deviceToggles,
+ Ref<DeviceBase::DeviceLostEvent>&& lostEvent)
+ : Base(adapter, descriptor, deviceToggles, std::move(lostEvent)) {}
Device::~Device() = default;
diff --git a/src/dawn/native/d3d12/DeviceD3D12.h b/src/dawn/native/d3d12/DeviceD3D12.h
index c8040cc..c34751e 100644
--- a/src/dawn/native/d3d12/DeviceD3D12.h
+++ b/src/dawn/native/d3d12/DeviceD3D12.h
@@ -63,7 +63,8 @@
public:
static ResultOrError<Ref<Device>> Create(AdapterBase* adapter,
const UnpackedPtr<DeviceDescriptor>& descriptor,
- const TogglesState& deviceToggles);
+ const TogglesState& deviceToggles,
+ Ref<DeviceBase::DeviceLostEvent>&& lostEvent);
~Device() override;
MaybeError Initialize(const UnpackedPtr<DeviceDescriptor>& descriptor);
@@ -187,7 +188,8 @@
Device(AdapterBase* adapter,
const UnpackedPtr<DeviceDescriptor>& descriptor,
- const TogglesState& deviceToggles);
+ const TogglesState& deviceToggles,
+ Ref<DeviceBase::DeviceLostEvent>&& lostEvent);
ResultOrError<Ref<BindGroupBase>> CreateBindGroupImpl(
const BindGroupDescriptor* descriptor) override;
diff --git a/src/dawn/native/d3d12/PhysicalDeviceD3D12.cpp b/src/dawn/native/d3d12/PhysicalDeviceD3D12.cpp
index e7a8b03..cc9775d 100644
--- a/src/dawn/native/d3d12/PhysicalDeviceD3D12.cpp
+++ b/src/dawn/native/d3d12/PhysicalDeviceD3D12.cpp
@@ -792,8 +792,9 @@
ResultOrError<Ref<DeviceBase>> PhysicalDevice::CreateDeviceImpl(
AdapterBase* adapter,
const UnpackedPtr<DeviceDescriptor>& descriptor,
- const TogglesState& deviceToggles) {
- return Device::Create(adapter, descriptor, deviceToggles);
+ const TogglesState& deviceToggles,
+ Ref<DeviceBase::DeviceLostEvent>&& lostEvent) {
+ return Device::Create(adapter, descriptor, deviceToggles, std::move(lostEvent));
}
// Resets the backend device and creates a new one. If any D3D12 objects belonging to the
diff --git a/src/dawn/native/d3d12/PhysicalDeviceD3D12.h b/src/dawn/native/d3d12/PhysicalDeviceD3D12.h
index 606f709..c9e910e 100644
--- a/src/dawn/native/d3d12/PhysicalDeviceD3D12.h
+++ b/src/dawn/native/d3d12/PhysicalDeviceD3D12.h
@@ -61,9 +61,11 @@
void SetupBackendAdapterToggles(TogglesState* adapterToggles) const override;
void SetupBackendDeviceToggles(TogglesState* deviceToggles) const override;
- ResultOrError<Ref<DeviceBase>> CreateDeviceImpl(AdapterBase* adapter,
- const UnpackedPtr<DeviceDescriptor>& descriptor,
- const TogglesState& deviceToggles) override;
+ ResultOrError<Ref<DeviceBase>> CreateDeviceImpl(
+ AdapterBase* adapter,
+ const UnpackedPtr<DeviceDescriptor>& descriptor,
+ const TogglesState& deviceToggles,
+ Ref<DeviceBase::DeviceLostEvent>&& lostEvent) override;
MaybeError ResetInternalDeviceForTestingImpl() override;
diff --git a/src/dawn/native/metal/BackendMTL.mm b/src/dawn/native/metal/BackendMTL.mm
index e9f65b0..41f36ec 100644
--- a/src/dawn/native/metal/BackendMTL.mm
+++ b/src/dawn/native/metal/BackendMTL.mm
@@ -304,10 +304,12 @@
bool SupportsFeatureLevel(FeatureLevel) const override { return true; }
private:
- ResultOrError<Ref<DeviceBase>> CreateDeviceImpl(AdapterBase* adapter,
- const UnpackedPtr<DeviceDescriptor>& descriptor,
- const TogglesState& deviceToggles) override {
- return Device::Create(adapter, mDevice, descriptor, deviceToggles);
+ ResultOrError<Ref<DeviceBase>> CreateDeviceImpl(
+ AdapterBase* adapter,
+ const UnpackedPtr<DeviceDescriptor>& descriptor,
+ const TogglesState& deviceToggles,
+ Ref<DeviceBase::DeviceLostEvent>&& lostEvent) override {
+ return Device::Create(adapter, mDevice, descriptor, deviceToggles, std::move(lostEvent));
}
void SetupBackendAdapterToggles(TogglesState* adapterToggles) const override {}
diff --git a/src/dawn/native/metal/DeviceMTL.h b/src/dawn/native/metal/DeviceMTL.h
index fa40bbb..2c8df01 100644
--- a/src/dawn/native/metal/DeviceMTL.h
+++ b/src/dawn/native/metal/DeviceMTL.h
@@ -53,7 +53,8 @@
static ResultOrError<Ref<Device>> Create(AdapterBase* adapter,
NSPRef<id<MTLDevice>> mtlDevice,
const UnpackedPtr<DeviceDescriptor>& descriptor,
- const TogglesState& deviceToggles);
+ const TogglesState& deviceToggles,
+ Ref<DeviceBase::DeviceLostEvent>&& lostEvent);
~Device() override;
MaybeError Initialize(const UnpackedPtr<DeviceDescriptor>& descriptor);
@@ -90,7 +91,8 @@
Device(AdapterBase* adapter,
NSPRef<id<MTLDevice>> mtlDevice,
const UnpackedPtr<DeviceDescriptor>& descriptor,
- const TogglesState& deviceToggles);
+ const TogglesState& deviceToggles,
+ Ref<DeviceBase::DeviceLostEvent>&& lostEvent);
ResultOrError<Ref<BindGroupBase>> CreateBindGroupImpl(
const BindGroupDescriptor* descriptor) override;
diff --git a/src/dawn/native/metal/DeviceMTL.mm b/src/dawn/native/metal/DeviceMTL.mm
index 2ac4c5f..10e9fbb 100644
--- a/src/dawn/native/metal/DeviceMTL.mm
+++ b/src/dawn/native/metal/DeviceMTL.mm
@@ -124,10 +124,11 @@
ResultOrError<Ref<Device>> Device::Create(AdapterBase* adapter,
NSPRef<id<MTLDevice>> mtlDevice,
const UnpackedPtr<DeviceDescriptor>& descriptor,
- const TogglesState& deviceToggles) {
+ const TogglesState& deviceToggles,
+ Ref<DeviceBase::DeviceLostEvent>&& lostEvent) {
@autoreleasepool {
- Ref<Device> device =
- AcquireRef(new Device(adapter, std::move(mtlDevice), descriptor, deviceToggles));
+ Ref<Device> device = AcquireRef(new Device(adapter, std::move(mtlDevice), descriptor,
+ deviceToggles, std::move(lostEvent)));
DAWN_TRY(device->Initialize(descriptor));
return device;
}
@@ -136,8 +137,10 @@
Device::Device(AdapterBase* adapter,
NSPRef<id<MTLDevice>> mtlDevice,
const UnpackedPtr<DeviceDescriptor>& descriptor,
- const TogglesState& deviceToggles)
- : DeviceBase(adapter, descriptor, deviceToggles), mMtlDevice(std::move(mtlDevice)) {
+ const TogglesState& deviceToggles,
+ Ref<DeviceBase::DeviceLostEvent>&& lostEvent)
+ : DeviceBase(adapter, descriptor, deviceToggles, std::move(lostEvent)),
+ mMtlDevice(std::move(mtlDevice)) {
// On macOS < 11.0, we only can check whether counter sampling is supported, and the counter
// only can be sampled between command boundary using sampleCountersInBuffer API if it's
// supported.
diff --git a/src/dawn/native/null/DeviceNull.cpp b/src/dawn/native/null/DeviceNull.cpp
index db54784..adfc7be 100644
--- a/src/dawn/native/null/DeviceNull.cpp
+++ b/src/dawn/native/null/DeviceNull.cpp
@@ -92,8 +92,9 @@
ResultOrError<Ref<DeviceBase>> PhysicalDevice::CreateDeviceImpl(
AdapterBase* adapter,
const UnpackedPtr<DeviceDescriptor>& descriptor,
- const TogglesState& deviceToggles) {
- return Device::Create(adapter, descriptor, deviceToggles);
+ const TogglesState& deviceToggles,
+ Ref<DeviceBase::DeviceLostEvent>&& lostEvent) {
+ return Device::Create(adapter, descriptor, deviceToggles, std::move(lostEvent));
}
void PhysicalDevice::PopulateBackendProperties(UnpackedPtr<AdapterProperties>& properties) const {
@@ -165,8 +166,10 @@
// static
ResultOrError<Ref<Device>> Device::Create(AdapterBase* adapter,
const UnpackedPtr<DeviceDescriptor>& descriptor,
- const TogglesState& deviceToggles) {
- Ref<Device> device = AcquireRef(new Device(adapter, descriptor, deviceToggles));
+ const TogglesState& deviceToggles,
+ Ref<DeviceBase::DeviceLostEvent>&& lostEvent) {
+ Ref<Device> device =
+ AcquireRef(new Device(adapter, descriptor, deviceToggles, std::move(lostEvent)));
DAWN_TRY(device->Initialize(descriptor));
return device;
}
diff --git a/src/dawn/native/null/DeviceNull.h b/src/dawn/native/null/DeviceNull.h
index db6b596..694e049 100644
--- a/src/dawn/native/null/DeviceNull.h
+++ b/src/dawn/native/null/DeviceNull.h
@@ -104,7 +104,8 @@
public:
static ResultOrError<Ref<Device>> Create(AdapterBase* adapter,
const UnpackedPtr<DeviceDescriptor>& descriptor,
- const TogglesState& deviceToggles);
+ const TogglesState& deviceToggles,
+ Ref<DeviceBase::DeviceLostEvent>&& lostEvent);
~Device() override;
MaybeError Initialize(const UnpackedPtr<DeviceDescriptor>& descriptor);
@@ -208,9 +209,11 @@
void SetupBackendAdapterToggles(TogglesState* adapterToggles) const override;
void SetupBackendDeviceToggles(TogglesState* deviceToggles) const override;
- ResultOrError<Ref<DeviceBase>> CreateDeviceImpl(AdapterBase* adapter,
- const UnpackedPtr<DeviceDescriptor>& descriptor,
- const TogglesState& deviceToggles) override;
+ ResultOrError<Ref<DeviceBase>> CreateDeviceImpl(
+ AdapterBase* adapter,
+ const UnpackedPtr<DeviceDescriptor>& descriptor,
+ const TogglesState& deviceToggles,
+ Ref<DeviceBase::DeviceLostEvent>&& lostEvent) override;
void PopulateBackendProperties(UnpackedPtr<AdapterProperties>& properties) const override;
};
diff --git a/src/dawn/native/opengl/DeviceGL.cpp b/src/dawn/native/opengl/DeviceGL.cpp
index 7c5b95f..7293d47 100644
--- a/src/dawn/native/opengl/DeviceGL.cpp
+++ b/src/dawn/native/opengl/DeviceGL.cpp
@@ -123,9 +123,10 @@
const UnpackedPtr<DeviceDescriptor>& descriptor,
const OpenGLFunctions& functions,
std::unique_ptr<Context> context,
- const TogglesState& deviceToggles) {
- Ref<Device> device =
- AcquireRef(new Device(adapter, descriptor, functions, std::move(context), deviceToggles));
+ const TogglesState& deviceToggles,
+ Ref<DeviceBase::DeviceLostEvent>&& lostEvent) {
+ Ref<Device> device = AcquireRef(new Device(adapter, descriptor, functions, std::move(context),
+ deviceToggles, std::move(lostEvent)));
DAWN_TRY(device->Initialize(descriptor));
return device;
}
@@ -134,8 +135,9 @@
const UnpackedPtr<DeviceDescriptor>& descriptor,
const OpenGLFunctions& functions,
std::unique_ptr<Context> context,
- const TogglesState& deviceToggles)
- : DeviceBase(adapter, descriptor, deviceToggles),
+ const TogglesState& deviceToggles,
+ Ref<DeviceBase::DeviceLostEvent>&& lostEvent)
+ : DeviceBase(adapter, descriptor, deviceToggles, std::move(lostEvent)),
mGL(functions),
mContext(std::move(context)) {}
diff --git a/src/dawn/native/opengl/DeviceGL.h b/src/dawn/native/opengl/DeviceGL.h
index 2ad1e2e..1302ef7 100644
--- a/src/dawn/native/opengl/DeviceGL.h
+++ b/src/dawn/native/opengl/DeviceGL.h
@@ -55,7 +55,8 @@
const UnpackedPtr<DeviceDescriptor>& descriptor,
const OpenGLFunctions& functions,
std::unique_ptr<Context> context,
- const TogglesState& deviceToggles);
+ const TogglesState& deviceToggles,
+ Ref<DeviceBase::DeviceLostEvent>&& lostEvent);
~Device() override;
MaybeError Initialize(const UnpackedPtr<DeviceDescriptor>& descriptor);
@@ -110,7 +111,8 @@
const UnpackedPtr<DeviceDescriptor>& descriptor,
const OpenGLFunctions& functions,
std::unique_ptr<Context> context,
- const TogglesState& deviceToggless);
+ const TogglesState& deviceToggles,
+ Ref<DeviceBase::DeviceLostEvent>&& lostEvent);
ResultOrError<Ref<BindGroupBase>> CreateBindGroupImpl(
const BindGroupDescriptor* descriptor) override;
diff --git a/src/dawn/native/opengl/PhysicalDeviceGL.cpp b/src/dawn/native/opengl/PhysicalDeviceGL.cpp
index 25bf2cc..b9a14a7 100644
--- a/src/dawn/native/opengl/PhysicalDeviceGL.cpp
+++ b/src/dawn/native/opengl/PhysicalDeviceGL.cpp
@@ -430,7 +430,8 @@
ResultOrError<Ref<DeviceBase>> PhysicalDevice::CreateDeviceImpl(
AdapterBase* adapter,
const UnpackedPtr<DeviceDescriptor>& descriptor,
- const TogglesState& deviceToggles) {
+ const TogglesState& deviceToggles,
+ Ref<DeviceBase::DeviceLostEvent>&& lostEvent) {
EGLenum api =
GetBackendType() == wgpu::BackendType::OpenGL ? EGL_OPENGL_API : EGL_OPENGL_ES_API;
std::unique_ptr<Device::Context> context;
@@ -443,7 +444,8 @@
DAWN_TRY_ASSIGN(context,
ContextEGL::Create(mEGLFunctions, api, mDisplay, useANGLETextureSharing));
- return Device::Create(adapter, descriptor, mFunctions, std::move(context), deviceToggles);
+ return Device::Create(adapter, descriptor, mFunctions, std::move(context), deviceToggles,
+ std::move(lostEvent));
}
bool PhysicalDevice::SupportsFeatureLevel(FeatureLevel featureLevel) const {
diff --git a/src/dawn/native/opengl/PhysicalDeviceGL.h b/src/dawn/native/opengl/PhysicalDeviceGL.h
index e2f4771..2dda6e9 100644
--- a/src/dawn/native/opengl/PhysicalDeviceGL.h
+++ b/src/dawn/native/opengl/PhysicalDeviceGL.h
@@ -61,9 +61,11 @@
void SetupBackendAdapterToggles(TogglesState* adapterToggles) const override;
void SetupBackendDeviceToggles(TogglesState* deviceToggles) const override;
- ResultOrError<Ref<DeviceBase>> CreateDeviceImpl(AdapterBase* adapter,
- const UnpackedPtr<DeviceDescriptor>& descriptor,
- const TogglesState& deviceToggles) override;
+ ResultOrError<Ref<DeviceBase>> CreateDeviceImpl(
+ AdapterBase* adapter,
+ const UnpackedPtr<DeviceDescriptor>& descriptor,
+ const TogglesState& deviceToggles,
+ Ref<DeviceBase::DeviceLostEvent>&& lostEvent) override;
void PopulateBackendProperties(UnpackedPtr<AdapterProperties>& properties) const override;
diff --git a/src/dawn/native/vulkan/DeviceVk.cpp b/src/dawn/native/vulkan/DeviceVk.cpp
index d88ee8f..265185c 100644
--- a/src/dawn/native/vulkan/DeviceVk.cpp
+++ b/src/dawn/native/vulkan/DeviceVk.cpp
@@ -77,16 +77,20 @@
// static
ResultOrError<Ref<Device>> Device::Create(AdapterBase* adapter,
const UnpackedPtr<DeviceDescriptor>& descriptor,
- const TogglesState& deviceToggles) {
- Ref<Device> device = AcquireRef(new Device(adapter, descriptor, deviceToggles));
+ const TogglesState& deviceToggles,
+ Ref<DeviceBase::DeviceLostEvent>&& lostEvent) {
+ Ref<Device> device =
+ AcquireRef(new Device(adapter, descriptor, deviceToggles, std::move(lostEvent)));
DAWN_TRY(device->Initialize(descriptor));
return device;
}
Device::Device(AdapterBase* adapter,
const UnpackedPtr<DeviceDescriptor>& descriptor,
- const TogglesState& deviceToggles)
- : DeviceBase(adapter, descriptor, deviceToggles), mDebugPrefix(GetNextDeviceDebugPrefix()) {}
+ const TogglesState& deviceToggles,
+ Ref<DeviceBase::DeviceLostEvent>&& lostEvent)
+ : DeviceBase(adapter, descriptor, deviceToggles, std::move(lostEvent)),
+ mDebugPrefix(GetNextDeviceDebugPrefix()) {}
MaybeError Device::Initialize(const UnpackedPtr<DeviceDescriptor>& descriptor) {
// Copy the adapter's device info to the device so that we can change the "knobs"
diff --git a/src/dawn/native/vulkan/DeviceVk.h b/src/dawn/native/vulkan/DeviceVk.h
index 82856ec..ab716ff 100644
--- a/src/dawn/native/vulkan/DeviceVk.h
+++ b/src/dawn/native/vulkan/DeviceVk.h
@@ -59,7 +59,8 @@
public:
static ResultOrError<Ref<Device>> Create(AdapterBase* adapter,
const UnpackedPtr<DeviceDescriptor>& descriptor,
- const TogglesState& deviceToggles);
+ const TogglesState& deviceToggles,
+ Ref<DeviceBase::DeviceLostEvent>&& lostEvent);
~Device() override;
MaybeError Initialize(const UnpackedPtr<DeviceDescriptor>& descriptor);
@@ -126,7 +127,8 @@
private:
Device(AdapterBase* adapter,
const UnpackedPtr<DeviceDescriptor>& descriptor,
- const TogglesState& deviceToggles);
+ const TogglesState& deviceToggles,
+ Ref<DeviceBase::DeviceLostEvent>&& lostEvent);
ResultOrError<Ref<BindGroupBase>> CreateBindGroupImpl(
const BindGroupDescriptor* descriptor) override;
diff --git a/src/dawn/native/vulkan/PhysicalDeviceVk.cpp b/src/dawn/native/vulkan/PhysicalDeviceVk.cpp
index f5894dd..ce58ea0 100644
--- a/src/dawn/native/vulkan/PhysicalDeviceVk.cpp
+++ b/src/dawn/native/vulkan/PhysicalDeviceVk.cpp
@@ -29,6 +29,7 @@
#include <algorithm>
#include <string>
+#include <utility>
#include "dawn/common/GPUInfo.h"
#include "dawn/native/ChainUtils.h"
@@ -767,8 +768,9 @@
ResultOrError<Ref<DeviceBase>> PhysicalDevice::CreateDeviceImpl(
AdapterBase* adapter,
const UnpackedPtr<DeviceDescriptor>& descriptor,
- const TogglesState& deviceToggles) {
- return Device::Create(adapter, descriptor, deviceToggles);
+ const TogglesState& deviceToggles,
+ Ref<DeviceBase::DeviceLostEvent>&& lostEvent) {
+ return Device::Create(adapter, descriptor, deviceToggles, std::move(lostEvent));
}
FeatureValidationResult PhysicalDevice::ValidateFeatureSupportedWithTogglesImpl(
diff --git a/src/dawn/native/vulkan/PhysicalDeviceVk.h b/src/dawn/native/vulkan/PhysicalDeviceVk.h
index bbb66c5..b40f0b1 100644
--- a/src/dawn/native/vulkan/PhysicalDeviceVk.h
+++ b/src/dawn/native/vulkan/PhysicalDeviceVk.h
@@ -72,9 +72,11 @@
void SetupBackendAdapterToggles(TogglesState* adapterToggles) const override;
void SetupBackendDeviceToggles(TogglesState* deviceToggles) const override;
- ResultOrError<Ref<DeviceBase>> CreateDeviceImpl(AdapterBase* adapter,
- const UnpackedPtr<DeviceDescriptor>& descriptor,
- const TogglesState& deviceToggles) override;
+ ResultOrError<Ref<DeviceBase>> CreateDeviceImpl(
+ AdapterBase* adapter,
+ const UnpackedPtr<DeviceDescriptor>& descriptor,
+ const TogglesState& deviceToggles,
+ Ref<DeviceBase::DeviceLostEvent>&& lostEvent) override;
uint32_t FindDefaultComputeSubgroupSize() const;
bool CheckSemaphoreSupport(DeviceExt deviceExt,
diff --git a/src/dawn/tests/DawnTest.cpp b/src/dawn/tests/DawnTest.cpp
index b58fb64..6896f61 100644
--- a/src/dawn/tests/DawnTest.cpp
+++ b/src/dawn/tests/DawnTest.cpp
@@ -67,6 +67,9 @@
namespace dawn {
namespace {
+using testing::_;
+using testing::AtMost;
+
struct MapReadUserdata {
raw_ptr<DawnTestBase> test;
size_t slot;
@@ -1089,6 +1092,16 @@
deviceDescriptor.requiredFeatures = requiredFeatures.data();
deviceDescriptor.requiredFeatureCount = requiredFeatures.size();
+ // Set up the mocks for device loss.
+ void* deviceUserdata = GetUniqueUserdata();
+ deviceDescriptor.deviceLostCallbackInfo.mode = wgpu::CallbackMode::AllowSpontaneous;
+ deviceDescriptor.deviceLostCallbackInfo.callback = mDeviceLostCallback.Callback();
+ deviceDescriptor.deviceLostCallbackInfo.userdata =
+ mDeviceLostCallback.MakeUserdata(deviceUserdata);
+ // The loss of the device is expected to happen at the end of the test so at it directly.
+ EXPECT_CALL(mDeviceLostCallback, Call(_, WGPUDeviceLostReason_Destroyed, _, deviceUserdata))
+ .Times(AtMost(1));
+
wgpu::DawnCacheDeviceDescriptor cacheDesc = {};
deviceDescriptor.nextInChain = &cacheDesc;
cacheDesc.isolationKey = isolationKey.c_str();
@@ -1111,11 +1124,6 @@
// RequestDevice is overriden by CreateDeviceImpl and device descriptor is ignored by it.
wgpu::DeviceDescriptor deviceDesc = {};
- // Set up the mocks for device loss.
- void* deviceUserdata = GetUniqueUserdata();
- deviceDesc.deviceLostCallback = mDeviceLostCallback.Callback();
- deviceDesc.deviceLostUserdata = mDeviceLostCallback.MakeUserdata(deviceUserdata);
-
adapter.RequestDevice(
&deviceDesc,
[](WGPURequestDeviceStatus, WGPUDevice cDevice, const char*, void* userdata) {
@@ -1129,11 +1137,6 @@
apiDevice.SetUncapturedErrorCallback(mDeviceErrorCallback.Callback(),
mDeviceErrorCallback.MakeUserdata(apiDevice.Get()));
- // The loss of the device is expected to happen at the end of the test so at it directly.
- EXPECT_CALL(mDeviceLostCallback,
- Call(WGPUDeviceLostReason_Destroyed, testing::_, deviceUserdata))
- .Times(testing::AtMost(1));
-
apiDevice.SetLoggingCallback(
[](WGPULoggingType type, char const* message, void*) {
switch (type) {
@@ -1222,8 +1225,7 @@
resolvedDevice = device;
}
- EXPECT_CALL(mDeviceLostCallback, Call(WGPUDeviceLostReason_Undefined, testing::_, testing::_))
- .Times(1);
+ EXPECT_CALL(mDeviceLostCallback, Call(_, WGPUDeviceLostReason_Undefined, _, _)).Times(1);
resolvedDevice.ForceLoss(wgpu::DeviceLostReason::Undefined, "Device lost for testing");
resolvedDevice.Tick();
}
diff --git a/src/dawn/tests/DawnTest.h b/src/dawn/tests/DawnTest.h
index 03378f0..e58f2f4 100644
--- a/src/dawn/tests/DawnTest.h
+++ b/src/dawn/tests/DawnTest.h
@@ -336,7 +336,7 @@
// device loss that aren't expected should result in test failures and not just some warnings
// printed to stdout.
testing::StrictMock<testing::MockCallback<WGPUErrorCallback>> mDeviceErrorCallback;
- testing::StrictMock<testing::MockCallback<WGPUDeviceLostCallback>> mDeviceLostCallback;
+ testing::StrictMock<testing::MockCallback<WGPUDeviceLostCallbackNew>> mDeviceLostCallback;
// Helper methods to implement the EXPECT_ macros
std::ostringstream& AddBufferExpectation(const char* file,
diff --git a/src/dawn/tests/end2end/DeviceLostTests.cpp b/src/dawn/tests/end2end/DeviceLostTests.cpp
index 5c8ebad..98bf703 100644
--- a/src/dawn/tests/end2end/DeviceLostTests.cpp
+++ b/src/dawn/tests/end2end/DeviceLostTests.cpp
@@ -436,12 +436,6 @@
// First LoseDeviceForTesting call should occur normally. The callback is already set in SetUp.
LoseDeviceForTesting();
- // Second LoseDeviceForTesting call should result in no callbacks. Note we also reset the
- // callback first since by default the device clears the callback after the device is lost.
- device.SetDeviceLostCallback(mDeviceLostCallback.Callback(),
- mDeviceLostCallback.MakeUserdata(device.Get()));
- EXPECT_CALL(mDeviceLostCallback, Call(WGPUDeviceLostReason_Undefined, testing::_, device.Get()))
- .Times(0);
device.ForceLoss(wgpu::DeviceLostReason::Undefined, "Device lost for testing");
FlushWire();
testing::Mock::VerifyAndClearExpectations(&mDeviceLostCallback);
diff --git a/src/dawn/tests/end2end/EventTests.cpp b/src/dawn/tests/end2end/EventTests.cpp
index 30acd1c..f4b0342 100644
--- a/src/dawn/tests/end2end/EventTests.cpp
+++ b/src/dawn/tests/end2end/EventTests.cpp
@@ -152,7 +152,7 @@
void LoseTestDevice() {
EXPECT_CALL(mDeviceLostCallback,
- Call(WGPUDeviceLostReason_Undefined, testing::_, testing::_))
+ Call(testing::_, WGPUDeviceLostReason_Undefined, testing::_, testing::_))
.Times(1);
testDevice.ForceLoss(wgpu::DeviceLostReason::Undefined, "Device lost for testing");
testInstance.ProcessEvents();
@@ -427,12 +427,15 @@
},
&status});
- // Callback should have been called immediately because we leaked it since there's no way to
- // call WaitAny or ProcessEvents anymore.
- //
- // TODO(crbug.com/dawn/2059): Once Spontaneous is implemented, this should no longer expect the
- // callback to be cleaned up immediately (and should expect it to happen on a future Tick).
- ASSERT_EQ(status, WGPUQueueWorkDoneStatus_InstanceDropped);
+ if (IsSpontaneous()) {
+ // TODO(crbug.com/dawn/2059): Once Spontaneous is implemented, this should no longer expect
+ // the callback to be cleaned up immediately (and should expect it to happen on a future
+ // Tick).
+ ASSERT_THAT(status, AnyOf(Eq(WGPUQueueWorkDoneStatus_Success),
+ Eq(WGPUQueueWorkDoneStatus_InstanceDropped)));
+ } else {
+ ASSERT_EQ(status, WGPUQueueWorkDoneStatus_InstanceDropped);
+ }
}
TEST_P(EventCompletionTests, WorkDoneDropInstanceAfterEvent) {
@@ -449,8 +452,6 @@
},
&status});
- // For spontaneous cases, it is possible that since there is no work to be done, the serial can
- // already be caught up and hence the callback fires immediately.
if (IsSpontaneous()) {
testInstance = nullptr; // Drop the last external ref to the instance.
diff --git a/src/dawn/tests/unittests/native/mocks/DeviceMock.cpp b/src/dawn/tests/unittests/native/mocks/DeviceMock.cpp
index 07d4c4f..16b33bf 100644
--- a/src/dawn/tests/unittests/native/mocks/DeviceMock.cpp
+++ b/src/dawn/tests/unittests/native/mocks/DeviceMock.cpp
@@ -123,6 +123,7 @@
ON_CALL(*this, TickImpl).WillByDefault([]() -> MaybeError { return {}; });
// Initialize the device.
+ GetInstance()->GetEventManager()->TrackEvent(mLostEvent);
QueueDescriptor desc = {};
EXPECT_FALSE(Initialize(AcquireRef(new NiceMock<QueueMock>(this, &desc))).IsError());
}
@@ -133,6 +134,10 @@
return mInstance->GetPlatform();
}
+dawn::native::InstanceBase* DeviceMock::GetInstance() const {
+ return mInstance.Get();
+}
+
QueueMock* DeviceMock::GetQueueMock() {
return reinterpret_cast<QueueMock*>(GetQueue());
}
diff --git a/src/dawn/tests/unittests/native/mocks/DeviceMock.h b/src/dawn/tests/unittests/native/mocks/DeviceMock.h
index d49e9a8..64b405b 100644
--- a/src/dawn/tests/unittests/native/mocks/DeviceMock.h
+++ b/src/dawn/tests/unittests/native/mocks/DeviceMock.h
@@ -55,6 +55,8 @@
~DeviceMock() override;
dawn::platform::Platform* GetPlatform() const override;
+ dawn::native::InstanceBase* GetInstance() const override;
+
// Mock specific functionality.
QueueMock* GetQueueMock();
diff --git a/src/dawn/tests/unittests/validation/ValidationTest.cpp b/src/dawn/tests/unittests/validation/ValidationTest.cpp
index 7afa6cb..fc2432c 100644
--- a/src/dawn/tests/unittests/validation/ValidationTest.cpp
+++ b/src/dawn/tests/unittests/validation/ValidationTest.cpp
@@ -332,11 +332,13 @@
// Reinitialize the device.
mExpectDestruction = true;
wgpu::DeviceDescriptor deviceDescriptor = {};
- deviceDescriptor.deviceLostCallback = ValidationTest::OnDeviceLost;
- deviceDescriptor.deviceLostUserdata = this;
+ deviceDescriptor.deviceLostCallbackInfo = {nullptr, wgpu::CallbackMode::AllowSpontaneous,
+ ValidationTest::OnDeviceLost, this};
+ deviceDescriptor.uncapturedErrorCallbackInfo.callback = ValidationTest::OnDeviceError;
+ deviceDescriptor.uncapturedErrorCallbackInfo.userdata = this;
+
device = RequestDeviceSync(deviceDescriptor);
backendDevice = mLastCreatedBackendDevice;
- device.SetUncapturedErrorCallback(ValidationTest::OnDeviceError, this);
mExpectDestruction = false;
}
@@ -362,7 +364,8 @@
self->mError = true;
}
-void ValidationTest::OnDeviceLost(WGPUDeviceLostReason reason,
+void ValidationTest::OnDeviceLost(WGPUDevice const* device,
+ WGPUDeviceLostReason reason,
const char* message,
void* userdata) {
auto* self = static_cast<ValidationTest*>(userdata);
diff --git a/src/dawn/tests/unittests/validation/ValidationTest.h b/src/dawn/tests/unittests/validation/ValidationTest.h
index 8708e7a..6eb7d26 100644
--- a/src/dawn/tests/unittests/validation/ValidationTest.h
+++ b/src/dawn/tests/unittests/validation/ValidationTest.h
@@ -179,7 +179,10 @@
wgpu::Device RequestDeviceSync(const wgpu::DeviceDescriptor& deviceDesc);
static void OnDeviceError(WGPUErrorType type, const char* message, void* userdata);
- static void OnDeviceLost(WGPUDeviceLostReason reason, const char* message, void* userdata);
+ static void OnDeviceLost(WGPUDevice const* device,
+ WGPUDeviceLostReason reason,
+ const char* message,
+ void* userdata);
virtual bool UseCompatibilityMode() const;
diff --git a/src/dawn/tests/unittests/wire/WireAdapterTests.cpp b/src/dawn/tests/unittests/wire/WireAdapterTests.cpp
index 5588561..b865b3e 100644
--- a/src/dawn/tests/unittests/wire/WireAdapterTests.cpp
+++ b/src/dawn/tests/unittests/wire/WireAdapterTests.cpp
@@ -114,15 +114,18 @@
});
}
-static void DeviceLostCallback(WGPUDeviceLostReason reason, const char* message, void* userdata) {}
+static void DeviceLostCallback(WGPUDevice const* device,
+ WGPUDeviceLostReason reason,
+ const char* message,
+ void* userdata) {}
// Test that the DeviceDescriptor is not allowed to pass a device lost callback from the client to
// the server.
TEST_P(WireAdapterTests, RequestDeviceAssertsOnLostCallbackPointer) {
int userdata = 1337;
wgpu::DeviceDescriptor desc = {};
- desc.deviceLostCallback = DeviceLostCallback;
- desc.deviceLostUserdata = &userdata;
+ desc.deviceLostCallbackInfo.callback = DeviceLostCallback;
+ desc.deviceLostCallbackInfo.userdata = &userdata;
AdapterRequestDevice(adapter, &desc);
@@ -130,9 +133,11 @@
.WillOnce(WithArg<1>(Invoke([&](const WGPUDeviceDescriptor* apiDesc) {
EXPECT_STREQ(apiDesc->label, desc.label);
- // The callback should not be passed through to the server.
- ASSERT_EQ(apiDesc->deviceLostCallback, nullptr);
- ASSERT_EQ(apiDesc->deviceLostUserdata, nullptr);
+ // The callback should not be passed through to the server, and it should be overridden.
+ ASSERT_NE(apiDesc->deviceLostCallbackInfo.callback, nullptr);
+ ASSERT_NE(apiDesc->deviceLostCallbackInfo.callback, &DeviceLostCallback);
+ ASSERT_NE(apiDesc->deviceLostCallbackInfo.userdata, nullptr);
+ ASSERT_NE(apiDesc->deviceLostCallbackInfo.userdata, &userdata);
// Call the callback so the test doesn't wait indefinitely.
api.CallAdapterRequestDeviceCallback(apiAdapter, WGPURequestDeviceStatus_Error, nullptr,
@@ -172,8 +177,6 @@
EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(apiDevice, NotNull(), NotNull()))
.Times(1);
EXPECT_CALL(api, OnDeviceSetLoggingCallback(apiDevice, NotNull(), NotNull())).Times(1);
- EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(apiDevice, NotNull(), NotNull()))
- .Times(1);
EXPECT_CALL(api, DeviceGetLimits(apiDevice, NotNull()))
.WillOnce(WithArg<1>(Invoke([&](WGPUSupportedLimits* limits) {
@@ -244,7 +247,6 @@
// Cleared when the device is destroyed.
EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(apiDevice, nullptr, nullptr)).Times(1);
EXPECT_CALL(api, OnDeviceSetLoggingCallback(apiDevice, nullptr, nullptr)).Times(1);
- EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(apiDevice, nullptr, nullptr)).Times(1);
EXPECT_CALL(api, DeviceRelease(apiDevice));
// Server has not recevied the release yet, so the device should be known.
@@ -349,8 +351,6 @@
EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(apiDevice, NotNull(), NotNull()))
.Times(1);
EXPECT_CALL(api, OnDeviceSetLoggingCallback(apiDevice, NotNull(), NotNull())).Times(1);
- EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(apiDevice, NotNull(), NotNull()))
- .Times(1);
EXPECT_CALL(api, DeviceGetLimits(apiDevice, NotNull()))
.WillOnce(WithArg<1>(Invoke([&](WGPUSupportedLimits* limits) {
@@ -389,7 +389,6 @@
// Cleared when the device is destroyed.
EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(apiDevice, nullptr, nullptr)).Times(1);
EXPECT_CALL(api, OnDeviceSetLoggingCallback(apiDevice, nullptr, nullptr)).Times(1);
- EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(apiDevice, nullptr, nullptr)).Times(1);
EXPECT_CALL(api, DeviceRelease(apiDevice));
FlushClient();
}
diff --git a/src/dawn/tests/unittests/wire/WireDeviceLifetimeTests.cpp b/src/dawn/tests/unittests/wire/WireDeviceLifetimeTests.cpp
index 564b444..cd14dd9 100644
--- a/src/dawn/tests/unittests/wire/WireDeviceLifetimeTests.cpp
+++ b/src/dawn/tests/unittests/wire/WireDeviceLifetimeTests.cpp
@@ -109,10 +109,6 @@
ASSERT_TRUE(wireHelper->FlushServer());
ASSERT_NE(device, nullptr);
- wgpu::BufferDescriptor bufferDesc = {};
- bufferDesc.size = 128;
- bufferDesc.usage = wgpu::BufferUsage::Uniform;
-
// Destroy the device.
device.Destroy();
diff --git a/src/dawn/tests/unittests/wire/WireDisconnectTests.cpp b/src/dawn/tests/unittests/wire/WireDisconnectTests.cpp
index 8220400..dc9a06d 100644
--- a/src/dawn/tests/unittests/wire/WireDisconnectTests.cpp
+++ b/src/dawn/tests/unittests/wire/WireDisconnectTests.cpp
@@ -81,12 +81,8 @@
// Check that disconnecting the wire client calls the device lost callback exacty once.
TEST_F(WireDisconnectTests, CallsDeviceLostCallback) {
- MockCallback<WGPUDeviceLostCallback> mockDeviceLostCallback;
- wgpuDeviceSetDeviceLostCallback(device, mockDeviceLostCallback.Callback(),
- mockDeviceLostCallback.MakeUserdata(this));
-
// Disconnect the wire client. We should receive device lost only once.
- EXPECT_CALL(mockDeviceLostCallback, Call(WGPUDeviceLostReason_Undefined, _, this))
+ EXPECT_CALL(deviceLostCallback, Call(_, WGPUDeviceLostReason_InstanceDropped, _, this))
.Times(Exactly(1));
GetWireClient()->Disconnect();
GetWireClient()->Disconnect();
@@ -95,21 +91,17 @@
// Check that disconnecting the wire client after a device loss does not trigger the callback
// again.
TEST_F(WireDisconnectTests, ServerLostThenDisconnect) {
- MockCallback<WGPUDeviceLostCallback> mockDeviceLostCallback;
- wgpuDeviceSetDeviceLostCallback(device, mockDeviceLostCallback.Callback(),
- mockDeviceLostCallback.MakeUserdata(this));
-
api.CallDeviceSetDeviceLostCallbackCallback(apiDevice, WGPUDeviceLostReason_Undefined,
"some reason");
// Flush the device lost return command.
- EXPECT_CALL(mockDeviceLostCallback,
- Call(WGPUDeviceLostReason_Undefined, StrEq("some reason"), this))
+ EXPECT_CALL(deviceLostCallback,
+ Call(_, WGPUDeviceLostReason_Undefined, StrEq("some reason"), this))
.Times(Exactly(1));
FlushServer();
// Disconnect the client. We shouldn't see the lost callback again.
- EXPECT_CALL(mockDeviceLostCallback, Call(_, _, _)).Times(Exactly(0));
+ EXPECT_CALL(deviceLostCallback, Call).Times(Exactly(0));
GetWireClient()->Disconnect();
}
@@ -141,7 +133,7 @@
mockDeviceLostCallback.MakeUserdata(this));
// Disconnect the client. We should see the callback once.
- EXPECT_CALL(mockDeviceLostCallback, Call(WGPUDeviceLostReason_Undefined, _, this))
+ EXPECT_CALL(mockDeviceLostCallback, Call(WGPUDeviceLostReason_InstanceDropped, _, this))
.Times(Exactly(1));
GetWireClient()->Disconnect();
@@ -180,9 +172,6 @@
EXPECT_CALL(api, OnDeviceSetLoggingCallback(apiDevice, nullptr, nullptr))
.Times(1)
.InSequence(s1, s2);
- EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(apiDevice, nullptr, nullptr))
- .Times(1)
- .InSequence(s1, s2);
EXPECT_CALL(api, DeviceRelease(apiDevice)).Times(1).InSequence(s1, s2, s3, s4, s5);
EXPECT_CALL(api, QueueRelease(apiQueue)).Times(1).InSequence(s1);
EXPECT_CALL(api, CommandEncoderRelease(apiCommandEncoder)).Times(1).InSequence(s2);
diff --git a/src/dawn/tests/unittests/wire/WireQueueTests.cpp b/src/dawn/tests/unittests/wire/WireQueueTests.cpp
index 87e2603..69142e8 100644
--- a/src/dawn/tests/unittests/wire/WireQueueTests.cpp
+++ b/src/dawn/tests/unittests/wire/WireQueueTests.cpp
@@ -176,7 +176,6 @@
// These set X callback methods are called before the device is released.
EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(apiDevice, nullptr, nullptr)).Times(1);
EXPECT_CALL(api, OnDeviceSetLoggingCallback(apiDevice, nullptr, nullptr)).Times(1);
- EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(apiDevice, nullptr, nullptr)).Times(1);
FlushClient();
// Indicate to the fixture that the device was already released.
@@ -198,7 +197,6 @@
// These set X callback methods are called before the device is released.
EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(apiDevice, nullptr, nullptr)).Times(1);
EXPECT_CALL(api, OnDeviceSetLoggingCallback(apiDevice, nullptr, nullptr)).Times(1);
- EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(apiDevice, nullptr, nullptr)).Times(1);
FlushClient();
// Release the external queue reference. The queue should be released.
diff --git a/src/dawn/tests/unittests/wire/WireTest.cpp b/src/dawn/tests/unittests/wire/WireTest.cpp
index 1fb4be0..d96a4cb 100644
--- a/src/dawn/tests/unittests/wire/WireTest.cpp
+++ b/src/dawn/tests/unittests/wire/WireTest.cpp
@@ -35,6 +35,7 @@
using testing::_;
using testing::AnyNumber;
+using testing::AtMost;
using testing::Exactly;
using testing::Invoke;
using testing::Mock;
@@ -129,29 +130,41 @@
// Create the device for testing.
apiDevice = api.GetNewDevice();
WGPUDeviceDescriptor deviceDesc = {};
+ deviceDesc.deviceLostCallbackInfo.callback = deviceLostCallback.Callback();
+ deviceDesc.deviceLostCallbackInfo.userdata = deviceLostCallback.MakeUserdata(this);
+ EXPECT_CALL(deviceLostCallback, Call).Times(AtMost(1));
+ deviceDesc.uncapturedErrorCallbackInfo.callback = uncapturedErrorCallback.Callback();
+ deviceDesc.uncapturedErrorCallbackInfo.userdata = uncapturedErrorCallback.MakeUserdata(this);
MockCallback<WGPURequestDeviceCallback> deviceCb;
wgpuAdapterRequestDevice(adapter.Get(), &deviceDesc, deviceCb.Callback(),
deviceCb.MakeUserdata(this));
- EXPECT_CALL(api, OnAdapterRequestDevice(apiAdapter, NotNull(), _)).WillOnce([&]() {
- // Set on device creation to forward callbacks to the client.
- EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(apiDevice, NotNull(), NotNull()))
- .Times(1);
- EXPECT_CALL(api, OnDeviceSetLoggingCallback(apiDevice, NotNull(), NotNull())).Times(1);
- EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(apiDevice, NotNull(), NotNull())).Times(1);
+ EXPECT_CALL(api, OnAdapterRequestDevice(apiAdapter, NotNull(), _))
+ .WillOnce(WithArg<1>([&](const WGPUDeviceDescriptor* desc) {
+ // Set on device creation to forward callbacks to the client.
+ EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(apiDevice, NotNull(), NotNull()))
+ .Times(1);
+ EXPECT_CALL(api, OnDeviceSetLoggingCallback(apiDevice, NotNull(), NotNull())).Times(1);
- EXPECT_CALL(api, DeviceGetLimits(apiDevice, NotNull()))
- .WillOnce(WithArg<1>(Invoke([&](WGPUSupportedLimits* limits) {
- *limits = {};
- return true;
- })));
+ // The mock objects currently require us to manually set the callbacks because we are no
+ // longer explicitly calling SetDeviceLostCallback anymore.
+ ProcTableAsClass::Object* object =
+ reinterpret_cast<ProcTableAsClass::Object*>(apiDevice);
+ object->mDeviceLostCallback = desc->deviceLostCallbackInfo.callback;
+ object->mDeviceLostUserdata = desc->deviceLostCallbackInfo.userdata;
- EXPECT_CALL(api, DeviceEnumerateFeatures(apiDevice, nullptr))
- .WillOnce(Return(0))
- .WillOnce(Return(0));
+ EXPECT_CALL(api, DeviceGetLimits(apiDevice, NotNull()))
+ .WillOnce(WithArg<1>(Invoke([&](WGPUSupportedLimits* limits) {
+ *limits = {};
+ return true;
+ })));
- api.CallAdapterRequestDeviceCallback(apiAdapter, WGPURequestDeviceStatus_Success, apiDevice,
- nullptr);
- });
+ EXPECT_CALL(api, DeviceEnumerateFeatures(apiDevice, nullptr))
+ .WillOnce(Return(0))
+ .WillOnce(Return(0));
+
+ api.CallAdapterRequestDeviceCallback(apiAdapter, WGPURequestDeviceStatus_Success,
+ apiDevice, nullptr);
+ }));
FlushClient();
EXPECT_CALL(deviceCb, Call(WGPURequestDeviceStatus_Success, NotNull(), nullptr, this))
.WillOnce(SaveArg<1>(&device));
@@ -189,8 +202,6 @@
EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(apiDevice, nullptr, nullptr))
.Times(Exactly(1));
EXPECT_CALL(api, OnDeviceSetLoggingCallback(apiDevice, nullptr, nullptr)).Times(Exactly(1));
- EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(apiDevice, nullptr, nullptr))
- .Times(Exactly(1));
}
mWireServer = nullptr;
}
@@ -238,8 +249,6 @@
EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(apiDevice, nullptr, nullptr))
.Times(Exactly(1));
EXPECT_CALL(api, OnDeviceSetLoggingCallback(apiDevice, nullptr, nullptr)).Times(Exactly(1));
- EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(apiDevice, nullptr, nullptr))
- .Times(Exactly(1));
}
mWireServer = nullptr;
}
diff --git a/src/dawn/tests/unittests/wire/WireTest.h b/src/dawn/tests/unittests/wire/WireTest.h
index 7b4074b..c9c86a7 100644
--- a/src/dawn/tests/unittests/wire/WireTest.h
+++ b/src/dawn/tests/unittests/wire/WireTest.h
@@ -32,6 +32,7 @@
#include "dawn/common/Log.h"
#include "dawn/mock_webgpu.h"
+#include "dawn/tests/MockCallback.h"
#include "gtest/gtest.h"
#include "webgpu/webgpu_cpp.h"
@@ -148,6 +149,10 @@
void DefaultApiAdapterWasReleased();
testing::StrictMock<MockProcTable> api;
+
+ testing::MockCallback<WGPUDeviceLostCallbackNew> deviceLostCallback;
+ testing::MockCallback<WGPUErrorCallback> uncapturedErrorCallback;
+
WGPUInstance instance;
WGPUInstance apiInstance;
wgpu::Adapter adapter;
diff --git a/src/dawn/wire/client/Adapter.cpp b/src/dawn/wire/client/Adapter.cpp
index ae378a8..287acee 100644
--- a/src/dawn/wire/client/Adapter.cpp
+++ b/src/dawn/wire/client/Adapter.cpp
@@ -69,18 +69,30 @@
private:
void CompleteImpl(FutureID futureID, EventCompletionType completionType) override {
+ Device* device = mDevice;
if (completionType == EventCompletionType::Shutdown) {
mStatus = WGPURequestDeviceStatus_InstanceDropped;
mMessage = "A valid external Instance reference no longer exists.";
}
- if (mStatus != WGPURequestDeviceStatus_Success && mDevice != nullptr) {
- // If there was an error, we may need to reclaim the device allocation, otherwise the
- // device is returned to the user who owns it.
- mDevice->GetClient()->Free(mDevice.get());
- mDevice = nullptr;
+ if (mStatus != WGPURequestDeviceStatus_Success) {
+ device = nullptr;
}
if (mCallback) {
- mCallback(mStatus, ToAPI(mDevice), mMessage ? mMessage->c_str() : nullptr, mUserdata);
+ mCallback(mStatus, ToAPI(device), mMessage ? mMessage->c_str() : nullptr, mUserdata);
+ }
+
+ if (mStatus != WGPURequestDeviceStatus_Success) {
+ // If there was an error, we may need to call the device lost callback and reclaim the
+ // device allocation, otherwise the device is returned to the user who owns it.
+ if (mStatus == WGPURequestDeviceStatus_InstanceDropped) {
+ mDevice->HandleDeviceLost(WGPUDeviceLostReason_InstanceDropped,
+ "A valid external Instance reference no longer exists.");
+ } else {
+ mDevice->HandleDeviceLost(WGPUDeviceLostReason_FailedCreation,
+ "Device failed at creation.");
+ }
+ mDevice->Release();
+ mDevice = nullptr;
}
}
@@ -257,13 +269,17 @@
return {futureIDInternal};
}
- // Ensure the device lost callback isn't serialized as part of the command, as it cannot be
- // passed between processes.
+ // Ensure callbacks are not serialized as part of the command, as they cannot be passed between
+ // processes.
WGPUDeviceDescriptor wireDescriptor = {};
if (descriptor) {
wireDescriptor = *descriptor;
wireDescriptor.deviceLostCallback = nullptr;
wireDescriptor.deviceLostUserdata = nullptr;
+ wireDescriptor.deviceLostCallbackInfo.callback = nullptr;
+ wireDescriptor.deviceLostCallbackInfo.userdata = nullptr;
+ wireDescriptor.uncapturedErrorCallbackInfo.callback = nullptr;
+ wireDescriptor.uncapturedErrorCallbackInfo.userdata = nullptr;
}
AdapterRequestDeviceCmd cmd;
@@ -271,6 +287,7 @@
cmd.eventManagerHandle = GetEventManagerHandle();
cmd.future = {futureIDInternal};
cmd.deviceObjectHandle = device->GetWireHandle();
+ cmd.deviceLostFuture = device->GetDeviceLostFuture();
cmd.descriptor = &wireDescriptor;
client->SerializeCommand(cmd);
diff --git a/src/dawn/wire/client/Client.cpp b/src/dawn/wire/client/Client.cpp
index 2bc76dc..3d9aac3 100644
--- a/src/dawn/wire/client/Client.cpp
+++ b/src/dawn/wire/client/Client.cpp
@@ -60,6 +60,11 @@
}
Client::~Client() {
+ // Transition all event managers to ClientDropped state.
+ for (auto& [_, eventManager] : mEventManagers) {
+ eventManager->TransitionTo(EventManager::State::ClientDropped);
+ }
+
DestroyAllObjects();
}
@@ -124,8 +129,7 @@
}
// Reserve an EventManager for the given instance and make the association in the map.
- mEventManagers.emplace(ObjectHandle(instance->GetWireId(), instance->GetWireGeneration()),
- std::make_unique<EventManager>());
+ mEventManagers.emplace(instance->GetWireHandle(), std::make_unique<EventManager>());
ReservedInstance result;
result.instance = ToAPI(instance);
@@ -159,6 +163,11 @@
mDisconnected = true;
mSerializer = ChunkedCommandSerializer(NoopCommandSerializer::GetInstance());
+ // Transition all event managers to ClientDropped state.
+ for (auto& [_, eventManager] : mEventManagers) {
+ eventManager->TransitionTo(EventManager::State::ClientDropped);
+ }
+
auto& deviceList = mObjects[ObjectType::Device];
{
for (LinkNode<ObjectBase>* device = deviceList.head(); device != deviceList.end();
@@ -173,11 +182,6 @@
object->value()->CancelCallbacksForDisconnect();
}
}
-
- // Transition all event managers to ClientDropped state.
- for (auto& [_, eventManager] : mEventManagers) {
- eventManager->TransitionTo(EventManager::State::ClientDropped);
- }
}
bool Client::IsDisconnected() const {
diff --git a/src/dawn/wire/client/ClientDoers.cpp b/src/dawn/wire/client/ClientDoers.cpp
index 5dfdb7d..0d74781 100644
--- a/src/dawn/wire/client/ClientDoers.cpp
+++ b/src/dawn/wire/client/ClientDoers.cpp
@@ -66,15 +66,4 @@
return WireResult::Success;
}
-WireResult Client::DoDeviceLostCallback(Device* device,
- WGPUDeviceLostReason reason,
- char const* message) {
- if (device == nullptr) {
- // The device might have been deleted or recreated so this isn't an error.
- return WireResult::Success;
- }
- device->HandleDeviceLost(reason, message);
- return WireResult::Success;
-}
-
} // namespace dawn::wire::client
diff --git a/src/dawn/wire/client/Device.cpp b/src/dawn/wire/client/Device.cpp
index 3762b98..5196d82 100644
--- a/src/dawn/wire/client/Device.cpp
+++ b/src/dawn/wire/client/Device.cpp
@@ -162,45 +162,128 @@
EventType::CreateRenderPipeline,
WGPUCreateRenderPipelineAsyncCallbackInfo>;
+static constexpr WGPUUncapturedErrorCallbackInfo kEmptyUncapturedErrorCallbackInfo = {
+ nullptr, nullptr, nullptr};
+
} // namespace
+class Device::DeviceLostEvent : public TrackedEvent {
+ public:
+ static constexpr EventType kType = EventType::DeviceLost;
+
+ DeviceLostEvent(const WGPUDeviceLostCallbackInfo& callbackInfo, Device* device)
+ : TrackedEvent(callbackInfo.mode), mDevice(device) {
+ DAWN_ASSERT(device != nullptr);
+ mDevice->Reference();
+ }
+
+ ~DeviceLostEvent() override { mDevice->Release(); }
+
+ EventType GetType() override { return kType; }
+
+ WireResult ReadyHook(FutureID futureID, WGPUDeviceLostReason reason, const char* message) {
+ mReason = reason;
+ if (message != nullptr) {
+ mMessage = message;
+ }
+ mDevice->mDeviceLostInfo.futureID = kNullFutureID;
+ return WireResult::Success;
+ }
+
+ private:
+ void CompleteImpl(FutureID futureID, EventCompletionType completionType) override {
+ if (completionType == EventCompletionType::Shutdown) {
+ mReason = WGPUDeviceLostReason_InstanceDropped;
+ mMessage = "A valid external Instance reference no longer exists.";
+ }
+
+ if (mDevice->mDeviceLostInfo.oldCallback != nullptr) {
+ mDevice->mDeviceLostInfo.oldCallback(mReason, mMessage ? mMessage->c_str() : nullptr,
+ mDevice->mDeviceLostInfo.userdata);
+ } else if (mDevice->mDeviceLostInfo.callback != nullptr) {
+ auto device = mReason != WGPUDeviceLostReason_FailedCreation ? ToAPI(mDevice) : nullptr;
+ mDevice->mDeviceLostInfo.callback(&device, mReason,
+ mMessage ? mMessage->c_str() : nullptr,
+ mDevice->mDeviceLostInfo.userdata);
+ }
+ mDevice->mUncapturedErrorCallbackInfo = kEmptyUncapturedErrorCallbackInfo;
+ }
+
+ WGPUDeviceLostReason mReason;
+ // Note that the message is optional because we want to return nullptr when it wasn't set
+ // instead of a pointer to an empty string.
+ std::optional<std::string> mMessage;
+
+ // Strong reference to the device so that when we call the callback we can pass the device.
+ // TODO(https://crbug.com/dawn/2345): Investigate `DanglingUntriaged` in dawn/wire.
+ raw_ptr<Device, DanglingUntriaged> mDevice;
+};
+
Device::Device(const ObjectBaseParams& params,
const ObjectHandle& eventManagerHandle,
const WGPUDeviceDescriptor* descriptor)
- : ObjectWithEventsBase(params, eventManagerHandle), mIsAlive(std::make_shared<bool>()) {
- if (descriptor && descriptor->deviceLostCallback) {
- mDeviceLostCallback = descriptor->deviceLostCallback;
- mDeviceLostUserdata = descriptor->deviceLostUserdata;
- }
-
+ : ObjectWithEventsBase(params, eventManagerHandle), mIsAlive(std::make_shared<bool>(true)) {
#if defined(DAWN_ENABLE_ASSERTS)
- mErrorCallback = [](WGPUErrorType, char const*, void*) {
- static bool calledOnce = false;
- if (!calledOnce) {
- calledOnce = true;
- dawn::WarningLog() << "No Dawn device uncaptured error callback was set. This is "
- "probably not intended. If you really want to ignore errors "
- "and suppress this message, set the callback to null.";
- }
- };
-
- if (!mDeviceLostCallback) {
- mDeviceLostCallback = [](WGPUDeviceLostReason, char const*, void*) {
+ static constexpr WGPUDeviceLostCallbackInfo kDefaultDeviceLostCallbackInfo = {
+ nullptr, WGPUCallbackMode_AllowSpontaneous,
+ [](WGPUDevice const*, WGPUDeviceLostReason, char const*, void*) {
static bool calledOnce = false;
if (!calledOnce) {
calledOnce = true;
dawn::WarningLog() << "No Dawn device lost callback was set. This is probably not "
"intended. If you really want to ignore device lost "
- "and suppress this message, set the callback to null.";
+ "and suppress this message, set the callback explicitly.";
}
- };
- }
+ },
+ nullptr};
+ static constexpr WGPUUncapturedErrorCallbackInfo kDefaultUncapturedErrorCallbackInfo = {
+ nullptr,
+ [](WGPUErrorType, char const*, void*) {
+ static bool calledOnce = false;
+ if (!calledOnce) {
+ calledOnce = true;
+ dawn::WarningLog() << "No Dawn device uncaptured error callback was set. This is "
+ "probably not intended. If you really want to ignore errors "
+ "and suppress this message, set the callback explicitly.";
+ }
+ },
+ nullptr};
+#else
+ static constexpr WGPUDeviceLostCallbackInfo kDefaultDeviceLostCallbackInfo = {
+ nullptr, WGPUCallbackMode_AllowSpontaneous, nullptr, nullptr};
+ static constexpr WGPUUncapturedErrorCallbackInfo kDefaultUncapturedErrorCallbackInfo =
+ kEmptyUncapturedErrorCallbackInfo;
#endif // DAWN_ENABLE_ASSERTS
+
+ WGPUDeviceLostCallbackInfo deviceLostCallbackInfo = kDefaultDeviceLostCallbackInfo;
+ if (descriptor != nullptr) {
+ if (descriptor->deviceLostCallbackInfo.callback != nullptr) {
+ deviceLostCallbackInfo = descriptor->deviceLostCallbackInfo;
+ if (deviceLostCallbackInfo.mode == WGPUCallbackMode_WaitAnyOnly) {
+ // TODO(dawn:2458) Currently we default the callback mode to Spontaneous if not
+ // passed for backwards compatibility. We should add warning logging for it though
+ // when available. Update this when we have WGPUCallbackMode_Undefined.
+ deviceLostCallbackInfo.mode = WGPUCallbackMode_AllowSpontaneous;
+ }
+ mDeviceLostInfo.callback = deviceLostCallbackInfo.callback;
+ mDeviceLostInfo.userdata = deviceLostCallbackInfo.userdata;
+ } else if (descriptor->deviceLostCallback != nullptr) {
+ deviceLostCallbackInfo = {nullptr, WGPUCallbackMode_AllowSpontaneous, nullptr, nullptr};
+ mDeviceLostInfo.oldCallback = descriptor->deviceLostCallback;
+ mDeviceLostInfo.userdata = descriptor->deviceLostUserdata;
+ }
+ }
+ mDeviceLostInfo.event = std::make_unique<DeviceLostEvent>(deviceLostCallbackInfo, this);
+
+ mUncapturedErrorCallbackInfo = kDefaultUncapturedErrorCallbackInfo;
+ if (descriptor && descriptor->uncapturedErrorCallbackInfo.callback != nullptr) {
+ mUncapturedErrorCallbackInfo = descriptor->uncapturedErrorCallbackInfo;
+ }
}
Device::~Device() {
if (mQueue != nullptr) {
- GetProcs().queueRelease(ToAPI(mQueue.ExtractAsDangling()));
+ mQueue.ExtractAsDangling()->Release();
}
}
@@ -208,6 +291,17 @@
return ObjectType::Device;
}
+uint32_t Device::Release() {
+ // The device always has a reference in it's DeviceLossEvent which is created at construction,
+ // so when we drop to 1, we want to set the event so that the device can be loss according to
+ // the callback mode.
+ uint32_t refCount = ObjectBase::Release();
+ if (refCount == 1) {
+ HandleDeviceLost(WGPUDeviceLostReason_Destroyed, "Device was destroyed.");
+ }
+ return refCount;
+}
+
bool Device::GetLimits(WGPUSupportedLimits* limits) const {
return mLimitsAndFeatures.GetLimits(limits);
}
@@ -229,8 +323,9 @@
}
void Device::HandleError(WGPUErrorType errorType, const char* message) {
- if (mErrorCallback) {
- mErrorCallback(errorType, message, mErrorUserdata);
+ if (mUncapturedErrorCallbackInfo.callback) {
+ mUncapturedErrorCallbackInfo.callback(errorType, message,
+ mUncapturedErrorCallbackInfo.userdata);
}
}
@@ -242,19 +337,33 @@
}
void Device::HandleDeviceLost(WGPUDeviceLostReason reason, const char* message) {
- if (mDeviceLostCallback && !mDidRunLostCallback) {
- mDidRunLostCallback = true;
- mDeviceLostCallback(reason, message, mDeviceLostUserdata);
+ FutureID futureID = GetDeviceLostFuture().id;
+ if (futureID != kNullFutureID) {
+ DAWN_CHECK(GetEventManager().SetFutureReady<DeviceLostEvent>(futureID, reason, message) ==
+ WireResult::Success);
}
}
+WGPUFuture Device::GetDeviceLostFuture() {
+ // Lazily track the device lost event so that event ordering w.r.t RequestDevice is correct.
+ if (mDeviceLostInfo.event != nullptr) {
+ auto [deviceLostFutureIDInternal, tracked] =
+ GetEventManager().TrackEvent(std::move(mDeviceLostInfo.event));
+ if (tracked) {
+ mDeviceLostInfo.futureID = deviceLostFutureIDInternal;
+ }
+ }
+ return {mDeviceLostInfo.futureID};
+}
+
std::weak_ptr<bool> Device::GetAliveWeakPtr() {
return mIsAlive;
}
void Device::SetUncapturedErrorCallback(WGPUErrorCallback errorCallback, void* errorUserdata) {
- mErrorCallback = errorCallback;
- mErrorUserdata = errorUserdata;
+ if (mDeviceLostInfo.futureID != kNullFutureID) {
+ mUncapturedErrorCallbackInfo = {nullptr, errorCallback, errorUserdata};
+ }
}
void Device::SetLoggingCallback(WGPULoggingCallback callback, void* userdata) {
@@ -263,8 +372,18 @@
}
void Device::SetDeviceLostCallback(WGPUDeviceLostCallback callback, void* userdata) {
- mDeviceLostCallback = callback;
- mDeviceLostUserdata = userdata;
+ if (mDeviceLostInfo.futureID != kNullFutureID) {
+ mDeviceLostInfo.oldCallback = callback;
+ mDeviceLostInfo.userdata = userdata;
+ }
+}
+
+WireResult Client::DoDeviceLostCallback(ObjectHandle eventManager,
+ WGPUFuture future,
+ WGPUDeviceLostReason reason,
+ char const* message) {
+ return GetEventManager(eventManager)
+ .SetFutureReady<Device::DeviceLostEvent>(future.id, reason, message);
}
void Device::PopErrorScope(WGPUErrorCallback callback, void* userdata) {
diff --git a/src/dawn/wire/client/Device.h b/src/dawn/wire/client/Device.h
index 6d16647..4467c4b 100644
--- a/src/dawn/wire/client/Device.h
+++ b/src/dawn/wire/client/Device.h
@@ -52,6 +52,9 @@
ObjectType GetObjectType() const override;
+ // Override the default Release implementation to handle the device lost event.
+ uint32_t Release();
+
void SetUncapturedErrorCallback(WGPUErrorCallback errorCallback, void* errorUserdata);
void SetLoggingCallback(WGPULoggingCallback errorCallback, void* errorUserdata);
void SetDeviceLostCallback(WGPUDeviceLostCallback errorCallback, void* errorUserdata);
@@ -83,9 +86,12 @@
void SetFeatures(const WGPUFeatureName* features, uint32_t featuresCount);
WGPUQueue GetQueue();
+ WGPUFuture GetDeviceLostFuture();
std::weak_ptr<bool> GetAliveWeakPtr();
+ class DeviceLostEvent;
+
private:
template <typename Event,
typename Cmd,
@@ -95,14 +101,21 @@
LimitsAndFeatures mLimitsAndFeatures;
- WGPUErrorCallback mErrorCallback = nullptr;
- WGPUDeviceLostCallback mDeviceLostCallback = nullptr;
+ // TODO(crbug.com/dawn/2465): This can probably just be the future id once SetDeviceLostCallback
+ // is deprecated, and the callback and userdata moved into the DeviceLostEvent.
+ struct DeviceLostInfo {
+ FutureID futureID = kNullFutureID;
+ std::unique_ptr<TrackedEvent> event = nullptr;
+ WGPUDeviceLostCallbackNew callback = nullptr;
+ WGPUDeviceLostCallback oldCallback = nullptr;
+ // TODO(https://crbug.com/dawn/2345): Investigate `DanglingUntriaged` in dawn/wire:
+ raw_ptr<void, DanglingUntriaged> userdata = nullptr;
+ };
+ DeviceLostInfo mDeviceLostInfo;
+
+ WGPUUncapturedErrorCallbackInfo mUncapturedErrorCallbackInfo;
WGPULoggingCallback mLoggingCallback = nullptr;
- bool mDidRunLostCallback = false;
// TODO(https://crbug.com/dawn/2345): Investigate `DanglingUntriaged` in dawn/wire:
- raw_ptr<void, DanglingUntriaged> mErrorUserdata = nullptr;
- // TODO(https://crbug.com/dawn/2345): Investigate `DanglingUntriaged` in dawn/wire:
- raw_ptr<void, DanglingUntriaged> mDeviceLostUserdata = nullptr;
raw_ptr<void> mLoggingUserdata = nullptr;
raw_ptr<Queue> mQueue = nullptr;
diff --git a/src/dawn/wire/client/EventManager.h b/src/dawn/wire/client/EventManager.h
index de30fb3..ff5663e 100644
--- a/src/dawn/wire/client/EventManager.h
+++ b/src/dawn/wire/client/EventManager.h
@@ -49,6 +49,7 @@
CompilationInfo,
CreateComputePipeline,
CreateRenderPipeline,
+ DeviceLost,
MapAsync,
PopErrorScope,
RequestAdapter,
diff --git a/src/dawn/wire/client/ObjectBase.cpp b/src/dawn/wire/client/ObjectBase.cpp
index cc992fc..df76a08 100644
--- a/src/dawn/wire/client/ObjectBase.cpp
+++ b/src/dawn/wire/client/ObjectBase.cpp
@@ -59,11 +59,11 @@
mRefcount++;
}
-void ObjectBase::Release() {
+uint32_t ObjectBase::Release() {
DAWN_ASSERT(mRefcount != 0);
- mRefcount--;
- if (mRefcount == 0) {
+ uint32_t refCount = --mRefcount;
+ if (refCount == 0) {
DestroyObjectCmd cmd;
cmd.objectType = GetObjectType();
cmd.objectId = GetWireId();
@@ -72,6 +72,8 @@
client->SerializeCommand(cmd);
client->Free(this, GetObjectType());
}
+
+ return refCount;
}
ObjectWithEventsBase::ObjectWithEventsBase(const ObjectBaseParams& params,
diff --git a/src/dawn/wire/client/ObjectBase.h b/src/dawn/wire/client/ObjectBase.h
index 4bfaa66..16d4fbf6 100644
--- a/src/dawn/wire/client/ObjectBase.h
+++ b/src/dawn/wire/client/ObjectBase.h
@@ -64,7 +64,7 @@
Client* GetClient() const;
void Reference();
- void Release();
+ uint32_t Release();
protected:
uint32_t GetRefcount() const { return mRefcount; }
diff --git a/src/dawn/wire/server/Server.cpp b/src/dawn/wire/server/Server.cpp
index bf04246..7ed991e 100644
--- a/src/dawn/wire/server/Server.cpp
+++ b/src/dawn/wire/server/Server.cpp
@@ -162,21 +162,13 @@
info->server->OnLogging(info->self, type, message);
},
device->info.get());
- mProcs.deviceSetDeviceLostCallback(
- device->handle,
- [](WGPUDeviceLostReason reason, const char* message, void* userdata) {
- DeviceInfo* info = static_cast<DeviceInfo*>(userdata);
- info->server->OnDeviceLost(info->self, reason, message);
- },
- device->info.get());
}
void Server::ClearDeviceCallbacks(WGPUDevice device) {
- // Un-set the error and lost callbacks since we cannot forward them
+ // Un-set the error and logging callbacks since we cannot forward them
// after the server has been destroyed.
mProcs.deviceSetUncapturedErrorCallback(device, nullptr, nullptr);
mProcs.deviceSetLoggingCallback(device, nullptr, nullptr);
- mProcs.deviceSetDeviceLostCallback(device, nullptr, nullptr);
}
} // namespace dawn::wire::server
diff --git a/src/dawn/wire/server/Server.h b/src/dawn/wire/server/Server.h
index 1e15f1b..f908b3d 100644
--- a/src/dawn/wire/server/Server.h
+++ b/src/dawn/wire/server/Server.h
@@ -160,6 +160,14 @@
ObjectHandle eventManager;
WGPUFuture future;
ObjectId deviceObjectId;
+ WGPUFuture deviceLostFuture;
+};
+
+struct DeviceLostUserdata : CallbackUserdata {
+ using CallbackUserdata::CallbackUserdata;
+
+ ObjectHandle eventManager;
+ WGPUFuture future;
};
class Server : public ServerBase {
@@ -207,8 +215,13 @@
// Error callbacks
void OnUncapturedError(ObjectHandle device, WGPUErrorType type, const char* message);
- void OnDeviceLost(ObjectHandle device, WGPUDeviceLostReason reason, const char* message);
void OnLogging(ObjectHandle device, WGPULoggingType type, const char* message);
+
+ // Async event callbacks
+ void OnDeviceLost(DeviceLostUserdata* userdata,
+ WGPUDevice const* device,
+ WGPUDeviceLostReason reason,
+ const char* message);
void OnDevicePopErrorScope(ErrorScopeUserdata* userdata,
WGPUErrorType type,
const char* message);
diff --git a/src/dawn/wire/server/ServerAdapter.cpp b/src/dawn/wire/server/ServerAdapter.cpp
index a95e78c..c5bbba1 100644
--- a/src/dawn/wire/server/ServerAdapter.cpp
+++ b/src/dawn/wire/server/ServerAdapter.cpp
@@ -36,6 +36,7 @@
ObjectHandle eventManager,
WGPUFuture future,
ObjectHandle deviceHandle,
+ WGPUFuture deviceLostFuture,
const WGPUDeviceDescriptor* descriptor) {
Known<WGPUDevice> device;
WIRE_TRY(DeviceObjects().Allocate(&device, deviceHandle, AllocationState::Reserved));
@@ -44,8 +45,19 @@
userdata->eventManager = eventManager;
userdata->future = future;
userdata->deviceObjectId = device.id;
+ userdata->deviceLostFuture = deviceLostFuture;
- mProcs.adapterRequestDevice(adapter->handle, descriptor,
+ // Update the descriptor with the device lost callback associated with this request.
+ auto deviceLostUserdata = MakeUserdata<DeviceLostUserdata>();
+ deviceLostUserdata->eventManager = eventManager;
+ deviceLostUserdata->future = deviceLostFuture;
+
+ WGPUDeviceDescriptor desc = *descriptor;
+ desc.deviceLostCallbackInfo.mode = WGPUCallbackMode_AllowProcessEvents;
+ desc.deviceLostCallbackInfo.callback = ForwardToServer<&Server::OnDeviceLost>;
+ desc.deviceLostCallbackInfo.userdata = deviceLostUserdata.release();
+
+ mProcs.adapterRequestDevice(adapter->handle, &desc,
ForwardToServer<&Server::OnRequestDeviceCallback>,
userdata.release());
return WireResult::Success;
@@ -61,11 +73,19 @@
cmd.status = status;
cmd.message = message;
- if (status != WGPURequestDeviceStatus_Success) {
- // Free the ObjectId which will make it unusable.
- DeviceObjects().Free(data->deviceObjectId);
- DAWN_ASSERT(device == nullptr);
+ // We always fill the reservation once we complete so that the client is the one to release it.
+ auto FillReservation = [&]() {
+ Known<WGPUDevice> reservation =
+ DeviceObjects().FillReservation(data->deviceObjectId, device);
+ reservation->info->server = this;
+ reservation->info->self = reservation.AsHandle();
SerializeCommand(cmd);
+ return reservation;
+ };
+
+ if (status != WGPURequestDeviceStatus_Success) {
+ DAWN_ASSERT(device == nullptr);
+ FillReservation();
return;
}
@@ -83,12 +103,11 @@
if (!IsFeatureSupported(f)) {
// Release the device.
mProcs.deviceRelease(device);
- // Free the ObjectId which will make it unusable.
- DeviceObjects().Free(data->deviceObjectId);
+ device = nullptr;
cmd.status = WGPURequestDeviceStatus_Error;
cmd.message = "Requested feature not supported.";
- SerializeCommand(cmd);
+ FillReservation();
return;
}
}
@@ -105,13 +124,9 @@
cmd.limits = &limits;
// Assign the handle and allocated status if the device is created successfully.
- Known<WGPUDevice> reservation = DeviceObjects().FillReservation(data->deviceObjectId, device);
+ Known<WGPUDevice> reservation = FillReservation();
DAWN_ASSERT(reservation.data != nullptr);
- reservation->info->server = this;
- reservation->info->self = reservation.AsHandle();
SetForwardingDeviceCallbacks(reservation);
-
- SerializeCommand(cmd);
}
} // namespace dawn::wire::server
diff --git a/src/dawn/wire/server/ServerDevice.cpp b/src/dawn/wire/server/ServerDevice.cpp
index 3d9fd40..f9fa66b 100644
--- a/src/dawn/wire/server/ServerDevice.cpp
+++ b/src/dawn/wire/server/ServerDevice.cpp
@@ -56,9 +56,13 @@
SerializeCommand(cmd);
}
-void Server::OnDeviceLost(ObjectHandle device, WGPUDeviceLostReason reason, const char* message) {
+void Server::OnDeviceLost(DeviceLostUserdata* userdata,
+ WGPUDevice const* device,
+ WGPUDeviceLostReason reason,
+ const char* message) {
ReturnDeviceLostCallbackCmd cmd;
- cmd.device = device;
+ cmd.eventManager = userdata->eventManager;
+ cmd.future = userdata->future;
cmd.reason = reason;
cmd.message = message;