Add compute pipeline cache key generation for Vulkan.
- Adds dependency to vulkan-tools for pNext chain helpers.
- Adds extra caching to vulkan shaders to keep the spirv in the in-memory cache as well.
- Adds pNext chain serializer infra for Vulkan.
Change-Id: Ibe73183fbff15f7310eaaeae92fbd622be1ac096
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/85022
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Loko Kung <lokokung@google.com>
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 8e4d6da..6adb40c 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -142,7 +142,8 @@
set_if_not_defined(DAWN_SPIRV_HEADERS_DIR "${DAWN_THIRD_PARTY_DIR}/vulkan-deps/spirv-headers/src" "Directory in which to find SPIRV-Headers")
set_if_not_defined(DAWN_SPIRV_TOOLS_DIR "${DAWN_THIRD_PARTY_DIR}/vulkan-deps/spirv-tools/src" "Directory in which to find SPIRV-Tools")
set_if_not_defined(DAWN_TINT_DIR "${Dawn_SOURCE_DIR}" "Directory in which to find Tint")
-set_if_not_defined(DAWN_VULKAN_HEADERS_DIR "${DAWN_THIRD_PARTY_DIR}/vulkan-deps/vulkan-headers/src" "Directory in which to find Vulkan-Headers")
+set_if_not_defined(DAWN_VULKAN_DEPS_DIR "${DAWN_THIRD_PARTY_DIR}/vulkan-deps" "Directory in which to find vulkan-deps")
+set_if_not_defined(DAWN_VULKAN_HEADERS_DIR "${DAWN_VULKAN_DEPS_DIR}/vulkan-headers/src" "Directory in which to find Vulkan-Headers")
# Dependencies for DAWN_BUILD_NODE_BINDINGS
set_if_not_defined(NODE_ADDON_API_DIR "${DAWN_THIRD_PARTY_DIR}/node-addon-api" "Directory in which to find node-addon-api")
diff --git a/scripts/dawn_overrides_with_defaults.gni b/scripts/dawn_overrides_with_defaults.gni
index 46f44ef..b4142ac 100644
--- a/scripts/dawn_overrides_with_defaults.gni
+++ b/scripts/dawn_overrides_with_defaults.gni
@@ -54,14 +54,17 @@
dawn_swiftshader_dir = ""
}
-if (!defined(dawn_vulkan_headers_dir)) {
- dawn_vulkan_headers_dir = "//third_party/vulkan-deps/vulkan-headers/src"
+if (!defined(dawn_vulkan_deps_dir)) {
+ dawn_vulkan_deps_dir = "//third_party/vulkan-deps"
if (dawn_standalone) {
- dawn_vulkan_headers_dir =
- "${dawn_root}/third_party/vulkan-deps/vulkan-headers/src"
+ dawn_vulkan_deps_dir = "${dawn_root}/third_party/vulkan-deps"
}
}
+if (!defined(dawn_vulkan_headers_dir)) {
+ dawn_vulkan_headers_dir = "${dawn_vulkan_deps_dir}/vulkan-headers/src"
+}
+
if (!defined(dawn_vulkan_loader_dir)) {
# Default to the Vulkan loader not being available except in standalone.
dawn_vulkan_loader_dir = ""
@@ -70,6 +73,10 @@
}
}
+if (!defined(dawn_vulkan_tools_dir)) {
+ dawn_vulkan_tools_dir = "${dawn_vulkan_deps_dir}/vulkan-tools/src"
+}
+
if (!defined(dawn_vulkan_validation_layers_dir)) {
# Default to VVLs not being available.
dawn_vulkan_validation_layers_dir = ""
diff --git a/src/dawn/native/BUILD.gn b/src/dawn/native/BUILD.gn
index 4547138..e6d90dc 100644
--- a/src/dawn/native/BUILD.gn
+++ b/src/dawn/native/BUILD.gn
@@ -99,6 +99,11 @@
}
}
+# Config that adds include directory for vulkan-deps, specifically for Vulkan-Tools.
+config("vulkan_deps_include") {
+ include_dirs = [ "${dawn_vulkan_deps_dir}" ]
+}
+
dawn_json_generator("utils_gen") {
target = "native_utils"
outputs = [
@@ -571,6 +576,8 @@
}
if (dawn_enable_vulkan) {
+ configs += [ ":vulkan_deps_include" ]
+ deps += [ "${dawn_vulkan_tools_dir}:vulkan_tools_headers" ]
public_deps += [ "${dawn_vulkan_headers_dir}:vulkan_headers" ]
sources += [
"vulkan/AdapterVk.cpp",
@@ -583,6 +590,8 @@
"vulkan/BindGroupVk.h",
"vulkan/BufferVk.cpp",
"vulkan/BufferVk.h",
+ "vulkan/CacheKeyVk.cpp",
+ "vulkan/CacheKeyVk.h",
"vulkan/CommandBufferVk.cpp",
"vulkan/CommandBufferVk.h",
"vulkan/CommandRecordingContext.h",
diff --git a/src/dawn/native/CMakeLists.txt b/src/dawn/native/CMakeLists.txt
index 170ea70..e7cab14 100644
--- a/src/dawn/native/CMakeLists.txt
+++ b/src/dawn/native/CMakeLists.txt
@@ -455,6 +455,8 @@
"vulkan/BindGroupVk.h"
"vulkan/BufferVk.cpp"
"vulkan/BufferVk.h"
+ "vulkan/CacheKeyVk.cpp"
+ "vulkan/CacheKeyVk.h"
"vulkan/CommandBufferVk.cpp"
"vulkan/CommandBufferVk.h"
"vulkan/CommandRecordingContext.h"
@@ -510,6 +512,7 @@
)
target_link_libraries(dawn_native PUBLIC Vulkan-Headers)
+ target_include_directories(dawn_native PRIVATE ${DAWN_VULKAN_DEPS_DIR})
if (UNIX AND NOT APPLE)
target_sources(dawn_native PRIVATE
diff --git a/src/dawn/native/CacheKey.cpp b/src/dawn/native/CacheKey.cpp
index 3495577..dea67f8 100644
--- a/src/dawn/native/CacheKey.cpp
+++ b/src/dawn/native/CacheKey.cpp
@@ -14,8 +14,19 @@
#include "dawn/native/CacheKey.h"
+#include <iomanip>
+
namespace dawn::native {
+ std::ostream& operator<<(std::ostream& os, const CacheKey& key) {
+ os << std::hex;
+ for (const int b : key) {
+ os << std::setfill('0') << std::setw(2) << b << " ";
+ }
+ os << std::dec;
+ return os;
+ }
+
template <>
void CacheKeySerializer<std::string>::Serialize(CacheKey* key, const std::string& t) {
key->Record(static_cast<size_t>(t.length()));
diff --git a/src/dawn/native/CacheKey.h b/src/dawn/native/CacheKey.h
index ce21f6d..e97e770 100644
--- a/src/dawn/native/CacheKey.h
+++ b/src/dawn/native/CacheKey.h
@@ -15,17 +15,20 @@
#ifndef DAWNNATIVE_CACHE_KEY_H_
#define DAWNNATIVE_CACHE_KEY_H_
+#include <iostream>
#include <limits>
#include <string>
#include <type_traits>
#include <vector>
-#include "dawn/common/Assert.h"
-
namespace dawn::native {
- // Forward declare CacheKey class because of co-dependency.
+ // Forward declare classes because of co-dependency.
class CacheKey;
+ class CachedObject;
+
+ // Stream operator for CacheKey for debugging.
+ std::ostream& operator<<(std::ostream& os, const CacheKey& key);
// Overridable serializer struct that should be implemented for cache key serializable
// types/classes.
@@ -82,7 +85,31 @@
}
};
- // Specialized overload for string literals. Note we drop the null-terminator.
+ // Specialized overload for enums.
+ template <typename T>
+ class CacheKeySerializer<T, std::enable_if_t<std::is_enum_v<T>>> {
+ public:
+ static void Serialize(CacheKey* key, const T t) {
+ CacheKeySerializer<std::underlying_type_t<T>>::Serialize(
+ key, static_cast<std::underlying_type_t<T>>(t));
+ }
+ };
+
+ // Specialized overload for pointers. Since we are serializing for a cache key, we always
+ // serialize via value, not by pointer. To handle nullptr scenarios, we always serialize whether
+ // the pointer was nullptr followed by the contents if applicable.
+ template <typename T>
+ class CacheKeySerializer<T, std::enable_if_t<std::is_pointer_v<T>>> {
+ public:
+ static void Serialize(CacheKey* key, const T t) {
+ key->Record(t == nullptr);
+ if (t != nullptr) {
+ CacheKeySerializer<std::remove_cv_t<std::remove_pointer_t<T>>>::Serialize(key, *t);
+ }
+ }
+ };
+
+ // Specialized overload for string literals.
template <size_t N>
class CacheKeySerializer<char[N]> {
public:
@@ -93,6 +120,15 @@
}
};
+ // Specialized overload for CachedObjects.
+ template <typename T>
+ class CacheKeySerializer<T, std::enable_if_t<std::is_base_of_v<CachedObject, T>>> {
+ public:
+ static void Serialize(CacheKey* key, const T& t) {
+ key->Record(t.GetCacheKey());
+ }
+ };
+
} // namespace dawn::native
#endif // DAWNNATIVE_CACHE_KEY_H_
diff --git a/src/dawn/native/vulkan/BindGroupLayoutVk.cpp b/src/dawn/native/vulkan/BindGroupLayoutVk.cpp
index 8ed4340..c377c96 100644
--- a/src/dawn/native/vulkan/BindGroupLayoutVk.cpp
+++ b/src/dawn/native/vulkan/BindGroupLayoutVk.cpp
@@ -16,6 +16,7 @@
#include "dawn/common/BitSetIterator.h"
#include "dawn/common/ityp_vector.h"
+#include "dawn/native/CacheKey.h"
#include "dawn/native/vulkan/BindGroupVk.h"
#include "dawn/native/vulkan/DescriptorSetAllocator.h"
#include "dawn/native/vulkan/DeviceVk.h"
@@ -115,6 +116,9 @@
createInfo.bindingCount = static_cast<uint32_t>(bindings.size());
createInfo.pBindings = bindings.data();
+ // Record cache key information now since the createInfo is not stored.
+ GetCacheKey()->Record(createInfo);
+
Device* device = ToBackend(GetDevice());
DAWN_TRY(CheckVkSuccess(device->fn.CreateDescriptorSetLayout(
device->GetVkDevice(), &createInfo, nullptr, &*mHandle),
diff --git a/src/dawn/native/vulkan/BindGroupLayoutVk.h b/src/dawn/native/vulkan/BindGroupLayoutVk.h
index 558ff7f..d8adedc 100644
--- a/src/dawn/native/vulkan/BindGroupLayoutVk.h
+++ b/src/dawn/native/vulkan/BindGroupLayoutVk.h
@@ -22,6 +22,10 @@
#include <vector>
+namespace dawn::native {
+ class CacheKey;
+} // namespace dawn::native
+
namespace dawn::native::vulkan {
class BindGroup;
diff --git a/src/dawn/native/vulkan/CacheKeyVk.cpp b/src/dawn/native/vulkan/CacheKeyVk.cpp
new file mode 100644
index 0000000..9fffbd1
--- /dev/null
+++ b/src/dawn/native/vulkan/CacheKeyVk.cpp
@@ -0,0 +1,97 @@
+// Copyright 2022 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "dawn/native/vulkan/CacheKeyVk.h"
+
+#include <cstring>
+
+namespace dawn::native {
+
+ template <>
+ void CacheKeySerializer<VkDescriptorSetLayoutBinding>::Serialize(
+ CacheKey* key,
+ const VkDescriptorSetLayoutBinding& t) {
+ key->Record(t.binding, t.descriptorType, t.descriptorCount, t.stageFlags);
+ }
+
+ template <>
+ void CacheKeySerializer<VkDescriptorSetLayoutCreateInfo>::Serialize(
+ CacheKey* key,
+ const VkDescriptorSetLayoutCreateInfo& t) {
+ key->Record(t.flags).RecordIterable(t.pBindings, t.bindingCount);
+ vulkan::SerializePnext<>(key, reinterpret_cast<const VkBaseOutStructure*>(&t));
+ }
+
+ template <>
+ void CacheKeySerializer<VkPushConstantRange>::Serialize(CacheKey* key,
+ const VkPushConstantRange& t) {
+ key->Record(t.stageFlags, t.offset, t.size);
+ }
+
+ template <>
+ void CacheKeySerializer<VkPipelineLayoutCreateInfo>::Serialize(
+ CacheKey* key,
+ const VkPipelineLayoutCreateInfo& t) {
+ // The set layouts are not serialized here because they are pointers to backend objects.
+ // They need to be cross-referenced with the frontend objects and serialized from there.
+ key->Record(t.flags).RecordIterable(t.pPushConstantRanges, t.pushConstantRangeCount);
+ vulkan::SerializePnext<>(key, reinterpret_cast<const VkBaseOutStructure*>(&t));
+ }
+
+ template <>
+ void CacheKeySerializer<VkPipelineShaderStageRequiredSubgroupSizeCreateInfoEXT>::Serialize(
+ CacheKey* key,
+ const VkPipelineShaderStageRequiredSubgroupSizeCreateInfoEXT& t) {
+ key->Record(t.requiredSubgroupSize);
+ }
+
+ template <>
+ void CacheKeySerializer<VkSpecializationMapEntry>::Serialize(
+ CacheKey* key,
+ const VkSpecializationMapEntry& t) {
+ key->Record(t.constantID, t.offset, t.size);
+ }
+
+ template <>
+ void CacheKeySerializer<VkSpecializationInfo>::Serialize(CacheKey* key,
+ const VkSpecializationInfo& t) {
+ key->RecordIterable(t.pMapEntries, t.mapEntryCount)
+ .RecordIterable(static_cast<const uint8_t*>(t.pData), t.dataSize);
+ }
+
+ template <>
+ void CacheKeySerializer<VkPipelineShaderStageCreateInfo>::Serialize(
+ CacheKey* key,
+ const VkPipelineShaderStageCreateInfo& t) {
+ // The shader module is not serialized here because it is a pointer to a backend object.
+ key->Record(t.flags, t.stage)
+ .RecordIterable(t.pName, strlen(t.pName))
+ .Record(t.pSpecializationInfo);
+ vulkan::SerializePnext<VkPipelineShaderStageRequiredSubgroupSizeCreateInfoEXT>(
+ key, reinterpret_cast<const VkBaseOutStructure*>(&t));
+ }
+
+ template <>
+ void CacheKeySerializer<VkComputePipelineCreateInfo>::Serialize(
+ CacheKey* key,
+ const VkComputePipelineCreateInfo& t) {
+ // The pipeline layout is not serialized here because it is a pointer to a backend object.
+ // It needs to be cross-referenced with the frontend objects and serialized from there. The
+ // base pipeline information is also currently not recorded since we do not use them in our
+ // backend implementation. If we decide to use them later on, they also need to be
+ // cross-referenced from the frontend.
+ key->Record(t.flags, t.stage);
+ }
+
+} // namespace dawn::native
diff --git a/src/dawn/native/vulkan/CacheKeyVk.h b/src/dawn/native/vulkan/CacheKeyVk.h
new file mode 100644
index 0000000..ab8e02d
--- /dev/null
+++ b/src/dawn/native/vulkan/CacheKeyVk.h
@@ -0,0 +1,85 @@
+// Copyright 2022 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "dawn/common/Assert.h"
+#include "dawn/common/vulkan_platform.h"
+#include "dawn/native/CacheKey.h"
+
+#include "vulkan-tools/src/icd/generated/vk_typemap_helper.h"
+
+#include <map>
+
+namespace dawn::native::vulkan {
+
+ namespace detail {
+
+ template <typename... VK_STRUCT_TYPES>
+ void ValidatePnextImpl(const VkBaseOutStructure* root) {
+ const VkBaseOutStructure* next =
+ reinterpret_cast<const VkBaseOutStructure*>(root->pNext);
+ while (next != nullptr) {
+ // Assert that the type of each pNext struct is exactly one of the specified
+ // templates.
+ ASSERT(((LvlTypeMap<VK_STRUCT_TYPES>::kSType == next->sType ? 1 : 0) + ... + 0) ==
+ 1);
+ next = reinterpret_cast<const VkBaseOutStructure*>(next->pNext);
+ }
+ }
+
+ template <typename VK_STRUCT_TYPE>
+ void SerializePnextImpl(CacheKey* key, const VkBaseOutStructure* root) {
+ const VkBaseOutStructure* next =
+ reinterpret_cast<const VkBaseOutStructure*>(root->pNext);
+ const VK_STRUCT_TYPE* found = nullptr;
+ while (next != nullptr) {
+ if (LvlTypeMap<VK_STRUCT_TYPE>::kSType == next->sType) {
+ if (found == nullptr) {
+ found = reinterpret_cast<const VK_STRUCT_TYPE*>(next);
+ } else {
+ // Fail an assert here since that means that the chain had more than one of
+ // the same typed chained object.
+ ASSERT(false);
+ }
+ }
+ next = reinterpret_cast<const VkBaseOutStructure*>(next->pNext);
+ }
+ if (found != nullptr) {
+ key->Record(found);
+ }
+ }
+
+ template <typename VK_STRUCT_TYPE,
+ typename... VK_STRUCT_TYPES,
+ typename = std::enable_if_t<(sizeof...(VK_STRUCT_TYPES) > 0)>>
+ void SerializePnextImpl(CacheKey* key, const VkBaseOutStructure* root) {
+ SerializePnextImpl<VK_STRUCT_TYPE>(key, root);
+ SerializePnextImpl<VK_STRUCT_TYPES...>(key, root);
+ }
+
+ } // namespace detail
+
+ template <typename... VK_STRUCT_TYPES>
+ void SerializePnext(CacheKey* key, const VkBaseOutStructure* root) {
+ detail::ValidatePnextImpl<VK_STRUCT_TYPES...>(root);
+ detail::SerializePnextImpl<VK_STRUCT_TYPES...>(key, root);
+ }
+
+ // Empty template specialization so that we can put this in to ensure failures occur if new
+ // extensions are added without updating serialization.
+ template <>
+ void SerializePnext(CacheKey* key, const VkBaseOutStructure* root) {
+ detail::ValidatePnextImpl<>(root);
+ }
+
+} // namespace dawn::native::vulkan
diff --git a/src/dawn/native/vulkan/ComputePipelineVk.cpp b/src/dawn/native/vulkan/ComputePipelineVk.cpp
index fa13e26..68ac7d9 100644
--- a/src/dawn/native/vulkan/ComputePipelineVk.cpp
+++ b/src/dawn/native/vulkan/ComputePipelineVk.cpp
@@ -22,6 +22,8 @@
#include "dawn/native/vulkan/UtilsVulkan.h"
#include "dawn/native/vulkan/VulkanError.h"
+#include <utility>
+
namespace dawn::native::vulkan {
// static
@@ -46,10 +48,11 @@
createInfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
// Generate a new VkShaderModule with BindingRemapper tint transform for each pipeline
const ProgrammableStage& computeStage = GetStage(SingleShaderStage::Compute);
- DAWN_TRY_ASSIGN(createInfo.stage.module,
- ToBackend(computeStage.module.Get())
- ->GetTransformedModuleHandle(computeStage.entryPoint.c_str(),
- ToBackend(GetLayout())));
+ ShaderModule* module = ToBackend(computeStage.module.Get());
+ PipelineLayout* layout = ToBackend(GetLayout());
+ const ShaderModule::Spirv* spirv;
+ DAWN_TRY_ASSIGN((std::tie(createInfo.stage.module, spirv)),
+ module->GetHandleAndSpirv(computeStage.entryPoint.c_str(), layout));
createInfo.stage.pName = computeStage.entryPoint.c_str();
@@ -74,6 +77,11 @@
VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_REQUIRED_SUBGROUP_SIZE_CREATE_INFO_EXT);
}
+ // Record cache key information now since the createInfo is not stored.
+ GetCacheKey()
+ ->Record(createInfo, static_cast<const ComputePipeline*>(this)->GetLayout())
+ .RecordIterable(*spirv);
+
DAWN_TRY(CheckVkSuccess(
device->fn.CreateComputePipelines(device->GetVkDevice(), ::VK_NULL_HANDLE, 1,
&createInfo, nullptr, &*mHandle),
diff --git a/src/dawn/native/vulkan/PipelineLayoutVk.cpp b/src/dawn/native/vulkan/PipelineLayoutVk.cpp
index 245f2c9..e6653bd 100644
--- a/src/dawn/native/vulkan/PipelineLayoutVk.cpp
+++ b/src/dawn/native/vulkan/PipelineLayoutVk.cpp
@@ -38,8 +38,11 @@
// this constraints at the Dawn level?
uint32_t numSetLayouts = 0;
std::array<VkDescriptorSetLayout, kMaxBindGroups> setLayouts;
+ std::array<const CachedObject*, kMaxBindGroups> cachedObjects;
for (BindGroupIndex setIndex : IterateBitSet(GetBindGroupLayoutsMask())) {
- setLayouts[numSetLayouts] = ToBackend(GetBindGroupLayout(setIndex))->GetHandle();
+ const BindGroupLayoutBase* bindGroupLayout = GetBindGroupLayout(setIndex);
+ setLayouts[numSetLayouts] = ToBackend(bindGroupLayout)->GetHandle();
+ cachedObjects[numSetLayouts] = bindGroupLayout;
numSetLayouts++;
}
@@ -52,6 +55,9 @@
createInfo.pushConstantRangeCount = 0;
createInfo.pPushConstantRanges = nullptr;
+ // Record cache key information now since the createInfo is not stored.
+ GetCacheKey()->RecordIterable(cachedObjects.data(), numSetLayouts).Record(createInfo);
+
Device* device = ToBackend(GetDevice());
DAWN_TRY(CheckVkSuccess(
device->fn.CreatePipelineLayout(device->GetVkDevice(), &createInfo, nullptr, &*mHandle),
diff --git a/src/dawn/native/vulkan/RenderPipelineVk.cpp b/src/dawn/native/vulkan/RenderPipelineVk.cpp
index 4f30496..405da49 100644
--- a/src/dawn/native/vulkan/RenderPipelineVk.cpp
+++ b/src/dawn/native/vulkan/RenderPipelineVk.cpp
@@ -332,6 +332,7 @@
MaybeError RenderPipeline::Initialize() {
Device* device = ToBackend(GetDevice());
+ PipelineLayout* layout = ToBackend(GetLayout());
// There are at most 2 shader stages in render pipeline, i.e. vertex and fragment
std::array<VkPipelineShaderStageCreateInfo, 2> shaderStages;
@@ -344,10 +345,11 @@
VkPipelineShaderStageCreateInfo shaderStage;
const ProgrammableStage& programmableStage = GetStage(stage);
- DAWN_TRY_ASSIGN(shaderStage.module,
- ToBackend(programmableStage.module)
- ->GetTransformedModuleHandle(programmableStage.entryPoint.c_str(),
- ToBackend(GetLayout())));
+ ShaderModule* module = ToBackend(programmableStage.module.Get());
+ const ShaderModule::Spirv* spirv;
+ DAWN_TRY_ASSIGN(
+ std::tie(shaderStage.module, spirv),
+ module->GetHandleAndSpirv(programmableStage.entryPoint.c_str(), layout));
shaderStage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
shaderStage.pNext = nullptr;
diff --git a/src/dawn/native/vulkan/ShaderModuleVk.cpp b/src/dawn/native/vulkan/ShaderModuleVk.cpp
index 9b8c291..d56e599 100644
--- a/src/dawn/native/vulkan/ShaderModuleVk.cpp
+++ b/src/dawn/native/vulkan/ShaderModuleVk.cpp
@@ -38,35 +38,38 @@
ShaderModule::ConcurrentTransformedShaderModuleCache::
~ConcurrentTransformedShaderModuleCache() {
std::lock_guard<std::mutex> lock(mMutex);
- for (const auto& [_, module] : mTransformedShaderModuleCache) {
- mDevice->GetFencedDeleter()->DeleteWhenUnused(module);
+ for (const auto& [_, moduleAndSpirv] : mTransformedShaderModuleCache) {
+ mDevice->GetFencedDeleter()->DeleteWhenUnused(moduleAndSpirv.first);
}
}
- VkShaderModule ShaderModule::ConcurrentTransformedShaderModuleCache::FindShaderModule(
+ std::optional<ShaderModule::ModuleAndSpirv>
+ ShaderModule::ConcurrentTransformedShaderModuleCache::Find(
const PipelineLayoutEntryPointPair& key) {
std::lock_guard<std::mutex> lock(mMutex);
auto iter = mTransformedShaderModuleCache.find(key);
if (iter != mTransformedShaderModuleCache.end()) {
- auto cached = iter->second;
- return cached;
+ return std::make_pair(iter->second.first, iter->second.second.get());
}
- return VK_NULL_HANDLE;
+ return {};
}
- VkShaderModule ShaderModule::ConcurrentTransformedShaderModuleCache::AddOrGetCachedShaderModule(
+ ShaderModule::ModuleAndSpirv ShaderModule::ConcurrentTransformedShaderModuleCache::AddOrGet(
const PipelineLayoutEntryPointPair& key,
- VkShaderModule value) {
- ASSERT(value != VK_NULL_HANDLE);
+ VkShaderModule module,
+ std::vector<uint32_t>&& spirv) {
+ ASSERT(module != VK_NULL_HANDLE);
std::lock_guard<std::mutex> lock(mMutex);
auto iter = mTransformedShaderModuleCache.find(key);
if (iter == mTransformedShaderModuleCache.end()) {
- mTransformedShaderModuleCache.emplace(key, value);
- return value;
+ mTransformedShaderModuleCache.emplace(
+ key, std::make_pair(module, std::unique_ptr<Spirv>(new Spirv(spirv))));
} else {
- mDevice->GetFencedDeleter()->DeleteWhenUnused(value);
- return iter->second;
+ mDevice->GetFencedDeleter()->DeleteWhenUnused(module);
}
+ // Now the key should exist in the map, so find it again and return it.
+ iter = mTransformedShaderModuleCache.find(key);
+ return std::make_pair(iter->second.first, iter->second.second.get());
}
// static
@@ -109,25 +112,24 @@
ShaderModule::~ShaderModule() = default;
- ResultOrError<VkShaderModule> ShaderModule::GetTransformedModuleHandle(
+ ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
const char* entryPointName,
PipelineLayout* layout) {
- TRACE_EVENT0(GetDevice()->GetPlatform(), General,
- "ShaderModuleVk::GetTransformedModuleHandle");
+ TRACE_EVENT0(GetDevice()->GetPlatform(), General, "ShaderModuleVk::GetHandleAndSpirv");
// If the shader was destroyed, we should never call this function.
ASSERT(IsAlive());
ScopedTintICEHandler scopedICEHandler(GetDevice());
+ // Check to see if we have the handle and spirv cached already.
auto cacheKey = std::make_pair(layout, entryPointName);
- VkShaderModule cachedShaderModule =
- mTransformedShaderModuleCache->FindShaderModule(cacheKey);
- if (cachedShaderModule != VK_NULL_HANDLE) {
- return cachedShaderModule;
+ auto handleAndSpirv = mTransformedShaderModuleCache->Find(cacheKey);
+ if (handleAndSpirv.has_value()) {
+ return std::move(*handleAndSpirv);
}
- // Creation of VkShaderModule is deferred to this point when using tint generator
+ // Creation of module and spirv is deferred to this point when using tint generator
// Remap BindingNumber to BindingIndex in WGSL shader
using BindingRemapper = tint::transform::BindingRemapper;
@@ -207,7 +209,7 @@
options.use_zero_initialize_workgroup_memory_extension =
GetDevice()->IsToggleEnabled(Toggle::VulkanUseZeroInitializeWorkgroupMemoryExtension);
- std::vector<uint32_t> spirv;
+ Spirv spirv;
{
TRACE_EVENT0(GetDevice()->GetPlatform(), General, "tint::writer::spirv::Generate()");
auto result = tint::writer::spirv::Generate(&program, options);
@@ -236,15 +238,17 @@
device->GetVkDevice(), &createInfo, nullptr, &*newHandle),
"CreateShaderModule"));
}
+ ModuleAndSpirv moduleAndSpirv;
if (newHandle != VK_NULL_HANDLE) {
- newHandle =
- mTransformedShaderModuleCache->AddOrGetCachedShaderModule(cacheKey, newHandle);
+ moduleAndSpirv =
+ mTransformedShaderModuleCache->AddOrGet(cacheKey, newHandle, std::move(spirv));
}
SetDebugName(ToBackend(GetDevice()), VK_OBJECT_TYPE_SHADER_MODULE,
- reinterpret_cast<uint64_t&>(newHandle), "Dawn_ShaderModule", GetLabel());
+ reinterpret_cast<uint64_t&>(moduleAndSpirv.first), "Dawn_ShaderModule",
+ GetLabel());
- return newHandle;
+ return std::move(moduleAndSpirv);
}
} // namespace dawn::native::vulkan
diff --git a/src/dawn/native/vulkan/ShaderModuleVk.h b/src/dawn/native/vulkan/ShaderModuleVk.h
index 7040b74..3b69b75 100644
--- a/src/dawn/native/vulkan/ShaderModuleVk.h
+++ b/src/dawn/native/vulkan/ShaderModuleVk.h
@@ -20,7 +20,11 @@
#include "dawn/common/vulkan_platform.h"
#include "dawn/native/Error.h"
+#include <memory>
#include <mutex>
+#include <optional>
+#include <utility>
+#include <vector>
namespace dawn::native::vulkan {
@@ -29,12 +33,15 @@
class ShaderModule final : public ShaderModuleBase {
public:
+ using Spirv = std::vector<uint32_t>;
+ using ModuleAndSpirv = std::pair<VkShaderModule, const Spirv*>;
+
static ResultOrError<Ref<ShaderModule>> Create(Device* device,
const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult);
- ResultOrError<VkShaderModule> GetTransformedModuleHandle(const char* entryPointName,
- PipelineLayout* layout);
+ ResultOrError<ModuleAndSpirv> GetHandleAndSpirv(const char* entryPointName,
+ PipelineLayout* layout);
private:
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
@@ -42,20 +49,24 @@
MaybeError Initialize(ShaderModuleParseResult* parseResult);
void DestroyImpl() override;
- // New handles created by GetTransformedModuleHandle at pipeline creation time
+ // New handles created by GetHandleAndSpirv at pipeline creation time.
class ConcurrentTransformedShaderModuleCache {
public:
explicit ConcurrentTransformedShaderModuleCache(Device* device);
~ConcurrentTransformedShaderModuleCache();
- VkShaderModule FindShaderModule(const PipelineLayoutEntryPointPair& key);
- VkShaderModule AddOrGetCachedShaderModule(const PipelineLayoutEntryPointPair& key,
- VkShaderModule value);
+
+ std::optional<ModuleAndSpirv> Find(const PipelineLayoutEntryPointPair& key);
+ ModuleAndSpirv AddOrGet(const PipelineLayoutEntryPointPair& key,
+ VkShaderModule module,
+ std::vector<uint32_t>&& spirv);
private:
+ using Entry = std::pair<VkShaderModule, std::unique_ptr<Spirv>>;
+
Device* mDevice;
std::mutex mMutex;
std::unordered_map<PipelineLayoutEntryPointPair,
- VkShaderModule,
+ Entry,
PipelineLayoutEntryPointPairHashFunc>
mTransformedShaderModuleCache;
};
diff --git a/src/dawn/tests/unittests/native/CacheKeyTests.cpp b/src/dawn/tests/unittests/native/CacheKeyTests.cpp
index 45fd360..009b7b6 100644
--- a/src/dawn/tests/unittests/native/CacheKeyTests.cpp
+++ b/src/dawn/tests/unittests/native/CacheKeyTests.cpp
@@ -51,7 +51,7 @@
// Matcher to compare CacheKeys for easier testing.
MATCHER_P(CacheKeyEq, key, PrintToString(key)) {
- return memcmp(arg.data(), key.data(), arg.size()) == 0;
+ return arg.size() == key.size() && memcmp(arg.data(), key.data(), key.size()) == 0;
}
TEST(CacheKeyTests, RecordSingleMember) {