Implement GPU-based validation for dispatchIndirect

Bug: dawn:1039
Change-Id: I1b77244d33b178c8e4d4b7d72dc038ccb9d65c48
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/67142
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Austin Eng <enga@chromium.org>
diff --git a/src/dawn_native/CommandBuffer.cpp b/src/dawn_native/CommandBuffer.cpp
index 18fef0d..a964f40 100644
--- a/src/dawn_native/CommandBuffer.cpp
+++ b/src/dawn_native/CommandBuffer.cpp
@@ -73,6 +73,10 @@
         return mResourceUsages;
     }
 
+    CommandIterator* CommandBufferBase::GetCommandIteratorForTesting() {
+        return &mCommands;
+    }
+
     bool IsCompleteSubresourceCopiedTo(const TextureBase* texture,
                                        const Extent3D copySize,
                                        const uint32_t mipLevel) {
diff --git a/src/dawn_native/CommandBuffer.h b/src/dawn_native/CommandBuffer.h
index 2800929..c6d47ae 100644
--- a/src/dawn_native/CommandBuffer.h
+++ b/src/dawn_native/CommandBuffer.h
@@ -43,6 +43,8 @@
 
         const CommandBufferResourceUsage& GetResourceUsages() const;
 
+        CommandIterator* GetCommandIteratorForTesting();
+
       protected:
         ~CommandBufferBase() override;
 
diff --git a/src/dawn_native/CommandBufferStateTracker.cpp b/src/dawn_native/CommandBufferStateTracker.cpp
index 5210936..45892a6 100644
--- a/src/dawn_native/CommandBufferStateTracker.cpp
+++ b/src/dawn_native/CommandBufferStateTracker.cpp
@@ -17,8 +17,10 @@
 #include "common/Assert.h"
 #include "common/BitSetIterator.h"
 #include "dawn_native/BindGroup.h"
+#include "dawn_native/ComputePassEncoder.h"
 #include "dawn_native/ComputePipeline.h"
 #include "dawn_native/Forward.h"
+#include "dawn_native/ObjectType_autogen.h"
 #include "dawn_native/PipelineLayout.h"
 #include "dawn_native/RenderPipeline.h"
 
@@ -83,13 +85,15 @@
     MaybeError CommandBufferStateTracker::ValidateBufferInRangeForVertexBuffer(
         uint32_t vertexCount,
         uint32_t firstVertex) {
+        RenderPipelineBase* lastRenderPipeline = GetRenderPipeline();
+
         const ityp::bitset<VertexBufferSlot, kMaxVertexBuffers>&
             vertexBufferSlotsUsedAsVertexBuffer =
-                mLastRenderPipeline->GetVertexBufferSlotsUsedAsVertexBuffer();
+                lastRenderPipeline->GetVertexBufferSlotsUsedAsVertexBuffer();
 
         for (auto usedSlotVertex : IterateBitSet(vertexBufferSlotsUsedAsVertexBuffer)) {
             const VertexBufferInfo& vertexBuffer =
-                mLastRenderPipeline->GetVertexBuffer(usedSlotVertex);
+                lastRenderPipeline->GetVertexBuffer(usedSlotVertex);
             uint64_t arrayStride = vertexBuffer.arrayStride;
             uint64_t bufferSize = mVertexBufferSizes[usedSlotVertex];
 
@@ -120,13 +124,15 @@
     MaybeError CommandBufferStateTracker::ValidateBufferInRangeForInstanceBuffer(
         uint32_t instanceCount,
         uint32_t firstInstance) {
+        RenderPipelineBase* lastRenderPipeline = GetRenderPipeline();
+
         const ityp::bitset<VertexBufferSlot, kMaxVertexBuffers>&
             vertexBufferSlotsUsedAsInstanceBuffer =
-                mLastRenderPipeline->GetVertexBufferSlotsUsedAsInstanceBuffer();
+                lastRenderPipeline->GetVertexBufferSlotsUsedAsInstanceBuffer();
 
         for (auto usedSlotInstance : IterateBitSet(vertexBufferSlotsUsedAsInstanceBuffer)) {
             const VertexBufferInfo& vertexBuffer =
-                mLastRenderPipeline->GetVertexBuffer(usedSlotInstance);
+                lastRenderPipeline->GetVertexBuffer(usedSlotInstance);
             uint64_t arrayStride = vertexBuffer.arrayStride;
             uint64_t bufferSize = mVertexBufferSizes[usedSlotInstance];
             if (arrayStride == 0) {
@@ -209,18 +215,19 @@
         }
 
         if (aspects[VALIDATION_ASPECT_VERTEX_BUFFERS]) {
-            ASSERT(mLastRenderPipeline != nullptr);
+            RenderPipelineBase* lastRenderPipeline = GetRenderPipeline();
 
             const ityp::bitset<VertexBufferSlot, kMaxVertexBuffers>& requiredVertexBuffers =
-                mLastRenderPipeline->GetVertexBufferSlotsUsed();
+                lastRenderPipeline->GetVertexBufferSlotsUsed();
             if (IsSubset(requiredVertexBuffers, mVertexBufferSlotsUsed)) {
                 mAspects.set(VALIDATION_ASPECT_VERTEX_BUFFERS);
             }
         }
 
         if (aspects[VALIDATION_ASPECT_INDEX_BUFFER] && mIndexBufferSet) {
-            if (!IsStripPrimitiveTopology(mLastRenderPipeline->GetPrimitiveTopology()) ||
-                mIndexFormat == mLastRenderPipeline->GetStripIndexFormat()) {
+            RenderPipelineBase* lastRenderPipeline = GetRenderPipeline();
+            if (!IsStripPrimitiveTopology(lastRenderPipeline->GetPrimitiveTopology()) ||
+                mIndexFormat == lastRenderPipeline->GetStripIndexFormat()) {
                 mAspects.set(VALIDATION_ASPECT_INDEX_BUFFER);
             }
         }
@@ -234,12 +241,13 @@
         if (DAWN_UNLIKELY(aspects[VALIDATION_ASPECT_INDEX_BUFFER])) {
             DAWN_INVALID_IF(!mIndexBufferSet, "Index buffer was not set.");
 
-            wgpu::IndexFormat pipelineIndexFormat = mLastRenderPipeline->GetStripIndexFormat();
+            RenderPipelineBase* lastRenderPipeline = GetRenderPipeline();
+            wgpu::IndexFormat pipelineIndexFormat = lastRenderPipeline->GetStripIndexFormat();
             DAWN_INVALID_IF(
-                IsStripPrimitiveTopology(mLastRenderPipeline->GetPrimitiveTopology()) &&
+                IsStripPrimitiveTopology(lastRenderPipeline->GetPrimitiveTopology()) &&
                     mIndexFormat != pipelineIndexFormat,
                 "Strip index format (%s) of %s does not match index buffer format (%s).",
-                pipelineIndexFormat, mLastRenderPipeline, mIndexFormat);
+                pipelineIndexFormat, lastRenderPipeline, mIndexFormat);
 
             // The chunk of code above should be similar to the one in |RecomputeLazyAspects|.
             // It returns the first invalid state found. We shouldn't be able to reach this line
@@ -251,7 +259,7 @@
 
         // TODO(dawn:563): Indicate which slots were not set.
         DAWN_INVALID_IF(aspects[VALIDATION_ASPECT_VERTEX_BUFFERS],
-                        "Vertex buffer slots required by %s were not set.", mLastRenderPipeline);
+                        "Vertex buffer slots required by %s were not set.", GetRenderPipeline());
 
         if (DAWN_UNLIKELY(aspects[VALIDATION_ASPECT_BIND_GROUPS])) {
             for (BindGroupIndex i : IterateBitSet(mLastPipelineLayout->GetBindGroupLayoutsMask())) {
@@ -290,12 +298,15 @@
     }
 
     void CommandBufferStateTracker::SetRenderPipeline(RenderPipelineBase* pipeline) {
-        mLastRenderPipeline = pipeline;
         SetPipelineCommon(pipeline);
     }
 
-    void CommandBufferStateTracker::SetBindGroup(BindGroupIndex index, BindGroupBase* bindgroup) {
+    void CommandBufferStateTracker::SetBindGroup(BindGroupIndex index,
+                                                 BindGroupBase* bindgroup,
+                                                 uint32_t dynamicOffsetCount,
+                                                 const uint32_t* dynamicOffsets) {
         mBindgroups[index] = bindgroup;
+        mDynamicOffsets[index].assign(dynamicOffsets, dynamicOffsets + dynamicOffsetCount);
         mAspects.reset(VALIDATION_ASPECT_BIND_GROUPS);
     }
 
@@ -311,8 +322,9 @@
     }
 
     void CommandBufferStateTracker::SetPipelineCommon(PipelineBase* pipeline) {
-        mLastPipelineLayout = pipeline->GetLayout();
-        mMinBufferSizes = &pipeline->GetMinBufferSizes();
+        mLastPipeline = pipeline;
+        mLastPipelineLayout = pipeline != nullptr ? pipeline->GetLayout() : nullptr;
+        mMinBufferSizes = pipeline != nullptr ? &pipeline->GetMinBufferSizes() : nullptr;
 
         mAspects.set(VALIDATION_ASPECT_PIPELINE);
 
@@ -324,6 +336,25 @@
         return mBindgroups[index];
     }
 
+    const std::vector<uint32_t>& CommandBufferStateTracker::GetDynamicOffsets(
+        BindGroupIndex index) const {
+        return mDynamicOffsets[index];
+    }
+
+    bool CommandBufferStateTracker::HasPipeline() const {
+        return mLastPipeline != nullptr;
+    }
+
+    RenderPipelineBase* CommandBufferStateTracker::GetRenderPipeline() const {
+        ASSERT(HasPipeline() && mLastPipeline->GetType() == ObjectType::RenderPipeline);
+        return static_cast<RenderPipelineBase*>(mLastPipeline);
+    }
+
+    ComputePipelineBase* CommandBufferStateTracker::GetComputePipeline() const {
+        ASSERT(HasPipeline() && mLastPipeline->GetType() == ObjectType::ComputePipeline);
+        return static_cast<ComputePipelineBase*>(mLastPipeline);
+    }
+
     PipelineLayoutBase* CommandBufferStateTracker::GetPipelineLayout() const {
         return mLastPipelineLayout;
     }
diff --git a/src/dawn_native/CommandBufferStateTracker.h b/src/dawn_native/CommandBufferStateTracker.h
index 0a6c587..5686956 100644
--- a/src/dawn_native/CommandBufferStateTracker.h
+++ b/src/dawn_native/CommandBufferStateTracker.h
@@ -38,7 +38,10 @@
         // State-modifying methods
         void SetComputePipeline(ComputePipelineBase* pipeline);
         void SetRenderPipeline(RenderPipelineBase* pipeline);
-        void SetBindGroup(BindGroupIndex index, BindGroupBase* bindgroup);
+        void SetBindGroup(BindGroupIndex index,
+                          BindGroupBase* bindgroup,
+                          uint32_t dynamicOffsetCount,
+                          const uint32_t* dynamicOffsets);
         void SetIndexBuffer(wgpu::IndexFormat format, uint64_t size);
         void SetVertexBuffer(VertexBufferSlot slot, uint64_t size);
 
@@ -46,6 +49,10 @@
         using ValidationAspects = std::bitset<kNumAspects>;
 
         BindGroupBase* GetBindGroup(BindGroupIndex index) const;
+        const std::vector<uint32_t>& GetDynamicOffsets(BindGroupIndex index) const;
+        bool HasPipeline() const;
+        RenderPipelineBase* GetRenderPipeline() const;
+        ComputePipelineBase* GetComputePipeline() const;
         PipelineLayoutBase* GetPipelineLayout() const;
         wgpu::IndexFormat GetIndexFormat() const;
         uint64_t GetIndexBufferSize() const;
@@ -60,6 +67,7 @@
         ValidationAspects mAspects;
 
         ityp::array<BindGroupIndex, BindGroupBase*, kMaxBindGroups> mBindgroups = {};
+        ityp::array<BindGroupIndex, std::vector<uint32_t>, kMaxBindGroups> mDynamicOffsets = {};
         ityp::bitset<VertexBufferSlot, kMaxVertexBuffers> mVertexBufferSlotsUsed;
         bool mIndexBufferSet = false;
         wgpu::IndexFormat mIndexFormat;
@@ -68,7 +76,7 @@
         ityp::array<VertexBufferSlot, uint64_t, kMaxVertexBuffers> mVertexBufferSizes = {};
 
         PipelineLayoutBase* mLastPipelineLayout = nullptr;
-        RenderPipelineBase* mLastRenderPipeline = nullptr;
+        PipelineBase* mLastPipeline = nullptr;
 
         const RequiredBufferSizes* mMinBufferSizes = nullptr;
     };
diff --git a/src/dawn_native/ComputePassEncoder.cpp b/src/dawn_native/ComputePassEncoder.cpp
index 1aa4845..05c68fb 100644
--- a/src/dawn_native/ComputePassEncoder.cpp
+++ b/src/dawn_native/ComputePassEncoder.cpp
@@ -14,18 +14,107 @@
 
 #include "dawn_native/ComputePassEncoder.h"
 
+#include "dawn_native/BindGroup.h"
+#include "dawn_native/BindGroupLayout.h"
 #include "dawn_native/Buffer.h"
 #include "dawn_native/CommandEncoder.h"
 #include "dawn_native/CommandValidation.h"
 #include "dawn_native/Commands.h"
 #include "dawn_native/ComputePipeline.h"
 #include "dawn_native/Device.h"
+#include "dawn_native/InternalPipelineStore.h"
 #include "dawn_native/ObjectType_autogen.h"
 #include "dawn_native/PassResourceUsageTracker.h"
 #include "dawn_native/QuerySet.h"
 
 namespace dawn_native {
 
+    namespace {
+
+        ResultOrError<ComputePipelineBase*> GetOrCreateIndirectDispatchValidationPipeline(
+            DeviceBase* device) {
+            InternalPipelineStore* store = device->GetInternalPipelineStore();
+
+            if (store->dispatchIndirectValidationPipeline != nullptr) {
+                return store->dispatchIndirectValidationPipeline.Get();
+            }
+
+            ShaderModuleDescriptor descriptor;
+            ShaderModuleWGSLDescriptor wgslDesc;
+            descriptor.nextInChain = reinterpret_cast<ChainedStruct*>(&wgslDesc);
+
+            // TODO(https://crbug.com/dawn/1108): Propagate validation feedback from this
+            // shader in various failure modes.
+            wgslDesc.source = R"(
+                [[block]] struct UniformParams {
+                    maxComputeWorkgroupsPerDimension: u32;
+                    clientOffsetInU32: u32;
+                };
+
+                [[block]] struct IndirectParams {
+                    data: array<u32>;
+                };
+
+                [[block]] struct ValidatedParams {
+                    data: array<u32, 3>;
+                };
+
+                [[group(0), binding(0)]] var<uniform> uniformParams: UniformParams;
+                [[group(0), binding(1)]] var<storage, read_write> clientParams: IndirectParams;
+                [[group(0), binding(2)]] var<storage, write> validatedParams: ValidatedParams;
+
+                [[stage(compute), workgroup_size(1, 1, 1)]]
+                fn main() {
+                    for (var i = 0u; i < 3u; i = i + 1u) {
+                        var numWorkgroups = clientParams.data[uniformParams.clientOffsetInU32 + i];
+                        if (numWorkgroups > uniformParams.maxComputeWorkgroupsPerDimension) {
+                            numWorkgroups = 0u;
+                        }
+                        validatedParams.data[i] = numWorkgroups;
+                    }
+                }
+            )";
+
+            Ref<ShaderModuleBase> shaderModule;
+            DAWN_TRY_ASSIGN(shaderModule, device->CreateShaderModule(&descriptor));
+
+            std::array<BindGroupLayoutEntry, 3> entries;
+            entries[0].binding = 0;
+            entries[0].visibility = wgpu::ShaderStage::Compute;
+            entries[0].buffer.type = wgpu::BufferBindingType::Uniform;
+            entries[1].binding = 1;
+            entries[1].visibility = wgpu::ShaderStage::Compute;
+            entries[1].buffer.type = kInternalStorageBufferBinding;
+            entries[2].binding = 2;
+            entries[2].visibility = wgpu::ShaderStage::Compute;
+            entries[2].buffer.type = wgpu::BufferBindingType::Storage;
+
+            BindGroupLayoutDescriptor bindGroupLayoutDescriptor;
+            bindGroupLayoutDescriptor.entryCount = entries.size();
+            bindGroupLayoutDescriptor.entries = entries.data();
+            Ref<BindGroupLayoutBase> bindGroupLayout;
+            DAWN_TRY_ASSIGN(bindGroupLayout,
+                            device->CreateBindGroupLayout(&bindGroupLayoutDescriptor, true));
+
+            PipelineLayoutDescriptor pipelineDescriptor;
+            pipelineDescriptor.bindGroupLayoutCount = 1;
+            pipelineDescriptor.bindGroupLayouts = &bindGroupLayout.Get();
+            Ref<PipelineLayoutBase> pipelineLayout;
+            DAWN_TRY_ASSIGN(pipelineLayout, device->CreatePipelineLayout(&pipelineDescriptor));
+
+            ComputePipelineDescriptor computePipelineDescriptor = {};
+            computePipelineDescriptor.layout = pipelineLayout.Get();
+            computePipelineDescriptor.compute.module = shaderModule.Get();
+            computePipelineDescriptor.compute.entryPoint = "main";
+
+            DAWN_TRY_ASSIGN(store->dispatchIndirectValidationPipeline,
+                            device->CreateComputePipeline(&computePipelineDescriptor));
+
+            return store->dispatchIndirectValidationPipeline.Get();
+        }
+
+    }  // namespace
+
     ComputePassEncoder::ComputePassEncoder(DeviceBase* device,
                                            CommandEncoder* commandEncoder,
                                            EncodingContext* encodingContext)
@@ -107,6 +196,95 @@
             "encoding Dispatch (x: %u, y: %u, z: %u)", x, y, z);
     }
 
+    ResultOrError<std::pair<Ref<BufferBase>, uint64_t>>
+    ComputePassEncoder::ValidateIndirectDispatch(BufferBase* indirectBuffer,
+                                                 uint64_t indirectOffset) {
+        DeviceBase* device = GetDevice();
+        auto* const store = device->GetInternalPipelineStore();
+
+        Ref<ComputePipelineBase> validationPipeline;
+        DAWN_TRY_ASSIGN(validationPipeline, GetOrCreateIndirectDispatchValidationPipeline(device));
+
+        Ref<BindGroupLayoutBase> layout;
+        DAWN_TRY_ASSIGN(layout, validationPipeline->GetBindGroupLayout(0));
+
+        uint32_t storageBufferOffsetAlignment =
+            device->GetLimits().v1.minStorageBufferOffsetAlignment;
+
+        std::array<BindGroupEntry, 3> bindings;
+
+        // Storage binding holding the client's indirect buffer.
+        BindGroupEntry& clientIndirectBinding = bindings[0];
+        clientIndirectBinding.binding = 1;
+        clientIndirectBinding.buffer = indirectBuffer;
+
+        // Let the offset be the indirectOffset, aligned down to |storageBufferOffsetAlignment|.
+        const uint32_t clientOffsetFromAlignedBoundary =
+            indirectOffset % storageBufferOffsetAlignment;
+        const uint64_t clientOffsetAlignedDown = indirectOffset - clientOffsetFromAlignedBoundary;
+        clientIndirectBinding.offset = clientOffsetAlignedDown;
+
+        // Let the size of the binding be the additional offset, plus the size.
+        clientIndirectBinding.size = kDispatchIndirectSize + clientOffsetFromAlignedBoundary;
+
+        struct UniformParams {
+            uint32_t maxComputeWorkgroupsPerDimension;
+            uint32_t clientOffsetInU32;
+        };
+
+        // Create a uniform buffer to hold parameters for the shader.
+        Ref<BufferBase> uniformBuffer;
+        {
+            BufferDescriptor uniformDesc = {};
+            uniformDesc.size = sizeof(UniformParams);
+            uniformDesc.usage = wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst;
+            uniformDesc.mappedAtCreation = true;
+            DAWN_TRY_ASSIGN(uniformBuffer, device->CreateBuffer(&uniformDesc));
+
+            UniformParams* params = static_cast<UniformParams*>(
+                uniformBuffer->GetMappedRange(0, sizeof(UniformParams)));
+            params->maxComputeWorkgroupsPerDimension =
+                device->GetLimits().v1.maxComputeWorkgroupsPerDimension;
+            params->clientOffsetInU32 = clientOffsetFromAlignedBoundary / sizeof(uint32_t);
+            uniformBuffer->Unmap();
+        }
+
+        // Uniform buffer binding pointing to the uniform parameters.
+        BindGroupEntry& uniformBinding = bindings[1];
+        uniformBinding.binding = 0;
+        uniformBinding.buffer = uniformBuffer.Get();
+        uniformBinding.offset = 0;
+        uniformBinding.size = sizeof(UniformParams);
+
+        // Reserve space in the scratch buffer to hold the validated indirect params.
+        ScratchBuffer& scratchBuffer = store->scratchIndirectStorage;
+        DAWN_TRY(scratchBuffer.EnsureCapacity(kDispatchIndirectSize));
+        Ref<BufferBase> validatedIndirectBuffer = scratchBuffer.GetBuffer();
+
+        // Binding for the validated indirect params.
+        BindGroupEntry& validatedParamsBinding = bindings[2];
+        validatedParamsBinding.binding = 2;
+        validatedParamsBinding.buffer = validatedIndirectBuffer.Get();
+        validatedParamsBinding.offset = 0;
+        validatedParamsBinding.size = kDispatchIndirectSize;
+
+        BindGroupDescriptor bindGroupDescriptor = {};
+        bindGroupDescriptor.layout = layout.Get();
+        bindGroupDescriptor.entryCount = bindings.size();
+        bindGroupDescriptor.entries = bindings.data();
+
+        Ref<BindGroupBase> validationBindGroup;
+        DAWN_TRY_ASSIGN(validationBindGroup, device->CreateBindGroup(&bindGroupDescriptor));
+
+        // Issue commands to validate the indirect buffer.
+        APISetPipeline(validationPipeline.Get());
+        APISetBindGroup(0, validationBindGroup.Get());
+        APIDispatch(1);
+
+        // Return the new indirect buffer and indirect buffer offset.
+        return std::make_pair(std::move(validatedIndirectBuffer), uint64_t(0));
+    }
+
     void ComputePassEncoder::APIDispatchIndirect(BufferBase* indirectBuffer,
                                                  uint64_t indirectOffset) {
         mEncodingContext->TryEncode(
@@ -136,18 +314,46 @@
                         indirectOffset, kDispatchIndirectSize, indirectBuffer->GetSize());
                 }
 
-                // Record the synchronization scope for Dispatch, both the bindgroups and the
-                // indirect buffer.
                 SyncScopeUsageTracker scope;
                 scope.BufferUsedAs(indirectBuffer, wgpu::BufferUsage::Indirect);
                 mUsageTracker.AddReferencedBuffer(indirectBuffer);
+                // TODO(crbug.com/dawn/1166): If validation is enabled, adding |indirectBuffer|
+                // is needed for correct usage validation even though it will only be bound for
+                // storage. This will unecessarily transition the |indirectBuffer| in
+                // the backend.
+
+                Ref<BufferBase> indirectBufferRef = indirectBuffer;
+                if (IsValidationEnabled()) {
+                    // Save the previous command buffer state so it can be restored after the
+                    // validation inserts additional commands.
+                    CommandBufferStateTracker previousState = mCommandBufferState;
+
+                    // Validate each indirect dispatch with a single dispatch to copy the indirect
+                    // buffer params into a scratch buffer if they're valid, and otherwise zero them
+                    // out. We could consider moving the validation earlier in the pass after the
+                    // last point the indirect buffer was used with writable usage, as well as batch
+                    // validation for multiple dispatches into one, but inserting commands at
+                    // arbitrary points in the past is not possible right now.
+                    DAWN_TRY_ASSIGN(
+                        std::tie(indirectBufferRef, indirectOffset),
+                        ValidateIndirectDispatch(indirectBufferRef.Get(), indirectOffset));
+
+                    // Restore the state.
+                    RestoreCommandBufferState(std::move(previousState));
+
+                    // |indirectBufferRef| was replaced with a scratch buffer. Add it to the
+                    // synchronization scope.
+                    ASSERT(indirectBufferRef.Get() != indirectBuffer);
+                    scope.BufferUsedAs(indirectBufferRef.Get(), wgpu::BufferUsage::Indirect);
+                    mUsageTracker.AddReferencedBuffer(indirectBufferRef.Get());
+                }
+
                 AddDispatchSyncScope(std::move(scope));
 
                 DispatchIndirectCmd* dispatch =
                     allocator->Allocate<DispatchIndirectCmd>(Command::DispatchIndirect);
-                dispatch->indirectBuffer = indirectBuffer;
+                dispatch->indirectBuffer = std::move(indirectBufferRef);
                 dispatch->indirectOffset = indirectOffset;
-
                 return {};
             },
             "encoding DispatchIndirect with %s", indirectBuffer);
@@ -187,10 +393,10 @@
                 }
 
                 mUsageTracker.AddResourcesReferencedByBindGroup(group);
-
                 RecordSetBindGroup(allocator, groupIndex, group, dynamicOffsetCount,
                                    dynamicOffsets);
-                mCommandBufferState.SetBindGroup(groupIndex, group);
+                mCommandBufferState.SetBindGroup(groupIndex, group, dynamicOffsetCount,
+                                                 dynamicOffsets);
 
                 return {};
             },
@@ -226,4 +432,29 @@
         mUsageTracker.AddDispatch(scope.AcquireSyncScopeUsage());
     }
 
+    void ComputePassEncoder::RestoreCommandBufferState(CommandBufferStateTracker state) {
+        // Encode commands for the backend to restore the pipeline and bind groups.
+        if (state.HasPipeline()) {
+            APISetPipeline(state.GetComputePipeline());
+        }
+        for (BindGroupIndex i(0); i < kMaxBindGroupsTyped; ++i) {
+            BindGroupBase* bg = state.GetBindGroup(i);
+            if (bg != nullptr) {
+                const std::vector<uint32_t>& offsets = state.GetDynamicOffsets(i);
+                if (offsets.empty()) {
+                    APISetBindGroup(static_cast<uint32_t>(i), bg);
+                } else {
+                    APISetBindGroup(static_cast<uint32_t>(i), bg, offsets.size(), offsets.data());
+                }
+            }
+        }
+
+        // Restore the frontend state tracking information.
+        mCommandBufferState = std::move(state);
+    }
+
+    CommandBufferStateTracker* ComputePassEncoder::GetCommandBufferStateTrackerForTesting() {
+        return &mCommandBufferState;
+    }
+
 }  // namespace dawn_native
diff --git a/src/dawn_native/ComputePassEncoder.h b/src/dawn_native/ComputePassEncoder.h
index b0962f4..03997ce 100644
--- a/src/dawn_native/ComputePassEncoder.h
+++ b/src/dawn_native/ComputePassEncoder.h
@@ -50,6 +50,11 @@
 
         void APIWriteTimestamp(QuerySetBase* querySet, uint32_t queryIndex);
 
+        CommandBufferStateTracker* GetCommandBufferStateTrackerForTesting();
+        void RestoreCommandBufferStateForTesting(CommandBufferStateTracker state) {
+            RestoreCommandBufferState(std::move(state));
+        }
+
       protected:
         ComputePassEncoder(DeviceBase* device,
                            CommandEncoder* commandEncoder,
@@ -57,6 +62,12 @@
                            ErrorTag errorTag);
 
       private:
+        ResultOrError<std::pair<Ref<BufferBase>, uint64_t>> ValidateIndirectDispatch(
+            BufferBase* indirectBuffer,
+            uint64_t indirectOffset);
+
+        void RestoreCommandBufferState(CommandBufferStateTracker state);
+
         CommandBufferStateTracker mCommandBufferState;
 
         // Adds the bindgroups used for the current dispatch to the SyncScopeResourceUsage and
diff --git a/src/dawn_native/InternalPipelineStore.h b/src/dawn_native/InternalPipelineStore.h
index acf3b13..803e0df 100644
--- a/src/dawn_native/InternalPipelineStore.h
+++ b/src/dawn_native/InternalPipelineStore.h
@@ -52,6 +52,7 @@
 
         Ref<ComputePipelineBase> renderValidationPipeline;
         Ref<ShaderModuleBase> renderValidationShader;
+        Ref<ComputePipelineBase> dispatchIndirectValidationPipeline;
     };
 
 }  // namespace dawn_native
diff --git a/src/dawn_native/RenderEncoderBase.cpp b/src/dawn_native/RenderEncoderBase.cpp
index 0445a97..a8ef2ff 100644
--- a/src/dawn_native/RenderEncoderBase.cpp
+++ b/src/dawn_native/RenderEncoderBase.cpp
@@ -208,6 +208,9 @@
                         BufferLocation::New(indirectBuffer, indirectOffset);
                 }
 
+                // TODO(crbug.com/dawn/1166): Adding the indirectBuffer is needed for correct usage
+                // validation, but it will unecessarily transition to indirectBuffer usage in the
+                // backend.
                 mUsageTracker.BufferUsedAs(indirectBuffer, wgpu::BufferUsage::Indirect);
 
                 return {};
@@ -404,7 +407,8 @@
 
                 RecordSetBindGroup(allocator, groupIndex, group, dynamicOffsetCount,
                                    dynamicOffsets);
-                mCommandBufferState.SetBindGroup(groupIndex, group);
+                mCommandBufferState.SetBindGroup(groupIndex, group, dynamicOffsetCount,
+                                                 dynamicOffsets);
                 mUsageTracker.AddBindGroup(group);
 
                 return {};
diff --git a/src/tests/BUILD.gn b/src/tests/BUILD.gn
index c7f4b7a..8ffd921 100644
--- a/src/tests/BUILD.gn
+++ b/src/tests/BUILD.gn
@@ -221,6 +221,7 @@
     "unittests/SystemUtilsTests.cpp",
     "unittests/ToBackendTests.cpp",
     "unittests/TypedIntegerTests.cpp",
+    "unittests/native/CommandBufferEncodingTests.cpp",
     "unittests/native/DestroyObjectTests.cpp",
     "unittests/validation/BindGroupValidationTests.cpp",
     "unittests/validation/BufferValidationTests.cpp",
diff --git a/src/tests/DawnNativeTest.cpp b/src/tests/DawnNativeTest.cpp
index d39c8e0..28d69bf 100644
--- a/src/tests/DawnNativeTest.cpp
+++ b/src/tests/DawnNativeTest.cpp
@@ -12,9 +12,11 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include <gtest/gtest.h>
+#include "tests/DawnNativeTest.h"
 
 #include "absl/strings/str_cat.h"
+#include "common/Assert.h"
+#include "dawn/dawn_proc.h"
 #include "dawn_native/ErrorData.h"
 
 namespace dawn_native {
@@ -28,3 +30,54 @@
     }
 
 }  // namespace dawn_native
+
+DawnNativeTest::DawnNativeTest() {
+    dawnProcSetProcs(&dawn_native::GetProcs());
+}
+
+DawnNativeTest::~DawnNativeTest() {
+    device = wgpu::Device();
+    dawnProcSetProcs(nullptr);
+}
+
+void DawnNativeTest::SetUp() {
+    instance = std::make_unique<dawn_native::Instance>();
+    instance->DiscoverDefaultAdapters();
+
+    std::vector<dawn_native::Adapter> adapters = instance->GetAdapters();
+
+    // DawnNative unittests run against the null backend, find the corresponding adapter
+    bool foundNullAdapter = false;
+    for (auto& currentAdapter : adapters) {
+        wgpu::AdapterProperties adapterProperties;
+        currentAdapter.GetProperties(&adapterProperties);
+
+        if (adapterProperties.backendType == wgpu::BackendType::Null) {
+            adapter = currentAdapter;
+            foundNullAdapter = true;
+            break;
+        }
+    }
+
+    ASSERT(foundNullAdapter);
+
+    device = wgpu::Device(CreateTestDevice());
+    device.SetUncapturedErrorCallback(DawnNativeTest::OnDeviceError, nullptr);
+}
+
+void DawnNativeTest::TearDown() {
+}
+
+WGPUDevice DawnNativeTest::CreateTestDevice() {
+    // Disabled disallowing unsafe APIs so we can test them.
+    dawn_native::DeviceDescriptor deviceDescriptor;
+    deviceDescriptor.forceDisabledToggles.push_back("disallow_unsafe_apis");
+
+    return adapter.CreateDevice(&deviceDescriptor);
+}
+
+// static
+void DawnNativeTest::OnDeviceError(WGPUErrorType type, const char* message, void* userdata) {
+    ASSERT(type != WGPUErrorType_NoError);
+    FAIL() << "Unexpected error: " << message;
+}
diff --git a/src/tests/DawnNativeTest.h b/src/tests/DawnNativeTest.h
index 94fdafb..91904a3 100644
--- a/src/tests/DawnNativeTest.h
+++ b/src/tests/DawnNativeTest.h
@@ -17,6 +17,8 @@
 
 #include <gtest/gtest.h>
 
+#include "dawn/webgpu_cpp.h"
+#include "dawn_native/DawnNative.h"
 #include "dawn_native/ErrorData.h"
 
 namespace dawn_native {
@@ -29,4 +31,23 @@
 
 }  // namespace dawn_native
 
+class DawnNativeTest : public ::testing::Test {
+  public:
+    DawnNativeTest();
+    ~DawnNativeTest() override;
+
+    void SetUp() override;
+    void TearDown() override;
+
+    virtual WGPUDevice CreateTestDevice();
+
+  protected:
+    std::unique_ptr<dawn_native::Instance> instance;
+    dawn_native::Adapter adapter;
+    wgpu::Device device;
+
+  private:
+    static void OnDeviceError(WGPUErrorType type, const char* message, void* userdata);
+};
+
 #endif  // TESTS_DAWNNATIVETEST_H_
diff --git a/src/tests/end2end/ComputeDispatchTests.cpp b/src/tests/end2end/ComputeDispatchTests.cpp
index 1a8b163..cbc2d86 100644
--- a/src/tests/end2end/ComputeDispatchTests.cpp
+++ b/src/tests/end2end/ComputeDispatchTests.cpp
@@ -158,8 +158,14 @@
         queue.Submit(1, &commands);
 
         std::vector<uint32_t> expected;
+
+        uint32_t maxComputeWorkgroupsPerDimension =
+            GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension;
         if (indirectBufferData[indirectStart] == 0 || indirectBufferData[indirectStart + 1] == 0 ||
-            indirectBufferData[indirectStart + 2] == 0) {
+            indirectBufferData[indirectStart + 2] == 0 ||
+            indirectBufferData[indirectStart] > maxComputeWorkgroupsPerDimension ||
+            indirectBufferData[indirectStart + 1] > maxComputeWorkgroupsPerDimension ||
+            indirectBufferData[indirectStart + 2] > maxComputeWorkgroupsPerDimension) {
             expected = kSentinelData;
         } else {
             expected.assign(indirectBufferData.begin() + indirectStart,
@@ -221,6 +227,52 @@
     IndirectTest({0, 0, 0, 2, 3, 4}, 3 * sizeof(uint32_t));
 }
 
+// Test indirect dispatches at max limit.
+TEST_P(ComputeDispatchTests, MaxWorkgroups) {
+    // TODO(crbug.com/dawn/1165): Fails with WARP
+    DAWN_SUPPRESS_TEST_IF(IsWARP());
+    uint32_t max = GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension;
+
+    // Test that the maximum works in each dimension.
+    // Note: Testing (max, max, max) is very slow.
+    IndirectTest({max, 3, 4}, 0);
+    IndirectTest({2, max, 4}, 0);
+    IndirectTest({2, 3, max}, 0);
+}
+
+// Test indirect dispatches exceeding the max limit are noop-ed.
+TEST_P(ComputeDispatchTests, ExceedsMaxWorkgroupsNoop) {
+    DAWN_TEST_UNSUPPORTED_IF(HasToggleEnabled("skip_validation"));
+
+    uint32_t max = GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension;
+
+    // All dimensions are above the max
+    IndirectTest({max + 1, max + 1, max + 1}, 0);
+
+    // Only x dimension is above the max
+    IndirectTest({max + 1, 3, 4}, 0);
+    IndirectTest({2 * max, 3, 4}, 0);
+
+    // Only y dimension is above the max
+    IndirectTest({2, max + 1, 4}, 0);
+    IndirectTest({2, 2 * max, 4}, 0);
+
+    // Only z dimension is above the max
+    IndirectTest({2, 3, max + 1}, 0);
+    IndirectTest({2, 3, 2 * max}, 0);
+}
+
+// Test indirect dispatches exceeding the max limit with an offset are noop-ed.
+TEST_P(ComputeDispatchTests, ExceedsMaxWorkgroupsWithOffsetNoop) {
+    DAWN_TEST_UNSUPPORTED_IF(HasToggleEnabled("skip_validation"));
+
+    uint32_t max = GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension;
+
+    IndirectTest({1, 2, 3, max + 1, 4, 5}, 1 * sizeof(uint32_t));
+    IndirectTest({1, 2, 3, max + 1, 4, 5}, 2 * sizeof(uint32_t));
+    IndirectTest({1, 2, 3, max + 1, 4, 5}, 3 * sizeof(uint32_t));
+}
+
 DAWN_INSTANTIATE_TEST(ComputeDispatchTests,
                       D3D12Backend(),
                       MetalBackend(),
diff --git a/src/tests/unittests/native/CommandBufferEncodingTests.cpp b/src/tests/unittests/native/CommandBufferEncodingTests.cpp
new file mode 100644
index 0000000..c1ca2d9
--- /dev/null
+++ b/src/tests/unittests/native/CommandBufferEncodingTests.cpp
@@ -0,0 +1,310 @@
+// Copyright 2021 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "tests/DawnNativeTest.h"
+
+#include "dawn_native/CommandBuffer.h"
+#include "dawn_native/Commands.h"
+#include "dawn_native/ComputePassEncoder.h"
+#include "utils/WGPUHelpers.h"
+
+class CommandBufferEncodingTests : public DawnNativeTest {
+  protected:
+    void ExpectCommands(dawn_native::CommandIterator* commands,
+                        std::vector<std::pair<dawn_native::Command,
+                                              std::function<void(dawn_native::CommandIterator*)>>>
+                            expectedCommands) {
+        dawn_native::Command commandId;
+        for (uint32_t commandIndex = 0; commands->NextCommandId(&commandId); ++commandIndex) {
+            ASSERT_LT(commandIndex, expectedCommands.size()) << "Unexpected command";
+            ASSERT_EQ(commandId, expectedCommands[commandIndex].first)
+                << "at command " << commandIndex;
+            expectedCommands[commandIndex].second(commands);
+        }
+    }
+};
+
+// Indirect dispatch validation changes the bind groups in the middle
+// of a pass. Test that bindings are restored after the validation runs.
+TEST_F(CommandBufferEncodingTests, ComputePassEncoderIndirectDispatchStateRestoration) {
+    using namespace dawn_native;
+
+    wgpu::BindGroupLayout staticLayout =
+        utils::MakeBindGroupLayout(device, {{
+                                               0,
+                                               wgpu::ShaderStage::Compute,
+                                               wgpu::BufferBindingType::Uniform,
+                                           }});
+
+    wgpu::BindGroupLayout dynamicLayout =
+        utils::MakeBindGroupLayout(device, {{
+                                               0,
+                                               wgpu::ShaderStage::Compute,
+                                               wgpu::BufferBindingType::Uniform,
+                                               true,
+                                           }});
+
+    // Create a simple pipeline
+    wgpu::ComputePipelineDescriptor csDesc;
+    csDesc.compute.module = utils::CreateShaderModule(device, R"(
+        [[stage(compute), workgroup_size(1, 1, 1)]]
+        fn main() {
+        })");
+    csDesc.compute.entryPoint = "main";
+
+    wgpu::PipelineLayout pl0 = utils::MakePipelineLayout(device, {staticLayout, dynamicLayout});
+    csDesc.layout = pl0;
+    wgpu::ComputePipeline pipeline0 = device.CreateComputePipeline(&csDesc);
+
+    wgpu::PipelineLayout pl1 = utils::MakePipelineLayout(device, {dynamicLayout, staticLayout});
+    csDesc.layout = pl1;
+    wgpu::ComputePipeline pipeline1 = device.CreateComputePipeline(&csDesc);
+
+    // Create buffers to use for both the indirect buffer and the bind groups.
+    wgpu::Buffer indirectBuffer =
+        utils::CreateBufferFromData<uint32_t>(device, wgpu::BufferUsage::Indirect, {1, 2, 3, 4});
+
+    wgpu::BufferDescriptor uniformBufferDesc = {};
+    uniformBufferDesc.size = 512;
+    uniformBufferDesc.usage = wgpu::BufferUsage::Uniform;
+    wgpu::Buffer uniformBuffer = device.CreateBuffer(&uniformBufferDesc);
+
+    wgpu::BindGroup staticBG = utils::MakeBindGroup(device, staticLayout, {{0, uniformBuffer}});
+
+    wgpu::BindGroup dynamicBG =
+        utils::MakeBindGroup(device, dynamicLayout, {{0, uniformBuffer, 0, 256}});
+
+    uint32_t dynamicOffset = 256;
+    std::vector<uint32_t> emptyDynamicOffsets = {};
+    std::vector<uint32_t> singleDynamicOffset = {dynamicOffset};
+
+    // Begin encoding commands.
+    wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+    wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
+
+    CommandBufferStateTracker* stateTracker =
+        FromAPI(pass.Get())->GetCommandBufferStateTrackerForTesting();
+
+    // Perform a dispatch indirect which will be preceded by a validation dispatch.
+    pass.SetPipeline(pipeline0);
+    pass.SetBindGroup(0, staticBG);
+    pass.SetBindGroup(1, dynamicBG, 1, &dynamicOffset);
+    EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline0.Get());
+
+    pass.DispatchIndirect(indirectBuffer, 0);
+
+    // Expect restored state.
+    EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline0.Get());
+    EXPECT_EQ(ToAPI(stateTracker->GetPipelineLayout()), pl0.Get());
+    EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(0))), staticBG.Get());
+    EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(0)), emptyDynamicOffsets);
+    EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(1))), dynamicBG.Get());
+    EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(1)), singleDynamicOffset);
+
+    // Dispatch again to check that the restored state can be used.
+    // Also pass an indirect offset which should get replaced with the offset
+    // into the scratch indirect buffer (0).
+    pass.DispatchIndirect(indirectBuffer, 4);
+
+    // Expect restored state.
+    EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline0.Get());
+    EXPECT_EQ(ToAPI(stateTracker->GetPipelineLayout()), pl0.Get());
+    EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(0))), staticBG.Get());
+    EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(0)), emptyDynamicOffsets);
+    EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(1))), dynamicBG.Get());
+    EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(1)), singleDynamicOffset);
+
+    // Change the pipeline
+    pass.SetPipeline(pipeline1);
+    pass.SetBindGroup(0, dynamicBG, 1, &dynamicOffset);
+    pass.SetBindGroup(1, staticBG);
+    EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline1.Get());
+    EXPECT_EQ(ToAPI(stateTracker->GetPipelineLayout()), pl1.Get());
+
+    pass.DispatchIndirect(indirectBuffer, 0);
+
+    // Expect restored state.
+    EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline1.Get());
+    EXPECT_EQ(ToAPI(stateTracker->GetPipelineLayout()), pl1.Get());
+    EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(0))), dynamicBG.Get());
+    EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(0)), singleDynamicOffset);
+    EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(1))), staticBG.Get());
+    EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(1)), emptyDynamicOffsets);
+
+    pass.EndPass();
+
+    wgpu::CommandBuffer commandBuffer = encoder.Finish();
+
+    auto ExpectSetPipeline = [](wgpu::ComputePipeline pipeline) {
+        return [pipeline](CommandIterator* commands) {
+            auto* cmd = commands->NextCommand<SetComputePipelineCmd>();
+            EXPECT_EQ(ToAPI(cmd->pipeline.Get()), pipeline.Get());
+        };
+    };
+
+    auto ExpectSetBindGroup = [](uint32_t index, wgpu::BindGroup bg,
+                                 std::vector<uint32_t> offsets = {}) {
+        return [index, bg, offsets](CommandIterator* commands) {
+            auto* cmd = commands->NextCommand<SetBindGroupCmd>();
+            uint32_t* dynamicOffsets = nullptr;
+            if (cmd->dynamicOffsetCount > 0) {
+                dynamicOffsets = commands->NextData<uint32_t>(cmd->dynamicOffsetCount);
+            }
+
+            ASSERT_EQ(cmd->index, BindGroupIndex(index));
+            ASSERT_EQ(ToAPI(cmd->group.Get()), bg.Get());
+            ASSERT_EQ(cmd->dynamicOffsetCount, offsets.size());
+            for (uint32_t i = 0; i < cmd->dynamicOffsetCount; ++i) {
+                ASSERT_EQ(dynamicOffsets[i], offsets[i]);
+            }
+        };
+    };
+
+    // Initialize as null. Once we know the pointer, we'll check
+    // that it's the same buffer every time.
+    WGPUBuffer indirectScratchBuffer = nullptr;
+    auto ExpectDispatchIndirect = [&](CommandIterator* commands) {
+        auto* cmd = commands->NextCommand<DispatchIndirectCmd>();
+        if (indirectScratchBuffer == nullptr) {
+            indirectScratchBuffer = ToAPI(cmd->indirectBuffer.Get());
+        }
+        ASSERT_EQ(ToAPI(cmd->indirectBuffer.Get()), indirectScratchBuffer);
+        ASSERT_EQ(cmd->indirectOffset, uint64_t(0));
+    };
+
+    // Initialize as null. Once we know the pointer, we'll check
+    // that it's the same pipeline every time.
+    WGPUComputePipeline validationPipeline = nullptr;
+    auto ExpectSetValidationPipeline = [&](CommandIterator* commands) {
+        auto* cmd = commands->NextCommand<SetComputePipelineCmd>();
+        WGPUComputePipeline pipeline = ToAPI(cmd->pipeline.Get());
+        if (validationPipeline != nullptr) {
+            EXPECT_EQ(pipeline, validationPipeline);
+        } else {
+            EXPECT_NE(pipeline, nullptr);
+            validationPipeline = pipeline;
+        }
+    };
+
+    auto ExpectSetValidationBindGroup = [&](CommandIterator* commands) {
+        auto* cmd = commands->NextCommand<SetBindGroupCmd>();
+        ASSERT_EQ(cmd->index, BindGroupIndex(0));
+        ASSERT_NE(cmd->group.Get(), nullptr);
+        ASSERT_EQ(cmd->dynamicOffsetCount, 0u);
+    };
+
+    auto ExpectSetValidationDispatch = [&](CommandIterator* commands) {
+        auto* cmd = commands->NextCommand<DispatchCmd>();
+        ASSERT_EQ(cmd->x, 1u);
+        ASSERT_EQ(cmd->y, 1u);
+        ASSERT_EQ(cmd->z, 1u);
+    };
+
+    ExpectCommands(
+        FromAPI(commandBuffer.Get())->GetCommandIteratorForTesting(),
+        {
+            {Command::BeginComputePass,
+             [&](CommandIterator* commands) { SkipCommand(commands, Command::BeginComputePass); }},
+            // Expect the state to be set.
+            {Command::SetComputePipeline, ExpectSetPipeline(pipeline0)},
+            {Command::SetBindGroup, ExpectSetBindGroup(0, staticBG)},
+            {Command::SetBindGroup, ExpectSetBindGroup(1, dynamicBG, {dynamicOffset})},
+
+            // Expect the validation.
+            {Command::SetComputePipeline, ExpectSetValidationPipeline},
+            {Command::SetBindGroup, ExpectSetValidationBindGroup},
+            {Command::Dispatch, ExpectSetValidationDispatch},
+
+            // Expect the state to be restored.
+            {Command::SetComputePipeline, ExpectSetPipeline(pipeline0)},
+            {Command::SetBindGroup, ExpectSetBindGroup(0, staticBG)},
+            {Command::SetBindGroup, ExpectSetBindGroup(1, dynamicBG, {dynamicOffset})},
+
+            // Expect the dispatchIndirect.
+            {Command::DispatchIndirect, ExpectDispatchIndirect},
+
+            // Expect the validation.
+            {Command::SetComputePipeline, ExpectSetValidationPipeline},
+            {Command::SetBindGroup, ExpectSetValidationBindGroup},
+            {Command::Dispatch, ExpectSetValidationDispatch},
+
+            // Expect the state to be restored.
+            {Command::SetComputePipeline, ExpectSetPipeline(pipeline0)},
+            {Command::SetBindGroup, ExpectSetBindGroup(0, staticBG)},
+            {Command::SetBindGroup, ExpectSetBindGroup(1, dynamicBG, {dynamicOffset})},
+
+            // Expect the dispatchIndirect.
+            {Command::DispatchIndirect, ExpectDispatchIndirect},
+
+            // Expect the state to be set (new pipeline).
+            {Command::SetComputePipeline, ExpectSetPipeline(pipeline1)},
+            {Command::SetBindGroup, ExpectSetBindGroup(0, dynamicBG, {dynamicOffset})},
+            {Command::SetBindGroup, ExpectSetBindGroup(1, staticBG)},
+
+            // Expect the validation.
+            {Command::SetComputePipeline, ExpectSetValidationPipeline},
+            {Command::SetBindGroup, ExpectSetValidationBindGroup},
+            {Command::Dispatch, ExpectSetValidationDispatch},
+
+            // Expect the state to be restored.
+            {Command::SetComputePipeline, ExpectSetPipeline(pipeline1)},
+            {Command::SetBindGroup, ExpectSetBindGroup(0, dynamicBG, {dynamicOffset})},
+            {Command::SetBindGroup, ExpectSetBindGroup(1, staticBG)},
+
+            // Expect the dispatchIndirect.
+            {Command::DispatchIndirect, ExpectDispatchIndirect},
+
+            {Command::EndComputePass,
+             [&](CommandIterator* commands) { commands->NextCommand<EndComputePassCmd>(); }},
+        });
+}
+
+// Test that after restoring state, it is fully applied to the state tracker
+// and does not leak state changes that occured between a snapshot and the
+// state restoration.
+TEST_F(CommandBufferEncodingTests, StateNotLeakedAfterRestore) {
+    using namespace dawn_native;
+
+    wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+    wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
+
+    CommandBufferStateTracker* stateTracker =
+        FromAPI(pass.Get())->GetCommandBufferStateTrackerForTesting();
+
+    // Snapshot the state.
+    CommandBufferStateTracker snapshot = *stateTracker;
+    // Expect no pipeline in the snapshot
+    EXPECT_FALSE(snapshot.HasPipeline());
+
+    // Create a simple pipeline
+    wgpu::ComputePipelineDescriptor csDesc;
+    csDesc.compute.module = utils::CreateShaderModule(device, R"(
+        [[stage(compute), workgroup_size(1, 1, 1)]]
+        fn main() {
+        })");
+    csDesc.compute.entryPoint = "main";
+    wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&csDesc);
+
+    // Set the pipeline.
+    pass.SetPipeline(pipeline);
+
+    // Expect the pipeline to be set.
+    EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline.Get());
+
+    // Restore the state.
+    FromAPI(pass.Get())->RestoreCommandBufferStateForTesting(std::move(snapshot));
+
+    // Expect no pipeline
+    EXPECT_FALSE(stateTracker->HasPipeline());
+}