[YUV AHB] Make vulkan::BGL return the full TextureToStaticSamplerMap

And split off BGL's code that writes the descriptors to a separate
method that takes that map as input.

This will be useful in follow-up CLs to let the special handling of
JITed pipelines for the OpaqueYCbCrAndroid ExternalTextures to recreate
VkDescriptorSets but with a different set of static samplers.

Bug: 468988322
Change-Id: If61d53938a677c62f0e88a846f9e8fae69aed1e0
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/298055
Reviewed-by: Kyle Charbonneau <kylechar@google.com>
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Brandon Jones <bajones@chromium.org>
diff --git a/src/dawn/native/vulkan/BindGroupLayoutVk.cpp b/src/dawn/native/vulkan/BindGroupLayoutVk.cpp
index 56250a7..8b9d239 100644
--- a/src/dawn/native/vulkan/BindGroupLayoutVk.cpp
+++ b/src/dawn/native/vulkan/BindGroupLayoutVk.cpp
@@ -51,7 +51,7 @@
 struct VulkanStaticBindings {
     ityp::vector<BindingIndex, VkDescriptorSetLayoutBinding> bindings;
     absl::flat_hash_map<VkDescriptorType, uint32_t> descriptorCountPerType;
-    absl::flat_hash_map<BindingIndex, BindingIndex> textureToStaticSamplerIndex;
+    TextureToStaticSamplerMap textureToStaticSampler;
 };
 VulkanStaticBindings ComputeVulkanStaticBindings(const BindGroupLayoutInternalBase* layout) {
     VulkanStaticBindings res;
@@ -67,7 +67,7 @@
             continue;
         }
 
-        res.textureToStaticSamplerIndex[samplerBindingInfo.sampledTextureIndex] = bindingIndex;
+        res.textureToStaticSampler[samplerBindingInfo.sampledTextureIndex] = bindingIndex;
     }
 
     // Compute the bindings that will be chained in the DescriptorSetLayout create info. We add
@@ -88,7 +88,7 @@
         }
 
         // This texture will be bound into the VkDescriptorSet at the index for the sampler itself.
-        if (res.textureToStaticSamplerIndex.contains(bindingIndex)) {
+        if (res.textureToStaticSampler.contains(bindingIndex)) {
             continue;
         }
 
@@ -211,7 +211,7 @@
     mDescriptorSetAllocator =
         DescriptorSetAllocator::Create(device, std::move(bindings.descriptorCountPerType));
 
-    mTextureToStaticSamplerIndex = std::move(bindings.textureToStaticSamplerIndex);
+    mTextureToStaticSampler = std::move(bindings.textureToStaticSampler);
 
     VkDescriptorSetLayoutCreateInfo createInfo{
         .sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO,
@@ -273,12 +273,8 @@
     mBindGroupAllocator->DeleteEmptySlabs();
 }
 
-std::optional<BindingIndex> BindGroupLayout::GetStaticSamplerIndexForTexture(
-    BindingIndex textureBinding) const {
-    if (mTextureToStaticSamplerIndex.contains(textureBinding)) {
-        return mTextureToStaticSamplerIndex.at(textureBinding);
-    }
-    return {};
+const TextureToStaticSamplerMap& BindGroupLayout::GetTextureToStaticSamplerMap() const {
+    return mTextureToStaticSampler;
 }
 
 void BindGroupLayout::SetLabelImpl() {
diff --git a/src/dawn/native/vulkan/BindGroupLayoutVk.h b/src/dawn/native/vulkan/BindGroupLayoutVk.h
index 20b96dc..76adc92 100644
--- a/src/dawn/native/vulkan/BindGroupLayoutVk.h
+++ b/src/dawn/native/vulkan/BindGroupLayoutVk.h
@@ -66,10 +66,7 @@
     void DeallocateDescriptorSet(DescriptorSetAllocation* descriptorSetAllocation);
     void ReduceMemoryUsage() override;
 
-    // If the client specified that the texture at `textureBinding` should be
-    // combined with a static sampler, returns the binding index of the static
-    // sampler that is sampling this texture.
-    std::optional<BindingIndex> GetStaticSamplerIndexForTexture(BindingIndex textureBinding) const;
+    const TextureToStaticSamplerMap& GetTextureToStaticSamplerMap() const;
 
   protected:
     BindGroupLayout(DeviceBase* device, const UnpackedPtr<BindGroupLayoutDescriptor>& descriptor);
@@ -88,7 +85,7 @@
 
     // Maps from indices of texture entries that are paired with static samplers
     // to indices of the entries of their respective samplers.
-    absl::flat_hash_map<BindingIndex, BindingIndex> mTextureToStaticSamplerIndex;
+    TextureToStaticSamplerMap mTextureToStaticSampler;
 
     Ref<DescriptorSetAllocator> mDescriptorSetAllocator;
 };
diff --git a/src/dawn/native/vulkan/BindGroupVk.cpp b/src/dawn/native/vulkan/BindGroupVk.cpp
index dc91c5b..f351c39 100644
--- a/src/dawn/native/vulkan/BindGroupVk.cpp
+++ b/src/dawn/native/vulkan/BindGroupVk.cpp
@@ -66,6 +66,14 @@
 BindGroup::~BindGroup() = default;
 
 MaybeError BindGroup::InitializeImpl() {
+    WriteDescriptorSet(GetHandle(), ToBackend(GetLayout())->GetTextureToStaticSamplerMap());
+
+    SetLabelImpl();
+    return {};
+}
+
+void BindGroup::WriteDescriptorSet(VkDescriptorSet dsSet,
+                                   const TextureToStaticSamplerMap& textureToStaticSampler) {
     const auto* layout = ToBackend(GetLayout());
 
     // Now do a write of a single descriptor set with all possible chained data allocated on the
@@ -73,7 +81,7 @@
     // invalidate the pointers chained in `writes`.
     // TODO(https://crbug.com/438554018): Use Vulkan's descriptor set update template so as to need
     // a single allocation, and one that could be reused at the layout level.
-    const uint32_t bindingCount = static_cast<uint32_t>((GetLayout()->GetBindingCount()));
+    const uint32_t bindingCount = static_cast<uint32_t>((layout->GetBindingCount()));
     ityp::stack_vec<uint32_t, VkWriteDescriptorSet, kMaxOptimalBindingsPerGroup> writes(
         bindingCount);
     ityp::stack_vec<uint32_t, VkDescriptorBufferInfo, kMaxOptimalBindingsPerGroup> writeBufferInfo(
@@ -100,7 +108,7 @@
         auto& write = writes[writeIndex];
         write.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
         write.pNext = nullptr;
-        write.dstSet = GetHandle();
+        write.dstSet = dsSet;
         // Arrays all have a single binding, so compute the binding index for the array, which is
         // the same as the binding index for the 0th element.
         write.dstBinding = uint32_t(bindingIndex - bindingInfo.indexInArray);
@@ -166,9 +174,9 @@
             // TODO(https://crbug.com/438554018): Alternatively take advantage of the precomputed
             // descriptor update template to do set this up once in the layout and have it be
             // transparent in the BindGroup.
-            if (auto samplerIndex = ToBackend(GetLayout())->GetStaticSamplerIndexForTexture(i)) {
+            if (auto it = textureToStaticSampler.find(i); it != textureToStaticSampler.end()) {
                 // Write the info of the texture at the binding index for the sampler.
-                write->dstBinding = static_cast<uint32_t>(samplerIndex.value());
+                write->dstBinding = static_cast<uint32_t>(it->second);
                 write->descriptorType = VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER;
             }
 
@@ -217,9 +225,6 @@
     Device* device = ToBackend(GetDevice());
     // TODO(https://crbug.com/42242088): Batch these updates
     device->fn.UpdateDescriptorSets(device->GetVkDevice(), numWrites, writes.data(), 0, nullptr);
-
-    SetLabelImpl();
-    return {};
 }
 
 void BindGroup::DestroyImpl(DestroyReason reason) {
diff --git a/src/dawn/native/vulkan/BindGroupVk.h b/src/dawn/native/vulkan/BindGroupVk.h
index 557782a..14a2eac 100644
--- a/src/dawn/native/vulkan/BindGroupVk.h
+++ b/src/dawn/native/vulkan/BindGroupVk.h
@@ -39,6 +39,11 @@
 
 class Device;
 
+// The sampled texture bindings in Vulkan need to be moved from whatever the binding was going to
+// be, to instead use the same slot as the static sampler they will be co-written with in the
+// VkDescriptorSet.
+using TextureToStaticSamplerMap = absl::flat_hash_map<BindingIndex, BindingIndex>;
+
 class BindGroup final : public BindGroupBase, public PlacementAllocated {
   public:
     static ResultOrError<Ref<BindGroup>> Create(Device* device,
@@ -57,6 +62,9 @@
     void DestroyImpl(DestroyReason reason) override;
     void DeleteThis() override;
 
+    void WriteDescriptorSet(VkDescriptorSet dsSet,
+                            const TextureToStaticSamplerMap& textureToStaticSampler);
+
     // Dawn API
     void SetLabelImpl() override;
 
diff --git a/src/dawn/native/vulkan/ShaderModuleVk.cpp b/src/dawn/native/vulkan/ShaderModuleVk.cpp
index bfc251e..e30d0a4 100644
--- a/src/dawn/native/vulkan/ShaderModuleVk.cpp
+++ b/src/dawn/native/vulkan/ShaderModuleVk.cpp
@@ -161,14 +161,15 @@
     std::unordered_set<tint::BindingPoint> staticallyPairedTextureBindingPoints;
     for (BindGroupIndex group : in.layout->GetBindGroupLayoutsMask()) {
         const BindGroupLayout* bgl = ToBackend(in.layout->GetBindGroupLayout(group));
+        const auto& textureToStaticSampler = bgl->GetTextureToStaticSamplerMap();
 
         for (BindingIndex index : bgl->GetSampledTextureIndices()) {
             const auto& bindingInfo = bgl->GetBindingInfo(index);
 
-            if (auto samplerIndex = bgl->GetStaticSamplerIndexForTexture(index)) {
+            if (auto it = textureToStaticSampler.find(index); it != textureToStaticSampler.end()) {
                 tint::BindingPoint wgslBindingPoint = {.group = uint32_t(startOfBindGroups + group),
                                                        .binding = uint32_t(bindingInfo.binding)};
-                bindings.texture[wgslBindingPoint].binding = uint32_t(samplerIndex.value());
+                bindings.texture[wgslBindingPoint].binding = uint32_t(it->second);
                 staticallyPairedTextureBindingPoints.insert(wgslBindingPoint);
             }
         }