[dawn][native] Adds a way to fake an error after Device creation.
- This is a follow up to:
https://dawn-review.googlesource.com/c/dawn/+/245734 where the
issues in the bugs below should be fixed.
- The fuzz tests linked below were a result of an error happening
after a Device was successfully allocated, but failed to initialize.
This is possible in the fuzz tests because we allow Vulkan failure
fuzzing.
- Updates the device lost event to be updated appropriately when the
device failed to create. Otherwise, when a device was created but
initialize failed, the device lost would actually look like the
device was destroyed. Now only the first |SetLost| on the event
will be reflected, and hence users will see |FailedCreation|.
- In order to better protect against this, this change adds a way to
fake a failure in the frontend and adds relevant regression tests.
- Note that as I was making this change, I also filed
https://github.com/webgpu-native/webgpu-headers/issues/546 upstream
because of the interaction between RequestDevice and the
DeviceLostCallback when we have different callback modes for them.
Bug: 421746158, 415350701
Change-Id: I3f2943c66d435b4bbbdca5fbb26c6c2ac388e160
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/245777
Commit-Queue: Brandon Jones <bajones@chromium.org>
Auto-Submit: Loko Kung <lokokung@google.com>
Reviewed-by: Brandon Jones <bajones@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/dawn/dawn.json b/src/dawn/dawn.json
index 6a80502..a544399 100644
--- a/src/dawn/dawn.json
+++ b/src/dawn/dawn.json
@@ -2158,6 +2158,13 @@
{"name": "fake OOM at device", "type": "bool"}
]
},
+ "dawn fake device initialize error for testing": {
+ "category": "structure",
+ "chained": "in",
+ "chain roots": ["device descriptor"],
+ "tags": ["dawn"],
+ "members": []
+ },
"shared fence type": {
"category": "enum",
"tags": ["dawn", "native"],
@@ -3864,7 +3871,8 @@
{"value": 66, "name": "dawn device allocator control", "tags": ["dawn"]},
{"value": 67, "name": "dawn host mapped pointer limits", "tags": ["dawn"]},
{"value": 68, "name": "render pass descriptor resolve rect", "tags": ["dawn"]},
- {"value": 69, "name": "request adapter WebGPU backend options", "tags": ["dawn", "native"]}
+ {"value": 69, "name": "request adapter WebGPU backend options", "tags": ["dawn", "native"]},
+ {"value": 70, "name": "dawn fake device initialize error for testing", "tags": ["dawn"]}
]
},
"texture": {
diff --git a/src/dawn/native/Adapter.cpp b/src/dawn/native/Adapter.cpp
index 6ebd897..a6ad007 100644
--- a/src/dawn/native/Adapter.cpp
+++ b/src/dawn/native/Adapter.cpp
@@ -318,15 +318,13 @@
// Catch any errors to directly complete the device lost event with the error message.
if (result.IsError()) {
auto error = result.AcquireError();
- lostEvent->mReason = wgpu::DeviceLostReason::FailedCreation;
- lostEvent->mMessage = "Failed to create device:\n" + error->GetFormattedMessage();
+ lostEvent->SetLost(mInstance->GetEventManager(), wgpu::DeviceLostReason::FailedCreation,
+ "Failed to create device:\n" + error->GetFormattedMessage());
// When the device fails to initialize, we need to both promote the device ref to an
// external ref to clean up resources, and drop it, so we acquire it in this scope.
APIRef<DeviceBase> device;
device.Acquire(ReturnToAPI(std::move(lostEvent->mDevice)));
-
- mInstance->GetEventManager()->SetFutureReady(lostEvent.Get());
return {lostEvent, std::move(error)};
}
diff --git a/src/dawn/native/Device.cpp b/src/dawn/native/Device.cpp
index 4530de9..aa2b382 100644
--- a/src/dawn/native/Device.cpp
+++ b/src/dawn/native/Device.cpp
@@ -194,6 +194,20 @@
return AcquireRef(new DeviceBase::DeviceLostEvent(deviceLostCallbackInfo));
}
+void DeviceBase::DeviceLostEvent::SetLost(EventManager* eventManager,
+ wgpu::DeviceLostReason reason,
+ std::string_view message) {
+ mReason = reason;
+ mMessage = message;
+ eventManager->SetFutureReady(this);
+ if (mDevice) {
+ // If the device was already set, then the device must be associated with this event. Since
+ // the event should only be set and triggered once, unset the event in the device now.
+ mDevice->mLostFuture = GetFuture();
+ mDevice->mLostEvent = nullptr;
+ }
+}
+
void DeviceBase::DeviceLostEvent::Complete(EventCompletionType completionType) {
if (completionType == EventCompletionType::Shutdown) {
mReason = wgpu::DeviceLostReason::CallbackCancelled;
@@ -420,7 +434,8 @@
mQueue = nullptr;
}
-MaybeError DeviceBase::Initialize(Ref<QueueBase> defaultQueue) {
+MaybeError DeviceBase::Initialize(const UnpackedPtr<DeviceDescriptor>& descriptor,
+ Ref<QueueBase> defaultQueue) {
mQueue = std::move(defaultQueue);
SetWGSLExtensionAllowList();
@@ -438,6 +453,11 @@
// alive.
mState = State::Alive;
+ // Fake an error after the creation of a device here for testing.
+ if (descriptor.Get<DawnFakeDeviceInitializeErrorForTesting>() != nullptr) {
+ return DAWN_INTERNAL_ERROR("DawnFakeDeviceInitialzeErrorForTesting");
+ }
+
DAWN_TRY_ASSIGN(mEmptyBindGroupLayout, CreateEmptyBindGroupLayout());
DAWN_TRY_ASSIGN(mEmptyPipelineLayout, CreateEmptyPipelineLayout());
@@ -447,13 +467,13 @@
constexpr char kEmptyFragmentShader[] = R"(
@fragment fn fs_empty_main() {}
)";
- ShaderModuleDescriptor descriptor;
+ ShaderModuleDescriptor shaderDesc;
ShaderSourceWGSL wgslDesc;
wgslDesc.code = kEmptyFragmentShader;
- descriptor.nextInChain = &wgslDesc;
+ shaderDesc.nextInChain = &wgslDesc;
DAWN_TRY_ASSIGN(mInternalPipelineStore->placeholderFragmentShader,
- CreateShaderModule(&descriptor, /* internalExtensions */ {}));
+ CreateShaderModule(&shaderDesc, /* internalExtensions */ {}));
}
if (HasFeature(Feature::ImplicitDeviceSynchronization)) {
@@ -673,11 +693,7 @@
void DeviceBase::HandleDeviceLost(wgpu::DeviceLostReason reason, std::string_view message) {
if (mLostEvent != nullptr) {
- mLostEvent->mReason = reason;
- mLostEvent->mMessage = message;
- GetInstance()->GetEventManager()->SetFutureReady(mLostEvent.Get());
- mLostFuture = mLostEvent->GetFuture();
- mLostEvent = nullptr;
+ mLostEvent->SetLost(GetInstance()->GetEventManager(), reason, message);
}
}
diff --git a/src/dawn/native/Device.h b/src/dawn/native/Device.h
index c286bf8..5e09a96 100644
--- a/src/dawn/native/Device.h
+++ b/src/dawn/native/Device.h
@@ -88,14 +88,11 @@
struct DeviceLostEvent final : public EventManager::TrackedEvent {
static Ref<DeviceLostEvent> Create(const DeviceDescriptor* descriptor);
- // Event result fields need to be public so that they can easily be updated prior to
- // completing the event.
- wgpu::DeviceLostReason mReason;
- std::string mMessage;
+ // Sets the device lost event's fields and sets the event to be ready.
+ void SetLost(EventManager* eventManager,
+ wgpu::DeviceLostReason reason,
+ std::string_view message);
- WGPUDeviceLostCallback mCallback = nullptr;
- raw_ptr<void> mUserdata1;
- raw_ptr<void> mUserdata2;
// Note that the device is set when the event is passed to construct a device.
Ref<DeviceBase> mDevice = nullptr;
@@ -104,6 +101,13 @@
~DeviceLostEvent() override;
void Complete(EventCompletionType completionType) override;
+
+ wgpu::DeviceLostReason mReason;
+ std::string mMessage;
+
+ WGPUDeviceLostCallback mCallback = nullptr;
+ raw_ptr<void> mUserdata1;
+ raw_ptr<void> mUserdata2;
};
DeviceBase(AdapterBase* adapter,
@@ -442,7 +446,8 @@
void ForceEnableFeatureForTesting(Feature feature);
- MaybeError Initialize(Ref<QueueBase> defaultQueue);
+ MaybeError Initialize(const UnpackedPtr<DeviceDescriptor>& descriptor,
+ Ref<QueueBase> defaultQueue);
void DestroyObjects();
void Destroy();
diff --git a/src/dawn/native/d3d11/DeviceD3D11.cpp b/src/dawn/native/d3d11/DeviceD3D11.cpp
index 1a743d1..163192e 100644
--- a/src/dawn/native/d3d11/DeviceD3D11.cpp
+++ b/src/dawn/native/d3d11/DeviceD3D11.cpp
@@ -203,7 +203,7 @@
Ref<Queue> queue;
DAWN_TRY_ASSIGN(queue, Queue::Create(this, &descriptor->defaultQueue));
- DAWN_TRY(DeviceBase::Initialize(queue));
+ DAWN_TRY(DeviceBase::Initialize(descriptor, queue));
DAWN_TRY(queue->InitializePendingContext());
SetLabelImpl();
diff --git a/src/dawn/native/d3d12/DeviceD3D12.cpp b/src/dawn/native/d3d12/DeviceD3D12.cpp
index 4396c0f..2247e3c 100644
--- a/src/dawn/native/d3d12/DeviceD3D12.cpp
+++ b/src/dawn/native/d3d12/DeviceD3D12.cpp
@@ -176,7 +176,7 @@
GetD3D12Device()->CreateCommandSignature(&programDesc, nullptr,
IID_PPV_ARGS(&mDrawIndexedIndirectSignature));
- DAWN_TRY(DeviceBase::Initialize(std::move(queue)));
+ DAWN_TRY(DeviceBase::Initialize(descriptor, std::move(queue)));
// Ensure DXC if use_dxc toggle is set.
DAWN_TRY(EnsureDXCIfRequired());
diff --git a/src/dawn/native/metal/DeviceMTL.mm b/src/dawn/native/metal/DeviceMTL.mm
index 935ac58..40db656 100644
--- a/src/dawn/native/metal/DeviceMTL.mm
+++ b/src/dawn/native/metal/DeviceMTL.mm
@@ -187,7 +187,7 @@
}
}
- return DeviceBase::Initialize(std::move(queue));
+ return DeviceBase::Initialize(descriptor, std::move(queue));
}
ResultOrError<Ref<BindGroupBase>> Device::CreateBindGroupImpl(
diff --git a/src/dawn/native/null/DeviceNull.cpp b/src/dawn/native/null/DeviceNull.cpp
index e089c8f..80d1946 100644
--- a/src/dawn/native/null/DeviceNull.cpp
+++ b/src/dawn/native/null/DeviceNull.cpp
@@ -192,7 +192,8 @@
}
MaybeError Device::Initialize(const UnpackedPtr<DeviceDescriptor>& descriptor) {
- return DeviceBase::Initialize(AcquireRef(new Queue(this, &descriptor->defaultQueue)));
+ return DeviceBase::Initialize(descriptor,
+ AcquireRef(new Queue(this, &descriptor->defaultQueue)));
}
ResultOrError<Ref<BindGroupBase>> Device::CreateBindGroupImpl(
diff --git a/src/dawn/native/opengl/DeviceGL.cpp b/src/dawn/native/opengl/DeviceGL.cpp
index 86e46ab..696175a 100644
--- a/src/dawn/native/opengl/DeviceGL.cpp
+++ b/src/dawn/native/opengl/DeviceGL.cpp
@@ -210,7 +210,7 @@
DAWN_GL_TRY(gl, GetIntegerv(GL_MAX_TEXTURE_MAX_ANISOTROPY, &mMaxTextureMaxAnisotropy));
}
- DAWN_TRY(DeviceBase::Initialize(std::move(queue)));
+ DAWN_TRY(DeviceBase::Initialize(descriptor, std::move(queue)));
// Create internal buffers needed for workarounds.
if (mTextureBuiltinsBuffer.Get() == nullptr) {
diff --git a/src/dawn/native/vulkan/DeviceVk.cpp b/src/dawn/native/vulkan/DeviceVk.cpp
index 299a4e4..ca84294 100644
--- a/src/dawn/native/vulkan/DeviceVk.cpp
+++ b/src/dawn/native/vulkan/DeviceVk.cpp
@@ -188,7 +188,7 @@
Ref<Queue> queue;
DAWN_TRY_ASSIGN(queue, Queue::Create(this, &descriptor->defaultQueue, mMainQueueFamily));
- return DeviceBase::Initialize(std::move(queue));
+ return DeviceBase::Initialize(descriptor, std::move(queue));
}
Device::~Device() {
diff --git a/src/dawn/native/webgpu/DeviceWGPU.cpp b/src/dawn/native/webgpu/DeviceWGPU.cpp
index 259592c..30dd2bb 100644
--- a/src/dawn/native/webgpu/DeviceWGPU.cpp
+++ b/src/dawn/native/webgpu/DeviceWGPU.cpp
@@ -125,7 +125,7 @@
MaybeError Device::Initialize(const UnpackedPtr<DeviceDescriptor>& descriptor) {
Ref<Queue> queue;
DAWN_TRY_ASSIGN(queue, Queue::Create(this, &descriptor->defaultQueue));
- return DeviceBase::Initialize(std::move(queue));
+ return DeviceBase::Initialize(descriptor, std::move(queue));
}
ResultOrError<Ref<BindGroupBase>> Device::CreateBindGroupImpl(
diff --git a/src/dawn/tests/unittests/native/mocks/DeviceMock.cpp b/src/dawn/tests/unittests/native/mocks/DeviceMock.cpp
index bf24b73..09f25ed 100644
--- a/src/dawn/tests/unittests/native/mocks/DeviceMock.cpp
+++ b/src/dawn/tests/unittests/native/mocks/DeviceMock.cpp
@@ -128,7 +128,8 @@
// Initialize the device.
GetInstance()->GetEventManager()->TrackEvent(mLostEvent);
QueueDescriptor desc = {};
- EXPECT_FALSE(Initialize(AcquireRef(new NiceMock<QueueMock>(this, &desc))).IsError());
+ EXPECT_FALSE(
+ Initialize(descriptor, AcquireRef(new NiceMock<QueueMock>(this, &desc))).IsError());
}
DeviceMock::~DeviceMock() = default;
diff --git a/src/dawn/tests/unittests/validation/DeviceValidationTests.cpp b/src/dawn/tests/unittests/validation/DeviceValidationTests.cpp
index 2b0ff80..3657761 100644
--- a/src/dawn/tests/unittests/validation/DeviceValidationTests.cpp
+++ b/src/dawn/tests/unittests/validation/DeviceValidationTests.cpp
@@ -37,7 +37,9 @@
namespace dawn {
namespace {
+using testing::_;
using testing::EmptySizedString;
+using testing::InSequence;
using testing::IsNull;
using testing::MockCppCallback;
using testing::NonEmptySizedString;
@@ -46,13 +48,14 @@
class RequestDeviceValidationTest : public ValidationTest {
protected:
+ using MockDeviceLostCallback = MockCppCallback<wgpu::DeviceLostCallback<void>*>;
+
void SetUp() override {
ValidationTest::SetUp();
DAWN_SKIP_TEST_IF(UsesWire());
}
- MockCppCallback<void (*)(wgpu::RequestDeviceStatus, wgpu::Device, wgpu::StringView)>
- mRequestDeviceCallback;
+ MockCppCallback<wgpu::RequestDeviceCallback<void>*> mRequestDeviceCallback;
};
// Test that requesting a device without specifying limits is valid.
@@ -224,6 +227,61 @@
mRequestDeviceCallback.Callback());
}
+// Test that if an error occurs when requesting a device, the device lost callback is called
+// appropriately.
+TEST_F(RequestDeviceValidationTest, ErrorTriggersDeviceLost) {
+ // Invalid descriptor chains:
+ // - ChainedStruct: This should cause an early validation error.
+ // - DawnFakeDeviceInitializeErrorForTesting: This should cause an internal device error.
+ wgpu::ChainedStruct chain1;
+ wgpu::DawnFakeDeviceInitializeErrorForTesting chain2;
+ std::array<wgpu::ChainedStruct*, 2> chains = {&chain1, &chain2};
+
+ for (const auto* chain : chains) {
+ SCOPED_TRACE(absl::StrFormat("Chain SType: %s", chain->sType));
+ {
+ wgpu::DeviceDescriptor descriptor;
+ descriptor.nextInChain = chain;
+
+ // Device lost callback mode: AllowSpontaneous.
+ MockDeviceLostCallback lostCb;
+ descriptor.SetDeviceLostCallback(wgpu::CallbackMode::AllowSpontaneous,
+ lostCb.Callback());
+
+ // When in spontaneous mode, the request device callback should fire immediately before
+ // the device lost callback.
+ InSequence s;
+ EXPECT_CALL(mRequestDeviceCallback,
+ Call(wgpu::RequestDeviceStatus::Error, IsNull(), NonEmptySizedString()))
+ .Times(1);
+ EXPECT_CALL(lostCb, Call(_, wgpu::DeviceLostReason::FailedCreation, _)).Times(1);
+ adapter.RequestDevice(&descriptor, wgpu::CallbackMode::AllowSpontaneous,
+ mRequestDeviceCallback.Callback());
+ }
+ {
+ wgpu::DeviceDescriptor descriptor;
+ descriptor.nextInChain = chain;
+
+ // Device lost callback mode: AllowProcessEvents.
+ MockDeviceLostCallback lostCb;
+ descriptor.SetDeviceLostCallback(wgpu::CallbackMode::AllowProcessEvents,
+ lostCb.Callback());
+
+ EXPECT_CALL(mRequestDeviceCallback,
+ Call(wgpu::RequestDeviceStatus::Error, IsNull(), NonEmptySizedString()))
+ .Times(1);
+ adapter.RequestDevice(&descriptor, wgpu::CallbackMode::AllowSpontaneous,
+ mRequestDeviceCallback.Callback());
+
+ // When in a non-spontaneous mode for the device lost, the request device callback
+ // should fire, but the device lost callback should only fire when the Instance level
+ // API is called.
+ EXPECT_CALL(lostCb, Call(_, wgpu::DeviceLostReason::FailedCreation, _)).Times(1);
+ instance.ProcessEvents();
+ }
+ }
+}
+
class DeviceTickValidationTest : public ValidationTest {};
// Device destroy before API-level Tick should always result in no-op and false.