dawn_native: deduplicate pipeline layouts
This is the first step to do pipeline deduplication. It also introduces
tests for deduplication.
BUG=dawn:143
Change-Id: Ib22496f543f8d1f9cfde04f725612504132c7d72
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/6861
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
diff --git a/BUILD.gn b/BUILD.gn
index b5365d3..44709b1 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -643,6 +643,7 @@
"src/tests/end2end/IndexFormatTests.cpp",
"src/tests/end2end/InputStateTests.cpp",
"src/tests/end2end/MultisampledRenderingTests.cpp",
+ "src/tests/end2end/ObjectCachingTests.cpp",
"src/tests/end2end/PrimitiveTopologyTests.cpp",
"src/tests/end2end/PushConstantTests.cpp",
"src/tests/end2end/RenderPassLoadOpTests.cpp",
diff --git a/src/dawn_native/BindGroupLayout.cpp b/src/dawn_native/BindGroupLayout.cpp
index dc3d65a..b5acd8b 100644
--- a/src/dawn_native/BindGroupLayout.cpp
+++ b/src/dawn_native/BindGroupLayout.cpp
@@ -114,15 +114,13 @@
return mBindingInfo;
}
- // BindGroupLayoutCacheFuncs
-
- size_t BindGroupLayoutCacheFuncs::operator()(const BindGroupLayoutBase* bgl) const {
- return HashBindingInfo(bgl->GetBindingInfo());
+ size_t BindGroupLayoutBase::HashFunc::operator()(const BindGroupLayoutBase* bgl) const {
+ return HashBindingInfo(bgl->mBindingInfo);
}
- bool BindGroupLayoutCacheFuncs::operator()(const BindGroupLayoutBase* a,
- const BindGroupLayoutBase* b) const {
- return a->GetBindingInfo() == b->GetBindingInfo();
+ bool BindGroupLayoutBase::EqualityFunc::operator()(const BindGroupLayoutBase* a,
+ const BindGroupLayoutBase* b) const {
+ return a->mBindingInfo == b->mBindingInfo;
}
} // namespace dawn_native
diff --git a/src/dawn_native/BindGroupLayout.h b/src/dawn_native/BindGroupLayout.h
index 08f38fa..047228f 100644
--- a/src/dawn_native/BindGroupLayout.h
+++ b/src/dawn_native/BindGroupLayout.h
@@ -46,6 +46,14 @@
};
const LayoutBindingInfo& GetBindingInfo() const;
+ // Functors necessary for the unordered_set<BGLBase*>-based cache.
+ struct HashFunc {
+ size_t operator()(const BindGroupLayoutBase* bgl) const;
+ };
+ struct EqualityFunc {
+ bool operator()(const BindGroupLayoutBase* a, const BindGroupLayoutBase* b) const;
+ };
+
private:
BindGroupLayoutBase(DeviceBase* device, ObjectBase::ErrorTag tag);
@@ -53,15 +61,6 @@
bool mIsBlueprint = false;
};
- // Implements the functors necessary for the unordered_set<BGL*>-based cache.
- struct BindGroupLayoutCacheFuncs {
- // The hash function
- size_t operator()(const BindGroupLayoutBase* bgl) const;
-
- // The equality predicate
- bool operator()(const BindGroupLayoutBase* a, const BindGroupLayoutBase* b) const;
- };
-
} // namespace dawn_native
#endif // DAWNNATIVE_BINDGROUPLAYOUT_H_
diff --git a/src/dawn_native/Device.cpp b/src/dawn_native/Device.cpp
index b7aa559..33bbdec 100644
--- a/src/dawn_native/Device.cpp
+++ b/src/dawn_native/Device.cpp
@@ -42,11 +42,13 @@
// The caches are unordered_sets of pointers with special hash and compare functions
// to compare the value of the objects, instead of the pointers.
- using BindGroupLayoutCache = std::
- unordered_set<BindGroupLayoutBase*, BindGroupLayoutCacheFuncs, BindGroupLayoutCacheFuncs>;
+ template <typename Object>
+ using ContentLessObjectCache =
+ std::unordered_set<Object*, typename Object::HashFunc, typename Object::EqualityFunc>;
struct DeviceBase::Caches {
- BindGroupLayoutCache bindGroupLayouts;
+ ContentLessObjectCache<BindGroupLayoutBase> bindGroupLayouts;
+ ContentLessObjectCache<PipelineLayoutBase> pipelineLayouts;
};
// DeviceBase
@@ -114,7 +116,29 @@
}
void DeviceBase::UncacheBindGroupLayout(BindGroupLayoutBase* obj) {
- mCaches->bindGroupLayouts.erase(obj);
+ size_t removedCount = mCaches->bindGroupLayouts.erase(obj);
+ ASSERT(removedCount == 1);
+ }
+
+ ResultOrError<PipelineLayoutBase*> DeviceBase::GetOrCreatePipelineLayout(
+ const PipelineLayoutDescriptor* descriptor) {
+ PipelineLayoutBase blueprint(this, descriptor, true);
+
+ auto iter = mCaches->pipelineLayouts.find(&blueprint);
+ if (iter != mCaches->pipelineLayouts.end()) {
+ (*iter)->Reference();
+ return *iter;
+ }
+
+ PipelineLayoutBase* backendObj;
+ DAWN_TRY_ASSIGN(backendObj, CreatePipelineLayoutImpl(descriptor));
+ mCaches->pipelineLayouts.insert(backendObj);
+ return backendObj;
+ }
+
+ void DeviceBase::UncachePipelineLayout(PipelineLayoutBase* obj) {
+ size_t removedCount = mCaches->pipelineLayouts.erase(obj);
+ ASSERT(removedCount == 1);
}
// Object creation API methods
@@ -331,7 +355,7 @@
PipelineLayoutBase** result,
const PipelineLayoutDescriptor* descriptor) {
DAWN_TRY(ValidatePipelineLayoutDescriptor(this, descriptor));
- DAWN_TRY_ASSIGN(*result, CreatePipelineLayoutImpl(descriptor));
+ DAWN_TRY_ASSIGN(*result, GetOrCreatePipelineLayout(descriptor));
return {};
}
diff --git a/src/dawn_native/Device.h b/src/dawn_native/Device.h
index 583488f..3f4a913 100644
--- a/src/dawn_native/Device.h
+++ b/src/dawn_native/Device.h
@@ -84,6 +84,10 @@
const BindGroupLayoutDescriptor* descriptor);
void UncacheBindGroupLayout(BindGroupLayoutBase* obj);
+ ResultOrError<PipelineLayoutBase*> GetOrCreatePipelineLayout(
+ const PipelineLayoutDescriptor* descriptor);
+ void UncachePipelineLayout(PipelineLayoutBase* obj);
+
// Dawn API
BindGroupBase* CreateBindGroup(const BindGroupDescriptor* descriptor);
BindGroupLayoutBase* CreateBindGroupLayout(const BindGroupLayoutDescriptor* descriptor);
diff --git a/src/dawn_native/PipelineLayout.cpp b/src/dawn_native/PipelineLayout.cpp
index fd23df9..1658206 100644
--- a/src/dawn_native/PipelineLayout.cpp
+++ b/src/dawn_native/PipelineLayout.cpp
@@ -15,6 +15,8 @@
#include "dawn_native/PipelineLayout.h"
#include "common/Assert.h"
+#include "common/BitSetIterator.h"
+#include "common/HashUtils.h"
#include "dawn_native/BindGroupLayout.h"
#include "dawn_native/Device.h"
@@ -39,8 +41,9 @@
// PipelineLayoutBase
PipelineLayoutBase::PipelineLayoutBase(DeviceBase* device,
- const PipelineLayoutDescriptor* descriptor)
- : ObjectBase(device) {
+ const PipelineLayoutDescriptor* descriptor,
+ bool blueprint)
+ : ObjectBase(device), mIsBlueprint(blueprint) {
ASSERT(descriptor->bindGroupLayoutCount <= kMaxBindGroups);
for (uint32_t group = 0; group < descriptor->bindGroupLayoutCount; ++group) {
mBindGroupLayouts[group] = descriptor->bindGroupLayouts[group];
@@ -52,6 +55,14 @@
: ObjectBase(device, tag) {
}
+ PipelineLayoutBase::~PipelineLayoutBase() {
+ // Do not uncache the actual cached object if we are a blueprint
+ if (!mIsBlueprint) {
+ ASSERT(!IsError());
+ GetDevice()->UncachePipelineLayout(this);
+ }
+ }
+
// static
PipelineLayoutBase* PipelineLayoutBase::MakeError(DeviceBase* device) {
return new PipelineLayoutBase(device, ObjectBase::kError);
@@ -86,4 +97,29 @@
return kMaxBindGroups;
}
+ size_t PipelineLayoutBase::HashFunc::operator()(const PipelineLayoutBase* pl) const {
+ size_t hash = Hash(pl->mMask);
+
+ for (uint32_t group : IterateBitSet(pl->mMask)) {
+ HashCombine(&hash, pl->GetBindGroupLayout(group));
+ }
+
+ return hash;
+ }
+
+ bool PipelineLayoutBase::EqualityFunc::operator()(const PipelineLayoutBase* a,
+ const PipelineLayoutBase* b) const {
+ if (a->mMask != b->mMask) {
+ return false;
+ }
+
+ for (uint32_t group : IterateBitSet(a->mMask)) {
+ if (a->GetBindGroupLayout(group) != b->GetBindGroupLayout(group)) {
+ return false;
+ }
+ }
+
+ return true;
+ }
+
} // namespace dawn_native
diff --git a/src/dawn_native/PipelineLayout.h b/src/dawn_native/PipelineLayout.h
index aa1e3e5..f101498 100644
--- a/src/dawn_native/PipelineLayout.h
+++ b/src/dawn_native/PipelineLayout.h
@@ -34,7 +34,10 @@
class PipelineLayoutBase : public ObjectBase {
public:
- PipelineLayoutBase(DeviceBase* device, const PipelineLayoutDescriptor* descriptor);
+ PipelineLayoutBase(DeviceBase* device,
+ const PipelineLayoutDescriptor* descriptor,
+ bool blueprint = false);
+ ~PipelineLayoutBase() override;
static PipelineLayoutBase* MakeError(DeviceBase* device);
@@ -49,11 +52,20 @@
// [1, kMaxBindGroups + 1]
uint32_t GroupsInheritUpTo(const PipelineLayoutBase* other) const;
+ // Functors necessary for the unordered_set<PipelineLayoutBase*>-based cache.
+ struct HashFunc {
+ size_t operator()(const PipelineLayoutBase* pl) const;
+ };
+ struct EqualityFunc {
+ bool operator()(const PipelineLayoutBase* a, const PipelineLayoutBase* b) const;
+ };
+
protected:
PipelineLayoutBase(DeviceBase* device, ObjectBase::ErrorTag tag);
BindGroupLayoutArray mBindGroupLayouts;
std::bitset<kMaxBindGroups> mMask;
+ bool mIsBlueprint = false;
};
} // namespace dawn_native
diff --git a/src/tests/DawnTest.cpp b/src/tests/DawnTest.cpp
index ec1f7c6..ee8c3e7 100644
--- a/src/tests/DawnTest.cpp
+++ b/src/tests/DawnTest.cpp
@@ -166,7 +166,7 @@
std::cout << std::endl;
}
-bool DawnTestEnvironment::UseWire() const {
+bool DawnTestEnvironment::UsesWire() const {
return mUseWire;
}
@@ -266,6 +266,10 @@
#endif
}
+bool DawnTest::UsesWire() const {
+ return gTestEnv->UsesWire();
+}
+
void DawnTest::SetUp() {
// Get an adapter for the backend to use, and create the device.
dawn_native::Adapter backendAdapter;
@@ -314,7 +318,7 @@
DawnDevice cDevice = nullptr;
DawnProcTable procs;
- if (gTestEnv->UseWire()) {
+ if (gTestEnv->UsesWire()) {
mC2sBuf = std::make_unique<utils::TerribleCommandBuffer>();
mS2cBuf = std::make_unique<utils::TerribleCommandBuffer>();
@@ -473,7 +477,7 @@
}
void DawnTest::FlushWire() {
- if (gTestEnv->UseWire()) {
+ if (gTestEnv->UsesWire()) {
bool C2SFlushed = mC2sBuf->Flush();
bool S2CFlushed = mS2cBuf->Flush();
ASSERT(C2SFlushed);
diff --git a/src/tests/DawnTest.h b/src/tests/DawnTest.h
index c3c9d45..eee016d 100644
--- a/src/tests/DawnTest.h
+++ b/src/tests/DawnTest.h
@@ -106,7 +106,7 @@
void SetUp() override;
- bool UseWire() const;
+ bool UsesWire() const;
dawn_native::Instance* GetInstance() const;
GLFWwindow* GetWindowForBackend(dawn_native::BackendType type) const;
@@ -146,6 +146,8 @@
bool IsLinux() const;
bool IsMacOS() const;
+ bool UsesWire() const;
+
void StartExpectDeviceError();
bool EndExpectDeviceError();
diff --git a/src/tests/end2end/ObjectCachingTests.cpp b/src/tests/end2end/ObjectCachingTests.cpp
new file mode 100644
index 0000000..d2a591b
--- /dev/null
+++ b/src/tests/end2end/ObjectCachingTests.cpp
@@ -0,0 +1,51 @@
+// Copyright 2019 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "tests/DawnTest.h"
+
+#include "utils/DawnHelpers.h"
+
+class ObjectCachingTest : public DawnTest {};
+
+// Test that BindGroupLayouts are correctly deduplicated.
+TEST_P(ObjectCachingTest, BindGroupLayoutDeduplication) {
+ dawn::BindGroupLayout bgl = utils::MakeBindGroupLayout(
+ device, {{1, dawn::ShaderStageBit::Fragment, dawn::BindingType::UniformBuffer}});
+ dawn::BindGroupLayout sameBgl = utils::MakeBindGroupLayout(
+ device, {{1, dawn::ShaderStageBit::Fragment, dawn::BindingType::UniformBuffer}});
+ dawn::BindGroupLayout otherBgl = utils::MakeBindGroupLayout(
+ device, {{1, dawn::ShaderStageBit::Vertex, dawn::BindingType::UniformBuffer}});
+
+ EXPECT_NE(bgl.Get(), otherBgl.Get());
+ EXPECT_EQ(bgl.Get() == sameBgl.Get(), !UsesWire());
+}
+
+// Test that PipelineLayouts are correctly deduplicated.
+TEST_P(ObjectCachingTest, PipelineLayoutDeduplication) {
+ dawn::BindGroupLayout bgl = utils::MakeBindGroupLayout(
+ device, {{1, dawn::ShaderStageBit::Fragment, dawn::BindingType::UniformBuffer}});
+ dawn::BindGroupLayout otherBgl = utils::MakeBindGroupLayout(
+ device, {{1, dawn::ShaderStageBit::Vertex, dawn::BindingType::UniformBuffer}});
+
+ dawn::PipelineLayout pl = utils::MakeBasicPipelineLayout(device, &bgl);
+ dawn::PipelineLayout samePl = utils::MakeBasicPipelineLayout(device, &bgl);
+ dawn::PipelineLayout otherPl1 = utils::MakeBasicPipelineLayout(device, nullptr);
+ dawn::PipelineLayout otherPl2 = utils::MakeBasicPipelineLayout(device, &otherBgl);
+
+ EXPECT_NE(pl.Get(), otherPl1.Get());
+ EXPECT_NE(pl.Get(), otherPl2.Get());
+ EXPECT_EQ(pl.Get() == samePl.Get(), !UsesWire());
+}
+
+DAWN_INSTANTIATE_TEST(ObjectCachingTest, D3D12Backend, MetalBackend, OpenGLBackend, VulkanBackend);
diff --git a/src/utils/DawnHelpers.cpp b/src/utils/DawnHelpers.cpp
index 0fa9441..a0d2121 100644
--- a/src/utils/DawnHelpers.cpp
+++ b/src/utils/DawnHelpers.cpp
@@ -278,7 +278,7 @@
dawn::PipelineLayout MakeBasicPipelineLayout(const dawn::Device& device,
const dawn::BindGroupLayout* bindGroupLayout) {
dawn::PipelineLayoutDescriptor descriptor;
- if (bindGroupLayout) {
+ if (bindGroupLayout != nullptr) {
descriptor.bindGroupLayoutCount = 1;
descriptor.bindGroupLayouts = bindGroupLayout;
} else {