Support setting bind groups before pipeline for all backends

This is to match WebGPU semantics.

Bug: dawn:201
Change-Id: I2aab671fc389edf1d2765395814a9c831afc653e
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/11080
Commit-Queue: Austin Eng <enga@chromium.org>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
diff --git a/BUILD.gn b/BUILD.gn
index d5364a1..d225c28 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -136,6 +136,7 @@
     "src/dawn_native/BindGroup.h",
     "src/dawn_native/BindGroupLayout.cpp",
     "src/dawn_native/BindGroupLayout.h",
+    "src/dawn_native/BindGroupTracker.h",
     "src/dawn_native/BuddyAllocator.cpp",
     "src/dawn_native/BuddyAllocator.h",
     "src/dawn_native/Buffer.cpp",
diff --git a/src/dawn_native/BindGroupTracker.h b/src/dawn_native/BindGroupTracker.h
new file mode 100644
index 0000000..60e8ffc
--- /dev/null
+++ b/src/dawn_native/BindGroupTracker.h
@@ -0,0 +1,140 @@
+// Copyright 2019 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef DAWNNATIVE_BINDGROUPTRACKER_H_
+#define DAWNNATIVE_BINDGROUPTRACKER_H_
+
+#include "common/Constants.h"
+#include "dawn_native/BindGroupLayout.h"
+#include "dawn_native/Pipeline.h"
+#include "dawn_native/PipelineLayout.h"
+
+#include <array>
+#include <bitset>
+
+namespace dawn_native {
+
+    // Keeps track of the dirty bind groups so they can be lazily applied when we know the
+    // pipeline state or it changes.
+    // |BindGroup| is a template parameter so a backend may provide its backend-specific
+    // type or native handle.
+    // |DynamicOffset| is a template parameter because offsets in Vulkan are uint32_t but uint64_t
+    // in other backends.
+    template <typename BindGroup, bool CanInheritBindGroups, typename DynamicOffset = uint64_t>
+    class BindGroupTrackerBase {
+      public:
+        void OnSetBindGroup(uint32_t index,
+                            BindGroup bindGroup,
+                            uint32_t dynamicOffsetCount,
+                            uint64_t* dynamicOffsets) {
+            ASSERT(index < kMaxBindGroups);
+
+            if (mBindGroupLayoutsMask[index]) {
+                // It is okay to only dirty bind groups that are used by the current pipeline
+                // layout. If the pipeline layout changes, then the bind groups it uses will
+                // become dirty.
+
+                if (mBindGroups[index] != bindGroup) {
+                    mDirtyBindGroups.set(index);
+                    mDirtyBindGroupsObjectChangedOrIsDynamic.set(index);
+                }
+
+                if (dynamicOffsetCount > 0) {
+                    mDirtyBindGroupsObjectChangedOrIsDynamic.set(index);
+                }
+            }
+
+            mBindGroups[index] = bindGroup;
+            mDynamicOffsetCounts[index] = dynamicOffsetCount;
+            SetDynamicOffsets(mDynamicOffsets[index].data(), dynamicOffsetCount, dynamicOffsets);
+        }
+
+        void OnSetPipeline(PipelineBase* pipeline) {
+            mPipelineLayout = pipeline->GetLayout();
+            if (mLastAppliedPipelineLayout == mPipelineLayout) {
+                return;
+            }
+
+            // Keep track of the bind group layout mask to avoid marking unused bind groups as
+            // dirty. This also allows us to avoid computing the intersection of the dirty bind
+            // groups and bind group layout mask in Draw or Dispatch which is very hot code.
+            mBindGroupLayoutsMask = mPipelineLayout->GetBindGroupLayoutsMask();
+
+            // Changing the pipeline layout sets bind groups as dirty. If CanInheritBindGroups,
+            // the first |k| matching bind groups may be inherited.
+            if (CanInheritBindGroups && mLastAppliedPipelineLayout != nullptr) {
+                // Dirty bind groups that cannot be inherited.
+                std::bitset<kMaxBindGroups> dirtiedGroups =
+                    ~mPipelineLayout->InheritedGroupsMask(mLastAppliedPipelineLayout);
+
+                mDirtyBindGroups |= dirtiedGroups;
+                mDirtyBindGroupsObjectChangedOrIsDynamic |= dirtiedGroups;
+
+                // Clear any bind groups not in the mask.
+                mDirtyBindGroups &= mBindGroupLayoutsMask;
+                mDirtyBindGroupsObjectChangedOrIsDynamic &= mBindGroupLayoutsMask;
+            } else {
+                mDirtyBindGroups = mBindGroupLayoutsMask;
+                mDirtyBindGroupsObjectChangedOrIsDynamic = mBindGroupLayoutsMask;
+            }
+        }
+
+      protected:
+        // The Derived class should call this when it applies bind groups.
+        void DidApply() {
+            // Reset all dirty bind groups. Dirty bind groups not in the bind group layout mask
+            // will be dirtied again by the next pipeline change.
+            mDirtyBindGroups.reset();
+            mDirtyBindGroupsObjectChangedOrIsDynamic.reset();
+            mLastAppliedPipelineLayout = mPipelineLayout;
+        }
+
+        std::bitset<kMaxBindGroups> mDirtyBindGroups = 0;
+        std::bitset<kMaxBindGroups> mDirtyBindGroupsObjectChangedOrIsDynamic = 0;
+        std::bitset<kMaxBindGroups> mBindGroupLayoutsMask = 0;
+        std::array<BindGroup, kMaxBindGroups> mBindGroups = {};
+        std::array<uint32_t, kMaxBindGroups> mDynamicOffsetCounts = {};
+        std::array<std::array<DynamicOffset, kMaxBindingsPerGroup>, kMaxBindGroups>
+            mDynamicOffsets = {};
+
+        // |mPipelineLayout| is the current pipeline layout set on the command buffer.
+        // |mLastAppliedPipelineLayout| is the last pipeline layout for which we applied changes
+        // to the bind group bindings.
+        PipelineLayoutBase* mPipelineLayout = nullptr;
+        PipelineLayoutBase* mLastAppliedPipelineLayout = nullptr;
+
+      private:
+        // Vulkan backend use uint32_t as dynamic offsets type, it is not correct.
+        // Vulkan should use VkDeviceSize. Dawn vulkan backend has to handle this.
+        template <typename T>
+        static void SetDynamicOffsets(T* data,
+                                      uint32_t dynamicOffsetCount,
+                                      uint64_t* dynamicOffsets) {
+            for (uint32_t i = 0; i < dynamicOffsetCount; ++i) {
+                ASSERT(dynamicOffsets[i] <= std::numeric_limits<T>::max());
+                data[i] = static_cast<T>(dynamicOffsets[i]);
+            }
+        }
+
+        template <>
+        static void SetDynamicOffsets<uint64_t>(uint64_t* data,
+                                                uint32_t dynamicOffsetCount,
+                                                uint64_t* dynamicOffsets) {
+            memcpy(data, dynamicOffsets, sizeof(uint64_t) * dynamicOffsetCount);
+        }
+    };
+
+}  // namespace dawn_native
+
+#endif  // DAWNNATIVE_BINDGROUPTRACKER_H_
diff --git a/src/dawn_native/d3d12/CommandBufferD3D12.cpp b/src/dawn_native/d3d12/CommandBufferD3D12.cpp
index 1dcc05f..951911e 100644
--- a/src/dawn_native/d3d12/CommandBufferD3D12.cpp
+++ b/src/dawn_native/d3d12/CommandBufferD3D12.cpp
@@ -15,6 +15,7 @@
 #include "dawn_native/d3d12/CommandBufferD3D12.h"
 
 #include "common/Assert.h"
+#include "dawn_native/BindGroupTracker.h"
 #include "dawn_native/CommandEncoder.h"
 #include "dawn_native/Commands.h"
 #include "dawn_native/RenderBundle.h"
@@ -72,9 +73,9 @@
 
     }  // anonymous namespace
 
-    class BindGroupStateTracker {
+    class BindGroupStateTracker : public BindGroupTrackerBase<BindGroup*, false> {
       public:
-        BindGroupStateTracker(Device* device) : mDevice(device) {
+        BindGroupStateTracker(Device* device) : BindGroupTrackerBase(), mDevice(device) {
         }
 
         void SetInComputePass(bool inCompute_) {
@@ -100,7 +101,7 @@
 
             uint32_t cbvSrvUavDescriptorIndex = 0;
             uint32_t samplerDescriptorIndex = 0;
-            for (BindGroup* group : mBindGroupsList) {
+            for (BindGroup* group : mBindGroupsToAllocate) {
                 ASSERT(group);
                 ASSERT(cbvSrvUavDescriptorIndex +
                            ToBackend(group->GetLayout())->GetCbvUavSrvDescriptorCount() <=
@@ -120,38 +121,50 @@
         void TrackSetBindGroup(BindGroup* group, uint32_t index, uint32_t indexInSubmit) {
             if (mBindGroups[index] != group) {
                 mBindGroups[index] = group;
-
                 if (!group->TestAndSetCounted(mDevice->GetPendingCommandSerial(), indexInSubmit)) {
                     const BindGroupLayout* layout = ToBackend(group->GetLayout());
 
                     mCbvSrvUavDescriptorHeapSize += layout->GetCbvUavSrvDescriptorCount();
                     mSamplerDescriptorHeapSize += layout->GetSamplerDescriptorCount();
-                    mBindGroupsList.push_back(group);
+                    mBindGroupsToAllocate.push_back(group);
                 }
             }
         }
 
-        // This function must only be called before calling AllocateDescriptorHeaps().
-        void TrackInheritedGroups(PipelineLayout* oldLayout,
-                                  PipelineLayout* newLayout,
-                                  uint32_t indexInSubmit) {
-            if (oldLayout == nullptr) {
-                return;
+        void Apply(ID3D12GraphicsCommandList* commandList) {
+            for (uint32_t index : IterateBitSet(mDirtyBindGroupsObjectChangedOrIsDynamic)) {
+                ApplyBindGroup(commandList, ToBackend(mPipelineLayout), index, mBindGroups[index],
+                               mDynamicOffsetCounts[index], mDynamicOffsets[index].data());
             }
+            DidApply();
+        }
 
-            uint32_t inheritUntil = oldLayout->GroupsInheritUpTo(newLayout);
-            for (uint32_t i = 0; i < inheritUntil; ++i) {
-                TrackSetBindGroup(mBindGroups[i], i, indexInSubmit);
+        void Reset() {
+            for (uint32_t i = 0; i < kMaxBindGroups; ++i) {
+                mBindGroups[i] = nullptr;
             }
         }
 
-        void SetBindGroup(ComPtr<ID3D12GraphicsCommandList> commandList,
-                          PipelineLayout* pipelineLayout,
-                          BindGroup* group,
-                          uint32_t index,
-                          uint32_t dynamicOffsetCount,
-                          uint64_t* dynamicOffsets,
-                          bool force = false) {
+        void SetID3D12DescriptorHeaps(ComPtr<ID3D12GraphicsCommandList> commandList) {
+            ASSERT(commandList != nullptr);
+            ID3D12DescriptorHeap* descriptorHeaps[2] = {mCbvSrvUavGPUDescriptorHeap.Get(),
+                                                        mSamplerGPUDescriptorHeap.Get()};
+            if (descriptorHeaps[0] && descriptorHeaps[1]) {
+                commandList->SetDescriptorHeaps(2, descriptorHeaps);
+            } else if (descriptorHeaps[0]) {
+                commandList->SetDescriptorHeaps(1, descriptorHeaps);
+            } else if (descriptorHeaps[1]) {
+                commandList->SetDescriptorHeaps(1, &descriptorHeaps[1]);
+            }
+        }
+
+      private:
+        void ApplyBindGroup(ID3D12GraphicsCommandList* commandList,
+                            PipelineLayout* pipelineLayout,
+                            uint32_t index,
+                            BindGroup* group,
+                            uint32_t dynamicOffsetCount,
+                            uint64_t* dynamicOffsets) {
             // Usually, the application won't set the same offsets many times,
             // so always try to apply dynamic offsets even if the offsets stay the same
             if (dynamicOffsetCount) {
@@ -200,96 +213,50 @@
                             break;
                     }
 
-                    // Record current dynamic offsets for inheriting
-                    mLastDynamicOffsets[index][currentDynamicBufferIndex] = dynamicOffset;
                     ++currentDynamicBufferIndex;
                 }
             }
 
-            if (mBindGroups[index] != group || force) {
-                mBindGroups[index] = group;
-                uint32_t cbvUavSrvCount =
-                    ToBackend(group->GetLayout())->GetCbvUavSrvDescriptorCount();
-                uint32_t samplerCount = ToBackend(group->GetLayout())->GetSamplerDescriptorCount();
-
-                if (cbvUavSrvCount > 0) {
-                    uint32_t parameterIndex = pipelineLayout->GetCbvUavSrvRootParameterIndex(index);
-
-                    if (mInCompute) {
-                        commandList->SetComputeRootDescriptorTable(
-                            parameterIndex, mCbvSrvUavGPUDescriptorHeap.GetGPUHandle(
-                                                group->GetCbvUavSrvHeapOffset()));
-                    } else {
-                        commandList->SetGraphicsRootDescriptorTable(
-                            parameterIndex, mCbvSrvUavGPUDescriptorHeap.GetGPUHandle(
-                                                group->GetCbvUavSrvHeapOffset()));
-                    }
-                }
-
-                if (samplerCount > 0) {
-                    uint32_t parameterIndex = pipelineLayout->GetSamplerRootParameterIndex(index);
-
-                    if (mInCompute) {
-                        commandList->SetComputeRootDescriptorTable(
-                            parameterIndex,
-                            mSamplerGPUDescriptorHeap.GetGPUHandle(group->GetSamplerHeapOffset()));
-                    } else {
-                        commandList->SetGraphicsRootDescriptorTable(
-                            parameterIndex,
-                            mSamplerGPUDescriptorHeap.GetGPUHandle(group->GetSamplerHeapOffset()));
-                    }
-                }
-            }
-        }
-
-        void SetInheritedBindGroups(ComPtr<ID3D12GraphicsCommandList> commandList,
-                                    PipelineLayout* oldLayout,
-                                    PipelineLayout* newLayout) {
-            if (oldLayout == nullptr) {
+            // It's not necessary to update descriptor tables if only the dynamic offset changed.
+            if (!mDirtyBindGroups[index]) {
                 return;
             }
 
-            uint32_t inheritUntil = oldLayout->GroupsInheritUpTo(newLayout);
-            for (uint32_t i = 0; i < inheritUntil; ++i) {
-                const BindGroupLayout* layout = ToBackend(mBindGroups[i]->GetLayout());
-                const uint32_t dynamicBufferCount = layout->GetDynamicBufferCount();
+            uint32_t cbvUavSrvCount = ToBackend(group->GetLayout())->GetCbvUavSrvDescriptorCount();
+            uint32_t samplerCount = ToBackend(group->GetLayout())->GetSamplerDescriptorCount();
 
-                // Inherit dynamic offsets
-                if (dynamicBufferCount > 0) {
-                    SetBindGroup(commandList, newLayout, mBindGroups[i], i, dynamicBufferCount,
-                                 mLastDynamicOffsets[i].data(), true);
+            if (cbvUavSrvCount > 0) {
+                uint32_t parameterIndex = pipelineLayout->GetCbvUavSrvRootParameterIndex(index);
+
+                if (mInCompute) {
+                    commandList->SetComputeRootDescriptorTable(
+                        parameterIndex,
+                        mCbvSrvUavGPUDescriptorHeap.GetGPUHandle(group->GetCbvUavSrvHeapOffset()));
                 } else {
-                    SetBindGroup(commandList, newLayout, mBindGroups[i], i, 0, nullptr, true);
+                    commandList->SetGraphicsRootDescriptorTable(
+                        parameterIndex,
+                        mCbvSrvUavGPUDescriptorHeap.GetGPUHandle(group->GetCbvUavSrvHeapOffset()));
+                }
+            }
+
+            if (samplerCount > 0) {
+                uint32_t parameterIndex = pipelineLayout->GetSamplerRootParameterIndex(index);
+
+                if (mInCompute) {
+                    commandList->SetComputeRootDescriptorTable(
+                        parameterIndex,
+                        mSamplerGPUDescriptorHeap.GetGPUHandle(group->GetSamplerHeapOffset()));
+                } else {
+                    commandList->SetGraphicsRootDescriptorTable(
+                        parameterIndex,
+                        mSamplerGPUDescriptorHeap.GetGPUHandle(group->GetSamplerHeapOffset()));
                 }
             }
         }
 
-        void Reset() {
-            for (uint32_t i = 0; i < kMaxBindGroups; ++i) {
-                mBindGroups[i] = nullptr;
-            }
-        }
-
-        void SetID3D12DescriptorHeaps(ComPtr<ID3D12GraphicsCommandList> commandList) {
-            ASSERT(commandList != nullptr);
-            ID3D12DescriptorHeap* descriptorHeaps[2] = {mCbvSrvUavGPUDescriptorHeap.Get(),
-                                                        mSamplerGPUDescriptorHeap.Get()};
-            if (descriptorHeaps[0] && descriptorHeaps[1]) {
-                commandList->SetDescriptorHeaps(2, descriptorHeaps);
-            } else if (descriptorHeaps[0]) {
-                commandList->SetDescriptorHeaps(1, descriptorHeaps);
-            } else if (descriptorHeaps[1]) {
-                commandList->SetDescriptorHeaps(1, &descriptorHeaps[1]);
-            }
-        }
-
-      private:
         uint32_t mCbvSrvUavDescriptorHeapSize = 0;
         uint32_t mSamplerDescriptorHeapSize = 0;
-        std::array<BindGroup*, kMaxBindGroups> mBindGroups = {};
-        std::deque<BindGroup*> mBindGroupsList = {};
-        std::array<std::array<uint64_t, kMaxDynamicBufferCount>, kMaxBindGroups>
-            mLastDynamicOffsets;
+        std::deque<BindGroup*> mBindGroupsToAllocate = {};
         bool mInCompute = false;
 
         DescriptorHeapHandle mCbvSrvUavGPUDescriptorHeap = {};
@@ -485,26 +452,9 @@
                                            uint32_t indexInSubmit) {
             {
                 Command type;
-                PipelineLayout* lastLayout = nullptr;
 
                 auto HandleCommand = [&](CommandIterator* commands, Command type) {
                     switch (type) {
-                        case Command::SetComputePipeline: {
-                            SetComputePipelineCmd* cmd =
-                                commands->NextCommand<SetComputePipelineCmd>();
-                            PipelineLayout* layout = ToBackend(cmd->pipeline->GetLayout());
-                            bindingTracker->TrackInheritedGroups(lastLayout, layout, indexInSubmit);
-                            lastLayout = layout;
-                        } break;
-
-                        case Command::SetRenderPipeline: {
-                            SetRenderPipelineCmd* cmd =
-                                commands->NextCommand<SetRenderPipelineCmd>();
-                            PipelineLayout* layout = ToBackend(cmd->pipeline->GetLayout());
-                            bindingTracker->TrackInheritedGroups(lastLayout, layout, indexInSubmit);
-                            lastLayout = layout;
-                        } break;
-
                         case Command::SetBindGroup: {
                             SetBindGroupCmd* cmd = commands->NextCommand<SetBindGroupCmd>();
                             BindGroup* group = ToBackend(cmd->group.Get());
@@ -826,12 +776,15 @@
             switch (type) {
                 case Command::Dispatch: {
                     DispatchCmd* dispatch = mCommands.NextCommand<DispatchCmd>();
+
+                    bindingTracker->Apply(commandList.Get());
                     commandList->Dispatch(dispatch->x, dispatch->y, dispatch->z);
                 } break;
 
                 case Command::DispatchIndirect: {
                     DispatchIndirectCmd* dispatch = mCommands.NextCommand<DispatchIndirectCmd>();
 
+                    bindingTracker->Apply(commandList.Get());
                     Buffer* buffer = ToBackend(dispatch->indirectBuffer.Get());
                     ComPtr<ID3D12CommandSignature> signature =
                         ToBackend(GetDevice())->GetDispatchIndirectSignature();
@@ -853,7 +806,8 @@
                     commandList->SetComputeRootSignature(layout->GetRootSignature().Get());
                     commandList->SetPipelineState(pipeline->GetPipelineState().Get());
 
-                    bindingTracker->SetInheritedBindGroups(commandList, lastLayout, layout);
+                    bindingTracker->OnSetPipeline(pipeline);
+
                     lastLayout = layout;
                 } break;
 
@@ -866,8 +820,8 @@
                         dynamicOffsets = mCommands.NextData<uint64_t>(cmd->dynamicOffsetCount);
                     }
 
-                    bindingTracker->SetBindGroup(commandList, lastLayout, group, cmd->index,
-                                                 cmd->dynamicOffsetCount, dynamicOffsets);
+                    bindingTracker->OnSetBindGroup(cmd->index, group, cmd->dynamicOffsetCount,
+                                                   dynamicOffsets);
                 } break;
 
                 case Command::InsertDebugMarker: {
@@ -1045,6 +999,7 @@
                 case Command::Draw: {
                     DrawCmd* draw = iter->NextCommand<DrawCmd>();
 
+                    bindingTracker->Apply(commandList.Get());
                     vertexBufferTracker.Apply(commandList.Get(), lastPipeline);
                     commandList->DrawInstanced(draw->vertexCount, draw->instanceCount,
                                                draw->firstVertex, draw->firstInstance);
@@ -1053,6 +1008,7 @@
                 case Command::DrawIndexed: {
                     DrawIndexedCmd* draw = iter->NextCommand<DrawIndexedCmd>();
 
+                    bindingTracker->Apply(commandList.Get());
                     indexBufferTracker.Apply(commandList.Get());
                     vertexBufferTracker.Apply(commandList.Get(), lastPipeline);
                     commandList->DrawIndexedInstanced(draw->indexCount, draw->instanceCount,
@@ -1063,6 +1019,7 @@
                 case Command::DrawIndirect: {
                     DrawIndirectCmd* draw = iter->NextCommand<DrawIndirectCmd>();
 
+                    bindingTracker->Apply(commandList.Get());
                     vertexBufferTracker.Apply(commandList.Get(), lastPipeline);
                     Buffer* buffer = ToBackend(draw->indirectBuffer.Get());
                     ComPtr<ID3D12CommandSignature> signature =
@@ -1075,6 +1032,7 @@
                 case Command::DrawIndexedIndirect: {
                     DrawIndexedIndirectCmd* draw = iter->NextCommand<DrawIndexedIndirectCmd>();
 
+                    bindingTracker->Apply(commandList.Get());
                     indexBufferTracker.Apply(commandList.Get());
                     vertexBufferTracker.Apply(commandList.Get(), lastPipeline);
                     Buffer* buffer = ToBackend(draw->indirectBuffer.Get());
@@ -1130,8 +1088,8 @@
                     commandList->SetPipelineState(pipeline->GetPipelineState().Get());
                     commandList->IASetPrimitiveTopology(pipeline->GetD3D12PrimitiveTopology());
 
+                    bindingTracker->OnSetPipeline(pipeline);
                     indexBufferTracker.OnSetPipeline(pipeline);
-                    bindingTracker->SetInheritedBindGroups(commandList, lastLayout, layout);
 
                     lastPipeline = pipeline;
                     lastLayout = layout;
@@ -1146,8 +1104,8 @@
                         dynamicOffsets = iter->NextData<uint64_t>(cmd->dynamicOffsetCount);
                     }
 
-                    bindingTracker->SetBindGroup(commandList, lastLayout, group, cmd->index,
-                                                 cmd->dynamicOffsetCount, dynamicOffsets);
+                    bindingTracker->OnSetBindGroup(cmd->index, group, cmd->dynamicOffsetCount,
+                                                   dynamicOffsets);
                 } break;
 
                 case Command::SetIndexBuffer: {
diff --git a/src/dawn_native/metal/CommandBufferMTL.mm b/src/dawn_native/metal/CommandBufferMTL.mm
index 4e512f6..0ea8feb 100644
--- a/src/dawn_native/metal/CommandBufferMTL.mm
+++ b/src/dawn_native/metal/CommandBufferMTL.mm
@@ -15,6 +15,7 @@
 #include "dawn_native/metal/CommandBufferMTL.h"
 
 #include "dawn_native/BindGroup.h"
+#include "dawn_native/BindGroupTracker.h"
 #include "dawn_native/CommandEncoder.h"
 #include "dawn_native/Commands.h"
 #include "dawn_native/RenderBundle.h"
@@ -391,66 +392,21 @@
 
         // Keeps track of the dirty bind groups so they can be lazily applied when we know the
         // pipeline state.
-        class BindGroupTracker {
+        // Bind groups may be inherited because bind groups are packed in the buffer /
+        // texture tables in contiguous order.
+        class BindGroupTracker : public BindGroupTrackerBase<BindGroup*, true> {
           public:
             explicit BindGroupTracker(StorageBufferLengthTracker* lengthTracker)
-                : mLengthTracker(lengthTracker) {
-            }
-
-            void OnSetBindGroup(uint32_t index,
-                                BindGroup* bindGroup,
-                                uint32_t dynamicOffsetCount,
-                                uint64_t* dynamicOffsets) {
-                ASSERT(index < kMaxBindGroups);
-
-                if (mBindGroupLayoutsMask[index]) {
-                    // It is okay to only dirty bind groups that are used by the current pipeline
-                    // layout. If the pipeline layout changes, then the bind groups it uses will
-                    // become dirty.
-                    mDirtyBindGroups.set(index);
-                }
-
-                mBindGroups[index] = bindGroup;
-                mDynamicOffsetCounts[index] = dynamicOffsetCount;
-                memcpy(mDynamicOffsets[index].data(), dynamicOffsets,
-                       sizeof(uint64_t) * dynamicOffsetCount);
-            }
-
-            void OnSetPipeline(PipelineBase* pipeline) {
-                mPipelineLayout = ToBackend(pipeline->GetLayout());
-                if (mLastAppliedPipelineLayout == mPipelineLayout) {
-                    return;
-                }
-
-                // Keep track of the bind group layout mask to avoid marking unused bind groups as
-                // dirty. This also allows us to avoid computing the intersection of the dirty bind
-                // groups and bind group layout mask in Draw or Dispatch which is very hot code.
-                mBindGroupLayoutsMask = mPipelineLayout->GetBindGroupLayoutsMask();
-
-                // Changing the pipeline layout sets bind groups as dirty. The first |k| matching
-                // bind groups may be inherited because bind groups are packed in the buffer /
-                // texture tables in contiguous order.
-                if (mLastAppliedPipelineLayout != nullptr) {
-                    // Dirty bind groups that cannot be inherited.
-                    mDirtyBindGroups |=
-                        ~mPipelineLayout->InheritedGroupsMask(mLastAppliedPipelineLayout);
-                    mDirtyBindGroups &= mBindGroupLayoutsMask;
-                } else {
-                    mDirtyBindGroups = mBindGroupLayoutsMask;
-                }
+                : BindGroupTrackerBase(), mLengthTracker(lengthTracker) {
             }
 
             template <typename Encoder>
             void Apply(Encoder encoder) {
-                for (uint32_t index : IterateBitSet(mDirtyBindGroups)) {
+                for (uint32_t index : IterateBitSet(mDirtyBindGroupsObjectChangedOrIsDynamic)) {
                     ApplyBindGroup(encoder, index, mBindGroups[index], mDynamicOffsetCounts[index],
-                                   mDynamicOffsets[index].data(), mPipelineLayout);
+                                   mDynamicOffsets[index].data(), ToBackend(mPipelineLayout));
                 }
-
-                // Reset all dirty bind groups. Dirty bind groups not in the bind group layout mask
-                // will be dirtied again by the next pipeline change.
-                mDirtyBindGroups.reset();
-                mLastAppliedPipelineLayout = mPipelineLayout;
+                DidApply();
             }
 
           private:
@@ -587,18 +543,6 @@
                 ApplyBindGroupImpl(nil, encoder, std::forward<Args&&>(args)...);
             }
 
-            std::bitset<kMaxBindGroups> mDirtyBindGroups;
-            std::bitset<kMaxBindGroups> mBindGroupLayoutsMask;
-            std::array<BindGroup*, kMaxBindGroups> mBindGroups;
-            std::array<uint32_t, kMaxBindGroups> mDynamicOffsetCounts;
-            std::array<std::array<uint64_t, kMaxBindingsPerGroup>, kMaxBindGroups> mDynamicOffsets;
-
-            // |mPipelineLayout| is the current pipeline layout set on the command buffer.
-            // |mLastAppliedPipelineLayout| is the last pipeline layout for which we applied changes
-            // to the bind group bindings.
-            PipelineLayout* mPipelineLayout = nullptr;
-            PipelineLayout* mLastAppliedPipelineLayout = nullptr;
-
             StorageBufferLengthTracker* mLengthTracker;
         };
 
diff --git a/src/dawn_native/opengl/CommandBufferGL.cpp b/src/dawn_native/opengl/CommandBufferGL.cpp
index ea753f5..5943a8d 100644
--- a/src/dawn_native/opengl/CommandBufferGL.cpp
+++ b/src/dawn_native/opengl/CommandBufferGL.cpp
@@ -15,6 +15,7 @@
 #include "dawn_native/opengl/CommandBufferGL.h"
 
 #include "dawn_native/BindGroup.h"
+#include "dawn_native/BindGroupTracker.h"
 #include "dawn_native/CommandEncoder.h"
 #include "dawn_native/Commands.h"
 #include "dawn_native/RenderBundle.h"
@@ -215,88 +216,109 @@
             RenderPipelineBase* mLastPipeline = nullptr;
         };
 
-        // Handles SetBindGroup commands with the specifics of translating to OpenGL texture and
-        // buffer units
-        void ApplyBindGroup(const OpenGLFunctions& gl,
-                            uint32_t index,
-                            BindGroupBase* group,
-                            PipelineLayout* pipelineLayout,
-                            PipelineGL* pipeline,
-                            uint32_t dynamicOffsetCount,
-                            uint64_t* dynamicOffsets) {
-            const auto& indices = pipelineLayout->GetBindingIndexInfo()[index];
-            const auto& layout = group->GetLayout()->GetBindingInfo();
-            uint32_t currentDynamicIndex = 0;
+        class BindGroupTracker : public BindGroupTrackerBase<BindGroupBase*, false> {
+          public:
+            void OnSetPipeline(RenderPipeline* pipeline) {
+                BindGroupTrackerBase::OnSetPipeline(pipeline);
+                mPipeline = pipeline;
+            }
 
-            for (uint32_t bindingIndex : IterateBitSet(layout.mask)) {
-                switch (layout.types[bindingIndex]) {
-                    case dawn::BindingType::UniformBuffer: {
-                        BufferBinding binding = group->GetBindingAsBufferBinding(bindingIndex);
-                        GLuint buffer = ToBackend(binding.buffer)->GetHandle();
-                        GLuint uboIndex = indices[bindingIndex];
-                        GLuint offset = binding.offset;
+            void OnSetPipeline(ComputePipeline* pipeline) {
+                BindGroupTrackerBase::OnSetPipeline(pipeline);
+                mPipeline = pipeline;
+            }
 
-                        if (layout.dynamic[bindingIndex]) {
-                            offset += dynamicOffsets[currentDynamicIndex];
-                            ++currentDynamicIndex;
-                        }
+            void Apply(const OpenGLFunctions& gl) {
+                for (uint32_t index : IterateBitSet(mDirtyBindGroupsObjectChangedOrIsDynamic)) {
+                    ApplyBindGroup(gl, index, mBindGroups[index], mDynamicOffsetCounts[index],
+                                   mDynamicOffsets[index].data());
+                }
+                DidApply();
+            }
 
-                        gl.BindBufferRange(GL_UNIFORM_BUFFER, uboIndex, buffer, offset,
-                                           binding.size);
-                    } break;
+          private:
+            void ApplyBindGroup(const OpenGLFunctions& gl,
+                                uint32_t index,
+                                BindGroupBase* group,
+                                uint32_t dynamicOffsetCount,
+                                uint64_t* dynamicOffsets) {
+                const auto& indices = ToBackend(mPipelineLayout)->GetBindingIndexInfo()[index];
+                const auto& layout = group->GetLayout()->GetBindingInfo();
+                uint32_t currentDynamicIndex = 0;
 
-                    case dawn::BindingType::Sampler: {
-                        Sampler* sampler = ToBackend(group->GetBindingAsSampler(bindingIndex));
-                        GLuint samplerIndex = indices[bindingIndex];
+                for (uint32_t bindingIndex : IterateBitSet(layout.mask)) {
+                    switch (layout.types[bindingIndex]) {
+                        case dawn::BindingType::UniformBuffer: {
+                            BufferBinding binding = group->GetBindingAsBufferBinding(bindingIndex);
+                            GLuint buffer = ToBackend(binding.buffer)->GetHandle();
+                            GLuint uboIndex = indices[bindingIndex];
+                            GLuint offset = binding.offset;
 
-                        for (PipelineGL::SamplerUnit unit :
-                             pipeline->GetTextureUnitsForSampler(samplerIndex)) {
-                            // Only use filtering for certain texture units, because int and uint
-                            // texture are only complete without filtering
-                            if (unit.shouldUseFiltering) {
-                                gl.BindSampler(unit.unit, sampler->GetFilteringHandle());
-                            } else {
-                                gl.BindSampler(unit.unit, sampler->GetNonFilteringHandle());
+                            if (layout.dynamic[bindingIndex]) {
+                                offset += dynamicOffsets[currentDynamicIndex];
+                                ++currentDynamicIndex;
                             }
-                        }
-                    } break;
 
-                    case dawn::BindingType::SampledTexture: {
-                        TextureView* view = ToBackend(group->GetBindingAsTextureView(bindingIndex));
-                        GLuint handle = view->GetHandle();
-                        GLenum target = view->GetGLTarget();
-                        GLuint viewIndex = indices[bindingIndex];
+                            gl.BindBufferRange(GL_UNIFORM_BUFFER, uboIndex, buffer, offset,
+                                               binding.size);
+                        } break;
 
-                        for (auto unit : pipeline->GetTextureUnitsForTextureView(viewIndex)) {
-                            gl.ActiveTexture(GL_TEXTURE0 + unit);
-                            gl.BindTexture(target, handle);
-                        }
-                    } break;
+                        case dawn::BindingType::Sampler: {
+                            Sampler* sampler = ToBackend(group->GetBindingAsSampler(bindingIndex));
+                            GLuint samplerIndex = indices[bindingIndex];
 
-                    case dawn::BindingType::StorageBuffer: {
-                        BufferBinding binding = group->GetBindingAsBufferBinding(bindingIndex);
-                        GLuint buffer = ToBackend(binding.buffer)->GetHandle();
-                        GLuint ssboIndex = indices[bindingIndex];
-                        GLuint offset = binding.offset;
+                            for (PipelineGL::SamplerUnit unit :
+                                 mPipeline->GetTextureUnitsForSampler(samplerIndex)) {
+                                // Only use filtering for certain texture units, because int and
+                                // uint texture are only complete without filtering
+                                if (unit.shouldUseFiltering) {
+                                    gl.BindSampler(unit.unit, sampler->GetFilteringHandle());
+                                } else {
+                                    gl.BindSampler(unit.unit, sampler->GetNonFilteringHandle());
+                                }
+                            }
+                        } break;
 
-                        if (layout.dynamic[bindingIndex]) {
-                            offset += dynamicOffsets[currentDynamicIndex];
-                            ++currentDynamicIndex;
-                        }
+                        case dawn::BindingType::SampledTexture: {
+                            TextureView* view =
+                                ToBackend(group->GetBindingAsTextureView(bindingIndex));
+                            GLuint handle = view->GetHandle();
+                            GLenum target = view->GetGLTarget();
+                            GLuint viewIndex = indices[bindingIndex];
 
-                        gl.BindBufferRange(GL_SHADER_STORAGE_BUFFER, ssboIndex, buffer, offset,
-                                           binding.size);
-                    } break;
+                            for (auto unit : mPipeline->GetTextureUnitsForTextureView(viewIndex)) {
+                                gl.ActiveTexture(GL_TEXTURE0 + unit);
+                                gl.BindTexture(target, handle);
+                            }
+                        } break;
 
-                    case dawn::BindingType::StorageTexture:
-                    case dawn::BindingType::ReadonlyStorageBuffer:
-                        UNREACHABLE();
-                        break;
+                        case dawn::BindingType::StorageBuffer: {
+                            BufferBinding binding = group->GetBindingAsBufferBinding(bindingIndex);
+                            GLuint buffer = ToBackend(binding.buffer)->GetHandle();
+                            GLuint ssboIndex = indices[bindingIndex];
+                            GLuint offset = binding.offset;
 
-                        // TODO(shaobo.yan@intel.com): Implement dynamic buffer offset.
+                            if (layout.dynamic[bindingIndex]) {
+                                offset += dynamicOffsets[currentDynamicIndex];
+                                ++currentDynamicIndex;
+                            }
+
+                            gl.BindBufferRange(GL_SHADER_STORAGE_BUFFER, ssboIndex, buffer, offset,
+                                               binding.size);
+                        } break;
+
+                        case dawn::BindingType::StorageTexture:
+                        case dawn::BindingType::ReadonlyStorageBuffer:
+                            UNREACHABLE();
+                            break;
+
+                            // TODO(shaobo.yan@intel.com): Implement dynamic buffer offset.
+                    }
                 }
             }
-        }
+
+            PipelineGL* mPipeline = nullptr;
+        };
 
         void ResolveMultisampledRenderTargets(const OpenGLFunctions& gl,
                                               const BeginRenderPassCmd* renderPass) {
@@ -608,6 +630,7 @@
     void CommandBuffer::ExecuteComputePass() {
         const OpenGLFunctions& gl = ToBackend(GetDevice())->gl;
         ComputePipeline* lastPipeline = nullptr;
+        BindGroupTracker bindGroupTracker = {};
 
         Command type;
         while (mCommands.NextCommandId(&type)) {
@@ -619,6 +642,8 @@
 
                 case Command::Dispatch: {
                     DispatchCmd* dispatch = mCommands.NextCommand<DispatchCmd>();
+                    bindGroupTracker.Apply(gl);
+
                     gl.DispatchCompute(dispatch->x, dispatch->y, dispatch->z);
                     // TODO(cwallez@chromium.org): add barriers to the API
                     gl.MemoryBarrier(GL_ALL_BARRIER_BITS);
@@ -626,6 +651,7 @@
 
                 case Command::DispatchIndirect: {
                     DispatchIndirectCmd* dispatch = mCommands.NextCommand<DispatchIndirectCmd>();
+                    bindGroupTracker.Apply(gl);
 
                     uint64_t indirectBufferOffset = dispatch->indirectOffset;
                     Buffer* indirectBuffer = ToBackend(dispatch->indirectBuffer.Get());
@@ -640,6 +666,8 @@
                     SetComputePipelineCmd* cmd = mCommands.NextCommand<SetComputePipelineCmd>();
                     lastPipeline = ToBackend(cmd->pipeline).Get();
                     lastPipeline->ApplyNow();
+
+                    bindGroupTracker.OnSetPipeline(lastPipeline);
                 } break;
 
                 case Command::SetBindGroup: {
@@ -648,9 +676,8 @@
                     if (cmd->dynamicOffsetCount > 0) {
                         dynamicOffsets = mCommands.NextData<uint64_t>(cmd->dynamicOffsetCount);
                     }
-                    ApplyBindGroup(gl, cmd->index, cmd->group.Get(),
-                                   ToBackend(lastPipeline->GetLayout()), lastPipeline,
-                                   cmd->dynamicOffsetCount, dynamicOffsets);
+                    bindGroupTracker.OnSetBindGroup(cmd->index, cmd->group.Get(),
+                                                    cmd->dynamicOffsetCount, dynamicOffsets);
                 } break;
 
                 case Command::InsertDebugMarker:
@@ -802,12 +829,14 @@
         uint64_t indexBufferBaseOffset = 0;
 
         InputBufferTracker inputBuffers;
+        BindGroupTracker bindGroupTracker = {};
 
         auto DoRenderBundleCommand = [&](CommandIterator* iter, Command type) {
             switch (type) {
                 case Command::Draw: {
                     DrawCmd* draw = iter->NextCommand<DrawCmd>();
                     inputBuffers.Apply(gl);
+                    bindGroupTracker.Apply(gl);
 
                     if (draw->firstInstance > 0) {
                         gl.DrawArraysInstancedBaseInstance(
@@ -824,6 +853,7 @@
                 case Command::DrawIndexed: {
                     DrawIndexedCmd* draw = iter->NextCommand<DrawIndexedCmd>();
                     inputBuffers.Apply(gl);
+                    bindGroupTracker.Apply(gl);
 
                     dawn::IndexFormat indexFormat =
                         lastPipeline->GetVertexInputDescriptor()->indexFormat;
@@ -849,6 +879,7 @@
                 case Command::DrawIndirect: {
                     DrawIndirectCmd* draw = iter->NextCommand<DrawIndirectCmd>();
                     inputBuffers.Apply(gl);
+                    bindGroupTracker.Apply(gl);
 
                     uint64_t indirectBufferOffset = draw->indirectOffset;
                     Buffer* indirectBuffer = ToBackend(draw->indirectBuffer.Get());
@@ -862,6 +893,7 @@
                 case Command::DrawIndexedIndirect: {
                     DrawIndexedIndirectCmd* draw = iter->NextCommand<DrawIndexedIndirectCmd>();
                     inputBuffers.Apply(gl);
+                    bindGroupTracker.Apply(gl);
 
                     dawn::IndexFormat indexFormat =
                         lastPipeline->GetVertexInputDescriptor()->indexFormat;
@@ -890,6 +922,7 @@
                     lastPipeline->ApplyNow(persistentPipelineState);
 
                     inputBuffers.OnSetPipeline(lastPipeline);
+                    bindGroupTracker.OnSetPipeline(lastPipeline);
                 } break;
 
                 case Command::SetBindGroup: {
@@ -898,9 +931,8 @@
                     if (cmd->dynamicOffsetCount > 0) {
                         dynamicOffsets = iter->NextData<uint64_t>(cmd->dynamicOffsetCount);
                     }
-                    ApplyBindGroup(gl, cmd->index, cmd->group.Get(),
-                                   ToBackend(lastPipeline->GetLayout()), lastPipeline,
-                                   cmd->dynamicOffsetCount, dynamicOffsets);
+                    bindGroupTracker.OnSetBindGroup(cmd->index, cmd->group.Get(),
+                                                    cmd->dynamicOffsetCount, dynamicOffsets);
                 } break;
 
                 case Command::SetIndexBuffer: {
diff --git a/src/dawn_native/vulkan/CommandBufferVk.cpp b/src/dawn_native/vulkan/CommandBufferVk.cpp
index 8eaca7e..71584fa 100644
--- a/src/dawn_native/vulkan/CommandBufferVk.cpp
+++ b/src/dawn_native/vulkan/CommandBufferVk.cpp
@@ -14,6 +14,7 @@
 
 #include "dawn_native/vulkan/CommandBufferVk.h"
 
+#include "dawn_native/BindGroupTracker.h"
 #include "dawn_native/CommandEncoder.h"
 #include "dawn_native/Commands.h"
 #include "dawn_native/RenderBundle.h"
@@ -89,61 +90,19 @@
             return region;
         }
 
-        class DescriptorSetTracker {
+        class DescriptorSetTracker : public BindGroupTrackerBase<VkDescriptorSet, true, uint32_t> {
           public:
-            void OnSetBindGroup(uint32_t index,
-                                VkDescriptorSet set,
-                                uint32_t dynamicOffsetCount,
-                                uint64_t* dynamicOffsets) {
-                mDirtySets.set(index);
-                mSets[index] = set;
-                mDynamicOffsetCounts[index] = dynamicOffsetCount;
-                if (dynamicOffsetCount > 0) {
-                    // Vulkan backend use uint32_t as dynamic offsets type, it is not correct.
-                    // Vulkan should use VkDeviceSize. Dawn vulkan backend has to handle this.
-                    for (uint32_t i = 0; i < dynamicOffsetCount; ++i) {
-                        ASSERT(dynamicOffsets[i] <= std::numeric_limits<uint32_t>::max());
-                        mDynamicOffsets[index][i] = static_cast<uint32_t>(dynamicOffsets[i]);
-                    }
-                }
-            }
-
-            void OnPipelineLayoutChange(PipelineLayout* layout) {
-                if (layout == mCurrentLayout) {
-                    return;
-                }
-
-                if (mCurrentLayout == nullptr) {
-                    // We're at the beginning of a pass so all bind groups will be set before any
-                    // draw / dispatch. Still clear the dirty sets to avoid leftover dirty sets
-                    // from previous passes.
-                    mDirtySets.reset();
-                } else {
-                    // Bindgroups that are not inherited will be set again before any draw or
-                    // dispatch. Resetting the bits also makes sure we don't have leftover dirty
-                    // bindgroups that don't exist in the pipeline layout.
-                    mDirtySets &= ~layout->InheritedGroupsMask(mCurrentLayout);
-                }
-                mCurrentLayout = layout;
-            }
-
-            void Flush(Device* device, VkCommandBuffer commands, VkPipelineBindPoint bindPoint) {
-                for (uint32_t dirtyIndex : IterateBitSet(mDirtySets)) {
+            void Apply(Device* device, VkCommandBuffer commands, VkPipelineBindPoint bindPoint) {
+                for (uint32_t dirtyIndex :
+                     IterateBitSet(mDirtyBindGroupsObjectChangedOrIsDynamic)) {
                     device->fn.CmdBindDescriptorSets(
-                        commands, bindPoint, mCurrentLayout->GetHandle(), dirtyIndex, 1,
-                        &mSets[dirtyIndex], mDynamicOffsetCounts[dirtyIndex],
+                        commands, bindPoint, ToBackend(mPipelineLayout)->GetHandle(), dirtyIndex, 1,
+                        &mBindGroups[dirtyIndex], mDynamicOffsetCounts[dirtyIndex],
                         mDynamicOffsetCounts[dirtyIndex] > 0 ? mDynamicOffsets[dirtyIndex].data()
                                                              : nullptr);
                 }
-                mDirtySets.reset();
+                DidApply();
             }
-
-          private:
-            PipelineLayout* mCurrentLayout = nullptr;
-            std::array<VkDescriptorSet, kMaxBindGroups> mSets;
-            std::bitset<kMaxBindGroups> mDirtySets;
-            std::array<uint32_t, kMaxBindGroups> mDynamicOffsetCounts;
-            std::array<std::array<uint32_t, kMaxBindingsPerGroup>, kMaxBindGroups> mDynamicOffsets;
         };
 
         void RecordBeginRenderPass(CommandRecordingContext* recordingContext,
@@ -574,7 +533,7 @@
         Device* device = ToBackend(GetDevice());
         VkCommandBuffer commands = recordingContext->commandBuffer;
 
-        DescriptorSetTracker descriptorSets;
+        DescriptorSetTracker descriptorSets = {};
 
         Command type;
         while (mCommands.NextCommandId(&type)) {
@@ -586,7 +545,7 @@
 
                 case Command::Dispatch: {
                     DispatchCmd* dispatch = mCommands.NextCommand<DispatchCmd>();
-                    descriptorSets.Flush(device, commands, VK_PIPELINE_BIND_POINT_COMPUTE);
+                    descriptorSets.Apply(device, commands, VK_PIPELINE_BIND_POINT_COMPUTE);
                     device->fn.CmdDispatch(commands, dispatch->x, dispatch->y, dispatch->z);
                 } break;
 
@@ -594,7 +553,7 @@
                     DispatchIndirectCmd* dispatch = mCommands.NextCommand<DispatchIndirectCmd>();
                     VkBuffer indirectBuffer = ToBackend(dispatch->indirectBuffer)->GetHandle();
 
-                    descriptorSets.Flush(device, commands, VK_PIPELINE_BIND_POINT_COMPUTE);
+                    descriptorSets.Apply(device, commands, VK_PIPELINE_BIND_POINT_COMPUTE);
                     device->fn.CmdDispatchIndirect(
                         commands, indirectBuffer,
                         static_cast<VkDeviceSize>(dispatch->indirectOffset));
@@ -618,7 +577,7 @@
 
                     device->fn.CmdBindPipeline(commands, VK_PIPELINE_BIND_POINT_COMPUTE,
                                                pipeline->GetHandle());
-                    descriptorSets.OnPipelineLayoutChange(ToBackend(pipeline->GetLayout()));
+                    descriptorSets.OnSetPipeline(pipeline);
                 } break;
 
                 case Command::InsertDebugMarker: {
@@ -715,7 +674,7 @@
             device->fn.CmdSetScissor(commands, 0, 1, &scissorRect);
         }
 
-        DescriptorSetTracker descriptorSets;
+        DescriptorSetTracker descriptorSets = {};
         RenderPipeline* lastPipeline = nullptr;
 
         auto EncodeRenderBundleCommand = [&](CommandIterator* iter, Command type) {
@@ -723,7 +682,7 @@
                 case Command::Draw: {
                     DrawCmd* draw = iter->NextCommand<DrawCmd>();
 
-                    descriptorSets.Flush(device, commands, VK_PIPELINE_BIND_POINT_GRAPHICS);
+                    descriptorSets.Apply(device, commands, VK_PIPELINE_BIND_POINT_GRAPHICS);
                     device->fn.CmdDraw(commands, draw->vertexCount, draw->instanceCount,
                                        draw->firstVertex, draw->firstInstance);
                 } break;
@@ -731,7 +690,7 @@
                 case Command::DrawIndexed: {
                     DrawIndexedCmd* draw = iter->NextCommand<DrawIndexedCmd>();
 
-                    descriptorSets.Flush(device, commands, VK_PIPELINE_BIND_POINT_GRAPHICS);
+                    descriptorSets.Apply(device, commands, VK_PIPELINE_BIND_POINT_GRAPHICS);
                     device->fn.CmdDrawIndexed(commands, draw->indexCount, draw->instanceCount,
                                               draw->firstIndex, draw->baseVertex,
                                               draw->firstInstance);
@@ -741,7 +700,7 @@
                     DrawIndirectCmd* draw = iter->NextCommand<DrawIndirectCmd>();
                     VkBuffer indirectBuffer = ToBackend(draw->indirectBuffer)->GetHandle();
 
-                    descriptorSets.Flush(device, commands, VK_PIPELINE_BIND_POINT_GRAPHICS);
+                    descriptorSets.Apply(device, commands, VK_PIPELINE_BIND_POINT_GRAPHICS);
                     device->fn.CmdDrawIndirect(commands, indirectBuffer,
                                                static_cast<VkDeviceSize>(draw->indirectOffset), 1,
                                                0);
@@ -751,7 +710,7 @@
                     DrawIndirectCmd* draw = iter->NextCommand<DrawIndirectCmd>();
                     VkBuffer indirectBuffer = ToBackend(draw->indirectBuffer)->GetHandle();
 
-                    descriptorSets.Flush(device, commands, VK_PIPELINE_BIND_POINT_GRAPHICS);
+                    descriptorSets.Apply(device, commands, VK_PIPELINE_BIND_POINT_GRAPHICS);
                     device->fn.CmdDrawIndexedIndirect(
                         commands, indirectBuffer, static_cast<VkDeviceSize>(draw->indirectOffset),
                         1, 0);
@@ -837,7 +796,7 @@
                                                pipeline->GetHandle());
                     lastPipeline = pipeline;
 
-                    descriptorSets.OnPipelineLayoutChange(ToBackend(pipeline->GetLayout()));
+                    descriptorSets.OnSetPipeline(pipeline);
                 } break;
 
                 case Command::SetVertexBuffers: {
diff --git a/src/tests/end2end/BindGroupTests.cpp b/src/tests/end2end/BindGroupTests.cpp
index 46e825a..df2251c 100644
--- a/src/tests/end2end/BindGroupTests.cpp
+++ b/src/tests/end2end/BindGroupTests.cpp
@@ -497,9 +497,6 @@
 
 // Test that bind groups can be set before the pipeline.
 TEST_P(BindGroupTests, SetBindGroupBeforePipeline) {
-    // TODO(crbug.com/dawn/201): Implement on all platforms.
-    DAWN_SKIP_TEST_IF(!IsMetal());
-
     utils::BasicRenderPass renderPass = utils::CreateBasicRenderPass(device, kRTSize, kRTSize);
 
     // Create a bind group layout which uses a single uniform buffer.
@@ -542,9 +539,6 @@
 
 // Test that dynamic bind groups can be set before the pipeline.
 TEST_P(BindGroupTests, SetDynamicBindGroupBeforePipeline) {
-    // TODO(crbug.com/dawn/201): Implement on all platforms.
-    DAWN_SKIP_TEST_IF(!IsMetal());
-
     utils::BasicRenderPass renderPass = utils::CreateBasicRenderPass(device, kRTSize, kRTSize);
 
     // Create a bind group layout which uses a single dynamic uniform buffer.
@@ -558,8 +552,8 @@
 
     // Prepare data RGBAunorm(1, 0, 0, 0.5) and RGBAunorm(0, 1, 0, 0.5). They will be added in the
     // shader.
-    std::array<float, 4> color0 = {1, 0, 0, 0.5};
-    std::array<float, 4> color1 = {0, 1, 0, 0.5};
+    std::array<float, 4> color0 = {1, 0, 0, 0.501};
+    std::array<float, 4> color1 = {0, 1, 0, 0.501};
 
     size_t color1Offset = Align(sizeof(color0), kMinDynamicBufferOffsetAlignment);
 
@@ -606,9 +600,6 @@
 
 // Test that bind groups set for one pipeline are still set when the pipeline changes.
 TEST_P(BindGroupTests, BindGroupsPersistAfterPipelineChange) {
-    // TODO(crbug.com/dawn/201): Implement on all platforms.
-    DAWN_SKIP_TEST_IF(!IsMetal());
-
     utils::BasicRenderPass renderPass = utils::CreateBasicRenderPass(device, kRTSize, kRTSize);
 
     // Create a bind group layout which uses a single dynamic uniform buffer.
@@ -687,9 +678,6 @@
 // Do a successful draw. Then, change the pipeline and one bind group.
 // Draw to check that the all bind groups are set.
 TEST_P(BindGroupTests, DrawThenChangePipelineAndBindGroup) {
-    // TODO(crbug.com/dawn/201): Implement on all platforms.
-    DAWN_SKIP_TEST_IF(!IsMetal());
-
     utils::BasicRenderPass renderPass = utils::CreateBasicRenderPass(device, kRTSize, kRTSize);
 
     // Create a bind group layout which uses a single dynamic uniform buffer.
@@ -715,7 +703,7 @@
     // The second draw will use { color0, color3, color2 }.
     // The pipeline uses additive color blending so the result of two draws should be
     // { 2 * color0 + color1 + color2 + color3} = RGBAunorm(1, 1, 1, 1)
-    std::array<float, 4> color0 = {0.5, 0, 0, 0};
+    std::array<float, 4> color0 = {0.501, 0, 0, 0};
     std::array<float, 4> color1 = {0, 1, 0, 0};
     std::array<float, 4> color2 = {0, 0, 0, 1};
     std::array<float, 4> color3 = {0, 0, 1, 0};