[YUV AHB] Factor state sync logic in CommandBufferVk.

All the Draw and Dispatch operations needed to make sure to apply
updates for dirty immediates and descriptor sets. Introduce a "state"
class that holds the various trackers and has a wrapper function
SyncAndRun that runs the draw/dispatch functor after syncing.

Also defer binding the Vulkan pipeline until right before the operation.

This is useful in the short term to add support for JITed pipelines for
YUV AHB, since it allows detecting that the pipeline needs JITing just
before the Draw/Dispatch, and instrumenting the code that does state
synchronization for that purpose. Long-term it will be useful to reduce
the state trackend by the Vulkan backend (some of it in BindGroupTracker
is not needed), and moving to a tail-call interpreter loop.

Bug: 468988322
Change-Id: Iaed78347dc49c1f468bb585f368a058b85d36e82
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/297735
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Kyle Charbonneau <kylechar@google.com>
Reviewed-by: Brandon Jones <bajones@chromium.org>
diff --git a/src/dawn/native/vulkan/CommandBufferVk.cpp b/src/dawn/native/vulkan/CommandBufferVk.cpp
index b8bc53e..df814fa 100644
--- a/src/dawn/native/vulkan/CommandBufferVk.cpp
+++ b/src/dawn/native/vulkan/CommandBufferVk.cpp
@@ -222,8 +222,8 @@
 
     void SetResourceTable(ResourceTable* resourceTable) { mResourceTable = resourceTable; }
 
-    void Apply(Device* device,
-               CommandRecordingContext* recordingContext,
+    void Apply(const VulkanFunctions& vk,
+               VkCommandBuffer commandBuffer,
                VkPipelineBindPoint bindPoint) {
         BeforeApply();
 
@@ -246,8 +246,8 @@
             if (mUsesResourceTable) {
                 DAWN_ASSERT(mResourceTable != nullptr);
                 VkDescriptorSet set = mResourceTable->GetHandle();
-                device->fn.CmdBindDescriptorSets(recordingContext->commandBuffer, bindPoint,
-                                                 mVkLayout, 0, 1, &*set, 0, nullptr);
+                vk.CmdBindDescriptorSets(commandBuffer, bindPoint, mVkLayout, 0, 1, &*set, 0,
+                                         nullptr);
             }
         }
         BindGroupIndex startOfBindGroups{mUsesResourceTable ? 1u : 0u};
@@ -260,8 +260,8 @@
             uint32_t count = static_cast<uint32_t>(dynamicOffsetSpan.size());
             const uint32_t* dynamicOffset = count > 0 ? dynamicOffsetSpan.data() : nullptr;
 
-            device->fn.CmdBindDescriptorSets(recordingContext->commandBuffer, bindPoint, mVkLayout,
-                                             setIndex, 1, &*set, count, dynamicOffset);
+            vk.CmdBindDescriptorSets(commandBuffer, bindPoint, mVkLayout, setIndex, 1, &*set, count,
+                                     dynamicOffset);
         }
 
         // Update PipelineLayout
@@ -286,7 +286,7 @@
   public:
     ImmediateConstantTracker() = default;
 
-    void Apply(Device* device, VkCommandBuffer commandBuffer) {
+    void Apply(const VulkanFunctions& vk, VkCommandBuffer commandBuffer) {
         DAWN_ASSERT(this->mLastPipeline != nullptr);
 
         auto* lastPipeline = this->mLastPipeline;
@@ -298,11 +298,11 @@
             uint32_t pushConstantRangeStartOffset =
                 GetImmediateIndexInPipeline(static_cast<uint32_t>(offset), pipelineMask) *
                 kImmediateConstantElementByteSize;
-            device->fn.CmdPushConstants(
-                commandBuffer, ToBackend(lastPipeline)->GetVkLayout(),
-                ToBackend(lastPipeline->GetLayout())->GetImmediateDataRangeStage(),
-                pushConstantRangeStartOffset, size * kImmediateConstantElementByteSize,
-                this->mContent.template Get<uint32_t>(immediateContentStartOffset));
+            vk.CmdPushConstants(commandBuffer, ToBackend(lastPipeline)->GetVkLayout(),
+                                ToBackend(lastPipeline->GetLayout())->GetImmediateDataRangeStage(),
+                                pushConstantRangeStartOffset,
+                                size * kImmediateConstantElementByteSize,
+                                this->mContent.template Get<uint32_t>(immediateContentStartOffset));
         }
 
         // Reset all dirty bits after uploading.
@@ -541,6 +541,54 @@
     return clearValue;
 }
 
+// A number of WebGPU commands cannot be immediately turned into Vulkan commands and instead just
+// dirty state which needs to be applied just before the Draw/Dispatch. This State class factors the
+// tracking logic of both Compute and Render passes.
+template <typename BaseImmediateTracker, typename Pipeline, VkPipelineBindPoint PipelineBindPoint>
+struct ProgrammablePassState : public StackAllocated {
+    ProgrammablePassState(const VulkanFunctions& vk, CommandRecordingContext* recordingContext)
+        : vk(vk), recordingContext(recordingContext) {}
+
+    void OnSetPipeline(Pipeline* pipeline) {
+        lastPipeline = pipeline;
+        descriptorSets.OnSetPipeline<Pipeline>(pipeline);
+        immediates.OnSetPipeline(pipeline);
+    }
+
+    // Synchronizes all the dirty state before doing the operation.
+    template <typename F>
+    MaybeError SyncAndRun(F&& DoOperation) {
+        VkCommandBuffer commands = recordingContext->commandBuffer;
+
+        if (lastAppliedPipeline != lastPipeline) {
+            vk.CmdBindPipeline(commands, PipelineBindPoint, lastPipeline->GetHandle());
+            lastAppliedPipeline = lastPipeline;
+        }
+
+        descriptorSets.Apply(vk, commands, PipelineBindPoint);
+        immediates.Apply(vk, commands);
+
+        DoOperation(vk, commands);
+        return {};
+    }
+
+    const VulkanFunctions& vk;
+    CommandRecordingContext* recordingContext;
+
+    DescriptorSetTracker descriptorSets;
+    ImmediateConstantTracker<BaseImmediateTracker> immediates;
+
+    Pipeline* lastPipeline = nullptr;
+    Pipeline* lastAppliedPipeline = nullptr;
+};
+
+using RenderPassState = ProgrammablePassState<RenderImmediateConstantsTrackerBase,
+                                              RenderPipeline,
+                                              VK_PIPELINE_BIND_POINT_GRAPHICS>;
+using ComputePassState = ProgrammablePassState<ComputeImmediateConstantsTrackerBase,
+                                               ComputePipeline,
+                                               VK_PIPELINE_BIND_POINT_COMPUTE>;
+
 }  // anonymous namespace
 
 MaybeError RecordBeginDynamicRenderPass(CommandRecordingContext* recordingContext,
@@ -1307,8 +1355,7 @@
     VkCommandBuffer commands = recordingContext->commandBuffer;
 
     uint64_t currentDispatch = 0;
-    DescriptorSetTracker descriptorSets;
-    ImmediateConstantTracker<ComputeImmediateConstantsTrackerBase> immediates = {};
+    ComputePassState state(device->fn, recordingContext);
 
     Command type;
     while (mCommands.NextCommandId(&type)) {
@@ -1332,10 +1379,11 @@
 
                 DAWN_TRY(PrepareResourcesForSyncScope(
                     device, recordingContext, resourceUsages.dispatchUsages[currentDispatch]));
-                descriptorSets.Apply(device, recordingContext, VK_PIPELINE_BIND_POINT_COMPUTE);
-                immediates.Apply(device, commands);
-                device->fn.CmdDispatch(commands, dispatch->x, dispatch->y, dispatch->z);
                 currentDispatch++;
+
+                DAWN_TRY(state.SyncAndRun([&](const VulkanFunctions& vk, VkCommandBuffer commands) {
+                    vk.CmdDispatch(commands, dispatch->x, dispatch->y, dispatch->z);
+                }));
                 break;
             }
 
@@ -1345,11 +1393,12 @@
 
                 DAWN_TRY(PrepareResourcesForSyncScope(
                     device, recordingContext, resourceUsages.dispatchUsages[currentDispatch]));
-                descriptorSets.Apply(device, recordingContext, VK_PIPELINE_BIND_POINT_COMPUTE);
-                immediates.Apply(device, commands);
-                device->fn.CmdDispatchIndirect(commands, indirectBuffer,
-                                               static_cast<VkDeviceSize>(dispatch->indirectOffset));
                 currentDispatch++;
+
+                DAWN_TRY(state.SyncAndRun([&](const VulkanFunctions& vk, VkCommandBuffer commands) {
+                    vk.CmdDispatchIndirect(commands, indirectBuffer,
+                                           static_cast<VkDeviceSize>(dispatch->indirectOffset));
+                }));
                 break;
             }
 
@@ -1362,19 +1411,14 @@
                     dynamicOffsets = mCommands.NextData<uint32_t>(cmd->dynamicOffsetCount);
                 }
 
-                descriptorSets.OnSetBindGroup(cmd->index, bindGroup, cmd->dynamicOffsetCount,
-                                              dynamicOffsets);
+                state.descriptorSets.OnSetBindGroup(cmd->index, bindGroup, cmd->dynamicOffsetCount,
+                                                    dynamicOffsets);
                 break;
             }
 
             case Command::SetComputePipeline: {
                 SetComputePipelineCmd* cmd = mCommands.NextCommand<SetComputePipelineCmd>();
-                ComputePipeline* pipeline = ToBackend(cmd->pipeline).Get();
-
-                device->fn.CmdBindPipeline(commands, VK_PIPELINE_BIND_POINT_COMPUTE,
-                                           pipeline->GetHandle());
-                descriptorSets.OnSetPipeline<ComputePipeline>(pipeline);
-                immediates.OnSetPipeline(pipeline);
+                state.OnSetPipeline(ToBackend(cmd->pipeline).Get());
                 break;
             }
 
@@ -1441,13 +1485,13 @@
                 DAWN_ASSERT(cmd->size > 0);
                 uint8_t* value = nullptr;
                 value = mCommands.NextData<uint8_t>(cmd->size);
-                immediates.SetImmediates(cmd->offset, value, cmd->size);
+                state.immediates.SetImmediates(cmd->offset, value, cmd->size);
                 break;
             }
 
             case Command::SetResourceTable: {
                 SetResourceTableCmd* cmd = mCommands.NextCommand<SetResourceTableCmd>();
-                descriptorSets.SetResourceTable(ToBackend(cmd->table.Get()));
+                state.descriptorSets.SetResourceTable(ToBackend(cmd->table.Get()));
                 break;
             }
 
@@ -1477,7 +1521,8 @@
 
     DAWN_TRY(RecordBeginRenderPass(recordingContext, device, renderPassCmd));
 
-    ImmediateConstantTracker<RenderImmediateConstantsTrackerBase> immediates = {};
+    RenderPassState state(device->fn, recordingContext);
+
     // Set the default value for the dynamic state
     {
         device->fn.CmdSetLineWidth(commands, 1.0f);
@@ -1511,25 +1556,23 @@
         device->fn.CmdSetScissor(commands, 0, 1, &scissorRect);
 
         // Apply default frag depth
-        immediates.SetClampFragDepth(0.0, 1.0);
+        state.immediates.SetClampFragDepth(0.0, 1.0);
     }
 
-    DescriptorSetTracker descriptorSets;
-    RenderPipeline* lastPipeline = nullptr;
 
     // Tracks the number of commands that do significant GPU work (a draw or query write) this pass.
     uint32_t workCommandCount = 0;
 
-    auto EncodeRenderBundleCommand = [&](CommandIterator* iter, Command type) {
+    auto EncodeRenderBundleCommand = [&](CommandIterator* iter, Command type) -> MaybeError {
         switch (type) {
             case Command::Draw: {
                 workCommandCount++;
                 DrawCmd* draw = iter->NextCommand<DrawCmd>();
 
-                descriptorSets.Apply(device, recordingContext, VK_PIPELINE_BIND_POINT_GRAPHICS);
-                immediates.Apply(device, commands);
-                device->fn.CmdDraw(commands, draw->vertexCount, draw->instanceCount,
-                                   draw->firstVertex, draw->firstInstance);
+                DAWN_TRY(state.SyncAndRun([&](const VulkanFunctions& vk, VkCommandBuffer commands) {
+                    vk.CmdDraw(commands, draw->vertexCount, draw->instanceCount, draw->firstVertex,
+                               draw->firstInstance);
+                }));
                 break;
             }
 
@@ -1537,10 +1580,10 @@
                 workCommandCount++;
                 DrawIndexedCmd* draw = iter->NextCommand<DrawIndexedCmd>();
 
-                descriptorSets.Apply(device, recordingContext, VK_PIPELINE_BIND_POINT_GRAPHICS);
-                immediates.Apply(device, commands);
-                device->fn.CmdDrawIndexed(commands, draw->indexCount, draw->instanceCount,
-                                          draw->firstIndex, draw->baseVertex, draw->firstInstance);
+                DAWN_TRY(state.SyncAndRun([&](const VulkanFunctions& vk, VkCommandBuffer commands) {
+                    vk.CmdDrawIndexed(commands, draw->indexCount, draw->instanceCount,
+                                      draw->firstIndex, draw->baseVertex, draw->firstInstance);
+                }));
                 break;
             }
 
@@ -1549,10 +1592,10 @@
                 DrawIndirectCmd* draw = iter->NextCommand<DrawIndirectCmd>();
                 Buffer* buffer = ToBackend(draw->indirectBuffer.Get());
 
-                descriptorSets.Apply(device, recordingContext, VK_PIPELINE_BIND_POINT_GRAPHICS);
-                immediates.Apply(device, commands);
-                device->fn.CmdDrawIndirect(commands, buffer->GetHandle(),
-                                           static_cast<VkDeviceSize>(draw->indirectOffset), 1, 0);
+                DAWN_TRY(state.SyncAndRun([&](const VulkanFunctions& vk, VkCommandBuffer commands) {
+                    vk.CmdDrawIndirect(commands, buffer->GetHandle(),
+                                       static_cast<VkDeviceSize>(draw->indirectOffset), 1, 0);
+                }));
                 break;
             }
 
@@ -1562,11 +1605,11 @@
                 Buffer* buffer = ToBackend(draw->indirectBuffer.Get());
                 DAWN_ASSERT(buffer != nullptr);
 
-                descriptorSets.Apply(device, recordingContext, VK_PIPELINE_BIND_POINT_GRAPHICS);
-                immediates.Apply(device, commands);
-                device->fn.CmdDrawIndexedIndirect(commands, buffer->GetHandle(),
-                                                  static_cast<VkDeviceSize>(draw->indirectOffset),
-                                                  1, 0);
+                DAWN_TRY(state.SyncAndRun([&](const VulkanFunctions& vk, VkCommandBuffer commands) {
+                    vk.CmdDrawIndexedIndirect(commands, buffer->GetHandle(),
+                                              static_cast<VkDeviceSize>(draw->indirectOffset), 1,
+                                              0);
+                }));
                 break;
             }
 
@@ -1580,20 +1623,19 @@
                 // Count buffer is optional
                 Buffer* countBuffer = ToBackend(cmd->drawCountBuffer.Get());
 
-                descriptorSets.Apply(device, recordingContext, VK_PIPELINE_BIND_POINT_GRAPHICS);
-                immediates.Apply(device, commands);
-
-                if (countBuffer == nullptr) {
-                    device->fn.CmdDrawIndirect(commands, indirectBuffer->GetHandle(),
-                                               static_cast<VkDeviceSize>(cmd->indirectOffset),
-                                               cmd->maxDrawCount, kDrawIndirectSize);
-                } else {
-                    device->fn.CmdDrawIndirectCountKHR(
-                        commands, indirectBuffer->GetHandle(),
-                        static_cast<VkDeviceSize>(cmd->indirectOffset), countBuffer->GetHandle(),
-                        static_cast<VkDeviceSize>(cmd->drawCountOffset), cmd->maxDrawCount,
-                        kDrawIndirectSize);
-                }
+                DAWN_TRY(state.SyncAndRun([&](const VulkanFunctions& vk, VkCommandBuffer commands) {
+                    if (countBuffer == nullptr) {
+                        vk.CmdDrawIndirect(commands, indirectBuffer->GetHandle(),
+                                           static_cast<VkDeviceSize>(cmd->indirectOffset),
+                                           cmd->maxDrawCount, kDrawIndirectSize);
+                    } else {
+                        vk.CmdDrawIndirectCountKHR(commands, indirectBuffer->GetHandle(),
+                                                   static_cast<VkDeviceSize>(cmd->indirectOffset),
+                                                   countBuffer->GetHandle(),
+                                                   static_cast<VkDeviceSize>(cmd->drawCountOffset),
+                                                   cmd->maxDrawCount, kDrawIndirectSize);
+                    }
+                }));
                 break;
             }
             case Command::MultiDrawIndexedIndirect: {
@@ -1606,21 +1648,20 @@
                 // Count buffer is optional
                 Buffer* countBuffer = ToBackend(cmd->drawCountBuffer.Get());
 
-                descriptorSets.Apply(device, recordingContext, VK_PIPELINE_BIND_POINT_GRAPHICS);
-                immediates.Apply(device, commands);
-
-                if (countBuffer == nullptr) {
-                    device->fn.CmdDrawIndexedIndirect(
-                        commands, indirectBuffer->GetHandle(),
-                        static_cast<VkDeviceSize>(cmd->indirectOffset), cmd->maxDrawCount,
-                        kDrawIndexedIndirectSize);
-                } else {
-                    device->fn.CmdDrawIndexedIndirectCountKHR(
-                        commands, indirectBuffer->GetHandle(),
-                        static_cast<VkDeviceSize>(cmd->indirectOffset), countBuffer->GetHandle(),
-                        static_cast<VkDeviceSize>(cmd->drawCountOffset), cmd->maxDrawCount,
-                        kDrawIndexedIndirectSize);
-                }
+                DAWN_TRY(state.SyncAndRun([&](const VulkanFunctions& vk, VkCommandBuffer commands) {
+                    if (countBuffer == nullptr) {
+                        vk.CmdDrawIndexedIndirect(commands, indirectBuffer->GetHandle(),
+                                                  static_cast<VkDeviceSize>(cmd->indirectOffset),
+                                                  cmd->maxDrawCount, kDrawIndexedIndirectSize);
+                    } else {
+                        vk.CmdDrawIndexedIndirectCountKHR(
+                            commands, indirectBuffer->GetHandle(),
+                            static_cast<VkDeviceSize>(cmd->indirectOffset),
+                            countBuffer->GetHandle(),
+                            static_cast<VkDeviceSize>(cmd->drawCountOffset), cmd->maxDrawCount,
+                            kDrawIndexedIndirectSize);
+                    }
+                }));
 
                 break;
             }
@@ -1683,8 +1724,8 @@
                     dynamicOffsets = iter->NextData<uint32_t>(cmd->dynamicOffsetCount);
                 }
 
-                descriptorSets.OnSetBindGroup(cmd->index, bindGroup, cmd->dynamicOffsetCount,
-                                              dynamicOffsets);
+                state.descriptorSets.OnSetBindGroup(cmd->index, bindGroup, cmd->dynamicOffsetCount,
+                                                    dynamicOffsets);
                 break;
             }
 
@@ -1699,14 +1740,7 @@
 
             case Command::SetRenderPipeline: {
                 SetRenderPipelineCmd* cmd = iter->NextCommand<SetRenderPipelineCmd>();
-                RenderPipeline* pipeline = ToBackend(cmd->pipeline).Get();
-
-                device->fn.CmdBindPipeline(commands, VK_PIPELINE_BIND_POINT_GRAPHICS,
-                                           pipeline->GetHandle());
-                lastPipeline = pipeline;
-
-                descriptorSets.OnSetPipeline<RenderPipeline>(pipeline);
-                immediates.OnSetPipeline(pipeline);
+                state.OnSetPipeline(ToBackend(cmd->pipeline).Get());
                 break;
             }
 
@@ -1725,13 +1759,13 @@
                 DAWN_ASSERT(cmd->size > 0);
                 uint8_t* value = nullptr;
                 value = iter->NextData<uint8_t>(cmd->size);
-                immediates.SetImmediates(cmd->offset, value, cmd->size);
+                state.immediates.SetImmediates(cmd->offset, value, cmd->size);
                 break;
             }
 
             case Command::SetResourceTable: {
                 SetResourceTableCmd* cmd = iter->NextCommand<SetResourceTableCmd>();
-                descriptorSets.SetResourceTable(ToBackend(cmd->table.Get()));
+                state.descriptorSets.SetResourceTable(ToBackend(cmd->table.Get()));
                 break;
             }
 
@@ -1739,6 +1773,8 @@
                 DAWN_UNREACHABLE();
                 break;
         }
+
+        return {};
     };
 
     Command type;
@@ -1812,7 +1848,7 @@
 
                 // Try applying the immediate data that contain min/maxDepth immediately. This can
                 // be deferred if no pipeline is currently bound.
-                immediates.SetClampFragDepth(viewport.minDepth, viewport.maxDepth);
+                state.immediates.SetClampFragDepth(viewport.minDepth, viewport.maxDepth);
                 break;
             }
 
@@ -1836,7 +1872,7 @@
                     CommandIterator* iter = bundles[i]->GetCommands();
                     iter->Reset();
                     while (iter->NextCommandId(&type)) {
-                        EncodeRenderBundleCommand(iter, type);
+                        DAWN_TRY(EncodeRenderBundleCommand(iter, type));
                     }
                 }
                 break;
@@ -1870,7 +1906,7 @@
             }
 
             default: {
-                EncodeRenderBundleCommand(&mCommands, type);
+                DAWN_TRY(EncodeRenderBundleCommand(&mCommands, type));
                 break;
             }
         }