D3D12: Refactor bind group descriptor tracking and descriptor heap allocation

Extract descriptor offset computation and CPU descriptor recording to
BindGroupLayout and BindGroup. Refactor descriptor heap allocation to
copy from a large CPU heap to a GPU heap.
diff --git a/src/backend/CMakeLists.txt b/src/backend/CMakeLists.txt
index f55a976..a19870e 100644
--- a/src/backend/CMakeLists.txt
+++ b/src/backend/CMakeLists.txt
@@ -237,6 +237,10 @@
     SetPIC(d3d12_autogen)
 
     list(APPEND BACKEND_SOURCES
+        ${D3D12_DIR}/BindGroupD3D12.cpp
+        ${D3D12_DIR}/BindGroupD3D12.h
+        ${D3D12_DIR}/BindGroupLayoutD3D12.cpp
+        ${D3D12_DIR}/BindGroupLayoutD3D12.h
         ${D3D12_DIR}/BufferD3D12.cpp
         ${D3D12_DIR}/BufferD3D12.h
         ${D3D12_DIR}/CommandAllocatorManager.cpp
diff --git a/src/backend/d3d12/BindGroupD3D12.cpp b/src/backend/d3d12/BindGroupD3D12.cpp
new file mode 100644
index 0000000..e278ed9
--- /dev/null
+++ b/src/backend/d3d12/BindGroupD3D12.cpp
@@ -0,0 +1,95 @@
+// Copyright 2017 The NXT Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "common/BitSetIterator.h"
+#include "BindGroupD3D12.h"
+#include "BindGroupLayoutD3D12.h"
+#include "BufferD3D12.h"
+#include "SamplerD3D12.h"
+#include "TextureD3D12.h"
+
+#include "D3D12Backend.h"
+
+namespace backend {
+namespace d3d12 {
+
+    BindGroup::BindGroup(Device* device, BindGroupBuilder* builder)
+        : BindGroupBase(builder), device(device) {
+    }
+
+    void BindGroup::RecordDescriptors(const DescriptorHeapHandle &cbvUavSrvHeapStart, uint32_t* cbvUavSrvHeapOffset, const DescriptorHeapHandle &samplerHeapStart, uint32_t* samplerHeapOffset, uint64_t serial) {
+        heapSerial = serial;
+
+        const auto* bgl = ToBackend(GetLayout());
+        const auto& layout = bgl->GetBindingInfo();
+
+        // Save the offset to the start of the descriptor table in the heap
+        this->cbvUavSrvHeapOffset = *cbvUavSrvHeapOffset;
+        this->samplerHeapOffset = *samplerHeapOffset;
+
+        const auto& bindingOffsets = bgl->GetBindingOffsets();
+
+        auto d3d12Device = device->GetD3D12Device();
+        for (uint32_t binding : IterateBitSet(layout.mask)) {
+            switch (layout.types[binding]) {
+                case nxt::BindingType::UniformBuffer:
+                    {
+                        auto* view = ToBackend(GetBindingAsBufferView(binding));
+                        auto& cbv = view->GetCBVDescriptor();
+                        d3d12Device->CreateConstantBufferView(&cbv, cbvUavSrvHeapStart.GetCPUHandle(*cbvUavSrvHeapOffset + bindingOffsets[binding]));
+                    }
+                    break;
+                case nxt::BindingType::StorageBuffer:
+                    {
+                        auto* view = ToBackend(GetBindingAsBufferView(binding));
+                        auto& uav = view->GetUAVDescriptor();
+                        d3d12Device->CreateUnorderedAccessView(ToBackend(view->GetBuffer())->GetD3D12Resource().Get(), nullptr, &uav, cbvUavSrvHeapStart.GetCPUHandle(*cbvUavSrvHeapOffset + bindingOffsets[binding]));
+                    }
+                    break;
+                case nxt::BindingType::SampledTexture:
+                    {
+                        auto* view = ToBackend(GetBindingAsTextureView(binding));
+                        auto& srv = view->GetSRVDescriptor();
+                        d3d12Device->CreateShaderResourceView(ToBackend(view->GetTexture())->GetD3D12Resource().Get(), &srv, cbvUavSrvHeapStart.GetCPUHandle(*cbvUavSrvHeapOffset + bindingOffsets[binding]));
+                    }
+                    break;
+                case nxt::BindingType::Sampler:
+                    {
+                        auto* sampler = ToBackend(GetBindingAsSampler(binding));
+                        auto& samplerDesc = sampler->GetSamplerDescriptor();
+                        d3d12Device->CreateSampler(&samplerDesc, samplerHeapStart.GetCPUHandle(*samplerHeapOffset + bindingOffsets[binding]));
+                    }
+                    break;
+            }
+        }
+
+        // Offset by the number of descriptors created
+        *cbvUavSrvHeapOffset += bgl->GetCbvUavSrvDescriptorCount();
+        *samplerHeapOffset += bgl->GetSamplerDescriptorCount();
+    }
+
+    uint32_t BindGroup::GetCbvUavSrvHeapOffset() const {
+        return cbvUavSrvHeapOffset;
+    }
+
+    uint32_t BindGroup::GetSamplerHeapOffset() const {
+        return samplerHeapOffset;
+    }
+
+    uint64_t BindGroup::GetHeapSerial() const {
+        return heapSerial;
+    }
+
+}
+}
diff --git a/src/backend/d3d12/BindGroupD3D12.h b/src/backend/d3d12/BindGroupD3D12.h
new file mode 100644
index 0000000..49d89fe
--- /dev/null
+++ b/src/backend/d3d12/BindGroupD3D12.h
@@ -0,0 +1,50 @@
+// Copyright 2017 The NXT Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef BACKEND_D3D12_BINDGROUPD3D12_H_
+#define BACKEND_D3D12_BINDGROUPD3D12_H_
+
+#include "common/BindGroup.h"
+
+#include "d3d12_platform.h"
+
+#include "DescriptorHeapAllocator.h"
+
+namespace backend {
+namespace d3d12 {
+
+    class Device;
+
+    class BindGroup : public BindGroupBase {
+        public:
+            BindGroup(Device* device, BindGroupBuilder* builder);
+
+            void RecordDescriptors(const DescriptorHeapHandle &cbvSrvUavHeapStart, uint32_t* cbvUavSrvHeapOffset, const DescriptorHeapHandle &samplerHeapStart, uint32_t* samplerHeapOffset, uint64_t serial);
+            uint32_t GetCbvUavSrvHeapOffset() const;
+            uint32_t GetSamplerHeapOffset() const;
+            uint64_t GetHeapSerial() const;
+
+        private:
+            Device* device;
+            uint32_t cbvUavSrvHeapOffset;
+            uint32_t samplerHeapOffset;
+            uint32_t cbvUavSrvCount = 0;
+            uint32_t samplerCount = 0;
+            uint64_t heapSerial = 0;
+    };
+
+}
+}
+
+#endif // BACKEND_D3D12_BINDGROUPD3D12_H_
diff --git a/src/backend/d3d12/BindGroupLayoutD3D12.cpp b/src/backend/d3d12/BindGroupLayoutD3D12.cpp
new file mode 100644
index 0000000..c7a5172
--- /dev/null
+++ b/src/backend/d3d12/BindGroupLayoutD3D12.cpp
@@ -0,0 +1,132 @@
+// Copyright 2017 The NXT Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "BindGroupLayoutD3D12.h"
+
+#include "common/BitSetIterator.h"
+#include "D3D12Backend.h"
+
+namespace backend {
+namespace d3d12 {
+
+    BindGroupLayout::BindGroupLayout(Device* device, BindGroupLayoutBuilder* builder)
+        : BindGroupLayoutBase(builder), device(device), descriptorCounts {}  {
+
+        const auto& groupInfo = GetBindingInfo();
+
+        for (uint32_t binding : IterateBitSet(groupInfo.mask)) {
+            switch (groupInfo.types[binding]) {
+                case nxt::BindingType::UniformBuffer:
+                    bindingOffsets[binding] = descriptorCounts[CBV]++;
+                    break;
+                case nxt::BindingType::StorageBuffer:
+                    bindingOffsets[binding] = descriptorCounts[UAV]++;
+                    break;
+                case nxt::BindingType::SampledTexture:
+                    bindingOffsets[binding] = descriptorCounts[SRV]++;
+                    break;
+                case nxt::BindingType::Sampler:
+                    bindingOffsets[binding] = descriptorCounts[Sampler]++;
+                    break;
+            }
+        }
+
+        auto SetDescriptorRange = [&](uint32_t index, uint32_t count, D3D12_DESCRIPTOR_RANGE_TYPE type) -> bool {
+            if (count == 0) {
+                return false;
+            }
+
+            auto& range = ranges[index];
+            range.RangeType = type;
+            range.NumDescriptors = count;
+            range.RegisterSpace = 0;
+            range.OffsetInDescriptorsFromTableStart = D3D12_DESCRIPTOR_RANGE_OFFSET_APPEND;
+            // These ranges will be copied and range.BaseShaderRegister will be set in d3d12::PipelineLayout to account for bind group register offsets
+            return true;
+        };
+
+        uint32_t rangeIndex = 0;
+
+        // Ranges 0-2 contain the CBV, UAV, and SRV ranges, if they exist, tightly packed
+        // Range 3 contains the Sampler range, if there is one
+        if (SetDescriptorRange(rangeIndex, descriptorCounts[CBV], D3D12_DESCRIPTOR_RANGE_TYPE_CBV)) {
+            rangeIndex++;
+        }
+        if (SetDescriptorRange(rangeIndex, descriptorCounts[UAV], D3D12_DESCRIPTOR_RANGE_TYPE_UAV)) {
+            rangeIndex++;
+        }
+        if (SetDescriptorRange(rangeIndex, descriptorCounts[SRV], D3D12_DESCRIPTOR_RANGE_TYPE_SRV)) {
+            rangeIndex++;
+        }
+        SetDescriptorRange(Sampler, descriptorCounts[Sampler], D3D12_DESCRIPTOR_RANGE_TYPE_SAMPLER);
+
+        // descriptors ranges are offset by the offset + size of the previous range
+        std::array<uint32_t, DescriptorType::Count> descriptorOffsets;
+        descriptorOffsets[CBV] = 0;
+        descriptorOffsets[UAV] = descriptorOffsets[CBV] + descriptorCounts[CBV];
+        descriptorOffsets[SRV] = descriptorOffsets[UAV] + descriptorCounts[UAV];
+        descriptorOffsets[Sampler] = 0; // samplers are in a different heap
+
+        for (uint32_t binding : IterateBitSet(groupInfo.mask)) {
+            switch (groupInfo.types[binding]) {
+                case nxt::BindingType::UniformBuffer:
+                    bindingOffsets[binding] += descriptorOffsets[CBV];
+                    break;
+                case nxt::BindingType::StorageBuffer:
+                    bindingOffsets[binding] += descriptorOffsets[UAV];
+                    break;
+                case nxt::BindingType::SampledTexture:
+                    bindingOffsets[binding] += descriptorOffsets[SRV];
+                    break;
+                case nxt::BindingType::Sampler:
+                    bindingOffsets[binding] += descriptorOffsets[Sampler];
+                    break;
+            }
+        }
+    }
+
+    const std::array<uint32_t, kMaxBindingsPerGroup>& BindGroupLayout::GetBindingOffsets() const {
+        return bindingOffsets;
+    }
+
+    uint32_t BindGroupLayout::GetCbvUavSrvDescriptorTableSize() const {
+        return (
+            static_cast<uint32_t>(descriptorCounts[CBV] > 0) +
+            static_cast<uint32_t>(descriptorCounts[UAV] > 0) +
+            static_cast<uint32_t>(descriptorCounts[SRV] > 0)
+        );
+    }
+
+    uint32_t BindGroupLayout::GetSamplerDescriptorTableSize() const {
+        return descriptorCounts[Sampler] > 0;
+    }
+
+    uint32_t BindGroupLayout::GetCbvUavSrvDescriptorCount() const {
+        return descriptorCounts[CBV] + descriptorCounts[UAV] + descriptorCounts[SRV];
+    }
+
+    uint32_t BindGroupLayout::GetSamplerDescriptorCount() const {
+        return descriptorCounts[Sampler];
+    }
+
+    const D3D12_DESCRIPTOR_RANGE* BindGroupLayout::GetCbvUavSrvDescriptorRanges() const {
+        return ranges;
+    }
+
+    const D3D12_DESCRIPTOR_RANGE* BindGroupLayout::GetSamplerDescriptorRanges() const {
+        return &ranges[Sampler];
+    }
+
+}
+}
diff --git a/src/backend/d3d12/BindGroupLayoutD3D12.h b/src/backend/d3d12/BindGroupLayoutD3D12.h
new file mode 100644
index 0000000..1d310703
--- /dev/null
+++ b/src/backend/d3d12/BindGroupLayoutD3D12.h
@@ -0,0 +1,57 @@
+// Copyright 2017 The NXT Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef BACKEND_D3D12_BINDGROUPLAYOUTD3D12_H_
+#define BACKEND_D3D12_BINDGROUPLAYOUTD3D12_H_
+
+#include "common/BindGroupLayout.h"
+
+#include "d3d12_platform.h"
+
+namespace backend {
+namespace d3d12 {
+
+    class Device;
+
+    class BindGroupLayout : public BindGroupLayoutBase {
+        public:
+            BindGroupLayout(Device* device, BindGroupLayoutBuilder* builder);
+
+            enum DescriptorType {
+                CBV,
+                UAV,
+                SRV,
+                Sampler,
+                Count,
+            };
+
+            const std::array<uint32_t, kMaxBindingsPerGroup>& GetBindingOffsets() const;
+            uint32_t GetCbvUavSrvDescriptorTableSize() const;
+            uint32_t GetSamplerDescriptorTableSize() const;
+            uint32_t GetCbvUavSrvDescriptorCount() const;
+            uint32_t GetSamplerDescriptorCount() const;
+            const D3D12_DESCRIPTOR_RANGE* GetCbvUavSrvDescriptorRanges() const;
+            const D3D12_DESCRIPTOR_RANGE* GetSamplerDescriptorRanges() const;
+
+        private:
+            Device* device;
+            std::array<uint32_t, kMaxBindingsPerGroup> bindingOffsets;
+            std::array<uint32_t, DescriptorType::Count> descriptorCounts;
+            D3D12_DESCRIPTOR_RANGE ranges[DescriptorType::Count];
+    };
+
+}
+}
+
+#endif // BACKEND_D3D12_BINDGROUPLAYOUTD3D12_H_
diff --git a/src/backend/d3d12/CommandBufferD3D12.cpp b/src/backend/d3d12/CommandBufferD3D12.cpp
index 1d173ba..da1ce4f 100644
--- a/src/backend/d3d12/CommandBufferD3D12.cpp
+++ b/src/backend/d3d12/CommandBufferD3D12.cpp
@@ -16,8 +16,10 @@
 
 #include "common/Commands.h"
 #include "D3D12Backend.h"
-#include "DescriptorHeapAllocator.h"
+#include "BindGroupD3D12.h"
+#include "BindGroupLayoutD3D12.h"
 #include "BufferD3D12.h"
+#include "DescriptorHeapAllocator.h"
 #include "InputStateD3D12.h"
 #include "PipelineD3D12.h"
 #include "PipelineLayoutD3D12.h"
@@ -40,142 +42,152 @@
         struct BindGroupStateTracker {
             uint32_t cbvSrvUavDescriptorIndex = 0;
             uint32_t samplerDescriptorIndex = 0;
-            DescriptorHeapHandle cbvSrvUavDescriptorHeap;
-            DescriptorHeapHandle samplerDescriptorHeap;
+            DescriptorHeapHandle cbvSrvUavCPUDescriptorHeap;
+            DescriptorHeapHandle samplerCPUDescriptorHeap;
+            DescriptorHeapHandle cbvSrvUavGPUDescriptorHeap;
+            DescriptorHeapHandle samplerGPUDescriptorHeap;
             std::array<BindGroup*, kMaxBindGroups> bindGroups = {};
+
             Device* device;
 
             BindGroupStateTracker(Device* device) : device(device) {
             }
 
-            void TrackSetBindGroup(const BindGroupLayoutBase* bindGroupLayout) {
-                const auto& layout = bindGroupLayout->GetBindingInfo();
+            void TrackSetBindGroup(BindGroup* group, uint32_t index) {
+                if (bindGroups[index] != group) {
+                    bindGroups[index] = group;
 
-                for (size_t binding = 0; binding < layout.mask.size(); ++binding) {
-                    if (!layout.mask[binding]) {
-                        continue;
-                    }
-
-                    switch (layout.types[binding]) {
-                    case nxt::BindingType::UniformBuffer:
-                    case nxt::BindingType::StorageBuffer:
-                    case nxt::BindingType::SampledTexture:
-                        cbvSrvUavDescriptorIndex++;
-                    case nxt::BindingType::Sampler:
-                        samplerDescriptorIndex++;
+                    // Descriptors don't need to be recorded if they have already been recorded in the heap. Indices are only updated when descriptors are recorded
+                    const uint64_t serial = device->GetSerial();
+                    if (group->GetHeapSerial() != serial) {
+                        group->RecordDescriptors(cbvSrvUavCPUDescriptorHeap, &cbvSrvUavDescriptorIndex, samplerCPUDescriptorHeap, &samplerDescriptorIndex, serial);
                     }
                 }
             }
 
-            void SetBindGroup(Pipeline* pipeline, BindGroup* group, uint32_t index, ComPtr<ID3D12GraphicsCommandList> commandList) {
-                const auto& layout = group->GetLayout()->GetBindingInfo();
-
-                // these indices are the beginning of the descriptor table
-                uint32_t cbvSrvUavDescriptorStart = cbvSrvUavDescriptorIndex;
-                uint32_t samplerDescriptorStart = samplerDescriptorIndex;
-
-                bindGroups[index] = group;
-
-                PipelineLayout* pipelineLayout = ToBackend(pipeline->GetLayout());
-
-                // these indices are the offsets from the start of the descriptor table
-                uint32_t cbvIndex = pipelineLayout->GetDescriptorStartingIndex(index, PipelineLayout::Descriptor::Type::CBV);
-                uint32_t uavIndex = pipelineLayout->GetDescriptorStartingIndex(index, PipelineLayout::Descriptor::Type::UAV);
-                uint32_t srvIndex = pipelineLayout->GetDescriptorStartingIndex(index, PipelineLayout::Descriptor::Type::SRV);
-                uint32_t samplerIndex = pipelineLayout->GetDescriptorStartingIndex(index, PipelineLayout::Descriptor::Type::Sampler);
-
-                for (size_t binding = 0; binding < layout.mask.size(); ++binding) {
-                    if (!layout.mask[binding]) {
-                        continue;
-                    }
-
-                    switch (layout.types[binding]) {
-                        case nxt::BindingType::UniformBuffer:
-                            {
-                                auto* view = ToBackend(group->GetBindingAsBufferView(binding));
-                                auto* buffer = ToBackend(view->GetBuffer());
-                                auto& cbvDesc = view->GetCBVDescriptor();
-                                device->GetD3D12Device()->CreateConstantBufferView(&cbvDesc, cbvSrvUavDescriptorHeap.GetCPUHandle(cbvSrvUavDescriptorStart + cbvIndex++));
-                                cbvSrvUavDescriptorIndex++;
-                            }
-                            break;
-                        case nxt::BindingType::StorageBuffer:
-                            {
-                                auto* view = ToBackend(group->GetBindingAsBufferView(binding));
-                                auto* buffer = ToBackend(view->GetBuffer());
-                                auto& uavDesc = view->GetUAVDescriptor();
-                                device->GetD3D12Device()->CreateUnorderedAccessView(buffer->GetD3D12Resource().Get(), nullptr, &uavDesc, cbvSrvUavDescriptorHeap.GetCPUHandle(cbvSrvUavDescriptorStart + uavIndex++));
-                                cbvSrvUavDescriptorIndex++;
-                            }
-                            break;
-                        case nxt::BindingType::SampledTexture:
-                            {
-                                auto* texture = ToBackend(group->GetBindingAsTextureView(binding)->GetTexture());
-                                auto& srvDesc = texture->GetSRVDescriptor();
-                                device->GetD3D12Device()->CreateShaderResourceView(texture->GetD3D12Resource().Get(), &srvDesc, cbvSrvUavDescriptorHeap.GetCPUHandle(cbvSrvUavDescriptorStart + srvIndex++));
-                                cbvSrvUavDescriptorIndex++;
-                            }
-                            break;
-                        case nxt::BindingType::Sampler:
-                            {
-                                auto* sampler = ToBackend(group->GetBindingAsSampler(binding));
-                                auto& samplerDesc = sampler->GetSamplerDescriptor();
-                                device->GetD3D12Device()->CreateSampler(&samplerDesc, samplerDescriptorHeap.GetCPUHandle(samplerDescriptorStart + samplerIndex++));
-                                samplerDescriptorIndex++;
-                            }
-                            break;
-                    }
-                }
-
-                if (cbvSrvUavDescriptorStart != cbvSrvUavDescriptorIndex) {
-                    uint32_t parameterIndex = pipelineLayout->GetCBVSRVUAVRootParameterIndex(index);
-
-                    if (pipeline->IsCompute()) {
-                        commandList->SetComputeRootDescriptorTable(parameterIndex, cbvSrvUavDescriptorHeap.GetGPUHandle(cbvSrvUavDescriptorStart));
-                    } else {
-                        commandList->SetGraphicsRootDescriptorTable(parameterIndex, cbvSrvUavDescriptorHeap.GetGPUHandle(cbvSrvUavDescriptorStart));
-                    }
-                }
-
-                if (samplerDescriptorStart != samplerDescriptorIndex) {
-                    uint32_t parameterIndex = pipelineLayout->GetSamplerRootParameterIndex(index);
-
-                    if (pipeline->IsCompute()) {
-                        commandList->SetComputeRootDescriptorTable(parameterIndex, samplerDescriptorHeap.GetGPUHandle(samplerDescriptorStart));
-                    } else {
-                        commandList->SetGraphicsRootDescriptorTable(parameterIndex, samplerDescriptorHeap.GetGPUHandle(samplerDescriptorStart));
-                    }
-                }
-            }
-
-            void SetInheritedBindGroup(Pipeline* pipeline, uint32_t index, ComPtr<ID3D12GraphicsCommandList> commandList) {
+            void TrackSetBindInheritedGroup(uint32_t index) {
                 BindGroup* group = bindGroups[index];
-                ASSERT(group != nullptr);
-                SetBindGroup(pipeline, group, index, commandList);
+                if (group != nullptr) {
+                    TrackSetBindGroup(group, index);
+                }
             }
 
-            void AllocateAndSetDescriptorHeaps(Device* device, ComPtr<ID3D12GraphicsCommandList> commandList) {
-                cbvSrvUavDescriptorHeap = device->GetDescriptorHeapAllocator()->Allocate(D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV, cbvSrvUavDescriptorIndex);
-                samplerDescriptorHeap = device->GetDescriptorHeapAllocator()->Allocate(D3D12_DESCRIPTOR_HEAP_TYPE_SAMPLER, samplerDescriptorIndex);
+            void SetBindGroup(ComPtr<ID3D12GraphicsCommandList> commandList, Pipeline* pipeline, BindGroup* group, uint32_t index, bool force = false) {
+                if (bindGroups[index] != group || force) {
+                    bindGroups[index] = group;
 
-                ID3D12DescriptorHeap* descriptorHeaps[2] = { cbvSrvUavDescriptorHeap.Get(), samplerDescriptorHeap.Get() };
-                if (descriptorHeaps[0] && descriptorHeaps[1]) {
-                    commandList->SetDescriptorHeaps(2, descriptorHeaps);
-                } else if (descriptorHeaps[0]) {
-                    commandList->SetDescriptorHeaps(1, descriptorHeaps);
-                } else if (descriptorHeaps[1]) {
-                    commandList->SetDescriptorHeaps(2, &descriptorHeaps[1]);
+                    PipelineLayout* pipelineLayout = ToBackend(pipeline->GetLayout());
+                    uint32_t cbvUavSrvCount = ToBackend(group->GetLayout())->GetCbvUavSrvDescriptorCount();
+                    uint32_t samplerCount = ToBackend(group->GetLayout())->GetSamplerDescriptorCount();
+
+                    if (cbvUavSrvCount > 0) {
+                        uint32_t parameterIndex = pipelineLayout->GetCbvUavSrvRootParameterIndex(index);
+
+                        if (pipeline->IsCompute()) {
+                            commandList->SetComputeRootDescriptorTable(parameterIndex, cbvSrvUavGPUDescriptorHeap.GetGPUHandle(group->GetCbvUavSrvHeapOffset()));
+                        } else {
+                            commandList->SetGraphicsRootDescriptorTable(parameterIndex, cbvSrvUavGPUDescriptorHeap.GetGPUHandle(group->GetCbvUavSrvHeapOffset()));
+                        }
+                    }
+
+                    if (samplerCount > 0) {
+                        uint32_t parameterIndex = pipelineLayout->GetSamplerRootParameterIndex(index);
+
+                        if (pipeline->IsCompute()) {
+                            commandList->SetComputeRootDescriptorTable(parameterIndex, samplerGPUDescriptorHeap.GetGPUHandle(group->GetSamplerHeapOffset()));
+                        } else {
+                            commandList->SetGraphicsRootDescriptorTable(parameterIndex, samplerGPUDescriptorHeap.GetGPUHandle(group->GetSamplerHeapOffset()));
+                        }
+                    }
+                }
+            }
+
+            void SetInheritedBindGroup(ComPtr<ID3D12GraphicsCommandList> commandList, Pipeline* pipeline, uint32_t index) {
+                BindGroup* group = bindGroups[index];
+                if (group != nullptr) {
+                    SetBindGroup(commandList, pipeline, group, index, true);
                 }
             }
 
             void Reset() {
-                cbvSrvUavDescriptorIndex = 0;
-                samplerDescriptorIndex = 0;
                 for (uint32_t i = 0; i < kMaxBindGroups; ++i) {
                     bindGroups[i] = nullptr;
                 }
             }
         };
+
+        void AllocateAndSetDescriptorHeaps(Device* device, BindGroupStateTracker* bindingTracker, CommandIterator* commands) {
+            auto* descriptorHeapAllocator = device->GetDescriptorHeapAllocator();
+
+            // TODO(enga@google.com): This currently allocates CPU heaps of arbitrarily chosen sizes
+            // This will not work if there are too many descriptors
+            bindingTracker->cbvSrvUavCPUDescriptorHeap = descriptorHeapAllocator->AllocateCPUHeap(D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV, 8192);
+            bindingTracker->samplerCPUDescriptorHeap = descriptorHeapAllocator->AllocateCPUHeap(D3D12_DESCRIPTOR_HEAP_TYPE_SAMPLER, 2048);
+
+            {
+                Command type;
+                Pipeline* lastPipeline = nullptr;
+                PipelineLayout* lastLayout = nullptr;
+
+                while (commands->NextCommandId(&type)) {
+                    switch (type) {
+                        case Command::SetPipeline:
+                        {
+                            SetPipelineCmd* cmd = commands->NextCommand<SetPipelineCmd>();
+                            Pipeline* pipeline = ToBackend(cmd->pipeline).Get();
+                            PipelineLayout* layout = ToBackend(pipeline->GetLayout());
+
+                            if (lastLayout) {
+                                auto mask = layout->GetBindGroupsLayoutMask();
+                                for (uint32_t i = 0; i < kMaxBindGroups; ++i) {
+                                    // matching bind groups are inherited until they differ
+                                    if (mask[i] && lastLayout->GetBindGroupLayout(i) == layout->GetBindGroupLayout(i)) {
+                                        bindingTracker->TrackSetBindInheritedGroup(i);
+                                    } else {
+                                        break;
+                                    }
+                                }
+                            }
+
+                            lastPipeline = pipeline;
+                            lastLayout = layout;
+                        }
+                        break;
+
+                        case Command::SetBindGroup:
+                        {
+                            SetBindGroupCmd* cmd = commands->NextCommand<SetBindGroupCmd>();
+                            BindGroup* group = ToBackend(cmd->group.Get());
+                            bindingTracker->TrackSetBindGroup(group, cmd->index);
+                        }
+                        break;
+                        default:
+                            SkipCommand(commands, type);
+                    }
+                }
+
+                commands->Reset();
+            }
+
+            if (bindingTracker->cbvSrvUavDescriptorIndex > 0) {
+                // Allocate a GPU-visible heap and copy from the CPU-only heap to the GPU-visible heap
+                bindingTracker->cbvSrvUavGPUDescriptorHeap = descriptorHeapAllocator->AllocateGPUHeap(D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV, bindingTracker->cbvSrvUavDescriptorIndex);
+                device->GetD3D12Device()->CopyDescriptorsSimple(
+                    bindingTracker->cbvSrvUavDescriptorIndex,
+                    bindingTracker->cbvSrvUavGPUDescriptorHeap.GetCPUHandle(0),
+                    bindingTracker->cbvSrvUavCPUDescriptorHeap.GetCPUHandle(0),
+                    D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV);
+            }
+
+            if (bindingTracker->samplerDescriptorIndex > 0) {
+                bindingTracker->samplerGPUDescriptorHeap = descriptorHeapAllocator->AllocateGPUHeap(D3D12_DESCRIPTOR_HEAP_TYPE_SAMPLER, bindingTracker->samplerDescriptorIndex);
+                device->GetD3D12Device()->CopyDescriptorsSimple(
+                    bindingTracker->samplerDescriptorIndex,
+                    bindingTracker->samplerGPUDescriptorHeap.GetCPUHandle(0),
+                    bindingTracker->samplerCPUDescriptorHeap.GetCPUHandle(0),
+                    D3D12_DESCRIPTOR_HEAP_TYPE_SAMPLER);
+            }
+        }
     }
 
     CommandBuffer::CommandBuffer(Device* device, CommandBufferBuilder* builder)
@@ -188,55 +200,18 @@
 
     void CommandBuffer::FillCommands(ComPtr<ID3D12GraphicsCommandList> commandList) {
         BindGroupStateTracker bindingTracker(device);
-
-        {
-            Command type;
-            Pipeline* lastPipeline = nullptr;
-            PipelineLayout* lastLayout = nullptr;
-
-            while(commands.NextCommandId(&type)) {
-                switch (type) {
-                    case Command::SetPipeline:
-                        {
-                            SetPipelineCmd* cmd = commands.NextCommand<SetPipelineCmd>();
-                            Pipeline* pipeline = ToBackend(cmd->pipeline).Get();
-                            PipelineLayout* layout = ToBackend(pipeline->GetLayout());
-
-                            if (lastLayout) {
-                                auto mask = layout->GetBindGroupsLayoutMask();
-                                for (uint32_t i = 0; i < kMaxBindGroups; ++i) {
-                                    // matching bind groups are inherited until they differ
-                                    if (mask[i] && lastLayout->GetBindGroupLayout(i) == layout->GetBindGroupLayout(i)) {
-                                        bindingTracker.TrackSetBindGroup(layout->GetBindGroupLayout(i));
-                                    } else {
-                                        break;
-                                    }
-                                }
-                            }
-
-                            lastPipeline = pipeline;
-                            lastLayout = layout;
-                        }
-                        break;
-
-                    case Command::SetBindGroup:
-                        {
-                            SetBindGroupCmd* cmd = commands.NextCommand<SetBindGroupCmd>();
-                            BindGroup* group = ToBackend(cmd->group.Get());
-                            bindingTracker.TrackSetBindGroup(group->GetLayout());
-                        }
-                        break;
-                    default:
-                        SkipCommand(&commands, type);
-                }
-            }
-
-            commands.Reset();
-        }
-
-        bindingTracker.AllocateAndSetDescriptorHeaps(device, commandList);
+        AllocateAndSetDescriptorHeaps(device, &bindingTracker, &commands);
         bindingTracker.Reset();
 
+        ID3D12DescriptorHeap* descriptorHeaps[2] = { bindingTracker.cbvSrvUavGPUDescriptorHeap.Get(), bindingTracker.samplerGPUDescriptorHeap.Get() };
+        if (descriptorHeaps[0] && descriptorHeaps[1]) {
+            commandList->SetDescriptorHeaps(2, descriptorHeaps);
+        } else if (descriptorHeaps[0]) {
+            commandList->SetDescriptorHeaps(1, descriptorHeaps);
+        } else if (descriptorHeaps[1]) {
+            commandList->SetDescriptorHeaps(2, &descriptorHeaps[1]);
+        }
+
         Command type;
         Pipeline* lastPipeline = nullptr;
         PipelineLayout* lastLayout = nullptr;
@@ -350,7 +325,7 @@
                             for (uint32_t i = 0; i < kMaxBindGroups; ++i) {
                                 // matching bind groups are inherited until they differ
                                 if (mask[i] && lastLayout->GetBindGroupLayout(i) == layout->GetBindGroupLayout(i)) {
-                                    bindingTracker.SetInheritedBindGroup(pipeline, i, commandList);
+                                    bindingTracker.SetInheritedBindGroup(commandList, pipeline, i);
                                 } else {
                                     break;
                                 }
@@ -379,7 +354,7 @@
                     {
                         SetBindGroupCmd* cmd = commands.NextCommand<SetBindGroupCmd>();
                         BindGroup* group = ToBackend(cmd->group.Get());
-                        bindingTracker.SetBindGroup(lastPipeline, group, cmd->index, commandList);
+                        bindingTracker.SetBindGroup(commandList, lastPipeline, group, cmd->index);
                     }
                     break;
 
diff --git a/src/backend/d3d12/D3D12Backend.cpp b/src/backend/d3d12/D3D12Backend.cpp
index c1c2586..36c330b 100644
--- a/src/backend/d3d12/D3D12Backend.cpp
+++ b/src/backend/d3d12/D3D12Backend.cpp
@@ -14,6 +14,8 @@
 
 #include "D3D12Backend.h"
 
+#include "BindGroupD3D12.h"
+#include "BindGroupLayoutD3D12.h"
 #include "BufferD3D12.h"
 #include "CommandBufferD3D12.h"
 #include "InputStateD3D12.h"
@@ -248,18 +250,6 @@
     void Device::Release() {
     }
 
-    // Bind Group
-
-    BindGroup::BindGroup(Device* device, BindGroupBuilder* builder)
-        : BindGroupBase(builder), device(device) {
-    }
-
-    // Bind Group Layout
-
-    BindGroupLayout::BindGroupLayout(Device* device, BindGroupLayoutBuilder* builder)
-        : BindGroupLayoutBase(builder), device(device) {
-    }
-
     // DepthStencilState
 
     DepthStencilState::DepthStencilState(Device* device, DepthStencilStateBuilder* builder)
diff --git a/src/backend/d3d12/D3D12Backend.h b/src/backend/d3d12/D3D12Backend.h
index e76d0b9..cae0fc8 100644
--- a/src/backend/d3d12/D3D12Backend.h
+++ b/src/backend/d3d12/D3D12Backend.h
@@ -155,23 +155,6 @@
             D3D12_CPU_DESCRIPTOR_HANDLE renderTargetDescriptor;
     };
 
-
-    class BindGroup : public BindGroupBase {
-        public:
-            BindGroup(Device* device, BindGroupBuilder* builder);
-
-        private:
-            Device* device;
-    };
-
-    class BindGroupLayout : public BindGroupLayoutBase {
-        public:
-            BindGroupLayout(Device* device, BindGroupLayoutBuilder* builder);
-
-        private:
-            Device* device;
-    };
-
     class Framebuffer : public FramebufferBase {
         public:
             Framebuffer(Device* device, FramebufferBuilder* builder);
diff --git a/src/backend/d3d12/DescriptorHeapAllocator.cpp b/src/backend/d3d12/DescriptorHeapAllocator.cpp
index e38a65a..6c93a22 100644
--- a/src/backend/d3d12/DescriptorHeapAllocator.cpp
+++ b/src/backend/d3d12/DescriptorHeapAllocator.cpp
@@ -55,40 +55,52 @@
           } {
     }
 
-    DescriptorHeapHandle DescriptorHeapAllocator::Allocate(D3D12_DESCRIPTOR_HEAP_TYPE type, uint32_t count) {
+    DescriptorHeapHandle DescriptorHeapAllocator::Allocate(D3D12_DESCRIPTOR_HEAP_TYPE type, uint32_t count, uint32_t allocationSize, DescriptorHeapInfo* heapInfo, D3D12_DESCRIPTOR_HEAP_FLAGS flags) {
+        // TODO(enga@google.com): This is just a linear allocator so the heap will quickly run out of space causing a new one to be allocated
+        // We should reuse heap subranges that have been released
         if (count == 0) {
             return DescriptorHeapHandle();
         }
 
-        auto& pools = descriptorHeapPools[type];
-        for (auto it : pools) {
-            auto& allocationInfo = it.second;
+        {
+            // If the current pool for this type has space, linearly allocate count bytes in the pool
+            auto& allocationInfo = heapInfo->second;
             if (allocationInfo.remaining >= count) {
-                DescriptorHeapHandle handle(it.first, sizeIncrements[type], allocationInfo.size - allocationInfo.remaining);
+                DescriptorHeapHandle handle(heapInfo->first, sizeIncrements[type], allocationInfo.size - allocationInfo.remaining);
                 allocationInfo.remaining -= count;
                 Release(handle);
                 return handle;
             }
         }
 
-        ASSERT(count <= 2048); // TODO(enga@google.com): Have a very large CPU heap that's copied to GPU-visible heaps
-        uint32_t descriptorHeapSize = 2048; // TODO(enga@google.com): Allocate much more and use this as a pool
+        // If the pool has no more space, replace the pool with a new one of the specified size
 
         D3D12_DESCRIPTOR_HEAP_DESC heapDescriptor;
         heapDescriptor.Type = type;
-        heapDescriptor.NumDescriptors = descriptorHeapSize;
-        heapDescriptor.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE;
+        heapDescriptor.NumDescriptors = allocationSize;
+        heapDescriptor.Flags = flags;
         heapDescriptor.NodeMask = 0;
         ComPtr<ID3D12DescriptorHeap> heap;
         ASSERT_SUCCESS(device->GetD3D12Device()->CreateDescriptorHeap(&heapDescriptor, IID_PPV_ARGS(&heap)));
-        AllocationInfo allocationInfo = { descriptorHeapSize, descriptorHeapSize - count };
-        pools.emplace_back(std::make_pair(heap, allocationInfo));
+
+        AllocationInfo allocationInfo = { allocationSize, allocationSize - count };
+        *heapInfo = std::make_pair(heap, allocationInfo);
 
         DescriptorHeapHandle handle(heap, sizeIncrements[type], 0);
         Release(handle);
         return handle;
     }
 
+    DescriptorHeapHandle DescriptorHeapAllocator::AllocateCPUHeap(D3D12_DESCRIPTOR_HEAP_TYPE type, uint32_t count) {
+        return Allocate(type, count, count, &cpuDescriptorHeapInfos[type], D3D12_DESCRIPTOR_HEAP_FLAG_NONE);
+    }
+
+    DescriptorHeapHandle DescriptorHeapAllocator::AllocateGPUHeap(D3D12_DESCRIPTOR_HEAP_TYPE type, uint32_t count) {
+        ASSERT(type == D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV || type == D3D12_DESCRIPTOR_HEAP_TYPE_SAMPLER);
+        unsigned int heapSize = (type == D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV ? kMaxCbvUavSrvHeapSize : kMaxSamplerHeapSize);
+        return Allocate(type, count, heapSize, &gpuDescriptorHeapInfos[type], D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE);
+    }
+
     void DescriptorHeapAllocator::FreeDescriptorHeaps(uint64_t lastCompletedSerial) {
         releasedHandles.ClearUpTo(lastCompletedSerial);
     }
diff --git a/src/backend/d3d12/DescriptorHeapAllocator.h b/src/backend/d3d12/DescriptorHeapAllocator.h
index 07ac478..30d8f0b 100644
--- a/src/backend/d3d12/DescriptorHeapAllocator.h
+++ b/src/backend/d3d12/DescriptorHeapAllocator.h
@@ -27,6 +27,7 @@
     class Device;
 
     class DescriptorHeapHandle {
+
         public:
             DescriptorHeapHandle();
             DescriptorHeapHandle(ComPtr<ID3D12DescriptorHeap> descriptorHeap, uint32_t sizeIncrement, uint32_t offset);
@@ -46,27 +47,30 @@
         public:
             DescriptorHeapAllocator(Device* device);
 
-            DescriptorHeapHandle Allocate(D3D12_DESCRIPTOR_HEAP_TYPE type, uint32_t count);
+            DescriptorHeapHandle AllocateGPUHeap(D3D12_DESCRIPTOR_HEAP_TYPE type, uint32_t count);
+            DescriptorHeapHandle AllocateCPUHeap(D3D12_DESCRIPTOR_HEAP_TYPE type, uint32_t count);
             void FreeDescriptorHeaps(uint64_t lastCompletedSerial);
 
         private:
+            static constexpr unsigned int kMaxCbvUavSrvHeapSize = 1000000;
+            static constexpr unsigned int kMaxSamplerHeapSize = 2048;
+            static constexpr unsigned int kDescriptorHeapTypes = D3D12_DESCRIPTOR_HEAP_TYPE::D3D12_DESCRIPTOR_HEAP_TYPE_NUM_TYPES;
+
+            struct AllocationInfo {
+                uint32_t size = 0;
+                uint32_t remaining = 0;
+            };
+
+            using DescriptorHeapInfo = std::pair<ComPtr<ID3D12DescriptorHeap>, AllocationInfo>;
+            
+            DescriptorHeapHandle Allocate(D3D12_DESCRIPTOR_HEAP_TYPE type, uint32_t count, uint32_t allocationSize, DescriptorHeapInfo* heapInfo, D3D12_DESCRIPTOR_HEAP_FLAGS flags);
             void Release(DescriptorHeapHandle handle);
 
             Device* device;
 
-            static constexpr unsigned int kDescriptorHeapTypes = D3D12_DESCRIPTOR_HEAP_TYPE::D3D12_DESCRIPTOR_HEAP_TYPE_NUM_TYPES;
-
-            struct AllocationInfo {
-                uint32_t size;
-                uint32_t remaining;
-            };
-
-            using DescriptorHeapPool = std::pair<ComPtr<ID3D12DescriptorHeap>, AllocationInfo>;
-
-            using DescriptorHeapPoolList = std::vector<DescriptorHeapPool>;
-
             std::array<uint32_t, kDescriptorHeapTypes> sizeIncrements;
-            std::array<DescriptorHeapPoolList, kDescriptorHeapTypes> descriptorHeapPools;
+            std::array<DescriptorHeapInfo, kDescriptorHeapTypes> cpuDescriptorHeapInfos;
+            std::array<DescriptorHeapInfo, kDescriptorHeapTypes> gpuDescriptorHeapInfos;
             SerialQueue<DescriptorHeapHandle> releasedHandles;
     };
 
diff --git a/src/backend/d3d12/GeneratedCodeIncludes.h b/src/backend/d3d12/GeneratedCodeIncludes.h
index a4f6a83..fbf10566 100644
--- a/src/backend/d3d12/GeneratedCodeIncludes.h
+++ b/src/backend/d3d12/GeneratedCodeIncludes.h
@@ -13,6 +13,8 @@
 // limitations under the License.
 
 #include "D3D12Backend.h"
+#include "BindGroupD3D12.h"
+#include "BindGroupLayoutD3D12.h"
 #include "BufferD3D12.h"
 #include "CommandBufferD3D12.h"
 #include "InputStateD3D12.h"
diff --git a/src/backend/d3d12/PipelineLayoutD3D12.cpp b/src/backend/d3d12/PipelineLayoutD3D12.cpp
index bdf67ca..883cd2d 100644
--- a/src/backend/d3d12/PipelineLayoutD3D12.cpp
+++ b/src/backend/d3d12/PipelineLayoutD3D12.cpp
@@ -48,22 +48,24 @@
             // Set the root descriptor table parameter and copy ranges. Ranges are offset by the bind group index
             // Returns whether or not the parameter was set. A root parameter is not set if the number of ranges is 0
             auto SetRootDescriptorTable = [&](uint32_t rangeCount, const D3D12_DESCRIPTOR_RANGE* descriptorRanges) -> bool {
-                if (rangeCount > 0) {
-                    auto& rootParameter = rootParameters[parameterIndex];
-                    rootParameter.ParameterType = D3D12_ROOT_PARAMETER_TYPE_DESCRIPTOR_TABLE;
-                    rootParameter.DescriptorTable = rootParameterValues[parameterIndex].DescriptorTable;
-                    rootParameter.ShaderVisibility = D3D12_SHADER_VISIBILITY_ALL;
-                    rootParameter.DescriptorTable.NumDescriptorRanges = rangeCount;
-                    rootParameter.DescriptorTable.pDescriptorRanges = &ranges[rangeIndex];
+                if (rangeCount == 0) {
+                    return false;
                 }
 
+                auto& rootParameter = rootParameters[parameterIndex];
+                rootParameter.ParameterType = D3D12_ROOT_PARAMETER_TYPE_DESCRIPTOR_TABLE;
+                rootParameter.DescriptorTable = rootParameterValues[parameterIndex].DescriptorTable;
+                rootParameter.ShaderVisibility = D3D12_SHADER_VISIBILITY_ALL;
+                rootParameter.DescriptorTable.NumDescriptorRanges = rangeCount;
+                rootParameter.DescriptorTable.pDescriptorRanges = &ranges[rangeIndex];
+
                 for (uint32_t i = 0; i < rangeCount; ++i) {
                     ranges[rangeIndex] = descriptorRanges[i];
                     ranges[rangeIndex].BaseShaderRegister = group * kMaxBindingsPerGroup;
                     rangeIndex++;
                 }
 
-                return (rangeCount > 0);
+                return true;
             };
 
             if (SetRootDescriptorTable(bindGroupLayout->GetCbvUavSrvDescriptorTableSize(), bindGroupLayout->GetCbvUavSrvDescriptorRanges())) {
diff --git a/src/backend/d3d12/PipelineLayoutD3D12.h b/src/backend/d3d12/PipelineLayoutD3D12.h
index a8b24cd..fe87a05 100644
--- a/src/backend/d3d12/PipelineLayoutD3D12.h
+++ b/src/backend/d3d12/PipelineLayoutD3D12.h
@@ -28,34 +28,16 @@
         public:
             PipelineLayout(Device* device, PipelineLayoutBuilder* builder);
 
-            class Descriptor {
-                public:
-                    enum class Type {
-                        CBV,
-                        UAV,
-                        SRV,
-                        Sampler,
-                        Count
-                    };
-                    static constexpr unsigned int TypeCount = static_cast<typename std::underlying_type<Type>::type>(Type::Count);
-            };
-
             uint32_t GetCbvUavSrvRootParameterIndex(uint32_t group) const;
             uint32_t GetSamplerRootParameterIndex(uint32_t group) const;
 
             ComPtr<ID3D12RootSignature> GetRootSignature();
 
         private:
-
-            static constexpr unsigned int ToIndex(Descriptor::Type type) {
-                return static_cast<typename std::underlying_type<Descriptor::Type>::type>(type);
-            }
-
             Device* device;
 
             std::array<uint32_t, kMaxBindGroups> cbvUavSrvRootParameterInfo;
             std::array<uint32_t, kMaxBindGroups> samplerRootParameterInfo;
-            std::array<std::array<uint32_t, Descriptor::TypeCount>, kMaxBindGroups> descriptorCountInfo;
 
             ComPtr<ID3D12RootSignature> rootSignature;
     };
diff --git a/src/backend/d3d12/ShaderModuleD3D12.cpp b/src/backend/d3d12/ShaderModuleD3D12.cpp
index 54ed656..87a272e 100644
--- a/src/backend/d3d12/ShaderModuleD3D12.cpp
+++ b/src/backend/d3d12/ShaderModuleD3D12.cpp
@@ -40,24 +40,23 @@
             Count,
         };
 
-        std::array<uint32_t, RegisterType::Count * kMaxBindGroups> baseRegisters = {};
-
-        const auto& resources = compiler.get_shader_resources();
-
         // rename bindings so that each register type b/u/t/s starts at 0 and then offset by kMaxBindingsPerGroup * bindGroupIndex
-        auto RenumberBindings = [&](std::vector<spirv_cross::Resource> resources, uint32_t offset) {
+        auto RenumberBindings = [&](std::vector<spirv_cross::Resource> resources) {
+            std::array<uint32_t, kMaxBindGroups> baseRegisters = {};
+
             for (const auto& resource : resources) {
                 auto bindGroupIndex = compiler.get_decoration(resource.id, spv::DecorationDescriptorSet);
-                auto& baseRegister = baseRegisters[RegisterType::Count * bindGroupIndex + offset];
+                auto& baseRegister = baseRegisters[bindGroupIndex];
                 auto bindGroupOffset = bindGroupIndex * kMaxBindingsPerGroup;
                 compiler.set_decoration(resource.id, spv::DecorationBinding, bindGroupOffset + baseRegister++);
             }
         };
 
-        RenumberBindings(resources.uniform_buffers, RegisterType::Buffer);
-        RenumberBindings(resources.storage_buffers, RegisterType::UnorderedAccess);
-        RenumberBindings(resources.separate_images, RegisterType::Texture);
-        RenumberBindings(resources.separate_samplers, RegisterType::Sampler);
+        const auto& resources = compiler.get_shader_resources();
+        RenumberBindings(resources.uniform_buffers);    // c
+        RenumberBindings(resources.storage_buffers);    // u
+        RenumberBindings(resources.separate_images);    // t
+        RenumberBindings(resources.separate_samplers);  // s
 
         hlslSource = compiler.compile();