Remap BindGroup bindingIndex for vulkan backend when using Tint Generator
Bug: dawn:750
Change-Id: I239f5544a5822422d61a249f2ef028df326f90ed
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/47380
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Shrek Shao <shrekshao@google.com>
diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp
index 7774c61..90115f0 100644
--- a/src/dawn_native/ShaderModule.cpp
+++ b/src/dawn_native/ShaderModule.cpp
@@ -14,6 +14,7 @@
#include "dawn_native/ShaderModule.h"
+#include "common/HashUtils.h"
#include "common/VertexFormatUtils.h"
#include "dawn_native/BindGroupLayout.h"
#include "dawn_native/CompilationMessages.h"
@@ -1376,4 +1377,11 @@
return std::move(result);
}
+ size_t PipelineLayoutEntryPointPairHashFunc::operator()(
+ const PipelineLayoutEntryPointPair& pair) const {
+ size_t hash = 0;
+ HashCombine(&hash, pair.first, pair.second);
+ return hash;
+ }
+
} // namespace dawn_native
diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h
index ab3a271..556d604 100644
--- a/src/dawn_native/ShaderModule.h
+++ b/src/dawn_native/ShaderModule.h
@@ -51,6 +51,11 @@
struct EntryPointMetadata;
+ using PipelineLayoutEntryPointPair = std::pair<PipelineLayoutBase*, std::string>;
+ struct PipelineLayoutEntryPointPairHashFunc {
+ size_t operator()(const PipelineLayoutEntryPointPair& pair) const;
+ };
+
// A map from name to EntryPointMetadata.
using EntryPointMetadataTable =
std::unordered_map<std::string, std::unique_ptr<EntryPointMetadata>>;
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
index 33ea80b..8b14210 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
@@ -261,7 +261,7 @@
tint::transform::Transform::Output output =
transformManager.Run(GetTintProgram(), transformInputs);
- tint::Program& program = output.program;
+ const tint::Program& program = output.program;
if (!program.IsValid()) {
errorStream << "Tint program transform error: " << program.Diagnostics().str()
<< std::endl;
diff --git a/src/dawn_native/vulkan/BindGroupLayoutVk.cpp b/src/dawn_native/vulkan/BindGroupLayoutVk.cpp
index 700e850..78f7a7a 100644
--- a/src/dawn_native/vulkan/BindGroupLayoutVk.cpp
+++ b/src/dawn_native/vulkan/BindGroupLayoutVk.cpp
@@ -89,13 +89,16 @@
ityp::vector<BindingIndex, VkDescriptorSetLayoutBinding> bindings;
bindings.reserve(GetBindingCount());
+ bool useBindingIndex = GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator);
+
for (const auto& it : GetBindingMap()) {
BindingNumber bindingNumber = it.first;
BindingIndex bindingIndex = it.second;
const BindingInfo& bindingInfo = GetBindingInfo(bindingIndex);
VkDescriptorSetLayoutBinding vkBinding;
- vkBinding.binding = static_cast<uint32_t>(bindingNumber);
+ vkBinding.binding = useBindingIndex ? static_cast<uint32_t>(bindingIndex)
+ : static_cast<uint32_t>(bindingNumber);
vkBinding.descriptorType = VulkanDescriptorType(bindingInfo);
vkBinding.descriptorCount = 1;
vkBinding.stageFlags = VulkanShaderStageFlags(bindingInfo.visibility);
diff --git a/src/dawn_native/vulkan/BindGroupLayoutVk.h b/src/dawn_native/vulkan/BindGroupLayoutVk.h
index 72f8b69..cc502c9 100644
--- a/src/dawn_native/vulkan/BindGroupLayoutVk.h
+++ b/src/dawn_native/vulkan/BindGroupLayoutVk.h
@@ -43,6 +43,10 @@
// the pools are reused when no longer used. Minimizing the number of descriptor pool allocation
// is important because creating them can incur GPU memory allocation which is usually an
// expensive syscall.
+ //
+ // The Vulkan BindGroupLayout is dependent on UseTintGenerator or not.
+ // When UseTintGenerator is on, VkDescriptorSetLayoutBinding::binding is set to BindingIndex,
+ // otherwise it is set to BindingNumber.
class BindGroupLayout final : public BindGroupLayoutBase {
public:
static ResultOrError<Ref<BindGroupLayout>> Create(
diff --git a/src/dawn_native/vulkan/BindGroupVk.cpp b/src/dawn_native/vulkan/BindGroupVk.cpp
index 07653e8..b2334d1 100644
--- a/src/dawn_native/vulkan/BindGroupVk.cpp
+++ b/src/dawn_native/vulkan/BindGroupVk.cpp
@@ -47,6 +47,8 @@
ityp::stack_vec<uint32_t, VkDescriptorImageInfo, kMaxOptimalBindingsPerGroup>
writeImageInfo(bindingCount);
+ bool useBindingIndex = device->IsToggleEnabled(Toggle::UseTintGenerator);
+
uint32_t numWrites = 0;
for (const auto& it : GetLayout()->GetBindingMap()) {
BindingNumber bindingNumber = it.first;
@@ -57,7 +59,8 @@
write.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
write.pNext = nullptr;
write.dstSet = GetHandle();
- write.dstBinding = static_cast<uint32_t>(bindingNumber);
+ write.dstBinding = useBindingIndex ? static_cast<uint32_t>(bindingIndex)
+ : static_cast<uint32_t>(bindingNumber);
write.dstArrayElement = 0;
write.descriptorCount = 1;
write.descriptorType = VulkanDescriptorType(bindingInfo);
diff --git a/src/dawn_native/vulkan/BindGroupVk.h b/src/dawn_native/vulkan/BindGroupVk.h
index dac780b..14b6940 100644
--- a/src/dawn_native/vulkan/BindGroupVk.h
+++ b/src/dawn_native/vulkan/BindGroupVk.h
@@ -26,6 +26,9 @@
class Device;
+ // The Vulkan BindGroup is dependent on UseTintGenerator or not.
+ // When UseTintGenerator is on, VkWriteDescriptorSet::dstBinding is set to BindingIndex,
+ // otherwise it is set to BindingNumber.
class BindGroup final : public BindGroupBase, public PlacementAllocated {
public:
static ResultOrError<Ref<BindGroup>> Create(Device* device,
diff --git a/src/dawn_native/vulkan/ComputePipelineVk.cpp b/src/dawn_native/vulkan/ComputePipelineVk.cpp
index a81dee9..322c026 100644
--- a/src/dawn_native/vulkan/ComputePipelineVk.cpp
+++ b/src/dawn_native/vulkan/ComputePipelineVk.cpp
@@ -45,7 +45,15 @@
createInfo.stage.pNext = nullptr;
createInfo.stage.flags = 0;
createInfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
- createInfo.stage.module = ToBackend(descriptor->computeStage.module)->GetHandle();
+ if (GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator)) {
+ // Generate a new VkShaderModule with BindingRemapper tint transform for each pipeline
+ DAWN_TRY_ASSIGN(createInfo.stage.module,
+ ToBackend(descriptor->computeStage.module)
+ ->GetTransformedModuleHandle(descriptor->computeStage.entryPoint,
+ ToBackend(GetLayout())));
+ } else {
+ createInfo.stage.module = ToBackend(descriptor->computeStage.module)->GetHandle();
+ }
createInfo.stage.pName = descriptor->computeStage.entryPoint;
createInfo.stage.pSpecializationInfo = nullptr;
diff --git a/src/dawn_native/vulkan/RenderPipelineVk.cpp b/src/dawn_native/vulkan/RenderPipelineVk.cpp
index bbb4f8ee..a743b8d 100644
--- a/src/dawn_native/vulkan/RenderPipelineVk.cpp
+++ b/src/dawn_native/vulkan/RenderPipelineVk.cpp
@@ -332,12 +332,27 @@
VkPipelineShaderStageCreateInfo shaderStages[2];
{
+ if (device->IsToggleEnabled(Toggle::UseTintGenerator)) {
+ // Generate a new VkShaderModule with BindingRemapper tint transform for each
+ // pipeline
+ DAWN_TRY_ASSIGN(shaderStages[0].module,
+ ToBackend(descriptor->vertex.module)
+ ->GetTransformedModuleHandle(descriptor->vertex.entryPoint,
+ ToBackend(GetLayout())));
+ DAWN_TRY_ASSIGN(shaderStages[1].module,
+ ToBackend(descriptor->fragment->module)
+ ->GetTransformedModuleHandle(descriptor->fragment->entryPoint,
+ ToBackend(GetLayout())));
+ } else {
+ shaderStages[0].module = ToBackend(descriptor->vertex.module)->GetHandle();
+ shaderStages[1].module = ToBackend(descriptor->fragment->module)->GetHandle();
+ }
+
shaderStages[0].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
shaderStages[0].pNext = nullptr;
shaderStages[0].flags = 0;
shaderStages[0].stage = VK_SHADER_STAGE_VERTEX_BIT;
shaderStages[0].pSpecializationInfo = nullptr;
- shaderStages[0].module = ToBackend(descriptor->vertex.module)->GetHandle();
shaderStages[0].pName = descriptor->vertex.entryPoint;
shaderStages[1].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
@@ -345,7 +360,6 @@
shaderStages[1].flags = 0;
shaderStages[1].stage = VK_SHADER_STAGE_FRAGMENT_BIT;
shaderStages[1].pSpecializationInfo = nullptr;
- shaderStages[1].module = ToBackend(descriptor->fragment->module)->GetHandle();
shaderStages[1].pName = descriptor->fragment->entryPoint;
}
diff --git a/src/dawn_native/vulkan/ShaderModuleVk.cpp b/src/dawn_native/vulkan/ShaderModuleVk.cpp
index 0fb4c61..b8a2b23 100644
--- a/src/dawn_native/vulkan/ShaderModuleVk.cpp
+++ b/src/dawn_native/vulkan/ShaderModuleVk.cpp
@@ -15,8 +15,10 @@
#include "dawn_native/vulkan/ShaderModuleVk.h"
#include "dawn_native/TintUtils.h"
+#include "dawn_native/vulkan/BindGroupLayoutVk.h"
#include "dawn_native/vulkan/DeviceVk.h"
#include "dawn_native/vulkan/FencedDeleter.h"
+#include "dawn_native/vulkan/PipelineLayoutVk.h"
#include "dawn_native/vulkan/VulkanError.h"
#include <spirv_cross.hpp>
@@ -103,10 +105,106 @@
device->GetFencedDeleter()->DeleteWhenUnused(mHandle);
mHandle = VK_NULL_HANDLE;
}
+
+ for (const auto& iter : mTransformedShaderModuleCache) {
+ device->GetFencedDeleter()->DeleteWhenUnused(iter.second);
+ }
}
VkShaderModule ShaderModule::GetHandle() const {
+ ASSERT(!GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator));
return mHandle;
}
+ ResultOrError<VkShaderModule> ShaderModule::GetTransformedModuleHandle(
+ const char* entryPointName,
+ PipelineLayout* layout) {
+ ScopedTintICEHandler scopedICEHandler(GetDevice());
+
+ ASSERT(GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator));
+
+ auto cacheKey = std::make_pair(layout, entryPointName);
+ auto iter = mTransformedShaderModuleCache.find(cacheKey);
+ if (iter != mTransformedShaderModuleCache.end()) {
+ auto cached = iter->second;
+ return cached;
+ }
+
+ // Creation of VkShaderModule is deferred to this point when using tint generator
+ std::ostringstream errorStream;
+ errorStream << "Tint SPIR-V writer failure:" << std::endl;
+
+ // Remap BindingNumber to BindingIndex in WGSL shader
+ using BindingRemapper = tint::transform::BindingRemapper;
+ using BindingPoint = tint::transform::BindingPoint;
+ BindingRemapper::BindingPoints bindingPoints;
+ BindingRemapper::AccessControls accessControls;
+
+ const EntryPointMetadata::BindingInfoArray& moduleBindingInfo =
+ GetEntryPoint(entryPointName).bindings;
+
+ for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
+ const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group));
+ const auto& groupBindingInfo = moduleBindingInfo[group];
+ for (const auto& it : groupBindingInfo) {
+ BindingNumber binding = it.first;
+ BindingIndex bindingIndex = bgl->GetBindingIndex(binding);
+ BindingPoint srcBindingPoint{static_cast<uint32_t>(group),
+ static_cast<uint32_t>(binding)};
+
+ BindingPoint dstBindingPoint{static_cast<uint32_t>(group),
+ static_cast<uint32_t>(bindingIndex)};
+ if (srcBindingPoint != dstBindingPoint) {
+ bindingPoints.emplace(srcBindingPoint, dstBindingPoint);
+ }
+ }
+ }
+
+ tint::transform::Manager transformManager;
+ transformManager.append(std::make_unique<tint::transform::BindingRemapper>());
+
+ tint::transform::DataMap transformInputs;
+ transformInputs.Add<BindingRemapper::Remappings>(std::move(bindingPoints),
+ std::move(accessControls));
+ tint::transform::Transform::Output output =
+ transformManager.Run(GetTintProgram(), transformInputs);
+
+ const tint::Program& program = output.program;
+ if (!program.IsValid()) {
+ errorStream << "Tint program transform error: " << program.Diagnostics().str()
+ << std::endl;
+ return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+ }
+
+ tint::writer::spirv::Generator generator(&program);
+ if (!generator.Generate()) {
+ errorStream << "Generator: " << generator.error() << std::endl;
+ return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+ }
+
+ std::vector<uint32_t> spirv = generator.result();
+
+ // Don't save the transformedParseResult but just create a VkShaderModule
+ VkShaderModuleCreateInfo createInfo;
+ createInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
+ createInfo.pNext = nullptr;
+ createInfo.flags = 0;
+ std::vector<uint32_t> vulkanSource;
+ createInfo.codeSize = spirv.size() * sizeof(uint32_t);
+ createInfo.pCode = spirv.data();
+
+ Device* device = ToBackend(GetDevice());
+
+ VkShaderModule newHandle = VK_NULL_HANDLE;
+
+ DAWN_TRY(CheckVkSuccess(
+ device->fn.CreateShaderModule(device->GetVkDevice(), &createInfo, nullptr, &*newHandle),
+ "CreateShaderModule"));
+ if (newHandle != VK_NULL_HANDLE) {
+ mTransformedShaderModuleCache.emplace(cacheKey, newHandle);
+ }
+
+ return newHandle;
+ }
+
}} // namespace dawn_native::vulkan
diff --git a/src/dawn_native/vulkan/ShaderModuleVk.h b/src/dawn_native/vulkan/ShaderModuleVk.h
index 7c0d8ef..9dd7817 100644
--- a/src/dawn_native/vulkan/ShaderModuleVk.h
+++ b/src/dawn_native/vulkan/ShaderModuleVk.h
@@ -23,6 +23,11 @@
namespace dawn_native { namespace vulkan {
class Device;
+ class PipelineLayout;
+
+ using TransformedShaderModuleCache = std::unordered_map<PipelineLayoutEntryPointPair,
+ VkShaderModule,
+ PipelineLayoutEntryPointPairHashFunc>;
class ShaderModule final : public ShaderModuleBase {
public:
@@ -32,12 +37,19 @@
VkShaderModule GetHandle() const;
+ // This is only called when UseTintGenerator is on
+ ResultOrError<VkShaderModule> GetTransformedModuleHandle(const char* entryPointName,
+ PipelineLayout* layout);
+
private:
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
~ShaderModule() override;
MaybeError Initialize(ShaderModuleParseResult* parseResult);
VkShaderModule mHandle = VK_NULL_HANDLE;
+
+ // New handles created by GetTransformedModuleHandle at pipeline creation time
+ TransformedShaderModuleCache mTransformedShaderModuleCache;
};
}} // namespace dawn_native::vulkan