Vulkan: Support creating compute pipeline asynchronously

BUG=dawn:529
TEST=dawn_end2end_tests

Change-Id: Id2b2bebe164ccc829e4f2cf737255d634d6572a0
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/53760
Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
Reviewed-by: Austin Eng <enga@chromium.org>
diff --git a/src/dawn_native/d3d12/ComputePipelineD3D12.cpp b/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
index 9fa167a..efd942a 100644
--- a/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
+++ b/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
@@ -14,7 +14,6 @@
 
 #include "dawn_native/d3d12/ComputePipelineD3D12.h"
 
-#include "dawn_native/AsyncTask.h"
 #include "dawn_native/CreatePipelineAsyncTask.h"
 #include "dawn_native/d3d12/D3D12Error.h"
 #include "dawn_native/d3d12/DeviceD3D12.h"
diff --git a/src/dawn_native/vulkan/ComputePipelineVk.cpp b/src/dawn_native/vulkan/ComputePipelineVk.cpp
index 322c026..5bc1d4d 100644
--- a/src/dawn_native/vulkan/ComputePipelineVk.cpp
+++ b/src/dawn_native/vulkan/ComputePipelineVk.cpp
@@ -14,6 +14,7 @@
 
 #include "dawn_native/vulkan/ComputePipelineVk.h"
 
+#include "dawn_native/CreatePipelineAsyncTask.h"
 #include "dawn_native/vulkan/DeviceVk.h"
 #include "dawn_native/vulkan/FencedDeleter.h"
 #include "dawn_native/vulkan/PipelineLayoutVk.h"
@@ -88,4 +89,16 @@
         return mHandle;
     }
 
+    void ComputePipeline::CreateAsync(Device* device,
+                                      const ComputePipelineDescriptor* descriptor,
+                                      size_t blueprintHash,
+                                      WGPUCreateComputePipelineAsyncCallback callback,
+                                      void* userdata) {
+        Ref<ComputePipeline> pipeline = AcquireRef(new ComputePipeline(device, descriptor));
+        std::unique_ptr<CreateComputePipelineAsyncTask> asyncTask =
+            std::make_unique<CreateComputePipelineAsyncTask>(pipeline, descriptor, blueprintHash,
+                                                             callback, userdata);
+        CreateComputePipelineAsyncTask::RunAsync(std::move(asyncTask));
+    }
+
 }}  // namespace dawn_native::vulkan
diff --git a/src/dawn_native/vulkan/ComputePipelineVk.h b/src/dawn_native/vulkan/ComputePipelineVk.h
index 97a32bf..f5ac787 100644
--- a/src/dawn_native/vulkan/ComputePipelineVk.h
+++ b/src/dawn_native/vulkan/ComputePipelineVk.h
@@ -29,6 +29,11 @@
         static ResultOrError<Ref<ComputePipeline>> Create(
             Device* device,
             const ComputePipelineDescriptor* descriptor);
+        static void CreateAsync(Device* device,
+                                const ComputePipelineDescriptor* descriptor,
+                                size_t blueprintHash,
+                                WGPUCreateComputePipelineAsyncCallback callback,
+                                void* userdata);
 
         VkPipeline GetHandle() const;
 
diff --git a/src/dawn_native/vulkan/DeviceVk.cpp b/src/dawn_native/vulkan/DeviceVk.cpp
index cc69fa9..5541dea 100644
--- a/src/dawn_native/vulkan/DeviceVk.cpp
+++ b/src/dawn_native/vulkan/DeviceVk.cpp
@@ -160,6 +160,12 @@
         const TextureViewDescriptor* descriptor) {
         return TextureView::Create(texture, descriptor);
     }
+    void Device::CreateComputePipelineAsyncImpl(const ComputePipelineDescriptor* descriptor,
+                                                size_t blueprintHash,
+                                                WGPUCreateComputePipelineAsyncCallback callback,
+                                                void* userdata) {
+        ComputePipeline::CreateAsync(this, descriptor, blueprintHash, callback, userdata);
+    }
 
     MaybeError Device::TickImpl() {
         RecycleCompletedCommands();
diff --git a/src/dawn_native/vulkan/DeviceVk.h b/src/dawn_native/vulkan/DeviceVk.h
index 376f15a..497defe 100644
--- a/src/dawn_native/vulkan/DeviceVk.h
+++ b/src/dawn_native/vulkan/DeviceVk.h
@@ -143,6 +143,10 @@
         ResultOrError<Ref<TextureViewBase>> CreateTextureViewImpl(
             TextureBase* texture,
             const TextureViewDescriptor* descriptor) override;
+        void CreateComputePipelineAsyncImpl(const ComputePipelineDescriptor* descriptor,
+                                            size_t blueprintHash,
+                                            WGPUCreateComputePipelineAsyncCallback callback,
+                                            void* userdata) override;
 
         ResultOrError<VulkanDeviceKnobs> CreateDevice(VkPhysicalDevice physicalDevice);
         void GatherQueueFromDevice();
diff --git a/src/dawn_native/vulkan/ShaderModuleVk.cpp b/src/dawn_native/vulkan/ShaderModuleVk.cpp
index 2e256da..e687505 100644
--- a/src/dawn_native/vulkan/ShaderModuleVk.cpp
+++ b/src/dawn_native/vulkan/ShaderModuleVk.cpp
@@ -31,6 +31,45 @@
 
 namespace dawn_native { namespace vulkan {
 
+    ShaderModule::ConcurrentTransformedShaderModuleCache::ConcurrentTransformedShaderModuleCache(
+        Device* device)
+        : mDevice(device) {
+    }
+
+    ShaderModule::ConcurrentTransformedShaderModuleCache::
+        ~ConcurrentTransformedShaderModuleCache() {
+        std::lock_guard<std::mutex> lock(mMutex);
+        for (const auto& iter : mTransformedShaderModuleCache) {
+            mDevice->GetFencedDeleter()->DeleteWhenUnused(iter.second);
+        }
+    }
+
+    VkShaderModule ShaderModule::ConcurrentTransformedShaderModuleCache::FindShaderModule(
+        const PipelineLayoutEntryPointPair& key) {
+        std::lock_guard<std::mutex> lock(mMutex);
+        auto iter = mTransformedShaderModuleCache.find(key);
+        if (iter != mTransformedShaderModuleCache.end()) {
+            auto cached = iter->second;
+            return cached;
+        }
+        return VK_NULL_HANDLE;
+    }
+
+    VkShaderModule ShaderModule::ConcurrentTransformedShaderModuleCache::AddOrGetCachedShaderModule(
+        const PipelineLayoutEntryPointPair& key,
+        VkShaderModule value) {
+        ASSERT(value != VK_NULL_HANDLE);
+        std::lock_guard<std::mutex> lock(mMutex);
+        auto iter = mTransformedShaderModuleCache.find(key);
+        if (iter == mTransformedShaderModuleCache.end()) {
+            mTransformedShaderModuleCache.emplace(key, value);
+            return value;
+        } else {
+            mDevice->GetFencedDeleter()->DeleteWhenUnused(value);
+            return iter->second;
+        }
+    }
+
     // static
     ResultOrError<Ref<ShaderModule>> ShaderModule::Create(Device* device,
                                                           const ShaderModuleDescriptor* descriptor,
@@ -41,7 +80,7 @@
     }
 
     ShaderModule::ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor)
-        : ShaderModuleBase(device, descriptor) {
+        : ShaderModuleBase(device, descriptor), mTransformedShaderModuleCache(device) {
     }
 
     MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
@@ -112,10 +151,6 @@
             device->GetFencedDeleter()->DeleteWhenUnused(mHandle);
             mHandle = VK_NULL_HANDLE;
         }
-
-        for (const auto& iter : mTransformedShaderModuleCache) {
-            device->GetFencedDeleter()->DeleteWhenUnused(iter.second);
-        }
     }
 
     VkShaderModule ShaderModule::GetHandle() const {
@@ -131,10 +166,10 @@
         ASSERT(GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator));
 
         auto cacheKey = std::make_pair(layout, entryPointName);
-        auto iter = mTransformedShaderModuleCache.find(cacheKey);
-        if (iter != mTransformedShaderModuleCache.end()) {
-            auto cached = iter->second;
-            return cached;
+        VkShaderModule cachedShaderModule =
+            mTransformedShaderModuleCache.FindShaderModule(cacheKey);
+        if (cachedShaderModule != VK_NULL_HANDLE) {
+            return cachedShaderModule;
         }
 
         // Creation of VkShaderModule is deferred to this point when using tint generator
@@ -204,7 +239,8 @@
             device->fn.CreateShaderModule(device->GetVkDevice(), &createInfo, nullptr, &*newHandle),
             "CreateShaderModule"));
         if (newHandle != VK_NULL_HANDLE) {
-            mTransformedShaderModuleCache.emplace(cacheKey, newHandle);
+            newHandle =
+                mTransformedShaderModuleCache.AddOrGetCachedShaderModule(cacheKey, newHandle);
         }
 
         return newHandle;
diff --git a/src/dawn_native/vulkan/ShaderModuleVk.h b/src/dawn_native/vulkan/ShaderModuleVk.h
index 9dd7817..7bd0c2f 100644
--- a/src/dawn_native/vulkan/ShaderModuleVk.h
+++ b/src/dawn_native/vulkan/ShaderModuleVk.h
@@ -20,15 +20,13 @@
 #include "common/vulkan_platform.h"
 #include "dawn_native/Error.h"
 
+#include <mutex>
+
 namespace dawn_native { namespace vulkan {
 
     class Device;
     class PipelineLayout;
 
-    using TransformedShaderModuleCache = std::unordered_map<PipelineLayoutEntryPointPair,
-                                                            VkShaderModule,
-                                                            PipelineLayoutEntryPointPairHashFunc>;
-
     class ShaderModule final : public ShaderModuleBase {
       public:
         static ResultOrError<Ref<ShaderModule>> Create(Device* device,
@@ -49,7 +47,23 @@
         VkShaderModule mHandle = VK_NULL_HANDLE;
 
         // New handles created by GetTransformedModuleHandle at pipeline creation time
-        TransformedShaderModuleCache mTransformedShaderModuleCache;
+        class ConcurrentTransformedShaderModuleCache {
+          public:
+            explicit ConcurrentTransformedShaderModuleCache(Device* device);
+            ~ConcurrentTransformedShaderModuleCache();
+            VkShaderModule FindShaderModule(const PipelineLayoutEntryPointPair& key);
+            VkShaderModule AddOrGetCachedShaderModule(const PipelineLayoutEntryPointPair& key,
+                                                      VkShaderModule value);
+
+          private:
+            Device* mDevice;
+            std::mutex mMutex;
+            std::unordered_map<PipelineLayoutEntryPointPair,
+                               VkShaderModule,
+                               PipelineLayoutEntryPointPairHashFunc>
+                mTransformedShaderModuleCache;
+        };
+        ConcurrentTransformedShaderModuleCache mTransformedShaderModuleCache;
     };
 
 }}  // namespace dawn_native::vulkan