dawn_native: deduplicate samplers
Bug:dawn:143
Change-Id: I3aee914100fed87ea98cf22a7b90070c165780a2
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/7361
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Idan Raiter <idanr@google.com>
diff --git a/src/dawn_native/Device.cpp b/src/dawn_native/Device.cpp
index 7d7c222..13c931a 100644
--- a/src/dawn_native/Device.cpp
+++ b/src/dawn_native/Device.cpp
@@ -51,6 +51,7 @@
ContentLessObjectCache<ComputePipelineBase> computePipelines;
ContentLessObjectCache<PipelineLayoutBase> pipelineLayouts;
ContentLessObjectCache<RenderPipelineBase> renderPipelines;
+ ContentLessObjectCache<SamplerBase> samplers;
ContentLessObjectCache<ShaderModuleBase> shaderModules;
};
@@ -186,6 +187,27 @@
ASSERT(removedCount == 1);
}
+ ResultOrError<SamplerBase*> DeviceBase::GetOrCreateSampler(
+ const SamplerDescriptor* descriptor) {
+ SamplerBase blueprint(this, descriptor, true);
+
+ auto iter = mCaches->samplers.find(&blueprint);
+ if (iter != mCaches->samplers.end()) {
+ (*iter)->Reference();
+ return *iter;
+ }
+
+ SamplerBase* backendObj;
+ DAWN_TRY_ASSIGN(backendObj, CreateSamplerImpl(descriptor));
+ mCaches->samplers.insert(backendObj);
+ return backendObj;
+ }
+
+ void DeviceBase::UncacheSampler(SamplerBase* obj) {
+ size_t removedCount = mCaches->samplers.erase(obj);
+ ASSERT(removedCount == 1);
+ }
+
ResultOrError<ShaderModuleBase*> DeviceBase::GetOrCreateShaderModule(
const ShaderModuleDescriptor* descriptor) {
ShaderModuleBase blueprint(this, descriptor, true);
@@ -465,7 +487,7 @@
MaybeError DeviceBase::CreateSamplerInternal(SamplerBase** result,
const SamplerDescriptor* descriptor) {
DAWN_TRY(ValidateSamplerDescriptor(this, descriptor));
- DAWN_TRY_ASSIGN(*result, CreateSamplerImpl(descriptor));
+ DAWN_TRY_ASSIGN(*result, GetOrCreateSampler(descriptor));
return {};
}
diff --git a/src/dawn_native/Device.h b/src/dawn_native/Device.h
index 998e035..939dcc0 100644
--- a/src/dawn_native/Device.h
+++ b/src/dawn_native/Device.h
@@ -96,6 +96,9 @@
const RenderPipelineDescriptor* descriptor);
void UncacheRenderPipeline(RenderPipelineBase* obj);
+ ResultOrError<SamplerBase*> GetOrCreateSampler(const SamplerDescriptor* descriptor);
+ void UncacheSampler(SamplerBase* obj);
+
ResultOrError<ShaderModuleBase*> GetOrCreateShaderModule(
const ShaderModuleDescriptor* descriptor);
void UncacheShaderModule(ShaderModuleBase* obj);
diff --git a/src/dawn_native/Sampler.cpp b/src/dawn_native/Sampler.cpp
index bbb224d..360397c 100644
--- a/src/dawn_native/Sampler.cpp
+++ b/src/dawn_native/Sampler.cpp
@@ -14,6 +14,7 @@
#include "dawn_native/Sampler.h"
+#include "common/HashUtils.h"
#include "dawn_native/Device.h"
#include "dawn_native/ValidationUtils_autogen.h"
@@ -45,16 +46,60 @@
// SamplerBase
- SamplerBase::SamplerBase(DeviceBase* device, const SamplerDescriptor*) : ObjectBase(device) {
+ SamplerBase::SamplerBase(DeviceBase* device,
+ const SamplerDescriptor* descriptor,
+ bool blueprint)
+ : ObjectBase(device),
+ mAddressModeU(descriptor->addressModeU),
+ mAddressModeV(descriptor->addressModeV),
+ mAddressModeW(descriptor->addressModeW),
+ mMagFilter(descriptor->magFilter),
+ mMinFilter(descriptor->minFilter),
+ mMipmapFilter(descriptor->mipmapFilter),
+ mLodMinClamp(descriptor->lodMinClamp),
+ mLodMaxClamp(descriptor->lodMaxClamp),
+ mCompareFunction(descriptor->compareFunction),
+ mIsBlueprint(blueprint) {
}
SamplerBase::SamplerBase(DeviceBase* device, ObjectBase::ErrorTag tag)
: ObjectBase(device, tag) {
}
+ SamplerBase::~SamplerBase() {
+ // Do not uncache the actual cached object if we are a blueprint
+ if (!mIsBlueprint && !IsError()) {
+ GetDevice()->UncacheSampler(this);
+ }
+ }
+
// static
SamplerBase* SamplerBase::MakeError(DeviceBase* device) {
return new SamplerBase(device, ObjectBase::kError);
}
+ size_t SamplerBase::HashFunc::operator()(const SamplerBase* module) const {
+ size_t hash = 0;
+
+ HashCombine(&hash, module->mAddressModeU);
+ HashCombine(&hash, module->mAddressModeV);
+ HashCombine(&hash, module->mAddressModeW);
+ HashCombine(&hash, module->mMagFilter);
+ HashCombine(&hash, module->mMinFilter);
+ HashCombine(&hash, module->mMipmapFilter);
+ HashCombine(&hash, module->mLodMinClamp);
+ HashCombine(&hash, module->mLodMaxClamp);
+ HashCombine(&hash, module->mCompareFunction);
+
+ return hash;
+ }
+
+ bool SamplerBase::EqualityFunc::operator()(const SamplerBase* a, const SamplerBase* b) const {
+ return a->mAddressModeU == b->mAddressModeU && a->mAddressModeV == b->mAddressModeV &&
+ a->mAddressModeW == b->mAddressModeW && a->mMagFilter == b->mMagFilter &&
+ a->mMinFilter == b->mMinFilter && a->mMipmapFilter == b->mMipmapFilter &&
+ a->mLodMinClamp == b->mLodMinClamp && a->mLodMaxClamp == b->mLodMaxClamp &&
+ a->mCompareFunction == b->mCompareFunction;
+ }
+
} // namespace dawn_native
diff --git a/src/dawn_native/Sampler.h b/src/dawn_native/Sampler.h
index cde32dd..202f3cd 100644
--- a/src/dawn_native/Sampler.h
+++ b/src/dawn_native/Sampler.h
@@ -28,12 +28,35 @@
class SamplerBase : public ObjectBase {
public:
- SamplerBase(DeviceBase* device, const SamplerDescriptor* descriptor);
+ SamplerBase(DeviceBase* device,
+ const SamplerDescriptor* descriptor,
+ bool blueprint = false);
+ ~SamplerBase() override;
static SamplerBase* MakeError(DeviceBase* device);
+ // Functors necessary for the unordered_set<SamplerBase*>-based cache.
+ struct HashFunc {
+ size_t operator()(const SamplerBase* module) const;
+ };
+ struct EqualityFunc {
+ bool operator()(const SamplerBase* a, const SamplerBase* b) const;
+ };
+
private:
SamplerBase(DeviceBase* device, ObjectBase::ErrorTag tag);
+
+ // TODO(cwallez@chromium.org): Store a crypto hash of the items instead?
+ dawn::AddressMode mAddressModeU;
+ dawn::AddressMode mAddressModeV;
+ dawn::AddressMode mAddressModeW;
+ dawn::FilterMode mMagFilter;
+ dawn::FilterMode mMinFilter;
+ dawn::FilterMode mMipmapFilter;
+ float mLodMinClamp;
+ float mLodMaxClamp;
+ dawn::CompareFunction mCompareFunction;
+ bool mIsBlueprint = false;
};
} // namespace dawn_native
diff --git a/src/tests/end2end/ObjectCachingTests.cpp b/src/tests/end2end/ObjectCachingTests.cpp
index 8b8e300..e7de879 100644
--- a/src/tests/end2end/ObjectCachingTests.cpp
+++ b/src/tests/end2end/ObjectCachingTests.cpp
@@ -286,4 +286,61 @@
EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire());
}
+// Test that Samplers are correctly deduplicated.
+TEST_P(ObjectCachingTest, SamplerDeduplication) {
+ dawn::SamplerDescriptor samplerDesc = utils::GetDefaultSamplerDescriptor();
+ dawn::Sampler sampler = device.CreateSampler(&samplerDesc);
+
+ dawn::SamplerDescriptor sameSamplerDesc = utils::GetDefaultSamplerDescriptor();
+ dawn::Sampler sameSampler = device.CreateSampler(&sameSamplerDesc);
+
+ dawn::SamplerDescriptor otherSamplerDescAddressModeU = utils::GetDefaultSamplerDescriptor();
+ otherSamplerDescAddressModeU.addressModeU = dawn::AddressMode::ClampToEdge;
+ dawn::Sampler otherSamplerAddressModeU = device.CreateSampler(&otherSamplerDescAddressModeU);
+
+ dawn::SamplerDescriptor otherSamplerDescAddressModeV = utils::GetDefaultSamplerDescriptor();
+ otherSamplerDescAddressModeV.addressModeV = dawn::AddressMode::ClampToEdge;
+ dawn::Sampler otherSamplerAddressModeV = device.CreateSampler(&otherSamplerDescAddressModeV);
+
+ dawn::SamplerDescriptor otherSamplerDescAddressModeW = utils::GetDefaultSamplerDescriptor();
+ otherSamplerDescAddressModeW.addressModeW = dawn::AddressMode::ClampToEdge;
+ dawn::Sampler otherSamplerAddressModeW = device.CreateSampler(&otherSamplerDescAddressModeW);
+
+ dawn::SamplerDescriptor otherSamplerDescMagFilter = utils::GetDefaultSamplerDescriptor();
+ otherSamplerDescMagFilter.magFilter = dawn::FilterMode::Nearest;
+ dawn::Sampler otherSamplerMagFilter = device.CreateSampler(&otherSamplerDescMagFilter);
+
+ dawn::SamplerDescriptor otherSamplerDescMinFilter = utils::GetDefaultSamplerDescriptor();
+ otherSamplerDescMinFilter.minFilter = dawn::FilterMode::Nearest;
+ dawn::Sampler otherSamplerMinFilter = device.CreateSampler(&otherSamplerDescMinFilter);
+
+ dawn::SamplerDescriptor otherSamplerDescMipmapFilter = utils::GetDefaultSamplerDescriptor();
+ otherSamplerDescMipmapFilter.mipmapFilter = dawn::FilterMode::Nearest;
+ dawn::Sampler otherSamplerMipmapFilter = device.CreateSampler(&otherSamplerDescMipmapFilter);
+
+ dawn::SamplerDescriptor otherSamplerDescLodMinClamp = utils::GetDefaultSamplerDescriptor();
+ otherSamplerDescLodMinClamp.lodMinClamp += 1;
+ dawn::Sampler otherSamplerLodMinClamp = device.CreateSampler(&otherSamplerDescLodMinClamp);
+
+ dawn::SamplerDescriptor otherSamplerDescLodMaxClamp = utils::GetDefaultSamplerDescriptor();
+ otherSamplerDescLodMaxClamp.lodMaxClamp += 1;
+ dawn::Sampler otherSamplerLodMaxClamp = device.CreateSampler(&otherSamplerDescLodMaxClamp);
+
+ dawn::SamplerDescriptor otherSamplerDescCompareFunction = utils::GetDefaultSamplerDescriptor();
+ otherSamplerDescCompareFunction.compareFunction = dawn::CompareFunction::Always;
+ dawn::Sampler otherSamplerCompareFunction =
+ device.CreateSampler(&otherSamplerDescCompareFunction);
+
+ EXPECT_NE(sampler.Get(), otherSamplerAddressModeU.Get());
+ EXPECT_NE(sampler.Get(), otherSamplerAddressModeV.Get());
+ EXPECT_NE(sampler.Get(), otherSamplerAddressModeW.Get());
+ EXPECT_NE(sampler.Get(), otherSamplerMagFilter.Get());
+ EXPECT_NE(sampler.Get(), otherSamplerMinFilter.Get());
+ EXPECT_NE(sampler.Get(), otherSamplerMipmapFilter.Get());
+ EXPECT_NE(sampler.Get(), otherSamplerLodMinClamp.Get());
+ EXPECT_NE(sampler.Get(), otherSamplerLodMaxClamp.Get());
+ EXPECT_NE(sampler.Get(), otherSamplerCompareFunction.Get());
+ EXPECT_EQ(sampler.Get() == sameSampler.Get(), !UsesWire());
+}
+
DAWN_INSTANTIATE_TEST(ObjectCachingTest, D3D12Backend, MetalBackend, OpenGLBackend, VulkanBackend);