Vulkan: Support creating compute pipeline asynchronously
BUG=dawn:529
TEST=dawn_end2end_tests
Change-Id: Id2b2bebe164ccc829e4f2cf737255d634d6572a0
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/53760
Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
Reviewed-by: Austin Eng <enga@chromium.org>
diff --git a/src/dawn_native/d3d12/ComputePipelineD3D12.cpp b/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
index 9fa167a..efd942a 100644
--- a/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
+++ b/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
@@ -14,7 +14,6 @@
#include "dawn_native/d3d12/ComputePipelineD3D12.h"
-#include "dawn_native/AsyncTask.h"
#include "dawn_native/CreatePipelineAsyncTask.h"
#include "dawn_native/d3d12/D3D12Error.h"
#include "dawn_native/d3d12/DeviceD3D12.h"
diff --git a/src/dawn_native/vulkan/ComputePipelineVk.cpp b/src/dawn_native/vulkan/ComputePipelineVk.cpp
index 322c026..5bc1d4d 100644
--- a/src/dawn_native/vulkan/ComputePipelineVk.cpp
+++ b/src/dawn_native/vulkan/ComputePipelineVk.cpp
@@ -14,6 +14,7 @@
#include "dawn_native/vulkan/ComputePipelineVk.h"
+#include "dawn_native/CreatePipelineAsyncTask.h"
#include "dawn_native/vulkan/DeviceVk.h"
#include "dawn_native/vulkan/FencedDeleter.h"
#include "dawn_native/vulkan/PipelineLayoutVk.h"
@@ -88,4 +89,16 @@
return mHandle;
}
+ void ComputePipeline::CreateAsync(Device* device,
+ const ComputePipelineDescriptor* descriptor,
+ size_t blueprintHash,
+ WGPUCreateComputePipelineAsyncCallback callback,
+ void* userdata) {
+ Ref<ComputePipeline> pipeline = AcquireRef(new ComputePipeline(device, descriptor));
+ std::unique_ptr<CreateComputePipelineAsyncTask> asyncTask =
+ std::make_unique<CreateComputePipelineAsyncTask>(pipeline, descriptor, blueprintHash,
+ callback, userdata);
+ CreateComputePipelineAsyncTask::RunAsync(std::move(asyncTask));
+ }
+
}} // namespace dawn_native::vulkan
diff --git a/src/dawn_native/vulkan/ComputePipelineVk.h b/src/dawn_native/vulkan/ComputePipelineVk.h
index 97a32bf..f5ac787 100644
--- a/src/dawn_native/vulkan/ComputePipelineVk.h
+++ b/src/dawn_native/vulkan/ComputePipelineVk.h
@@ -29,6 +29,11 @@
static ResultOrError<Ref<ComputePipeline>> Create(
Device* device,
const ComputePipelineDescriptor* descriptor);
+ static void CreateAsync(Device* device,
+ const ComputePipelineDescriptor* descriptor,
+ size_t blueprintHash,
+ WGPUCreateComputePipelineAsyncCallback callback,
+ void* userdata);
VkPipeline GetHandle() const;
diff --git a/src/dawn_native/vulkan/DeviceVk.cpp b/src/dawn_native/vulkan/DeviceVk.cpp
index cc69fa9..5541dea 100644
--- a/src/dawn_native/vulkan/DeviceVk.cpp
+++ b/src/dawn_native/vulkan/DeviceVk.cpp
@@ -160,6 +160,12 @@
const TextureViewDescriptor* descriptor) {
return TextureView::Create(texture, descriptor);
}
+ void Device::CreateComputePipelineAsyncImpl(const ComputePipelineDescriptor* descriptor,
+ size_t blueprintHash,
+ WGPUCreateComputePipelineAsyncCallback callback,
+ void* userdata) {
+ ComputePipeline::CreateAsync(this, descriptor, blueprintHash, callback, userdata);
+ }
MaybeError Device::TickImpl() {
RecycleCompletedCommands();
diff --git a/src/dawn_native/vulkan/DeviceVk.h b/src/dawn_native/vulkan/DeviceVk.h
index 376f15a..497defe 100644
--- a/src/dawn_native/vulkan/DeviceVk.h
+++ b/src/dawn_native/vulkan/DeviceVk.h
@@ -143,6 +143,10 @@
ResultOrError<Ref<TextureViewBase>> CreateTextureViewImpl(
TextureBase* texture,
const TextureViewDescriptor* descriptor) override;
+ void CreateComputePipelineAsyncImpl(const ComputePipelineDescriptor* descriptor,
+ size_t blueprintHash,
+ WGPUCreateComputePipelineAsyncCallback callback,
+ void* userdata) override;
ResultOrError<VulkanDeviceKnobs> CreateDevice(VkPhysicalDevice physicalDevice);
void GatherQueueFromDevice();
diff --git a/src/dawn_native/vulkan/ShaderModuleVk.cpp b/src/dawn_native/vulkan/ShaderModuleVk.cpp
index 2e256da..e687505 100644
--- a/src/dawn_native/vulkan/ShaderModuleVk.cpp
+++ b/src/dawn_native/vulkan/ShaderModuleVk.cpp
@@ -31,6 +31,45 @@
namespace dawn_native { namespace vulkan {
+ ShaderModule::ConcurrentTransformedShaderModuleCache::ConcurrentTransformedShaderModuleCache(
+ Device* device)
+ : mDevice(device) {
+ }
+
+ ShaderModule::ConcurrentTransformedShaderModuleCache::
+ ~ConcurrentTransformedShaderModuleCache() {
+ std::lock_guard<std::mutex> lock(mMutex);
+ for (const auto& iter : mTransformedShaderModuleCache) {
+ mDevice->GetFencedDeleter()->DeleteWhenUnused(iter.second);
+ }
+ }
+
+ VkShaderModule ShaderModule::ConcurrentTransformedShaderModuleCache::FindShaderModule(
+ const PipelineLayoutEntryPointPair& key) {
+ std::lock_guard<std::mutex> lock(mMutex);
+ auto iter = mTransformedShaderModuleCache.find(key);
+ if (iter != mTransformedShaderModuleCache.end()) {
+ auto cached = iter->second;
+ return cached;
+ }
+ return VK_NULL_HANDLE;
+ }
+
+ VkShaderModule ShaderModule::ConcurrentTransformedShaderModuleCache::AddOrGetCachedShaderModule(
+ const PipelineLayoutEntryPointPair& key,
+ VkShaderModule value) {
+ ASSERT(value != VK_NULL_HANDLE);
+ std::lock_guard<std::mutex> lock(mMutex);
+ auto iter = mTransformedShaderModuleCache.find(key);
+ if (iter == mTransformedShaderModuleCache.end()) {
+ mTransformedShaderModuleCache.emplace(key, value);
+ return value;
+ } else {
+ mDevice->GetFencedDeleter()->DeleteWhenUnused(value);
+ return iter->second;
+ }
+ }
+
// static
ResultOrError<Ref<ShaderModule>> ShaderModule::Create(Device* device,
const ShaderModuleDescriptor* descriptor,
@@ -41,7 +80,7 @@
}
ShaderModule::ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor)
- : ShaderModuleBase(device, descriptor) {
+ : ShaderModuleBase(device, descriptor), mTransformedShaderModuleCache(device) {
}
MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
@@ -112,10 +151,6 @@
device->GetFencedDeleter()->DeleteWhenUnused(mHandle);
mHandle = VK_NULL_HANDLE;
}
-
- for (const auto& iter : mTransformedShaderModuleCache) {
- device->GetFencedDeleter()->DeleteWhenUnused(iter.second);
- }
}
VkShaderModule ShaderModule::GetHandle() const {
@@ -131,10 +166,10 @@
ASSERT(GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator));
auto cacheKey = std::make_pair(layout, entryPointName);
- auto iter = mTransformedShaderModuleCache.find(cacheKey);
- if (iter != mTransformedShaderModuleCache.end()) {
- auto cached = iter->second;
- return cached;
+ VkShaderModule cachedShaderModule =
+ mTransformedShaderModuleCache.FindShaderModule(cacheKey);
+ if (cachedShaderModule != VK_NULL_HANDLE) {
+ return cachedShaderModule;
}
// Creation of VkShaderModule is deferred to this point when using tint generator
@@ -204,7 +239,8 @@
device->fn.CreateShaderModule(device->GetVkDevice(), &createInfo, nullptr, &*newHandle),
"CreateShaderModule"));
if (newHandle != VK_NULL_HANDLE) {
- mTransformedShaderModuleCache.emplace(cacheKey, newHandle);
+ newHandle =
+ mTransformedShaderModuleCache.AddOrGetCachedShaderModule(cacheKey, newHandle);
}
return newHandle;
diff --git a/src/dawn_native/vulkan/ShaderModuleVk.h b/src/dawn_native/vulkan/ShaderModuleVk.h
index 9dd7817..7bd0c2f 100644
--- a/src/dawn_native/vulkan/ShaderModuleVk.h
+++ b/src/dawn_native/vulkan/ShaderModuleVk.h
@@ -20,15 +20,13 @@
#include "common/vulkan_platform.h"
#include "dawn_native/Error.h"
+#include <mutex>
+
namespace dawn_native { namespace vulkan {
class Device;
class PipelineLayout;
- using TransformedShaderModuleCache = std::unordered_map<PipelineLayoutEntryPointPair,
- VkShaderModule,
- PipelineLayoutEntryPointPairHashFunc>;
-
class ShaderModule final : public ShaderModuleBase {
public:
static ResultOrError<Ref<ShaderModule>> Create(Device* device,
@@ -49,7 +47,23 @@
VkShaderModule mHandle = VK_NULL_HANDLE;
// New handles created by GetTransformedModuleHandle at pipeline creation time
- TransformedShaderModuleCache mTransformedShaderModuleCache;
+ class ConcurrentTransformedShaderModuleCache {
+ public:
+ explicit ConcurrentTransformedShaderModuleCache(Device* device);
+ ~ConcurrentTransformedShaderModuleCache();
+ VkShaderModule FindShaderModule(const PipelineLayoutEntryPointPair& key);
+ VkShaderModule AddOrGetCachedShaderModule(const PipelineLayoutEntryPointPair& key,
+ VkShaderModule value);
+
+ private:
+ Device* mDevice;
+ std::mutex mMutex;
+ std::unordered_map<PipelineLayoutEntryPointPair,
+ VkShaderModule,
+ PipelineLayoutEntryPointPairHashFunc>
+ mTransformedShaderModuleCache;
+ };
+ ConcurrentTransformedShaderModuleCache mTransformedShaderModuleCache;
};
}} // namespace dawn_native::vulkan