Refactor APICreateComputePipelineAsync to support both sync and async path

This patch refactors the implementation of APICreateComputePipelineAsync
as a preparation of the async path of the creation of compute pipeline.

Now the code path of APICreateComputePipelineAsync() includes the following
3 parts:
- When an error occurs in the front-end validations, the callback will be
  called at once in the main thread.
- When we can find a proper compute pipeline object in the cache, the
  callback will be called at once in the main thread.
- When we cannot find the proper comptue pipeline object in the cache, the
  newly-created pipeline object, the callback and userdata will be saved
  into the CreatePipelineAsyncTracker, and the callback will be called in
  device.Tick(). All the logic mentioned in this section has been put into
  one function CreateComputePipelineAsyncImpl(), which will be overrided
  by its asynchronous version on all the backends that support creating
  pipeline objects asynchronously.

Note that APICreateRenderPipelineAsync is not changed in this patch because
it is now under refactoring to match the current updates in WebGPU SPEC.

BUG=dawn:529
TEST=dawn_end2end_tests

Change-Id: Ie1cf2f9fc8e18c3e6ad723c6a0cefce29a0eb69c
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/45842
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
diff --git a/src/dawn_native/CreatePipelineAsyncTracker.cpp b/src/dawn_native/CreatePipelineAsyncTracker.cpp
index 2228111..a9fcf62 100644
--- a/src/dawn_native/CreatePipelineAsyncTracker.cpp
+++ b/src/dawn_native/CreatePipelineAsyncTracker.cpp
@@ -20,68 +20,86 @@
 
 namespace dawn_native {
 
-    CreatePipelineAsyncTaskBase::CreatePipelineAsyncTaskBase(void* userdata) : mUserData(userdata) {
+    CreatePipelineAsyncTaskBase::CreatePipelineAsyncTaskBase(std::string errorMessage,
+                                                             void* userdata)
+        : mErrorMessage(errorMessage), mUserData(userdata) {
     }
 
     CreatePipelineAsyncTaskBase::~CreatePipelineAsyncTaskBase() {
     }
 
     CreateComputePipelineAsyncTask::CreateComputePipelineAsyncTask(
-        ComputePipelineBase* pipeline,
+        Ref<ComputePipelineBase> pipeline,
+        std::string errorMessage,
         WGPUCreateComputePipelineAsyncCallback callback,
         void* userdata)
-        : CreatePipelineAsyncTaskBase(userdata),
-          mPipeline(pipeline),
+        : CreatePipelineAsyncTaskBase(errorMessage, userdata),
+          mPipeline(std::move(pipeline)),
           mCreateComputePipelineAsyncCallback(callback) {
     }
 
-    void CreateComputePipelineAsyncTask::Finish(WGPUCreatePipelineAsyncStatus status) {
-        ASSERT(mPipeline != nullptr);
+    void CreateComputePipelineAsyncTask::Finish() {
         ASSERT(mCreateComputePipelineAsyncCallback != nullptr);
 
-        if (status != WGPUCreatePipelineAsyncStatus_Success) {
-            // TODO(jiawei.shao@intel.com): support handling device lost
-            ASSERT(status == WGPUCreatePipelineAsyncStatus_DeviceDestroyed);
-            mCreateComputePipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_DeviceDestroyed,
-                                                nullptr, "Device destroyed before callback",
-                                                mUserData);
-            mPipeline->Release();
-        } else {
+        if (mPipeline.Get() != nullptr) {
             mCreateComputePipelineAsyncCallback(
-                status, reinterpret_cast<WGPUComputePipeline>(mPipeline), "", mUserData);
+                WGPUCreatePipelineAsyncStatus_Success,
+                reinterpret_cast<WGPUComputePipeline>(mPipeline.Detach()), "", mUserData);
+        } else {
+            mCreateComputePipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_Error, nullptr,
+                                                mErrorMessage.c_str(), mUserData);
         }
+    }
 
-        // Set mCreateComputePipelineAsyncCallback to nullptr in case it is called more than once.
-        mCreateComputePipelineAsyncCallback = nullptr;
+    void CreateComputePipelineAsyncTask::HandleShutDown() {
+        ASSERT(mCreateComputePipelineAsyncCallback != nullptr);
+
+        mCreateComputePipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_DeviceDestroyed, nullptr,
+                                            "Device destroyed before callback", mUserData);
+    }
+
+    void CreateComputePipelineAsyncTask::HandleDeviceLoss() {
+        ASSERT(mCreateComputePipelineAsyncCallback != nullptr);
+
+        mCreateComputePipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr,
+                                            "Device lost before callback", mUserData);
     }
 
     CreateRenderPipelineAsyncTask::CreateRenderPipelineAsyncTask(
-        RenderPipelineBase* pipeline,
+        Ref<RenderPipelineBase> pipeline,
+        std::string errorMessage,
         WGPUCreateRenderPipelineAsyncCallback callback,
         void* userdata)
-        : CreatePipelineAsyncTaskBase(userdata),
-          mPipeline(pipeline),
+        : CreatePipelineAsyncTaskBase(errorMessage, userdata),
+          mPipeline(std::move(pipeline)),
           mCreateRenderPipelineAsyncCallback(callback) {
     }
 
-    void CreateRenderPipelineAsyncTask::Finish(WGPUCreatePipelineAsyncStatus status) {
-        ASSERT(mPipeline != nullptr);
+    void CreateRenderPipelineAsyncTask::Finish() {
         ASSERT(mCreateRenderPipelineAsyncCallback != nullptr);
 
-        if (status != WGPUCreatePipelineAsyncStatus_Success) {
-            // TODO(jiawei.shao@intel.com): support handling device lost
-            ASSERT(status == WGPUCreatePipelineAsyncStatus_DeviceDestroyed);
-            mCreateRenderPipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_DeviceDestroyed,
-                                               nullptr, "Device destroyed before callback",
-                                               mUserData);
-            mPipeline->Release();
-        } else {
+        if (mPipeline.Get() != nullptr) {
             mCreateRenderPipelineAsyncCallback(
-                status, reinterpret_cast<WGPURenderPipeline>(mPipeline), "", mUserData);
+                WGPUCreatePipelineAsyncStatus_Success,
+                reinterpret_cast<WGPURenderPipeline>(mPipeline.Detach()), "", mUserData);
+        } else {
+            mCreateRenderPipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_Error, nullptr,
+                                               mErrorMessage.c_str(), mUserData);
         }
+    }
 
-        // Set mCreatePipelineAsyncCallback to nullptr in case it is called more than once.
-        mCreateRenderPipelineAsyncCallback = nullptr;
+    void CreateRenderPipelineAsyncTask::HandleShutDown() {
+        ASSERT(mCreateRenderPipelineAsyncCallback != nullptr);
+
+        mCreateRenderPipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_DeviceDestroyed, nullptr,
+                                           "Device destroyed before callback", mUserData);
+    }
+
+    void CreateRenderPipelineAsyncTask::HandleDeviceLoss() {
+        ASSERT(mCreateRenderPipelineAsyncCallback != nullptr);
+
+        mCreateRenderPipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr,
+                                           "Device lost before callback", mUserData);
     }
 
     CreatePipelineAsyncTracker::CreatePipelineAsyncTracker(DeviceBase* device) : mDevice(device) {
@@ -110,13 +128,17 @@
         mCreatePipelineAsyncTasksInFlight.ClearUpTo(finishedSerial);
 
         for (auto& task : tasks) {
-            task->Finish(WGPUCreatePipelineAsyncStatus_Success);
+            if (mDevice->IsLost()) {
+                task->HandleDeviceLoss();
+            } else {
+                task->Finish();
+            }
         }
     }
 
     void CreatePipelineAsyncTracker::ClearForShutDown() {
         for (auto& task : mCreatePipelineAsyncTasksInFlight.IterateAll()) {
-            task->Finish(WGPUCreatePipelineAsyncStatus_DeviceDestroyed);
+            task->HandleShutDown();
         }
         mCreatePipelineAsyncTasksInFlight.Clear();
     }
diff --git a/src/dawn_native/CreatePipelineAsyncTracker.h b/src/dawn_native/CreatePipelineAsyncTracker.h
index 438427a..b84daed 100644
--- a/src/dawn_native/CreatePipelineAsyncTracker.h
+++ b/src/dawn_native/CreatePipelineAsyncTracker.h
@@ -15,11 +15,13 @@
 #ifndef DAWNNATIVE_CREATEPIPELINEASYNCTRACKER_H_
 #define DAWNNATIVE_CREATEPIPELINEASYNCTRACKER_H_
 
+#include "common/RefCounted.h"
 #include "common/SerialQueue.h"
 #include "dawn/webgpu.h"
 #include "dawn_native/IntegerTypes.h"
 
 #include <memory>
+#include <string>
 
 namespace dawn_native {
 
@@ -28,42 +30,51 @@
     class RenderPipelineBase;
 
     struct CreatePipelineAsyncTaskBase {
-        CreatePipelineAsyncTaskBase(void* userData);
+        CreatePipelineAsyncTaskBase(std::string errorMessage, void* userData);
         virtual ~CreatePipelineAsyncTaskBase();
 
-        virtual void Finish(WGPUCreatePipelineAsyncStatus status) = 0;
+        virtual void Finish() = 0;
+        virtual void HandleShutDown() = 0;
+        virtual void HandleDeviceLoss() = 0;
 
       protected:
+        std::string mErrorMessage;
         void* mUserData;
     };
 
     struct CreateComputePipelineAsyncTask final : public CreatePipelineAsyncTaskBase {
-        CreateComputePipelineAsyncTask(ComputePipelineBase* pipeline,
+        CreateComputePipelineAsyncTask(Ref<ComputePipelineBase> pipeline,
+                                       std::string errorMessage,
                                        WGPUCreateComputePipelineAsyncCallback callback,
                                        void* userdata);
 
-        void Finish(WGPUCreatePipelineAsyncStatus status) final;
+        void Finish() final;
+        void HandleShutDown() final;
+        void HandleDeviceLoss() final;
 
       private:
-        ComputePipelineBase* mPipeline;
+        Ref<ComputePipelineBase> mPipeline;
         WGPUCreateComputePipelineAsyncCallback mCreateComputePipelineAsyncCallback;
     };
 
     struct CreateRenderPipelineAsyncTask final : public CreatePipelineAsyncTaskBase {
-        CreateRenderPipelineAsyncTask(RenderPipelineBase* pipeline,
+        CreateRenderPipelineAsyncTask(Ref<RenderPipelineBase> pipeline,
+                                      std::string errorMessage,
                                       WGPUCreateRenderPipelineAsyncCallback callback,
                                       void* userdata);
 
-        void Finish(WGPUCreatePipelineAsyncStatus status) final;
+        void Finish() final;
+        void HandleShutDown() final;
+        void HandleDeviceLoss() final;
 
       private:
-        RenderPipelineBase* mPipeline;
+        Ref<RenderPipelineBase> mPipeline;
         WGPUCreateRenderPipelineAsyncCallback mCreateRenderPipelineAsyncCallback;
     };
 
     class CreatePipelineAsyncTracker {
       public:
-        CreatePipelineAsyncTracker(DeviceBase* device);
+        explicit CreatePipelineAsyncTracker(DeviceBase* device);
         ~CreatePipelineAsyncTracker();
 
         void TrackTask(std::unique_ptr<CreatePipelineAsyncTaskBase> task, ExecutionSerial serial);
diff --git a/src/dawn_native/Device.cpp b/src/dawn_native/Device.cpp
index 49898d9..259020b 100644
--- a/src/dawn_native/Device.cpp
+++ b/src/dawn_native/Device.cpp
@@ -480,7 +480,7 @@
         return mEmptyBindGroupLayout.Get();
     }
 
-    ResultOrError<Ref<ComputePipelineBase>> DeviceBase::GetOrCreateComputePipeline(
+    std::pair<Ref<ComputePipelineBase>, size_t> DeviceBase::GetCachedComputePipeline(
         const ComputePipelineDescriptor* descriptor) {
         ComputePipelineBase blueprint(this, descriptor);
 
@@ -491,14 +491,22 @@
         auto iter = mCaches->computePipelines.find(&blueprint);
         if (iter != mCaches->computePipelines.end()) {
             result = *iter;
-        } else {
-            DAWN_TRY_ASSIGN(result, CreateComputePipelineImpl(descriptor));
-            result->SetIsCachedReference();
-            result->SetContentHash(blueprintHash);
-            mCaches->computePipelines.insert(result.Get());
         }
 
-        return std::move(result);
+        return std::make_pair(result, blueprintHash);
+    }
+
+    Ref<ComputePipelineBase> DeviceBase::AddOrGetCachedPipeline(
+        Ref<ComputePipelineBase> computePipeline,
+        size_t blueprintHash) {
+        computePipeline->SetContentHash(blueprintHash);
+        auto insertion = mCaches->computePipelines.insert(computePipeline.Get());
+        if (insertion.second) {
+            computePipeline->SetIsCachedReference();
+            return computePipeline;
+        } else {
+            return *(insertion.first);
+        }
     }
 
     void DeviceBase::UncacheComputePipeline(ComputePipelineBase* obj) {
@@ -711,19 +719,16 @@
     void DeviceBase::APICreateComputePipelineAsync(const ComputePipelineDescriptor* descriptor,
                                                    WGPUCreateComputePipelineAsyncCallback callback,
                                                    void* userdata) {
-        ResultOrError<Ref<ComputePipelineBase>> maybeResult =
-            CreateComputePipelineInternal(descriptor);
+        MaybeError maybeResult = CreateComputePipelineAsyncInternal(descriptor, callback, userdata);
+
+        // Call the callback directly when a validation error has been found in the front-end
+        // validations. If there is no error, then CreateComputePipelineAsyncInternal will call the
+        // callback.
         if (maybeResult.IsError()) {
             std::unique_ptr<ErrorData> error = maybeResult.AcquireError();
             callback(WGPUCreatePipelineAsyncStatus_Error, nullptr, error->GetMessage().c_str(),
                      userdata);
-            return;
         }
-
-        std::unique_ptr<CreateComputePipelineAsyncTask> request =
-            std::make_unique<CreateComputePipelineAsyncTask>(maybeResult.AcquireSuccess().Detach(),
-                                                             callback, userdata);
-        mCreatePipelineAsyncTracker->TrackTask(std::move(request), GetPendingCommandSerial());
     }
     PipelineLayoutBase* DeviceBase::APICreatePipelineLayout(
         const PipelineLayoutDescriptor* descriptor) {
@@ -759,9 +764,10 @@
             return;
         }
 
+        Ref<RenderPipelineBase> result = maybeResult.AcquireSuccess();
         std::unique_ptr<CreateRenderPipelineAsyncTask> request =
-            std::make_unique<CreateRenderPipelineAsyncTask>(maybeResult.AcquireSuccess().Detach(),
-                                                            callback, userdata);
+            std::make_unique<CreateRenderPipelineAsyncTask>(std::move(result), "", callback,
+                                                            userdata);
         mCreatePipelineAsyncTracker->TrackTask(std::move(request), GetPendingCommandSerial());
     }
     RenderBundleEncoder* DeviceBase::APICreateRenderBundleEncoder(
@@ -1074,23 +1080,95 @@
             DAWN_TRY(ValidateComputePipelineDescriptor(this, descriptor));
         }
 
-        if (descriptor->layout == nullptr) {
-            ComputePipelineDescriptor descriptorWithDefaultLayout = *descriptor;
+        // Ref will keep the pipeline layout alive until the end of the function where
+        // the pipeline will take another reference.
+        Ref<PipelineLayoutBase> layoutRef;
+        ComputePipelineDescriptor appliedDescriptor;
+        DAWN_TRY_ASSIGN(layoutRef, ValidateAndGetComputePipelineDescriptorWithDefaults(
+                                       *descriptor, &appliedDescriptor));
 
-            // Ref will keep the pipeline layout alive until the end of the function where
-            // the pipeline will take another reference.
-            Ref<PipelineLayoutBase> layoutRef;
-            DAWN_TRY_ASSIGN(layoutRef,
-                            PipelineLayoutBase::CreateDefault(
-                                this, {{SingleShaderStage::Compute, descriptor->computeStage.module,
-                                        descriptor->computeStage.entryPoint}}));
-
-            descriptorWithDefaultLayout.layout = layoutRef.Get();
-
-            return GetOrCreateComputePipeline(&descriptorWithDefaultLayout);
-        } else {
-            return GetOrCreateComputePipeline(descriptor);
+        auto pipelineAndBlueprintFromCache = GetCachedComputePipeline(&appliedDescriptor);
+        if (pipelineAndBlueprintFromCache.first.Get() != nullptr) {
+            return std::move(pipelineAndBlueprintFromCache.first);
         }
+
+        Ref<ComputePipelineBase> backendObj;
+        DAWN_TRY_ASSIGN(backendObj, CreateComputePipelineImpl(&appliedDescriptor));
+        size_t blueprintHash = pipelineAndBlueprintFromCache.second;
+        return AddOrGetCachedPipeline(backendObj, blueprintHash);
+    }
+
+    MaybeError DeviceBase::CreateComputePipelineAsyncInternal(
+        const ComputePipelineDescriptor* descriptor,
+        WGPUCreateComputePipelineAsyncCallback callback,
+        void* userdata) {
+        DAWN_TRY(ValidateIsAlive());
+        if (IsValidationEnabled()) {
+            DAWN_TRY(ValidateComputePipelineDescriptor(this, descriptor));
+        }
+
+        // Ref will keep the pipeline layout alive until the end of the function where
+        // the pipeline will take another reference.
+        Ref<PipelineLayoutBase> layoutRef;
+        ComputePipelineDescriptor appliedDescriptor;
+        DAWN_TRY_ASSIGN(layoutRef, ValidateAndGetComputePipelineDescriptorWithDefaults(
+                                       *descriptor, &appliedDescriptor));
+
+        // Call the callback directly when we can get a cached compute pipeline object.
+        auto pipelineAndBlueprintFromCache = GetCachedComputePipeline(&appliedDescriptor);
+        if (pipelineAndBlueprintFromCache.first.Get() != nullptr) {
+            Ref<ComputePipelineBase> result = std::move(pipelineAndBlueprintFromCache.first);
+            callback(WGPUCreatePipelineAsyncStatus_Success,
+                     reinterpret_cast<WGPUComputePipeline>(result.Detach()), "", userdata);
+        } else {
+            // Otherwise we will create the pipeline object in CreateComputePipelineAsyncImpl(),
+            // where the pipeline object may be created asynchronously and the result will be saved
+            // to mCreatePipelineAsyncTracker.
+            const size_t blueprintHash = pipelineAndBlueprintFromCache.second;
+            CreateComputePipelineAsyncImpl(&appliedDescriptor, blueprintHash, callback, userdata);
+        }
+
+        return {};
+    }
+
+    ResultOrError<Ref<PipelineLayoutBase>>
+    DeviceBase::ValidateAndGetComputePipelineDescriptorWithDefaults(
+        const ComputePipelineDescriptor& descriptor,
+        ComputePipelineDescriptor* outDescriptor) {
+        Ref<PipelineLayoutBase> layoutRef;
+        *outDescriptor = descriptor;
+        if (outDescriptor->layout == nullptr) {
+            DAWN_TRY_ASSIGN(layoutRef, PipelineLayoutBase::CreateDefault(
+                                           this, {{SingleShaderStage::Compute,
+                                                   outDescriptor->computeStage.module,
+                                                   outDescriptor->computeStage.entryPoint}}));
+            outDescriptor->layout = layoutRef.Get();
+        }
+
+        return layoutRef;
+    }
+
+    // TODO(jiawei.shao@intel.com): override this function with the async version on the backends
+    // that supports creating compute pipeline asynchronously
+    void DeviceBase::CreateComputePipelineAsyncImpl(const ComputePipelineDescriptor* descriptor,
+                                                    size_t blueprintHash,
+                                                    WGPUCreateComputePipelineAsyncCallback callback,
+                                                    void* userdata) {
+        Ref<ComputePipelineBase> result;
+        std::string errorMessage;
+
+        auto resultOrError = CreateComputePipelineImpl(descriptor);
+        if (resultOrError.IsError()) {
+            std::unique_ptr<ErrorData> error = resultOrError.AcquireError();
+            errorMessage = error->GetMessage();
+        } else {
+            result = AddOrGetCachedPipeline(resultOrError.AcquireSuccess(), blueprintHash);
+        }
+
+        std::unique_ptr<CreateComputePipelineAsyncTask> request =
+            std::make_unique<CreateComputePipelineAsyncTask>(result, errorMessage, callback,
+                                                             userdata);
+        mCreatePipelineAsyncTracker->TrackTask(std::move(request), GetPendingCommandSerial());
     }
 
     ResultOrError<Ref<PipelineLayoutBase>> DeviceBase::CreatePipelineLayoutInternal(
diff --git a/src/dawn_native/Device.h b/src/dawn_native/Device.h
index 1f014f3..5bb6f6b 100644
--- a/src/dawn_native/Device.h
+++ b/src/dawn_native/Device.h
@@ -27,6 +27,7 @@
 #include "dawn_native/dawn_platform.h"
 
 #include <memory>
+#include <utility>
 
 namespace dawn_native {
     class AdapterBase;
@@ -111,8 +112,6 @@
 
         BindGroupLayoutBase* GetEmptyBindGroupLayout();
 
-        ResultOrError<Ref<ComputePipelineBase>> GetOrCreateComputePipeline(
-            const ComputePipelineDescriptor* descriptor);
         void UncacheComputePipeline(ComputePipelineBase* obj);
 
         ResultOrError<Ref<PipelineLayoutBase>> GetOrCreatePipelineLayout(
@@ -304,6 +303,10 @@
         ResultOrError<Ref<BindGroupLayoutBase>> CreateBindGroupLayoutInternal(
             const BindGroupLayoutDescriptor* descriptor);
         ResultOrError<Ref<BufferBase>> CreateBufferInternal(const BufferDescriptor* descriptor);
+        MaybeError CreateComputePipelineAsyncInternal(
+            const ComputePipelineDescriptor* descriptor,
+            WGPUCreateComputePipelineAsyncCallback callback,
+            void* userdata);
         ResultOrError<Ref<ComputePipelineBase>> CreateComputePipelineInternal(
             const ComputePipelineDescriptor* descriptor);
         ResultOrError<Ref<PipelineLayoutBase>> CreatePipelineLayoutInternal(
@@ -327,6 +330,18 @@
             TextureBase* texture,
             const TextureViewDescriptor* descriptor);
 
+        ResultOrError<Ref<PipelineLayoutBase>> ValidateAndGetComputePipelineDescriptorWithDefaults(
+            const ComputePipelineDescriptor& descriptor,
+            ComputePipelineDescriptor* outDescriptor);
+        std::pair<Ref<ComputePipelineBase>, size_t> GetCachedComputePipeline(
+            const ComputePipelineDescriptor* descriptor);
+        Ref<ComputePipelineBase> AddOrGetCachedPipeline(Ref<ComputePipelineBase> computePipeline,
+                                                        size_t blueprintHash);
+        void CreateComputePipelineAsyncImpl(const ComputePipelineDescriptor* descriptor,
+                                            size_t blueprintHash,
+                                            WGPUCreateComputePipelineAsyncCallback callback,
+                                            void* userdata);
+
         void ApplyToggleOverrides(const DeviceDescriptor* deviceDescriptor);
         void ApplyExtensions(const DeviceDescriptor* deviceDescriptor);
 
diff --git a/src/tests/end2end/CreatePipelineAsyncTests.cpp b/src/tests/end2end/CreatePipelineAsyncTests.cpp
index f5dffe3..7646035 100644
--- a/src/tests/end2end/CreatePipelineAsyncTests.cpp
+++ b/src/tests/end2end/CreatePipelineAsyncTests.cpp
@@ -28,6 +28,46 @@
 
 class CreatePipelineAsyncTest : public DawnTest {
   protected:
+    void ValidateCreateComputePipelineAsync(CreatePipelineAsyncTask* currentTask) {
+        wgpu::BufferDescriptor bufferDesc;
+        bufferDesc.size = sizeof(uint32_t);
+        bufferDesc.usage = wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc;
+        wgpu::Buffer ssbo = device.CreateBuffer(&bufferDesc);
+
+        wgpu::CommandBuffer commands;
+        {
+            wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+            wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
+
+            while (!currentTask->isCompleted) {
+                WaitABit();
+            }
+            ASSERT_TRUE(currentTask->message.empty());
+            ASSERT_NE(nullptr, currentTask->computePipeline.Get());
+            wgpu::BindGroup bindGroup =
+                utils::MakeBindGroup(device, currentTask->computePipeline.GetBindGroupLayout(0),
+                                     {
+                                         {0, ssbo, 0, sizeof(uint32_t)},
+                                     });
+            pass.SetBindGroup(0, bindGroup);
+            pass.SetPipeline(currentTask->computePipeline);
+
+            pass.Dispatch(1);
+            pass.EndPass();
+
+            commands = encoder.Finish();
+        }
+
+        queue.Submit(1, &commands);
+
+        constexpr uint32_t kExpected = 1u;
+        EXPECT_BUFFER_U32_EQ(kExpected, ssbo, 0);
+    }
+
+    void ValidateCreateComputePipelineAsync() {
+        ValidateCreateComputePipelineAsync(&task);
+    }
+
     CreatePipelineAsyncTask task;
 };
 
@@ -58,39 +98,7 @@
         },
         &task);
 
-    wgpu::BufferDescriptor bufferDesc;
-    bufferDesc.size = sizeof(uint32_t);
-    bufferDesc.usage = wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc;
-    wgpu::Buffer ssbo = device.CreateBuffer(&bufferDesc);
-
-    wgpu::CommandBuffer commands;
-    {
-        wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
-        wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
-
-        while (!task.isCompleted) {
-            WaitABit();
-        }
-        ASSERT_TRUE(task.message.empty());
-        ASSERT_NE(nullptr, task.computePipeline.Get());
-        wgpu::BindGroup bindGroup =
-            utils::MakeBindGroup(device, task.computePipeline.GetBindGroupLayout(0),
-                                 {
-                                     {0, ssbo, 0, sizeof(uint32_t)},
-                                 });
-        pass.SetBindGroup(0, bindGroup);
-        pass.SetPipeline(task.computePipeline);
-
-        pass.Dispatch(1);
-        pass.EndPass();
-
-        commands = encoder.Finish();
-    }
-
-    queue.Submit(1, &commands);
-
-    constexpr uint32_t kExpected = 1u;
-    EXPECT_BUFFER_U32_EQ(kExpected, ssbo, 0);
+    ValidateCreateComputePipelineAsync();
 }
 
 // Verify CreateComputePipelineAsync() works as expected when there is any error that happens during
@@ -302,6 +310,88 @@
         &task);
 }
 
+// Verify the code path of CreateComputePipelineAsync() to directly return the compute pipeline
+// object from cache works correctly.
+TEST_P(CreatePipelineAsyncTest, CreateSameComputePipelineTwice) {
+    wgpu::ComputePipelineDescriptor csDesc;
+    csDesc.computeStage.module = utils::CreateShaderModule(device, R"(
+        [[block]] struct SSBO {
+            value : u32;
+        };
+        [[group(0), binding(0)]] var<storage> ssbo : [[access(read_write)]] SSBO;
+
+        [[stage(compute)]] fn main() -> void {
+            ssbo.value = 1u;
+        })");
+    csDesc.computeStage.entryPoint = "main";
+
+    auto callback = [](WGPUCreatePipelineAsyncStatus status, WGPUComputePipeline returnPipeline,
+                       const char* message, void* userdata) {
+        EXPECT_EQ(WGPUCreatePipelineAsyncStatus::WGPUCreatePipelineAsyncStatus_Success, status);
+
+        CreatePipelineAsyncTask* task = static_cast<CreatePipelineAsyncTask*>(userdata);
+        task->computePipeline = wgpu::ComputePipeline::Acquire(returnPipeline);
+        task->isCompleted = true;
+        task->message = message;
+    };
+
+    // Create a pipeline object and save it into anotherTask.computePipeline.
+    CreatePipelineAsyncTask anotherTask;
+    device.CreateComputePipelineAsync(&csDesc, callback, &anotherTask);
+    while (!anotherTask.isCompleted) {
+        WaitABit();
+    }
+    ASSERT_TRUE(anotherTask.message.empty());
+    ASSERT_NE(nullptr, anotherTask.computePipeline.Get());
+
+    // Create another pipeline object task.comnputepipeline with the same compute pipeline
+    // descriptor used in the creation of anotherTask.computePipeline. This time the pipeline
+    // object should be directly got from the pipeline object cache.
+    device.CreateComputePipelineAsync(&csDesc, callback, &task);
+    ValidateCreateComputePipelineAsync();
+}
+
+// Verify creating compute pipeline with same descriptor and CreateComputePipelineAsync() at the
+// same time works correctly.
+TEST_P(CreatePipelineAsyncTest, CreateSamePipelineTwiceAtSameTime) {
+    wgpu::ComputePipelineDescriptor csDesc;
+    csDesc.computeStage.module = utils::CreateShaderModule(device, R"(
+        [[block]] struct SSBO {
+            value : u32;
+        };
+        [[group(0), binding(0)]] var<storage> ssbo : [[access(read_write)]] SSBO;
+
+        [[stage(compute)]] fn main() -> void {
+            ssbo.value = 1u;
+        })");
+    csDesc.computeStage.entryPoint = "main";
+
+    auto callback = [](WGPUCreatePipelineAsyncStatus status, WGPUComputePipeline returnPipeline,
+                       const char* message, void* userdata) {
+        EXPECT_EQ(WGPUCreatePipelineAsyncStatus::WGPUCreatePipelineAsyncStatus_Success, status);
+
+        CreatePipelineAsyncTask* task = static_cast<CreatePipelineAsyncTask*>(userdata);
+        task->computePipeline = wgpu::ComputePipeline::Acquire(returnPipeline);
+        task->isCompleted = true;
+        task->message = message;
+    };
+
+    // Create two pipeline objects with same descriptor.
+    CreatePipelineAsyncTask anotherTask;
+    device.CreateComputePipelineAsync(&csDesc, callback, &task);
+    device.CreateComputePipelineAsync(&csDesc, callback, &anotherTask);
+
+    // Verify both task.computePipeline and anotherTask.computePipeline are created correctly.
+    ValidateCreateComputePipelineAsync(&anotherTask);
+    ValidateCreateComputePipelineAsync(&task);
+
+    // Verify task.computePipeline and anotherTask.computePipeline are pointing to the same Dawn
+    // object.
+    if (!UsesWire()) {
+        EXPECT_EQ(task.computePipeline.Get(), anotherTask.computePipeline.Get());
+    }
+}
+
 DAWN_INSTANTIATE_TEST(CreatePipelineAsyncTest,
                       D3D12Backend(),
                       MetalBackend(),