dawn_native: deduplicate pipeline layouts

This is the first step to do pipeline deduplication. It also introduces
tests for deduplication.

BUG=dawn:143

Change-Id: Ib22496f543f8d1f9cfde04f725612504132c7d72
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/6861
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
diff --git a/BUILD.gn b/BUILD.gn
index b5365d3..44709b1 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -643,6 +643,7 @@
     "src/tests/end2end/IndexFormatTests.cpp",
     "src/tests/end2end/InputStateTests.cpp",
     "src/tests/end2end/MultisampledRenderingTests.cpp",
+    "src/tests/end2end/ObjectCachingTests.cpp",
     "src/tests/end2end/PrimitiveTopologyTests.cpp",
     "src/tests/end2end/PushConstantTests.cpp",
     "src/tests/end2end/RenderPassLoadOpTests.cpp",
diff --git a/src/dawn_native/BindGroupLayout.cpp b/src/dawn_native/BindGroupLayout.cpp
index dc3d65a..b5acd8b 100644
--- a/src/dawn_native/BindGroupLayout.cpp
+++ b/src/dawn_native/BindGroupLayout.cpp
@@ -114,15 +114,13 @@
         return mBindingInfo;
     }
 
-    // BindGroupLayoutCacheFuncs
-
-    size_t BindGroupLayoutCacheFuncs::operator()(const BindGroupLayoutBase* bgl) const {
-        return HashBindingInfo(bgl->GetBindingInfo());
+    size_t BindGroupLayoutBase::HashFunc::operator()(const BindGroupLayoutBase* bgl) const {
+        return HashBindingInfo(bgl->mBindingInfo);
     }
 
-    bool BindGroupLayoutCacheFuncs::operator()(const BindGroupLayoutBase* a,
-                                               const BindGroupLayoutBase* b) const {
-        return a->GetBindingInfo() == b->GetBindingInfo();
+    bool BindGroupLayoutBase::EqualityFunc::operator()(const BindGroupLayoutBase* a,
+                                                       const BindGroupLayoutBase* b) const {
+        return a->mBindingInfo == b->mBindingInfo;
     }
 
 }  // namespace dawn_native
diff --git a/src/dawn_native/BindGroupLayout.h b/src/dawn_native/BindGroupLayout.h
index 08f38fa..047228f 100644
--- a/src/dawn_native/BindGroupLayout.h
+++ b/src/dawn_native/BindGroupLayout.h
@@ -46,6 +46,14 @@
         };
         const LayoutBindingInfo& GetBindingInfo() const;
 
+        // Functors necessary for the unordered_set<BGLBase*>-based cache.
+        struct HashFunc {
+            size_t operator()(const BindGroupLayoutBase* bgl) const;
+        };
+        struct EqualityFunc {
+            bool operator()(const BindGroupLayoutBase* a, const BindGroupLayoutBase* b) const;
+        };
+
       private:
         BindGroupLayoutBase(DeviceBase* device, ObjectBase::ErrorTag tag);
 
@@ -53,15 +61,6 @@
         bool mIsBlueprint = false;
     };
 
-    // Implements the functors necessary for the unordered_set<BGL*>-based cache.
-    struct BindGroupLayoutCacheFuncs {
-        // The hash function
-        size_t operator()(const BindGroupLayoutBase* bgl) const;
-
-        // The equality predicate
-        bool operator()(const BindGroupLayoutBase* a, const BindGroupLayoutBase* b) const;
-    };
-
 }  // namespace dawn_native
 
 #endif  // DAWNNATIVE_BINDGROUPLAYOUT_H_
diff --git a/src/dawn_native/Device.cpp b/src/dawn_native/Device.cpp
index b7aa559..33bbdec 100644
--- a/src/dawn_native/Device.cpp
+++ b/src/dawn_native/Device.cpp
@@ -42,11 +42,13 @@
 
     // The caches are unordered_sets of pointers with special hash and compare functions
     // to compare the value of the objects, instead of the pointers.
-    using BindGroupLayoutCache = std::
-        unordered_set<BindGroupLayoutBase*, BindGroupLayoutCacheFuncs, BindGroupLayoutCacheFuncs>;
+    template <typename Object>
+    using ContentLessObjectCache =
+        std::unordered_set<Object*, typename Object::HashFunc, typename Object::EqualityFunc>;
 
     struct DeviceBase::Caches {
-        BindGroupLayoutCache bindGroupLayouts;
+        ContentLessObjectCache<BindGroupLayoutBase> bindGroupLayouts;
+        ContentLessObjectCache<PipelineLayoutBase> pipelineLayouts;
     };
 
     // DeviceBase
@@ -114,7 +116,29 @@
     }
 
     void DeviceBase::UncacheBindGroupLayout(BindGroupLayoutBase* obj) {
-        mCaches->bindGroupLayouts.erase(obj);
+        size_t removedCount = mCaches->bindGroupLayouts.erase(obj);
+        ASSERT(removedCount == 1);
+    }
+
+    ResultOrError<PipelineLayoutBase*> DeviceBase::GetOrCreatePipelineLayout(
+        const PipelineLayoutDescriptor* descriptor) {
+        PipelineLayoutBase blueprint(this, descriptor, true);
+
+        auto iter = mCaches->pipelineLayouts.find(&blueprint);
+        if (iter != mCaches->pipelineLayouts.end()) {
+            (*iter)->Reference();
+            return *iter;
+        }
+
+        PipelineLayoutBase* backendObj;
+        DAWN_TRY_ASSIGN(backendObj, CreatePipelineLayoutImpl(descriptor));
+        mCaches->pipelineLayouts.insert(backendObj);
+        return backendObj;
+    }
+
+    void DeviceBase::UncachePipelineLayout(PipelineLayoutBase* obj) {
+        size_t removedCount = mCaches->pipelineLayouts.erase(obj);
+        ASSERT(removedCount == 1);
     }
 
     // Object creation API methods
@@ -331,7 +355,7 @@
         PipelineLayoutBase** result,
         const PipelineLayoutDescriptor* descriptor) {
         DAWN_TRY(ValidatePipelineLayoutDescriptor(this, descriptor));
-        DAWN_TRY_ASSIGN(*result, CreatePipelineLayoutImpl(descriptor));
+        DAWN_TRY_ASSIGN(*result, GetOrCreatePipelineLayout(descriptor));
         return {};
     }
 
diff --git a/src/dawn_native/Device.h b/src/dawn_native/Device.h
index 583488f..3f4a913 100644
--- a/src/dawn_native/Device.h
+++ b/src/dawn_native/Device.h
@@ -84,6 +84,10 @@
             const BindGroupLayoutDescriptor* descriptor);
         void UncacheBindGroupLayout(BindGroupLayoutBase* obj);
 
+        ResultOrError<PipelineLayoutBase*> GetOrCreatePipelineLayout(
+            const PipelineLayoutDescriptor* descriptor);
+        void UncachePipelineLayout(PipelineLayoutBase* obj);
+
         // Dawn API
         BindGroupBase* CreateBindGroup(const BindGroupDescriptor* descriptor);
         BindGroupLayoutBase* CreateBindGroupLayout(const BindGroupLayoutDescriptor* descriptor);
diff --git a/src/dawn_native/PipelineLayout.cpp b/src/dawn_native/PipelineLayout.cpp
index fd23df9..1658206 100644
--- a/src/dawn_native/PipelineLayout.cpp
+++ b/src/dawn_native/PipelineLayout.cpp
@@ -15,6 +15,8 @@
 #include "dawn_native/PipelineLayout.h"
 
 #include "common/Assert.h"
+#include "common/BitSetIterator.h"
+#include "common/HashUtils.h"
 #include "dawn_native/BindGroupLayout.h"
 #include "dawn_native/Device.h"
 
@@ -39,8 +41,9 @@
     // PipelineLayoutBase
 
     PipelineLayoutBase::PipelineLayoutBase(DeviceBase* device,
-                                           const PipelineLayoutDescriptor* descriptor)
-        : ObjectBase(device) {
+                                           const PipelineLayoutDescriptor* descriptor,
+                                           bool blueprint)
+        : ObjectBase(device), mIsBlueprint(blueprint) {
         ASSERT(descriptor->bindGroupLayoutCount <= kMaxBindGroups);
         for (uint32_t group = 0; group < descriptor->bindGroupLayoutCount; ++group) {
             mBindGroupLayouts[group] = descriptor->bindGroupLayouts[group];
@@ -52,6 +55,14 @@
         : ObjectBase(device, tag) {
     }
 
+    PipelineLayoutBase::~PipelineLayoutBase() {
+        // Do not uncache the actual cached object if we are a blueprint
+        if (!mIsBlueprint) {
+            ASSERT(!IsError());
+            GetDevice()->UncachePipelineLayout(this);
+        }
+    }
+
     // static
     PipelineLayoutBase* PipelineLayoutBase::MakeError(DeviceBase* device) {
         return new PipelineLayoutBase(device, ObjectBase::kError);
@@ -86,4 +97,29 @@
         return kMaxBindGroups;
     }
 
+    size_t PipelineLayoutBase::HashFunc::operator()(const PipelineLayoutBase* pl) const {
+        size_t hash = Hash(pl->mMask);
+
+        for (uint32_t group : IterateBitSet(pl->mMask)) {
+            HashCombine(&hash, pl->GetBindGroupLayout(group));
+        }
+
+        return hash;
+    }
+
+    bool PipelineLayoutBase::EqualityFunc::operator()(const PipelineLayoutBase* a,
+                                                      const PipelineLayoutBase* b) const {
+        if (a->mMask != b->mMask) {
+            return false;
+        }
+
+        for (uint32_t group : IterateBitSet(a->mMask)) {
+            if (a->GetBindGroupLayout(group) != b->GetBindGroupLayout(group)) {
+                return false;
+            }
+        }
+
+        return true;
+    }
+
 }  // namespace dawn_native
diff --git a/src/dawn_native/PipelineLayout.h b/src/dawn_native/PipelineLayout.h
index aa1e3e5..f101498 100644
--- a/src/dawn_native/PipelineLayout.h
+++ b/src/dawn_native/PipelineLayout.h
@@ -34,7 +34,10 @@
 
     class PipelineLayoutBase : public ObjectBase {
       public:
-        PipelineLayoutBase(DeviceBase* device, const PipelineLayoutDescriptor* descriptor);
+        PipelineLayoutBase(DeviceBase* device,
+                           const PipelineLayoutDescriptor* descriptor,
+                           bool blueprint = false);
+        ~PipelineLayoutBase() override;
 
         static PipelineLayoutBase* MakeError(DeviceBase* device);
 
@@ -49,11 +52,20 @@
         // [1, kMaxBindGroups + 1]
         uint32_t GroupsInheritUpTo(const PipelineLayoutBase* other) const;
 
+        // Functors necessary for the unordered_set<PipelineLayoutBase*>-based cache.
+        struct HashFunc {
+            size_t operator()(const PipelineLayoutBase* pl) const;
+        };
+        struct EqualityFunc {
+            bool operator()(const PipelineLayoutBase* a, const PipelineLayoutBase* b) const;
+        };
+
       protected:
         PipelineLayoutBase(DeviceBase* device, ObjectBase::ErrorTag tag);
 
         BindGroupLayoutArray mBindGroupLayouts;
         std::bitset<kMaxBindGroups> mMask;
+        bool mIsBlueprint = false;
     };
 
 }  // namespace dawn_native
diff --git a/src/tests/DawnTest.cpp b/src/tests/DawnTest.cpp
index ec1f7c6..ee8c3e7 100644
--- a/src/tests/DawnTest.cpp
+++ b/src/tests/DawnTest.cpp
@@ -166,7 +166,7 @@
     std::cout << std::endl;
 }
 
-bool DawnTestEnvironment::UseWire() const {
+bool DawnTestEnvironment::UsesWire() const {
     return mUseWire;
 }
 
@@ -266,6 +266,10 @@
 #endif
 }
 
+bool DawnTest::UsesWire() const {
+    return gTestEnv->UsesWire();
+}
+
 void DawnTest::SetUp() {
     // Get an adapter for the backend to use, and create the device.
     dawn_native::Adapter backendAdapter;
@@ -314,7 +318,7 @@
     DawnDevice cDevice = nullptr;
     DawnProcTable procs;
 
-    if (gTestEnv->UseWire()) {
+    if (gTestEnv->UsesWire()) {
         mC2sBuf = std::make_unique<utils::TerribleCommandBuffer>();
         mS2cBuf = std::make_unique<utils::TerribleCommandBuffer>();
 
@@ -473,7 +477,7 @@
 }
 
 void DawnTest::FlushWire() {
-    if (gTestEnv->UseWire()) {
+    if (gTestEnv->UsesWire()) {
         bool C2SFlushed = mC2sBuf->Flush();
         bool S2CFlushed = mS2cBuf->Flush();
         ASSERT(C2SFlushed);
diff --git a/src/tests/DawnTest.h b/src/tests/DawnTest.h
index c3c9d45..eee016d 100644
--- a/src/tests/DawnTest.h
+++ b/src/tests/DawnTest.h
@@ -106,7 +106,7 @@
 
     void SetUp() override;
 
-    bool UseWire() const;
+    bool UsesWire() const;
     dawn_native::Instance* GetInstance() const;
     GLFWwindow* GetWindowForBackend(dawn_native::BackendType type) const;
 
@@ -146,6 +146,8 @@
     bool IsLinux() const;
     bool IsMacOS() const;
 
+    bool UsesWire() const;
+
     void StartExpectDeviceError();
     bool EndExpectDeviceError();
 
diff --git a/src/tests/end2end/ObjectCachingTests.cpp b/src/tests/end2end/ObjectCachingTests.cpp
new file mode 100644
index 0000000..d2a591b
--- /dev/null
+++ b/src/tests/end2end/ObjectCachingTests.cpp
@@ -0,0 +1,51 @@
+// Copyright 2019 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "tests/DawnTest.h"
+
+#include "utils/DawnHelpers.h"
+
+class ObjectCachingTest : public DawnTest {};
+
+// Test that BindGroupLayouts are correctly deduplicated.
+TEST_P(ObjectCachingTest, BindGroupLayoutDeduplication) {
+    dawn::BindGroupLayout bgl = utils::MakeBindGroupLayout(
+        device, {{1, dawn::ShaderStageBit::Fragment, dawn::BindingType::UniformBuffer}});
+    dawn::BindGroupLayout sameBgl = utils::MakeBindGroupLayout(
+        device, {{1, dawn::ShaderStageBit::Fragment, dawn::BindingType::UniformBuffer}});
+    dawn::BindGroupLayout otherBgl = utils::MakeBindGroupLayout(
+        device, {{1, dawn::ShaderStageBit::Vertex, dawn::BindingType::UniformBuffer}});
+
+    EXPECT_NE(bgl.Get(), otherBgl.Get());
+    EXPECT_EQ(bgl.Get() == sameBgl.Get(), !UsesWire());
+}
+
+// Test that PipelineLayouts are correctly deduplicated.
+TEST_P(ObjectCachingTest, PipelineLayoutDeduplication) {
+    dawn::BindGroupLayout bgl = utils::MakeBindGroupLayout(
+        device, {{1, dawn::ShaderStageBit::Fragment, dawn::BindingType::UniformBuffer}});
+    dawn::BindGroupLayout otherBgl = utils::MakeBindGroupLayout(
+        device, {{1, dawn::ShaderStageBit::Vertex, dawn::BindingType::UniformBuffer}});
+
+    dawn::PipelineLayout pl = utils::MakeBasicPipelineLayout(device, &bgl);
+    dawn::PipelineLayout samePl = utils::MakeBasicPipelineLayout(device, &bgl);
+    dawn::PipelineLayout otherPl1 = utils::MakeBasicPipelineLayout(device, nullptr);
+    dawn::PipelineLayout otherPl2 = utils::MakeBasicPipelineLayout(device, &otherBgl);
+
+    EXPECT_NE(pl.Get(), otherPl1.Get());
+    EXPECT_NE(pl.Get(), otherPl2.Get());
+    EXPECT_EQ(pl.Get() == samePl.Get(), !UsesWire());
+}
+
+DAWN_INSTANTIATE_TEST(ObjectCachingTest, D3D12Backend, MetalBackend, OpenGLBackend, VulkanBackend);
diff --git a/src/utils/DawnHelpers.cpp b/src/utils/DawnHelpers.cpp
index 0fa9441..a0d2121 100644
--- a/src/utils/DawnHelpers.cpp
+++ b/src/utils/DawnHelpers.cpp
@@ -278,7 +278,7 @@
     dawn::PipelineLayout MakeBasicPipelineLayout(const dawn::Device& device,
                                                  const dawn::BindGroupLayout* bindGroupLayout) {
         dawn::PipelineLayoutDescriptor descriptor;
-        if (bindGroupLayout) {
+        if (bindGroupLayout != nullptr) {
             descriptor.bindGroupLayoutCount = 1;
             descriptor.bindGroupLayouts = bindGroupLayout;
         } else {