dawn_native: deduplicate samplers

Bug:dawn:143
Change-Id: I3aee914100fed87ea98cf22a7b90070c165780a2
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/7361
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Idan Raiter <idanr@google.com>
diff --git a/src/dawn_native/Device.cpp b/src/dawn_native/Device.cpp
index 7d7c222..13c931a 100644
--- a/src/dawn_native/Device.cpp
+++ b/src/dawn_native/Device.cpp
@@ -51,6 +51,7 @@
         ContentLessObjectCache<ComputePipelineBase> computePipelines;
         ContentLessObjectCache<PipelineLayoutBase> pipelineLayouts;
         ContentLessObjectCache<RenderPipelineBase> renderPipelines;
+        ContentLessObjectCache<SamplerBase> samplers;
         ContentLessObjectCache<ShaderModuleBase> shaderModules;
     };
 
@@ -186,6 +187,27 @@
         ASSERT(removedCount == 1);
     }
 
+    ResultOrError<SamplerBase*> DeviceBase::GetOrCreateSampler(
+        const SamplerDescriptor* descriptor) {
+        SamplerBase blueprint(this, descriptor, true);
+
+        auto iter = mCaches->samplers.find(&blueprint);
+        if (iter != mCaches->samplers.end()) {
+            (*iter)->Reference();
+            return *iter;
+        }
+
+        SamplerBase* backendObj;
+        DAWN_TRY_ASSIGN(backendObj, CreateSamplerImpl(descriptor));
+        mCaches->samplers.insert(backendObj);
+        return backendObj;
+    }
+
+    void DeviceBase::UncacheSampler(SamplerBase* obj) {
+        size_t removedCount = mCaches->samplers.erase(obj);
+        ASSERT(removedCount == 1);
+    }
+
     ResultOrError<ShaderModuleBase*> DeviceBase::GetOrCreateShaderModule(
         const ShaderModuleDescriptor* descriptor) {
         ShaderModuleBase blueprint(this, descriptor, true);
@@ -465,7 +487,7 @@
     MaybeError DeviceBase::CreateSamplerInternal(SamplerBase** result,
                                                  const SamplerDescriptor* descriptor) {
         DAWN_TRY(ValidateSamplerDescriptor(this, descriptor));
-        DAWN_TRY_ASSIGN(*result, CreateSamplerImpl(descriptor));
+        DAWN_TRY_ASSIGN(*result, GetOrCreateSampler(descriptor));
         return {};
     }
 
diff --git a/src/dawn_native/Device.h b/src/dawn_native/Device.h
index 998e035..939dcc0 100644
--- a/src/dawn_native/Device.h
+++ b/src/dawn_native/Device.h
@@ -96,6 +96,9 @@
             const RenderPipelineDescriptor* descriptor);
         void UncacheRenderPipeline(RenderPipelineBase* obj);
 
+        ResultOrError<SamplerBase*> GetOrCreateSampler(const SamplerDescriptor* descriptor);
+        void UncacheSampler(SamplerBase* obj);
+
         ResultOrError<ShaderModuleBase*> GetOrCreateShaderModule(
             const ShaderModuleDescriptor* descriptor);
         void UncacheShaderModule(ShaderModuleBase* obj);
diff --git a/src/dawn_native/Sampler.cpp b/src/dawn_native/Sampler.cpp
index bbb224d..360397c 100644
--- a/src/dawn_native/Sampler.cpp
+++ b/src/dawn_native/Sampler.cpp
@@ -14,6 +14,7 @@
 
 #include "dawn_native/Sampler.h"
 
+#include "common/HashUtils.h"
 #include "dawn_native/Device.h"
 #include "dawn_native/ValidationUtils_autogen.h"
 
@@ -45,16 +46,60 @@
 
     // SamplerBase
 
-    SamplerBase::SamplerBase(DeviceBase* device, const SamplerDescriptor*) : ObjectBase(device) {
+    SamplerBase::SamplerBase(DeviceBase* device,
+                             const SamplerDescriptor* descriptor,
+                             bool blueprint)
+        : ObjectBase(device),
+          mAddressModeU(descriptor->addressModeU),
+          mAddressModeV(descriptor->addressModeV),
+          mAddressModeW(descriptor->addressModeW),
+          mMagFilter(descriptor->magFilter),
+          mMinFilter(descriptor->minFilter),
+          mMipmapFilter(descriptor->mipmapFilter),
+          mLodMinClamp(descriptor->lodMinClamp),
+          mLodMaxClamp(descriptor->lodMaxClamp),
+          mCompareFunction(descriptor->compareFunction),
+          mIsBlueprint(blueprint) {
     }
 
     SamplerBase::SamplerBase(DeviceBase* device, ObjectBase::ErrorTag tag)
         : ObjectBase(device, tag) {
     }
 
+    SamplerBase::~SamplerBase() {
+        // Do not uncache the actual cached object if we are a blueprint
+        if (!mIsBlueprint && !IsError()) {
+            GetDevice()->UncacheSampler(this);
+        }
+    }
+
     // static
     SamplerBase* SamplerBase::MakeError(DeviceBase* device) {
         return new SamplerBase(device, ObjectBase::kError);
     }
 
+    size_t SamplerBase::HashFunc::operator()(const SamplerBase* module) const {
+        size_t hash = 0;
+
+        HashCombine(&hash, module->mAddressModeU);
+        HashCombine(&hash, module->mAddressModeV);
+        HashCombine(&hash, module->mAddressModeW);
+        HashCombine(&hash, module->mMagFilter);
+        HashCombine(&hash, module->mMinFilter);
+        HashCombine(&hash, module->mMipmapFilter);
+        HashCombine(&hash, module->mLodMinClamp);
+        HashCombine(&hash, module->mLodMaxClamp);
+        HashCombine(&hash, module->mCompareFunction);
+
+        return hash;
+    }
+
+    bool SamplerBase::EqualityFunc::operator()(const SamplerBase* a, const SamplerBase* b) const {
+        return a->mAddressModeU == b->mAddressModeU && a->mAddressModeV == b->mAddressModeV &&
+               a->mAddressModeW == b->mAddressModeW && a->mMagFilter == b->mMagFilter &&
+               a->mMinFilter == b->mMinFilter && a->mMipmapFilter == b->mMipmapFilter &&
+               a->mLodMinClamp == b->mLodMinClamp && a->mLodMaxClamp == b->mLodMaxClamp &&
+               a->mCompareFunction == b->mCompareFunction;
+    }
+
 }  // namespace dawn_native
diff --git a/src/dawn_native/Sampler.h b/src/dawn_native/Sampler.h
index cde32dd..202f3cd 100644
--- a/src/dawn_native/Sampler.h
+++ b/src/dawn_native/Sampler.h
@@ -28,12 +28,35 @@
 
     class SamplerBase : public ObjectBase {
       public:
-        SamplerBase(DeviceBase* device, const SamplerDescriptor* descriptor);
+        SamplerBase(DeviceBase* device,
+                    const SamplerDescriptor* descriptor,
+                    bool blueprint = false);
+        ~SamplerBase() override;
 
         static SamplerBase* MakeError(DeviceBase* device);
 
+        // Functors necessary for the unordered_set<SamplerBase*>-based cache.
+        struct HashFunc {
+            size_t operator()(const SamplerBase* module) const;
+        };
+        struct EqualityFunc {
+            bool operator()(const SamplerBase* a, const SamplerBase* b) const;
+        };
+
       private:
         SamplerBase(DeviceBase* device, ObjectBase::ErrorTag tag);
+
+        // TODO(cwallez@chromium.org): Store a crypto hash of the items instead?
+        dawn::AddressMode mAddressModeU;
+        dawn::AddressMode mAddressModeV;
+        dawn::AddressMode mAddressModeW;
+        dawn::FilterMode mMagFilter;
+        dawn::FilterMode mMinFilter;
+        dawn::FilterMode mMipmapFilter;
+        float mLodMinClamp;
+        float mLodMaxClamp;
+        dawn::CompareFunction mCompareFunction;
+        bool mIsBlueprint = false;
     };
 
 }  // namespace dawn_native
diff --git a/src/tests/end2end/ObjectCachingTests.cpp b/src/tests/end2end/ObjectCachingTests.cpp
index 8b8e300..e7de879 100644
--- a/src/tests/end2end/ObjectCachingTests.cpp
+++ b/src/tests/end2end/ObjectCachingTests.cpp
@@ -286,4 +286,61 @@
     EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire());
 }
 
+// Test that Samplers are correctly deduplicated.
+TEST_P(ObjectCachingTest, SamplerDeduplication) {
+    dawn::SamplerDescriptor samplerDesc = utils::GetDefaultSamplerDescriptor();
+    dawn::Sampler sampler = device.CreateSampler(&samplerDesc);
+
+    dawn::SamplerDescriptor sameSamplerDesc = utils::GetDefaultSamplerDescriptor();
+    dawn::Sampler sameSampler = device.CreateSampler(&sameSamplerDesc);
+
+    dawn::SamplerDescriptor otherSamplerDescAddressModeU = utils::GetDefaultSamplerDescriptor();
+    otherSamplerDescAddressModeU.addressModeU = dawn::AddressMode::ClampToEdge;
+    dawn::Sampler otherSamplerAddressModeU = device.CreateSampler(&otherSamplerDescAddressModeU);
+
+    dawn::SamplerDescriptor otherSamplerDescAddressModeV = utils::GetDefaultSamplerDescriptor();
+    otherSamplerDescAddressModeV.addressModeV = dawn::AddressMode::ClampToEdge;
+    dawn::Sampler otherSamplerAddressModeV = device.CreateSampler(&otherSamplerDescAddressModeV);
+
+    dawn::SamplerDescriptor otherSamplerDescAddressModeW = utils::GetDefaultSamplerDescriptor();
+    otherSamplerDescAddressModeW.addressModeW = dawn::AddressMode::ClampToEdge;
+    dawn::Sampler otherSamplerAddressModeW = device.CreateSampler(&otherSamplerDescAddressModeW);
+
+    dawn::SamplerDescriptor otherSamplerDescMagFilter = utils::GetDefaultSamplerDescriptor();
+    otherSamplerDescMagFilter.magFilter = dawn::FilterMode::Nearest;
+    dawn::Sampler otherSamplerMagFilter = device.CreateSampler(&otherSamplerDescMagFilter);
+
+    dawn::SamplerDescriptor otherSamplerDescMinFilter = utils::GetDefaultSamplerDescriptor();
+    otherSamplerDescMinFilter.minFilter = dawn::FilterMode::Nearest;
+    dawn::Sampler otherSamplerMinFilter = device.CreateSampler(&otherSamplerDescMinFilter);
+
+    dawn::SamplerDescriptor otherSamplerDescMipmapFilter = utils::GetDefaultSamplerDescriptor();
+    otherSamplerDescMipmapFilter.mipmapFilter = dawn::FilterMode::Nearest;
+    dawn::Sampler otherSamplerMipmapFilter = device.CreateSampler(&otherSamplerDescMipmapFilter);
+
+    dawn::SamplerDescriptor otherSamplerDescLodMinClamp = utils::GetDefaultSamplerDescriptor();
+    otherSamplerDescLodMinClamp.lodMinClamp += 1;
+    dawn::Sampler otherSamplerLodMinClamp = device.CreateSampler(&otherSamplerDescLodMinClamp);
+
+    dawn::SamplerDescriptor otherSamplerDescLodMaxClamp = utils::GetDefaultSamplerDescriptor();
+    otherSamplerDescLodMaxClamp.lodMaxClamp += 1;
+    dawn::Sampler otherSamplerLodMaxClamp = device.CreateSampler(&otherSamplerDescLodMaxClamp);
+
+    dawn::SamplerDescriptor otherSamplerDescCompareFunction = utils::GetDefaultSamplerDescriptor();
+    otherSamplerDescCompareFunction.compareFunction = dawn::CompareFunction::Always;
+    dawn::Sampler otherSamplerCompareFunction =
+        device.CreateSampler(&otherSamplerDescCompareFunction);
+
+    EXPECT_NE(sampler.Get(), otherSamplerAddressModeU.Get());
+    EXPECT_NE(sampler.Get(), otherSamplerAddressModeV.Get());
+    EXPECT_NE(sampler.Get(), otherSamplerAddressModeW.Get());
+    EXPECT_NE(sampler.Get(), otherSamplerMagFilter.Get());
+    EXPECT_NE(sampler.Get(), otherSamplerMinFilter.Get());
+    EXPECT_NE(sampler.Get(), otherSamplerMipmapFilter.Get());
+    EXPECT_NE(sampler.Get(), otherSamplerLodMinClamp.Get());
+    EXPECT_NE(sampler.Get(), otherSamplerLodMaxClamp.Get());
+    EXPECT_NE(sampler.Get(), otherSamplerCompareFunction.Get());
+    EXPECT_EQ(sampler.Get() == sameSampler.Get(), !UsesWire());
+}
+
 DAWN_INSTANTIATE_TEST(ObjectCachingTest, D3D12Backend, MetalBackend, OpenGLBackend, VulkanBackend);