d3d12: Lazily apply IASetVertexBuffers before draw

This fixes a bug where D3D12 SetVertexBuffers was using the input info
from the last set pipeline's InputState. IASetVertexBuffers needs to be
reapplied if the input state changes.

Bug: dawn:91
Change-Id: I7d0b308ea20cee6d595f6b29548f57d82c8e47a4
Reviewed-on: https://dawn-review.googlesource.com/c/3860
Commit-Queue: Austin Eng <enga@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/dawn_native/d3d12/CommandBufferD3D12.cpp b/src/dawn_native/d3d12/CommandBufferD3D12.cpp
index 5db2375..bff0a4d 100644
--- a/src/dawn_native/d3d12/CommandBufferD3D12.cpp
+++ b/src/dawn_native/d3d12/CommandBufferD3D12.cpp
@@ -420,6 +420,47 @@
         }
     }
 
+    void CommandBuffer::FlushSetVertexBuffers(ComPtr<ID3D12GraphicsCommandList> commandList,
+                                              VertexBuffersInfo* vertexBuffersInfo,
+                                              const InputState* inputState) {
+        DAWN_ASSERT(vertexBuffersInfo != nullptr);
+        DAWN_ASSERT(inputState != nullptr);
+
+        auto inputsMask = inputState->GetInputsSetMask();
+
+        uint32_t startSlot = vertexBuffersInfo->startSlot;
+        uint32_t endSlot = vertexBuffersInfo->endSlot;
+
+        // If the input state has changed, we need to update the StrideInBytes
+        // for the D3D12 buffer views. We also need to extend the dirty range to
+        // touch all these slots because the stride may have changed.
+        if (vertexBuffersInfo->lastInputState != inputState) {
+            vertexBuffersInfo->lastInputState = inputState;
+
+            for (uint32_t slot : IterateBitSet(inputsMask)) {
+                startSlot = std::min(startSlot, slot);
+                endSlot = std::max(endSlot, slot + 1);
+                vertexBuffersInfo->d3d12BufferViews[slot].StrideInBytes =
+                    inputState->GetInput(slot).stride;
+            }
+        }
+
+        if (endSlot <= startSlot) {
+            return;
+        }
+
+        // d3d12BufferViews is kept up to date with the most recent data passed
+        // to SetVertexBuffers. This makes it correct to only track the start
+        // and end of the dirty range. When FlushSetVertexBuffers is called,
+        // we will at worst set non-dirty vertex buffers in duplicate.
+        uint32_t count = endSlot - startSlot;
+        commandList->IASetVertexBuffers(startSlot, count,
+                                        &vertexBuffersInfo->d3d12BufferViews[startSlot]);
+
+        vertexBuffersInfo->startSlot = kMaxVertexInputs;
+        vertexBuffersInfo->endSlot = 0;
+    }
+
     void CommandBuffer::RecordComputePass(ComPtr<ID3D12GraphicsCommandList> commandList,
                                           BindGroupStateTracker* bindingTracker) {
         PipelineLayout* lastLayout = nullptr;
@@ -532,6 +573,8 @@
 
         RenderPipeline* lastPipeline = nullptr;
         PipelineLayout* lastLayout = nullptr;
+        InputState* lastInputState = nullptr;
+        VertexBuffersInfo vertexBuffersInfo = {};
 
         Command type;
         while (mCommands.NextCommandId(&type)) {
@@ -543,12 +586,16 @@
 
                 case Command::Draw: {
                     DrawCmd* draw = mCommands.NextCommand<DrawCmd>();
+
+                    FlushSetVertexBuffers(commandList, &vertexBuffersInfo, lastInputState);
                     commandList->DrawInstanced(draw->vertexCount, draw->instanceCount,
                                                draw->firstVertex, draw->firstInstance);
                 } break;
 
                 case Command::DrawIndexed: {
                     DrawIndexedCmd* draw = mCommands.NextCommand<DrawIndexedCmd>();
+
+                    FlushSetVertexBuffers(commandList, &vertexBuffersInfo, lastInputState);
                     commandList->DrawIndexedInstanced(draw->indexCount, draw->instanceCount,
                                                       draw->firstIndex, draw->baseVertex,
                                                       draw->firstInstance);
@@ -558,6 +605,7 @@
                     SetRenderPipelineCmd* cmd = mCommands.NextCommand<SetRenderPipelineCmd>();
                     RenderPipeline* pipeline = ToBackend(cmd->pipeline).Get();
                     PipelineLayout* layout = ToBackend(pipeline->GetLayout());
+                    InputState* inputState = ToBackend(pipeline->GetInputState());
 
                     commandList->SetGraphicsRootSignature(layout->GetRootSignature().Get());
                     commandList->SetPipelineState(pipeline->GetPipelineState().Get());
@@ -567,6 +615,7 @@
 
                     lastPipeline = pipeline;
                     lastLayout = layout;
+                    lastInputState = inputState;
                 } break;
 
                 case Command::SetStencilReference: {
@@ -617,19 +666,19 @@
                     auto buffers = mCommands.NextData<Ref<BufferBase>>(cmd->count);
                     auto offsets = mCommands.NextData<uint32_t>(cmd->count);
 
-                    auto inputState = ToBackend(lastPipeline->GetInputState());
+                    vertexBuffersInfo.startSlot =
+                        std::min(vertexBuffersInfo.startSlot, cmd->startSlot);
+                    vertexBuffersInfo.endSlot =
+                        std::max(vertexBuffersInfo.endSlot, cmd->startSlot + cmd->count);
 
-                    std::array<D3D12_VERTEX_BUFFER_VIEW, kMaxVertexInputs> d3d12BufferViews;
                     for (uint32_t i = 0; i < cmd->count; ++i) {
-                        auto input = inputState->GetInput(cmd->startSlot + i);
                         Buffer* buffer = ToBackend(buffers[i].Get());
-                        d3d12BufferViews[i].BufferLocation = buffer->GetVA() + offsets[i];
-                        d3d12BufferViews[i].StrideInBytes = input.stride;
-                        d3d12BufferViews[i].SizeInBytes = buffer->GetSize() - offsets[i];
+                        auto* d3d12BufferView =
+                            &vertexBuffersInfo.d3d12BufferViews[cmd->startSlot + i];
+                        d3d12BufferView->BufferLocation = buffer->GetVA() + offsets[i];
+                        d3d12BufferView->SizeInBytes = buffer->GetSize() - offsets[i];
+                        // The bufferView stride is set based on the input state before a draw.
                     }
-
-                    commandList->IASetVertexBuffers(cmd->startSlot, cmd->count,
-                                                    d3d12BufferViews.data());
                 } break;
 
                 default: { UNREACHABLE(); } break;
diff --git a/src/dawn_native/d3d12/CommandBufferD3D12.h b/src/dawn_native/d3d12/CommandBufferD3D12.h
index 95ac8bd..2ab0bb4 100644
--- a/src/dawn_native/d3d12/CommandBufferD3D12.h
+++ b/src/dawn_native/d3d12/CommandBufferD3D12.h
@@ -18,6 +18,7 @@
 #include "dawn_native/CommandAllocator.h"
 #include "dawn_native/CommandBuffer.h"
 
+#include "dawn_native/d3d12/InputStateD3D12.h"
 #include "dawn_native/d3d12/d3d12_platform.h"
 
 namespace dawn_native { namespace d3d12 {
@@ -27,6 +28,17 @@
 
     struct BindGroupStateTracker;
 
+    struct VertexBuffersInfo {
+        // startSlot and endSlot indicate the range of dirty vertex buffers.
+        // If there are multiple calls to SetVertexBuffers, the start and end
+        // represent the union of the dirty ranges (the union may have non-dirty
+        // data in the middle of the range).
+        const InputState* lastInputState = nullptr;
+        uint32_t startSlot = kMaxVertexInputs;
+        uint32_t endSlot = 0;
+        std::array<D3D12_VERTEX_BUFFER_VIEW, kMaxVertexInputs> d3d12BufferViews = {};
+    };
+
     class CommandBuffer : public CommandBufferBase {
       public:
         CommandBuffer(CommandBufferBuilder* builder);
@@ -35,6 +47,9 @@
         void RecordCommands(ComPtr<ID3D12GraphicsCommandList> commandList, uint32_t indexInSubmit);
 
       private:
+        void FlushSetVertexBuffers(ComPtr<ID3D12GraphicsCommandList> commandList,
+                                   VertexBuffersInfo* vertexBuffersInfo,
+                                   const InputState* inputState);
         void RecordComputePass(ComPtr<ID3D12GraphicsCommandList> commandList,
                                BindGroupStateTracker* bindingTracker);
         void RecordRenderPass(ComPtr<ID3D12GraphicsCommandList> commandList,
diff --git a/src/tests/end2end/InputStateTests.cpp b/src/tests/end2end/InputStateTests.cpp
index 21cb337..922702d 100644
--- a/src/tests/end2end/InputStateTests.cpp
+++ b/src/tests/end2end/InputStateTests.cpp
@@ -185,6 +185,10 @@
             dawn::CommandBuffer commands = builder.GetResult();
             queue.Submit(1, &commands);
 
+            CheckResult(triangles, instances);
+        }
+
+        void CheckResult(unsigned int triangles, unsigned int instances) {
             // Check that the center of each triangle is pure green, so that if a single vertex shader
             // instance fails, linear interpolation makes the pixel check fail.
             for (unsigned int triangle = 0; triangle < 4; triangle++) {
@@ -423,6 +427,89 @@
     DoTestDraw(pipeline, 1, 1, {{0, &buffer0}, {1, &buffer1}});
 }
 
+// Test input state is unaffected by unused vertex slot
+TEST_P(InputStateTest, UnusedVertexSlot) {
+    // Instance input state, using slot 1
+    dawn::InputState instanceInputState =
+        MakeInputState({{1, 4 * sizeof(float), InputStepMode::Instance}},
+                       {{0, 1, 0, VertexFormat::FloatR32G32B32A32}});
+    dawn::RenderPipeline instancePipeline = MakeTestPipeline(
+        instanceInputState, 1, {{0, VertexFormat::FloatR32G32B32A32, InputStepMode::Instance}});
+
+    dawn::Buffer buffer = MakeVertexBuffer<float>({
+        0, 1, 2, 3,
+        1, 2, 3, 4,
+        2, 3, 4, 5,
+        3, 4, 5, 6,
+    });
+
+    dawn::CommandBufferBuilder builder = device.CreateCommandBufferBuilder();
+
+    dawn::RenderPassEncoder pass = builder.BeginRenderPass(renderPass.renderPassInfo);
+
+    uint32_t zeroOffset = 0;
+    pass.SetVertexBuffers(0, 1, &buffer, &zeroOffset);
+    pass.SetVertexBuffers(1, 1, &buffer, &zeroOffset);
+
+    pass.SetPipeline(instancePipeline);
+    pass.Draw(1 * 3, 4, 0, 0);
+
+    pass.EndPass();
+
+    dawn::CommandBuffer commands = builder.GetResult();
+    queue.Submit(1, &commands);
+
+    CheckResult(1, 4);
+}
+
+// Test setting a different pipeline with a different input state.
+// This was a problem with the D3D12 backend where SetVertexBuffers
+// was getting the input from the last set pipeline, not the current.
+// SetVertexBuffers should be reapplied when the input state changes.
+TEST_P(InputStateTest, MultiplePipelinesMixedInputState) {
+    // Basic input state, using slot 0
+    dawn::InputState vertexInputState =
+        MakeInputState({{0, 4 * sizeof(float), InputStepMode::Vertex}},
+                       {{0, 0, 0, VertexFormat::FloatR32G32B32A32}});
+    dawn::RenderPipeline vertexPipeline = MakeTestPipeline(
+        vertexInputState, 1, {{0, VertexFormat::FloatR32G32B32A32, InputStepMode::Vertex}});
+
+    // Instance input state, using slot 1
+    dawn::InputState instanceInputState =
+        MakeInputState({{1, 4 * sizeof(float), InputStepMode::Instance}},
+                       {{0, 1, 0, VertexFormat::FloatR32G32B32A32}});
+    dawn::RenderPipeline instancePipeline = MakeTestPipeline(
+        instanceInputState, 1, {{0, VertexFormat::FloatR32G32B32A32, InputStepMode::Instance}});
+
+    dawn::Buffer buffer = MakeVertexBuffer<float>({
+        0, 1, 2, 3,
+        1, 2, 3, 4,
+        2, 3, 4, 5,
+        3, 4, 5, 6,
+    });
+
+    dawn::CommandBufferBuilder builder = device.CreateCommandBufferBuilder();
+
+    dawn::RenderPassEncoder pass = builder.BeginRenderPass(renderPass.renderPassInfo);
+
+    uint32_t zeroOffset = 0;
+    pass.SetVertexBuffers(0, 1, &buffer, &zeroOffset);
+    pass.SetVertexBuffers(1, 1, &buffer, &zeroOffset);
+
+    pass.SetPipeline(vertexPipeline);
+    pass.Draw(1 * 3, 1, 0, 0);
+
+    pass.SetPipeline(instancePipeline);
+    pass.Draw(1 * 3, 4, 0, 0);
+
+    pass.EndPass();
+
+    dawn::CommandBuffer commands = builder.GetResult();
+    queue.Submit(1, &commands);
+
+    CheckResult(1, 4);
+}
+
 DAWN_INSTANTIATE_TEST(InputStateTest, D3D12Backend, MetalBackend, OpenGLBackend, VulkanBackend)
 
 // TODO for the input state: