Fix crash when device is removed before CreateReady*Pipeline callback

This patch fixes a crash issue when the device is destroyed before
the callback of CreateReady{Render, Compute}Pipeline is called. Now
when the callback is called in DeviceBase::ShutDown(), the cached
pipeline object will also be destroyed before the callback returns.

BUG=dawn:529
TEST=dawn_end2end_tests

Change-Id: I91ec2608b53591d265c0648f5c02daf7fadac85e
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/30744
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
diff --git a/dawn.json b/dawn.json
index 6a3bb9a..5671e26 100644
--- a/dawn.json
+++ b/dawn.json
@@ -516,7 +516,8 @@
             {"value": 0, "name": "success"},
             {"value": 1, "name": "error"},
             {"value": 2, "name": "device lost"},
-            {"value": 3, "name": "unknown"}
+            {"value": 3, "name": "device destroyed"},
+            {"value": 4, "name": "unknown"}
         ]
     },
     "create ready render pipeline callback": {
diff --git a/src/dawn_native/CreateReadyPipelineTracker.cpp b/src/dawn_native/CreateReadyPipelineTracker.cpp
index cfc19cc..be64393 100644
--- a/src/dawn_native/CreateReadyPipelineTracker.cpp
+++ b/src/dawn_native/CreateReadyPipelineTracker.cpp
@@ -14,7 +14,9 @@
 
 #include "dawn_native/CreateReadyPipelineTracker.h"
 
+#include "dawn_native/ComputePipeline.h"
 #include "dawn_native/Device.h"
+#include "dawn_native/RenderPipeline.h"
 
 namespace dawn_native {
 
@@ -33,13 +35,21 @@
           mCreateReadyComputePipelineCallback(callback) {
     }
 
-    void CreateReadyComputePipelineTask::Finish() {
+    void CreateReadyComputePipelineTask::Finish(WGPUCreateReadyPipelineStatus status) {
         ASSERT(mPipeline != nullptr);
         ASSERT(mCreateReadyComputePipelineCallback != nullptr);
 
-        mCreateReadyComputePipelineCallback(WGPUCreateReadyPipelineStatus_Success,
-                                            reinterpret_cast<WGPUComputePipeline>(mPipeline), "",
-                                            mUserData);
+        if (status != WGPUCreateReadyPipelineStatus_Success) {
+            // TODO(jiawei.shao@intel.com): support handling device lost
+            ASSERT(status == WGPUCreateReadyPipelineStatus_DeviceDestroyed);
+            mCreateReadyComputePipelineCallback(WGPUCreateReadyPipelineStatus_DeviceDestroyed,
+                                                nullptr, "Device destroyed before callback",
+                                                mUserData);
+            mPipeline->Release();
+        } else {
+            mCreateReadyComputePipelineCallback(
+                status, reinterpret_cast<WGPUComputePipeline>(mPipeline), "", mUserData);
+        }
 
         // Set mCreateReadyComputePipelineCallback to nullptr in case it is called more than once.
         mCreateReadyComputePipelineCallback = nullptr;
@@ -54,13 +64,21 @@
           mCreateReadyRenderPipelineCallback(callback) {
     }
 
-    void CreateReadyRenderPipelineTask::Finish() {
+    void CreateReadyRenderPipelineTask::Finish(WGPUCreateReadyPipelineStatus status) {
         ASSERT(mPipeline != nullptr);
         ASSERT(mCreateReadyRenderPipelineCallback != nullptr);
 
-        mCreateReadyRenderPipelineCallback(WGPUCreateReadyPipelineStatus_Success,
-                                           reinterpret_cast<WGPURenderPipeline>(mPipeline), "",
-                                           mUserData);
+        if (status != WGPUCreateReadyPipelineStatus_Success) {
+            // TODO(jiawei.shao@intel.com): support handling device lost
+            ASSERT(status == WGPUCreateReadyPipelineStatus_DeviceDestroyed);
+            mCreateReadyRenderPipelineCallback(WGPUCreateReadyPipelineStatus_DeviceDestroyed,
+                                               nullptr, "Device destroyed before callback",
+                                               mUserData);
+            mPipeline->Release();
+        } else {
+            mCreateReadyRenderPipelineCallback(
+                status, reinterpret_cast<WGPURenderPipeline>(mPipeline), "", mUserData);
+        }
 
         // Set mCreateReadyPipelineCallback to nullptr in case it is called more than once.
         mCreateReadyRenderPipelineCallback = nullptr;
@@ -81,9 +99,16 @@
 
     void CreateReadyPipelineTracker::Tick(ExecutionSerial finishedSerial) {
         for (auto& task : mCreateReadyPipelineTasksInFlight.IterateUpTo(finishedSerial)) {
-            task->Finish();
+            task->Finish(WGPUCreateReadyPipelineStatus_Success);
         }
         mCreateReadyPipelineTasksInFlight.ClearUpTo(finishedSerial);
     }
 
+    void CreateReadyPipelineTracker::ClearForShutDown() {
+        for (auto& task : mCreateReadyPipelineTasksInFlight.IterateAll()) {
+            task->Finish(WGPUCreateReadyPipelineStatus_DeviceDestroyed);
+        }
+        mCreateReadyPipelineTasksInFlight.Clear();
+    }
+
 }  // namespace dawn_native
diff --git a/src/dawn_native/CreateReadyPipelineTracker.h b/src/dawn_native/CreateReadyPipelineTracker.h
index 0c6b1dc..b1eed08 100644
--- a/src/dawn_native/CreateReadyPipelineTracker.h
+++ b/src/dawn_native/CreateReadyPipelineTracker.h
@@ -32,7 +32,7 @@
         CreateReadyPipelineTaskBase(void* userData);
         virtual ~CreateReadyPipelineTaskBase();
 
-        virtual void Finish() = 0;
+        virtual void Finish(WGPUCreateReadyPipelineStatus status) = 0;
 
       protected:
         void* mUserData;
@@ -43,7 +43,7 @@
                                        WGPUCreateReadyComputePipelineCallback callback,
                                        void* userdata);
 
-        void Finish() final;
+        void Finish(WGPUCreateReadyPipelineStatus status) final;
 
       private:
         ComputePipelineBase* mPipeline;
@@ -55,7 +55,7 @@
                                       WGPUCreateReadyRenderPipelineCallback callback,
                                       void* userdata);
 
-        void Finish() final;
+        void Finish(WGPUCreateReadyPipelineStatus status) final;
 
       private:
         RenderPipelineBase* mPipeline;
@@ -69,6 +69,7 @@
 
         void TrackTask(std::unique_ptr<CreateReadyPipelineTaskBase> task, ExecutionSerial serial);
         void Tick(ExecutionSerial finishedSerial);
+        void ClearForShutDown();
 
       private:
         DeviceBase* mDevice;
diff --git a/src/dawn_native/Device.cpp b/src/dawn_native/Device.cpp
index 2c5b3aa..84210db 100644
--- a/src/dawn_native/Device.cpp
+++ b/src/dawn_native/Device.cpp
@@ -151,7 +151,8 @@
             // pending callbacks.
             mErrorScopeTracker->Tick(GetCompletedCommandSerial());
             GetDefaultQueue()->Tick(GetCompletedCommandSerial());
-            mCreateReadyPipelineTracker->Tick(GetCompletedCommandSerial());
+
+            mCreateReadyPipelineTracker->ClearForShutDown();
 
             // call TickImpl once last time to clean up resources
             // Ignore errors so that we can continue with destruction
@@ -787,6 +788,7 @@
             mDynamicUploader->Deallocate(mCompletedSerial);
             mErrorScopeTracker->Tick(mCompletedSerial);
             GetDefaultQueue()->Tick(mCompletedSerial);
+
             mCreateReadyPipelineTracker->Tick(mCompletedSerial);
         }
 
diff --git a/src/dawn_wire/client/Device.cpp b/src/dawn_wire/client/Device.cpp
index b49f8a4..91d12fc 100644
--- a/src/dawn_wire/client/Device.cpp
+++ b/src/dawn_wire/client/Device.cpp
@@ -40,20 +40,21 @@
         // Fire pending error scopes
         auto errorScopes = std::move(mErrorScopes);
         for (const auto& it : errorScopes) {
-            it.second.callback(WGPUErrorType_Unknown, "Device destroyed", it.second.userdata);
+            it.second.callback(WGPUErrorType_Unknown, "Device destroyed before callback",
+                               it.second.userdata);
         }
 
         auto createReadyPipelineRequests = std::move(mCreateReadyPipelineRequests);
         for (const auto& it : createReadyPipelineRequests) {
             if (it.second.createReadyComputePipelineCallback != nullptr) {
-                it.second.createReadyComputePipelineCallback(WGPUCreateReadyPipelineStatus_Unknown,
-                                                             nullptr, "Device destroyed",
-                                                             it.second.userdata);
+                it.second.createReadyComputePipelineCallback(
+                    WGPUCreateReadyPipelineStatus_DeviceDestroyed, nullptr,
+                    "Device destroyed before callback", it.second.userdata);
             } else {
                 ASSERT(it.second.createReadyRenderPipelineCallback != nullptr);
-                it.second.createReadyRenderPipelineCallback(WGPUCreateReadyPipelineStatus_Unknown,
-                                                            nullptr, "Device destroyed",
-                                                            it.second.userdata);
+                it.second.createReadyRenderPipelineCallback(
+                    WGPUCreateReadyPipelineStatus_DeviceDestroyed, nullptr,
+                    "Device destroyed before callback", it.second.userdata);
             }
         }
 
diff --git a/src/dawn_wire/server/Server.cpp b/src/dawn_wire/server/Server.cpp
index 3a9215e..c7e67aa 100644
--- a/src/dawn_wire/server/Server.cpp
+++ b/src/dawn_wire/server/Server.cpp
@@ -21,7 +21,10 @@
                    const DawnProcTable& procs,
                    CommandSerializer* serializer,
                    MemoryTransferService* memoryTransferService)
-        : mSerializer(serializer), mProcs(procs), mMemoryTransferService(memoryTransferService) {
+        : mSerializer(serializer),
+          mProcs(procs),
+          mMemoryTransferService(memoryTransferService),
+          mIsAlive(std::make_shared<bool>(true)) {
         if (mMemoryTransferService == nullptr) {
             // If a MemoryTransferService is not provided, fallback to inline memory.
             mOwnedMemoryTransferService = CreateInlineMemoryTransferService();
diff --git a/src/dawn_wire/server/Server.h b/src/dawn_wire/server/Server.h
index 4e725a8..08b5452 100644
--- a/src/dawn_wire/server/Server.h
+++ b/src/dawn_wire/server/Server.h
@@ -56,6 +56,7 @@
     };
 
     struct CreateReadyPipelineUserData {
+        std::weak_ptr<bool> isServerAlive;
         Server* server;
         uint64_t requestSerial;
         ObjectId pipelineObjectID;
@@ -131,6 +132,8 @@
         DawnProcTable mProcs;
         std::unique_ptr<MemoryTransferService> mOwnedMemoryTransferService = nullptr;
         MemoryTransferService* mMemoryTransferService = nullptr;
+
+        std::shared_ptr<bool> mIsAlive;
     };
 
     std::unique_ptr<MemoryTransferService> CreateInlineMemoryTransferService();
diff --git a/src/dawn_wire/server/ServerDevice.cpp b/src/dawn_wire/server/ServerDevice.cpp
index b3accf8..2fd26a2 100644
--- a/src/dawn_wire/server/ServerDevice.cpp
+++ b/src/dawn_wire/server/ServerDevice.cpp
@@ -30,20 +30,34 @@
                                                    WGPUComputePipeline pipeline,
                                                    const char* message,
                                                    void* userdata) {
-        CreateReadyPipelineUserData* createReadyPipelineUserData =
-            static_cast<CreateReadyPipelineUserData*>(userdata);
+        std::unique_ptr<CreateReadyPipelineUserData> createReadyPipelineUserData(
+            static_cast<CreateReadyPipelineUserData*>(userdata));
+
+        // We need to ensure createReadyPipelineUserData->server is still pointing to a valid
+        // object before doing any operations on it.
+        if (createReadyPipelineUserData->isServerAlive.expired()) {
+            return;
+        }
+
         createReadyPipelineUserData->server->OnCreateReadyComputePipelineCallback(
-            status, pipeline, message, createReadyPipelineUserData);
+            status, pipeline, message, createReadyPipelineUserData.release());
     }
 
     void Server::ForwardCreateReadyRenderPipeline(WGPUCreateReadyPipelineStatus status,
                                                   WGPURenderPipeline pipeline,
                                                   const char* message,
                                                   void* userdata) {
-        CreateReadyPipelineUserData* createReadyPipelineUserData =
-            static_cast<CreateReadyPipelineUserData*>(userdata);
+        std::unique_ptr<CreateReadyPipelineUserData> createReadyPipelineUserData(
+            static_cast<CreateReadyPipelineUserData*>(userdata));
+
+        // We need to ensure createReadyPipelineUserData->server is still pointing to a valid
+        // object before doing any operations on it.
+        if (createReadyPipelineUserData->isServerAlive.expired()) {
+            return;
+        }
+
         createReadyPipelineUserData->server->OnCreateReadyRenderPipelineCallback(
-            status, pipeline, message, createReadyPipelineUserData);
+            status, pipeline, message, createReadyPipelineUserData.release());
     }
 
     void Server::OnUncapturedError(WGPUErrorType type, const char* message) {
@@ -87,6 +101,7 @@
 
         std::unique_ptr<CreateReadyPipelineUserData> userdata =
             std::make_unique<CreateReadyPipelineUserData>();
+        userdata->isServerAlive = mIsAlive;
         userdata->server = this;
         userdata->requestSerial = requestSerial;
         userdata->pipelineObjectID = pipelineObjectHandle.id;
@@ -101,10 +116,27 @@
                                                       const char* message,
                                                       CreateReadyPipelineUserData* userdata) {
         std::unique_ptr<CreateReadyPipelineUserData> data(userdata);
-        if (status != WGPUCreateReadyPipelineStatus_Success) {
-            ComputePipelineObjects().Free(data->pipelineObjectID);
-        } else {
-            ComputePipelineObjects().Get(data->pipelineObjectID)->handle = pipeline;
+
+        auto* computePipelineObject = ComputePipelineObjects().Get(data->pipelineObjectID);
+        ASSERT(computePipelineObject != nullptr);
+
+        switch (status) {
+            case WGPUCreateReadyPipelineStatus_Success:
+                computePipelineObject->handle = pipeline;
+                break;
+
+            case WGPUCreateReadyPipelineStatus_Error:
+                ComputePipelineObjects().Free(data->pipelineObjectID);
+                break;
+
+            // Currently this code is unreachable because WireServer is always deleted before the
+            // removal of the device. In the future this logic may be changed when we decide to
+            // support sharing one pair of WireServer/WireClient to multiple devices.
+            case WGPUCreateReadyPipelineStatus_DeviceLost:
+            case WGPUCreateReadyPipelineStatus_DeviceDestroyed:
+            case WGPUCreateReadyPipelineStatus_Unknown:
+            default:
+                UNREACHABLE();
         }
 
         ReturnDeviceCreateReadyComputePipelineCallbackCmd cmd;
@@ -128,6 +160,7 @@
 
         std::unique_ptr<CreateReadyPipelineUserData> userdata =
             std::make_unique<CreateReadyPipelineUserData>();
+        userdata->isServerAlive = mIsAlive;
         userdata->server = this;
         userdata->requestSerial = requestSerial;
         userdata->pipelineObjectID = pipelineObjectHandle.id;
@@ -142,10 +175,27 @@
                                                      const char* message,
                                                      CreateReadyPipelineUserData* userdata) {
         std::unique_ptr<CreateReadyPipelineUserData> data(userdata);
-        if (status != WGPUCreateReadyPipelineStatus_Success) {
-            RenderPipelineObjects().Free(data->pipelineObjectID);
-        } else {
-            RenderPipelineObjects().Get(data->pipelineObjectID)->handle = pipeline;
+
+        auto* renderPipelineObject = RenderPipelineObjects().Get(data->pipelineObjectID);
+        ASSERT(renderPipelineObject != nullptr);
+
+        switch (status) {
+            case WGPUCreateReadyPipelineStatus_Success:
+                renderPipelineObject->handle = pipeline;
+                break;
+
+            case WGPUCreateReadyPipelineStatus_Error:
+                RenderPipelineObjects().Free(data->pipelineObjectID);
+                break;
+
+            // Currently this code is unreachable because WireServer is always deleted before the
+            // removal of the device. In the future this logic may be changed when we decide to
+            // support sharing one pair of WireServer/WireClient to multiple devices.
+            case WGPUCreateReadyPipelineStatus_DeviceLost:
+            case WGPUCreateReadyPipelineStatus_DeviceDestroyed:
+            case WGPUCreateReadyPipelineStatus_Unknown:
+            default:
+                UNREACHABLE();
         }
 
         ReturnDeviceCreateReadyRenderPipelineCallbackCmd cmd;
diff --git a/src/tests/end2end/CreateReadyPipelineTests.cpp b/src/tests/end2end/CreateReadyPipelineTests.cpp
index 60c4b06..a03443b 100644
--- a/src/tests/end2end/CreateReadyPipelineTests.cpp
+++ b/src/tests/end2end/CreateReadyPipelineTests.cpp
@@ -26,7 +26,10 @@
     };
 }  // anonymous namespace
 
-class CreateReadyPipelineTest : public DawnTest {};
+class CreateReadyPipelineTest : public DawnTest {
+  protected:
+    CreateReadyPipelineTask task;
+};
 
 // Verify the basic use of CreateReadyComputePipeline works on all backends.
 TEST_P(CreateReadyPipelineTest, BasicUseOfCreateReadyComputePipeline) {
@@ -42,7 +45,6 @@
         utils::CreateShaderModule(device, utils::SingleShaderStage::Compute, computeShader);
     csDesc.computeStage.entryPoint = "main";
 
-    CreateReadyPipelineTask task;
     device.CreateReadyComputePipeline(
         &csDesc,
         [](WGPUCreateReadyPipelineStatus status, WGPUComputePipeline returnPipeline,
@@ -110,7 +112,6 @@
         utils::CreateShaderModule(device, utils::SingleShaderStage::Compute, computeShader);
     csDesc.computeStage.entryPoint = "main0";
 
-    CreateReadyPipelineTask task;
     device.CreateReadyComputePipeline(
         &csDesc,
         [](WGPUCreateReadyPipelineStatus status, WGPUComputePipeline returnPipeline,
@@ -159,7 +160,6 @@
     renderPipelineDescriptor.cColorStates[0].format = kOutputAttachmentFormat;
     renderPipelineDescriptor.primitiveTopology = wgpu::PrimitiveTopology::PointList;
 
-    CreateReadyPipelineTask task;
     device.CreateReadyRenderPipeline(
         &renderPipelineDescriptor,
         [](WGPUCreateReadyPipelineStatus status, WGPURenderPipeline returnPipeline,
@@ -237,7 +237,6 @@
     renderPipelineDescriptor.cColorStates[0].format = kOutputAttachmentFormat;
     renderPipelineDescriptor.primitiveTopology = wgpu::PrimitiveTopology::PointList;
 
-    CreateReadyPipelineTask task;
     device.CreateReadyRenderPipeline(
         &renderPipelineDescriptor,
         [](WGPUCreateReadyPipelineStatus status, WGPURenderPipeline returnPipeline,
@@ -259,6 +258,75 @@
     ASSERT_EQ(nullptr, task.computePipeline.Get());
 }
 
+// Verify there is no error when the device is released before the callback of
+// CreateReadyComputePipeline() is called.
+TEST_P(CreateReadyPipelineTest, ReleaseDeviceBeforeCallbackOfCreateReadyComputePipeline) {
+    const char* computeShader = R"(
+        #version 450
+        void main() {
+        })";
+
+    wgpu::ComputePipelineDescriptor csDesc;
+    csDesc.computeStage.module =
+        utils::CreateShaderModule(device, utils::SingleShaderStage::Compute, computeShader);
+    csDesc.computeStage.entryPoint = "main";
+
+    device.CreateReadyComputePipeline(
+        &csDesc,
+        [](WGPUCreateReadyPipelineStatus status, WGPUComputePipeline returnPipeline,
+           const char* message, void* userdata) {
+            ASSERT_EQ(WGPUCreateReadyPipelineStatus::WGPUCreateReadyPipelineStatus_DeviceDestroyed,
+                      status);
+
+            CreateReadyPipelineTask* task = static_cast<CreateReadyPipelineTask*>(userdata);
+            task->computePipeline = wgpu::ComputePipeline::Acquire(returnPipeline);
+            task->isCompleted = true;
+            task->message = message;
+        },
+        &task);
+}
+
+// Verify there is no error when the device is released before the callback of
+// CreateReadyRenderPipeline() is called.
+TEST_P(CreateReadyPipelineTest, ReleaseDeviceBeforeCallbackOfCreateReadyRenderPipeline) {
+    const char* vertexShader = R"(
+        #version 450
+        void main() {
+            gl_Position = vec4(0.f, 0.f, 0.f, 1.f);
+            gl_PointSize = 1.0f;
+        })";
+    const char* fragmentShader = R"(
+        #version 450
+        layout(location = 0) out vec4 o_color;
+        void main() {
+            o_color = vec4(0.f, 1.f, 0.f, 1.f);
+        })";
+
+    utils::ComboRenderPipelineDescriptor renderPipelineDescriptor(device);
+    wgpu::ShaderModule vsModule =
+        utils::CreateShaderModule(device, utils::SingleShaderStage::Vertex, vertexShader);
+    wgpu::ShaderModule fsModule =
+        utils::CreateShaderModule(device, utils::SingleShaderStage::Fragment, fragmentShader);
+    renderPipelineDescriptor.vertexStage.module = vsModule;
+    renderPipelineDescriptor.cFragmentStage.module = fsModule;
+    renderPipelineDescriptor.cColorStates[0].format = wgpu::TextureFormat::RGBA8Unorm;
+    renderPipelineDescriptor.primitiveTopology = wgpu::PrimitiveTopology::PointList;
+
+    device.CreateReadyRenderPipeline(
+        &renderPipelineDescriptor,
+        [](WGPUCreateReadyPipelineStatus status, WGPURenderPipeline returnPipeline,
+           const char* message, void* userdata) {
+            ASSERT_EQ(WGPUCreateReadyPipelineStatus::WGPUCreateReadyPipelineStatus_DeviceDestroyed,
+                      status);
+
+            CreateReadyPipelineTask* task = static_cast<CreateReadyPipelineTask*>(userdata);
+            task->renderPipeline = wgpu::RenderPipeline::Acquire(returnPipeline);
+            task->isCompleted = true;
+            task->message = message;
+        },
+        &task);
+}
+
 DAWN_INSTANTIATE_TEST(CreateReadyPipelineTest,
                       D3D12Backend(),
                       MetalBackend(),