dawn_native: deduplicate shader modules
BUG=dawn:143
Change-Id: I2c0fa63e3a6d77c137418f12b9807d16a0636d57
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/6862
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
diff --git a/src/dawn_native/Device.cpp b/src/dawn_native/Device.cpp
index 33bbdec..844feb0 100644
--- a/src/dawn_native/Device.cpp
+++ b/src/dawn_native/Device.cpp
@@ -49,6 +49,7 @@
struct DeviceBase::Caches {
ContentLessObjectCache<BindGroupLayoutBase> bindGroupLayouts;
ContentLessObjectCache<PipelineLayoutBase> pipelineLayouts;
+ ContentLessObjectCache<ShaderModuleBase> shaderModules;
};
// DeviceBase
@@ -141,6 +142,27 @@
ASSERT(removedCount == 1);
}
+ ResultOrError<ShaderModuleBase*> DeviceBase::GetOrCreateShaderModule(
+ const ShaderModuleDescriptor* descriptor) {
+ ShaderModuleBase blueprint(this, descriptor, true);
+
+ auto iter = mCaches->shaderModules.find(&blueprint);
+ if (iter != mCaches->shaderModules.end()) {
+ (*iter)->Reference();
+ return *iter;
+ }
+
+ ShaderModuleBase* backendObj;
+ DAWN_TRY_ASSIGN(backendObj, CreateShaderModuleImpl(descriptor));
+ mCaches->shaderModules.insert(backendObj);
+ return backendObj;
+ }
+
+ void DeviceBase::UncacheShaderModule(ShaderModuleBase* obj) {
+ size_t removedCount = mCaches->shaderModules.erase(obj);
+ ASSERT(removedCount == 1);
+ }
+
// Object creation API methods
BindGroupBase* DeviceBase::CreateBindGroup(const BindGroupDescriptor* descriptor) {
@@ -382,7 +404,7 @@
MaybeError DeviceBase::CreateShaderModuleInternal(ShaderModuleBase** result,
const ShaderModuleDescriptor* descriptor) {
DAWN_TRY(ValidateShaderModuleDescriptor(this, descriptor));
- DAWN_TRY_ASSIGN(*result, CreateShaderModuleImpl(descriptor));
+ DAWN_TRY_ASSIGN(*result, GetOrCreateShaderModule(descriptor));
return {};
}
diff --git a/src/dawn_native/Device.h b/src/dawn_native/Device.h
index 3f4a913..aae0751 100644
--- a/src/dawn_native/Device.h
+++ b/src/dawn_native/Device.h
@@ -88,6 +88,10 @@
const PipelineLayoutDescriptor* descriptor);
void UncachePipelineLayout(PipelineLayoutBase* obj);
+ ResultOrError<ShaderModuleBase*> GetOrCreateShaderModule(
+ const ShaderModuleDescriptor* descriptor);
+ void UncacheShaderModule(ShaderModuleBase* obj);
+
// Dawn API
BindGroupBase* CreateBindGroup(const BindGroupDescriptor* descriptor);
BindGroupLayoutBase* CreateBindGroupLayout(const BindGroupLayoutDescriptor* descriptor);
diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp
index ca6548d..1cdf18f 100644
--- a/src/dawn_native/ShaderModule.cpp
+++ b/src/dawn_native/ShaderModule.cpp
@@ -14,6 +14,7 @@
#include "dawn_native/ShaderModule.h"
+#include "common/HashUtils.h"
#include "dawn_native/BindGroupLayout.h"
#include "dawn_native/Device.h"
#include "dawn_native/Pipeline.h"
@@ -67,14 +68,26 @@
// ShaderModuleBase
- ShaderModuleBase::ShaderModuleBase(DeviceBase* device, const ShaderModuleDescriptor*)
- : ObjectBase(device) {
+ ShaderModuleBase::ShaderModuleBase(DeviceBase* device,
+ const ShaderModuleDescriptor* descriptor,
+ bool blueprint)
+ : ObjectBase(device),
+ mCode(descriptor->code, descriptor->code + descriptor->codeSize),
+ mIsBlueprint(blueprint) {
}
ShaderModuleBase::ShaderModuleBase(DeviceBase* device, ObjectBase::ErrorTag tag)
: ObjectBase(device, tag) {
}
+ ShaderModuleBase::~ShaderModuleBase() {
+ // Do not uncache the actual cached object if we are a blueprint
+ if (!mIsBlueprint) {
+ ASSERT(!IsError());
+ GetDevice()->UncacheShaderModule(this);
+ }
+ }
+
// static
ShaderModuleBase* ShaderModuleBase::MakeError(DeviceBase* device) {
return new ShaderModuleBase(device, ObjectBase::kError);
@@ -287,4 +300,19 @@
return true;
}
+ size_t ShaderModuleBase::HashFunc::operator()(const ShaderModuleBase* module) const {
+ size_t hash = 0;
+
+ for (uint32_t word : module->mCode) {
+ HashCombine(&hash, word);
+ }
+
+ return hash;
+ }
+
+ bool ShaderModuleBase::EqualityFunc::operator()(const ShaderModuleBase* a,
+ const ShaderModuleBase* b) const {
+ return a->mCode == b->mCode;
+ }
+
} // namespace dawn_native
diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h
index b8020f9..ab00c27 100644
--- a/src/dawn_native/ShaderModule.h
+++ b/src/dawn_native/ShaderModule.h
@@ -37,7 +37,10 @@
class ShaderModuleBase : public ObjectBase {
public:
- ShaderModuleBase(DeviceBase* device, const ShaderModuleDescriptor* descriptor);
+ ShaderModuleBase(DeviceBase* device,
+ const ShaderModuleDescriptor* descriptor,
+ bool blueprint = false);
+ ~ShaderModuleBase() override;
static ShaderModuleBase* MakeError(DeviceBase* device);
@@ -68,11 +71,24 @@
bool IsCompatibleWithPipelineLayout(const PipelineLayoutBase* layout);
+ // Functors necessary for the unordered_set<ShaderModuleBase*>-based cache.
+ struct HashFunc {
+ size_t operator()(const ShaderModuleBase* module) const;
+ };
+ struct EqualityFunc {
+ bool operator()(const ShaderModuleBase* a, const ShaderModuleBase* b) const;
+ };
+
private:
ShaderModuleBase(DeviceBase* device, ObjectBase::ErrorTag tag);
bool IsCompatibleWithBindGroupLayout(size_t group, const BindGroupLayoutBase* layout);
+ // TODO(cwallez@chromium.org): The code is only stored for deduplication. We could maybe
+ // store a cryptographic hash of the code instead?
+ std::vector<uint32_t> mCode;
+ bool mIsBlueprint = false;
+
PushConstantInfo mPushConstants = {};
ModuleBindingInfo mBindingInfo;
std::bitset<kMaxVertexAttributes> mUsedVertexAttributes;
diff --git a/src/tests/end2end/ObjectCachingTests.cpp b/src/tests/end2end/ObjectCachingTests.cpp
index d2a591b..64b6ae1 100644
--- a/src/tests/end2end/ObjectCachingTests.cpp
+++ b/src/tests/end2end/ObjectCachingTests.cpp
@@ -48,4 +48,31 @@
EXPECT_EQ(pl.Get() == samePl.Get(), !UsesWire());
}
+// Test that ShaderModules are correctly deduplicated.
+TEST_P(ObjectCachingTest, ShaderModuleDeduplication) {
+ dawn::ShaderModule module = utils::CreateShaderModule(device, dawn::ShaderStage::Fragment, R"(
+ #version 450
+ layout(location = 0) out vec4 fragColor;
+ void main() {
+ fragColor = vec4(0.0, 1.0, 0.0, 1.0);
+ })");
+ dawn::ShaderModule sameModule =
+ utils::CreateShaderModule(device, dawn::ShaderStage::Fragment, R"(
+ #version 450
+ layout(location = 0) out vec4 fragColor;
+ void main() {
+ fragColor = vec4(0.0, 1.0, 0.0, 1.0);
+ })");
+ dawn::ShaderModule otherModule =
+ utils::CreateShaderModule(device, dawn::ShaderStage::Fragment, R"(
+ #version 450
+ layout(location = 0) out vec4 fragColor;
+ void main() {
+ fragColor = vec4(0.0);
+ })");
+
+ EXPECT_NE(module.Get(), otherModule.Get());
+ EXPECT_EQ(module.Get() == sameModule.Get(), !UsesWire());
+}
+
DAWN_INSTANTIATE_TEST(ObjectCachingTest, D3D12Backend, MetalBackend, OpenGLBackend, VulkanBackend);