D3D12: Factor SetVertexBuffer tracking to match other tracking classes
Bug: dawn:201
Change-Id: I711e93a706b5043318263b203d3f3dc7f1a675bb
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/11000
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
Commit-Queue: Austin Eng <enga@chromium.org>
diff --git a/src/dawn_native/d3d12/CommandBufferD3D12.cpp b/src/dawn_native/d3d12/CommandBufferD3D12.cpp
index 7ea4e40..aeb87d7 100644
--- a/src/dawn_native/d3d12/CommandBufferD3D12.cpp
+++ b/src/dawn_native/d3d12/CommandBufferD3D12.cpp
@@ -38,6 +38,7 @@
namespace dawn_native { namespace d3d12 {
namespace {
+
DXGI_FORMAT DXGIIndexFormat(dawn::IndexFormat format) {
switch (format) {
case dawn::IndexFormat::Uint16:
@@ -63,6 +64,12 @@
return false;
}
+ struct OMSetRenderTargetArgs {
+ unsigned int numRTVs = 0;
+ std::array<D3D12_CPU_DESCRIPTOR_HANDLE, kMaxColorAttachments> RTVs = {};
+ D3D12_CPU_DESCRIPTOR_HANDLE dsv = {};
+ };
+
} // anonymous namespace
class BindGroupStateTracker {
@@ -291,12 +298,6 @@
Device* mDevice;
};
- struct OMSetRenderTargetArgs {
- unsigned int numRTVs = 0;
- std::array<D3D12_CPU_DESCRIPTOR_HANDLE, kMaxColorAttachments> RTVs = {};
- D3D12_CPU_DESCRIPTOR_HANDLE dsv = {};
- };
-
class RenderPassDescriptorHeapTracker {
public:
RenderPassDescriptorHeapTracker(Device* device) : mDevice(device) {
@@ -325,8 +326,8 @@
}
}
- // TODO(jiawei.shao@intel.com): use hash map <RenderPass, OMSetRenderTargetArgs> as cache to
- // avoid redundant RTV and DSV memory allocations.
+ // TODO(jiawei.shao@intel.com): use hash map <RenderPass, OMSetRenderTargetArgs> as
+ // cache to avoid redundant RTV and DSV memory allocations.
OMSetRenderTargetArgs GetSubpassOMSetRenderTargetArgs(BeginRenderPassCmd* renderPass) {
OMSetRenderTargetArgs args = {};
@@ -380,6 +381,73 @@
namespace {
+ class VertexBufferTracker {
+ public:
+ void OnSetVertexBuffers(uint32_t startSlot,
+ uint32_t count,
+ Ref<BufferBase>* buffers,
+ uint64_t* offsets) {
+ mStartSlot = std::min(mStartSlot, startSlot);
+ mEndSlot = std::max(mEndSlot, startSlot + count);
+
+ for (uint32_t i = 0; i < count; ++i) {
+ Buffer* buffer = ToBackend(buffers[i].Get());
+ auto* d3d12BufferView = &mD3D12BufferViews[startSlot + i];
+ d3d12BufferView->BufferLocation = buffer->GetVA() + offsets[i];
+ d3d12BufferView->SizeInBytes = buffer->GetSize() - offsets[i];
+ // The bufferView stride is set based on the input state before a draw.
+ }
+ }
+
+ void Apply(ID3D12GraphicsCommandList* commandList,
+ const RenderPipeline* renderPipeline) {
+ ASSERT(renderPipeline != nullptr);
+
+ std::bitset<kMaxVertexBuffers> inputsMask = renderPipeline->GetInputsSetMask();
+
+ uint32_t startSlot = mStartSlot;
+ uint32_t endSlot = mEndSlot;
+
+ // If the input state has changed, we need to update the StrideInBytes
+ // for the D3D12 buffer views. We also need to extend the dirty range to
+ // touch all these slots because the stride may have changed.
+ if (mLastAppliedRenderPipeline != renderPipeline) {
+ mLastAppliedRenderPipeline = renderPipeline;
+
+ for (uint32_t slot : IterateBitSet(inputsMask)) {
+ startSlot = std::min(startSlot, slot);
+ endSlot = std::max(endSlot, slot + 1);
+ mD3D12BufferViews[slot].StrideInBytes =
+ renderPipeline->GetInput(slot).stride;
+ }
+ }
+
+ if (endSlot <= startSlot) {
+ return;
+ }
+
+ // mD3D12BufferViews is kept up to date with the most recent data passed
+ // to SetVertexBuffers. This makes it correct to only track the start
+ // and end of the dirty range. When Apply is called,
+ // we will at worst set non-dirty vertex buffers in duplicate.
+ uint32_t count = endSlot - startSlot;
+ commandList->IASetVertexBuffers(startSlot, count, &mD3D12BufferViews[startSlot]);
+
+ mStartSlot = kMaxVertexBuffers;
+ mEndSlot = 0;
+ }
+
+ private:
+ // startSlot and endSlot indicate the range of dirty vertex buffers.
+ // If there are multiple calls to SetVertexBuffers, the start and end
+ // represent the union of the dirty ranges (the union may have non-dirty
+ // data in the middle of the range).
+ const RenderPipeline* mLastAppliedRenderPipeline = nullptr;
+ uint32_t mStartSlot = kMaxVertexBuffers;
+ uint32_t mEndSlot = 0;
+ std::array<D3D12_VERTEX_BUFFER_VIEW, kMaxVertexBuffers> mD3D12BufferViews = {};
+ };
+
void AllocateAndSetDescriptorHeaps(Device* device,
BindGroupStateTracker* bindingTracker,
RenderPassDescriptorHeapTracker* renderPassTracker,
@@ -719,47 +787,6 @@
DAWN_ASSERT(renderPassTracker.IsHeapAllocationCompleted());
}
- void CommandBuffer::FlushSetVertexBuffers(ComPtr<ID3D12GraphicsCommandList> commandList,
- VertexBuffersInfo* vertexBuffersInfo,
- const RenderPipeline* renderPipeline) {
- DAWN_ASSERT(vertexBuffersInfo != nullptr);
- DAWN_ASSERT(renderPipeline != nullptr);
-
- auto inputsMask = renderPipeline->GetInputsSetMask();
-
- uint32_t startSlot = vertexBuffersInfo->startSlot;
- uint32_t endSlot = vertexBuffersInfo->endSlot;
-
- // If the input state has changed, we need to update the StrideInBytes
- // for the D3D12 buffer views. We also need to extend the dirty range to
- // touch all these slots because the stride may have changed.
- if (vertexBuffersInfo->lastRenderPipeline != renderPipeline) {
- vertexBuffersInfo->lastRenderPipeline = renderPipeline;
-
- for (uint32_t slot : IterateBitSet(inputsMask)) {
- startSlot = std::min(startSlot, slot);
- endSlot = std::max(endSlot, slot + 1);
- vertexBuffersInfo->d3d12BufferViews[slot].StrideInBytes =
- renderPipeline->GetInput(slot).stride;
- }
- }
-
- if (endSlot <= startSlot) {
- return;
- }
-
- // d3d12BufferViews is kept up to date with the most recent data passed
- // to SetVertexBuffers. This makes it correct to only track the start
- // and end of the dirty range. When FlushSetVertexBuffers is called,
- // we will at worst set non-dirty vertex buffers in duplicate.
- uint32_t count = endSlot - startSlot;
- commandList->IASetVertexBuffers(startSlot, count,
- &vertexBuffersInfo->d3d12BufferViews[startSlot]);
-
- vertexBuffersInfo->startSlot = kMaxVertexBuffers;
- vertexBuffersInfo->endSlot = 0;
- }
-
void CommandBuffer::RecordComputePass(ComPtr<ID3D12GraphicsCommandList> commandList,
BindGroupStateTracker* bindingTracker) {
PipelineLayout* lastLayout = nullptr;
@@ -969,14 +996,14 @@
RenderPipeline* lastPipeline = nullptr;
PipelineLayout* lastLayout = nullptr;
- VertexBuffersInfo vertexBuffersInfo = {};
+ VertexBufferTracker vertexBufferTracker = {};
auto EncodeRenderBundleCommand = [&](CommandIterator* iter, Command type) {
switch (type) {
case Command::Draw: {
DrawCmd* draw = iter->NextCommand<DrawCmd>();
- FlushSetVertexBuffers(commandList, &vertexBuffersInfo, lastPipeline);
+ vertexBufferTracker.Apply(commandList.Get(), lastPipeline);
commandList->DrawInstanced(draw->vertexCount, draw->instanceCount,
draw->firstVertex, draw->firstInstance);
} break;
@@ -984,7 +1011,7 @@
case Command::DrawIndexed: {
DrawIndexedCmd* draw = iter->NextCommand<DrawIndexedCmd>();
- FlushSetVertexBuffers(commandList, &vertexBuffersInfo, lastPipeline);
+ vertexBufferTracker.Apply(commandList.Get(), lastPipeline);
commandList->DrawIndexedInstanced(draw->indexCount, draw->instanceCount,
draw->firstIndex, draw->baseVertex,
draw->firstInstance);
@@ -993,7 +1020,7 @@
case Command::DrawIndirect: {
DrawIndirectCmd* draw = iter->NextCommand<DrawIndirectCmd>();
- FlushSetVertexBuffers(commandList, &vertexBuffersInfo, lastPipeline);
+ vertexBufferTracker.Apply(commandList.Get(), lastPipeline);
Buffer* buffer = ToBackend(draw->indirectBuffer.Get());
ComPtr<ID3D12CommandSignature> signature =
ToBackend(GetDevice())->GetDrawIndirectSignature();
@@ -1005,7 +1032,7 @@
case Command::DrawIndexedIndirect: {
DrawIndexedIndirectCmd* draw = iter->NextCommand<DrawIndexedIndirectCmd>();
- FlushSetVertexBuffers(commandList, &vertexBuffersInfo, lastPipeline);
+ vertexBufferTracker.Apply(commandList.Get(), lastPipeline);
Buffer* buffer = ToBackend(draw->indirectBuffer.Get());
ComPtr<ID3D12CommandSignature> signature =
ToBackend(GetDevice())->GetDrawIndexedIndirectSignature();
@@ -1096,22 +1123,11 @@
case Command::SetVertexBuffers: {
SetVertexBuffersCmd* cmd = iter->NextCommand<SetVertexBuffersCmd>();
- auto buffers = iter->NextData<Ref<BufferBase>>(cmd->count);
- auto offsets = iter->NextData<uint64_t>(cmd->count);
+ Ref<BufferBase>* buffers = iter->NextData<Ref<BufferBase>>(cmd->count);
+ uint64_t* offsets = iter->NextData<uint64_t>(cmd->count);
- vertexBuffersInfo.startSlot =
- std::min(vertexBuffersInfo.startSlot, cmd->startSlot);
- vertexBuffersInfo.endSlot =
- std::max(vertexBuffersInfo.endSlot, cmd->startSlot + cmd->count);
-
- for (uint32_t i = 0; i < cmd->count; ++i) {
- Buffer* buffer = ToBackend(buffers[i].Get());
- auto* d3d12BufferView =
- &vertexBuffersInfo.d3d12BufferViews[cmd->startSlot + i];
- d3d12BufferView->BufferLocation = buffer->GetVA() + offsets[i];
- d3d12BufferView->SizeInBytes = buffer->GetSize() - offsets[i];
- // The bufferView stride is set based on the input state before a draw.
- }
+ vertexBufferTracker.OnSetVertexBuffers(cmd->startSlot, cmd->count, buffers,
+ offsets);
} break;
default:
diff --git a/src/dawn_native/d3d12/CommandBufferD3D12.h b/src/dawn_native/d3d12/CommandBufferD3D12.h
index a367b4b..78c5630 100644
--- a/src/dawn_native/d3d12/CommandBufferD3D12.h
+++ b/src/dawn_native/d3d12/CommandBufferD3D12.h
@@ -35,17 +35,6 @@
class RenderPassDescriptorHeapTracker;
class RenderPipeline;
- struct VertexBuffersInfo {
- // startSlot and endSlot indicate the range of dirty vertex buffers.
- // If there are multiple calls to SetVertexBuffers, the start and end
- // represent the union of the dirty ranges (the union may have non-dirty
- // data in the middle of the range).
- const RenderPipeline* lastRenderPipeline = nullptr;
- uint32_t startSlot = kMaxVertexBuffers;
- uint32_t endSlot = 0;
- std::array<D3D12_VERTEX_BUFFER_VIEW, kMaxVertexBuffers> d3d12BufferViews = {};
- };
-
class CommandBuffer : public CommandBufferBase {
public:
CommandBuffer(CommandEncoderBase* encoder, const CommandBufferDescriptor* descriptor);
@@ -54,9 +43,6 @@
void RecordCommands(ComPtr<ID3D12GraphicsCommandList> commandList, uint32_t indexInSubmit);
private:
- void FlushSetVertexBuffers(ComPtr<ID3D12GraphicsCommandList> commandList,
- VertexBuffersInfo* vertexBuffersInfo,
- const RenderPipeline* lastRenderPipeline);
void RecordComputePass(ComPtr<ID3D12GraphicsCommandList> commandList,
BindGroupStateTracker* bindingTracker);
void RecordRenderPass(ComPtr<ID3D12GraphicsCommandList> commandList,