d3d12: track graphics/compute state independently

Fixes a bug where Dawn incorrectly did not re-apply state
when transitioning between compute and render passes. If
a compute and render pipeline share the same pipeline layout,
all of the resources for the graphics pipeline need to be rebound
since the graphics state in D3D12 is disjoint from the compute
state.

Fixed: dawn:1689
Change-Id: I7d25a1c7954039c4130e67b682ebc05324353e9a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/124540
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Austin Eng <enga@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Loko Kung <lokokung@google.com>
diff --git a/src/dawn/native/d3d12/CommandBufferD3D12.cpp b/src/dawn/native/d3d12/CommandBufferD3D12.cpp
index 8ebdc73..a57fccc 100644
--- a/src/dawn/native/d3d12/CommandBufferD3D12.cpp
+++ b/src/dawn/native/d3d12/CommandBufferD3D12.cpp
@@ -383,18 +383,20 @@
 
 }  // anonymous namespace
 
+class DescriptorHeapState;
+
 class BindGroupStateTracker : public BindGroupTrackerBase<false, uint64_t> {
     using Base = BindGroupTrackerBase;
 
   public:
-    explicit BindGroupStateTracker(Device* device)
+    BindGroupStateTracker(Device* device, DescriptorHeapState* heapState, bool inCompute)
         : BindGroupTrackerBase(),
           mDevice(device),
+          mHeapState(heapState),
+          mInCompute(inCompute),
           mViewAllocator(device->GetViewShaderVisibleDescriptorAllocator()),
           mSamplerAllocator(device->GetSamplerShaderVisibleDescriptorAllocator()) {}
 
-    void SetInComputePass(bool inCompute_) { mInCompute = inCompute_; }
-
     MaybeError Apply(CommandRecordingContext* commandContext) {
         BeforeApply();
 
@@ -454,20 +456,9 @@
         return {};
     }
 
-    void SetID3D12DescriptorHeaps(ID3D12GraphicsCommandList* commandList) {
-        ASSERT(commandList != nullptr);
-        std::array<ID3D12DescriptorHeap*, 2> descriptorHeaps = {
-            mViewAllocator->GetShaderVisibleHeap(), mSamplerAllocator->GetShaderVisibleHeap()};
-        ASSERT(descriptorHeaps[0] != nullptr);
-        ASSERT(descriptorHeaps[1] != nullptr);
-        commandList->SetDescriptorHeaps(descriptorHeaps.size(), descriptorHeaps.data());
+    void ResetRootSamplerTables() { mBoundRootSamplerTables = {}; }
 
-        // Descriptor table state is undefined at the beginning of a command list and after
-        // descriptor heaps are changed on a command list. Invalidate the root sampler tables to
-        // reset the root descriptor table for samplers, otherwise the shader cannot access the
-        // descriptor heaps.
-        mBoundRootSamplerTables = {};
-    }
+    void SetID3D12DescriptorHeaps(ID3D12GraphicsCommandList* commandList);
 
   private:
     void UpdateRootSignatureIfNecessary(ID3D12GraphicsCommandList* commandList) {
@@ -480,7 +471,7 @@
                     ToBackend(mPipelineLayout)->GetRootSignature());
             }
             // Invalidate the root sampler tables previously set in the root signature.
-            mBoundRootSamplerTables = {};
+            ResetRootSamplerTables();
         }
     }
 
@@ -607,6 +598,7 @@
     }
 
     Device* mDevice;
+    DescriptorHeapState* mHeapState;
 
     bool mInCompute = false;
 
@@ -617,6 +609,43 @@
     ShaderVisibleDescriptorAllocator* mSamplerAllocator;
 };
 
+class DescriptorHeapState {
+  public:
+    explicit DescriptorHeapState(Device* device)
+        : mDevice(device),
+          mComputeBindingTracker(device, this, true),
+          mGraphicsBindingTracker(device, this, false) {}
+
+    void SetID3D12DescriptorHeaps(ID3D12GraphicsCommandList* commandList) {
+        ASSERT(commandList != nullptr);
+        std::array<ID3D12DescriptorHeap*, 2> descriptorHeaps = {
+            mDevice->GetViewShaderVisibleDescriptorAllocator()->GetShaderVisibleHeap(),
+            mDevice->GetSamplerShaderVisibleDescriptorAllocator()->GetShaderVisibleHeap()};
+        ASSERT(descriptorHeaps[0] != nullptr);
+        ASSERT(descriptorHeaps[1] != nullptr);
+        commandList->SetDescriptorHeaps(descriptorHeaps.size(), descriptorHeaps.data());
+
+        // Descriptor table state is undefined at the beginning of a command list and after
+        // descriptor heaps are changed on a command list. Invalidate the root sampler tables to
+        // reset the root descriptor table for samplers, otherwise the shader cannot access the
+        // descriptor heaps.
+        mComputeBindingTracker.ResetRootSamplerTables();
+        mGraphicsBindingTracker.ResetRootSamplerTables();
+    }
+
+    BindGroupStateTracker* GetComputeBindingTracker() { return &mComputeBindingTracker; }
+    BindGroupStateTracker* GetGraphicsBindingTracker() { return &mGraphicsBindingTracker; }
+
+  private:
+    Device* mDevice;
+    BindGroupStateTracker mComputeBindingTracker;
+    BindGroupStateTracker mGraphicsBindingTracker;
+};
+
+void BindGroupStateTracker::SetID3D12DescriptorHeaps(ID3D12GraphicsCommandList* commandList) {
+    mHeapState->SetID3D12DescriptorHeaps(commandList);
+}
+
 namespace {
 class VertexBufferTracker {
   public:
@@ -726,13 +755,12 @@
 
 MaybeError CommandBuffer::RecordCommands(CommandRecordingContext* commandContext) {
     Device* device = ToBackend(GetDevice());
-    BindGroupStateTracker bindingTracker(device);
 
-    ID3D12GraphicsCommandList* commandList = commandContext->GetCommandList();
-
+    DescriptorHeapState descriptorHeapState(device);
     // Make sure we use the correct descriptors for this command list. Could be done once per
     // actual command list but here is ok because there should be few command buffers.
-    bindingTracker.SetID3D12DescriptorHeaps(commandList);
+    ID3D12GraphicsCommandList* commandList = commandContext->GetCommandList();
+    descriptorHeapState.SetID3D12DescriptorHeaps(commandList);
 
     size_t nextComputePassNumber = 0;
     size_t nextRenderPassNumber = 0;
@@ -743,11 +771,9 @@
             case Command::BeginComputePass: {
                 BeginComputePassCmd* cmd = mCommands.NextCommand<BeginComputePassCmd>();
 
-                bindingTracker.SetInComputePass(true);
-
-                DAWN_TRY(
-                    RecordComputePass(commandContext, &bindingTracker, cmd,
-                                      GetResourceUsages().computePasses[nextComputePassNumber]));
+                DAWN_TRY(RecordComputePass(
+                    commandContext, descriptorHeapState.GetComputeBindingTracker(), cmd,
+                    GetResourceUsages().computePasses[nextComputePassNumber]));
 
                 nextComputePassNumber++;
                 break;
@@ -761,11 +787,11 @@
                 DAWN_TRY(TransitionAndClearForSyncScope(
                     commandContext, GetResourceUsages().renderPasses[nextRenderPassNumber],
                     &passHasUAV));
-                bindingTracker.SetInComputePass(false);
 
                 LazyClearRenderPassAttachments(beginRenderPassCmd);
-                DAWN_TRY(RecordRenderPass(commandContext, &bindingTracker, beginRenderPassCmd,
-                                          passHasUAV));
+                DAWN_TRY(RecordRenderPass(commandContext,
+                                          descriptorHeapState.GetGraphicsBindingTracker(),
+                                          beginRenderPassCmd, passHasUAV));
 
                 nextRenderPassNumber++;
                 break;
diff --git a/src/dawn/tests/end2end/PipelineLayoutTests.cpp b/src/dawn/tests/end2end/PipelineLayoutTests.cpp
index 2b5a9c7..dca4be3 100644
--- a/src/dawn/tests/end2end/PipelineLayoutTests.cpp
+++ b/src/dawn/tests/end2end/PipelineLayoutTests.cpp
@@ -16,6 +16,7 @@
 
 #include "dawn/common/Constants.h"
 #include "dawn/tests/DawnTest.h"
+#include "dawn/utils/WGPUHelpers.h"
 
 class PipelineLayoutTests : public DawnTest {};
 
@@ -68,6 +69,80 @@
     device.CreatePipelineLayout(&descriptor);
 }
 
+// Regression test for crbug.com/dawn/1689. Test using a compute pass and a render pass,
+// where the two pipelines have the same pipeline layout.
+TEST_P(PipelineLayoutTests, ComputeAndRenderSamePipelineLayout) {
+    wgpu::TextureFormat format = wgpu::TextureFormat::RGBA8Unorm;
+    wgpu::ShaderModule shaderModule = utils::CreateShaderModule(device, R"(
+        @compute @workgroup_size(8, 8)
+        fn computeMain() {}
+
+        @vertex fn vertexMain() -> @builtin(position) vec4f {
+            return vec4f(0.0);
+        }
+
+        @fragment fn fragmentMain() -> @location(0) vec4f {
+            return vec4f(0.0, 0.0, 0.0, 1.0);
+        }
+    )");
+
+    wgpu::BindGroupLayout bgl = utils::MakeBindGroupLayout(
+        device, {{0, wgpu::ShaderStage::Compute, wgpu::BufferBindingType::Uniform}});
+
+    wgpu::PipelineLayout pl = utils::MakeBasicPipelineLayout(device, &bgl);
+    wgpu::ComputePipeline computePipeline;
+    {
+        wgpu::ComputePipelineDescriptor desc = {};
+        desc.layout = pl;
+        desc.compute.module = shaderModule;
+        desc.compute.entryPoint = "computeMain";
+        computePipeline = device.CreateComputePipeline(&desc);
+    }
+    wgpu::RenderPipeline renderPipeline;
+    {
+        wgpu::RenderPipelineDescriptor desc = {};
+        desc.layout = pl;
+        desc.vertex.module = shaderModule;
+        desc.vertex.entryPoint = "vertexMain";
+
+        wgpu::FragmentState fragment = {};
+        desc.fragment = &fragment;
+        fragment.module = shaderModule;
+        fragment.entryPoint = "fragmentMain";
+        fragment.targetCount = 1;
+
+        wgpu::ColorTargetState colorTargetState = {};
+        colorTargetState.format = format;
+        fragment.targets = &colorTargetState;
+
+        renderPipeline = device.CreateRenderPipeline(&desc);
+    }
+
+    wgpu::Buffer buffer = utils::CreateBufferFromData(device, wgpu::BufferUsage::Uniform, {1});
+    wgpu::BindGroup bg0 = utils::MakeBindGroup(device, bgl, {{0, buffer}});
+    wgpu::BindGroup bg1 = utils::MakeBindGroup(device, bgl, {{0, buffer}});
+
+    wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+    {
+        wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
+        pass.SetPipeline(computePipeline);
+        pass.SetBindGroup(0, bg0);
+        pass.DispatchWorkgroups(1);
+        pass.End();
+    }
+    {
+        utils::BasicRenderPass renderPass = utils::CreateBasicRenderPass(device, 4, 4, format);
+        wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass.renderPassInfo);
+        pass.SetPipeline(renderPipeline);
+        pass.SetBindGroup(0, bg1);
+        pass.Draw(1);
+        pass.End();
+    }
+
+    wgpu::CommandBuffer commands = encoder.Finish();
+    queue.Submit(1, &commands);
+}
+
 DAWN_INSTANTIATE_TEST(PipelineLayoutTests,
                       D3D12Backend(),
                       MetalBackend(),