[dawn][native] Adds a way to fake an error after Device creation.

- This is a follow up to:
  https://dawn-review.googlesource.com/c/dawn/+/245734 where the
  issues in the bugs below should be fixed.
- The fuzz tests linked below were a result of an error happening
  after a Device was successfully allocated, but failed to initialize.
  This is possible in the fuzz tests because we allow Vulkan failure
  fuzzing.
- Updates the device lost event to be updated appropriately when the
  device failed to create. Otherwise, when a device was created but
  initialize failed, the device lost would actually look like the
  device was destroyed. Now only the first |SetLost| on the event
  will be reflected, and hence users will see |FailedCreation|.
- In order to better protect against this, this change adds a way to
  fake a failure in the frontend and adds relevant regression tests.
- Note that as I was making this change, I also filed
  https://github.com/webgpu-native/webgpu-headers/issues/546 upstream
  because of the interaction between RequestDevice and the
  DeviceLostCallback when we have different callback modes for them.

Bug: 421746158, 415350701
Change-Id: I3f2943c66d435b4bbbdca5fbb26c6c2ac388e160
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/245777
Commit-Queue: Brandon Jones <bajones@chromium.org>
Auto-Submit: Loko Kung <lokokung@google.com>
Reviewed-by: Brandon Jones <bajones@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/dawn/dawn.json b/src/dawn/dawn.json
index 6a80502..a544399 100644
--- a/src/dawn/dawn.json
+++ b/src/dawn/dawn.json
@@ -2158,6 +2158,13 @@
             {"name": "fake OOM at device", "type": "bool"}
         ]
     },
+    "dawn fake device initialize error for testing": {
+        "category": "structure",
+        "chained": "in",
+        "chain roots": ["device descriptor"],
+        "tags": ["dawn"],
+        "members": []
+    },
     "shared fence type": {
         "category": "enum",
         "tags": ["dawn", "native"],
@@ -3864,7 +3871,8 @@
             {"value": 66, "name": "dawn device allocator control", "tags": ["dawn"]},
             {"value": 67, "name": "dawn host mapped pointer limits", "tags": ["dawn"]},
             {"value": 68, "name": "render pass descriptor resolve rect", "tags": ["dawn"]},
-            {"value": 69, "name": "request adapter WebGPU backend options", "tags": ["dawn", "native"]}
+            {"value": 69, "name": "request adapter WebGPU backend options", "tags": ["dawn", "native"]},
+            {"value": 70, "name": "dawn fake device initialize error for testing", "tags": ["dawn"]}
         ]
     },
     "texture": {
diff --git a/src/dawn/native/Adapter.cpp b/src/dawn/native/Adapter.cpp
index 6ebd897..a6ad007 100644
--- a/src/dawn/native/Adapter.cpp
+++ b/src/dawn/native/Adapter.cpp
@@ -318,15 +318,13 @@
     // Catch any errors to directly complete the device lost event with the error message.
     if (result.IsError()) {
         auto error = result.AcquireError();
-        lostEvent->mReason = wgpu::DeviceLostReason::FailedCreation;
-        lostEvent->mMessage = "Failed to create device:\n" + error->GetFormattedMessage();
+        lostEvent->SetLost(mInstance->GetEventManager(), wgpu::DeviceLostReason::FailedCreation,
+                           "Failed to create device:\n" + error->GetFormattedMessage());
 
         // When the device fails to initialize, we need to both promote the device ref to an
         // external ref to clean up resources, and drop it, so we acquire it in this scope.
         APIRef<DeviceBase> device;
         device.Acquire(ReturnToAPI(std::move(lostEvent->mDevice)));
-
-        mInstance->GetEventManager()->SetFutureReady(lostEvent.Get());
         return {lostEvent, std::move(error)};
     }
 
diff --git a/src/dawn/native/Device.cpp b/src/dawn/native/Device.cpp
index 4530de9..aa2b382 100644
--- a/src/dawn/native/Device.cpp
+++ b/src/dawn/native/Device.cpp
@@ -194,6 +194,20 @@
     return AcquireRef(new DeviceBase::DeviceLostEvent(deviceLostCallbackInfo));
 }
 
+void DeviceBase::DeviceLostEvent::SetLost(EventManager* eventManager,
+                                          wgpu::DeviceLostReason reason,
+                                          std::string_view message) {
+    mReason = reason;
+    mMessage = message;
+    eventManager->SetFutureReady(this);
+    if (mDevice) {
+        // If the device was already set, then the device must be associated with this event. Since
+        // the event should only be set and triggered once, unset the event in the device now.
+        mDevice->mLostFuture = GetFuture();
+        mDevice->mLostEvent = nullptr;
+    }
+}
+
 void DeviceBase::DeviceLostEvent::Complete(EventCompletionType completionType) {
     if (completionType == EventCompletionType::Shutdown) {
         mReason = wgpu::DeviceLostReason::CallbackCancelled;
@@ -420,7 +434,8 @@
     mQueue = nullptr;
 }
 
-MaybeError DeviceBase::Initialize(Ref<QueueBase> defaultQueue) {
+MaybeError DeviceBase::Initialize(const UnpackedPtr<DeviceDescriptor>& descriptor,
+                                  Ref<QueueBase> defaultQueue) {
     mQueue = std::move(defaultQueue);
 
     SetWGSLExtensionAllowList();
@@ -438,6 +453,11 @@
     // alive.
     mState = State::Alive;
 
+    // Fake an error after the creation of a device here for testing.
+    if (descriptor.Get<DawnFakeDeviceInitializeErrorForTesting>() != nullptr) {
+        return DAWN_INTERNAL_ERROR("DawnFakeDeviceInitialzeErrorForTesting");
+    }
+
     DAWN_TRY_ASSIGN(mEmptyBindGroupLayout, CreateEmptyBindGroupLayout());
     DAWN_TRY_ASSIGN(mEmptyPipelineLayout, CreateEmptyPipelineLayout());
 
@@ -447,13 +467,13 @@
         constexpr char kEmptyFragmentShader[] = R"(
                 @fragment fn fs_empty_main() {}
             )";
-        ShaderModuleDescriptor descriptor;
+        ShaderModuleDescriptor shaderDesc;
         ShaderSourceWGSL wgslDesc;
         wgslDesc.code = kEmptyFragmentShader;
-        descriptor.nextInChain = &wgslDesc;
+        shaderDesc.nextInChain = &wgslDesc;
 
         DAWN_TRY_ASSIGN(mInternalPipelineStore->placeholderFragmentShader,
-                        CreateShaderModule(&descriptor, /* internalExtensions */ {}));
+                        CreateShaderModule(&shaderDesc, /* internalExtensions */ {}));
     }
 
     if (HasFeature(Feature::ImplicitDeviceSynchronization)) {
@@ -673,11 +693,7 @@
 
 void DeviceBase::HandleDeviceLost(wgpu::DeviceLostReason reason, std::string_view message) {
     if (mLostEvent != nullptr) {
-        mLostEvent->mReason = reason;
-        mLostEvent->mMessage = message;
-        GetInstance()->GetEventManager()->SetFutureReady(mLostEvent.Get());
-        mLostFuture = mLostEvent->GetFuture();
-        mLostEvent = nullptr;
+        mLostEvent->SetLost(GetInstance()->GetEventManager(), reason, message);
     }
 }
 
diff --git a/src/dawn/native/Device.h b/src/dawn/native/Device.h
index c286bf8..5e09a96 100644
--- a/src/dawn/native/Device.h
+++ b/src/dawn/native/Device.h
@@ -88,14 +88,11 @@
     struct DeviceLostEvent final : public EventManager::TrackedEvent {
         static Ref<DeviceLostEvent> Create(const DeviceDescriptor* descriptor);
 
-        // Event result fields need to be public so that they can easily be updated prior to
-        // completing the event.
-        wgpu::DeviceLostReason mReason;
-        std::string mMessage;
+        // Sets the device lost event's fields and sets the event to be ready.
+        void SetLost(EventManager* eventManager,
+                     wgpu::DeviceLostReason reason,
+                     std::string_view message);
 
-        WGPUDeviceLostCallback mCallback = nullptr;
-        raw_ptr<void> mUserdata1;
-        raw_ptr<void> mUserdata2;
         // Note that the device is set when the event is passed to construct a device.
         Ref<DeviceBase> mDevice = nullptr;
 
@@ -104,6 +101,13 @@
         ~DeviceLostEvent() override;
 
         void Complete(EventCompletionType completionType) override;
+
+        wgpu::DeviceLostReason mReason;
+        std::string mMessage;
+
+        WGPUDeviceLostCallback mCallback = nullptr;
+        raw_ptr<void> mUserdata1;
+        raw_ptr<void> mUserdata2;
     };
 
     DeviceBase(AdapterBase* adapter,
@@ -442,7 +446,8 @@
 
     void ForceEnableFeatureForTesting(Feature feature);
 
-    MaybeError Initialize(Ref<QueueBase> defaultQueue);
+    MaybeError Initialize(const UnpackedPtr<DeviceDescriptor>& descriptor,
+                          Ref<QueueBase> defaultQueue);
     void DestroyObjects();
     void Destroy();
 
diff --git a/src/dawn/native/d3d11/DeviceD3D11.cpp b/src/dawn/native/d3d11/DeviceD3D11.cpp
index 1a743d1..163192e 100644
--- a/src/dawn/native/d3d11/DeviceD3D11.cpp
+++ b/src/dawn/native/d3d11/DeviceD3D11.cpp
@@ -203,7 +203,7 @@
     Ref<Queue> queue;
     DAWN_TRY_ASSIGN(queue, Queue::Create(this, &descriptor->defaultQueue));
 
-    DAWN_TRY(DeviceBase::Initialize(queue));
+    DAWN_TRY(DeviceBase::Initialize(descriptor, queue));
     DAWN_TRY(queue->InitializePendingContext());
 
     SetLabelImpl();
diff --git a/src/dawn/native/d3d12/DeviceD3D12.cpp b/src/dawn/native/d3d12/DeviceD3D12.cpp
index 4396c0f..2247e3c 100644
--- a/src/dawn/native/d3d12/DeviceD3D12.cpp
+++ b/src/dawn/native/d3d12/DeviceD3D12.cpp
@@ -176,7 +176,7 @@
     GetD3D12Device()->CreateCommandSignature(&programDesc, nullptr,
                                              IID_PPV_ARGS(&mDrawIndexedIndirectSignature));
 
-    DAWN_TRY(DeviceBase::Initialize(std::move(queue)));
+    DAWN_TRY(DeviceBase::Initialize(descriptor, std::move(queue)));
 
     // Ensure DXC if use_dxc toggle is set.
     DAWN_TRY(EnsureDXCIfRequired());
diff --git a/src/dawn/native/metal/DeviceMTL.mm b/src/dawn/native/metal/DeviceMTL.mm
index 935ac58..40db656 100644
--- a/src/dawn/native/metal/DeviceMTL.mm
+++ b/src/dawn/native/metal/DeviceMTL.mm
@@ -187,7 +187,7 @@
         }
     }
 
-    return DeviceBase::Initialize(std::move(queue));
+    return DeviceBase::Initialize(descriptor, std::move(queue));
 }
 
 ResultOrError<Ref<BindGroupBase>> Device::CreateBindGroupImpl(
diff --git a/src/dawn/native/null/DeviceNull.cpp b/src/dawn/native/null/DeviceNull.cpp
index e089c8f..80d1946 100644
--- a/src/dawn/native/null/DeviceNull.cpp
+++ b/src/dawn/native/null/DeviceNull.cpp
@@ -192,7 +192,8 @@
 }
 
 MaybeError Device::Initialize(const UnpackedPtr<DeviceDescriptor>& descriptor) {
-    return DeviceBase::Initialize(AcquireRef(new Queue(this, &descriptor->defaultQueue)));
+    return DeviceBase::Initialize(descriptor,
+                                  AcquireRef(new Queue(this, &descriptor->defaultQueue)));
 }
 
 ResultOrError<Ref<BindGroupBase>> Device::CreateBindGroupImpl(
diff --git a/src/dawn/native/opengl/DeviceGL.cpp b/src/dawn/native/opengl/DeviceGL.cpp
index 86e46ab..696175a 100644
--- a/src/dawn/native/opengl/DeviceGL.cpp
+++ b/src/dawn/native/opengl/DeviceGL.cpp
@@ -210,7 +210,7 @@
         DAWN_GL_TRY(gl, GetIntegerv(GL_MAX_TEXTURE_MAX_ANISOTROPY, &mMaxTextureMaxAnisotropy));
     }
 
-    DAWN_TRY(DeviceBase::Initialize(std::move(queue)));
+    DAWN_TRY(DeviceBase::Initialize(descriptor, std::move(queue)));
 
     // Create internal buffers needed for workarounds.
     if (mTextureBuiltinsBuffer.Get() == nullptr) {
diff --git a/src/dawn/native/vulkan/DeviceVk.cpp b/src/dawn/native/vulkan/DeviceVk.cpp
index 299a4e4..ca84294 100644
--- a/src/dawn/native/vulkan/DeviceVk.cpp
+++ b/src/dawn/native/vulkan/DeviceVk.cpp
@@ -188,7 +188,7 @@
     Ref<Queue> queue;
     DAWN_TRY_ASSIGN(queue, Queue::Create(this, &descriptor->defaultQueue, mMainQueueFamily));
 
-    return DeviceBase::Initialize(std::move(queue));
+    return DeviceBase::Initialize(descriptor, std::move(queue));
 }
 
 Device::~Device() {
diff --git a/src/dawn/native/webgpu/DeviceWGPU.cpp b/src/dawn/native/webgpu/DeviceWGPU.cpp
index 259592c..30dd2bb 100644
--- a/src/dawn/native/webgpu/DeviceWGPU.cpp
+++ b/src/dawn/native/webgpu/DeviceWGPU.cpp
@@ -125,7 +125,7 @@
 MaybeError Device::Initialize(const UnpackedPtr<DeviceDescriptor>& descriptor) {
     Ref<Queue> queue;
     DAWN_TRY_ASSIGN(queue, Queue::Create(this, &descriptor->defaultQueue));
-    return DeviceBase::Initialize(std::move(queue));
+    return DeviceBase::Initialize(descriptor, std::move(queue));
 }
 
 ResultOrError<Ref<BindGroupBase>> Device::CreateBindGroupImpl(
diff --git a/src/dawn/tests/unittests/native/mocks/DeviceMock.cpp b/src/dawn/tests/unittests/native/mocks/DeviceMock.cpp
index bf24b73..09f25ed 100644
--- a/src/dawn/tests/unittests/native/mocks/DeviceMock.cpp
+++ b/src/dawn/tests/unittests/native/mocks/DeviceMock.cpp
@@ -128,7 +128,8 @@
     // Initialize the device.
     GetInstance()->GetEventManager()->TrackEvent(mLostEvent);
     QueueDescriptor desc = {};
-    EXPECT_FALSE(Initialize(AcquireRef(new NiceMock<QueueMock>(this, &desc))).IsError());
+    EXPECT_FALSE(
+        Initialize(descriptor, AcquireRef(new NiceMock<QueueMock>(this, &desc))).IsError());
 }
 
 DeviceMock::~DeviceMock() = default;
diff --git a/src/dawn/tests/unittests/validation/DeviceValidationTests.cpp b/src/dawn/tests/unittests/validation/DeviceValidationTests.cpp
index 2b0ff80..3657761 100644
--- a/src/dawn/tests/unittests/validation/DeviceValidationTests.cpp
+++ b/src/dawn/tests/unittests/validation/DeviceValidationTests.cpp
@@ -37,7 +37,9 @@
 namespace dawn {
 namespace {
 
+using testing::_;
 using testing::EmptySizedString;
+using testing::InSequence;
 using testing::IsNull;
 using testing::MockCppCallback;
 using testing::NonEmptySizedString;
@@ -46,13 +48,14 @@
 
 class RequestDeviceValidationTest : public ValidationTest {
   protected:
+    using MockDeviceLostCallback = MockCppCallback<wgpu::DeviceLostCallback<void>*>;
+
     void SetUp() override {
         ValidationTest::SetUp();
         DAWN_SKIP_TEST_IF(UsesWire());
     }
 
-    MockCppCallback<void (*)(wgpu::RequestDeviceStatus, wgpu::Device, wgpu::StringView)>
-        mRequestDeviceCallback;
+    MockCppCallback<wgpu::RequestDeviceCallback<void>*> mRequestDeviceCallback;
 };
 
 // Test that requesting a device without specifying limits is valid.
@@ -224,6 +227,61 @@
                           mRequestDeviceCallback.Callback());
 }
 
+// Test that if an error occurs when requesting a device, the device lost callback is called
+// appropriately.
+TEST_F(RequestDeviceValidationTest, ErrorTriggersDeviceLost) {
+    // Invalid descriptor chains:
+    //   - ChainedStruct: This should cause an early validation error.
+    //   - DawnFakeDeviceInitializeErrorForTesting: This should cause an internal device error.
+    wgpu::ChainedStruct chain1;
+    wgpu::DawnFakeDeviceInitializeErrorForTesting chain2;
+    std::array<wgpu::ChainedStruct*, 2> chains = {&chain1, &chain2};
+
+    for (const auto* chain : chains) {
+        SCOPED_TRACE(absl::StrFormat("Chain SType: %s", chain->sType));
+        {
+            wgpu::DeviceDescriptor descriptor;
+            descriptor.nextInChain = chain;
+
+            // Device lost callback mode: AllowSpontaneous.
+            MockDeviceLostCallback lostCb;
+            descriptor.SetDeviceLostCallback(wgpu::CallbackMode::AllowSpontaneous,
+                                             lostCb.Callback());
+
+            // When in spontaneous mode, the request device callback should fire immediately before
+            // the device lost callback.
+            InSequence s;
+            EXPECT_CALL(mRequestDeviceCallback,
+                        Call(wgpu::RequestDeviceStatus::Error, IsNull(), NonEmptySizedString()))
+                .Times(1);
+            EXPECT_CALL(lostCb, Call(_, wgpu::DeviceLostReason::FailedCreation, _)).Times(1);
+            adapter.RequestDevice(&descriptor, wgpu::CallbackMode::AllowSpontaneous,
+                                  mRequestDeviceCallback.Callback());
+        }
+        {
+            wgpu::DeviceDescriptor descriptor;
+            descriptor.nextInChain = chain;
+
+            // Device lost callback mode: AllowProcessEvents.
+            MockDeviceLostCallback lostCb;
+            descriptor.SetDeviceLostCallback(wgpu::CallbackMode::AllowProcessEvents,
+                                             lostCb.Callback());
+
+            EXPECT_CALL(mRequestDeviceCallback,
+                        Call(wgpu::RequestDeviceStatus::Error, IsNull(), NonEmptySizedString()))
+                .Times(1);
+            adapter.RequestDevice(&descriptor, wgpu::CallbackMode::AllowSpontaneous,
+                                  mRequestDeviceCallback.Callback());
+
+            // When in a non-spontaneous mode for the device lost, the request device callback
+            // should fire, but the device lost callback should only fire when the Instance level
+            // API is called.
+            EXPECT_CALL(lostCb, Call(_, wgpu::DeviceLostReason::FailedCreation, _)).Times(1);
+            instance.ProcessEvents();
+        }
+    }
+}
+
 class DeviceTickValidationTest : public ValidationTest {};
 
 // Device destroy before API-level Tick should always result in no-op and false.