dawn_native: deduplicate compute pipelines

BUG=dawn:143

Change-Id: I64e4660de2241bb72bb7c615a0bd1e675e043295
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/6863
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
diff --git a/src/dawn_native/ComputePipeline.cpp b/src/dawn_native/ComputePipeline.cpp
index a2fb60f..0ae29119 100644
--- a/src/dawn_native/ComputePipeline.cpp
+++ b/src/dawn_native/ComputePipeline.cpp
@@ -14,6 +14,7 @@
 
 #include "dawn_native/ComputePipeline.h"
 
+#include "common/HashUtils.h"
 #include "dawn_native/Device.h"
 
 namespace dawn_native {
@@ -33,8 +34,12 @@
     // ComputePipelineBase
 
     ComputePipelineBase::ComputePipelineBase(DeviceBase* device,
-                                             const ComputePipelineDescriptor* descriptor)
-        : PipelineBase(device, descriptor->layout, dawn::ShaderStageBit::Compute) {
+                                             const ComputePipelineDescriptor* descriptor,
+                                             bool blueprint)
+        : PipelineBase(device, descriptor->layout, dawn::ShaderStageBit::Compute),
+          mModule(descriptor->computeStage->module),
+          mEntryPoint(descriptor->computeStage->entryPoint),
+          mIsBlueprint(blueprint) {
         ExtractModuleData(dawn::ShaderStage::Compute, descriptor->computeStage->module);
     }
 
@@ -42,9 +47,29 @@
         : PipelineBase(device, tag) {
     }
 
+    ComputePipelineBase::~ComputePipelineBase() {
+        // Do not uncache the actual cached object if we are a blueprint
+        if (!mIsBlueprint) {
+            ASSERT(!IsError());
+            GetDevice()->UncacheComputePipeline(this);
+        }
+    }
+
     // static
     ComputePipelineBase* ComputePipelineBase::MakeError(DeviceBase* device) {
         return new ComputePipelineBase(device, ObjectBase::kError);
     }
 
+    size_t ComputePipelineBase::HashFunc::operator()(const ComputePipelineBase* pipeline) const {
+        size_t hash = 0;
+        HashCombine(&hash, pipeline->mModule.Get(), pipeline->mEntryPoint, pipeline->GetLayout());
+        return hash;
+    }
+
+    bool ComputePipelineBase::EqualityFunc::operator()(const ComputePipelineBase* a,
+                                                       const ComputePipelineBase* b) const {
+        return a->mModule.Get() == b->mModule.Get() && a->mEntryPoint == b->mEntryPoint &&
+               a->GetLayout() == b->GetLayout();
+    }
+
 }  // namespace dawn_native
diff --git a/src/dawn_native/ComputePipeline.h b/src/dawn_native/ComputePipeline.h
index c1450d3..006c469 100644
--- a/src/dawn_native/ComputePipeline.h
+++ b/src/dawn_native/ComputePipeline.h
@@ -26,12 +26,28 @@
 
     class ComputePipelineBase : public PipelineBase {
       public:
-        ComputePipelineBase(DeviceBase* device, const ComputePipelineDescriptor* descriptor);
+        ComputePipelineBase(DeviceBase* device,
+                            const ComputePipelineDescriptor* descriptor,
+                            bool blueprint = false);
+        ~ComputePipelineBase() override;
 
         static ComputePipelineBase* MakeError(DeviceBase* device);
 
+        // Functors necessary for the unordered_set<ComputePipelineBase*>-based cache.
+        struct HashFunc {
+            size_t operator()(const ComputePipelineBase* pipeline) const;
+        };
+        struct EqualityFunc {
+            bool operator()(const ComputePipelineBase* a, const ComputePipelineBase* b) const;
+        };
+
       private:
         ComputePipelineBase(DeviceBase* device, ObjectBase::ErrorTag tag);
+
+        // TODO(cwallez@chromium.org): Store a crypto hash of the module instead.
+        Ref<ShaderModuleBase> mModule;
+        std::string mEntryPoint;
+        bool mIsBlueprint = false;
     };
 
 }  // namespace dawn_native
diff --git a/src/dawn_native/Device.cpp b/src/dawn_native/Device.cpp
index 844feb0..c11a67e 100644
--- a/src/dawn_native/Device.cpp
+++ b/src/dawn_native/Device.cpp
@@ -48,6 +48,7 @@
 
     struct DeviceBase::Caches {
         ContentLessObjectCache<BindGroupLayoutBase> bindGroupLayouts;
+        ContentLessObjectCache<ComputePipelineBase> computePipelines;
         ContentLessObjectCache<PipelineLayoutBase> pipelineLayouts;
         ContentLessObjectCache<ShaderModuleBase> shaderModules;
     };
@@ -121,6 +122,27 @@
         ASSERT(removedCount == 1);
     }
 
+    ResultOrError<ComputePipelineBase*> DeviceBase::GetOrCreateComputePipeline(
+        const ComputePipelineDescriptor* descriptor) {
+        ComputePipelineBase blueprint(this, descriptor, true);
+
+        auto iter = mCaches->computePipelines.find(&blueprint);
+        if (iter != mCaches->computePipelines.end()) {
+            (*iter)->Reference();
+            return *iter;
+        }
+
+        ComputePipelineBase* backendObj;
+        DAWN_TRY_ASSIGN(backendObj, CreateComputePipelineImpl(descriptor));
+        mCaches->computePipelines.insert(backendObj);
+        return backendObj;
+    }
+
+    void DeviceBase::UncacheComputePipeline(ComputePipelineBase* obj) {
+        size_t removedCount = mCaches->computePipelines.erase(obj);
+        ASSERT(removedCount == 1);
+    }
+
     ResultOrError<PipelineLayoutBase*> DeviceBase::GetOrCreatePipelineLayout(
         const PipelineLayoutDescriptor* descriptor) {
         PipelineLayoutBase blueprint(this, descriptor, true);
@@ -369,7 +391,7 @@
         ComputePipelineBase** result,
         const ComputePipelineDescriptor* descriptor) {
         DAWN_TRY(ValidateComputePipelineDescriptor(this, descriptor));
-        DAWN_TRY_ASSIGN(*result, CreateComputePipelineImpl(descriptor));
+        DAWN_TRY_ASSIGN(*result, GetOrCreateComputePipeline(descriptor));
         return {};
     }
 
diff --git a/src/dawn_native/Device.h b/src/dawn_native/Device.h
index aae0751..0addd78 100644
--- a/src/dawn_native/Device.h
+++ b/src/dawn_native/Device.h
@@ -84,6 +84,10 @@
             const BindGroupLayoutDescriptor* descriptor);
         void UncacheBindGroupLayout(BindGroupLayoutBase* obj);
 
+        ResultOrError<ComputePipelineBase*> GetOrCreateComputePipeline(
+            const ComputePipelineDescriptor* descriptor);
+        void UncacheComputePipeline(ComputePipelineBase* obj);
+
         ResultOrError<PipelineLayoutBase*> GetOrCreatePipelineLayout(
             const PipelineLayoutDescriptor* descriptor);
         void UncachePipelineLayout(PipelineLayoutBase* obj);
diff --git a/src/dawn_native/Pipeline.cpp b/src/dawn_native/Pipeline.cpp
index e839b1b..8c245e1 100644
--- a/src/dawn_native/Pipeline.cpp
+++ b/src/dawn_native/Pipeline.cpp
@@ -87,4 +87,9 @@
         return mLayout.Get();
     }
 
+    const PipelineLayoutBase* PipelineBase::GetLayout() const {
+        ASSERT(!IsError());
+        return mLayout.Get();
+    }
+
 }  // namespace dawn_native
diff --git a/src/dawn_native/Pipeline.h b/src/dawn_native/Pipeline.h
index c917125..55d57bf 100644
--- a/src/dawn_native/Pipeline.h
+++ b/src/dawn_native/Pipeline.h
@@ -48,6 +48,7 @@
         const PushConstantInfo& GetPushConstants(dawn::ShaderStage stage) const;
         dawn::ShaderStageBit GetStageMask() const;
         PipelineLayoutBase* GetLayout();
+        const PipelineLayoutBase* GetLayout() const;
 
       protected:
         PipelineBase(DeviceBase* device, PipelineLayoutBase* layout, dawn::ShaderStageBit stages);
diff --git a/src/tests/end2end/ObjectCachingTests.cpp b/src/tests/end2end/ObjectCachingTests.cpp
index 64b6ae1..43feae8 100644
--- a/src/tests/end2end/ObjectCachingTests.cpp
+++ b/src/tests/end2end/ObjectCachingTests.cpp
@@ -75,4 +75,86 @@
     EXPECT_EQ(module.Get() == sameModule.Get(), !UsesWire());
 }
 
+// Test that ComputePipeline are correctly deduplicated wrt. their ShaderModule
+TEST_P(ObjectCachingTest, ComputePipelineDeduplicationOnShaderModule) {
+    dawn::ShaderModule module = utils::CreateShaderModule(device, dawn::ShaderStage::Compute, R"(
+            #version 450
+            void main() {
+                int i = 0;
+            })");
+    dawn::ShaderModule sameModule =
+        utils::CreateShaderModule(device, dawn::ShaderStage::Compute, R"(
+            #version 450
+            void main() {
+                int i = 0;
+            })");
+    dawn::ShaderModule otherModule =
+        utils::CreateShaderModule(device, dawn::ShaderStage::Compute, R"(
+            #version 450
+            void main() {
+            })");
+
+    EXPECT_NE(module.Get(), otherModule.Get());
+    EXPECT_EQ(module.Get() == sameModule.Get(), !UsesWire());
+
+    dawn::PipelineLayout layout = utils::MakeBasicPipelineLayout(device, nullptr);
+
+    dawn::PipelineStageDescriptor stageDesc;
+    stageDesc.entryPoint = "main";
+    stageDesc.module = module;
+
+    dawn::ComputePipelineDescriptor desc;
+    desc.computeStage = &stageDesc;
+    desc.layout = layout;
+
+    dawn::ComputePipeline pipeline = device.CreateComputePipeline(&desc);
+
+    stageDesc.module = sameModule;
+    dawn::ComputePipeline samePipeline = device.CreateComputePipeline(&desc);
+
+    stageDesc.module = otherModule;
+    dawn::ComputePipeline otherPipeline = device.CreateComputePipeline(&desc);
+
+    EXPECT_NE(pipeline.Get(), otherPipeline.Get());
+    EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire());
+}
+
+// Test that ComputePipeline are correctly deduplicated wrt. their layout
+TEST_P(ObjectCachingTest, ComputePipelineDeduplicationOnLayout) {
+    dawn::BindGroupLayout bgl = utils::MakeBindGroupLayout(
+        device, {{1, dawn::ShaderStageBit::Fragment, dawn::BindingType::UniformBuffer}});
+    dawn::BindGroupLayout otherBgl = utils::MakeBindGroupLayout(
+        device, {{1, dawn::ShaderStageBit::Vertex, dawn::BindingType::UniformBuffer}});
+
+    dawn::PipelineLayout pl = utils::MakeBasicPipelineLayout(device, &bgl);
+    dawn::PipelineLayout samePl = utils::MakeBasicPipelineLayout(device, &bgl);
+    dawn::PipelineLayout otherPl = utils::MakeBasicPipelineLayout(device, nullptr);
+
+    EXPECT_NE(pl.Get(), otherPl.Get());
+    EXPECT_EQ(pl.Get() == samePl.Get(), !UsesWire());
+
+    dawn::PipelineStageDescriptor stageDesc;
+    stageDesc.entryPoint = "main";
+    stageDesc.module = utils::CreateShaderModule(device, dawn::ShaderStage::Compute, R"(
+            #version 450
+            void main() {
+                int i = 0;
+            })");
+
+    dawn::ComputePipelineDescriptor desc;
+    desc.computeStage = &stageDesc;
+
+    desc.layout = pl;
+    dawn::ComputePipeline pipeline = device.CreateComputePipeline(&desc);
+
+    desc.layout = samePl;
+    dawn::ComputePipeline samePipeline = device.CreateComputePipeline(&desc);
+
+    desc.layout = otherPl;
+    dawn::ComputePipeline otherPipeline = device.CreateComputePipeline(&desc);
+
+    EXPECT_NE(pipeline.Get(), otherPipeline.Get());
+    EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire());
+}
+
 DAWN_INSTANTIATE_TEST(ObjectCachingTest, D3D12Backend, MetalBackend, OpenGLBackend, VulkanBackend);