Remap BindGroup bindingIndex for vulkan backend when using Tint Generator

Bug: dawn:750
Change-Id: I239f5544a5822422d61a249f2ef028df326f90ed
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/47380
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Shrek Shao <shrekshao@google.com>
diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp
index 7774c61..90115f0 100644
--- a/src/dawn_native/ShaderModule.cpp
+++ b/src/dawn_native/ShaderModule.cpp
@@ -14,6 +14,7 @@
 
 #include "dawn_native/ShaderModule.h"
 
+#include "common/HashUtils.h"
 #include "common/VertexFormatUtils.h"
 #include "dawn_native/BindGroupLayout.h"
 #include "dawn_native/CompilationMessages.h"
@@ -1376,4 +1377,11 @@
         return std::move(result);
     }
 
+    size_t PipelineLayoutEntryPointPairHashFunc::operator()(
+        const PipelineLayoutEntryPointPair& pair) const {
+        size_t hash = 0;
+        HashCombine(&hash, pair.first, pair.second);
+        return hash;
+    }
+
 }  // namespace dawn_native
diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h
index ab3a271..556d604 100644
--- a/src/dawn_native/ShaderModule.h
+++ b/src/dawn_native/ShaderModule.h
@@ -51,6 +51,11 @@
 
     struct EntryPointMetadata;
 
+    using PipelineLayoutEntryPointPair = std::pair<PipelineLayoutBase*, std::string>;
+    struct PipelineLayoutEntryPointPairHashFunc {
+        size_t operator()(const PipelineLayoutEntryPointPair& pair) const;
+    };
+
     // A map from name to EntryPointMetadata.
     using EntryPointMetadataTable =
         std::unordered_map<std::string, std::unique_ptr<EntryPointMetadata>>;
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
index 33ea80b..8b14210 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
@@ -261,7 +261,7 @@
         tint::transform::Transform::Output output =
             transformManager.Run(GetTintProgram(), transformInputs);
 
-        tint::Program& program = output.program;
+        const tint::Program& program = output.program;
         if (!program.IsValid()) {
             errorStream << "Tint program transform error: " << program.Diagnostics().str()
                         << std::endl;
diff --git a/src/dawn_native/vulkan/BindGroupLayoutVk.cpp b/src/dawn_native/vulkan/BindGroupLayoutVk.cpp
index 700e850..78f7a7a 100644
--- a/src/dawn_native/vulkan/BindGroupLayoutVk.cpp
+++ b/src/dawn_native/vulkan/BindGroupLayoutVk.cpp
@@ -89,13 +89,16 @@
         ityp::vector<BindingIndex, VkDescriptorSetLayoutBinding> bindings;
         bindings.reserve(GetBindingCount());
 
+        bool useBindingIndex = GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator);
+
         for (const auto& it : GetBindingMap()) {
             BindingNumber bindingNumber = it.first;
             BindingIndex bindingIndex = it.second;
             const BindingInfo& bindingInfo = GetBindingInfo(bindingIndex);
 
             VkDescriptorSetLayoutBinding vkBinding;
-            vkBinding.binding = static_cast<uint32_t>(bindingNumber);
+            vkBinding.binding = useBindingIndex ? static_cast<uint32_t>(bindingIndex)
+                                                : static_cast<uint32_t>(bindingNumber);
             vkBinding.descriptorType = VulkanDescriptorType(bindingInfo);
             vkBinding.descriptorCount = 1;
             vkBinding.stageFlags = VulkanShaderStageFlags(bindingInfo.visibility);
diff --git a/src/dawn_native/vulkan/BindGroupLayoutVk.h b/src/dawn_native/vulkan/BindGroupLayoutVk.h
index 72f8b69..cc502c9 100644
--- a/src/dawn_native/vulkan/BindGroupLayoutVk.h
+++ b/src/dawn_native/vulkan/BindGroupLayoutVk.h
@@ -43,6 +43,10 @@
     // the pools are reused when no longer used. Minimizing the number of descriptor pool allocation
     // is important because creating them can incur GPU memory allocation which is usually an
     // expensive syscall.
+    //
+    // The Vulkan BindGroupLayout is dependent on UseTintGenerator or not.
+    // When UseTintGenerator is on, VkDescriptorSetLayoutBinding::binding is set to BindingIndex,
+    // otherwise it is set to BindingNumber.
     class BindGroupLayout final : public BindGroupLayoutBase {
       public:
         static ResultOrError<Ref<BindGroupLayout>> Create(
diff --git a/src/dawn_native/vulkan/BindGroupVk.cpp b/src/dawn_native/vulkan/BindGroupVk.cpp
index 07653e8..b2334d1 100644
--- a/src/dawn_native/vulkan/BindGroupVk.cpp
+++ b/src/dawn_native/vulkan/BindGroupVk.cpp
@@ -47,6 +47,8 @@
         ityp::stack_vec<uint32_t, VkDescriptorImageInfo, kMaxOptimalBindingsPerGroup>
             writeImageInfo(bindingCount);
 
+        bool useBindingIndex = device->IsToggleEnabled(Toggle::UseTintGenerator);
+
         uint32_t numWrites = 0;
         for (const auto& it : GetLayout()->GetBindingMap()) {
             BindingNumber bindingNumber = it.first;
@@ -57,7 +59,8 @@
             write.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
             write.pNext = nullptr;
             write.dstSet = GetHandle();
-            write.dstBinding = static_cast<uint32_t>(bindingNumber);
+            write.dstBinding = useBindingIndex ? static_cast<uint32_t>(bindingIndex)
+                                               : static_cast<uint32_t>(bindingNumber);
             write.dstArrayElement = 0;
             write.descriptorCount = 1;
             write.descriptorType = VulkanDescriptorType(bindingInfo);
diff --git a/src/dawn_native/vulkan/BindGroupVk.h b/src/dawn_native/vulkan/BindGroupVk.h
index dac780b..14b6940 100644
--- a/src/dawn_native/vulkan/BindGroupVk.h
+++ b/src/dawn_native/vulkan/BindGroupVk.h
@@ -26,6 +26,9 @@
 
     class Device;
 
+    // The Vulkan BindGroup is dependent on UseTintGenerator or not.
+    // When UseTintGenerator is on, VkWriteDescriptorSet::dstBinding is set to BindingIndex,
+    // otherwise it is set to BindingNumber.
     class BindGroup final : public BindGroupBase, public PlacementAllocated {
       public:
         static ResultOrError<Ref<BindGroup>> Create(Device* device,
diff --git a/src/dawn_native/vulkan/ComputePipelineVk.cpp b/src/dawn_native/vulkan/ComputePipelineVk.cpp
index a81dee9..322c026 100644
--- a/src/dawn_native/vulkan/ComputePipelineVk.cpp
+++ b/src/dawn_native/vulkan/ComputePipelineVk.cpp
@@ -45,7 +45,15 @@
         createInfo.stage.pNext = nullptr;
         createInfo.stage.flags = 0;
         createInfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
-        createInfo.stage.module = ToBackend(descriptor->computeStage.module)->GetHandle();
+        if (GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator)) {
+            // Generate a new VkShaderModule with BindingRemapper tint transform for each pipeline
+            DAWN_TRY_ASSIGN(createInfo.stage.module,
+                            ToBackend(descriptor->computeStage.module)
+                                ->GetTransformedModuleHandle(descriptor->computeStage.entryPoint,
+                                                             ToBackend(GetLayout())));
+        } else {
+            createInfo.stage.module = ToBackend(descriptor->computeStage.module)->GetHandle();
+        }
         createInfo.stage.pName = descriptor->computeStage.entryPoint;
         createInfo.stage.pSpecializationInfo = nullptr;
 
diff --git a/src/dawn_native/vulkan/RenderPipelineVk.cpp b/src/dawn_native/vulkan/RenderPipelineVk.cpp
index bbb4f8ee..a743b8d 100644
--- a/src/dawn_native/vulkan/RenderPipelineVk.cpp
+++ b/src/dawn_native/vulkan/RenderPipelineVk.cpp
@@ -332,12 +332,27 @@
 
         VkPipelineShaderStageCreateInfo shaderStages[2];
         {
+            if (device->IsToggleEnabled(Toggle::UseTintGenerator)) {
+                // Generate a new VkShaderModule with BindingRemapper tint transform for each
+                // pipeline
+                DAWN_TRY_ASSIGN(shaderStages[0].module,
+                                ToBackend(descriptor->vertex.module)
+                                    ->GetTransformedModuleHandle(descriptor->vertex.entryPoint,
+                                                                 ToBackend(GetLayout())));
+                DAWN_TRY_ASSIGN(shaderStages[1].module,
+                                ToBackend(descriptor->fragment->module)
+                                    ->GetTransformedModuleHandle(descriptor->fragment->entryPoint,
+                                                                 ToBackend(GetLayout())));
+            } else {
+                shaderStages[0].module = ToBackend(descriptor->vertex.module)->GetHandle();
+                shaderStages[1].module = ToBackend(descriptor->fragment->module)->GetHandle();
+            }
+
             shaderStages[0].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
             shaderStages[0].pNext = nullptr;
             shaderStages[0].flags = 0;
             shaderStages[0].stage = VK_SHADER_STAGE_VERTEX_BIT;
             shaderStages[0].pSpecializationInfo = nullptr;
-            shaderStages[0].module = ToBackend(descriptor->vertex.module)->GetHandle();
             shaderStages[0].pName = descriptor->vertex.entryPoint;
 
             shaderStages[1].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
@@ -345,7 +360,6 @@
             shaderStages[1].flags = 0;
             shaderStages[1].stage = VK_SHADER_STAGE_FRAGMENT_BIT;
             shaderStages[1].pSpecializationInfo = nullptr;
-            shaderStages[1].module = ToBackend(descriptor->fragment->module)->GetHandle();
             shaderStages[1].pName = descriptor->fragment->entryPoint;
         }
 
diff --git a/src/dawn_native/vulkan/ShaderModuleVk.cpp b/src/dawn_native/vulkan/ShaderModuleVk.cpp
index 0fb4c61..b8a2b23 100644
--- a/src/dawn_native/vulkan/ShaderModuleVk.cpp
+++ b/src/dawn_native/vulkan/ShaderModuleVk.cpp
@@ -15,8 +15,10 @@
 #include "dawn_native/vulkan/ShaderModuleVk.h"
 
 #include "dawn_native/TintUtils.h"
+#include "dawn_native/vulkan/BindGroupLayoutVk.h"
 #include "dawn_native/vulkan/DeviceVk.h"
 #include "dawn_native/vulkan/FencedDeleter.h"
+#include "dawn_native/vulkan/PipelineLayoutVk.h"
 #include "dawn_native/vulkan/VulkanError.h"
 
 #include <spirv_cross.hpp>
@@ -103,10 +105,106 @@
             device->GetFencedDeleter()->DeleteWhenUnused(mHandle);
             mHandle = VK_NULL_HANDLE;
         }
+
+        for (const auto& iter : mTransformedShaderModuleCache) {
+            device->GetFencedDeleter()->DeleteWhenUnused(iter.second);
+        }
     }
 
     VkShaderModule ShaderModule::GetHandle() const {
+        ASSERT(!GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator));
         return mHandle;
     }
 
+    ResultOrError<VkShaderModule> ShaderModule::GetTransformedModuleHandle(
+        const char* entryPointName,
+        PipelineLayout* layout) {
+        ScopedTintICEHandler scopedICEHandler(GetDevice());
+
+        ASSERT(GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator));
+
+        auto cacheKey = std::make_pair(layout, entryPointName);
+        auto iter = mTransformedShaderModuleCache.find(cacheKey);
+        if (iter != mTransformedShaderModuleCache.end()) {
+            auto cached = iter->second;
+            return cached;
+        }
+
+        // Creation of VkShaderModule is deferred to this point when using tint generator
+        std::ostringstream errorStream;
+        errorStream << "Tint SPIR-V writer failure:" << std::endl;
+
+        // Remap BindingNumber to BindingIndex in WGSL shader
+        using BindingRemapper = tint::transform::BindingRemapper;
+        using BindingPoint = tint::transform::BindingPoint;
+        BindingRemapper::BindingPoints bindingPoints;
+        BindingRemapper::AccessControls accessControls;
+
+        const EntryPointMetadata::BindingInfoArray& moduleBindingInfo =
+            GetEntryPoint(entryPointName).bindings;
+
+        for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
+            const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group));
+            const auto& groupBindingInfo = moduleBindingInfo[group];
+            for (const auto& it : groupBindingInfo) {
+                BindingNumber binding = it.first;
+                BindingIndex bindingIndex = bgl->GetBindingIndex(binding);
+                BindingPoint srcBindingPoint{static_cast<uint32_t>(group),
+                                             static_cast<uint32_t>(binding)};
+
+                BindingPoint dstBindingPoint{static_cast<uint32_t>(group),
+                                             static_cast<uint32_t>(bindingIndex)};
+                if (srcBindingPoint != dstBindingPoint) {
+                    bindingPoints.emplace(srcBindingPoint, dstBindingPoint);
+                }
+            }
+        }
+
+        tint::transform::Manager transformManager;
+        transformManager.append(std::make_unique<tint::transform::BindingRemapper>());
+
+        tint::transform::DataMap transformInputs;
+        transformInputs.Add<BindingRemapper::Remappings>(std::move(bindingPoints),
+                                                         std::move(accessControls));
+        tint::transform::Transform::Output output =
+            transformManager.Run(GetTintProgram(), transformInputs);
+
+        const tint::Program& program = output.program;
+        if (!program.IsValid()) {
+            errorStream << "Tint program transform error: " << program.Diagnostics().str()
+                        << std::endl;
+            return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+        }
+
+        tint::writer::spirv::Generator generator(&program);
+        if (!generator.Generate()) {
+            errorStream << "Generator: " << generator.error() << std::endl;
+            return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+        }
+
+        std::vector<uint32_t> spirv = generator.result();
+
+        // Don't save the transformedParseResult but just create a VkShaderModule
+        VkShaderModuleCreateInfo createInfo;
+        createInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
+        createInfo.pNext = nullptr;
+        createInfo.flags = 0;
+        std::vector<uint32_t> vulkanSource;
+        createInfo.codeSize = spirv.size() * sizeof(uint32_t);
+        createInfo.pCode = spirv.data();
+
+        Device* device = ToBackend(GetDevice());
+
+        VkShaderModule newHandle = VK_NULL_HANDLE;
+
+        DAWN_TRY(CheckVkSuccess(
+            device->fn.CreateShaderModule(device->GetVkDevice(), &createInfo, nullptr, &*newHandle),
+            "CreateShaderModule"));
+        if (newHandle != VK_NULL_HANDLE) {
+            mTransformedShaderModuleCache.emplace(cacheKey, newHandle);
+        }
+
+        return newHandle;
+    }
+
 }}  // namespace dawn_native::vulkan
diff --git a/src/dawn_native/vulkan/ShaderModuleVk.h b/src/dawn_native/vulkan/ShaderModuleVk.h
index 7c0d8ef..9dd7817 100644
--- a/src/dawn_native/vulkan/ShaderModuleVk.h
+++ b/src/dawn_native/vulkan/ShaderModuleVk.h
@@ -23,6 +23,11 @@
 namespace dawn_native { namespace vulkan {
 
     class Device;
+    class PipelineLayout;
+
+    using TransformedShaderModuleCache = std::unordered_map<PipelineLayoutEntryPointPair,
+                                                            VkShaderModule,
+                                                            PipelineLayoutEntryPointPairHashFunc>;
 
     class ShaderModule final : public ShaderModuleBase {
       public:
@@ -32,12 +37,19 @@
 
         VkShaderModule GetHandle() const;
 
+        // This is only called when UseTintGenerator is on
+        ResultOrError<VkShaderModule> GetTransformedModuleHandle(const char* entryPointName,
+                                                                 PipelineLayout* layout);
+
       private:
         ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
         ~ShaderModule() override;
         MaybeError Initialize(ShaderModuleParseResult* parseResult);
 
         VkShaderModule mHandle = VK_NULL_HANDLE;
+
+        // New handles created by GetTransformedModuleHandle at pipeline creation time
+        TransformedShaderModuleCache mTransformedShaderModuleCache;
     };
 
 }}  // namespace dawn_native::vulkan