Handle Device Lost for Buffer

Bug: dawn:68, chromium:1042998, chromium:1043468
Change-Id: I4faa46b0d2e8f814b9d353a75489d3c8ca0b2e89
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/15340
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Natasha Lee <natlee@microsoft.com>
diff --git a/src/dawn_native/Buffer.cpp b/src/dawn_native/Buffer.cpp
index 8e72bd7..e387240 100644
--- a/src/dawn_native/Buffer.cpp
+++ b/src/dawn_native/Buffer.cpp
@@ -202,11 +202,17 @@
         ASSERT(!IsError());
         if (mMapReadCallback != nullptr && serial == mMapSerial) {
             ASSERT(mMapWriteCallback == nullptr);
+
             // Tag the callback as fired before firing it, otherwise it could fire a second time if
             // for example buffer.Unmap() is called inside the application-provided callback.
             WGPUBufferMapReadCallback callback = mMapReadCallback;
             mMapReadCallback = nullptr;
-            callback(status, pointer, dataLength, mMapUserdata);
+
+            if (GetDevice()->IsLost()) {
+                callback(WGPUBufferMapAsyncStatus_DeviceLost, nullptr, 0, mMapUserdata);
+            } else {
+                callback(status, pointer, dataLength, mMapUserdata);
+            }
         }
     }
 
@@ -217,11 +223,17 @@
         ASSERT(!IsError());
         if (mMapWriteCallback != nullptr && serial == mMapSerial) {
             ASSERT(mMapReadCallback == nullptr);
+
             // Tag the callback as fired before firing it, otherwise it could fire a second time if
             // for example buffer.Unmap() is called inside the application-provided callback.
             WGPUBufferMapWriteCallback callback = mMapWriteCallback;
             mMapWriteCallback = nullptr;
-            callback(status, pointer, dataLength, mMapUserdata);
+
+            if (GetDevice()->IsLost()) {
+                callback(WGPUBufferMapAsyncStatus_DeviceLost, nullptr, 0, mMapUserdata);
+            } else {
+                callback(status, pointer, dataLength, mMapUserdata);
+            }
         }
     }
 
@@ -237,8 +249,9 @@
     }
 
     void BufferBase::MapReadAsync(WGPUBufferMapReadCallback callback, void* userdata) {
-        if (GetDevice()->ConsumedError(ValidateMap(wgpu::BufferUsage::MapRead))) {
-            callback(WGPUBufferMapAsyncStatus_Error, nullptr, 0, userdata);
+        WGPUBufferMapAsyncStatus status;
+        if (GetDevice()->ConsumedError(ValidateMap(wgpu::BufferUsage::MapRead, &status))) {
+            callback(status, nullptr, 0, userdata);
             return;
         }
         ASSERT(!IsError());
@@ -273,8 +286,9 @@
     }
 
     void BufferBase::MapWriteAsync(WGPUBufferMapWriteCallback callback, void* userdata) {
-        if (GetDevice()->ConsumedError(ValidateMap(wgpu::BufferUsage::MapWrite))) {
-            callback(WGPUBufferMapAsyncStatus_Error, nullptr, 0, userdata);
+        WGPUBufferMapAsyncStatus status;
+        if (GetDevice()->ConsumedError(ValidateMap(wgpu::BufferUsage::MapWrite, &status))) {
+            callback(status, nullptr, 0, userdata);
             return;
         }
         ASSERT(!IsError());
@@ -389,8 +403,12 @@
         return {};
     }
 
-    MaybeError BufferBase::ValidateMap(wgpu::BufferUsage requiredUsage) const {
+    MaybeError BufferBase::ValidateMap(wgpu::BufferUsage requiredUsage,
+                                       WGPUBufferMapAsyncStatus* status) const {
+        *status = WGPUBufferMapAsyncStatus_DeviceLost;
         DAWN_TRY(GetDevice()->ValidateIsAlive());
+
+        *status = WGPUBufferMapAsyncStatus_Error;
         DAWN_TRY(GetDevice()->ValidateObject(this));
 
         switch (mState) {
@@ -406,6 +424,7 @@
             return DAWN_VALIDATION_ERROR("Buffer needs the correct map usage bit");
         }
 
+        *status = WGPUBufferMapAsyncStatus_Success;
         return {};
     }
 
diff --git a/src/dawn_native/Buffer.h b/src/dawn_native/Buffer.h
index 054e555..6a0a0e1 100644
--- a/src/dawn_native/Buffer.h
+++ b/src/dawn_native/Buffer.h
@@ -94,7 +94,8 @@
         MaybeError CopyFromStagingBuffer();
 
         MaybeError ValidateSetSubData(uint32_t start, uint32_t count) const;
-        MaybeError ValidateMap(wgpu::BufferUsage requiredUsage) const;
+        MaybeError ValidateMap(wgpu::BufferUsage requiredUsage,
+                               WGPUBufferMapAsyncStatus* status) const;
         MaybeError ValidateUnmap() const;
         MaybeError ValidateDestroy() const;
 
diff --git a/src/dawn_native/Device.cpp b/src/dawn_native/Device.cpp
index 81a3177..abf8d60 100644
--- a/src/dawn_native/Device.cpp
+++ b/src/dawn_native/Device.cpp
@@ -198,6 +198,10 @@
         HandleError(wgpu::ErrorType::DeviceLost, "Device lost for testing");
     }
 
+    bool DeviceBase::IsLost() const {
+        return mLossStatus != LossStatus::Alive;
+    }
+
     AdapterBase* DeviceBase::GetAdapter() const {
         return mAdapter;
     }
@@ -478,7 +482,9 @@
         WGPUCreateBufferMappedResult result = CreateBufferMapped(descriptor);
 
         WGPUBufferMapAsyncStatus status = WGPUBufferMapAsyncStatus_Success;
-        if (result.data == nullptr || result.dataLength != descriptor->size) {
+        if (IsLost()) {
+            status = WGPUBufferMapAsyncStatus_DeviceLost;
+        } else if (result.data == nullptr || result.dataLength != descriptor->size) {
             status = WGPUBufferMapAsyncStatus_Error;
         }
 
@@ -594,6 +600,14 @@
     // Other Device API methods
 
     void DeviceBase::Tick() {
+        // We need to do the deferred callback even if Device is lost since Buffer Map Async will
+        // send callback with device lost status when device is lost.
+        {
+            auto deferredResults = std::move(mDeferredCreateBufferMappedAsyncResults);
+            for (const auto& deferred : deferredResults) {
+                deferred.callback(deferred.status, deferred.result, deferred.userdata);
+            }
+        }
         if (ConsumedError(ValidateIsAlive())) {
             return;
         }
@@ -601,12 +615,6 @@
             return;
         }
 
-        {
-            auto deferredResults = std::move(mDeferredCreateBufferMappedAsyncResults);
-            for (const auto& deferred : deferredResults) {
-                deferred.callback(deferred.status, deferred.result, deferred.userdata);
-            }
-        }
         mErrorScopeTracker->Tick(GetCompletedCommandSerial());
         mFenceSignalTracker->Tick(GetCompletedCommandSerial());
     }
diff --git a/src/dawn_native/Device.h b/src/dawn_native/Device.h
index 622c436..6474a64 100644
--- a/src/dawn_native/Device.h
+++ b/src/dawn_native/Device.h
@@ -190,6 +190,7 @@
         size_t GetLazyClearCountForTesting();
         void IncrementLazyClearCountForTesting();
         void LoseForTesting();
+        bool IsLost() const;
 
       protected:
         void SetToggle(Toggle toggle, bool isEnabled);
diff --git a/src/dawn_native/null/DeviceNull.cpp b/src/dawn_native/null/DeviceNull.cpp
index 2d29183..c7586eb 100644
--- a/src/dawn_native/null/DeviceNull.cpp
+++ b/src/dawn_native/null/DeviceNull.cpp
@@ -88,6 +88,9 @@
 
     Device::~Device() {
         BaseDestructor();
+        // This assert is in the destructor rather than Device::Destroy() because it needs to make
+        // sure buffers have been destroyed before the device.
+        ASSERT(mMemoryUsage == 0);
     }
 
     ResultOrError<BindGroupBase*> Device::CreateBindGroupImpl(
@@ -181,7 +184,6 @@
         mDynamicUploader = nullptr;
 
         mPendingOperations.clear();
-        ASSERT(mMemoryUsage == 0);
     }
 
     MaybeError Device::WaitForIdleForDestruction() {
diff --git a/src/tests/end2end/DeviceLostTests.cpp b/src/tests/end2end/DeviceLostTests.cpp
index f247de8..c5471f1 100644
--- a/src/tests/end2end/DeviceLostTests.cpp
+++ b/src/tests/end2end/DeviceLostTests.cpp
@@ -34,6 +34,8 @@
     self->StartExpectDeviceError();
 }
 
+static const int fakeUserData = 0;
+
 class DeviceLostTest : public DawnTest {
   protected:
     void TestSetUp() override {
@@ -52,6 +54,26 @@
         EXPECT_CALL(*mockDeviceLostCallback, Call(_, this)).Times(1);
         device.LoseForTesting();
     }
+
+    static void CheckMapWriteFail(WGPUBufferMapAsyncStatus status,
+                                  void* data,
+                                  uint64_t datalength,
+                                  void* userdata) {
+        EXPECT_EQ(WGPUBufferMapAsyncStatus_DeviceLost, status);
+        EXPECT_EQ(nullptr, data);
+        EXPECT_EQ(0u, datalength);
+        EXPECT_EQ(&fakeUserData, userdata);
+    }
+
+    static void CheckMapReadFail(WGPUBufferMapAsyncStatus status,
+                                 const void* data,
+                                 uint64_t datalength,
+                                 void* userdata) {
+        EXPECT_EQ(WGPUBufferMapAsyncStatus_DeviceLost, status);
+        EXPECT_EQ(nullptr, data);
+        EXPECT_EQ(0u, datalength);
+        EXPECT_EQ(&fakeUserData, userdata);
+    }
 };
 
 // Test that DeviceLostCallback is invoked when LostForTestimg is called
@@ -212,4 +234,127 @@
     ASSERT_DEVICE_ERROR(device.Tick());
 }
 
-DAWN_INSTANTIATE_TEST(DeviceLostTest, D3D12Backend, VulkanBackend);
+// Test that CreateBuffer fails when device is lost
+TEST_P(DeviceLostTest, CreateBufferFails) {
+    SetCallbackAndLoseForTesting();
+
+    wgpu::BufferDescriptor bufferDescriptor;
+    bufferDescriptor.size = sizeof(float);
+    bufferDescriptor.usage = wgpu::BufferUsage::CopySrc;
+    ASSERT_DEVICE_ERROR(device.CreateBuffer(&bufferDescriptor));
+}
+
+// Test that buffer.MapWriteAsync fails after device is lost
+TEST_P(DeviceLostTest, BufferMapWriteAsyncFails) {
+    wgpu::BufferDescriptor bufferDescriptor;
+    bufferDescriptor.size = sizeof(float);
+    bufferDescriptor.usage = wgpu::BufferUsage::MapWrite;
+    wgpu::Buffer buffer = device.CreateBuffer(&bufferDescriptor);
+
+    SetCallbackAndLoseForTesting();
+    ASSERT_DEVICE_ERROR(buffer.MapWriteAsync(CheckMapWriteFail, const_cast<int*>(&fakeUserData)));
+}
+
+// Test that buffer.MapWriteAsync calls back with device loss status
+TEST_P(DeviceLostTest, BufferMapWriteAsyncBeforeLossFails) {
+    wgpu::BufferDescriptor bufferDescriptor;
+    bufferDescriptor.size = sizeof(float);
+    bufferDescriptor.usage = wgpu::BufferUsage::MapWrite;
+    wgpu::Buffer buffer = device.CreateBuffer(&bufferDescriptor);
+
+    buffer.MapWriteAsync(CheckMapWriteFail, const_cast<int*>(&fakeUserData));
+    SetCallbackAndLoseForTesting();
+}
+
+// Test that buffer.Unmap fails after device is lost
+TEST_P(DeviceLostTest, BufferUnmapFails) {
+    wgpu::BufferDescriptor bufferDescriptor;
+    bufferDescriptor.size = sizeof(float);
+    bufferDescriptor.usage = wgpu::BufferUsage::MapWrite;
+    wgpu::Buffer buffer = device.CreateBuffer(&bufferDescriptor);
+    wgpu::CreateBufferMappedResult result = device.CreateBufferMapped(&bufferDescriptor);
+
+    SetCallbackAndLoseForTesting();
+    ASSERT_DEVICE_ERROR(result.buffer.Unmap());
+}
+
+// Test that CreateBufferMapped fails after device is lost
+TEST_P(DeviceLostTest, CreateBufferMappedFails) {
+    wgpu::BufferDescriptor bufferDescriptor;
+    bufferDescriptor.size = sizeof(float);
+    bufferDescriptor.usage = wgpu::BufferUsage::MapWrite;
+
+    SetCallbackAndLoseForTesting();
+    ASSERT_DEVICE_ERROR(device.CreateBufferMapped(&bufferDescriptor));
+}
+
+// Test that CreateBufferMappedAsync fails after device is lost
+TEST_P(DeviceLostTest, CreateBufferMappedAsyncFails) {
+    wgpu::BufferDescriptor bufferDescriptor;
+    bufferDescriptor.size = sizeof(float);
+    bufferDescriptor.usage = wgpu::BufferUsage::MapWrite;
+
+    SetCallbackAndLoseForTesting();
+    struct ResultInfo {
+        wgpu::CreateBufferMappedResult result;
+        bool done = false;
+    } resultInfo;
+
+    ASSERT_DEVICE_ERROR(device.CreateBufferMappedAsync(
+        &bufferDescriptor,
+        [](WGPUBufferMapAsyncStatus status, WGPUCreateBufferMappedResult result, void* userdata) {
+            auto* resultInfo = static_cast<ResultInfo*>(userdata);
+            EXPECT_EQ(WGPUBufferMapAsyncStatus_DeviceLost, status);
+            EXPECT_NE(nullptr, result.data);
+            resultInfo->result.buffer = wgpu::Buffer::Acquire(result.buffer);
+            resultInfo->result.data = result.data;
+            resultInfo->result.dataLength = result.dataLength;
+            resultInfo->done = true;
+        },
+        &resultInfo));
+
+    while (!resultInfo.done) {
+        ASSERT_DEVICE_ERROR(WaitABit());
+    }
+
+    ASSERT_DEVICE_ERROR(resultInfo.result.buffer.Unmap());
+}
+
+// Test that BufferMapReadAsync fails after device is lost
+TEST_P(DeviceLostTest, BufferMapReadAsyncFails) {
+    wgpu::BufferDescriptor bufferDescriptor;
+    bufferDescriptor.size = sizeof(float);
+    bufferDescriptor.usage = wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst;
+
+    wgpu::Buffer buffer = device.CreateBuffer(&bufferDescriptor);
+
+    SetCallbackAndLoseForTesting();
+    ASSERT_DEVICE_ERROR(buffer.MapReadAsync(CheckMapReadFail, const_cast<int*>(&fakeUserData)));
+}
+
+// Test that BufferMapReadAsync calls back with device lost status when device lost after map read
+TEST_P(DeviceLostTest, BufferMapReadAsyncBeforeLossFails) {
+    wgpu::BufferDescriptor bufferDescriptor;
+    bufferDescriptor.size = sizeof(float);
+    bufferDescriptor.usage = wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst;
+
+    wgpu::Buffer buffer = device.CreateBuffer(&bufferDescriptor);
+
+    buffer.MapReadAsync(CheckMapReadFail, const_cast<int*>(&fakeUserData));
+    SetCallbackAndLoseForTesting();
+}
+
+// Test that SetSubData fails after device is lost
+TEST_P(DeviceLostTest, SetSubDataFails) {
+    wgpu::BufferDescriptor bufferDescriptor;
+    bufferDescriptor.size = sizeof(float);
+    bufferDescriptor.usage = wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst;
+
+    wgpu::Buffer buffer = device.CreateBuffer(&bufferDescriptor);
+
+    SetCallbackAndLoseForTesting();
+    std::array<float, 1> data = {12};
+    ASSERT_DEVICE_ERROR(buffer.SetSubData(0, sizeof(float), data.data()));
+}
+
+DAWN_INSTANTIATE_TEST(DeviceLostTest, D3D12Backend, VulkanBackend);
\ No newline at end of file