d3d11: VertexBufferTracker

This introduces VertexBufferTracker to correctly track pipeline's
vertex buffer state.

Bug: dawn:1799
Bug: dawn:1705
Change-Id: I06f32b501a3637b22318ec201b1953eba6ed0cf2
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/131700
Reviewed-by: Peng Huang <penghuang@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Jie A Chen <jie.a.chen@intel.com>
diff --git a/src/dawn/native/d3d11/CommandBufferD3D11.cpp b/src/dawn/native/d3d11/CommandBufferD3D11.cpp
index 38bf0eb..4db4f9b 100644
--- a/src/dawn/native/d3d11/CommandBufferD3D11.cpp
+++ b/src/dawn/native/d3d11/CommandBufferD3D11.cpp
@@ -50,6 +50,48 @@
     }
 }
 
+class VertexBufferTracker {
+  public:
+    explicit VertexBufferTracker(CommandRecordingContext* commandContext)
+        : mCommandContext(commandContext) {}
+
+    ~VertexBufferTracker() {
+        mD3D11Buffers = {};
+        mStrides = {};
+        mOffsets = {};
+        mCommandContext->GetD3D11DeviceContext()->IASetVertexBuffers(
+            0, kMaxVertexBuffers, mD3D11Buffers.data(), mStrides.data(), mOffsets.data());
+    }
+
+    void OnSetVertexBuffer(VertexBufferSlot slot, ID3D11Buffer* buffer, uint64_t offset) {
+        mD3D11Buffers[slot] = buffer;
+        mOffsets[slot] = offset;
+    }
+
+    void Apply(const RenderPipeline* renderPipeline) {
+        ASSERT(renderPipeline != nullptr);
+
+        // If the vertex state has changed, we need to update the strides.
+        if (mLastAppliedRenderPipeline != renderPipeline) {
+            mLastAppliedRenderPipeline = renderPipeline;
+            for (VertexBufferSlot slot :
+                 IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) {
+                mStrides[slot] = renderPipeline->GetVertexBuffer(slot).arrayStride;
+            }
+        }
+
+        mCommandContext->GetD3D11DeviceContext()->IASetVertexBuffers(
+            0, kMaxVertexBuffers, mD3D11Buffers.data(), mStrides.data(), mOffsets.data());
+    }
+
+  private:
+    CommandRecordingContext* const mCommandContext;
+    const RenderPipeline* mLastAppliedRenderPipeline = nullptr;
+    ityp::array<VertexBufferSlot, ID3D11Buffer*, kMaxVertexBuffers> mD3D11Buffers = {};
+    ityp::array<VertexBufferSlot, UINT, kMaxVertexBuffers> mStrides = {};
+    ityp::array<VertexBufferSlot, UINT, kMaxVertexBuffers> mOffsets = {};
+};
+
 }  // namespace
 
 // Create CommandBuffer
@@ -437,6 +479,7 @@
 
     RenderPipeline* lastPipeline = nullptr;
     BindGroupTracker bindGroupTracker(commandContext);
+    VertexBufferTracker vertexBufferTracker(commandContext);
     std::array<float, 4> blendColor = {0.0f, 0.0f, 0.0f, 0.0f};
     uint32_t stencilReference = 0;
 
@@ -446,6 +489,7 @@
                 DrawCmd* draw = iter->NextCommand<DrawCmd>();
 
                 DAWN_TRY(bindGroupTracker.Apply());
+                vertexBufferTracker.Apply(lastPipeline);
                 DAWN_TRY(RecordFirstIndexOffset(lastPipeline, commandContext, draw->firstVertex,
                                                 draw->firstInstance));
                 commandContext->GetD3D11DeviceContext()->DrawInstanced(
@@ -458,6 +502,7 @@
                 DrawIndexedCmd* draw = iter->NextCommand<DrawIndexedCmd>();
 
                 DAWN_TRY(bindGroupTracker.Apply());
+                vertexBufferTracker.Apply(lastPipeline);
                 DAWN_TRY(RecordFirstIndexOffset(lastPipeline, commandContext, draw->baseVertex,
                                                 draw->firstInstance));
                 commandContext->GetD3D11DeviceContext()->DrawIndexedInstanced(
@@ -474,6 +519,7 @@
                 ASSERT(indirectBuffer != nullptr);
 
                 DAWN_TRY(bindGroupTracker.Apply());
+                vertexBufferTracker.Apply(lastPipeline);
 
                 if (lastPipeline->GetUsesVertexOrInstanceIndex()) {
                     // Copy StartVertexLocation and StartInstanceLocation into the uniform buffer
@@ -499,6 +545,7 @@
                 ASSERT(indirectBuffer != nullptr);
 
                 DAWN_TRY(bindGroupTracker.Apply());
+                vertexBufferTracker.Apply(lastPipeline);
 
                 if (lastPipeline->GetUsesVertexOrInstanceIndex()) {
                     // Copy StartVertexLocation and StartInstanceLocation into the uniform buffer
@@ -555,17 +602,8 @@
 
             case Command::SetVertexBuffer: {
                 SetVertexBufferCmd* cmd = iter->NextCommand<SetVertexBufferCmd>();
-                ASSERT(lastPipeline);
-                const VertexBufferInfo& info = lastPipeline->GetVertexBuffer(cmd->slot);
-
-                // TODO(dawn:1705): should we set vertex back to nullptr after the draw call?
-                UINT slot = static_cast<uint8_t>(cmd->slot);
                 ID3D11Buffer* buffer = ToBackend(cmd->buffer)->GetD3D11Buffer();
-                UINT arrayStride = info.arrayStride;
-                UINT offset = cmd->offset;
-                commandContext->GetD3D11DeviceContext()->IASetVertexBuffers(slot, 1, &buffer,
-                                                                            &arrayStride, &offset);
-
+                vertexBufferTracker.OnSetVertexBuffer(cmd->slot, buffer, cmd->offset);
                 break;
             }
 
diff --git a/src/dawn/tests/end2end/BufferZeroInitTests.cpp b/src/dawn/tests/end2end/BufferZeroInitTests.cpp
index faf3d7d..05982fa 100644
--- a/src/dawn/tests/end2end/BufferZeroInitTests.cpp
+++ b/src/dawn/tests/end2end/BufferZeroInitTests.cpp
@@ -1122,8 +1122,6 @@
 
 // Test the buffer will be lazily initialized correctly when its first use is in SetVertexBuffer.
 TEST_P(BufferZeroInitTest, SetVertexBuffer) {
-    // TODO(dawn:1799): Figure this out.
-    DAWN_SUPPRESS_TEST_IF(IsD3D11());
     // Bind the whole buffer as a vertex buffer.
     {
         constexpr uint64_t kVertexBufferOffset = 0u;