Metal: Support setting bind groups before pipeline to match WebGPU semantics

Bug: dawn:201
Change-Id: I3bd03bbce3c38d0182e5e93f3898a43183bd647d
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/10840
Commit-Queue: Austin Eng <enga@chromium.org>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
diff --git a/src/dawn_native/metal/CommandBufferMTL.mm b/src/dawn_native/metal/CommandBufferMTL.mm
index ef89ff2..4e512f6 100644
--- a/src/dawn_native/metal/CommandBufferMTL.mm
+++ b/src/dawn_native/metal/CommandBufferMTL.mm
@@ -206,7 +206,7 @@
             // MSL code generated by SPIRV-Cross expects.
             PerStage<std::array<uint32_t, kGenericMetalBufferSlots>> data;
 
-            void Apply(RenderPipeline* pipeline, id<MTLRenderCommandEncoder> render) {
+            void Apply(id<MTLRenderCommandEncoder> render, RenderPipeline* pipeline) {
                 dawn::ShaderStage stagesToApply =
                     dirtyStages & pipeline->GetStagesRequiringStorageBufferLength();
 
@@ -234,7 +234,7 @@
                 dirtyStages ^= stagesToApply;
             }
 
-            void Apply(ComputePipeline* pipeline, id<MTLComputeCommandEncoder> compute) {
+            void Apply(id<MTLComputeCommandEncoder> compute, ComputePipeline* pipeline) {
                 if (!(dirtyStages & dawn::ShaderStage::Compute)) {
                     return;
                 }
@@ -253,128 +253,6 @@
             }
         };
 
-        // Handles a call to SetBindGroup, directing the commands to the correct encoder.
-        // There is a single function that takes both encoders to factor code. Other approaches like
-        // templates wouldn't work because the name of methods are different between the two encoder
-        // types.
-        void ApplyBindGroup(uint32_t index,
-                            BindGroup* group,
-                            uint32_t dynamicOffsetCount,
-                            uint64_t* dynamicOffsets,
-                            PipelineLayout* pipelineLayout,
-                            StorageBufferLengthTracker* lengthTracker,
-                            id<MTLRenderCommandEncoder> render,
-                            id<MTLComputeCommandEncoder> compute) {
-            const auto& layout = group->GetLayout()->GetBindingInfo();
-            uint32_t currentDynamicBufferIndex = 0;
-
-            // TODO(kainino@chromium.org): Maintain buffers and offsets arrays in BindGroup
-            // so that we only have to do one setVertexBuffers and one setFragmentBuffers
-            // call here.
-            for (uint32_t bindingIndex : IterateBitSet(layout.mask)) {
-                auto stage = layout.visibilities[bindingIndex];
-                bool hasVertStage = stage & dawn::ShaderStage::Vertex && render != nil;
-                bool hasFragStage = stage & dawn::ShaderStage::Fragment && render != nil;
-                bool hasComputeStage = stage & dawn::ShaderStage::Compute && compute != nil;
-
-                uint32_t vertIndex = 0;
-                uint32_t fragIndex = 0;
-                uint32_t computeIndex = 0;
-
-                if (hasVertStage) {
-                    vertIndex = pipelineLayout->GetBindingIndexInfo(
-                        SingleShaderStage::Vertex)[index][bindingIndex];
-                }
-                if (hasFragStage) {
-                    fragIndex = pipelineLayout->GetBindingIndexInfo(
-                        SingleShaderStage::Fragment)[index][bindingIndex];
-                }
-                if (hasComputeStage) {
-                    computeIndex = pipelineLayout->GetBindingIndexInfo(
-                        SingleShaderStage::Compute)[index][bindingIndex];
-                }
-
-                switch (layout.types[bindingIndex]) {
-                    case dawn::BindingType::UniformBuffer:
-                    case dawn::BindingType::StorageBuffer: {
-                        const BufferBinding& binding =
-                            group->GetBindingAsBufferBinding(bindingIndex);
-                        const id<MTLBuffer> buffer = ToBackend(binding.buffer)->GetMTLBuffer();
-                        NSUInteger offset = binding.offset;
-
-                        // TODO(shaobo.yan@intel.com): Record bound buffer status to use
-                        // setBufferOffset to achieve better performance.
-                        if (layout.dynamic[bindingIndex]) {
-                            offset += dynamicOffsets[currentDynamicBufferIndex];
-                            currentDynamicBufferIndex++;
-                        }
-
-                        if (hasVertStage) {
-                            lengthTracker->data[SingleShaderStage::Vertex][vertIndex] =
-                                binding.size;
-                            lengthTracker->dirtyStages |= dawn::ShaderStage::Vertex;
-                            [render setVertexBuffers:&buffer
-                                             offsets:&offset
-                                           withRange:NSMakeRange(vertIndex, 1)];
-                        }
-                        if (hasFragStage) {
-                            lengthTracker->data[SingleShaderStage::Fragment][fragIndex] =
-                                binding.size;
-                            lengthTracker->dirtyStages |= dawn::ShaderStage::Fragment;
-                            [render setFragmentBuffers:&buffer
-                                               offsets:&offset
-                                             withRange:NSMakeRange(fragIndex, 1)];
-                        }
-                        if (hasComputeStage) {
-                            lengthTracker->data[SingleShaderStage::Compute][computeIndex] =
-                                binding.size;
-                            lengthTracker->dirtyStages |= dawn::ShaderStage::Compute;
-                            [compute setBuffers:&buffer
-                                        offsets:&offset
-                                      withRange:NSMakeRange(computeIndex, 1)];
-                        }
-
-                    } break;
-
-                    case dawn::BindingType::Sampler: {
-                        auto sampler = ToBackend(group->GetBindingAsSampler(bindingIndex));
-                        if (hasVertStage) {
-                            [render setVertexSamplerState:sampler->GetMTLSamplerState()
-                                                  atIndex:vertIndex];
-                        }
-                        if (hasFragStage) {
-                            [render setFragmentSamplerState:sampler->GetMTLSamplerState()
-                                                    atIndex:fragIndex];
-                        }
-                        if (hasComputeStage) {
-                            [compute setSamplerState:sampler->GetMTLSamplerState()
-                                             atIndex:computeIndex];
-                        }
-                    } break;
-
-                    case dawn::BindingType::SampledTexture: {
-                        auto textureView = ToBackend(group->GetBindingAsTextureView(bindingIndex));
-                        if (hasVertStage) {
-                            [render setVertexTexture:textureView->GetMTLTexture()
-                                             atIndex:vertIndex];
-                        }
-                        if (hasFragStage) {
-                            [render setFragmentTexture:textureView->GetMTLTexture()
-                                               atIndex:fragIndex];
-                        }
-                        if (hasComputeStage) {
-                            [compute setTexture:textureView->GetMTLTexture() atIndex:computeIndex];
-                        }
-                    } break;
-
-                    case dawn::BindingType::StorageTexture:
-                    case dawn::BindingType::ReadonlyStorageBuffer:
-                        UNREACHABLE();
-                        break;
-                }
-            }
-        }
-
         struct TextureBufferCopySplit {
             static constexpr uint32_t kMaxTextureBufferCopyRegions = 3;
 
@@ -511,6 +389,219 @@
             return copy;
         }
 
+        // Keeps track of the dirty bind groups so they can be lazily applied when we know the
+        // pipeline state.
+        class BindGroupTracker {
+          public:
+            explicit BindGroupTracker(StorageBufferLengthTracker* lengthTracker)
+                : mLengthTracker(lengthTracker) {
+            }
+
+            void OnSetBindGroup(uint32_t index,
+                                BindGroup* bindGroup,
+                                uint32_t dynamicOffsetCount,
+                                uint64_t* dynamicOffsets) {
+                ASSERT(index < kMaxBindGroups);
+
+                if (mBindGroupLayoutsMask[index]) {
+                    // It is okay to only dirty bind groups that are used by the current pipeline
+                    // layout. If the pipeline layout changes, then the bind groups it uses will
+                    // become dirty.
+                    mDirtyBindGroups.set(index);
+                }
+
+                mBindGroups[index] = bindGroup;
+                mDynamicOffsetCounts[index] = dynamicOffsetCount;
+                memcpy(mDynamicOffsets[index].data(), dynamicOffsets,
+                       sizeof(uint64_t) * dynamicOffsetCount);
+            }
+
+            void OnSetPipeline(PipelineBase* pipeline) {
+                mPipelineLayout = ToBackend(pipeline->GetLayout());
+                if (mLastAppliedPipelineLayout == mPipelineLayout) {
+                    return;
+                }
+
+                // Keep track of the bind group layout mask to avoid marking unused bind groups as
+                // dirty. This also allows us to avoid computing the intersection of the dirty bind
+                // groups and bind group layout mask in Draw or Dispatch which is very hot code.
+                mBindGroupLayoutsMask = mPipelineLayout->GetBindGroupLayoutsMask();
+
+                // Changing the pipeline layout sets bind groups as dirty. The first |k| matching
+                // bind groups may be inherited because bind groups are packed in the buffer /
+                // texture tables in contiguous order.
+                if (mLastAppliedPipelineLayout != nullptr) {
+                    // Dirty bind groups that cannot be inherited.
+                    mDirtyBindGroups |=
+                        ~mPipelineLayout->InheritedGroupsMask(mLastAppliedPipelineLayout);
+                    mDirtyBindGroups &= mBindGroupLayoutsMask;
+                } else {
+                    mDirtyBindGroups = mBindGroupLayoutsMask;
+                }
+            }
+
+            template <typename Encoder>
+            void Apply(Encoder encoder) {
+                for (uint32_t index : IterateBitSet(mDirtyBindGroups)) {
+                    ApplyBindGroup(encoder, index, mBindGroups[index], mDynamicOffsetCounts[index],
+                                   mDynamicOffsets[index].data(), mPipelineLayout);
+                }
+
+                // Reset all dirty bind groups. Dirty bind groups not in the bind group layout mask
+                // will be dirtied again by the next pipeline change.
+                mDirtyBindGroups.reset();
+                mLastAppliedPipelineLayout = mPipelineLayout;
+            }
+
+          private:
+            // Handles a call to SetBindGroup, directing the commands to the correct encoder.
+            // There is a single function that takes both encoders to factor code. Other approaches
+            // like templates wouldn't work because the name of methods are different between the
+            // two encoder types.
+            void ApplyBindGroupImpl(id<MTLRenderCommandEncoder> render,
+                                    id<MTLComputeCommandEncoder> compute,
+                                    uint32_t index,
+                                    BindGroup* group,
+                                    uint32_t dynamicOffsetCount,
+                                    uint64_t* dynamicOffsets,
+                                    PipelineLayout* pipelineLayout) {
+                const auto& layout = group->GetLayout()->GetBindingInfo();
+                uint32_t currentDynamicBufferIndex = 0;
+
+                // TODO(kainino@chromium.org): Maintain buffers and offsets arrays in BindGroup
+                // so that we only have to do one setVertexBuffers and one setFragmentBuffers
+                // call here.
+                for (uint32_t bindingIndex : IterateBitSet(layout.mask)) {
+                    auto stage = layout.visibilities[bindingIndex];
+                    bool hasVertStage = stage & dawn::ShaderStage::Vertex && render != nil;
+                    bool hasFragStage = stage & dawn::ShaderStage::Fragment && render != nil;
+                    bool hasComputeStage = stage & dawn::ShaderStage::Compute && compute != nil;
+
+                    uint32_t vertIndex = 0;
+                    uint32_t fragIndex = 0;
+                    uint32_t computeIndex = 0;
+
+                    if (hasVertStage) {
+                        vertIndex = pipelineLayout->GetBindingIndexInfo(
+                            SingleShaderStage::Vertex)[index][bindingIndex];
+                    }
+                    if (hasFragStage) {
+                        fragIndex = pipelineLayout->GetBindingIndexInfo(
+                            SingleShaderStage::Fragment)[index][bindingIndex];
+                    }
+                    if (hasComputeStage) {
+                        computeIndex = pipelineLayout->GetBindingIndexInfo(
+                            SingleShaderStage::Compute)[index][bindingIndex];
+                    }
+
+                    switch (layout.types[bindingIndex]) {
+                        case dawn::BindingType::UniformBuffer:
+                        case dawn::BindingType::StorageBuffer: {
+                            const BufferBinding& binding =
+                                group->GetBindingAsBufferBinding(bindingIndex);
+                            const id<MTLBuffer> buffer = ToBackend(binding.buffer)->GetMTLBuffer();
+                            NSUInteger offset = binding.offset;
+
+                            // TODO(shaobo.yan@intel.com): Record bound buffer status to use
+                            // setBufferOffset to achieve better performance.
+                            if (layout.dynamic[bindingIndex]) {
+                                offset += dynamicOffsets[currentDynamicBufferIndex];
+                                currentDynamicBufferIndex++;
+                            }
+
+                            if (hasVertStage) {
+                                mLengthTracker->data[SingleShaderStage::Vertex][vertIndex] =
+                                    binding.size;
+                                mLengthTracker->dirtyStages |= dawn::ShaderStage::Vertex;
+                                [render setVertexBuffers:&buffer
+                                                 offsets:&offset
+                                               withRange:NSMakeRange(vertIndex, 1)];
+                            }
+                            if (hasFragStage) {
+                                mLengthTracker->data[SingleShaderStage::Fragment][fragIndex] =
+                                    binding.size;
+                                mLengthTracker->dirtyStages |= dawn::ShaderStage::Fragment;
+                                [render setFragmentBuffers:&buffer
+                                                   offsets:&offset
+                                                 withRange:NSMakeRange(fragIndex, 1)];
+                            }
+                            if (hasComputeStage) {
+                                mLengthTracker->data[SingleShaderStage::Compute][computeIndex] =
+                                    binding.size;
+                                mLengthTracker->dirtyStages |= dawn::ShaderStage::Compute;
+                                [compute setBuffers:&buffer
+                                            offsets:&offset
+                                          withRange:NSMakeRange(computeIndex, 1)];
+                            }
+
+                        } break;
+
+                        case dawn::BindingType::Sampler: {
+                            auto sampler = ToBackend(group->GetBindingAsSampler(bindingIndex));
+                            if (hasVertStage) {
+                                [render setVertexSamplerState:sampler->GetMTLSamplerState()
+                                                      atIndex:vertIndex];
+                            }
+                            if (hasFragStage) {
+                                [render setFragmentSamplerState:sampler->GetMTLSamplerState()
+                                                        atIndex:fragIndex];
+                            }
+                            if (hasComputeStage) {
+                                [compute setSamplerState:sampler->GetMTLSamplerState()
+                                                 atIndex:computeIndex];
+                            }
+                        } break;
+
+                        case dawn::BindingType::SampledTexture: {
+                            auto textureView =
+                                ToBackend(group->GetBindingAsTextureView(bindingIndex));
+                            if (hasVertStage) {
+                                [render setVertexTexture:textureView->GetMTLTexture()
+                                                 atIndex:vertIndex];
+                            }
+                            if (hasFragStage) {
+                                [render setFragmentTexture:textureView->GetMTLTexture()
+                                                   atIndex:fragIndex];
+                            }
+                            if (hasComputeStage) {
+                                [compute setTexture:textureView->GetMTLTexture()
+                                            atIndex:computeIndex];
+                            }
+                        } break;
+
+                        case dawn::BindingType::StorageTexture:
+                        case dawn::BindingType::ReadonlyStorageBuffer:
+                            UNREACHABLE();
+                            break;
+                    }
+                }
+            }
+
+            template <typename... Args>
+            void ApplyBindGroup(id<MTLRenderCommandEncoder> encoder, Args&&... args) {
+                ApplyBindGroupImpl(encoder, nil, std::forward<Args&&>(args)...);
+            }
+
+            template <typename... Args>
+            void ApplyBindGroup(id<MTLComputeCommandEncoder> encoder, Args&&... args) {
+                ApplyBindGroupImpl(nil, encoder, std::forward<Args&&>(args)...);
+            }
+
+            std::bitset<kMaxBindGroups> mDirtyBindGroups;
+            std::bitset<kMaxBindGroups> mBindGroupLayoutsMask;
+            std::array<BindGroup*, kMaxBindGroups> mBindGroups;
+            std::array<uint32_t, kMaxBindGroups> mDynamicOffsetCounts;
+            std::array<std::array<uint64_t, kMaxBindingsPerGroup>, kMaxBindGroups> mDynamicOffsets;
+
+            // |mPipelineLayout| is the current pipeline layout set on the command buffer.
+            // |mLastAppliedPipelineLayout| is the last pipeline layout for which we applied changes
+            // to the bind group bindings.
+            PipelineLayout* mPipelineLayout = nullptr;
+            PipelineLayout* mLastAppliedPipelineLayout = nullptr;
+
+            StorageBufferLengthTracker* mLengthTracker;
+        };
+
         // Keeps track of the dirty vertex buffer values so they can be lazily applied when we know
         // all the relevant state.
         class VertexInputBufferTracker {
@@ -685,6 +776,7 @@
     void CommandBuffer::EncodeComputePass(id<MTLCommandBuffer> commandBuffer) {
         ComputePipeline* lastPipeline = nullptr;
         StorageBufferLengthTracker storageBufferLengths = {};
+        BindGroupTracker bindGroups(&storageBufferLengths);
 
         // Will be autoreleased
         id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder];
@@ -700,7 +792,9 @@
 
                 case Command::Dispatch: {
                     DispatchCmd* dispatch = mCommands.NextCommand<DispatchCmd>();
-                    storageBufferLengths.Apply(lastPipeline, encoder);
+
+                    bindGroups.Apply(encoder);
+                    storageBufferLengths.Apply(encoder, lastPipeline);
 
                     [encoder dispatchThreadgroups:MTLSizeMake(dispatch->x, dispatch->y, dispatch->z)
                             threadsPerThreadgroup:lastPipeline->GetLocalWorkGroupSize()];
@@ -708,7 +802,9 @@
 
                 case Command::DispatchIndirect: {
                     DispatchIndirectCmd* dispatch = mCommands.NextCommand<DispatchIndirectCmd>();
-                    storageBufferLengths.Apply(lastPipeline, encoder);
+
+                    bindGroups.Apply(encoder);
+                    storageBufferLengths.Apply(encoder, lastPipeline);
 
                     Buffer* buffer = ToBackend(dispatch->indirectBuffer.Get());
                     id<MTLBuffer> indirectBuffer = buffer->GetMTLBuffer();
@@ -722,6 +818,8 @@
                     SetComputePipelineCmd* cmd = mCommands.NextCommand<SetComputePipelineCmd>();
                     lastPipeline = ToBackend(cmd->pipeline).Get();
 
+                    bindGroups.OnSetPipeline(lastPipeline);
+
                     lastPipeline->Encode(encoder);
                 } break;
 
@@ -732,9 +830,8 @@
                         dynamicOffsets = mCommands.NextData<uint64_t>(cmd->dynamicOffsetCount);
                     }
 
-                    ApplyBindGroup(cmd->index, ToBackend(cmd->group.Get()), cmd->dynamicOffsetCount,
-                                   dynamicOffsets, ToBackend(lastPipeline->GetLayout()),
-                                   &storageBufferLengths, nil, encoder);
+                    bindGroups.OnSetBindGroup(cmd->index, ToBackend(cmd->group.Get()),
+                                              cmd->dynamicOffsetCount, dynamicOffsets);
                 } break;
 
                 case Command::InsertDebugMarker: {
@@ -870,6 +967,7 @@
         uint32_t indexBufferBaseOffset = 0;
         VertexInputBufferTracker vertexInputBuffers;
         StorageBufferLengthTracker storageBufferLengths = {};
+        BindGroupTracker bindGroups(&storageBufferLengths);
 
         // This will be autoreleased
         id<MTLRenderCommandEncoder> encoder =
@@ -881,7 +979,8 @@
                     DrawCmd* draw = iter->NextCommand<DrawCmd>();
 
                     vertexInputBuffers.Apply(encoder, lastPipeline);
-                    storageBufferLengths.Apply(lastPipeline, encoder);
+                    bindGroups.Apply(encoder);
+                    storageBufferLengths.Apply(encoder, lastPipeline);
 
                     // The instance count must be non-zero, otherwise no-op
                     if (draw->instanceCount != 0) {
@@ -899,7 +998,8 @@
                         IndexFormatSize(lastPipeline->GetVertexInputDescriptor()->indexFormat);
 
                     vertexInputBuffers.Apply(encoder, lastPipeline);
-                    storageBufferLengths.Apply(lastPipeline, encoder);
+                    bindGroups.Apply(encoder);
+                    storageBufferLengths.Apply(encoder, lastPipeline);
 
                     // The index and instance count must be non-zero, otherwise no-op
                     if (draw->indexCount != 0 && draw->instanceCount != 0) {
@@ -919,7 +1019,8 @@
                     DrawIndirectCmd* draw = iter->NextCommand<DrawIndirectCmd>();
 
                     vertexInputBuffers.Apply(encoder, lastPipeline);
-                    storageBufferLengths.Apply(lastPipeline, encoder);
+                    bindGroups.Apply(encoder);
+                    storageBufferLengths.Apply(encoder, lastPipeline);
 
                     Buffer* buffer = ToBackend(draw->indirectBuffer.Get());
                     id<MTLBuffer> indirectBuffer = buffer->GetMTLBuffer();
@@ -932,7 +1033,8 @@
                     DrawIndirectCmd* draw = iter->NextCommand<DrawIndirectCmd>();
 
                     vertexInputBuffers.Apply(encoder, lastPipeline);
-                    storageBufferLengths.Apply(lastPipeline, encoder);
+                    bindGroups.Apply(encoder);
+                    storageBufferLengths.Apply(encoder, lastPipeline);
 
                     Buffer* buffer = ToBackend(draw->indirectBuffer.Get());
                     id<MTLBuffer> indirectBuffer = buffer->GetMTLBuffer();
@@ -973,6 +1075,8 @@
                     RenderPipeline* newPipeline = ToBackend(cmd->pipeline).Get();
 
                     vertexInputBuffers.OnSetPipeline(lastPipeline, newPipeline);
+                    bindGroups.OnSetPipeline(newPipeline);
+
                     [encoder setDepthStencilState:newPipeline->GetMTLDepthStencilState()];
                     [encoder setFrontFacingWinding:newPipeline->GetMTLFrontFace()];
                     [encoder setCullMode:newPipeline->GetMTLCullMode()];
@@ -988,9 +1092,8 @@
                         dynamicOffsets = iter->NextData<uint64_t>(cmd->dynamicOffsetCount);
                     }
 
-                    ApplyBindGroup(cmd->index, ToBackend(cmd->group.Get()), cmd->dynamicOffsetCount,
-                                   dynamicOffsets, ToBackend(lastPipeline->GetLayout()),
-                                   &storageBufferLengths, encoder, nil);
+                    bindGroups.OnSetBindGroup(cmd->index, ToBackend(cmd->group.Get()),
+                                              cmd->dynamicOffsetCount, dynamicOffsets);
                 } break;
 
                 case Command::SetIndexBuffer: {
diff --git a/src/tests/end2end/BindGroupTests.cpp b/src/tests/end2end/BindGroupTests.cpp
index 30f2194..46e825a 100644
--- a/src/tests/end2end/BindGroupTests.cpp
+++ b/src/tests/end2end/BindGroupTests.cpp
@@ -14,6 +14,7 @@
 
 #include "common/Assert.h"
 #include "common/Constants.h"
+#include "common/Math.h"
 #include "tests/DawnTest.h"
 #include "utils/ComboRenderPipelineDescriptor.h"
 #include "utils/DawnHelpers.h"
@@ -43,6 +44,75 @@
 
         return device.CreatePipelineLayout(&descriptor);
     }
+
+    dawn::ShaderModule MakeSimpleVSModule() const {
+        return utils::CreateShaderModule(device, utils::SingleShaderStage::Vertex, R"(
+        #version 450
+        void main() {
+            const vec2 pos[3] = vec2[3](vec2(-1.f, -1.f), vec2(1.f, -1.f), vec2(-1.f, 1.f));
+            gl_Position = vec4(pos[gl_VertexIndex], 0.f, 1.f);
+        })");
+    }
+
+    dawn::ShaderModule MakeFSModule(std::vector<dawn::BindingType> bindingTypes) const {
+        ASSERT(bindingTypes.size() <= kMaxBindGroups);
+
+        std::ostringstream fs;
+        fs << R"(
+        #version 450
+        layout(location = 0) out vec4 fragColor;
+        )";
+
+        for (size_t i = 0; i < bindingTypes.size(); ++i) {
+            switch (bindingTypes[i]) {
+                case dawn::BindingType::UniformBuffer:
+                    fs << "layout (std140, set = " << i << ", binding = 0) uniform UniformBuffer" << i << R"( {
+                        vec4 color;
+                    } buffer)" << i << ";\n";
+                    break;
+                case dawn::BindingType::StorageBuffer:
+                    fs << "layout (std430, set = " << i << ", binding = 0) buffer StorageBuffer" << i << R"( {
+                        vec4 color;
+                    } buffer)" << i << ";\n";
+                    break;
+                default:
+                    UNREACHABLE();
+            }
+        }
+
+        fs << R"(
+        void main() {
+            fragColor = vec4(0.0);
+        )";
+        for (size_t i = 0; i < bindingTypes.size(); ++i) {
+            fs << "fragColor += buffer" << i << ".color;\n";
+        }
+        fs << "}\n";
+
+        return utils::CreateShaderModule(device, utils::SingleShaderStage::Fragment, fs.str().c_str());
+    }
+
+    dawn::RenderPipeline MakeTestPipeline(
+        const utils::BasicRenderPass& renderPass,
+        std::vector<dawn::BindingType> bindingTypes,
+        std::vector<dawn::BindGroupLayout> bindGroupLayouts) {
+
+        dawn::ShaderModule vsModule = MakeSimpleVSModule();
+        dawn::ShaderModule fsModule = MakeFSModule(bindingTypes);
+
+        dawn::PipelineLayout pipelineLayout = MakeBasicPipelineLayout(device, bindGroupLayouts);
+
+        utils::ComboRenderPipelineDescriptor pipelineDescriptor(device);
+        pipelineDescriptor.layout = pipelineLayout;
+        pipelineDescriptor.vertexStage.module = vsModule;
+        pipelineDescriptor.cFragmentStage.module = fsModule;
+        pipelineDescriptor.cColorStates[0]->format = renderPass.colorFormat;
+        pipelineDescriptor.cColorStates[0]->colorBlend.operation = dawn::BlendOperation::Add;
+        pipelineDescriptor.cColorStates[0]->colorBlend.srcFactor = dawn::BlendFactor::One;
+        pipelineDescriptor.cColorStates[0]->colorBlend.dstFactor = dawn::BlendFactor::One;
+
+        return device.CreateRenderPipeline(&pipelineDescriptor);
+    }
 };
 
 // Test a bindgroup reused in two command buffers in the same call to queue.Submit().
@@ -380,60 +450,28 @@
 
 // This test reproduces an out-of-bound bug on D3D12 backends when calling draw command twice with
 // one pipeline that has 4 bind group sets in one render pass.
-TEST_P(BindGroupTests, DrawTwiceInSamePipelineWithFourBindGroupSets)
-{
+TEST_P(BindGroupTests, DrawTwiceInSamePipelineWithFourBindGroupSets) {
     utils::BasicRenderPass renderPass = utils::CreateBasicRenderPass(device, kRTSize, kRTSize);
 
-    dawn::ShaderModule vsModule =
-        utils::CreateShaderModule(device, utils::SingleShaderStage::Vertex, R"(
-        #version 450
-        void main() {
-            const vec2 pos[3] = vec2[3](vec2(-1.f, -1.f), vec2(1.f, -1.f), vec2(-1.f, 1.f));
-            gl_Position = vec4(pos[gl_VertexIndex], 0.f, 1.f);
-        })");
-
-    dawn::ShaderModule fsModule =
-        utils::CreateShaderModule(device, utils::SingleShaderStage::Fragment, R"(
-        #version 450
-        layout (std140, set = 0, binding = 0) uniform fragmentUniformBuffer1 {
-            vec4 color1;
-        };
-        layout (std140, set = 1, binding = 0) uniform fragmentUniformBuffer2 {
-            vec4 color2;
-        };
-        layout (std140, set = 2, binding = 0) uniform fragmentUniformBuffer3 {
-            vec4 color3;
-        };
-        layout (std140, set = 3, binding = 0) uniform fragmentUniformBuffer4 {
-            vec4 color4;
-        };
-        layout(location = 0) out vec4 fragColor;
-        void main() {
-            fragColor = color1 + color2 + color3 + color4;
-        })");
-
     dawn::BindGroupLayout layout = utils::MakeBindGroupLayout(
         device, {{0, dawn::ShaderStage::Fragment, dawn::BindingType::UniformBuffer}});
-    dawn::PipelineLayout pipelineLayout = MakeBasicPipelineLayout(
-        device, { layout, layout, layout, layout });
 
-    utils::ComboRenderPipelineDescriptor pipelineDescriptor(device);
-    pipelineDescriptor.layout = pipelineLayout;
-    pipelineDescriptor.vertexStage.module = vsModule;
-    pipelineDescriptor.cFragmentStage.module = fsModule;
-    pipelineDescriptor.cColorStates[0]->format = renderPass.colorFormat;
+    dawn::RenderPipeline pipeline =
+        MakeTestPipeline(renderPass,
+                         {dawn::BindingType::UniformBuffer, dawn::BindingType::UniformBuffer,
+                          dawn::BindingType::UniformBuffer, dawn::BindingType::UniformBuffer},
+                         {layout, layout, layout, layout});
 
-    dawn::RenderPipeline pipeline = device.CreateRenderPipeline(&pipelineDescriptor);
     dawn::CommandEncoder encoder = device.CreateCommandEncoder();
     dawn::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass.renderPassInfo);
 
     pass.SetPipeline(pipeline);
 
-    std::array<float, 4> color = { 0.25, 0, 0, 0.25 };
+    std::array<float, 4> color = {0.25, 0, 0, 0.25};
     dawn::Buffer uniformBuffer =
         utils::CreateBufferFromData(device, &color, sizeof(color), dawn::BufferUsage::Uniform);
-    dawn::BindGroup bindGroup = utils::MakeBindGroup(
-        device, layout, { { 0, uniformBuffer, 0, sizeof(color) } });
+    dawn::BindGroup bindGroup =
+        utils::MakeBindGroup(device, layout, {{0, uniformBuffer, 0, sizeof(color)}});
 
     pass.SetBindGroup(0, bindGroup, 0, nullptr);
     pass.SetBindGroup(1, bindGroup, 0, nullptr);
@@ -457,4 +495,297 @@
     EXPECT_PIXEL_RGBA8_EQ(notFilled, renderPass.color, max, max);
 }
 
+// Test that bind groups can be set before the pipeline.
+TEST_P(BindGroupTests, SetBindGroupBeforePipeline) {
+    // TODO(crbug.com/dawn/201): Implement on all platforms.
+    DAWN_SKIP_TEST_IF(!IsMetal());
+
+    utils::BasicRenderPass renderPass = utils::CreateBasicRenderPass(device, kRTSize, kRTSize);
+
+    // Create a bind group layout which uses a single uniform buffer.
+    dawn::BindGroupLayout layout = utils::MakeBindGroupLayout(
+        device, {{0, dawn::ShaderStage::Fragment, dawn::BindingType::UniformBuffer}});
+
+    // Create a pipeline that uses the uniform bind group layout.
+    dawn::RenderPipeline pipeline =
+        MakeTestPipeline(renderPass, {dawn::BindingType::UniformBuffer}, {layout});
+
+    dawn::CommandEncoder encoder = device.CreateCommandEncoder();
+    dawn::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass.renderPassInfo);
+
+    // Create a bind group with a uniform buffer and fill it with RGBAunorm(1, 0, 0, 1).
+    std::array<float, 4> color = {1, 0, 0, 1};
+    dawn::Buffer uniformBuffer =
+        utils::CreateBufferFromData(device, &color, sizeof(color), dawn::BufferUsage::Uniform);
+    dawn::BindGroup bindGroup =
+        utils::MakeBindGroup(device, layout, {{0, uniformBuffer, 0, sizeof(color)}});
+
+    // Set the bind group, then the pipeline, and draw.
+    pass.SetBindGroup(0, bindGroup, 0, nullptr);
+    pass.SetPipeline(pipeline);
+    pass.Draw(3, 1, 0, 0);
+
+    pass.EndPass();
+
+    dawn::CommandBuffer commands = encoder.Finish();
+    queue.Submit(1, &commands);
+
+    // The result should be red.
+    RGBA8 filled(255, 0, 0, 255);
+    RGBA8 notFilled(0, 0, 0, 0);
+    int min = 1, max = kRTSize - 3;
+    EXPECT_PIXEL_RGBA8_EQ(filled, renderPass.color, min, min);
+    EXPECT_PIXEL_RGBA8_EQ(filled, renderPass.color, max, min);
+    EXPECT_PIXEL_RGBA8_EQ(filled, renderPass.color, min, max);
+    EXPECT_PIXEL_RGBA8_EQ(notFilled, renderPass.color, max, max);
+}
+
+// Test that dynamic bind groups can be set before the pipeline.
+TEST_P(BindGroupTests, SetDynamicBindGroupBeforePipeline) {
+    // TODO(crbug.com/dawn/201): Implement on all platforms.
+    DAWN_SKIP_TEST_IF(!IsMetal());
+
+    utils::BasicRenderPass renderPass = utils::CreateBasicRenderPass(device, kRTSize, kRTSize);
+
+    // Create a bind group layout which uses a single dynamic uniform buffer.
+    dawn::BindGroupLayout layout = utils::MakeBindGroupLayout(
+        device, {{0, dawn::ShaderStage::Fragment, dawn::BindingType::UniformBuffer, true}});
+
+    // Create a pipeline that uses the dynamic uniform bind group layout for two bind groups.
+    dawn::RenderPipeline pipeline = MakeTestPipeline(
+        renderPass, {dawn::BindingType::UniformBuffer, dawn::BindingType::UniformBuffer},
+        {layout, layout});
+
+    // Prepare data RGBAunorm(1, 0, 0, 0.5) and RGBAunorm(0, 1, 0, 0.5). They will be added in the
+    // shader.
+    std::array<float, 4> color0 = {1, 0, 0, 0.5};
+    std::array<float, 4> color1 = {0, 1, 0, 0.5};
+
+    size_t color1Offset = Align(sizeof(color0), kMinDynamicBufferOffsetAlignment);
+
+    std::vector<uint8_t> data(color1Offset + sizeof(color1));
+    memcpy(data.data(), color0.data(), sizeof(color0));
+    memcpy(data.data() + color1Offset, color1.data(), sizeof(color1));
+
+    // Create a bind group and uniform buffer with the color data. It will be bound at the offset
+    // to each color.
+    dawn::Buffer uniformBuffer =
+        utils::CreateBufferFromData(device, data.data(), data.size(), dawn::BufferUsage::Uniform);
+    dawn::BindGroup bindGroup =
+        utils::MakeBindGroup(device, layout, {{0, uniformBuffer, 0, 4 * sizeof(float)}});
+
+    dawn::CommandEncoder encoder = device.CreateCommandEncoder();
+    dawn::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass.renderPassInfo);
+
+    // Set the first dynamic bind group.
+    uint64_t dynamicOffset = 0;
+    pass.SetBindGroup(0, bindGroup, 1, &dynamicOffset);
+
+    // Set the second dynamic bind group.
+    dynamicOffset = color1Offset;
+    pass.SetBindGroup(1, bindGroup, 1, &dynamicOffset);
+
+    // Set the pipeline and draw.
+    pass.SetPipeline(pipeline);
+    pass.Draw(3, 1, 0, 0);
+
+    pass.EndPass();
+
+    dawn::CommandBuffer commands = encoder.Finish();
+    queue.Submit(1, &commands);
+
+    // The result should be RGBAunorm(1, 0, 0, 0.5) + RGBAunorm(0, 1, 0, 0.5)
+    RGBA8 filled(255, 255, 0, 255);
+    RGBA8 notFilled(0, 0, 0, 0);
+    int min = 1, max = kRTSize - 3;
+    EXPECT_PIXEL_RGBA8_EQ(filled, renderPass.color, min, min);
+    EXPECT_PIXEL_RGBA8_EQ(filled, renderPass.color, max, min);
+    EXPECT_PIXEL_RGBA8_EQ(filled, renderPass.color, min, max);
+    EXPECT_PIXEL_RGBA8_EQ(notFilled, renderPass.color, max, max);
+}
+
+// Test that bind groups set for one pipeline are still set when the pipeline changes.
+TEST_P(BindGroupTests, BindGroupsPersistAfterPipelineChange) {
+    // TODO(crbug.com/dawn/201): Implement on all platforms.
+    DAWN_SKIP_TEST_IF(!IsMetal());
+
+    utils::BasicRenderPass renderPass = utils::CreateBasicRenderPass(device, kRTSize, kRTSize);
+
+    // Create a bind group layout which uses a single dynamic uniform buffer.
+    dawn::BindGroupLayout uniformLayout = utils::MakeBindGroupLayout(
+        device, {{0, dawn::ShaderStage::Fragment, dawn::BindingType::UniformBuffer, true}});
+
+    // Create a bind group layout which uses a single dynamic storage buffer.
+    dawn::BindGroupLayout storageLayout = utils::MakeBindGroupLayout(
+        device, {{0, dawn::ShaderStage::Fragment, dawn::BindingType::StorageBuffer, true}});
+
+    // Create a pipeline which uses the uniform buffer and storage buffer bind groups.
+    dawn::RenderPipeline pipeline0 = MakeTestPipeline(
+        renderPass, {dawn::BindingType::UniformBuffer, dawn::BindingType::StorageBuffer},
+        {uniformLayout, storageLayout});
+
+    // Create a pipeline which uses the uniform buffer bind group twice.
+    dawn::RenderPipeline pipeline1 = MakeTestPipeline(
+        renderPass, {dawn::BindingType::UniformBuffer, dawn::BindingType::UniformBuffer},
+        {uniformLayout, uniformLayout});
+
+    // Prepare data RGBAunorm(1, 0, 0, 0.5) and RGBAunorm(0, 1, 0, 0.5). They will be added in the
+    // shader.
+    std::array<float, 4> color0 = {1, 0, 0, 0.5};
+    std::array<float, 4> color1 = {0, 1, 0, 0.5};
+
+    size_t color1Offset = Align(sizeof(color0), kMinDynamicBufferOffsetAlignment);
+
+    std::vector<uint8_t> data(color1Offset + sizeof(color1));
+    memcpy(data.data(), color0.data(), sizeof(color0));
+    memcpy(data.data() + color1Offset, color1.data(), sizeof(color1));
+
+    // Create a bind group and uniform buffer with the color data. It will be bound at the offset
+    // to each color.
+    dawn::Buffer uniformBuffer =
+        utils::CreateBufferFromData(device, data.data(), data.size(), dawn::BufferUsage::Uniform);
+    dawn::BindGroup bindGroup =
+        utils::MakeBindGroup(device, uniformLayout, {{0, uniformBuffer, 0, 4 * sizeof(float)}});
+
+    dawn::CommandEncoder encoder = device.CreateCommandEncoder();
+    dawn::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass.renderPassInfo);
+
+    // Set the first pipeline (uniform, storage).
+    pass.SetPipeline(pipeline0);
+
+    // Set the first bind group at a dynamic offset.
+    // This bind group matches the slot in the pipeline layout.
+    uint64_t dynamicOffset = 0;
+    pass.SetBindGroup(0, bindGroup, 1, &dynamicOffset);
+
+    // Set the second bind group at a dynamic offset.
+    // This bind group does not match the slot in the pipeline layout.
+    dynamicOffset = color1Offset;
+    pass.SetBindGroup(1, bindGroup, 1, &dynamicOffset);
+
+    // Set the second pipeline (uniform, uniform).
+    // Both bind groups match the pipeline.
+    // They should persist and not need to be bound again.
+    pass.SetPipeline(pipeline1);
+    pass.Draw(3, 1, 0, 0);
+
+    pass.EndPass();
+
+    dawn::CommandBuffer commands = encoder.Finish();
+    queue.Submit(1, &commands);
+
+    // The result should be RGBAunorm(1, 0, 0, 0.5) + RGBAunorm(0, 1, 0, 0.5)
+    RGBA8 filled(255, 255, 0, 255);
+    RGBA8 notFilled(0, 0, 0, 0);
+    int min = 1, max = kRTSize - 3;
+    EXPECT_PIXEL_RGBA8_EQ(filled, renderPass.color, min, min);
+    EXPECT_PIXEL_RGBA8_EQ(filled, renderPass.color, max, min);
+    EXPECT_PIXEL_RGBA8_EQ(filled, renderPass.color, min, max);
+    EXPECT_PIXEL_RGBA8_EQ(notFilled, renderPass.color, max, max);
+}
+
+// Do a successful draw. Then, change the pipeline and one bind group.
+// Draw to check that the all bind groups are set.
+TEST_P(BindGroupTests, DrawThenChangePipelineAndBindGroup) {
+    // TODO(crbug.com/dawn/201): Implement on all platforms.
+    DAWN_SKIP_TEST_IF(!IsMetal());
+
+    utils::BasicRenderPass renderPass = utils::CreateBasicRenderPass(device, kRTSize, kRTSize);
+
+    // Create a bind group layout which uses a single dynamic uniform buffer.
+    dawn::BindGroupLayout uniformLayout = utils::MakeBindGroupLayout(
+        device, {{0, dawn::ShaderStage::Fragment, dawn::BindingType::UniformBuffer, true}});
+
+    // Create a bind group layout which uses a single dynamic storage buffer.
+    dawn::BindGroupLayout storageLayout = utils::MakeBindGroupLayout(
+        device, {{0, dawn::ShaderStage::Fragment, dawn::BindingType::StorageBuffer, true}});
+
+    // Create a pipeline with pipeline layout (uniform, uniform, storage).
+    dawn::RenderPipeline pipeline0 = MakeTestPipeline(
+        renderPass, {dawn::BindingType::UniformBuffer, dawn::BindingType::UniformBuffer, dawn::BindingType::StorageBuffer},
+        {uniformLayout, uniformLayout, storageLayout});
+
+    // Create a pipeline with pipeline layout (uniform, storage, storage).
+    dawn::RenderPipeline pipeline1 = MakeTestPipeline(
+        renderPass, {dawn::BindingType::UniformBuffer, dawn::BindingType::StorageBuffer, dawn::BindingType::StorageBuffer },
+        {uniformLayout, storageLayout, storageLayout});
+
+    // Prepare color data.
+    // The first draw will use { color0, color1, color2 }.
+    // The second draw will use { color0, color3, color2 }.
+    // The pipeline uses additive color blending so the result of two draws should be
+    // { 2 * color0 + color1 + color2 + color3} = RGBAunorm(1, 1, 1, 1)
+    std::array<float, 4> color0 = {0.5, 0, 0, 0};
+    std::array<float, 4> color1 = {0, 1, 0, 0};
+    std::array<float, 4> color2 = {0, 0, 0, 1};
+    std::array<float, 4> color3 = {0, 0, 1, 0};
+
+    size_t color1Offset = Align(sizeof(color0), kMinDynamicBufferOffsetAlignment);
+    size_t color2Offset = Align(color1Offset + sizeof(color1), kMinDynamicBufferOffsetAlignment);
+    size_t color3Offset = Align(color2Offset + sizeof(color2), kMinDynamicBufferOffsetAlignment);
+
+    std::vector<uint8_t> data(color3Offset + sizeof(color3), 0);
+    memcpy(data.data(), color0.data(), sizeof(color0));
+    memcpy(data.data() + color1Offset, color1.data(), sizeof(color1));
+    memcpy(data.data() + color2Offset, color2.data(), sizeof(color2));
+    memcpy(data.data() + color3Offset, color3.data(), sizeof(color3));
+
+    // Create a uniform and storage buffer bind groups to bind the color data.
+    dawn::Buffer uniformBuffer =
+        utils::CreateBufferFromData(device, data.data(), data.size(), dawn::BufferUsage::Uniform);
+
+    dawn::Buffer storageBuffer =
+        utils::CreateBufferFromData(device, data.data(), data.size(), dawn::BufferUsage::Storage);
+
+    dawn::BindGroup uniformBindGroup =
+        utils::MakeBindGroup(device, uniformLayout, {{0, uniformBuffer, 0, 4 * sizeof(float)}});
+    dawn::BindGroup storageBindGroup =
+        utils::MakeBindGroup(device, storageLayout, {{0, storageBuffer, 0, 4 * sizeof(float)}});
+
+    dawn::CommandEncoder encoder = device.CreateCommandEncoder();
+    dawn::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass.renderPassInfo);
+
+    // Set the pipeline to (uniform, uniform, storage)
+    pass.SetPipeline(pipeline0);
+
+    // Set the first bind group to color0 in the dynamic uniform buffer.
+    uint64_t dynamicOffset = 0;
+    pass.SetBindGroup(0, uniformBindGroup, 1, &dynamicOffset);
+
+    // Set the first bind group to color1 in the dynamic uniform buffer.
+    dynamicOffset = color1Offset;
+    pass.SetBindGroup(1, uniformBindGroup, 1, &dynamicOffset);
+
+    // Set the first bind group to color2 in the dynamic storage buffer.
+    dynamicOffset = color2Offset;
+    pass.SetBindGroup(2, storageBindGroup, 1, &dynamicOffset);
+
+    pass.Draw(3, 1, 0, 0);
+
+    // Set the pipeline to (uniform, storage, storage)
+    //  - The first bind group should persist (inherited on some backends)
+    //  - The second bind group needs to be set again to pass validation.
+    //    It changed from uniform to storage.
+    //  - The third bind group should persist. It should be set again by the backend internally.
+    pass.SetPipeline(pipeline1);
+
+    // Set the second bind group to color3 in the dynamic storage buffer.
+    dynamicOffset = color3Offset;
+    pass.SetBindGroup(1, storageBindGroup, 1, &dynamicOffset);
+
+    pass.Draw(3, 1, 0, 0);
+    pass.EndPass();
+
+    dawn::CommandBuffer commands = encoder.Finish();
+    queue.Submit(1, &commands);
+
+    RGBA8 filled(255, 255, 255, 255);
+    RGBA8 notFilled(0, 0, 0, 0);
+    int min = 1, max = kRTSize - 3;
+    EXPECT_PIXEL_RGBA8_EQ(filled, renderPass.color, min, min);
+    EXPECT_PIXEL_RGBA8_EQ(filled, renderPass.color, max, min);
+    EXPECT_PIXEL_RGBA8_EQ(filled, renderPass.color, min, max);
+    EXPECT_PIXEL_RGBA8_EQ(notFilled, renderPass.color, max, max);
+}
+
 DAWN_INSTANTIATE_TEST(BindGroupTests, D3D12Backend, MetalBackend, OpenGLBackend, VulkanBackend);