d3d12: Lazily apply IASetVertexBuffers before draw
This fixes a bug where D3D12 SetVertexBuffers was using the input info
from the last set pipeline's InputState. IASetVertexBuffers needs to be
reapplied if the input state changes.
Bug: dawn:91
Change-Id: I7d0b308ea20cee6d595f6b29548f57d82c8e47a4
Reviewed-on: https://dawn-review.googlesource.com/c/3860
Commit-Queue: Austin Eng <enga@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/dawn_native/d3d12/CommandBufferD3D12.cpp b/src/dawn_native/d3d12/CommandBufferD3D12.cpp
index 5db2375..bff0a4d 100644
--- a/src/dawn_native/d3d12/CommandBufferD3D12.cpp
+++ b/src/dawn_native/d3d12/CommandBufferD3D12.cpp
@@ -420,6 +420,47 @@
}
}
+ void CommandBuffer::FlushSetVertexBuffers(ComPtr<ID3D12GraphicsCommandList> commandList,
+ VertexBuffersInfo* vertexBuffersInfo,
+ const InputState* inputState) {
+ DAWN_ASSERT(vertexBuffersInfo != nullptr);
+ DAWN_ASSERT(inputState != nullptr);
+
+ auto inputsMask = inputState->GetInputsSetMask();
+
+ uint32_t startSlot = vertexBuffersInfo->startSlot;
+ uint32_t endSlot = vertexBuffersInfo->endSlot;
+
+ // If the input state has changed, we need to update the StrideInBytes
+ // for the D3D12 buffer views. We also need to extend the dirty range to
+ // touch all these slots because the stride may have changed.
+ if (vertexBuffersInfo->lastInputState != inputState) {
+ vertexBuffersInfo->lastInputState = inputState;
+
+ for (uint32_t slot : IterateBitSet(inputsMask)) {
+ startSlot = std::min(startSlot, slot);
+ endSlot = std::max(endSlot, slot + 1);
+ vertexBuffersInfo->d3d12BufferViews[slot].StrideInBytes =
+ inputState->GetInput(slot).stride;
+ }
+ }
+
+ if (endSlot <= startSlot) {
+ return;
+ }
+
+ // d3d12BufferViews is kept up to date with the most recent data passed
+ // to SetVertexBuffers. This makes it correct to only track the start
+ // and end of the dirty range. When FlushSetVertexBuffers is called,
+ // we will at worst set non-dirty vertex buffers in duplicate.
+ uint32_t count = endSlot - startSlot;
+ commandList->IASetVertexBuffers(startSlot, count,
+ &vertexBuffersInfo->d3d12BufferViews[startSlot]);
+
+ vertexBuffersInfo->startSlot = kMaxVertexInputs;
+ vertexBuffersInfo->endSlot = 0;
+ }
+
void CommandBuffer::RecordComputePass(ComPtr<ID3D12GraphicsCommandList> commandList,
BindGroupStateTracker* bindingTracker) {
PipelineLayout* lastLayout = nullptr;
@@ -532,6 +573,8 @@
RenderPipeline* lastPipeline = nullptr;
PipelineLayout* lastLayout = nullptr;
+ InputState* lastInputState = nullptr;
+ VertexBuffersInfo vertexBuffersInfo = {};
Command type;
while (mCommands.NextCommandId(&type)) {
@@ -543,12 +586,16 @@
case Command::Draw: {
DrawCmd* draw = mCommands.NextCommand<DrawCmd>();
+
+ FlushSetVertexBuffers(commandList, &vertexBuffersInfo, lastInputState);
commandList->DrawInstanced(draw->vertexCount, draw->instanceCount,
draw->firstVertex, draw->firstInstance);
} break;
case Command::DrawIndexed: {
DrawIndexedCmd* draw = mCommands.NextCommand<DrawIndexedCmd>();
+
+ FlushSetVertexBuffers(commandList, &vertexBuffersInfo, lastInputState);
commandList->DrawIndexedInstanced(draw->indexCount, draw->instanceCount,
draw->firstIndex, draw->baseVertex,
draw->firstInstance);
@@ -558,6 +605,7 @@
SetRenderPipelineCmd* cmd = mCommands.NextCommand<SetRenderPipelineCmd>();
RenderPipeline* pipeline = ToBackend(cmd->pipeline).Get();
PipelineLayout* layout = ToBackend(pipeline->GetLayout());
+ InputState* inputState = ToBackend(pipeline->GetInputState());
commandList->SetGraphicsRootSignature(layout->GetRootSignature().Get());
commandList->SetPipelineState(pipeline->GetPipelineState().Get());
@@ -567,6 +615,7 @@
lastPipeline = pipeline;
lastLayout = layout;
+ lastInputState = inputState;
} break;
case Command::SetStencilReference: {
@@ -617,19 +666,19 @@
auto buffers = mCommands.NextData<Ref<BufferBase>>(cmd->count);
auto offsets = mCommands.NextData<uint32_t>(cmd->count);
- auto inputState = ToBackend(lastPipeline->GetInputState());
+ vertexBuffersInfo.startSlot =
+ std::min(vertexBuffersInfo.startSlot, cmd->startSlot);
+ vertexBuffersInfo.endSlot =
+ std::max(vertexBuffersInfo.endSlot, cmd->startSlot + cmd->count);
- std::array<D3D12_VERTEX_BUFFER_VIEW, kMaxVertexInputs> d3d12BufferViews;
for (uint32_t i = 0; i < cmd->count; ++i) {
- auto input = inputState->GetInput(cmd->startSlot + i);
Buffer* buffer = ToBackend(buffers[i].Get());
- d3d12BufferViews[i].BufferLocation = buffer->GetVA() + offsets[i];
- d3d12BufferViews[i].StrideInBytes = input.stride;
- d3d12BufferViews[i].SizeInBytes = buffer->GetSize() - offsets[i];
+ auto* d3d12BufferView =
+ &vertexBuffersInfo.d3d12BufferViews[cmd->startSlot + i];
+ d3d12BufferView->BufferLocation = buffer->GetVA() + offsets[i];
+ d3d12BufferView->SizeInBytes = buffer->GetSize() - offsets[i];
+ // The bufferView stride is set based on the input state before a draw.
}
-
- commandList->IASetVertexBuffers(cmd->startSlot, cmd->count,
- d3d12BufferViews.data());
} break;
default: { UNREACHABLE(); } break;
diff --git a/src/dawn_native/d3d12/CommandBufferD3D12.h b/src/dawn_native/d3d12/CommandBufferD3D12.h
index 95ac8bd..2ab0bb4 100644
--- a/src/dawn_native/d3d12/CommandBufferD3D12.h
+++ b/src/dawn_native/d3d12/CommandBufferD3D12.h
@@ -18,6 +18,7 @@
#include "dawn_native/CommandAllocator.h"
#include "dawn_native/CommandBuffer.h"
+#include "dawn_native/d3d12/InputStateD3D12.h"
#include "dawn_native/d3d12/d3d12_platform.h"
namespace dawn_native { namespace d3d12 {
@@ -27,6 +28,17 @@
struct BindGroupStateTracker;
+ struct VertexBuffersInfo {
+ // startSlot and endSlot indicate the range of dirty vertex buffers.
+ // If there are multiple calls to SetVertexBuffers, the start and end
+ // represent the union of the dirty ranges (the union may have non-dirty
+ // data in the middle of the range).
+ const InputState* lastInputState = nullptr;
+ uint32_t startSlot = kMaxVertexInputs;
+ uint32_t endSlot = 0;
+ std::array<D3D12_VERTEX_BUFFER_VIEW, kMaxVertexInputs> d3d12BufferViews = {};
+ };
+
class CommandBuffer : public CommandBufferBase {
public:
CommandBuffer(CommandBufferBuilder* builder);
@@ -35,6 +47,9 @@
void RecordCommands(ComPtr<ID3D12GraphicsCommandList> commandList, uint32_t indexInSubmit);
private:
+ void FlushSetVertexBuffers(ComPtr<ID3D12GraphicsCommandList> commandList,
+ VertexBuffersInfo* vertexBuffersInfo,
+ const InputState* inputState);
void RecordComputePass(ComPtr<ID3D12GraphicsCommandList> commandList,
BindGroupStateTracker* bindingTracker);
void RecordRenderPass(ComPtr<ID3D12GraphicsCommandList> commandList,
diff --git a/src/tests/end2end/InputStateTests.cpp b/src/tests/end2end/InputStateTests.cpp
index 21cb337..922702d 100644
--- a/src/tests/end2end/InputStateTests.cpp
+++ b/src/tests/end2end/InputStateTests.cpp
@@ -185,6 +185,10 @@
dawn::CommandBuffer commands = builder.GetResult();
queue.Submit(1, &commands);
+ CheckResult(triangles, instances);
+ }
+
+ void CheckResult(unsigned int triangles, unsigned int instances) {
// Check that the center of each triangle is pure green, so that if a single vertex shader
// instance fails, linear interpolation makes the pixel check fail.
for (unsigned int triangle = 0; triangle < 4; triangle++) {
@@ -423,6 +427,89 @@
DoTestDraw(pipeline, 1, 1, {{0, &buffer0}, {1, &buffer1}});
}
+// Test input state is unaffected by unused vertex slot
+TEST_P(InputStateTest, UnusedVertexSlot) {
+ // Instance input state, using slot 1
+ dawn::InputState instanceInputState =
+ MakeInputState({{1, 4 * sizeof(float), InputStepMode::Instance}},
+ {{0, 1, 0, VertexFormat::FloatR32G32B32A32}});
+ dawn::RenderPipeline instancePipeline = MakeTestPipeline(
+ instanceInputState, 1, {{0, VertexFormat::FloatR32G32B32A32, InputStepMode::Instance}});
+
+ dawn::Buffer buffer = MakeVertexBuffer<float>({
+ 0, 1, 2, 3,
+ 1, 2, 3, 4,
+ 2, 3, 4, 5,
+ 3, 4, 5, 6,
+ });
+
+ dawn::CommandBufferBuilder builder = device.CreateCommandBufferBuilder();
+
+ dawn::RenderPassEncoder pass = builder.BeginRenderPass(renderPass.renderPassInfo);
+
+ uint32_t zeroOffset = 0;
+ pass.SetVertexBuffers(0, 1, &buffer, &zeroOffset);
+ pass.SetVertexBuffers(1, 1, &buffer, &zeroOffset);
+
+ pass.SetPipeline(instancePipeline);
+ pass.Draw(1 * 3, 4, 0, 0);
+
+ pass.EndPass();
+
+ dawn::CommandBuffer commands = builder.GetResult();
+ queue.Submit(1, &commands);
+
+ CheckResult(1, 4);
+}
+
+// Test setting a different pipeline with a different input state.
+// This was a problem with the D3D12 backend where SetVertexBuffers
+// was getting the input from the last set pipeline, not the current.
+// SetVertexBuffers should be reapplied when the input state changes.
+TEST_P(InputStateTest, MultiplePipelinesMixedInputState) {
+ // Basic input state, using slot 0
+ dawn::InputState vertexInputState =
+ MakeInputState({{0, 4 * sizeof(float), InputStepMode::Vertex}},
+ {{0, 0, 0, VertexFormat::FloatR32G32B32A32}});
+ dawn::RenderPipeline vertexPipeline = MakeTestPipeline(
+ vertexInputState, 1, {{0, VertexFormat::FloatR32G32B32A32, InputStepMode::Vertex}});
+
+ // Instance input state, using slot 1
+ dawn::InputState instanceInputState =
+ MakeInputState({{1, 4 * sizeof(float), InputStepMode::Instance}},
+ {{0, 1, 0, VertexFormat::FloatR32G32B32A32}});
+ dawn::RenderPipeline instancePipeline = MakeTestPipeline(
+ instanceInputState, 1, {{0, VertexFormat::FloatR32G32B32A32, InputStepMode::Instance}});
+
+ dawn::Buffer buffer = MakeVertexBuffer<float>({
+ 0, 1, 2, 3,
+ 1, 2, 3, 4,
+ 2, 3, 4, 5,
+ 3, 4, 5, 6,
+ });
+
+ dawn::CommandBufferBuilder builder = device.CreateCommandBufferBuilder();
+
+ dawn::RenderPassEncoder pass = builder.BeginRenderPass(renderPass.renderPassInfo);
+
+ uint32_t zeroOffset = 0;
+ pass.SetVertexBuffers(0, 1, &buffer, &zeroOffset);
+ pass.SetVertexBuffers(1, 1, &buffer, &zeroOffset);
+
+ pass.SetPipeline(vertexPipeline);
+ pass.Draw(1 * 3, 1, 0, 0);
+
+ pass.SetPipeline(instancePipeline);
+ pass.Draw(1 * 3, 4, 0, 0);
+
+ pass.EndPass();
+
+ dawn::CommandBuffer commands = builder.GetResult();
+ queue.Submit(1, &commands);
+
+ CheckResult(1, 4);
+}
+
DAWN_INSTANTIATE_TEST(InputStateTest, D3D12Backend, MetalBackend, OpenGLBackend, VulkanBackend)
// TODO for the input state: