dawn_native: deduplicate shader modules

BUG=dawn:143

Change-Id: I2c0fa63e3a6d77c137418f12b9807d16a0636d57
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/6862
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
diff --git a/src/dawn_native/Device.cpp b/src/dawn_native/Device.cpp
index 33bbdec..844feb0 100644
--- a/src/dawn_native/Device.cpp
+++ b/src/dawn_native/Device.cpp
@@ -49,6 +49,7 @@
     struct DeviceBase::Caches {
         ContentLessObjectCache<BindGroupLayoutBase> bindGroupLayouts;
         ContentLessObjectCache<PipelineLayoutBase> pipelineLayouts;
+        ContentLessObjectCache<ShaderModuleBase> shaderModules;
     };
 
     // DeviceBase
@@ -141,6 +142,27 @@
         ASSERT(removedCount == 1);
     }
 
+    ResultOrError<ShaderModuleBase*> DeviceBase::GetOrCreateShaderModule(
+        const ShaderModuleDescriptor* descriptor) {
+        ShaderModuleBase blueprint(this, descriptor, true);
+
+        auto iter = mCaches->shaderModules.find(&blueprint);
+        if (iter != mCaches->shaderModules.end()) {
+            (*iter)->Reference();
+            return *iter;
+        }
+
+        ShaderModuleBase* backendObj;
+        DAWN_TRY_ASSIGN(backendObj, CreateShaderModuleImpl(descriptor));
+        mCaches->shaderModules.insert(backendObj);
+        return backendObj;
+    }
+
+    void DeviceBase::UncacheShaderModule(ShaderModuleBase* obj) {
+        size_t removedCount = mCaches->shaderModules.erase(obj);
+        ASSERT(removedCount == 1);
+    }
+
     // Object creation API methods
 
     BindGroupBase* DeviceBase::CreateBindGroup(const BindGroupDescriptor* descriptor) {
@@ -382,7 +404,7 @@
     MaybeError DeviceBase::CreateShaderModuleInternal(ShaderModuleBase** result,
                                                       const ShaderModuleDescriptor* descriptor) {
         DAWN_TRY(ValidateShaderModuleDescriptor(this, descriptor));
-        DAWN_TRY_ASSIGN(*result, CreateShaderModuleImpl(descriptor));
+        DAWN_TRY_ASSIGN(*result, GetOrCreateShaderModule(descriptor));
         return {};
     }
 
diff --git a/src/dawn_native/Device.h b/src/dawn_native/Device.h
index 3f4a913..aae0751 100644
--- a/src/dawn_native/Device.h
+++ b/src/dawn_native/Device.h
@@ -88,6 +88,10 @@
             const PipelineLayoutDescriptor* descriptor);
         void UncachePipelineLayout(PipelineLayoutBase* obj);
 
+        ResultOrError<ShaderModuleBase*> GetOrCreateShaderModule(
+            const ShaderModuleDescriptor* descriptor);
+        void UncacheShaderModule(ShaderModuleBase* obj);
+
         // Dawn API
         BindGroupBase* CreateBindGroup(const BindGroupDescriptor* descriptor);
         BindGroupLayoutBase* CreateBindGroupLayout(const BindGroupLayoutDescriptor* descriptor);
diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp
index ca6548d..1cdf18f 100644
--- a/src/dawn_native/ShaderModule.cpp
+++ b/src/dawn_native/ShaderModule.cpp
@@ -14,6 +14,7 @@
 
 #include "dawn_native/ShaderModule.h"
 
+#include "common/HashUtils.h"
 #include "dawn_native/BindGroupLayout.h"
 #include "dawn_native/Device.h"
 #include "dawn_native/Pipeline.h"
@@ -67,14 +68,26 @@
 
     // ShaderModuleBase
 
-    ShaderModuleBase::ShaderModuleBase(DeviceBase* device, const ShaderModuleDescriptor*)
-        : ObjectBase(device) {
+    ShaderModuleBase::ShaderModuleBase(DeviceBase* device,
+                                       const ShaderModuleDescriptor* descriptor,
+                                       bool blueprint)
+        : ObjectBase(device),
+          mCode(descriptor->code, descriptor->code + descriptor->codeSize),
+          mIsBlueprint(blueprint) {
     }
 
     ShaderModuleBase::ShaderModuleBase(DeviceBase* device, ObjectBase::ErrorTag tag)
         : ObjectBase(device, tag) {
     }
 
+    ShaderModuleBase::~ShaderModuleBase() {
+        // Do not uncache the actual cached object if we are a blueprint
+        if (!mIsBlueprint) {
+            ASSERT(!IsError());
+            GetDevice()->UncacheShaderModule(this);
+        }
+    }
+
     // static
     ShaderModuleBase* ShaderModuleBase::MakeError(DeviceBase* device) {
         return new ShaderModuleBase(device, ObjectBase::kError);
@@ -287,4 +300,19 @@
         return true;
     }
 
+    size_t ShaderModuleBase::HashFunc::operator()(const ShaderModuleBase* module) const {
+        size_t hash = 0;
+
+        for (uint32_t word : module->mCode) {
+            HashCombine(&hash, word);
+        }
+
+        return hash;
+    }
+
+    bool ShaderModuleBase::EqualityFunc::operator()(const ShaderModuleBase* a,
+                                                    const ShaderModuleBase* b) const {
+        return a->mCode == b->mCode;
+    }
+
 }  // namespace dawn_native
diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h
index b8020f9..ab00c27 100644
--- a/src/dawn_native/ShaderModule.h
+++ b/src/dawn_native/ShaderModule.h
@@ -37,7 +37,10 @@
 
     class ShaderModuleBase : public ObjectBase {
       public:
-        ShaderModuleBase(DeviceBase* device, const ShaderModuleDescriptor* descriptor);
+        ShaderModuleBase(DeviceBase* device,
+                         const ShaderModuleDescriptor* descriptor,
+                         bool blueprint = false);
+        ~ShaderModuleBase() override;
 
         static ShaderModuleBase* MakeError(DeviceBase* device);
 
@@ -68,11 +71,24 @@
 
         bool IsCompatibleWithPipelineLayout(const PipelineLayoutBase* layout);
 
+        // Functors necessary for the unordered_set<ShaderModuleBase*>-based cache.
+        struct HashFunc {
+            size_t operator()(const ShaderModuleBase* module) const;
+        };
+        struct EqualityFunc {
+            bool operator()(const ShaderModuleBase* a, const ShaderModuleBase* b) const;
+        };
+
       private:
         ShaderModuleBase(DeviceBase* device, ObjectBase::ErrorTag tag);
 
         bool IsCompatibleWithBindGroupLayout(size_t group, const BindGroupLayoutBase* layout);
 
+        // TODO(cwallez@chromium.org): The code is only stored for deduplication. We could maybe
+        // store a cryptographic hash of the code instead?
+        std::vector<uint32_t> mCode;
+        bool mIsBlueprint = false;
+
         PushConstantInfo mPushConstants = {};
         ModuleBindingInfo mBindingInfo;
         std::bitset<kMaxVertexAttributes> mUsedVertexAttributes;
diff --git a/src/tests/end2end/ObjectCachingTests.cpp b/src/tests/end2end/ObjectCachingTests.cpp
index d2a591b..64b6ae1 100644
--- a/src/tests/end2end/ObjectCachingTests.cpp
+++ b/src/tests/end2end/ObjectCachingTests.cpp
@@ -48,4 +48,31 @@
     EXPECT_EQ(pl.Get() == samePl.Get(), !UsesWire());
 }
 
+// Test that ShaderModules are correctly deduplicated.
+TEST_P(ObjectCachingTest, ShaderModuleDeduplication) {
+    dawn::ShaderModule module = utils::CreateShaderModule(device, dawn::ShaderStage::Fragment, R"(
+            #version 450
+            layout(location = 0) out vec4 fragColor;
+            void main() {
+                fragColor = vec4(0.0, 1.0, 0.0, 1.0);
+            })");
+    dawn::ShaderModule sameModule =
+        utils::CreateShaderModule(device, dawn::ShaderStage::Fragment, R"(
+            #version 450
+            layout(location = 0) out vec4 fragColor;
+            void main() {
+                fragColor = vec4(0.0, 1.0, 0.0, 1.0);
+            })");
+    dawn::ShaderModule otherModule =
+        utils::CreateShaderModule(device, dawn::ShaderStage::Fragment, R"(
+            #version 450
+            layout(location = 0) out vec4 fragColor;
+            void main() {
+                fragColor = vec4(0.0);
+            })");
+
+    EXPECT_NE(module.Get(), otherModule.Get());
+    EXPECT_EQ(module.Get() == sameModule.Get(), !UsesWire());
+}
+
 DAWN_INSTANTIATE_TEST(ObjectCachingTest, D3D12Backend, MetalBackend, OpenGLBackend, VulkanBackend);