D3D12: Factor SetVertexBuffer tracking to match other tracking classes

Bug: dawn:201
Change-Id: I711e93a706b5043318263b203d3f3dc7f1a675bb
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/11000
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
Commit-Queue: Austin Eng <enga@chromium.org>
diff --git a/src/dawn_native/d3d12/CommandBufferD3D12.cpp b/src/dawn_native/d3d12/CommandBufferD3D12.cpp
index 7ea4e40..aeb87d7 100644
--- a/src/dawn_native/d3d12/CommandBufferD3D12.cpp
+++ b/src/dawn_native/d3d12/CommandBufferD3D12.cpp
@@ -38,6 +38,7 @@
 namespace dawn_native { namespace d3d12 {
 
     namespace {
+
         DXGI_FORMAT DXGIIndexFormat(dawn::IndexFormat format) {
             switch (format) {
                 case dawn::IndexFormat::Uint16:
@@ -63,6 +64,12 @@
             return false;
         }
 
+        struct OMSetRenderTargetArgs {
+            unsigned int numRTVs = 0;
+            std::array<D3D12_CPU_DESCRIPTOR_HANDLE, kMaxColorAttachments> RTVs = {};
+            D3D12_CPU_DESCRIPTOR_HANDLE dsv = {};
+        };
+
     }  // anonymous namespace
 
     class BindGroupStateTracker {
@@ -291,12 +298,6 @@
         Device* mDevice;
     };
 
-    struct OMSetRenderTargetArgs {
-        unsigned int numRTVs = 0;
-        std::array<D3D12_CPU_DESCRIPTOR_HANDLE, kMaxColorAttachments> RTVs = {};
-        D3D12_CPU_DESCRIPTOR_HANDLE dsv = {};
-    };
-
     class RenderPassDescriptorHeapTracker {
       public:
         RenderPassDescriptorHeapTracker(Device* device) : mDevice(device) {
@@ -325,8 +326,8 @@
             }
         }
 
-        // TODO(jiawei.shao@intel.com): use hash map <RenderPass, OMSetRenderTargetArgs> as cache to
-        // avoid redundant RTV and DSV memory allocations.
+        // TODO(jiawei.shao@intel.com): use hash map <RenderPass, OMSetRenderTargetArgs> as
+        // cache to avoid redundant RTV and DSV memory allocations.
         OMSetRenderTargetArgs GetSubpassOMSetRenderTargetArgs(BeginRenderPassCmd* renderPass) {
             OMSetRenderTargetArgs args = {};
 
@@ -380,6 +381,73 @@
 
     namespace {
 
+        class VertexBufferTracker {
+          public:
+            void OnSetVertexBuffers(uint32_t startSlot,
+                                    uint32_t count,
+                                    Ref<BufferBase>* buffers,
+                                    uint64_t* offsets) {
+                mStartSlot = std::min(mStartSlot, startSlot);
+                mEndSlot = std::max(mEndSlot, startSlot + count);
+
+                for (uint32_t i = 0; i < count; ++i) {
+                    Buffer* buffer = ToBackend(buffers[i].Get());
+                    auto* d3d12BufferView = &mD3D12BufferViews[startSlot + i];
+                    d3d12BufferView->BufferLocation = buffer->GetVA() + offsets[i];
+                    d3d12BufferView->SizeInBytes = buffer->GetSize() - offsets[i];
+                    // The bufferView stride is set based on the input state before a draw.
+                }
+            }
+
+            void Apply(ID3D12GraphicsCommandList* commandList,
+                       const RenderPipeline* renderPipeline) {
+                ASSERT(renderPipeline != nullptr);
+
+                std::bitset<kMaxVertexBuffers> inputsMask = renderPipeline->GetInputsSetMask();
+
+                uint32_t startSlot = mStartSlot;
+                uint32_t endSlot = mEndSlot;
+
+                // If the input state has changed, we need to update the StrideInBytes
+                // for the D3D12 buffer views. We also need to extend the dirty range to
+                // touch all these slots because the stride may have changed.
+                if (mLastAppliedRenderPipeline != renderPipeline) {
+                    mLastAppliedRenderPipeline = renderPipeline;
+
+                    for (uint32_t slot : IterateBitSet(inputsMask)) {
+                        startSlot = std::min(startSlot, slot);
+                        endSlot = std::max(endSlot, slot + 1);
+                        mD3D12BufferViews[slot].StrideInBytes =
+                            renderPipeline->GetInput(slot).stride;
+                    }
+                }
+
+                if (endSlot <= startSlot) {
+                    return;
+                }
+
+                // mD3D12BufferViews is kept up to date with the most recent data passed
+                // to SetVertexBuffers. This makes it correct to only track the start
+                // and end of the dirty range. When Apply is called,
+                // we will at worst set non-dirty vertex buffers in duplicate.
+                uint32_t count = endSlot - startSlot;
+                commandList->IASetVertexBuffers(startSlot, count, &mD3D12BufferViews[startSlot]);
+
+                mStartSlot = kMaxVertexBuffers;
+                mEndSlot = 0;
+            }
+
+          private:
+            // startSlot and endSlot indicate the range of dirty vertex buffers.
+            // If there are multiple calls to SetVertexBuffers, the start and end
+            // represent the union of the dirty ranges (the union may have non-dirty
+            // data in the middle of the range).
+            const RenderPipeline* mLastAppliedRenderPipeline = nullptr;
+            uint32_t mStartSlot = kMaxVertexBuffers;
+            uint32_t mEndSlot = 0;
+            std::array<D3D12_VERTEX_BUFFER_VIEW, kMaxVertexBuffers> mD3D12BufferViews = {};
+        };
+
         void AllocateAndSetDescriptorHeaps(Device* device,
                                            BindGroupStateTracker* bindingTracker,
                                            RenderPassDescriptorHeapTracker* renderPassTracker,
@@ -719,47 +787,6 @@
         DAWN_ASSERT(renderPassTracker.IsHeapAllocationCompleted());
     }
 
-    void CommandBuffer::FlushSetVertexBuffers(ComPtr<ID3D12GraphicsCommandList> commandList,
-                                              VertexBuffersInfo* vertexBuffersInfo,
-                                              const RenderPipeline* renderPipeline) {
-        DAWN_ASSERT(vertexBuffersInfo != nullptr);
-        DAWN_ASSERT(renderPipeline != nullptr);
-
-        auto inputsMask = renderPipeline->GetInputsSetMask();
-
-        uint32_t startSlot = vertexBuffersInfo->startSlot;
-        uint32_t endSlot = vertexBuffersInfo->endSlot;
-
-        // If the input state has changed, we need to update the StrideInBytes
-        // for the D3D12 buffer views. We also need to extend the dirty range to
-        // touch all these slots because the stride may have changed.
-        if (vertexBuffersInfo->lastRenderPipeline != renderPipeline) {
-            vertexBuffersInfo->lastRenderPipeline = renderPipeline;
-
-            for (uint32_t slot : IterateBitSet(inputsMask)) {
-                startSlot = std::min(startSlot, slot);
-                endSlot = std::max(endSlot, slot + 1);
-                vertexBuffersInfo->d3d12BufferViews[slot].StrideInBytes =
-                    renderPipeline->GetInput(slot).stride;
-            }
-        }
-
-        if (endSlot <= startSlot) {
-            return;
-        }
-
-        // d3d12BufferViews is kept up to date with the most recent data passed
-        // to SetVertexBuffers. This makes it correct to only track the start
-        // and end of the dirty range. When FlushSetVertexBuffers is called,
-        // we will at worst set non-dirty vertex buffers in duplicate.
-        uint32_t count = endSlot - startSlot;
-        commandList->IASetVertexBuffers(startSlot, count,
-                                        &vertexBuffersInfo->d3d12BufferViews[startSlot]);
-
-        vertexBuffersInfo->startSlot = kMaxVertexBuffers;
-        vertexBuffersInfo->endSlot = 0;
-    }
-
     void CommandBuffer::RecordComputePass(ComPtr<ID3D12GraphicsCommandList> commandList,
                                           BindGroupStateTracker* bindingTracker) {
         PipelineLayout* lastLayout = nullptr;
@@ -969,14 +996,14 @@
 
         RenderPipeline* lastPipeline = nullptr;
         PipelineLayout* lastLayout = nullptr;
-        VertexBuffersInfo vertexBuffersInfo = {};
+        VertexBufferTracker vertexBufferTracker = {};
 
         auto EncodeRenderBundleCommand = [&](CommandIterator* iter, Command type) {
             switch (type) {
                 case Command::Draw: {
                     DrawCmd* draw = iter->NextCommand<DrawCmd>();
 
-                    FlushSetVertexBuffers(commandList, &vertexBuffersInfo, lastPipeline);
+                    vertexBufferTracker.Apply(commandList.Get(), lastPipeline);
                     commandList->DrawInstanced(draw->vertexCount, draw->instanceCount,
                                                draw->firstVertex, draw->firstInstance);
                 } break;
@@ -984,7 +1011,7 @@
                 case Command::DrawIndexed: {
                     DrawIndexedCmd* draw = iter->NextCommand<DrawIndexedCmd>();
 
-                    FlushSetVertexBuffers(commandList, &vertexBuffersInfo, lastPipeline);
+                    vertexBufferTracker.Apply(commandList.Get(), lastPipeline);
                     commandList->DrawIndexedInstanced(draw->indexCount, draw->instanceCount,
                                                       draw->firstIndex, draw->baseVertex,
                                                       draw->firstInstance);
@@ -993,7 +1020,7 @@
                 case Command::DrawIndirect: {
                     DrawIndirectCmd* draw = iter->NextCommand<DrawIndirectCmd>();
 
-                    FlushSetVertexBuffers(commandList, &vertexBuffersInfo, lastPipeline);
+                    vertexBufferTracker.Apply(commandList.Get(), lastPipeline);
                     Buffer* buffer = ToBackend(draw->indirectBuffer.Get());
                     ComPtr<ID3D12CommandSignature> signature =
                         ToBackend(GetDevice())->GetDrawIndirectSignature();
@@ -1005,7 +1032,7 @@
                 case Command::DrawIndexedIndirect: {
                     DrawIndexedIndirectCmd* draw = iter->NextCommand<DrawIndexedIndirectCmd>();
 
-                    FlushSetVertexBuffers(commandList, &vertexBuffersInfo, lastPipeline);
+                    vertexBufferTracker.Apply(commandList.Get(), lastPipeline);
                     Buffer* buffer = ToBackend(draw->indirectBuffer.Get());
                     ComPtr<ID3D12CommandSignature> signature =
                         ToBackend(GetDevice())->GetDrawIndexedIndirectSignature();
@@ -1096,22 +1123,11 @@
 
                 case Command::SetVertexBuffers: {
                     SetVertexBuffersCmd* cmd = iter->NextCommand<SetVertexBuffersCmd>();
-                    auto buffers = iter->NextData<Ref<BufferBase>>(cmd->count);
-                    auto offsets = iter->NextData<uint64_t>(cmd->count);
+                    Ref<BufferBase>* buffers = iter->NextData<Ref<BufferBase>>(cmd->count);
+                    uint64_t* offsets = iter->NextData<uint64_t>(cmd->count);
 
-                    vertexBuffersInfo.startSlot =
-                        std::min(vertexBuffersInfo.startSlot, cmd->startSlot);
-                    vertexBuffersInfo.endSlot =
-                        std::max(vertexBuffersInfo.endSlot, cmd->startSlot + cmd->count);
-
-                    for (uint32_t i = 0; i < cmd->count; ++i) {
-                        Buffer* buffer = ToBackend(buffers[i].Get());
-                        auto* d3d12BufferView =
-                            &vertexBuffersInfo.d3d12BufferViews[cmd->startSlot + i];
-                        d3d12BufferView->BufferLocation = buffer->GetVA() + offsets[i];
-                        d3d12BufferView->SizeInBytes = buffer->GetSize() - offsets[i];
-                        // The bufferView stride is set based on the input state before a draw.
-                    }
+                    vertexBufferTracker.OnSetVertexBuffers(cmd->startSlot, cmd->count, buffers,
+                                                           offsets);
                 } break;
 
                 default:
diff --git a/src/dawn_native/d3d12/CommandBufferD3D12.h b/src/dawn_native/d3d12/CommandBufferD3D12.h
index a367b4b..78c5630 100644
--- a/src/dawn_native/d3d12/CommandBufferD3D12.h
+++ b/src/dawn_native/d3d12/CommandBufferD3D12.h
@@ -35,17 +35,6 @@
     class RenderPassDescriptorHeapTracker;
     class RenderPipeline;
 
-    struct VertexBuffersInfo {
-        // startSlot and endSlot indicate the range of dirty vertex buffers.
-        // If there are multiple calls to SetVertexBuffers, the start and end
-        // represent the union of the dirty ranges (the union may have non-dirty
-        // data in the middle of the range).
-        const RenderPipeline* lastRenderPipeline = nullptr;
-        uint32_t startSlot = kMaxVertexBuffers;
-        uint32_t endSlot = 0;
-        std::array<D3D12_VERTEX_BUFFER_VIEW, kMaxVertexBuffers> d3d12BufferViews = {};
-    };
-
     class CommandBuffer : public CommandBufferBase {
       public:
         CommandBuffer(CommandEncoderBase* encoder, const CommandBufferDescriptor* descriptor);
@@ -54,9 +43,6 @@
         void RecordCommands(ComPtr<ID3D12GraphicsCommandList> commandList, uint32_t indexInSubmit);
 
       private:
-        void FlushSetVertexBuffers(ComPtr<ID3D12GraphicsCommandList> commandList,
-                                   VertexBuffersInfo* vertexBuffersInfo,
-                                   const RenderPipeline* lastRenderPipeline);
         void RecordComputePass(ComPtr<ID3D12GraphicsCommandList> commandList,
                                BindGroupStateTracker* bindingTracker);
         void RecordRenderPass(ComPtr<ID3D12GraphicsCommandList> commandList,