Slab-allocate Metal bind groups

Bug: dawn:340
Change-Id: I6185e41d9c71c49953a4de91e5f3042968679fd6
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/15862
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
Commit-Queue: Austin Eng <enga@chromium.org>
diff --git a/BUILD.gn b/BUILD.gn
index 6a4f381..c5dfa5d 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -360,6 +360,10 @@
       "src/dawn_native/Surface_metal.mm",
       "src/dawn_native/metal/BackendMTL.h",
       "src/dawn_native/metal/BackendMTL.mm",
+      "src/dawn_native/metal/BindGroupLayoutMTL.h",
+      "src/dawn_native/metal/BindGroupLayoutMTL.mm",
+      "src/dawn_native/metal/BindGroupMTL.h",
+      "src/dawn_native/metal/BindGroupMTL.mm",
       "src/dawn_native/metal/BufferMTL.h",
       "src/dawn_native/metal/BufferMTL.mm",
       "src/dawn_native/metal/CommandBufferMTL.h",
diff --git a/src/common/SlabAllocator.cpp b/src/common/SlabAllocator.cpp
index 75aada5..6194887 100644
--- a/src/common/SlabAllocator.cpp
+++ b/src/common/SlabAllocator.cpp
@@ -17,6 +17,7 @@
 #include "common/Assert.h"
 #include "common/Math.h"
 
+#include <algorithm>
 #include <cstdlib>
 #include <limits>
 #include <new>
@@ -37,9 +38,13 @@
       blocksInUse(0) {
 }
 
+SlabAllocatorImpl::Slab::Slab(Slab&& rhs) = default;
+
 SlabAllocatorImpl::SentinelSlab::SentinelSlab() : Slab(nullptr, nullptr) {
 }
 
+SlabAllocatorImpl::SentinelSlab::SentinelSlab(SentinelSlab&& rhs) = default;
+
 SlabAllocatorImpl::SentinelSlab::~SentinelSlab() {
     Slab* slab = this->next;
     while (slab != nullptr) {
@@ -56,14 +61,12 @@
     std::numeric_limits<SlabAllocatorImpl::Index>::max();
 
 SlabAllocatorImpl::SlabAllocatorImpl(Index blocksPerSlab,
-                                     uint32_t allocationAlignment,
-                                     uint32_t slabBlocksOffset,
-                                     uint32_t blockStride,
-                                     uint32_t indexLinkNodeOffset)
-    : mAllocationAlignment(allocationAlignment),
-      mSlabBlocksOffset(slabBlocksOffset),
-      mBlockStride(blockStride),
-      mIndexLinkNodeOffset(indexLinkNodeOffset),
+                                     uint32_t objectSize,
+                                     uint32_t objectAlignment)
+    : mAllocationAlignment(std::max(static_cast<uint32_t>(alignof(Slab)), objectAlignment)),
+      mSlabBlocksOffset(Align(sizeof(Slab), objectAlignment)),
+      mIndexLinkNodeOffset(Align(objectSize, alignof(IndexLinkNode))),
+      mBlockStride(Align(mIndexLinkNodeOffset + sizeof(IndexLinkNode), objectAlignment)),
       mBlocksPerSlab(blocksPerSlab),
       mTotalAllocationSize(
           // required allocation size
@@ -74,6 +77,18 @@
     ASSERT(IsPowerOfTwo(mAllocationAlignment));
 }
 
+SlabAllocatorImpl::SlabAllocatorImpl(SlabAllocatorImpl&& rhs)
+    : mAllocationAlignment(rhs.mAllocationAlignment),
+      mSlabBlocksOffset(rhs.mSlabBlocksOffset),
+      mIndexLinkNodeOffset(rhs.mIndexLinkNodeOffset),
+      mBlockStride(rhs.mBlockStride),
+      mBlocksPerSlab(rhs.mBlocksPerSlab),
+      mTotalAllocationSize(rhs.mTotalAllocationSize),
+      mAvailableSlabs(std::move(rhs.mAvailableSlabs)),
+      mFullSlabs(std::move(rhs.mFullSlabs)),
+      mRecycledSlabs(std::move(rhs.mRecycledSlabs)) {
+}
+
 SlabAllocatorImpl::~SlabAllocatorImpl() = default;
 
 SlabAllocatorImpl::IndexLinkNode* SlabAllocatorImpl::OffsetFrom(
diff --git a/src/common/SlabAllocator.h b/src/common/SlabAllocator.h
index 59c9b63..939f1c0 100644
--- a/src/common/SlabAllocator.h
+++ b/src/common/SlabAllocator.h
@@ -60,6 +60,8 @@
     // TODO(enga): Is uint8_t sufficient?
     using Index = uint16_t;
 
+    SlabAllocatorImpl(SlabAllocatorImpl&& rhs);
+
   protected:
     // This is essentially a singly linked list using indices instead of pointers,
     // so we store the index of "this" in |this->index|.
@@ -76,6 +78,7 @@
         // | ---------- allocation --------- |
         // | pad | Slab | data ------------> |
         Slab(std::unique_ptr<char[]> allocation, IndexLinkNode* head);
+        Slab(Slab&& rhs);
 
         void Splice();
 
@@ -86,11 +89,7 @@
         Index blocksInUse;
     };
 
-    SlabAllocatorImpl(Index blocksPerSlab,
-                      uint32_t allocationAlignment,
-                      uint32_t slabBlocksOffset,
-                      uint32_t blockStride,
-                      uint32_t indexLinkNodeOffset);
+    SlabAllocatorImpl(Index blocksPerSlab, uint32_t objectSize, uint32_t objectAlignment);
     ~SlabAllocatorImpl();
 
     // Allocate a new block of memory.
@@ -136,13 +135,13 @@
     // the offset to the start of the aligned memory region.
     const uint32_t mSlabBlocksOffset;
 
+    // The IndexLinkNode is stored after the Allocation itself. This is the offset to it.
+    const uint32_t mIndexLinkNodeOffset;
+
     // Because alignment of allocations may introduce padding, |mBlockStride| is the
     // distance between aligned blocks of (Allocation + IndexLinkNode)
     const uint32_t mBlockStride;
 
-    // The IndexLinkNode is stored after the Allocation itself. This is the offset to it.
-    const uint32_t mIndexLinkNodeOffset;
-
     const Index mBlocksPerSlab;  // The total number of blocks in a slab.
 
     const size_t mTotalAllocationSize;
@@ -151,6 +150,8 @@
         SentinelSlab();
         ~SentinelSlab();
 
+        SentinelSlab(SentinelSlab&& rhs);
+
         void Prepend(Slab* slab);
     };
 
@@ -160,32 +161,13 @@
                                    // we don't thrash the current "active" slab.
 };
 
-template <typename T, size_t ObjectSize = 0>
+template <typename T>
 class SlabAllocator : public SlabAllocatorImpl {
-    // Helper struct for computing alignments
-    struct Storage {
-        Slab slab;
-        struct Block {
-            // If the size is unspecified, use sizeof(T) as default. Defined here and not as a
-            // default template parameter because T may be an incomplete type at the time of
-            // declaration.
-            static constexpr size_t kSize = ObjectSize == 0 ? sizeof(T) : ObjectSize;
-            static_assert(kSize >= sizeof(T), "");
-
-            alignas(alignof(T)) char object[kSize];
-            IndexLinkNode node;
-        } blocks[];
-    };
-
   public:
-    SlabAllocator(Index blocksPerSlab)
-        : SlabAllocatorImpl(
-              blocksPerSlab,
-              alignof(Storage),                                             // allocationAlignment
-              offsetof(Storage, blocks[0]),                                 // slabBlocksOffset
-              offsetof(Storage, blocks[1]) - offsetof(Storage, blocks[0]),  // blockStride
-              offsetof(typename Storage::Block, node)                       // indexLinkNodeOffset
-          ) {
+    SlabAllocator(size_t totalObjectBytes,
+                  uint32_t objectSize = sizeof(T),
+                  uint32_t objectAlignment = alignof(T))
+        : SlabAllocatorImpl(totalObjectBytes / objectSize, objectSize, objectAlignment) {
     }
 
     template <typename... Args>
diff --git a/src/dawn_native/BindGroup.cpp b/src/dawn_native/BindGroup.cpp
index fe8166f..491d8f2 100644
--- a/src/dawn_native/BindGroup.cpp
+++ b/src/dawn_native/BindGroup.cpp
@@ -181,45 +181,84 @@
         return {};
     }
 
+    // OwnBindingDataHolder
+
+    OwnBindingDataHolder::OwnBindingDataHolder(size_t size)
+        : mBindingDataAllocation(malloc(size))  // malloc is guaranteed to return a
+                                                // pointer aligned enough for the allocation
+    {
+    }
+
+    OwnBindingDataHolder::~OwnBindingDataHolder() {
+        free(mBindingDataAllocation);
+    }
+
+    // BindGroupBaseOwnBindingData
+
+    BindGroupBaseOwnBindingData::BindGroupBaseOwnBindingData(DeviceBase* device,
+                                                             const BindGroupDescriptor* descriptor)
+        : OwnBindingDataHolder(descriptor->layout->GetBindingDataSize()),
+          BindGroupBase(device, descriptor, mBindingDataAllocation) {
+    }
+
     // BindGroup
 
-    BindGroupBase::BindGroupBase(DeviceBase* device, const BindGroupDescriptor* descriptor)
-        : ObjectBase(device), mLayout(descriptor->layout) {
+    BindGroupBase::BindGroupBase(DeviceBase* device,
+                                 const BindGroupDescriptor* descriptor,
+                                 void* bindingDataStart)
+        : ObjectBase(device),
+          mLayout(descriptor->layout),
+          mBindingData(mLayout->ComputeBindingDataPointers(bindingDataStart)) {
+        for (uint32_t i = 0; i < mLayout->GetBindingCount(); ++i) {
+            // TODO(enga): Shouldn't be needed when bindings are tightly packed.
+            // This is to fill Ref<ObjectBase> holes with nullptrs.
+            new (&mBindingData.bindings[i]) Ref<ObjectBase>();
+        }
+
         for (uint32_t i = 0; i < descriptor->bindingCount; ++i) {
             const BindGroupBinding& binding = descriptor->bindings[i];
 
             uint32_t bindingIndex = binding.binding;
-            ASSERT(bindingIndex < kMaxBindingsPerGroup);
+            ASSERT(bindingIndex < mLayout->GetBindingCount());
 
             // Only a single binding type should be set, so once we found it we can skip to the
             // next loop iteration.
 
             if (binding.buffer != nullptr) {
-                ASSERT(mBindings[bindingIndex].Get() == nullptr);
-                mBindings[bindingIndex] = binding.buffer;
-                mOffsets[bindingIndex] = binding.offset;
+                ASSERT(mBindingData.bindings[bindingIndex].Get() == nullptr);
+                mBindingData.bindings[bindingIndex] = binding.buffer;
+                mBindingData.bufferData[bindingIndex].offset = binding.offset;
                 uint64_t bufferSize =
                     (binding.size == wgpu::kWholeSize) ? binding.buffer->GetSize() : binding.size;
-                mSizes[bindingIndex] = bufferSize;
+                mBindingData.bufferData[bindingIndex].size = bufferSize;
                 continue;
             }
 
             if (binding.textureView != nullptr) {
-                ASSERT(mBindings[bindingIndex].Get() == nullptr);
-                mBindings[bindingIndex] = binding.textureView;
+                ASSERT(mBindingData.bindings[bindingIndex].Get() == nullptr);
+                mBindingData.bindings[bindingIndex] = binding.textureView;
                 continue;
             }
 
             if (binding.sampler != nullptr) {
-                ASSERT(mBindings[bindingIndex].Get() == nullptr);
-                mBindings[bindingIndex] = binding.sampler;
+                ASSERT(mBindingData.bindings[bindingIndex].Get() == nullptr);
+                mBindingData.bindings[bindingIndex] = binding.sampler;
                 continue;
             }
         }
     }
 
+    BindGroupBase::~BindGroupBase() {
+        if (mLayout) {
+            ASSERT(!IsError());
+            for (uint32_t i = 0; i < mLayout->GetBindingCount(); ++i) {
+                mBindingData.bindings[i].~Ref<ObjectBase>();
+            }
+        }
+    }
+
     BindGroupBase::BindGroupBase(DeviceBase* device, ObjectBase::ErrorTag tag)
-        : ObjectBase(device, tag) {
+        : ObjectBase(device, tag), mBindingData() {
     }
 
     // static
@@ -240,8 +279,9 @@
                mLayout->GetBindingInfo().types[binding] == wgpu::BindingType::StorageBuffer ||
                mLayout->GetBindingInfo().types[binding] ==
                    wgpu::BindingType::ReadonlyStorageBuffer);
-        BufferBase* buffer = static_cast<BufferBase*>(mBindings[binding].Get());
-        return {buffer, mOffsets[binding], mSizes[binding]};
+        BufferBase* buffer = static_cast<BufferBase*>(mBindingData.bindings[binding].Get());
+        return {buffer, mBindingData.bufferData[binding].offset,
+                mBindingData.bufferData[binding].size};
     }
 
     SamplerBase* BindGroupBase::GetBindingAsSampler(size_t binding) {
@@ -249,7 +289,7 @@
         ASSERT(binding < kMaxBindingsPerGroup);
         ASSERT(mLayout->GetBindingInfo().mask[binding]);
         ASSERT(mLayout->GetBindingInfo().types[binding] == wgpu::BindingType::Sampler);
-        return static_cast<SamplerBase*>(mBindings[binding].Get());
+        return static_cast<SamplerBase*>(mBindingData.bindings[binding].Get());
     }
 
     TextureViewBase* BindGroupBase::GetBindingAsTextureView(size_t binding) {
@@ -257,7 +297,7 @@
         ASSERT(binding < kMaxBindingsPerGroup);
         ASSERT(mLayout->GetBindingInfo().mask[binding]);
         ASSERT(mLayout->GetBindingInfo().types[binding] == wgpu::BindingType::SampledTexture);
-        return static_cast<TextureViewBase*>(mBindings[binding].Get());
+        return static_cast<TextureViewBase*>(mBindingData.bindings[binding].Get());
     }
 
 }  // namespace dawn_native
diff --git a/src/dawn_native/BindGroup.h b/src/dawn_native/BindGroup.h
index fae804d..3255171 100644
--- a/src/dawn_native/BindGroup.h
+++ b/src/dawn_native/BindGroup.h
@@ -16,6 +16,7 @@
 #define DAWNNATIVE_BINDGROUP_H_
 
 #include "common/Constants.h"
+#include "common/Math.h"
 #include "dawn_native/BindGroupLayout.h"
 #include "dawn_native/Error.h"
 #include "dawn_native/Forward.h"
@@ -40,7 +41,7 @@
 
     class BindGroupBase : public ObjectBase {
       public:
-        BindGroupBase(DeviceBase* device, const BindGroupDescriptor* descriptor);
+        ~BindGroupBase() override;
 
         static BindGroupBase* MakeError(DeviceBase* device);
 
@@ -49,13 +50,49 @@
         SamplerBase* GetBindingAsSampler(size_t binding);
         TextureViewBase* GetBindingAsTextureView(size_t binding);
 
+      protected:
+        // To save memory, the size of a bind group is dynamically determined and the bind group is
+        // placement-allocated into memory big enough to hold the bind group with its
+        // dynamically-sized bindings after it. The pointer of the memory of the beginning of the
+        // binding data should be passed as |bindingDataStart|.
+        BindGroupBase(DeviceBase* device,
+                      const BindGroupDescriptor* descriptor,
+                      void* bindingDataStart);
+
+        // Helper to instantiate BindGroupBase. We pass in |derived| because BindGroupBase may not
+        // be first in the allocation. The binding data is stored after the Derived class.
+        template <typename Derived>
+        BindGroupBase(Derived* derived, DeviceBase* device, const BindGroupDescriptor* descriptor)
+            : BindGroupBase(device,
+                            descriptor,
+                            AlignPtr(reinterpret_cast<char*>(derived) + sizeof(Derived),
+                                     descriptor->layout->GetBindingDataAlignment())) {
+            static_assert(std::is_base_of<BindGroupBase, Derived>::value, "");
+        }
+
       private:
         BindGroupBase(DeviceBase* device, ObjectBase::ErrorTag tag);
 
         Ref<BindGroupLayoutBase> mLayout;
-        std::array<Ref<ObjectBase>, kMaxBindingsPerGroup> mBindings;
-        std::array<uint32_t, kMaxBindingsPerGroup> mOffsets;
-        std::array<uint32_t, kMaxBindingsPerGroup> mSizes;
+        BindGroupLayoutBase::BindingDataPointers mBindingData;
+    };
+
+    // Helper class so |BindGroupBaseOwnBindingData| can allocate memory for its binding data,
+    // before calling the BindGroupBase base class constructor.
+    class OwnBindingDataHolder {
+      protected:
+        explicit OwnBindingDataHolder(size_t size);
+        ~OwnBindingDataHolder();
+
+        void* mBindingDataAllocation;
+    };
+
+    // We don't have the complexity of placement-allocation of bind group data in
+    // the Null backend. This class, keeps the binding data in a separate allocation for simplicity.
+    class BindGroupBaseOwnBindingData : private OwnBindingDataHolder, public BindGroupBase {
+      public:
+        BindGroupBaseOwnBindingData(DeviceBase* device, const BindGroupDescriptor* descriptor);
+        ~BindGroupBaseOwnBindingData() override = default;
     };
 
 }  // namespace dawn_native
diff --git a/src/dawn_native/BindGroupLayout.cpp b/src/dawn_native/BindGroupLayout.cpp
index c583137..9039789 100644
--- a/src/dawn_native/BindGroupLayout.cpp
+++ b/src/dawn_native/BindGroupLayout.cpp
@@ -178,6 +178,19 @@
             mBindingInfo.types[index] = binding.type;
             mBindingInfo.textureComponentTypes[index] = binding.textureComponentType;
 
+            // TODO(enga): This is a greedy computation because there may be holes in bindings.
+            // Fix this when we pack bindings.
+            mBindingCount = std::max(mBindingCount, index + 1);
+            switch (binding.type) {
+                case wgpu::BindingType::UniformBuffer:
+                case wgpu::BindingType::StorageBuffer:
+                case wgpu::BindingType::ReadonlyStorageBuffer:
+                    mBufferCount = std::max(mBufferCount, index + 1);
+                    break;
+                default:
+                    break;
+            }
+
             if (binding.textureDimension == wgpu::TextureViewDimension::Undefined) {
                 mBindingInfo.textureDimensions[index] = wgpu::TextureViewDimension::e2D;
             } else {
@@ -240,6 +253,10 @@
         return a->mBindingInfo == b->mBindingInfo;
     }
 
+    uint32_t BindGroupLayoutBase::GetBindingCount() const {
+        return mBindingCount;
+    }
+
     uint32_t BindGroupLayoutBase::GetDynamicBufferCount() const {
         return mDynamicStorageBufferCount + mDynamicUniformBufferCount;
     }
@@ -252,4 +269,23 @@
         return mDynamicStorageBufferCount;
     }
 
+    size_t BindGroupLayoutBase::GetBindingDataSize() const {
+        // | ------ buffer-specific ----------| ------------ object pointers -------------|
+        // | --- offsets + sizes -------------| --------------- Ref<ObjectBase> ----------|
+        size_t objectPointerStart = mBufferCount * sizeof(BufferBindingData);
+        ASSERT(IsAligned(objectPointerStart, alignof(Ref<ObjectBase>)));
+        return objectPointerStart + mBindingCount * sizeof(Ref<ObjectBase>);
+    }
+
+    BindGroupLayoutBase::BindingDataPointers BindGroupLayoutBase::ComputeBindingDataPointers(
+        void* dataStart) const {
+        BufferBindingData* bufferData = reinterpret_cast<BufferBindingData*>(dataStart);
+        auto bindings = reinterpret_cast<Ref<ObjectBase>*>(bufferData + mBufferCount);
+
+        ASSERT(IsPtrAligned(bufferData, alignof(BufferBindingData)));
+        ASSERT(IsPtrAligned(bindings, alignof(Ref<ObjectBase>)));
+
+        return {bufferData, bindings};
+    }
+
 }  // namespace dawn_native
diff --git a/src/dawn_native/BindGroupLayout.h b/src/dawn_native/BindGroupLayout.h
index 02d8453..24bae39 100644
--- a/src/dawn_native/BindGroupLayout.h
+++ b/src/dawn_native/BindGroupLayout.h
@@ -16,6 +16,8 @@
 #define DAWNNATIVE_BINDGROUPLAYOUT_H_
 
 #include "common/Constants.h"
+#include "common/Math.h"
+#include "common/SlabAllocator.h"
 #include "dawn_native/CachedObject.h"
 #include "dawn_native/Error.h"
 #include "dawn_native/Forward.h"
@@ -60,14 +62,47 @@
             bool operator()(const BindGroupLayoutBase* a, const BindGroupLayoutBase* b) const;
         };
 
+        uint32_t GetBindingCount() const;
         uint32_t GetDynamicBufferCount() const;
         uint32_t GetDynamicUniformBufferCount() const;
         uint32_t GetDynamicStorageBufferCount() const;
 
+        struct BufferBindingData {
+            uint64_t offset;
+            uint64_t size;
+        };
+
+        struct BindingDataPointers {
+            BufferBindingData* const bufferData = nullptr;
+            Ref<ObjectBase>* const bindings = nullptr;
+        };
+
+        // Compute the amount of space / alignment required to store bindings for a bind group of
+        // this layout.
+        size_t GetBindingDataSize() const;
+        static constexpr size_t GetBindingDataAlignment() {
+            static_assert(alignof(Ref<ObjectBase>) <= alignof(BufferBindingData), "");
+            return alignof(BufferBindingData);
+        }
+
+        BindingDataPointers ComputeBindingDataPointers(void* dataStart) const;
+
+      protected:
+        template <typename BindGroup>
+        SlabAllocator<BindGroup> MakeFrontendBindGroupAllocator(size_t size) {
+            return SlabAllocator<BindGroup>(
+                size,  // bytes
+                Align(sizeof(BindGroup), GetBindingDataAlignment()) + GetBindingDataSize(),  // size
+                std::max(alignof(BindGroup), GetBindingDataAlignment())  // alignment
+            );
+        }
+
       private:
         BindGroupLayoutBase(DeviceBase* device, ObjectBase::ErrorTag tag);
 
         LayoutBindingInfo mBindingInfo;
+        uint32_t mBindingCount = 0;
+        uint32_t mBufferCount = 0;
         uint32_t mDynamicUniformBufferCount = 0;
         uint32_t mDynamicStorageBufferCount = 0;
     };
diff --git a/src/dawn_native/CMakeLists.txt b/src/dawn_native/CMakeLists.txt
index b3e4a89..4b45307 100644
--- a/src/dawn_native/CMakeLists.txt
+++ b/src/dawn_native/CMakeLists.txt
@@ -229,6 +229,10 @@
         "Surface_metal.mm"
         "metal/BackendMTL.h"
         "metal/BackendMTL.mm"
+        "metal/BindGroupLayoutMTL.h"
+        "metal/BindGroupLayoutMTL.mm"
+        "metal/BindGroupMTL.h"
+        "metal/BindGroupMTL.mm"
         "metal/BufferMTL.h"
         "metal/BufferMTL.mm"
         "metal/CommandBufferMTL.h"
diff --git a/src/dawn_native/d3d12/BindGroupD3D12.cpp b/src/dawn_native/d3d12/BindGroupD3D12.cpp
index 0f9ea86..29def10 100644
--- a/src/dawn_native/d3d12/BindGroupD3D12.cpp
+++ b/src/dawn_native/d3d12/BindGroupD3D12.cpp
@@ -24,10 +24,6 @@
 
 namespace dawn_native { namespace d3d12 {
 
-    BindGroup::BindGroup(Device* device, const BindGroupDescriptor* descriptor)
-        : BindGroupBase(device, descriptor) {
-    }
-
     ResultOrError<bool> BindGroup::Populate(ShaderVisibleDescriptorAllocator* allocator) {
         Device* device = ToBackend(GetDevice());
 
diff --git a/src/dawn_native/d3d12/BindGroupD3D12.h b/src/dawn_native/d3d12/BindGroupD3D12.h
index e00a662..b65af78 100644
--- a/src/dawn_native/d3d12/BindGroupD3D12.h
+++ b/src/dawn_native/d3d12/BindGroupD3D12.h
@@ -24,9 +24,9 @@
     class Device;
     class ShaderVisibleDescriptorAllocator;
 
-    class BindGroup : public BindGroupBase {
+    class BindGroup : public BindGroupBaseOwnBindingData {
       public:
-        BindGroup(Device* device, const BindGroupDescriptor* descriptor);
+        using BindGroupBaseOwnBindingData::BindGroupBaseOwnBindingData;
 
         // Returns true if the BindGroup was successfully populated.
         ResultOrError<bool> Populate(ShaderVisibleDescriptorAllocator* allocator);
diff --git a/src/dawn_native/metal/BindGroupLayoutMTL.h b/src/dawn_native/metal/BindGroupLayoutMTL.h
new file mode 100644
index 0000000..7911835
--- /dev/null
+++ b/src/dawn_native/metal/BindGroupLayoutMTL.h
@@ -0,0 +1,39 @@
+// Copyright 2020 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef DAWNNATIVE_METAL_BINDGROUPLAYOUTMTL_H_
+#define DAWNNATIVE_METAL_BINDGROUPLAYOUTMTL_H_
+
+#include "common/SlabAllocator.h"
+#include "dawn_native/BindGroupLayout.h"
+
+namespace dawn_native { namespace metal {
+
+    class BindGroup;
+    class Device;
+
+    class BindGroupLayout : public BindGroupLayoutBase {
+      public:
+        BindGroupLayout(DeviceBase* device, const BindGroupLayoutDescriptor* descriptor);
+
+        BindGroup* AllocateBindGroup(Device* device, const BindGroupDescriptor* descriptor);
+        void DeallocateBindGroup(BindGroup* bindGroup);
+
+      private:
+        SlabAllocator<BindGroup> mBindGroupAllocator;
+    };
+
+}}  // namespace dawn_native::metal
+
+#endif  // DAWNNATIVE_METAL_BINDGROUPLAYOUTMTL_H_
diff --git a/src/dawn_native/metal/BindGroupLayoutMTL.mm b/src/dawn_native/metal/BindGroupLayoutMTL.mm
new file mode 100644
index 0000000..70beb5d
--- /dev/null
+++ b/src/dawn_native/metal/BindGroupLayoutMTL.mm
@@ -0,0 +1,36 @@
+// Copyright 2020 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "dawn_native/metal/BindGroupLayoutMTL.h"
+
+#include "dawn_native/metal/BindGroupMTL.h"
+
+namespace dawn_native { namespace metal {
+
+    BindGroupLayout::BindGroupLayout(DeviceBase* device,
+                                     const BindGroupLayoutDescriptor* descriptor)
+        : BindGroupLayoutBase(device, descriptor),
+          mBindGroupAllocator(MakeFrontendBindGroupAllocator<BindGroup>(4096)) {
+    }
+
+    BindGroup* BindGroupLayout::AllocateBindGroup(Device* device,
+                                                  const BindGroupDescriptor* descriptor) {
+        return mBindGroupAllocator.Allocate(device, descriptor);
+    }
+
+    void BindGroupLayout::DeallocateBindGroup(BindGroup* bindGroup) {
+        mBindGroupAllocator.Deallocate(bindGroup);
+    }
+
+}}  // namespace dawn_native::metal
diff --git a/src/dawn_native/metal/BindGroupMTL.h b/src/dawn_native/metal/BindGroupMTL.h
new file mode 100644
index 0000000..4a0a229
--- /dev/null
+++ b/src/dawn_native/metal/BindGroupMTL.h
@@ -0,0 +1,36 @@
+// Copyright 2020 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef DAWNNATIVE_METAL_BINDGROUPMTL_H_
+#define DAWNNATIVE_METAL_BINDGROUPMTL_H_
+
+#include "common/PlacementAllocated.h"
+#include "dawn_native/BindGroup.h"
+
+namespace dawn_native { namespace metal {
+
+    class BindGroupLayout;
+    class Device;
+
+    class BindGroup : public BindGroupBase, public PlacementAllocated {
+      public:
+        BindGroup(Device* device, const BindGroupDescriptor* descriptor);
+        ~BindGroup() override;
+
+        static BindGroup* Create(Device* device, const BindGroupDescriptor* descriptor);
+    };
+
+}}  // namespace dawn_native::metal
+
+#endif  // DAWNNATIVE_METAL_BINDGROUPMTL_H_
diff --git a/src/dawn_native/metal/BindGroupMTL.mm b/src/dawn_native/metal/BindGroupMTL.mm
new file mode 100644
index 0000000..d8bcd51
--- /dev/null
+++ b/src/dawn_native/metal/BindGroupMTL.mm
@@ -0,0 +1,34 @@
+// Copyright 2020 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "dawn_native/metal/BindGroupMTL.h"
+
+#include "dawn_native/metal/BindGroupLayoutMTL.h"
+#include "dawn_native/metal/DeviceMTL.h"
+namespace dawn_native { namespace metal {
+
+    BindGroup::BindGroup(Device* device, const BindGroupDescriptor* descriptor)
+        : BindGroupBase(this, device, descriptor) {
+    }
+
+    BindGroup::~BindGroup() {
+        ToBackend(GetLayout())->DeallocateBindGroup(this);
+    }
+
+    // static
+    BindGroup* BindGroup::Create(Device* device, const BindGroupDescriptor* descriptor) {
+        return ToBackend(descriptor->layout)->AllocateBindGroup(device, descriptor);
+    }
+
+}}  // namespace dawn_native::metal
diff --git a/src/dawn_native/metal/CommandBufferMTL.mm b/src/dawn_native/metal/CommandBufferMTL.mm
index 7958217..8d85724 100644
--- a/src/dawn_native/metal/CommandBufferMTL.mm
+++ b/src/dawn_native/metal/CommandBufferMTL.mm
@@ -14,11 +14,11 @@
 
 #include "dawn_native/metal/CommandBufferMTL.h"
 
-#include "dawn_native/BindGroup.h"
 #include "dawn_native/BindGroupTracker.h"
 #include "dawn_native/CommandEncoder.h"
 #include "dawn_native/Commands.h"
 #include "dawn_native/RenderBundle.h"
+#include "dawn_native/metal/BindGroupMTL.h"
 #include "dawn_native/metal/BufferMTL.h"
 #include "dawn_native/metal/ComputePipelineMTL.h"
 #include "dawn_native/metal/DeviceMTL.h"
diff --git a/src/dawn_native/metal/DeviceMTL.mm b/src/dawn_native/metal/DeviceMTL.mm
index 8869b68..f91fd95 100644
--- a/src/dawn_native/metal/DeviceMTL.mm
+++ b/src/dawn_native/metal/DeviceMTL.mm
@@ -15,10 +15,11 @@
 #include "dawn_native/metal/DeviceMTL.h"
 
 #include "dawn_native/BackendConnection.h"
-#include "dawn_native/BindGroup.h"
 #include "dawn_native/BindGroupLayout.h"
 #include "dawn_native/DynamicUploader.h"
 #include "dawn_native/ErrorData.h"
+#include "dawn_native/metal/BindGroupLayoutMTL.h"
+#include "dawn_native/metal/BindGroupMTL.h"
 #include "dawn_native/metal/BufferMTL.h"
 #include "dawn_native/metal/CommandBufferMTL.h"
 #include "dawn_native/metal/ComputePipelineMTL.h"
@@ -93,7 +94,7 @@
 
     ResultOrError<BindGroupBase*> Device::CreateBindGroupImpl(
         const BindGroupDescriptor* descriptor) {
-        return new BindGroup(this, descriptor);
+        return BindGroup::Create(this, descriptor);
     }
     ResultOrError<BindGroupLayoutBase*> Device::CreateBindGroupLayoutImpl(
         const BindGroupLayoutDescriptor* descriptor) {
diff --git a/src/dawn_native/metal/Forward.h b/src/dawn_native/metal/Forward.h
index 4e889cd..a773a18 100644
--- a/src/dawn_native/metal/Forward.h
+++ b/src/dawn_native/metal/Forward.h
@@ -17,16 +17,11 @@
 
 #include "dawn_native/ToBackend.h"
 
-namespace {
-    class BindGroupBase;
-    class BindGroup;
-}  // namespace
-
 namespace dawn_native { namespace metal {
 
     class Adapter;
-    using BindGroup = BindGroupBase;
-    using BindGroupLayout = BindGroupLayoutBase;
+    class BindGroup;
+    class BindGroupLayout;
     class Buffer;
     class CommandBuffer;
     class ComputePipeline;
diff --git a/src/dawn_native/null/DeviceNull.h b/src/dawn_native/null/DeviceNull.h
index dee944a..cd52971 100644
--- a/src/dawn_native/null/DeviceNull.h
+++ b/src/dawn_native/null/DeviceNull.h
@@ -38,7 +38,7 @@
 namespace dawn_native { namespace null {
 
     class Adapter;
-    using BindGroup = BindGroupBase;
+    using BindGroup = BindGroupBaseOwnBindingData;
     using BindGroupLayout = BindGroupLayoutBase;
     class Buffer;
     class CommandBuffer;
diff --git a/src/dawn_native/opengl/Forward.h b/src/dawn_native/opengl/Forward.h
index 6542ff9..8b4f20b 100644
--- a/src/dawn_native/opengl/Forward.h
+++ b/src/dawn_native/opengl/Forward.h
@@ -17,16 +17,16 @@
 
 #include "dawn_native/ToBackend.h"
 
-namespace {
-    class BindGroupBase;
-    class BindGroup;
-    class RenderPassDescriptor;
-}  // namespace
+namespace dawn_native {
+    class BindGroupBaseOwnBindingData;
+    class BindGroupLayoutBase;
+    struct RenderPassDescriptor;
+}  // namespace dawn_native
 
 namespace dawn_native { namespace opengl {
 
     class Adapter;
-    using BindGroup = BindGroupBase;
+    using BindGroup = BindGroupBaseOwnBindingData;
     using BindGroupLayout = BindGroupLayoutBase;
     class Buffer;
     class CommandBuffer;
diff --git a/src/dawn_native/vulkan/BindGroupVk.h b/src/dawn_native/vulkan/BindGroupVk.h
index 9fa857b..727f959 100644
--- a/src/dawn_native/vulkan/BindGroupVk.h
+++ b/src/dawn_native/vulkan/BindGroupVk.h
@@ -24,7 +24,7 @@
 
     class Device;
 
-    class BindGroup : public BindGroupBase {
+    class BindGroup : public BindGroupBaseOwnBindingData {
       public:
         static ResultOrError<BindGroup*> Create(Device* device,
                                                 const BindGroupDescriptor* descriptor);
@@ -33,7 +33,7 @@
         VkDescriptorSet GetHandle() const;
 
       private:
-        using BindGroupBase::BindGroupBase;
+        using BindGroupBaseOwnBindingData::BindGroupBaseOwnBindingData;
         MaybeError Initialize();
 
         // The descriptor set in this allocation outlives the BindGroup because it is owned by
diff --git a/src/tests/unittests/SlabAllocatorTests.cpp b/src/tests/unittests/SlabAllocatorTests.cpp
index 12da010..45011fc 100644
--- a/src/tests/unittests/SlabAllocatorTests.cpp
+++ b/src/tests/unittests/SlabAllocatorTests.cpp
@@ -34,7 +34,7 @@
 
 // Test that a slab allocator of a single object works.
 TEST(SlabAllocatorTests, Single) {
-    SlabAllocator<Foo> allocator(1);
+    SlabAllocator<Foo> allocator(1 * sizeof(Foo));
 
     Foo* obj = allocator.Allocate(4);
     EXPECT_EQ(obj->value, 4);
@@ -46,7 +46,7 @@
 TEST(SlabAllocatorTests, AllocateSequential) {
     // Check small alignment
     {
-        SlabAllocator<Foo> allocator(5);
+        SlabAllocator<Foo> allocator(5 * sizeof(Foo));
 
         std::vector<Foo*> objects;
         for (int i = 0; i < 10; ++i) {
@@ -71,7 +71,7 @@
 
     // Check large alignment
     {
-        SlabAllocator<AlignedFoo> allocator(9);
+        SlabAllocator<AlignedFoo> allocator(9 * sizeof(AlignedFoo));
 
         std::vector<AlignedFoo*> objects;
         for (int i = 0; i < 21; ++i) {
@@ -97,7 +97,7 @@
 
 // Test that when reallocating a number of objects <= pool size, all memory is reused.
 TEST(SlabAllocatorTests, ReusesFreedMemory) {
-    SlabAllocator<Foo> allocator(17);
+    SlabAllocator<Foo> allocator(17 * sizeof(Foo));
 
     // Allocate a number of objects.
     std::set<Foo*> objects;
@@ -127,7 +127,7 @@
 // Test many allocations and deallocations. Meant to catch corner cases with partially
 // empty slabs.
 TEST(SlabAllocatorTests, AllocateDeallocateMany) {
-    SlabAllocator<Foo> allocator(17);
+    SlabAllocator<Foo> allocator(17 * sizeof(Foo));
 
     std::set<Foo*> objects;
     std::set<Foo*> set3;