D3D12: Only set the root signature when bind groups are applied

Dawn only reapplies bind groups that have changed on the layout
since the last draw. However, the D3D12 backend was always resetting
the root signature state upon switching pipelines. This led to a
bug where the root signature could be changed and dirtied without
reapplying the dirty bind groups.

Moving application of the root signature state to the same loop
that applies bind groups helps ensure the state stays in sync.

Fixed: dawn:1055
Change-Id: Iae89088560e83f6104c921d42de27c03095d654f
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/61582
Commit-Queue: Austin Eng <enga@chromium.org>
Reviewed-by: Bryan Bernhart <bryan.bernhart@intel.com>
diff --git a/src/dawn_native/d3d12/CommandBufferD3D12.cpp b/src/dawn_native/d3d12/CommandBufferD3D12.cpp
index 302d981..50c5236 100644
--- a/src/dawn_native/d3d12/CommandBufferD3D12.cpp
+++ b/src/dawn_native/d3d12/CommandBufferD3D12.cpp
@@ -326,20 +326,12 @@
             mInCompute = inCompute_;
         }
 
-        void OnSetPipeline(PipelineBase* pipeline) {
-            // Invalidate the root sampler tables previously set in the root signature.
-            // This is because changing the pipeline layout also changes the root signature.
-            const PipelineLayout* pipelineLayout = ToBackend(pipeline->GetLayout());
-            if (mLastAppliedPipelineLayout != pipelineLayout) {
-                mBoundRootSamplerTables = {};
-            }
-
-            Base::OnSetPipeline(pipeline);
-        }
-
         MaybeError Apply(CommandRecordingContext* commandContext) {
             BeforeApply();
 
+            ID3D12GraphicsCommandList* commandList = commandContext->GetCommandList();
+            UpdateRootSignatureIfNecessary(commandList);
+
             // Bindgroups are allocated in shader-visible descriptor heaps which are managed by a
             // ringbuffer. There can be a single shader-visible descriptor heap of each type bound
             // at any given time. This means that when we switch heaps, all other currently bound
@@ -358,8 +350,6 @@
                 }
             }
 
-            ID3D12GraphicsCommandList* commandList = commandContext->GetCommandList();
-
             if (!didCreateBindGroupViews || !didCreateBindGroupSamplers) {
                 if (!didCreateBindGroupViews) {
                     DAWN_TRY(mViewAllocator->AllocateAndSwitchShaderVisibleHeap());
@@ -406,6 +396,20 @@
         }
 
       private:
+        void UpdateRootSignatureIfNecessary(ID3D12GraphicsCommandList* commandList) {
+            if (mLastAppliedPipelineLayout != mPipelineLayout) {
+                if (mInCompute) {
+                    commandList->SetComputeRootSignature(
+                        ToBackend(mPipelineLayout)->GetRootSignature());
+                } else {
+                    commandList->SetGraphicsRootSignature(
+                        ToBackend(mPipelineLayout)->GetRootSignature());
+                }
+                // Invalidate the root sampler tables previously set in the root signature.
+                mBoundRootSamplerTables = {};
+            }
+        }
+
         void ApplyBindGroup(ID3D12GraphicsCommandList* commandList,
                             const PipelineLayout* pipelineLayout,
                             BindGroupIndex index,
@@ -1036,9 +1040,7 @@
                 case Command::SetComputePipeline: {
                     SetComputePipelineCmd* cmd = mCommands.NextCommand<SetComputePipelineCmd>();
                     ComputePipeline* pipeline = ToBackend(cmd->pipeline).Get();
-                    PipelineLayout* layout = ToBackend(pipeline->GetLayout());
 
-                    commandList->SetComputeRootSignature(layout->GetRootSignature());
                     commandList->SetPipelineState(pipeline->GetPipelineState());
 
                     bindingTracker->OnSetPipeline(pipeline);
@@ -1305,7 +1307,6 @@
         }
 
         RenderPipeline* lastPipeline = nullptr;
-        PipelineLayout* lastLayout = nullptr;
         VertexBufferTracker vertexBufferTracker = {};
 
         auto EncodeRenderBundleCommand = [&](CommandIterator* iter, Command type) -> MaybeError {
@@ -1403,16 +1404,13 @@
                 case Command::SetRenderPipeline: {
                     SetRenderPipelineCmd* cmd = iter->NextCommand<SetRenderPipelineCmd>();
                     RenderPipeline* pipeline = ToBackend(cmd->pipeline).Get();
-                    PipelineLayout* layout = ToBackend(pipeline->GetLayout());
 
-                    commandList->SetGraphicsRootSignature(layout->GetRootSignature());
                     commandList->SetPipelineState(pipeline->GetPipelineState());
                     commandList->IASetPrimitiveTopology(pipeline->GetD3D12PrimitiveTopology());
 
                     bindingTracker->OnSetPipeline(pipeline);
 
                     lastPipeline = pipeline;
-                    lastLayout = layout;
                     break;
                 }
 
diff --git a/src/tests/end2end/BindGroupTests.cpp b/src/tests/end2end/BindGroupTests.cpp
index 5b8c5fe..90ecbed 100644
--- a/src/tests/end2end/BindGroupTests.cpp
+++ b/src/tests/end2end/BindGroupTests.cpp
@@ -758,9 +758,6 @@
 // Test for crbug.com/dawn/1049, where setting a pipeline without drawing can prevent
 // bind groups from being applied later
 TEST_P(BindGroupTests, DrawThenChangePipelineTwiceAndBindGroup) {
-    // TODO(crbug.com/dawn/1055) find out why this test fails on Windows Intel D3D12 drivers.
-    DAWN_SUPPRESS_TEST_IF(IsIntel() && IsWindows() && IsD3D12());
-
     utils::BasicRenderPass renderPass = utils::CreateBasicRenderPass(device, kRTSize, kRTSize);
 
     // Create a bind group layout which uses a single dynamic uniform buffer.
@@ -810,7 +807,7 @@
     wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
     wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass.renderPassInfo);
 
-    // Set the pipeline to (uniform, uniform, storage)
+    // Set the pipeline to (uniform, uniform, uniform)
     pass.SetPipeline(pipeline0);
 
     // Set the first bind group to color0 in the dynamic uniform buffer.
@@ -841,8 +838,9 @@
     // Revert to pipeline 0
     pass.SetPipeline(pipeline0);
 
-    // Internally this will not re-apply the bind groups, because we already
-    // drew with this pipeline (setting pipeline 1 did not dirty the bind groups).
+    // Internally this should re-apply bind group 2. Because we already
+    // drew with this pipeline, and setting pipeline 1 did not dirty the bind groups,
+    // bind groups 0 and 1 should still be valid.
     pass.Draw(3);
 
     pass.EndPass();