dawn_native: deduplicate render pipelines

BUG=dawn:143

Change-Id: I2f66387f95bcb44dc20f308b4a582b878803dbe8
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/6864
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/dawn_native/Device.cpp b/src/dawn_native/Device.cpp
index c11a67e..753f709 100644
--- a/src/dawn_native/Device.cpp
+++ b/src/dawn_native/Device.cpp
@@ -50,6 +50,7 @@
         ContentLessObjectCache<BindGroupLayoutBase> bindGroupLayouts;
         ContentLessObjectCache<ComputePipelineBase> computePipelines;
         ContentLessObjectCache<PipelineLayoutBase> pipelineLayouts;
+        ContentLessObjectCache<RenderPipelineBase> renderPipelines;
         ContentLessObjectCache<ShaderModuleBase> shaderModules;
     };
 
@@ -164,6 +165,27 @@
         ASSERT(removedCount == 1);
     }
 
+    ResultOrError<RenderPipelineBase*> DeviceBase::GetOrCreateRenderPipeline(
+        const RenderPipelineDescriptor* descriptor) {
+        RenderPipelineBase blueprint(this, descriptor, true);
+
+        auto iter = mCaches->renderPipelines.find(&blueprint);
+        if (iter != mCaches->renderPipelines.end()) {
+            (*iter)->Reference();
+            return *iter;
+        }
+
+        RenderPipelineBase* backendObj;
+        DAWN_TRY_ASSIGN(backendObj, CreateRenderPipelineImpl(descriptor));
+        mCaches->renderPipelines.insert(backendObj);
+        return backendObj;
+    }
+
+    void DeviceBase::UncacheRenderPipeline(RenderPipelineBase* obj) {
+        size_t removedCount = mCaches->renderPipelines.erase(obj);
+        ASSERT(removedCount == 1);
+    }
+
     ResultOrError<ShaderModuleBase*> DeviceBase::GetOrCreateShaderModule(
         const ShaderModuleDescriptor* descriptor) {
         ShaderModuleBase blueprint(this, descriptor, true);
@@ -412,7 +434,7 @@
         RenderPipelineBase** result,
         const RenderPipelineDescriptor* descriptor) {
         DAWN_TRY(ValidateRenderPipelineDescriptor(this, descriptor));
-        DAWN_TRY_ASSIGN(*result, CreateRenderPipelineImpl(descriptor));
+        DAWN_TRY_ASSIGN(*result, GetOrCreateRenderPipeline(descriptor));
         return {};
     }
 
diff --git a/src/dawn_native/Device.h b/src/dawn_native/Device.h
index 0addd78..b9039e4 100644
--- a/src/dawn_native/Device.h
+++ b/src/dawn_native/Device.h
@@ -92,6 +92,10 @@
             const PipelineLayoutDescriptor* descriptor);
         void UncachePipelineLayout(PipelineLayoutBase* obj);
 
+        ResultOrError<RenderPipelineBase*> GetOrCreateRenderPipeline(
+            const RenderPipelineDescriptor* descriptor);
+        void UncacheRenderPipeline(RenderPipelineBase* obj);
+
         ResultOrError<ShaderModuleBase*> GetOrCreateShaderModule(
             const ShaderModuleDescriptor* descriptor);
         void UncacheShaderModule(ShaderModuleBase* obj);
diff --git a/src/dawn_native/RenderPipeline.cpp b/src/dawn_native/RenderPipeline.cpp
index 84344ee..e44d287 100644
--- a/src/dawn_native/RenderPipeline.cpp
+++ b/src/dawn_native/RenderPipeline.cpp
@@ -15,6 +15,7 @@
 #include "dawn_native/RenderPipeline.h"
 
 #include "common/BitSetIterator.h"
+#include "common/HashUtils.h"
 #include "dawn_native/Commands.h"
 #include "dawn_native/Device.h"
 #include "dawn_native/Texture.h"
@@ -328,15 +329,21 @@
     // RenderPipelineBase
 
     RenderPipelineBase::RenderPipelineBase(DeviceBase* device,
-                                           const RenderPipelineDescriptor* descriptor)
+                                           const RenderPipelineDescriptor* descriptor,
+                                           bool blueprint)
         : PipelineBase(device,
                        descriptor->layout,
                        dawn::ShaderStageBit::Vertex | dawn::ShaderStageBit::Fragment),
           mInputState(*descriptor->inputState),
+          mHasDepthStencilAttachment(descriptor->depthStencilState != nullptr),
           mPrimitiveTopology(descriptor->primitiveTopology),
           mRasterizationState(*descriptor->rasterizationState),
-          mHasDepthStencilAttachment(descriptor->depthStencilState != nullptr),
-          mSampleCount(descriptor->sampleCount) {
+          mSampleCount(descriptor->sampleCount),
+          mVertexModule(descriptor->vertexStage->module),
+          mVertexEntryPoint(descriptor->vertexStage->entryPoint),
+          mFragmentModule(descriptor->fragmentStage->module),
+          mFragmentEntryPoint(descriptor->fragmentStage->entryPoint),
+          mIsBlueprint(blueprint) {
         uint32_t location = 0;
         for (uint32_t i = 0; i < mInputState.numAttributes; ++i) {
             location = mInputState.attributes[i].shaderLocation;
@@ -391,6 +398,13 @@
         return new RenderPipelineBase(device, ObjectBase::kError);
     }
 
+    RenderPipelineBase::~RenderPipelineBase() {
+        // Do not uncache the actual cached object if we are a blueprint
+        if (!mIsBlueprint && !IsError()) {
+            GetDevice()->UncacheRenderPipeline(this);
+        }
+    }
+
     const InputStateDescriptor* RenderPipelineBase::GetInputStateDescriptor() const {
         ASSERT(!IsError());
         return &mInputState;
@@ -419,13 +433,13 @@
     }
 
     const ColorStateDescriptor* RenderPipelineBase::GetColorStateDescriptor(
-        uint32_t attachmentSlot) {
+        uint32_t attachmentSlot) const {
         ASSERT(!IsError());
         ASSERT(attachmentSlot < mColorStates.size());
         return &mColorStates[attachmentSlot];
     }
 
-    const DepthStencilStateDescriptor* RenderPipelineBase::GetDepthStencilStateDescriptor() {
+    const DepthStencilStateDescriptor* RenderPipelineBase::GetDepthStencilStateDescriptor() const {
         ASSERT(!IsError());
         return &mDepthStencilState;
     }
@@ -509,4 +523,175 @@
         return attributesUsingInput[slot];
     }
 
+    size_t RenderPipelineBase::HashFunc::operator()(const RenderPipelineBase* pipeline) const {
+        size_t hash = 0;
+
+        // Hash modules and layout
+        HashCombine(&hash, pipeline->GetLayout());
+        HashCombine(&hash, pipeline->mVertexModule.Get(), pipeline->mFragmentEntryPoint);
+        HashCombine(&hash, pipeline->mFragmentModule.Get(), pipeline->mFragmentEntryPoint);
+
+        // Hash attachments
+        HashCombine(&hash, pipeline->mColorAttachmentsSet);
+        for (uint32_t i : IterateBitSet(pipeline->mColorAttachmentsSet)) {
+            const ColorStateDescriptor& desc = *pipeline->GetColorStateDescriptor(i);
+            HashCombine(&hash, desc.format, desc.writeMask);
+            HashCombine(&hash, desc.colorBlend.operation, desc.colorBlend.srcFactor,
+                        desc.colorBlend.dstFactor);
+            HashCombine(&hash, desc.alphaBlend.operation, desc.alphaBlend.srcFactor,
+                        desc.alphaBlend.dstFactor);
+        }
+
+        if (pipeline->mHasDepthStencilAttachment) {
+            const DepthStencilStateDescriptor& desc = pipeline->mDepthStencilState;
+            HashCombine(&hash, desc.format, desc.depthWriteEnabled, desc.depthCompare);
+            HashCombine(&hash, desc.stencilReadMask, desc.stencilWriteMask);
+            HashCombine(&hash, desc.stencilFront.compare, desc.stencilFront.failOp,
+                        desc.stencilFront.depthFailOp, desc.stencilFront.passOp);
+            HashCombine(&hash, desc.stencilBack.compare, desc.stencilBack.failOp,
+                        desc.stencilBack.depthFailOp, desc.stencilBack.passOp);
+        }
+
+        // Hash vertex input state
+        HashCombine(&hash, pipeline->mAttributesSetMask);
+        for (uint32_t i : IterateBitSet(pipeline->mAttributesSetMask)) {
+            const VertexAttributeDescriptor& desc = pipeline->GetAttribute(i);
+            HashCombine(&hash, desc.shaderLocation, desc.inputSlot, desc.offset, desc.format);
+        }
+
+        HashCombine(&hash, pipeline->mInputsSetMask);
+        for (uint32_t i : IterateBitSet(pipeline->mInputsSetMask)) {
+            const VertexInputDescriptor& desc = pipeline->GetInput(i);
+            HashCombine(&hash, desc.inputSlot, desc.stride, desc.stepMode);
+        }
+
+        HashCombine(&hash, pipeline->mInputState.indexFormat);
+
+        // Hash rasterization state
+        {
+            const RasterizationStateDescriptor& desc = pipeline->mRasterizationState;
+            HashCombine(&hash, desc.frontFace, desc.cullMode);
+            HashCombine(&hash, desc.depthBias, desc.depthBiasSlopeScale, desc.depthBiasClamp);
+        }
+
+        // Hash other state
+        HashCombine(&hash, pipeline->mSampleCount, pipeline->mPrimitiveTopology);
+
+        return hash;
+    }
+
+    bool RenderPipelineBase::EqualityFunc::operator()(const RenderPipelineBase* a,
+                                                      const RenderPipelineBase* b) const {
+        // Check modules and layout
+        if (a->GetLayout() != b->GetLayout() || a->mVertexModule.Get() != b->mVertexModule.Get() ||
+            a->mVertexEntryPoint != b->mVertexEntryPoint ||
+            a->mFragmentModule.Get() != b->mFragmentModule.Get() ||
+            a->mFragmentEntryPoint != b->mFragmentEntryPoint) {
+            return false;
+        }
+
+        // Check attachments
+        if (a->mColorAttachmentsSet != b->mColorAttachmentsSet ||
+            a->mHasDepthStencilAttachment != b->mHasDepthStencilAttachment) {
+            return false;
+        }
+
+        for (uint32_t i : IterateBitSet(a->mColorAttachmentsSet)) {
+            const ColorStateDescriptor& descA = *a->GetColorStateDescriptor(i);
+            const ColorStateDescriptor& descB = *b->GetColorStateDescriptor(i);
+            if (descA.format != descB.format || descA.writeMask != descB.writeMask) {
+                return false;
+            }
+            if (descA.colorBlend.operation != descB.colorBlend.operation ||
+                descA.colorBlend.srcFactor != descB.colorBlend.srcFactor ||
+                descA.colorBlend.dstFactor != descB.colorBlend.dstFactor) {
+                return false;
+            }
+            if (descA.alphaBlend.operation != descB.alphaBlend.operation ||
+                descA.alphaBlend.srcFactor != descB.alphaBlend.srcFactor ||
+                descA.alphaBlend.dstFactor != descB.alphaBlend.dstFactor) {
+                return false;
+            }
+        }
+
+        if (a->mHasDepthStencilAttachment) {
+            const DepthStencilStateDescriptor& descA = a->mDepthStencilState;
+            const DepthStencilStateDescriptor& descB = b->mDepthStencilState;
+            if (descA.format != descB.format ||
+                descA.depthWriteEnabled != descB.depthWriteEnabled ||
+                descA.depthCompare != descB.depthCompare) {
+                return false;
+            }
+            if (descA.stencilReadMask != descB.stencilReadMask ||
+                descA.stencilWriteMask != descB.stencilWriteMask) {
+                return false;
+            }
+            if (descA.stencilFront.compare != descB.stencilFront.compare ||
+                descA.stencilFront.failOp != descB.stencilFront.failOp ||
+                descA.stencilFront.depthFailOp != descB.stencilFront.depthFailOp ||
+                descA.stencilFront.passOp != descB.stencilFront.passOp) {
+                return false;
+            }
+            if (descA.stencilBack.compare != descB.stencilBack.compare ||
+                descA.stencilBack.failOp != descB.stencilBack.failOp ||
+                descA.stencilBack.depthFailOp != descB.stencilBack.depthFailOp ||
+                descA.stencilBack.passOp != descB.stencilBack.passOp) {
+                return false;
+            }
+        }
+
+        // Check vertex input state
+        if (a->mAttributesSetMask != b->mAttributesSetMask) {
+            return false;
+        }
+
+        for (uint32_t i : IterateBitSet(a->mAttributesSetMask)) {
+            const VertexAttributeDescriptor& descA = a->GetAttribute(i);
+            const VertexAttributeDescriptor& descB = b->GetAttribute(i);
+            if (descA.shaderLocation != descB.shaderLocation ||
+                descA.inputSlot != descB.inputSlot || descA.offset != descB.offset ||
+                descA.format != descB.format) {
+                return false;
+            }
+        }
+
+        if (a->mInputsSetMask != b->mInputsSetMask) {
+            return false;
+        }
+
+        for (uint32_t i : IterateBitSet(a->mInputsSetMask)) {
+            const VertexInputDescriptor& descA = a->GetInput(i);
+            const VertexInputDescriptor& descB = b->GetInput(i);
+            if (descA.inputSlot != descB.inputSlot || descA.stride != descB.stride ||
+                descA.stepMode != descB.stepMode) {
+                return false;
+            }
+        }
+
+        if (a->mInputState.indexFormat != b->mInputState.indexFormat) {
+            return false;
+        }
+
+        // Check rasterization state
+        {
+            const RasterizationStateDescriptor& descA = a->mRasterizationState;
+            const RasterizationStateDescriptor& descB = b->mRasterizationState;
+            if (descA.frontFace != descB.frontFace || descA.cullMode != descB.cullMode) {
+                return false;
+            }
+            if (descA.depthBias != descB.depthBias ||
+                descA.depthBiasSlopeScale != descB.depthBiasSlopeScale ||
+                descA.depthBiasClamp != descB.depthBiasClamp) {
+                return false;
+            }
+        }
+
+        // Check other state
+        if (a->mSampleCount != b->mSampleCount || a->mPrimitiveTopology != b->mPrimitiveTopology) {
+            return false;
+        }
+
+        return true;
+    }
+
 }  // namespace dawn_native
diff --git a/src/dawn_native/RenderPipeline.h b/src/dawn_native/RenderPipeline.h
index f72afd7..98fc1ad 100644
--- a/src/dawn_native/RenderPipeline.h
+++ b/src/dawn_native/RenderPipeline.h
@@ -40,7 +40,10 @@
 
     class RenderPipelineBase : public PipelineBase {
       public:
-        RenderPipelineBase(DeviceBase* device, const RenderPipelineDescriptor* descriptor);
+        RenderPipelineBase(DeviceBase* device,
+                           const RenderPipelineDescriptor* descriptor,
+                           bool blueprint = false);
+        ~RenderPipelineBase() override;
 
         static RenderPipelineBase* MakeError(DeviceBase* device);
 
@@ -50,8 +53,8 @@
         const std::bitset<kMaxVertexInputs>& GetInputsSetMask() const;
         const VertexInputDescriptor& GetInput(uint32_t slot) const;
 
-        const ColorStateDescriptor* GetColorStateDescriptor(uint32_t attachmentSlot);
-        const DepthStencilStateDescriptor* GetDepthStencilStateDescriptor();
+        const ColorStateDescriptor* GetColorStateDescriptor(uint32_t attachmentSlot) const;
+        const DepthStencilStateDescriptor* GetDepthStencilStateDescriptor() const;
         dawn::PrimitiveTopology GetPrimitiveTopology() const;
         dawn::CullMode GetCullMode() const;
         dawn::FrontFace GetFrontFace() const;
@@ -68,23 +71,43 @@
         std::bitset<kMaxVertexAttributes> GetAttributesUsingInput(uint32_t slot) const;
         std::array<std::bitset<kMaxVertexAttributes>, kMaxVertexInputs> attributesUsingInput;
 
+        // Functors necessary for the unordered_set<RenderPipelineBase*>-based cache.
+        struct HashFunc {
+            size_t operator()(const RenderPipelineBase* pipeline) const;
+        };
+        struct EqualityFunc {
+            bool operator()(const RenderPipelineBase* a, const RenderPipelineBase* b) const;
+        };
+
       private:
         RenderPipelineBase(DeviceBase* device, ObjectBase::ErrorTag tag);
 
+        // Vertex input
         InputStateDescriptor mInputState;
         std::bitset<kMaxVertexAttributes> mAttributesSetMask;
         std::array<VertexAttributeDescriptor, kMaxVertexAttributes> mAttributeInfos;
         std::bitset<kMaxVertexInputs> mInputsSetMask;
         std::array<VertexInputDescriptor, kMaxVertexInputs> mInputInfos;
-        dawn::PrimitiveTopology mPrimitiveTopology;
-        RasterizationStateDescriptor mRasterizationState;
+
+        // Attachments
+        bool mHasDepthStencilAttachment = false;
         DepthStencilStateDescriptor mDepthStencilState;
+        std::bitset<kMaxColorAttachments> mColorAttachmentsSet;
         std::array<ColorStateDescriptor, kMaxColorAttachments> mColorStates;
 
-        std::bitset<kMaxColorAttachments> mColorAttachmentsSet;
-        bool mHasDepthStencilAttachment = false;
-
+        // Other state
+        dawn::PrimitiveTopology mPrimitiveTopology;
+        RasterizationStateDescriptor mRasterizationState;
         uint32_t mSampleCount;
+
+        // Stage information
+        // TODO(cwallez@chromium.org): Store a crypto hash of the modules instead.
+        Ref<ShaderModuleBase> mVertexModule;
+        std::string mVertexEntryPoint;
+        Ref<ShaderModuleBase> mFragmentModule;
+        std::string mFragmentEntryPoint;
+
+        bool mIsBlueprint = false;
     };
 
 }  // namespace dawn_native
diff --git a/src/tests/end2end/ObjectCachingTests.cpp b/src/tests/end2end/ObjectCachingTests.cpp
index 5172a6b..8b8e300 100644
--- a/src/tests/end2end/ObjectCachingTests.cpp
+++ b/src/tests/end2end/ObjectCachingTests.cpp
@@ -14,6 +14,7 @@
 
 #include "tests/DawnTest.h"
 
+#include "utils/ComboRenderPipelineDescriptor.h"
 #include "utils/DawnHelpers.h"
 
 class ObjectCachingTest : public DawnTest {};
@@ -165,4 +166,124 @@
     EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire());
 }
 
+// Test that RenderPipelines are correctly deduplicated wrt. their layout
+TEST_P(ObjectCachingTest, RenderPipelineDeduplicationOnLayout) {
+    dawn::BindGroupLayout bgl = utils::MakeBindGroupLayout(
+        device, {{1, dawn::ShaderStageBit::Fragment, dawn::BindingType::UniformBuffer}});
+    dawn::BindGroupLayout otherBgl = utils::MakeBindGroupLayout(
+        device, {{1, dawn::ShaderStageBit::Vertex, dawn::BindingType::UniformBuffer}});
+
+    dawn::PipelineLayout pl = utils::MakeBasicPipelineLayout(device, &bgl);
+    dawn::PipelineLayout samePl = utils::MakeBasicPipelineLayout(device, &bgl);
+    dawn::PipelineLayout otherPl = utils::MakeBasicPipelineLayout(device, nullptr);
+
+    EXPECT_NE(pl.Get(), otherPl.Get());
+    EXPECT_EQ(pl.Get() == samePl.Get(), !UsesWire());
+
+    utils::ComboRenderPipelineDescriptor desc(device);
+    desc.cVertexStage.module = utils::CreateShaderModule(device, dawn::ShaderStage::Vertex, R"(
+            #version 450
+            void main() {
+                gl_Position = vec4(0.0);
+            })");
+    desc.cFragmentStage.module = utils::CreateShaderModule(device, dawn::ShaderStage::Fragment, R"(
+            #version 450
+            void main() {
+            })");
+
+    desc.layout = pl;
+    dawn::RenderPipeline pipeline = device.CreateRenderPipeline(&desc);
+
+    desc.layout = samePl;
+    dawn::RenderPipeline samePipeline = device.CreateRenderPipeline(&desc);
+
+    desc.layout = otherPl;
+    dawn::RenderPipeline otherPipeline = device.CreateRenderPipeline(&desc);
+
+    EXPECT_NE(pipeline.Get(), otherPipeline.Get());
+    EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire());
+}
+
+// Test that RenderPipelines are correctly deduplicated wrt. their vertex module
+TEST_P(ObjectCachingTest, RenderPipelineDeduplicationOnVertexModule) {
+    dawn::ShaderModule module = utils::CreateShaderModule(device, dawn::ShaderStage::Vertex, R"(
+            #version 450
+            void main() {
+                gl_Position = vec4(0.0);
+            })");
+    dawn::ShaderModule sameModule = utils::CreateShaderModule(device, dawn::ShaderStage::Vertex, R"(
+            #version 450
+            void main() {
+                gl_Position = vec4(0.0);
+            })");
+    dawn::ShaderModule otherModule =
+        utils::CreateShaderModule(device, dawn::ShaderStage::Vertex, R"(
+            #version 450
+            void main() {
+                gl_Position = vec4(1.0);
+            })");
+
+    EXPECT_NE(module.Get(), otherModule.Get());
+    EXPECT_EQ(module.Get() == sameModule.Get(), !UsesWire());
+
+    utils::ComboRenderPipelineDescriptor desc(device);
+    desc.cFragmentStage.module = utils::CreateShaderModule(device, dawn::ShaderStage::Fragment, R"(
+            #version 450
+            void main() {
+            })");
+
+    desc.cVertexStage.module = module;
+    dawn::RenderPipeline pipeline = device.CreateRenderPipeline(&desc);
+
+    desc.cVertexStage.module = sameModule;
+    dawn::RenderPipeline samePipeline = device.CreateRenderPipeline(&desc);
+
+    desc.cVertexStage.module = otherModule;
+    dawn::RenderPipeline otherPipeline = device.CreateRenderPipeline(&desc);
+
+    EXPECT_NE(pipeline.Get(), otherPipeline.Get());
+    EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire());
+}
+
+// Test that RenderPipelines are correctly deduplicated wrt. their fragment module
+TEST_P(ObjectCachingTest, RenderPipelineDeduplicationOnFragmentModule) {
+    dawn::ShaderModule module = utils::CreateShaderModule(device, dawn::ShaderStage::Fragment, R"(
+            #version 450
+            void main() {
+            })");
+    dawn::ShaderModule sameModule =
+        utils::CreateShaderModule(device, dawn::ShaderStage::Fragment, R"(
+            #version 450
+            void main() {
+            })");
+    dawn::ShaderModule otherModule =
+        utils::CreateShaderModule(device, dawn::ShaderStage::Fragment, R"(
+            #version 450
+            void main() {
+                int i = 0;
+            })");
+
+    EXPECT_NE(module.Get(), otherModule.Get());
+    EXPECT_EQ(module.Get() == sameModule.Get(), !UsesWire());
+
+    utils::ComboRenderPipelineDescriptor desc(device);
+    desc.cVertexStage.module = utils::CreateShaderModule(device, dawn::ShaderStage::Vertex, R"(
+            #version 450
+            void main() {
+                gl_Position = vec4(0.0);
+            })");
+
+    desc.cFragmentStage.module = module;
+    dawn::RenderPipeline pipeline = device.CreateRenderPipeline(&desc);
+
+    desc.cFragmentStage.module = sameModule;
+    dawn::RenderPipeline samePipeline = device.CreateRenderPipeline(&desc);
+
+    desc.cFragmentStage.module = otherModule;
+    dawn::RenderPipeline otherPipeline = device.CreateRenderPipeline(&desc);
+
+    EXPECT_NE(pipeline.Get(), otherPipeline.Get());
+    EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire());
+}
+
 DAWN_INSTANTIATE_TEST(ObjectCachingTest, D3D12Backend, MetalBackend, OpenGLBackend, VulkanBackend);