D3D11: Skip UAV count if binding is not visible in fragment stage
This change ensures that the UAV count is not incremented for bindings
that are not visible in fragment stage in BindGroupTrackerD3D11.
Adding a test to cover this case.
Bug:366291600
Change-Id: Ifc3460254936f5750b7b35d2ab64764f8b457384
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/244096
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Quyen Le <lehoangquyen@chromium.org>
Commit-Queue: Shaobo Yan <shaoboyan@microsoft.com>
diff --git a/src/dawn/native/d3d11/BindGroupTrackerD3D11.cpp b/src/dawn/native/d3d11/BindGroupTrackerD3D11.cpp
index 3a8b5cf..84d801d 100644
--- a/src/dawn/native/d3d11/BindGroupTrackerD3D11.cpp
+++ b/src/dawn/native/d3d11/BindGroupTrackerD3D11.cpp
@@ -758,6 +758,12 @@
// D3D11 uav slot allocated in reverse order.
for (BindingIndex bindingIndex : Range(group->GetLayout()->GetBindingCount())) {
const BindingInfo& bindingInfo = group->GetLayout()->GetBindingInfo(bindingIndex);
+
+ // Skip if this binding isn't visible in the fragment shader.
+ if (!(bindingInfo.visibility & wgpu::ShaderStage::Fragment)) {
+ continue;
+ }
+
uint32_t pos = indices[bindingIndex][kFragment] - uavStartSlot;
DAWN_TRY(MatchVariant(
bindingInfo.bindingLayout,
diff --git a/src/dawn/tests/end2end/BindGroupTests.cpp b/src/dawn/tests/end2end/BindGroupTests.cpp
index 72d7254..81bcf94 100644
--- a/src/dawn/tests/end2end/BindGroupTests.cpp
+++ b/src/dawn/tests/end2end/BindGroupTests.cpp
@@ -1620,6 +1620,98 @@
}
}
+// Test a bindgroup has a invisible binding in the fragment stage.
+// This test passes by not asserting or crashing.
+TEST_P(BindGroupTests, BindingInvisibleInFragmentStage) {
+ wgpu::ShaderModule shaderModule = utils::CreateShaderModule(device, R"(
+ struct VertexOut {
+ @builtin(position) position : vec4f,
+ }
+
+ @vertex fn vsMain(@builtin(vertex_index) VertexIndex : u32) -> VertexOut {
+ const pos = array(
+ vec2( 1.0, -1.0),
+ vec2(-1.0, -1.0),
+ vec2( 0.0, 1.0),
+ );
+ var output: VertexOut;
+ output.position = vec4f(pos[VertexIndex], 0.0, 1.0);
+ return output;
+ }
+
+ // to reuse the same pipeline layout
+ @fragment fn fsMain() -> @location(0) vec4f {
+ return vec4f(1.0);
+ }
+
+ @group(0) @binding(0) var<storage, read_write> output : u32;
+
+ @compute @workgroup_size(1, 1, 1)
+ fn csMain() {
+ output = 1u;
+ })");
+
+ // Create storage buffer.
+ wgpu::BufferDescriptor bufferDesc;
+ bufferDesc.size = sizeof(uint32_t) * 4;
+ bufferDesc.usage = wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage;
+ wgpu::Buffer storageBuffer = device.CreateBuffer(&bufferDesc);
+
+ // Create bind group layout.
+ wgpu::BindGroupLayoutEntry entries[1];
+ entries[0].binding = 0;
+ entries[0].visibility = wgpu::ShaderStage::Compute;
+ entries[0].buffer.type = wgpu::BufferBindingType::Storage;
+
+ wgpu::BindGroupLayoutDescriptor bindGroupLayoutDesc;
+ bindGroupLayoutDesc.entryCount = 1;
+ bindGroupLayoutDesc.entries = entries;
+ wgpu::BindGroupLayout bindGroupLayout = device.CreateBindGroupLayout(&bindGroupLayoutDesc);
+
+ // Create bind group.
+ wgpu::BindGroup bindGroup = utils::MakeBindGroup(device, bindGroupLayout, {{0, storageBuffer}});
+
+ // Create pipeline layout.
+ wgpu::PipelineLayoutDescriptor pipelineLayoutDesc;
+ pipelineLayoutDesc.bindGroupLayoutCount = 1;
+ pipelineLayoutDesc.bindGroupLayouts = &bindGroupLayout;
+ wgpu::PipelineLayout pipelineLayout = device.CreatePipelineLayout(&pipelineLayoutDesc);
+
+ // Create render pipeline.
+ utils::ComboRenderPipelineDescriptor renderPipelineDescriptor;
+ renderPipelineDescriptor.vertex.module = shaderModule;
+ renderPipelineDescriptor.cFragment.module = shaderModule;
+ renderPipelineDescriptor.cFragment.targetCount = 1;
+ renderPipelineDescriptor.layout = pipelineLayout;
+ wgpu::RenderPipeline renderPipeline = device.CreateRenderPipeline(&renderPipelineDescriptor);
+
+ // Create compute pipeline.
+ wgpu::ComputePipelineDescriptor computePipelineDescriptor;
+ computePipelineDescriptor.compute.module = shaderModule;
+ computePipelineDescriptor.layout = pipelineLayout;
+ wgpu::ComputePipeline computePipeline =
+ device.CreateComputePipeline(&computePipelineDescriptor);
+
+ // Encode commands to render and compute passes.
+ utils::BasicRenderPass renderPass = utils::CreateBasicRenderPass(device, 1, 1);
+ wgpu::CommandEncoder commandEncoder = device.CreateCommandEncoder();
+ wgpu::RenderPassEncoder renderPassEncoder =
+ commandEncoder.BeginRenderPass(&renderPass.renderPassInfo);
+ renderPassEncoder.SetPipeline(renderPipeline);
+ renderPassEncoder.SetBindGroup(0, bindGroup);
+ renderPassEncoder.Draw(3);
+ renderPassEncoder.End();
+ wgpu::ComputePassEncoder computePassEncoder = commandEncoder.BeginComputePass();
+ computePassEncoder.SetPipeline(computePipeline);
+ computePassEncoder.SetBindGroup(0, bindGroup);
+ computePassEncoder.DispatchWorkgroups(1);
+ computePassEncoder.End();
+ wgpu::CommandBuffer commands = commandEncoder.Finish();
+ queue.Submit(1, &commands);
+ EXPECT_PIXEL_RGBA8_EQ(utils::RGBA8(255, 255, 255, 255), renderPass.color, 0, 0);
+ EXPECT_BUFFER_U32_EQ(1u, storageBuffer, 0);
+}
+
DAWN_INSTANTIATE_TEST(BindGroupTests,
D3D11Backend(),
D3D12Backend(),