D3D12: De-duplicate sampler heap allocations.

Allows bindgroups that use the same samplers to share
a descriptor heap allocation. This is particularly important
for sampler heaps which incur expensive pipeline flushes
due to the smaller size requiring more frequent switches.

The device dolls out entries to a sampler heap allocation cache.
When the BindGroup is created, it does a lookup and refs the
allocation. This ensures the cache does not grow unbounded
or needlessly store unused entires.

This change is a follow-up of de-coupling heaps.

BUG=dawn:155

Change-Id: I3ab6f1bdb13a40905cb990cd7a2139e73da30303
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/20783
Commit-Queue: Bryan Bernhart <bryan.bernhart@intel.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/dawn_native/BUILD.gn b/src/dawn_native/BUILD.gn
index 8e97fae..b55c05f 100644
--- a/src/dawn_native/BUILD.gn
+++ b/src/dawn_native/BUILD.gn
@@ -328,6 +328,8 @@
       "d3d12/ResourceHeapAllocationD3D12.h",
       "d3d12/SamplerD3D12.cpp",
       "d3d12/SamplerD3D12.h",
+      "d3d12/SamplerHeapCacheD3D12.cpp",
+      "d3d12/SamplerHeapCacheD3D12.h",
       "d3d12/ShaderModuleD3D12.cpp",
       "d3d12/ShaderModuleD3D12.h",
       "d3d12/ShaderVisibleDescriptorAllocatorD3D12.cpp",
diff --git a/src/dawn_native/BindGroup.cpp b/src/dawn_native/BindGroup.cpp
index 840c404..647587c 100644
--- a/src/dawn_native/BindGroup.cpp
+++ b/src/dawn_native/BindGroup.cpp
@@ -297,6 +297,11 @@
         return mLayout.Get();
     }
 
+    const BindGroupLayoutBase* BindGroupBase::GetLayout() const {
+        ASSERT(!IsError());
+        return mLayout.Get();
+    }
+
     BufferBinding BindGroupBase::GetBindingAsBufferBinding(BindingIndex bindingIndex) {
         ASSERT(!IsError());
         ASSERT(bindingIndex < mLayout->GetBindingCount());
@@ -309,7 +314,7 @@
                 mBindingData.bufferData[bindingIndex].size};
     }
 
-    SamplerBase* BindGroupBase::GetBindingAsSampler(BindingIndex bindingIndex) {
+    SamplerBase* BindGroupBase::GetBindingAsSampler(BindingIndex bindingIndex) const {
         ASSERT(!IsError());
         ASSERT(bindingIndex < mLayout->GetBindingCount());
         ASSERT(mLayout->GetBindingInfo(bindingIndex).type == wgpu::BindingType::Sampler ||
diff --git a/src/dawn_native/BindGroup.h b/src/dawn_native/BindGroup.h
index 6afee61..1240246 100644
--- a/src/dawn_native/BindGroup.h
+++ b/src/dawn_native/BindGroup.h
@@ -44,8 +44,9 @@
         static BindGroupBase* MakeError(DeviceBase* device);
 
         BindGroupLayoutBase* GetLayout();
+        const BindGroupLayoutBase* GetLayout() const;
         BufferBinding GetBindingAsBufferBinding(BindingIndex bindingIndex);
-        SamplerBase* GetBindingAsSampler(BindingIndex bindingIndex);
+        SamplerBase* GetBindingAsSampler(BindingIndex bindingIndex) const;
         TextureViewBase* GetBindingAsTextureView(BindingIndex bindingIndex);
 
       protected:
diff --git a/src/dawn_native/CMakeLists.txt b/src/dawn_native/CMakeLists.txt
index 1465ad0..aa5aa70 100644
--- a/src/dawn_native/CMakeLists.txt
+++ b/src/dawn_native/CMakeLists.txt
@@ -209,6 +209,8 @@
         "d3d12/ResourceHeapAllocationD3D12.h"
         "d3d12/SamplerD3D12.cpp"
         "d3d12/SamplerD3D12.h"
+        "d3d12/SamplerHeapCacheD3D12.cpp"
+        "d3d12/SamplerHeapCacheD3D12.h"
         "d3d12/ShaderModuleD3D12.cpp"
         "d3d12/ShaderModuleD3D12.h"
         "d3d12/ShaderVisibleDescriptorAllocatorD3D12.cpp"
diff --git a/src/dawn_native/d3d12/BindGroupD3D12.cpp b/src/dawn_native/d3d12/BindGroupD3D12.cpp
index 5aeaf56..76fb028 100644
--- a/src/dawn_native/d3d12/BindGroupD3D12.cpp
+++ b/src/dawn_native/d3d12/BindGroupD3D12.cpp
@@ -18,7 +18,7 @@
 #include "dawn_native/d3d12/BindGroupLayoutD3D12.h"
 #include "dawn_native/d3d12/BufferD3D12.h"
 #include "dawn_native/d3d12/DeviceD3D12.h"
-#include "dawn_native/d3d12/SamplerD3D12.h"
+#include "dawn_native/d3d12/SamplerHeapCacheD3D12.h"
 #include "dawn_native/d3d12/ShaderVisibleDescriptorAllocatorD3D12.h"
 #include "dawn_native/d3d12/TextureD3D12.h"
 
@@ -33,14 +33,11 @@
     BindGroup::BindGroup(Device* device,
                          const BindGroupDescriptor* descriptor,
                          uint32_t viewSizeIncrement,
-                         const CPUDescriptorHeapAllocation& viewAllocation,
-                         uint32_t samplerSizeIncrement,
-                         const CPUDescriptorHeapAllocation& samplerAllocation)
+                         const CPUDescriptorHeapAllocation& viewAllocation)
         : BindGroupBase(this, device, descriptor) {
         BindGroupLayout* bgl = ToBackend(GetLayout());
 
         mCPUViewAllocation = viewAllocation;
-        mCPUSamplerAllocation = samplerAllocation;
 
         const auto& bindingOffsets = bgl->GetBindingOffsets();
 
@@ -129,11 +126,7 @@
                 }
                 case wgpu::BindingType::Sampler:
                 case wgpu::BindingType::ComparisonSampler: {
-                    auto* sampler = ToBackend(GetBindingAsSampler(bindingIndex));
-                    auto& samplerDesc = sampler->GetSamplerDescriptor();
-                    d3d12Device->CreateSampler(
-                        &samplerDesc, samplerAllocation.OffsetFrom(samplerSizeIncrement,
-                                                                   bindingOffsets[bindingIndex]));
+                    // No-op as samplers will be later initialized by CreateSamplers().
                     break;
                 }
 
@@ -156,32 +149,15 @@
     }
 
     BindGroup::~BindGroup() {
-        ToBackend(GetLayout())
-            ->DeallocateBindGroup(this, &mCPUViewAllocation, &mCPUSamplerAllocation);
+        ToBackend(GetLayout())->DeallocateBindGroup(this, &mCPUViewAllocation);
         ASSERT(!mCPUViewAllocation.IsValid());
-        ASSERT(!mCPUSamplerAllocation.IsValid());
     }
 
     bool BindGroup::PopulateViews(ShaderVisibleDescriptorAllocator* viewAllocator) {
         const BindGroupLayout* bgl = ToBackend(GetLayout());
-        return Populate(viewAllocator, bgl->GetCbvUavSrvDescriptorCount(),
-                        D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV, mCPUViewAllocation,
-                        &mGPUViewAllocation);
-    }
 
-    bool BindGroup::PopulateSamplers(ShaderVisibleDescriptorAllocator* samplerAllocator) {
-        const BindGroupLayout* bgl = ToBackend(GetLayout());
-        return Populate(samplerAllocator, bgl->GetSamplerDescriptorCount(),
-                        D3D12_DESCRIPTOR_HEAP_TYPE_SAMPLER, mCPUSamplerAllocation,
-                        &mGPUSamplerAllocation);
-    }
-
-    bool BindGroup::Populate(ShaderVisibleDescriptorAllocator* allocator,
-                             uint32_t descriptorCount,
-                             D3D12_DESCRIPTOR_HEAP_TYPE heapType,
-                             const CPUDescriptorHeapAllocation& stagingAllocation,
-                             GPUDescriptorHeapAllocation* allocation) {
-        if (descriptorCount == 0 || allocator->IsAllocationStillValid(*allocation)) {
+        const uint32_t descriptorCount = bgl->GetCbvUavSrvDescriptorCount();
+        if (descriptorCount == 0 || viewAllocator->IsAllocationStillValid(mGPUViewAllocation)) {
             return true;
         }
 
@@ -190,16 +166,18 @@
         Device* device = ToBackend(GetDevice());
 
         D3D12_CPU_DESCRIPTOR_HANDLE baseCPUDescriptor;
-        if (!allocator->AllocateGPUDescriptors(descriptorCount, device->GetPendingCommandSerial(),
-                                               &baseCPUDescriptor, allocation)) {
+        if (!viewAllocator->AllocateGPUDescriptors(descriptorCount,
+                                                   device->GetPendingCommandSerial(),
+                                                   &baseCPUDescriptor, &mGPUViewAllocation)) {
             return false;
         }
 
         // CPU bindgroups are sparsely allocated across CPU heaps. Instead of doing
         // simple copies per bindgroup, a single non-simple copy could be issued.
         // TODO(dawn:155): Consider doing this optimization.
-        device->GetD3D12Device()->CopyDescriptorsSimple(
-            descriptorCount, baseCPUDescriptor, stagingAllocation.GetBaseDescriptor(), heapType);
+        device->GetD3D12Device()->CopyDescriptorsSimple(descriptorCount, baseCPUDescriptor,
+                                                        mCPUViewAllocation.GetBaseDescriptor(),
+                                                        D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV);
 
         return true;
     }
@@ -209,6 +187,19 @@
     }
 
     D3D12_GPU_DESCRIPTOR_HANDLE BindGroup::GetBaseSamplerDescriptor() const {
-        return mGPUSamplerAllocation.GetBaseDescriptor();
+        ASSERT(mSamplerAllocationEntry.Get() != nullptr);
+        return mSamplerAllocationEntry->GetBaseDescriptor();
+    }
+
+    bool BindGroup::PopulateSamplers(Device* device,
+                                     ShaderVisibleDescriptorAllocator* samplerAllocator) {
+        if (mSamplerAllocationEntry.Get() == nullptr) {
+            return true;
+        }
+        return mSamplerAllocationEntry->Populate(device, samplerAllocator);
+    }
+
+    void BindGroup::SetSamplerAllocationEntry(Ref<SamplerHeapCacheEntry> entry) {
+        mSamplerAllocationEntry = std::move(entry);
     }
 }}  // namespace dawn_native::d3d12
diff --git a/src/dawn_native/d3d12/BindGroupD3D12.h b/src/dawn_native/d3d12/BindGroupD3D12.h
index 05d67b4..54acb3d 100644
--- a/src/dawn_native/d3d12/BindGroupD3D12.h
+++ b/src/dawn_native/d3d12/BindGroupD3D12.h
@@ -24,7 +24,9 @@
 namespace dawn_native { namespace d3d12 {
 
     class Device;
+    class SamplerHeapCacheEntry;
     class ShaderVisibleDescriptorAllocator;
+    class StagingDescriptorAllocator;
 
     class BindGroup final : public BindGroupBase, public PlacementAllocated {
       public:
@@ -34,30 +36,23 @@
         BindGroup(Device* device,
                   const BindGroupDescriptor* descriptor,
                   uint32_t viewSizeIncrement,
-                  const CPUDescriptorHeapAllocation& viewAllocation,
-                  uint32_t samplerSizeIncrement,
-                  const CPUDescriptorHeapAllocation& samplerAllocation);
+                  const CPUDescriptorHeapAllocation& viewAllocation);
 
         // Returns true if the BindGroup was successfully populated.
         bool PopulateViews(ShaderVisibleDescriptorAllocator* viewAllocator);
-        bool PopulateSamplers(ShaderVisibleDescriptorAllocator* samplerAllocator);
+        bool PopulateSamplers(Device* device, ShaderVisibleDescriptorAllocator* samplerAllocator);
 
         D3D12_GPU_DESCRIPTOR_HANDLE GetBaseViewDescriptor() const;
         D3D12_GPU_DESCRIPTOR_HANDLE GetBaseSamplerDescriptor() const;
 
-      private:
-        bool Populate(ShaderVisibleDescriptorAllocator* allocator,
-                      uint32_t descriptorCount,
-                      D3D12_DESCRIPTOR_HEAP_TYPE heapType,
-                      const CPUDescriptorHeapAllocation& stagingAllocation,
-                      GPUDescriptorHeapAllocation* allocation);
+        void SetSamplerAllocationEntry(Ref<SamplerHeapCacheEntry> entry);
 
+      private:
         ~BindGroup() override;
 
-        GPUDescriptorHeapAllocation mGPUSamplerAllocation;
-        GPUDescriptorHeapAllocation mGPUViewAllocation;
+        Ref<SamplerHeapCacheEntry> mSamplerAllocationEntry;
 
-        CPUDescriptorHeapAllocation mCPUSamplerAllocation;
+        GPUDescriptorHeapAllocation mGPUViewAllocation;
         CPUDescriptorHeapAllocation mCPUViewAllocation;
     };
 }}  // namespace dawn_native::d3d12
diff --git a/src/dawn_native/d3d12/BindGroupLayoutD3D12.cpp b/src/dawn_native/d3d12/BindGroupLayoutD3D12.cpp
index c5c7999..2bd7bae 100644
--- a/src/dawn_native/d3d12/BindGroupLayoutD3D12.cpp
+++ b/src/dawn_native/d3d12/BindGroupLayoutD3D12.cpp
@@ -17,6 +17,7 @@
 #include "common/BitSetIterator.h"
 #include "dawn_native/d3d12/BindGroupD3D12.h"
 #include "dawn_native/d3d12/DeviceD3D12.h"
+#include "dawn_native/d3d12/SamplerHeapCacheD3D12.h"
 #include "dawn_native/d3d12/StagingDescriptorAllocatorD3D12.h"
 
 namespace dawn_native { namespace d3d12 {
@@ -147,28 +148,25 @@
             viewSizeIncrement = mViewAllocator->GetSizeIncrement();
         }
 
-        uint32_t samplerSizeIncrement = 0;
-        CPUDescriptorHeapAllocation samplerAllocation;
+        Ref<BindGroup> bindGroup = AcquireRef<BindGroup>(
+            mBindGroupAllocator.Allocate(device, descriptor, viewSizeIncrement, viewAllocation));
+
         if (GetSamplerDescriptorCount() > 0) {
-            DAWN_TRY_ASSIGN(samplerAllocation, mSamplerAllocator->AllocateCPUDescriptors());
-            samplerSizeIncrement = mSamplerAllocator->GetSizeIncrement();
+            Ref<SamplerHeapCacheEntry> samplerHeapCacheEntry;
+            DAWN_TRY_ASSIGN(samplerHeapCacheEntry, device->GetSamplerHeapCache()->GetOrCreate(
+                                                       bindGroup.Get(), mSamplerAllocator));
+            bindGroup->SetSamplerAllocationEntry(std::move(samplerHeapCacheEntry));
         }
 
-        return mBindGroupAllocator.Allocate(device, descriptor, viewSizeIncrement, viewAllocation,
-                                            samplerSizeIncrement, samplerAllocation);
+        return bindGroup.Detach();
     }
 
     void BindGroupLayout::DeallocateBindGroup(BindGroup* bindGroup,
-                                              CPUDescriptorHeapAllocation* viewAllocation,
-                                              CPUDescriptorHeapAllocation* samplerAllocation) {
+                                              CPUDescriptorHeapAllocation* viewAllocation) {
         if (viewAllocation->IsValid()) {
             mViewAllocator->Deallocate(viewAllocation);
         }
 
-        if (samplerAllocation->IsValid()) {
-            mSamplerAllocator->Deallocate(samplerAllocation);
-        }
-
         mBindGroupAllocator.Deallocate(bindGroup);
     }
 
diff --git a/src/dawn_native/d3d12/BindGroupLayoutD3D12.h b/src/dawn_native/d3d12/BindGroupLayoutD3D12.h
index d04ab75..e739ca2 100644
--- a/src/dawn_native/d3d12/BindGroupLayoutD3D12.h
+++ b/src/dawn_native/d3d12/BindGroupLayoutD3D12.h
@@ -25,6 +25,7 @@
     class BindGroup;
     class CPUDescriptorHeapAllocation;
     class Device;
+    class SamplerHeapCacheEntry;
     class StagingDescriptorAllocator;
 
     class BindGroupLayout final : public BindGroupLayoutBase {
@@ -33,9 +34,7 @@
 
         ResultOrError<BindGroup*> AllocateBindGroup(Device* device,
                                                     const BindGroupDescriptor* descriptor);
-        void DeallocateBindGroup(BindGroup* bindGroup,
-                                 CPUDescriptorHeapAllocation* viewAllocation,
-                                 CPUDescriptorHeapAllocation* samplerAllocation);
+        void DeallocateBindGroup(BindGroup* bindGroup, CPUDescriptorHeapAllocation* viewAllocation);
 
         enum DescriptorType {
             CBV,
diff --git a/src/dawn_native/d3d12/CommandBufferD3D12.cpp b/src/dawn_native/d3d12/CommandBufferD3D12.cpp
index 0ec73aa..ebbef1c 100644
--- a/src/dawn_native/d3d12/CommandBufferD3D12.cpp
+++ b/src/dawn_native/d3d12/CommandBufferD3D12.cpp
@@ -30,6 +30,7 @@
 #include "dawn_native/d3d12/RenderPassBuilderD3D12.h"
 #include "dawn_native/d3d12/RenderPipelineD3D12.h"
 #include "dawn_native/d3d12/SamplerD3D12.h"
+#include "dawn_native/d3d12/SamplerHeapCacheD3D12.h"
 #include "dawn_native/d3d12/ShaderVisibleDescriptorAllocatorD3D12.h"
 #include "dawn_native/d3d12/StagingDescriptorAllocatorD3D12.h"
 #include "dawn_native/d3d12/TextureCopySplitter.h"
@@ -95,6 +96,7 @@
       public:
         BindGroupStateTracker(Device* device)
             : BindGroupAndStorageBarrierTrackerBase(),
+              mDevice(device),
               mViewAllocator(device->GetViewShaderVisibleDescriptorAllocator()),
               mSamplerAllocator(device->GetSamplerShaderVisibleDescriptorAllocator()) {
         }
@@ -117,7 +119,7 @@
             for (uint32_t index : IterateBitSet(mDirtyBindGroups)) {
                 BindGroup* group = ToBackend(mBindGroups[index]);
                 didCreateBindGroupViews = group->PopulateViews(mViewAllocator);
-                didCreateBindGroupSamplers = group->PopulateSamplers(mSamplerAllocator);
+                didCreateBindGroupSamplers = group->PopulateSamplers(mDevice, mSamplerAllocator);
                 if (!didCreateBindGroupViews && !didCreateBindGroupSamplers) {
                     break;
                 }
@@ -143,7 +145,8 @@
                 for (uint32_t index : IterateBitSet(mBindGroupLayoutsMask)) {
                     BindGroup* group = ToBackend(mBindGroups[index]);
                     didCreateBindGroupViews = group->PopulateViews(mViewAllocator);
-                    didCreateBindGroupSamplers = group->PopulateSamplers(mSamplerAllocator);
+                    didCreateBindGroupSamplers =
+                        group->PopulateSamplers(mDevice, mSamplerAllocator);
                     ASSERT(didCreateBindGroupViews);
                     ASSERT(didCreateBindGroupSamplers);
                 }
@@ -310,6 +313,8 @@
             }
         }
 
+        Device* mDevice;
+
         bool mInCompute = false;
 
         ShaderVisibleDescriptorAllocator* mViewAllocator;
diff --git a/src/dawn_native/d3d12/DeviceD3D12.cpp b/src/dawn_native/d3d12/DeviceD3D12.cpp
index cfd109f..21f4aaf 100644
--- a/src/dawn_native/d3d12/DeviceD3D12.cpp
+++ b/src/dawn_native/d3d12/DeviceD3D12.cpp
@@ -33,6 +33,7 @@
 #include "dawn_native/d3d12/ResidencyManagerD3D12.h"
 #include "dawn_native/d3d12/ResourceAllocatorManagerD3D12.h"
 #include "dawn_native/d3d12/SamplerD3D12.h"
+#include "dawn_native/d3d12/SamplerHeapCacheD3D12.h"
 #include "dawn_native/d3d12/ShaderModuleD3D12.h"
 #include "dawn_native/d3d12/ShaderVisibleDescriptorAllocatorD3D12.h"
 #include "dawn_native/d3d12/StagingBufferD3D12.h"
@@ -109,6 +110,8 @@
         mDepthStencilViewAllocator = std::make_unique<StagingDescriptorAllocator>(
             this, 1, kAttachmentDescriptorHeapSize, D3D12_DESCRIPTOR_HEAP_TYPE_DSV);
 
+        mSamplerHeapCache = std::make_unique<SamplerHeapCache>(this);
+
         mMapRequestTracker = std::make_unique<MapRequestTracker>(this);
         mResidencyManager = std::make_unique<ResidencyManager>(this);
         mResourceAllocatorManager = std::make_unique<ResourceAllocatorManager>(this);
@@ -503,4 +506,8 @@
         return mDepthStencilViewAllocator.get();
     }
 
+    SamplerHeapCache* Device::GetSamplerHeapCache() {
+        return mSamplerHeapCache.get();
+    }
+
 }}  // namespace dawn_native::d3d12
diff --git a/src/dawn_native/d3d12/DeviceD3D12.h b/src/dawn_native/d3d12/DeviceD3D12.h
index 4505230..88011b9 100644
--- a/src/dawn_native/d3d12/DeviceD3D12.h
+++ b/src/dawn_native/d3d12/DeviceD3D12.h
@@ -35,6 +35,7 @@
     class PlatformFunctions;
     class ResidencyManager;
     class ResourceAllocatorManager;
+    class SamplerHeapCache;
     class ShaderVisibleDescriptorAllocator;
     class StagingDescriptorAllocator;
 
@@ -107,6 +108,8 @@
         StagingDescriptorAllocator* GetSamplerStagingDescriptorAllocator(
             uint32_t descriptorCount) const;
 
+        SamplerHeapCache* GetSamplerHeapCache();
+
         StagingDescriptorAllocator* GetRenderTargetViewAllocator() const;
 
         StagingDescriptorAllocator* GetDepthStencilViewAllocator() const;
@@ -194,6 +197,10 @@
         std::unique_ptr<ShaderVisibleDescriptorAllocator> mViewShaderVisibleDescriptorAllocator;
 
         std::unique_ptr<ShaderVisibleDescriptorAllocator> mSamplerShaderVisibleDescriptorAllocator;
+
+        // Sampler cache needs to be destroyed before the CPU sampler allocator to ensure the final
+        // release is called.
+        std::unique_ptr<SamplerHeapCache> mSamplerHeapCache;
     };
 
 }}  // namespace dawn_native::d3d12
diff --git a/src/dawn_native/d3d12/SamplerHeapCacheD3D12.cpp b/src/dawn_native/d3d12/SamplerHeapCacheD3D12.cpp
new file mode 100644
index 0000000..224051a
--- /dev/null
+++ b/src/dawn_native/d3d12/SamplerHeapCacheD3D12.cpp
@@ -0,0 +1,167 @@
+// Copyright 2020 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "dawn_native/d3d12/SamplerHeapCacheD3D12.h"
+
+#include "common/Assert.h"
+#include "common/HashUtils.h"
+#include "dawn_native/d3d12/BindGroupD3D12.h"
+#include "dawn_native/d3d12/BindGroupLayoutD3D12.h"
+#include "dawn_native/d3d12/DeviceD3D12.h"
+#include "dawn_native/d3d12/Forward.h"
+#include "dawn_native/d3d12/SamplerD3D12.h"
+#include "dawn_native/d3d12/ShaderVisibleDescriptorAllocatorD3D12.h"
+#include "dawn_native/d3d12/StagingDescriptorAllocatorD3D12.h"
+
+namespace dawn_native { namespace d3d12 {
+
+    SamplerHeapCacheEntry::SamplerHeapCacheEntry(std::vector<Sampler*> samplers)
+        : mSamplers(std::move(samplers)) {
+    }
+
+    SamplerHeapCacheEntry::SamplerHeapCacheEntry(SamplerHeapCache* cache,
+                                                 StagingDescriptorAllocator* allocator,
+                                                 std::vector<Sampler*> samplers,
+                                                 CPUDescriptorHeapAllocation allocation)
+        : mCPUAllocation(std::move(allocation)),
+          mSamplers(std::move(samplers)),
+          mAllocator(allocator),
+          mCache(cache) {
+        ASSERT(mCache != nullptr);
+        ASSERT(mCPUAllocation.IsValid());
+        ASSERT(!mSamplers.empty());
+    }
+
+    std::vector<Sampler*>&& SamplerHeapCacheEntry::AcquireSamplers() {
+        return std::move(mSamplers);
+    }
+
+    SamplerHeapCacheEntry::~SamplerHeapCacheEntry() {
+        // If this is a blueprint then the CPU allocation cannot exist and has no entry to remove.
+        if (mCPUAllocation.IsValid()) {
+            mCache->RemoveCacheEntry(this);
+            mAllocator->Deallocate(&mCPUAllocation);
+        }
+
+        ASSERT(!mCPUAllocation.IsValid());
+    }
+
+    bool SamplerHeapCacheEntry::Populate(Device* device,
+                                         ShaderVisibleDescriptorAllocator* allocator) {
+        if (allocator->IsAllocationStillValid(mGPUAllocation)) {
+            return true;
+        }
+
+        ASSERT(!mSamplers.empty());
+
+        // Attempt to allocate descriptors for the currently bound shader-visible heaps.
+        // If either failed, return early to re-allocate and switch the heaps.
+        const uint32_t descriptorCount = mSamplers.size();
+        D3D12_CPU_DESCRIPTOR_HANDLE baseCPUDescriptor;
+        if (!allocator->AllocateGPUDescriptors(descriptorCount, device->GetPendingCommandSerial(),
+                                               &baseCPUDescriptor, &mGPUAllocation)) {
+            return false;
+        }
+
+        // CPU bindgroups are sparsely allocated across CPU heaps. Instead of doing
+        // simple copies per bindgroup, a single non-simple copy could be issued.
+        // TODO(dawn:155): Consider doing this optimization.
+        device->GetD3D12Device()->CopyDescriptorsSimple(descriptorCount, baseCPUDescriptor,
+                                                        mCPUAllocation.GetBaseDescriptor(),
+                                                        D3D12_DESCRIPTOR_HEAP_TYPE_SAMPLER);
+
+        return true;
+    }
+
+    D3D12_GPU_DESCRIPTOR_HANDLE SamplerHeapCacheEntry::GetBaseDescriptor() const {
+        return mGPUAllocation.GetBaseDescriptor();
+    }
+
+    ResultOrError<Ref<SamplerHeapCacheEntry>> SamplerHeapCache::GetOrCreate(
+        const BindGroup* group,
+        StagingDescriptorAllocator* samplerAllocator) {
+        const BindGroupLayout* bgl = ToBackend(group->GetLayout());
+
+        // If a previously created bindgroup used the same samplers, the backing sampler heap
+        // allocation can be reused. The packed list of samplers acts as the key to lookup the
+        // allocation in a cache.
+        // TODO(dawn:155): Avoid re-allocating the vector each lookup.
+        std::vector<Sampler*> samplers;
+        samplers.reserve(bgl->GetSamplerDescriptorCount());
+
+        for (BindingIndex bindingIndex = bgl->GetDynamicBufferCount();
+             bindingIndex < bgl->GetBindingCount(); ++bindingIndex) {
+            const BindingInfo& bindingInfo = bgl->GetBindingInfo(bindingIndex);
+            if (bindingInfo.type == wgpu::BindingType::Sampler ||
+                bindingInfo.type == wgpu::BindingType::ComparisonSampler) {
+                samplers.push_back(ToBackend(group->GetBindingAsSampler(bindingIndex)));
+            }
+        }
+
+        // Check the cache if there exists a sampler heap allocation that corresponds to the
+        // samplers.
+        SamplerHeapCacheEntry blueprint(std::move(samplers));
+        auto iter = mCache.find(&blueprint);
+        if (iter != mCache.end()) {
+            return Ref<SamplerHeapCacheEntry>(*iter);
+        }
+
+        // Steal the sampler vector back from the blueprint to avoid creating a new copy for the
+        // real entry below.
+        samplers = std::move(blueprint.AcquireSamplers());
+
+        CPUDescriptorHeapAllocation allocation;
+        DAWN_TRY_ASSIGN(allocation, samplerAllocator->AllocateCPUDescriptors());
+
+        const uint32_t samplerSizeIncrement = samplerAllocator->GetSizeIncrement();
+        ID3D12Device* d3d12Device = mDevice->GetD3D12Device();
+
+        for (uint32_t i = 0; i < samplers.size(); ++i) {
+            const auto& samplerDesc = samplers[i]->GetSamplerDescriptor();
+            d3d12Device->CreateSampler(&samplerDesc,
+                                       allocation.OffsetFrom(samplerSizeIncrement, i));
+        }
+
+        Ref<SamplerHeapCacheEntry> entry = AcquireRef(new SamplerHeapCacheEntry(
+            this, samplerAllocator, std::move(samplers), std::move(allocation)));
+        mCache.insert(entry.Get());
+        return std::move(entry);
+    }
+
+    SamplerHeapCache::SamplerHeapCache(Device* device) : mDevice(device) {
+    }
+
+    SamplerHeapCache::~SamplerHeapCache() {
+        ASSERT(mCache.empty());
+    }
+
+    void SamplerHeapCache::RemoveCacheEntry(SamplerHeapCacheEntry* entry) {
+        ASSERT(entry->GetRefCountForTesting() == 0);
+        size_t removedCount = mCache.erase(entry);
+        ASSERT(removedCount == 1);
+    }
+
+    size_t SamplerHeapCacheEntry::HashFunc::operator()(const SamplerHeapCacheEntry* entry) const {
+        size_t hash = 0;
+        for (const Sampler* sampler : entry->mSamplers) {
+            HashCombine(&hash, sampler);
+        }
+        return hash;
+    }
+
+    bool SamplerHeapCacheEntry::EqualityFunc::operator()(const SamplerHeapCacheEntry* a,
+                                                         const SamplerHeapCacheEntry* b) const {
+        return a->mSamplers == b->mSamplers;
+    }
+}}  // namespace dawn_native::d3d12
\ No newline at end of file
diff --git a/src/dawn_native/d3d12/SamplerHeapCacheD3D12.h b/src/dawn_native/d3d12/SamplerHeapCacheD3D12.h
new file mode 100644
index 0000000..2f41086
--- /dev/null
+++ b/src/dawn_native/d3d12/SamplerHeapCacheD3D12.h
@@ -0,0 +1,108 @@
+// Copyright 2020 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef DAWNNATIVE_D3D12_SAMPLERHEAPCACHE_H_
+#define DAWNNATIVE_D3D12_SAMPLERHEAPCACHE_H_
+
+#include "common/RefCounted.h"
+#include "dawn_native/BindingInfo.h"
+#include "dawn_native/d3d12/CPUDescriptorHeapAllocationD3D12.h"
+#include "dawn_native/d3d12/GPUDescriptorHeapAllocationD3D12.h"
+
+#include <unordered_set>
+#include <vector>
+
+// |SamplerHeapCacheEntry| maintains a cache of sampler descriptor heap allocations.
+// Each entry represents one or more sampler descriptors that co-exist in a CPU and
+// GPU descriptor heap. The CPU-side allocation is deallocated once the final reference
+// has been released while the GPU-side allocation is deallocated when the GPU is finished.
+//
+// The BindGroupLayout hands out these entries upon constructing the bindgroup. If the entry is not
+// invalid, it will allocate and initialize so it may be reused by another bindgroup.
+//
+// The cache is primary needed for the GPU sampler heap, which is much smaller than the view heap
+// and switches incur expensive pipeline flushes.
+namespace dawn_native { namespace d3d12 {
+
+    class BindGroup;
+    class Device;
+    class Sampler;
+    class SamplerHeapCache;
+    class StagingDescriptorAllocator;
+    class ShaderVisibleDescriptorAllocator;
+
+    // Wraps sampler descriptor heap allocations in a cache.
+    class SamplerHeapCacheEntry : public RefCounted {
+      public:
+        SamplerHeapCacheEntry() = default;
+        SamplerHeapCacheEntry(std::vector<Sampler*> samplers);
+        SamplerHeapCacheEntry(SamplerHeapCache* cache,
+                              StagingDescriptorAllocator* allocator,
+                              std::vector<Sampler*> samplers,
+                              CPUDescriptorHeapAllocation allocation);
+        ~SamplerHeapCacheEntry() override;
+
+        D3D12_GPU_DESCRIPTOR_HANDLE GetBaseDescriptor() const;
+
+        std::vector<Sampler*>&& AcquireSamplers();
+
+        bool Populate(Device* device, ShaderVisibleDescriptorAllocator* allocator);
+
+        // Functors necessary for the unordered_map<SamplerHeapCacheEntry*>-based cache.
+        struct HashFunc {
+            size_t operator()(const SamplerHeapCacheEntry* entry) const;
+        };
+
+        struct EqualityFunc {
+            bool operator()(const SamplerHeapCacheEntry* a, const SamplerHeapCacheEntry* b) const;
+        };
+
+      private:
+        CPUDescriptorHeapAllocation mCPUAllocation;
+        GPUDescriptorHeapAllocation mGPUAllocation;
+
+        // Storing raw pointer because the sampler object will be already hashed
+        // by the device and will already be unique.
+        std::vector<Sampler*> mSamplers;
+
+        StagingDescriptorAllocator* mAllocator = nullptr;
+        SamplerHeapCache* mCache = nullptr;
+    };
+
+    // Cache descriptor heap allocations so that we don't create duplicate ones for every
+    // BindGroup.
+    class SamplerHeapCache {
+      public:
+        SamplerHeapCache(Device* device);
+        ~SamplerHeapCache();
+
+        ResultOrError<Ref<SamplerHeapCacheEntry>> GetOrCreate(
+            const BindGroup* group,
+            StagingDescriptorAllocator* samplerAllocator);
+
+        void RemoveCacheEntry(SamplerHeapCacheEntry* entry);
+
+      private:
+        Device* mDevice;
+
+        using Cache = std::unordered_set<SamplerHeapCacheEntry*,
+                                         SamplerHeapCacheEntry::HashFunc,
+                                         SamplerHeapCacheEntry::EqualityFunc>;
+
+        Cache mCache;
+    };
+
+}}  // namespace dawn_native::d3d12
+
+#endif  // DAWNNATIVE_D3D12_SAMPLERHEAPCACHE_H_
\ No newline at end of file
diff --git a/src/tests/white_box/D3D12DescriptorHeapTests.cpp b/src/tests/white_box/D3D12DescriptorHeapTests.cpp
index cbe7dcb..b6eb4a6 100644
--- a/src/tests/white_box/D3D12DescriptorHeapTests.cpp
+++ b/src/tests/white_box/D3D12DescriptorHeapTests.cpp
@@ -115,12 +115,60 @@
     StagingDescriptorAllocator mAllocator;
 };
 
-// Verify the shader visible sampler heap switch within a single submit.
-TEST_P(D3D12DescriptorHeapTests, SwitchOverSamplerHeap) {
+// Verify the shader visible view heaps switch over within a single submit.
+TEST_P(D3D12DescriptorHeapTests, SwitchOverViewHeap) {
+    DAWN_SKIP_TEST_IF(!mD3DDevice->IsToggleEnabled(
+        dawn_native::Toggle::UseD3D12SmallShaderVisibleHeapForTesting));
+
+    utils::ComboRenderPipelineDescriptor renderPipelineDescriptor(device);
+
+    // Fill in a view heap with "view only" bindgroups (1x view per group) by creating a
+    // view bindgroup each draw. After HEAP_SIZE + 1 draws, the heaps must switch over.
+    renderPipelineDescriptor.vertexStage.module = mSimpleVSModule;
+    renderPipelineDescriptor.cFragmentStage.module = mSimpleFSModule;
+
+    wgpu::RenderPipeline renderPipeline = device.CreateRenderPipeline(&renderPipelineDescriptor);
+    utils::BasicRenderPass renderPass = utils::CreateBasicRenderPass(device, kRTSize, kRTSize);
+
+    Device* d3dDevice = reinterpret_cast<Device*>(device.Get());
+    ShaderVisibleDescriptorAllocator* allocator =
+        d3dDevice->GetViewShaderVisibleDescriptorAllocator();
+    const uint64_t heapSize = allocator->GetShaderVisibleHeapSizeForTesting();
+
+    const Serial heapSerial = allocator->GetShaderVisibleHeapSerialForTesting();
+
+    wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+    {
+        wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass.renderPassInfo);
+
+        pass.SetPipeline(renderPipeline);
+
+        std::array<float, 4> redColor = {1, 0, 0, 1};
+        wgpu::Buffer uniformBuffer = utils::CreateBufferFromData(
+            device, &redColor, sizeof(redColor), wgpu::BufferUsage::Uniform);
+
+        for (uint32_t i = 0; i < heapSize + 1; ++i) {
+            pass.SetBindGroup(0, utils::MakeBindGroup(device, renderPipeline.GetBindGroupLayout(0),
+                                                      {{0, uniformBuffer, 0, sizeof(redColor)}}));
+            pass.Draw(3);
+        }
+
+        pass.EndPass();
+    }
+
+    wgpu::CommandBuffer commands = encoder.Finish();
+    queue.Submit(1, &commands);
+
+    EXPECT_EQ(allocator->GetShaderVisibleHeapSerialForTesting(), heapSerial + 1);
+}
+
+// Verify the shader visible sampler heaps does not switch over within a single submit.
+TEST_P(D3D12DescriptorHeapTests, NoSwitchOverSamplerHeap) {
     utils::ComboRenderPipelineDescriptor renderPipelineDescriptor(device);
 
     // Fill in a sampler heap with "sampler only" bindgroups (1x sampler per group) by creating a
-    // sampler bindgroup each draw. After HEAP_SIZE + 1 draws, the heaps must switch over.
+    // sampler bindgroup each draw. After HEAP_SIZE + 1 draws, the heaps WILL NOT switch over
+    // because the sampler heap allocations are de-duplicated.
     renderPipelineDescriptor.vertexStage.module =
         utils::CreateShaderModule(device, utils::SingleShaderStage::Vertex, R"(
             #version 450
@@ -167,7 +215,7 @@
     wgpu::CommandBuffer commands = encoder.Finish();
     queue.Submit(1, &commands);
 
-    EXPECT_EQ(allocator->GetShaderVisibleHeapSerialForTesting(), heapSerial + 1);
+    EXPECT_EQ(allocator->GetShaderVisibleHeapSerialForTesting(), heapSerial);
 }
 
 // Verify shader-visible heaps can be recycled for multiple submits.
@@ -727,13 +775,8 @@
         EXPECT_EQ(viewAllocator->GetShaderVisibleHeapSerialForTesting(),
                   viewHeapSerial + kNumOfViewHeaps);
 
-        const uint32_t numOfSamplerHeaps =
-            numOfEncodedBindGroups /
-            samplerAllocator->GetShaderVisibleHeapSizeForTesting();  // 1 sampler per group.
-
-        EXPECT_EQ(samplerAllocator->GetShaderVisiblePoolSizeForTesting(), numOfSamplerHeaps);
-        EXPECT_EQ(samplerAllocator->GetShaderVisibleHeapSerialForTesting(),
-                  samplerHeapSerial + numOfSamplerHeaps);
+        EXPECT_EQ(samplerAllocator->GetShaderVisiblePoolSizeForTesting(), 0u);
+        EXPECT_EQ(samplerAllocator->GetShaderVisibleHeapSerialForTesting(), samplerHeapSerial);
     }
 }