Metal: Pack vertex buffers just after the pipeline layout

WebGPU have a 2D pipeline layout plus a vertex buffer table while Metal
has a single vertex buffer table that contains everything (including
uniform and storage buffers). Previously the space for vertex buffers
was statically allocated in that table which made the last vertex buffer
go out of bound of the Metal vertex buffer table.

This fixes the issue by packing all the vertex buffers that are used
right after the vertex buffers used by the bind groups. This is a
drive-by fix found while looking at reserving Metal vertex buffer 30 to
contain the shader storage buffer lengths.

BUG=dawn:195

Change-Id: If5c67bbc0d15c976793ef43889e50e4a360217d7
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/9387
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
diff --git a/src/dawn_native/metal/BufferMTL.h b/src/dawn_native/metal/BufferMTL.h
index c4395b2..f05f38e 100644
--- a/src/dawn_native/metal/BufferMTL.h
+++ b/src/dawn_native/metal/BufferMTL.h
@@ -29,7 +29,7 @@
         Buffer(Device* device, const BufferDescriptor* descriptor);
         ~Buffer();
 
-        id<MTLBuffer> GetMTLBuffer();
+        id<MTLBuffer> GetMTLBuffer() const;
 
         void OnMapCommandSerialFinished(uint32_t mapSerial, bool isWrite);
 
diff --git a/src/dawn_native/metal/BufferMTL.mm b/src/dawn_native/metal/BufferMTL.mm
index d4cb82f..0bdc09d 100644
--- a/src/dawn_native/metal/BufferMTL.mm
+++ b/src/dawn_native/metal/BufferMTL.mm
@@ -34,7 +34,7 @@
         DestroyInternal();
     }
 
-    id<MTLBuffer> Buffer::GetMTLBuffer() {
+    id<MTLBuffer> Buffer::GetMTLBuffer() const {
         return mMtlBuffer;
     }
 
diff --git a/src/dawn_native/metal/CommandBufferMTL.mm b/src/dawn_native/metal/CommandBufferMTL.mm
index 5f6331f..dd65f3e 100644
--- a/src/dawn_native/metal/CommandBufferMTL.mm
+++ b/src/dawn_native/metal/CommandBufferMTL.mm
@@ -437,6 +437,55 @@
 
             return copy;
         }
+
+        // Keeps track of the dirty vertex buffer values so they can be lazily applied when we know
+        // all the relevant state.
+        class VertexInputBufferTracker {
+          public:
+            void OnSetVertexBuffers(uint32_t startSlot,
+                                    uint32_t count,
+                                    const Ref<BufferBase>* buffers,
+                                    const uint64_t* offsets) {
+                for (uint32_t i = 0; i < count; ++i) {
+                    uint32_t slot = startSlot + i;
+                    mVertexBuffers[slot] = ToBackend(buffers[i].Get())->GetMTLBuffer();
+                    mVertexBufferOffsets[slot] = offsets[i];
+                }
+
+                // Use 64 bit masks and make sure there are no shift UB
+                static_assert(kMaxVertexBuffers <= 8 * sizeof(unsigned long long) - 1, "");
+                mDirtyVertexBuffers |= ((1ull << count) - 1ull) << startSlot;
+            }
+
+            void OnSetPipeline(RenderPipeline* lastPipeline, RenderPipeline* pipeline) {
+                // When a new pipeline is bound we must set all the vertex buffers again because
+                // they might have been offset by the pipeline layout, and they might be packed
+                // differently from the previous pipeline.
+                mDirtyVertexBuffers |= pipeline->GetInputsSetMask();
+            }
+
+            void Apply(id<MTLRenderCommandEncoder> encoder, RenderPipeline* pipeline) {
+                std::bitset<kMaxVertexBuffers> vertexBuffersToApply =
+                    mDirtyVertexBuffers & pipeline->GetInputsSetMask();
+
+                for (uint32_t dawnIndex : IterateBitSet(vertexBuffersToApply)) {
+                    uint32_t metalIndex = pipeline->GetMtlVertexBufferIndex(dawnIndex);
+
+                    [encoder setVertexBuffers:&mVertexBuffers[dawnIndex]
+                                      offsets:&mVertexBufferOffsets[dawnIndex]
+                                    withRange:NSMakeRange(metalIndex, 1)];
+                }
+
+                mDirtyVertexBuffers.reset();
+            }
+
+          private:
+            // All the indices in these arrays are Dawn vertex buffer indices
+            std::bitset<kMaxVertexBuffers> mDirtyVertexBuffers;
+            std::array<id<MTLBuffer>, kMaxVertexBuffers> mVertexBuffers;
+            std::array<NSUInteger, kMaxVertexBuffers> mVertexBufferOffsets;
+        };
+
     }  // anonymous namespace
 
     CommandBuffer::CommandBuffer(CommandEncoderBase* encoder,
@@ -718,6 +767,7 @@
         RenderPipeline* lastPipeline = nullptr;
         id<MTLBuffer> indexBuffer = nil;
         uint32_t indexBufferBaseOffset = 0;
+        VertexInputBufferTracker vertexInputBuffers;
 
         // This will be autoreleased
         id<MTLRenderCommandEncoder> encoder =
@@ -735,6 +785,8 @@
                 case Command::Draw: {
                     DrawCmd* draw = mCommands.NextCommand<DrawCmd>();
 
+                    vertexInputBuffers.Apply(encoder, lastPipeline);
+
                     // The instance count must be non-zero, otherwise no-op
                     if (draw->instanceCount != 0) {
                         [encoder drawPrimitives:lastPipeline->GetMTLPrimitiveTopology()
@@ -750,6 +802,8 @@
                     size_t formatSize =
                         IndexFormatSize(lastPipeline->GetVertexInputDescriptor()->indexFormat);
 
+                    vertexInputBuffers.Apply(encoder, lastPipeline);
+
                     // The index and instance count must be non-zero, otherwise no-op
                     if (draw->indexCount != 0 && draw->instanceCount != 0) {
                         [encoder drawIndexedPrimitives:lastPipeline->GetMTLPrimitiveTopology()
@@ -767,6 +821,8 @@
                 case Command::DrawIndirect: {
                     DrawIndirectCmd* draw = mCommands.NextCommand<DrawIndirectCmd>();
 
+                    vertexInputBuffers.Apply(encoder, lastPipeline);
+
                     Buffer* buffer = ToBackend(draw->indirectBuffer.Get());
                     id<MTLBuffer> indirectBuffer = buffer->GetMTLBuffer();
                     [encoder drawPrimitives:lastPipeline->GetMTLPrimitiveTopology()
@@ -777,6 +833,8 @@
                 case Command::DrawIndexedIndirect: {
                     DrawIndirectCmd* draw = mCommands.NextCommand<DrawIndirectCmd>();
 
+                    vertexInputBuffers.Apply(encoder, lastPipeline);
+
                     Buffer* buffer = ToBackend(draw->indirectBuffer.Get());
                     id<MTLBuffer> indirectBuffer = buffer->GetMTLBuffer();
                     [encoder drawIndexedPrimitives:lastPipeline->GetMTLPrimitiveTopology()
@@ -813,12 +871,15 @@
 
                 case Command::SetRenderPipeline: {
                     SetRenderPipelineCmd* cmd = mCommands.NextCommand<SetRenderPipelineCmd>();
-                    lastPipeline = ToBackend(cmd->pipeline).Get();
+                    RenderPipeline* newPipeline = ToBackend(cmd->pipeline).Get();
 
-                    [encoder setDepthStencilState:lastPipeline->GetMTLDepthStencilState()];
-                    [encoder setFrontFacingWinding:lastPipeline->GetMTLFrontFace()];
-                    [encoder setCullMode:lastPipeline->GetMTLCullMode()];
-                    lastPipeline->Encode(encoder);
+                    vertexInputBuffers.OnSetPipeline(lastPipeline, newPipeline);
+                    [encoder setDepthStencilState:newPipeline->GetMTLDepthStencilState()];
+                    [encoder setFrontFacingWinding:newPipeline->GetMTLFrontFace()];
+                    [encoder setCullMode:newPipeline->GetMTLCullMode()];
+                    newPipeline->Encode(encoder);
+
+                    lastPipeline = newPipeline;
                 } break;
 
                 case Command::SetStencilReference: {
@@ -888,24 +949,12 @@
 
                 case Command::SetVertexBuffers: {
                     SetVertexBuffersCmd* cmd = mCommands.NextCommand<SetVertexBuffersCmd>();
-                    auto buffers = mCommands.NextData<Ref<BufferBase>>(cmd->count);
-                    auto offsets = mCommands.NextData<uint64_t>(cmd->count);
+                    const Ref<BufferBase>* buffers =
+                        mCommands.NextData<Ref<BufferBase>>(cmd->count);
+                    const uint64_t* offsets = mCommands.NextData<uint64_t>(cmd->count);
 
-                    std::array<id<MTLBuffer>, kMaxVertexBuffers> mtlBuffers;
-                    std::array<NSUInteger, kMaxVertexBuffers> mtlOffsets;
-
-                    // Perhaps an "array of vertex buffers(+offsets?)" should be
-                    // a Dawn API primitive to avoid reconstructing this array?
-                    for (uint32_t i = 0; i < cmd->count; ++i) {
-                        Buffer* buffer = ToBackend(buffers[i].Get());
-                        mtlBuffers[i] = buffer->GetMTLBuffer();
-                        mtlOffsets[i] = offsets[i];
-                    }
-
-                    [encoder setVertexBuffers:mtlBuffers.data()
-                                      offsets:mtlOffsets.data()
-                                    withRange:NSMakeRange(kMaxBindingsPerGroup + cmd->startSlot,
-                                                          cmd->count)];
+                    vertexInputBuffers.OnSetVertexBuffers(cmd->startSlot, cmd->count, buffers,
+                                                          offsets);
                 } break;
 
                 default: { UNREACHABLE(); } break;
diff --git a/src/dawn_native/metal/PipelineLayoutMTL.h b/src/dawn_native/metal/PipelineLayoutMTL.h
index 77610a5..59ba3b7 100644
--- a/src/dawn_native/metal/PipelineLayoutMTL.h
+++ b/src/dawn_native/metal/PipelineLayoutMTL.h
@@ -37,8 +37,12 @@
             std::array<std::array<uint32_t, kMaxBindingsPerGroup>, kMaxBindGroups>;
         const BindingIndexInfo& GetBindingIndexInfo(ShaderStage stage) const;
 
+        // The number of Metal vertex stage buffers used for the whole pipeline layout.
+        uint32_t GetBufferBindingCount(ShaderStage stage);
+
       private:
         PerStage<BindingIndexInfo> mIndexInfo;
+        PerStage<uint32_t> mBufferBindingCount;
     };
 
 }}  // namespace dawn_native::metal
diff --git a/src/dawn_native/metal/PipelineLayoutMTL.mm b/src/dawn_native/metal/PipelineLayoutMTL.mm
index 282559f..af08c5b 100644
--- a/src/dawn_native/metal/PipelineLayoutMTL.mm
+++ b/src/dawn_native/metal/PipelineLayoutMTL.mm
@@ -60,6 +60,8 @@
                     }
                 }
             }
+
+            mBufferBindingCount[stage] = bufferIndex;
         }
     }
 
@@ -68,4 +70,8 @@
         return mIndexInfo[stage];
     }
 
+    uint32_t PipelineLayout::GetBufferBindingCount(ShaderStage stage) {
+        return mBufferBindingCount[stage];
+    }
+
 }}  // namespace dawn_native::metal
diff --git a/src/dawn_native/metal/RenderPipelineMTL.h b/src/dawn_native/metal/RenderPipelineMTL.h
index a027ffc..10d7525 100644
--- a/src/dawn_native/metal/RenderPipelineMTL.h
+++ b/src/dawn_native/metal/RenderPipelineMTL.h
@@ -37,6 +37,10 @@
 
         id<MTLDepthStencilState> GetMTLDepthStencilState();
 
+        // For each Dawn vertex buffer, give the index in which it will be positioned in the Metal
+        // vertex buffer table.
+        uint32_t GetMtlVertexBufferIndex(uint32_t dawnIndex) const;
+
       private:
         MTLVertexDescriptor* MakeVertexDesc();
 
@@ -46,6 +50,7 @@
         MTLCullMode mMtlCullMode;
         id<MTLRenderPipelineState> mMtlRenderPipelineState = nil;
         id<MTLDepthStencilState> mMtlDepthStencilState = nil;
+        std::array<uint32_t, kMaxVertexBuffers> mMtlVertexBufferIndices;
     };
 
 }}  // namespace dawn_native::metal
diff --git a/src/dawn_native/metal/RenderPipelineMTL.mm b/src/dawn_native/metal/RenderPipelineMTL.mm
index 0bb3a55..ac6ab35 100644
--- a/src/dawn_native/metal/RenderPipelineMTL.mm
+++ b/src/dawn_native/metal/RenderPipelineMTL.mm
@@ -400,24 +400,22 @@
         return mMtlDepthStencilState;
     }
 
+    uint32_t RenderPipeline::GetMtlVertexBufferIndex(uint32_t dawnIndex) const {
+        ASSERT(dawnIndex < kMaxVertexBuffers);
+        return mMtlVertexBufferIndices[dawnIndex];
+    }
+
     MTLVertexDescriptor* RenderPipeline::MakeVertexDesc() {
         MTLVertexDescriptor* mtlVertexDescriptor = [MTLVertexDescriptor new];
 
-        for (uint32_t i : IterateBitSet(GetAttributesSetMask())) {
-            const VertexAttributeInfo& info = GetAttribute(i);
+        // Vertex buffers are packed after all the buffers for the bind groups.
+        uint32_t mtlVertexBufferIndex =
+            ToBackend(GetLayout())->GetBufferBindingCount(ShaderStage::Vertex);
 
-            auto attribDesc = [MTLVertexAttributeDescriptor new];
-            attribDesc.format = VertexFormatType(info.format);
-            attribDesc.offset = info.offset;
-            attribDesc.bufferIndex = kMaxBindingsPerGroup + info.inputSlot;
-            mtlVertexDescriptor.attributes[i] = attribDesc;
-            [attribDesc release];
-        }
+        for (uint32_t dawnVertexBufferIndex : IterateBitSet(GetInputsSetMask())) {
+            const VertexBufferInfo& info = GetInput(dawnVertexBufferIndex);
 
-        for (uint32_t vbInputSlot : IterateBitSet(GetInputsSetMask())) {
-            const VertexBufferInfo& info = GetInput(vbInputSlot);
-
-            auto layoutDesc = [MTLVertexBufferLayoutDescriptor new];
+            MTLVertexBufferLayoutDescriptor* layoutDesc = [MTLVertexBufferLayoutDescriptor new];
             if (info.stride == 0) {
                 // For MTLVertexStepFunctionConstant, the stepRate must be 0,
                 // but the stride must NOT be 0, so we made up it with
@@ -426,7 +424,7 @@
                 for (uint32_t attribIndex : IterateBitSet(GetAttributesSetMask())) {
                     const VertexAttributeInfo& attrib = GetAttribute(attribIndex);
                     // Only use the attributes that use the current input
-                    if (attrib.inputSlot != vbInputSlot) {
+                    if (attrib.inputSlot != dawnVertexBufferIndex) {
                         continue;
                     }
                     max_stride = std::max(max_stride,
@@ -442,10 +440,25 @@
                 layoutDesc.stepRate = 1;
                 layoutDesc.stride = info.stride;
             }
-            // TODO(cwallez@chromium.org): make the offset depend on the pipeline layout
-            mtlVertexDescriptor.layouts[kMaxBindingsPerGroup + vbInputSlot] = layoutDesc;
+
+            mtlVertexDescriptor.layouts[mtlVertexBufferIndex] = layoutDesc;
             [layoutDesc release];
+
+            mMtlVertexBufferIndices[dawnVertexBufferIndex] = mtlVertexBufferIndex;
+            mtlVertexBufferIndex++;
         }
+
+        for (uint32_t i : IterateBitSet(GetAttributesSetMask())) {
+            const VertexAttributeInfo& info = GetAttribute(i);
+
+            auto attribDesc = [MTLVertexAttributeDescriptor new];
+            attribDesc.format = VertexFormatType(info.format);
+            attribDesc.offset = info.offset;
+            attribDesc.bufferIndex = mMtlVertexBufferIndices[info.inputSlot];
+            mtlVertexDescriptor.attributes[i] = attribDesc;
+            [attribDesc release];
+        }
+
         return mtlVertexDescriptor;
     }
 
diff --git a/src/tests/end2end/VertexInputTests.cpp b/src/tests/end2end/VertexInputTests.cpp
index 9024b1e..6efce06 100644
--- a/src/tests/end2end/VertexInputTests.cpp
+++ b/src/tests/end2end/VertexInputTests.cpp
@@ -487,6 +487,28 @@
     CheckResult(1, 4);
 }
 
+// Checks that using the last vertex buffer doesn't overflow the vertex buffer table in Metal.
+TEST_P(VertexInputTest, LastAllowedVertexBuffer) {
+    constexpr uint32_t kBufferIndex = kMaxVertexBuffers - 1;
+
+    utils::ComboVertexInputDescriptor vertexInput;
+    // All the other vertex buffers default to no attributes
+    vertexInput.bufferCount = kMaxVertexBuffers;
+    vertexInput.cBuffers[kBufferIndex].stride = 4 * sizeof(float);
+    vertexInput.cBuffers[kBufferIndex].stepMode = InputStepMode::Vertex;
+    vertexInput.cBuffers[kBufferIndex].attributeCount = 1;
+    vertexInput.cBuffers[kBufferIndex].attributes = &vertexInput.cAttributes[0];
+    vertexInput.cAttributes[0].shaderLocation = 0;
+    vertexInput.cAttributes[0].offset = 0;
+    vertexInput.cAttributes[0].format = VertexFormat::Float4;
+
+    dawn::RenderPipeline pipeline =
+        MakeTestPipeline(vertexInput, 1, {{0, VertexFormat::Float4, InputStepMode::Vertex}});
+
+    dawn::Buffer buffer0 = MakeVertexBuffer<float>({0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5});
+    DoTestDraw(pipeline, 1, 1, {DrawVertexBuffer{kMaxVertexBuffers - 1, &buffer0}});
+}
+
 DAWN_INSTANTIATE_TEST(VertexInputTest, D3D12Backend, MetalBackend, OpenGLBackend, VulkanBackend);
 
 // TODO for the input state: