D3D11: Move BindGroupTracker's common code to helper functions.
Bug: None
Change-Id: Iccff35c38fe3956b4228f293a451c381d99f5261
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/235094
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Geoff Lang <geofflang@chromium.org>
Commit-Queue: Quyen Le <lehoangquyen@chromium.org>
diff --git a/src/dawn/native/d3d11/BindGroupTrackerD3D11.cpp b/src/dawn/native/d3d11/BindGroupTrackerD3D11.cpp
index add7887..7136145 100644
--- a/src/dawn/native/d3d11/BindGroupTrackerD3D11.cpp
+++ b/src/dawn/native/d3d11/BindGroupTrackerD3D11.cpp
@@ -28,6 +28,8 @@
#include "dawn/native/d3d11/BindGroupTrackerD3D11.h"
#include <algorithm>
+#include <tuple>
+#include <type_traits>
#include <utility>
#include <vector>
@@ -104,6 +106,21 @@
return true;
}
+std::tuple<const BindingInfo&, BufferBinding> ExtractBufferBindingInfo(
+ BindGroupBase* group,
+ BindingIndex bindingIndex,
+ const BufferBindingInfo& layout,
+ const ityp::vector<BindingIndex, uint64_t>& dynamicOffsets) {
+ const BindingInfo& bindingInfo = group->GetLayout()->GetBindingInfo(bindingIndex);
+
+ BufferBinding binding = group->GetBindingAsBufferBinding(bindingIndex);
+ if (layout.hasDynamicOffset) {
+ binding.offset += dynamicOffsets[bindingIndex];
+ }
+
+ return std::make_tuple(std::cref(bindingInfo), std::move(binding));
+}
+
} // namespace
template <typename T, uint32_t InitialCapacity>
@@ -287,11 +304,134 @@
mPSMinUAVSlot = std::min(mPSMinUAVSlot, startSlot);
mPSMaxUAVSlot = std::max(mPSMaxUAVSlot, startSlot + count);
- GetCommandContext()->GetD3D11DeviceContext3()->OMSetRenderTargetsAndUnorderedAccessViews(
+ mCommandContext->GetD3D11DeviceContext3()->OMSetRenderTargetsAndUnorderedAccessViews(
D3D11_KEEP_RENDER_TARGETS_AND_DEPTH_STENCIL, nullptr, nullptr, startSlot, count, uavs,
nullptr);
}
+ResultOrError<BindGroupTracker::ConstantBufferBinding> BindGroupTracker::GetConstantBufferBinding(
+ BindGroupBase* group,
+ BindingIndex bindingIndex,
+ const BufferBindingInfo& layout,
+ const ityp::vector<BindingIndex, uint64_t>& dynamicOffsets) {
+ const auto& [bindingInfo, binding] =
+ ExtractBufferBindingInfo(group, bindingIndex, layout, dynamicOffsets);
+
+ DAWN_ASSERT(layout.type == wgpu::BufferBindingType::Uniform);
+
+ ID3D11Buffer* d3d11Buffer;
+ DAWN_TRY_ASSIGN(d3d11Buffer,
+ ToGPUUsableBuffer(binding.buffer)->GetD3D11ConstantBuffer(mCommandContext));
+ // https://learn.microsoft.com/en-us/windows/win32/api/d3d11_1/nf-d3d11_1-id3d11devicecontext1-vssetconstantbuffers1
+ // Offset and size are measured in shader constants, which are 16 bytes
+ // (4*32-bit components). And the offsets and counts must be multiples
+ // of 16.
+ // WebGPU's minUniformBufferOffsetAlignment is 256.
+ DAWN_ASSERT(IsAligned(binding.offset, 256));
+ UINT firstConstant = static_cast<UINT>(binding.offset / 16);
+ UINT size = static_cast<UINT>(Align(binding.size, 16) / 16);
+ UINT numConstants = Align(size, 16);
+ DAWN_ASSERT(binding.offset + numConstants * 16 <= binding.buffer->GetAllocatedSize());
+
+ return ConstantBufferBinding{d3d11Buffer, firstConstant, numConstants};
+}
+
+template <typename T>
+ResultOrError<ComPtr<T>> BindGroupTracker::GetBufferD3DView(
+ BindGroupBase* group,
+ BindingIndex bindingIndex,
+ const BufferBindingInfo& layout,
+ const ityp::vector<BindingIndex, uint64_t>& dynamicOffsets) {
+ const auto& [bindingInfo, binding] =
+ ExtractBufferBindingInfo(group, bindingIndex, layout, dynamicOffsets);
+
+ if constexpr (std::is_same_v<T, ID3D11ShaderResourceView>) {
+ DAWN_ASSERT(layout.type == wgpu::BufferBindingType::ReadOnlyStorage ||
+ layout.type == kInternalReadOnlyStorageBufferBinding);
+
+ return ToGPUUsableBuffer(binding.buffer)
+ ->UseAsSRV(mCommandContext, binding.offset, binding.size);
+ } else if constexpr (std::is_same_v<T, ID3D11UnorderedAccessView>) {
+ DAWN_ASSERT(layout.type == wgpu::BufferBindingType::Storage ||
+ layout.type == kInternalStorageBufferBinding);
+
+ return ToGPUUsableBuffer(binding.buffer)
+ ->UseAsUAV(mCommandContext, binding.offset, binding.size);
+ } else {
+ DAWN_UNREACHABLE();
+ return ComPtr<T>();
+ }
+}
+
+ResultOrError<ComPtr<ID3D11ShaderResourceView>> BindGroupTracker::GetTextureShaderResourceView(
+ BindGroupBase* group,
+ BindingIndex bindingIndex) {
+ const BindingInfo& bindingInfo = group->GetLayout()->GetBindingInfo(bindingIndex);
+
+ DAWN_ASSERT(std::holds_alternative<TextureBindingInfo>(bindingInfo.bindingLayout) ||
+ std::holds_alternative<StorageTextureBindingInfo>(bindingInfo.bindingLayout));
+
+ if (std::holds_alternative<StorageTextureBindingInfo>(bindingInfo.bindingLayout)) {
+ DAWN_ASSERT(std::get<StorageTextureBindingInfo>(bindingInfo.bindingLayout).access ==
+ wgpu::StorageTextureAccess::ReadOnly);
+ }
+
+ TextureView* view = ToBackend(group->GetBindingAsTextureView(bindingIndex));
+ ComPtr<ID3D11ShaderResourceView> srv;
+
+ if (DAWN_UNLIKELY(view->GetAspects() == Aspect::Stencil)) {
+ // For sampling from stencil, we have to use an internal mirror 'R8Uint' texture.
+ DAWN_TRY_ASSIGN(srv, ToBackend(view->GetTexture())->GetStencilSRV(mCommandContext, view));
+ } else {
+ DAWN_TRY_ASSIGN(srv, view->GetOrCreateD3D11ShaderResourceView());
+ }
+
+ return srv;
+}
+
+ResultOrError<ComPtr<ID3D11UnorderedAccessView>> BindGroupTracker::GetTextureUnorderedAccessView(
+ BindGroupBase* group,
+ BindingIndex bindingIndex) {
+ const BindingInfo& bindingInfo = group->GetLayout()->GetBindingInfo(bindingIndex);
+
+ DAWN_ASSERT(std::holds_alternative<StorageTextureBindingInfo>(bindingInfo.bindingLayout));
+
+ [[maybe_unused]] const auto& layout =
+ std::get<StorageTextureBindingInfo>(bindingInfo.bindingLayout);
+
+ DAWN_ASSERT(layout.access == wgpu::StorageTextureAccess::ReadWrite ||
+ layout.access == wgpu::StorageTextureAccess::WriteOnly);
+
+ TextureView* view = ToBackend(group->GetBindingAsTextureView(bindingIndex));
+ ComPtr<ID3D11UnorderedAccessView> uav;
+
+ DAWN_TRY_ASSIGN(uav, view->GetOrCreateD3D11UnorderedAccessView());
+
+ return uav;
+}
+
+template <typename T>
+ResultOrError<ComPtr<T>> BindGroupTracker::GetTextureD3DView(BindGroupBase* group,
+ BindingIndex bindingIndex) {
+ if constexpr (std::is_same_v<T, ID3D11ShaderResourceView>) {
+ return GetTextureShaderResourceView(group, bindingIndex);
+ } else if constexpr (std::is_same_v<T, ID3D11UnorderedAccessView>) {
+ return GetTextureUnorderedAccessView(group, bindingIndex);
+ }
+
+ DAWN_UNREACHABLE();
+ return ComPtr<T>();
+}
+
+ID3D11SamplerState* BindGroupTracker::GetSamplerState(BindGroupBase* group,
+ BindingIndex bindingIndex) {
+ const BindingInfo& bindingInfo = group->GetLayout()->GetBindingInfo(bindingIndex);
+ DAWN_ASSERT(std::holds_alternative<SamplerBindingInfo>(bindingInfo.bindingLayout));
+
+ Sampler* sampler = ToBackend(group->GetBindingAsSampler(bindingIndex));
+ return sampler->GetD3D11SamplerState();
+}
+
template <wgpu::ShaderStage kVisibleStage>
MaybeError BindGroupTracker::ApplyBindGroup(BindGroupIndex index) {
constexpr wgpu::ShaderStage kVisibleFragment = wgpu::ShaderStage::Fragment & kVisibleStage;
@@ -322,41 +462,28 @@
DAWN_TRY(MatchVariant(
bindingInfo.bindingLayout,
[&](const BufferBindingInfo& layout) -> MaybeError {
- BufferBinding binding = group->GetBindingAsBufferBinding(bindingIndex);
- auto offset = binding.offset;
- if (layout.hasDynamicOffset) {
- // Dynamic buffers are packed at the front of BindingIndices.
- offset += dynamicOffsets[bindingIndex];
- }
-
switch (layout.type) {
case wgpu::BufferBindingType::Uniform: {
- ID3D11Buffer* d3d11Buffer;
- DAWN_TRY_ASSIGN(d3d11Buffer, ToGPUUsableBuffer(binding.buffer)
- ->GetD3D11ConstantBuffer(mCommandContext));
- // https://learn.microsoft.com/en-us/windows/win32/api/d3d11_1/nf-d3d11_1-id3d11devicecontext1-vssetconstantbuffers1
- // Offset and size are measured in shader constants, which are 16 bytes
- // (4*32-bit components). And the offsets and counts must be multiples
- // of 16.
- // WebGPU's minUniformBufferOffsetAlignment is 256.
- DAWN_ASSERT(IsAligned(offset, 256));
- uint32_t firstConstant = static_cast<uint32_t>(offset / 16);
- uint32_t size = static_cast<uint32_t>(Align(binding.size, 16) / 16);
- uint32_t numConstants = Align(size, 16);
- DAWN_ASSERT(offset + numConstants * 16 <=
- binding.buffer->GetAllocatedSize());
+ ConstantBufferBinding bufferBinding;
+ DAWN_TRY_ASSIGN(bufferBinding,
+ this->GetConstantBufferBinding(group, bindingIndex, layout,
+ dynamicOffsets));
+ auto d3d11Buffer = bufferBinding.buffer.Get();
if (bindingVisibility & kVisibleVertex) {
- this->VSSetConstantBuffer(bindingSlotVS, d3d11Buffer, firstConstant,
- numConstants);
+ this->VSSetConstantBuffer(bindingSlotVS, d3d11Buffer,
+ bufferBinding.firstConstant,
+ bufferBinding.numConstants);
}
if (bindingVisibility & kVisibleFragment) {
- this->PSSetConstantBuffer(bindingSlotPS, d3d11Buffer, firstConstant,
- numConstants);
+ this->PSSetConstantBuffer(bindingSlotPS, d3d11Buffer,
+ bufferBinding.firstConstant,
+ bufferBinding.numConstants);
}
if (bindingVisibility & kVisibleCompute) {
- this->CSSetConstantBuffer(bindingSlotCS, d3d11Buffer, firstConstant,
- numConstants);
+ this->CSSetConstantBuffer(bindingSlotCS, d3d11Buffer,
+ bufferBinding.firstConstant,
+ bufferBinding.numConstants);
}
break;
}
@@ -368,8 +495,8 @@
if (bindingVisibility & wgpu::ShaderStage::Compute) {
ComPtr<ID3D11UnorderedAccessView> d3d11UAV;
DAWN_TRY_ASSIGN(d3d11UAV,
- ToGPUUsableBuffer(binding.buffer)
- ->UseAsUAV(mCommandContext, offset, binding.size));
+ GetBufferD3DView<ID3D11UnorderedAccessView>(
+ group, bindingIndex, layout, dynamicOffsets));
this->CSSetUnorderedAccessView(bindingSlotCS, d3d11UAV);
}
break;
@@ -377,9 +504,8 @@
case wgpu::BufferBindingType::ReadOnlyStorage:
case kInternalReadOnlyStorageBufferBinding: {
ComPtr<ID3D11ShaderResourceView> d3d11SRV;
- DAWN_TRY_ASSIGN(d3d11SRV,
- ToGPUUsableBuffer(binding.buffer)
- ->UseAsSRV(mCommandContext, offset, binding.size));
+ DAWN_TRY_ASSIGN(d3d11SRV, GetBufferD3DView<ID3D11ShaderResourceView>(
+ group, bindingIndex, layout, dynamicOffsets));
if (bindingVisibility & kVisibleVertex) {
this->VSSetShaderResource(bindingSlotVS, d3d11SRV);
}
@@ -404,8 +530,7 @@
return {};
},
[&](const SamplerBindingInfo&) -> MaybeError {
- Sampler* sampler = ToBackend(group->GetBindingAsSampler(bindingIndex));
- ID3D11SamplerState* d3d11SamplerState = sampler->GetD3D11SamplerState();
+ ID3D11SamplerState* d3d11SamplerState = GetSamplerState(group, bindingIndex);
if (bindingVisibility & kVisibleVertex) {
this->VSSetSampler(bindingSlotVS, d3d11SamplerState);
}
@@ -418,15 +543,10 @@
return {};
},
[&](const TextureBindingInfo&) -> MaybeError {
- TextureView* view = ToBackend(group->GetBindingAsTextureView(bindingIndex));
ComPtr<ID3D11ShaderResourceView> srv;
- // For sampling from stencil, we have to use an internal mirror 'R8Uint' texture.
- if (view->GetAspects() == Aspect::Stencil) {
- DAWN_TRY_ASSIGN(
- srv, ToBackend(view->GetTexture())->GetStencilSRV(mCommandContext, view));
- } else {
- DAWN_TRY_ASSIGN(srv, view->GetOrCreateD3D11ShaderResourceView());
- }
+ DAWN_TRY_ASSIGN(srv,
+ GetTextureD3DView<ID3D11ShaderResourceView>(group, bindingIndex));
+
if (bindingVisibility & kVisibleVertex) {
this->VSSetShaderResource(bindingSlotVS, srv);
}
@@ -439,7 +559,6 @@
return {};
},
[&](const StorageTextureBindingInfo& layout) -> MaybeError {
- TextureView* view = ToBackend(group->GetBindingAsTextureView(bindingIndex));
switch (layout.access) {
case wgpu::StorageTextureAccess::WriteOnly:
case wgpu::StorageTextureAccess::ReadWrite: {
@@ -447,15 +566,17 @@
// OMSetRenderTargetsAndUnorderedAccessViews call to set all UAVs.
// Delegate to RenderPassBindGroupTracker::Apply.
if (bindingVisibility & kVisibleCompute) {
- ID3D11UnorderedAccessView* d3d11UAV = nullptr;
- DAWN_TRY_ASSIGN(d3d11UAV, view->GetOrCreateD3D11UnorderedAccessView());
+ ComPtr<ID3D11UnorderedAccessView> d3d11UAV = nullptr;
+ DAWN_TRY_ASSIGN(d3d11UAV, GetTextureD3DView<ID3D11UnorderedAccessView>(
+ group, bindingIndex));
this->CSSetUnorderedAccessView(bindingSlotCS, d3d11UAV);
}
break;
}
case wgpu::StorageTextureAccess::ReadOnly: {
- ID3D11ShaderResourceView* d3d11SRV = nullptr;
- DAWN_TRY_ASSIGN(d3d11SRV, view->GetOrCreateD3D11ShaderResourceView());
+ ComPtr<ID3D11ShaderResourceView> d3d11SRV = nullptr;
+ DAWN_TRY_ASSIGN(d3d11SRV, GetTextureD3DView<ID3D11ShaderResourceView>(
+ group, bindingIndex));
if (bindingVisibility & kVisibleVertex) {
this->VSSetShaderResource(bindingSlotVS, d3d11SRV);
}
@@ -480,10 +601,6 @@
return {};
}
-const ScopedSwapStateCommandRecordingContext* BindGroupTracker::GetCommandContext() const {
- return mCommandContext.get();
-}
-
ComputePassBindGroupTracker::ComputePassBindGroupTracker(
const ScopedSwapStateCommandRecordingContext* commandContext)
: BindGroupTracker(commandContext) {}
@@ -621,20 +738,13 @@
DAWN_TRY(MatchVariant(
bindingInfo.bindingLayout,
[&](const BufferBindingInfo& layout) -> MaybeError {
- BufferBinding binding = group->GetBindingAsBufferBinding(bindingIndex);
- auto offset = binding.offset;
- if (layout.hasDynamicOffset) {
- // Dynamic buffers are packed at the front of BindingIndices.
- offset += dynamicOffsets[bindingIndex];
- }
-
switch (layout.type) {
case wgpu::BufferBindingType::Storage:
case kInternalStorageBufferBinding: {
ComPtr<ID3D11UnorderedAccessView> d3d11UAV;
- DAWN_TRY_ASSIGN(d3d11UAV, ToGPUUsableBuffer(binding.buffer)
- ->UseAsUAV(GetCommandContext(), offset,
- binding.size));
+ DAWN_TRY_ASSIGN(d3d11UAV,
+ GetBufferD3DView<ID3D11UnorderedAccessView>(
+ group, bindingIndex, layout, dynamicOffsets));
uavsInBindGroup[pos] = std::move(d3d11UAV);
break;
}
@@ -654,9 +764,8 @@
case wgpu::StorageTextureAccess::WriteOnly:
case wgpu::StorageTextureAccess::ReadWrite: {
ComPtr<ID3D11UnorderedAccessView> d3d11UAV;
- TextureView* view =
- ToBackend(group->GetBindingAsTextureView(bindingIndex));
- DAWN_TRY_ASSIGN(d3d11UAV, view->GetOrCreateD3D11UnorderedAccessView());
+ DAWN_TRY_ASSIGN(d3d11UAV, GetTextureD3DView<ID3D11UnorderedAccessView>(
+ group, bindingIndex));
uavsInBindGroup[pos] = std::move(d3d11UAV);
break;
}
diff --git a/src/dawn/native/d3d11/BindGroupTrackerD3D11.h b/src/dawn/native/d3d11/BindGroupTrackerD3D11.h
index fa7f6ac..cf892a4 100644
--- a/src/dawn/native/d3d11/BindGroupTrackerD3D11.h
+++ b/src/dawn/native/d3d11/BindGroupTrackerD3D11.h
@@ -50,8 +50,6 @@
explicit BindGroupTracker(const ScopedSwapStateCommandRecordingContext* commandContext);
virtual ~BindGroupTracker();
- const ScopedSwapStateCommandRecordingContext* GetCommandContext() const;
-
protected:
template <wgpu::ShaderStage kVisibleStage>
MaybeError ApplyBindGroup(BindGroupIndex index);
@@ -79,7 +77,43 @@
uint32_t count,
ID3D11UnorderedAccessView* const* uavs);
+ template <typename T>
+ ResultOrError<ComPtr<T>> GetBufferD3DView(
+ BindGroupBase* group,
+ BindingIndex bindingIndex,
+ const BufferBindingInfo& layout,
+ const ityp::vector<BindingIndex, uint64_t>& dynamicOffsets);
+
+ template <typename T>
+ ResultOrError<ComPtr<T>> GetTextureD3DView(BindGroupBase* group, BindingIndex bindingIndex);
+
private:
+ struct ConstantBufferBinding {
+ bool operator==(const ConstantBufferBinding& rhs) const {
+ return buffer.Get() == rhs.buffer.Get() && firstConstant == rhs.firstConstant &&
+ numConstants == rhs.numConstants;
+ }
+
+ ComPtr<ID3D11Buffer> buffer;
+ UINT firstConstant = 0;
+ UINT numConstants = 0;
+ };
+
+ ResultOrError<ConstantBufferBinding> GetConstantBufferBinding(
+ BindGroupBase* group,
+ BindingIndex bindingIndex,
+ const BufferBindingInfo& layout,
+ const ityp::vector<BindingIndex, uint64_t>& dynamicOffsets);
+
+ ResultOrError<ComPtr<ID3D11ShaderResourceView>> GetTextureShaderResourceView(
+ BindGroupBase* group,
+ BindingIndex bindingIndex);
+ ResultOrError<ComPtr<ID3D11UnorderedAccessView>> GetTextureUnorderedAccessView(
+ BindGroupBase* group,
+ BindingIndex bindingIndex);
+
+ ID3D11SamplerState* GetSamplerState(BindGroupBase* group, BindingIndex bindingIndex);
+
raw_ptr<const ScopedSwapStateCommandRecordingContext> mCommandContext;
// This class will track the current bound resources and prevent redundant bindings.
@@ -95,17 +129,6 @@
absl::InlinedVector<T, InitialCapacity> mBoundSlots;
};
- struct ConstantBufferBinding {
- bool operator==(const ConstantBufferBinding& rhs) const {
- return buffer.Get() == rhs.buffer.Get() && firstConstant == rhs.firstConstant &&
- numConstants == rhs.numConstants;
- }
-
- ComPtr<ID3D11Buffer> buffer;
- UINT firstConstant = 0;
- UINT numConstants = 0;
- };
-
BindingSlot<ConstantBufferBinding, 4> mVSConstantBufferSlots;
BindingSlot<ConstantBufferBinding, 4> mPSConstantBufferSlots;
BindingSlot<ConstantBufferBinding, 4> mCSConstantBufferSlots;
diff --git a/src/dawn/native/d3d11/BufferD3D11.cpp b/src/dawn/native/d3d11/BufferD3D11.cpp
index 36a5fb9..fca418e 100644
--- a/src/dawn/native/d3d11/BufferD3D11.cpp
+++ b/src/dawn/native/d3d11/BufferD3D11.cpp
@@ -1257,7 +1257,7 @@
->CreateShaderResourceView(d3d11Buffer, &desc, &srv),
"ShaderResourceView creation"));
- return srv;
+ return std::move(srv);
}
ResultOrError<ComPtr<ID3D11UnorderedAccessView1>>
@@ -1284,7 +1284,7 @@
->CreateUnorderedAccessView1(d3d11Buffer, &desc, &uav),
"UnorderedAccessView creation"));
- return uav;
+ return std::move(uav);
}
ResultOrError<ComPtr<ID3D11ShaderResourceView>> GPUUsableBuffer::UseAsSRV(
@@ -1306,10 +1306,10 @@
mSRVCache[key] = srv;
- return srv;
+ return std::move(srv);
}
-ResultOrError<ComPtr<ID3D11UnorderedAccessView1>> GPUUsableBuffer::UseAsUAV(
+ResultOrError<ComPtr<ID3D11UnorderedAccessView>> GPUUsableBuffer::UseAsUAV(
const ScopedCommandRecordingContext* commandContext,
uint64_t offset,
uint64_t size) {
@@ -1333,7 +1333,7 @@
// Since UAV will modify the storage's content, increment its revision.
IncrStorageRevAndMakeLatest(commandContext, storage);
- return uav;
+ return ComPtr<ID3D11UnorderedAccessView>(std::move(uav));
}
MaybeError GPUUsableBuffer::UpdateD3D11ConstantBuffer(
diff --git a/src/dawn/native/d3d11/BufferD3D11.h b/src/dawn/native/d3d11/BufferD3D11.h
index bf09d4d..a7f04fb 100644
--- a/src/dawn/native/d3d11/BufferD3D11.h
+++ b/src/dawn/native/d3d11/BufferD3D11.h
@@ -215,7 +215,7 @@
ResultOrError<ComPtr<ID3D11ShaderResourceView>>
UseAsSRV(const ScopedCommandRecordingContext* commandContext, uint64_t offset, uint64_t size);
- ResultOrError<ComPtr<ID3D11UnorderedAccessView1>>
+ ResultOrError<ComPtr<ID3D11UnorderedAccessView>>
UseAsUAV(const ScopedCommandRecordingContext* commandContext, uint64_t offset, uint64_t size);
MaybeError PredicatedClear(const ScopedSwapStateCommandRecordingContext* commandContext,