Unify ProgrammableStageDescriptor handling in PipelineBase
Previously both Render and Compute pipelines handled extracting data
from the ProgrammableStageDescriptors. Unify them in PipelineBase in
preparation for gathering EntryPointMetadata in the PipelineBase.
Bug: dawn:216
Change-Id: I633dd2d8c9fdd0c08bb34cbf18955445951e312f
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/27263
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
Reviewed-by: Ryan Harrison <rharrison@chromium.org>
diff --git a/src/dawn_native/CommandBufferStateTracker.cpp b/src/dawn_native/CommandBufferStateTracker.cpp
index 3e806a1..3418753 100644
--- a/src/dawn_native/CommandBufferStateTracker.cpp
+++ b/src/dawn_native/CommandBufferStateTracker.cpp
@@ -26,11 +26,11 @@
namespace {
bool BufferSizesAtLeastAsBig(const ityp::span<uint32_t, uint64_t> unverifiedBufferSizes,
- const std::vector<uint64_t>& pipelineMinimumBufferSizes) {
- ASSERT(unverifiedBufferSizes.size() == pipelineMinimumBufferSizes.size());
+ const std::vector<uint64_t>& pipelineMinBufferSizes) {
+ ASSERT(unverifiedBufferSizes.size() == pipelineMinBufferSizes.size());
for (uint32_t i = 0; i < unverifiedBufferSizes.size(); ++i) {
- if (unverifiedBufferSizes[i] < pipelineMinimumBufferSizes[i]) {
+ if (unverifiedBufferSizes[i] < pipelineMinBufferSizes[i]) {
return false;
}
}
@@ -105,7 +105,7 @@
if (mBindgroups[i] == nullptr ||
mLastPipelineLayout->GetBindGroupLayout(i) != mBindgroups[i]->GetLayout() ||
!BufferSizesAtLeastAsBig(mBindgroups[i]->GetUnverifiedBufferSizes(),
- (*mMinimumBufferSizes)[i])) {
+ (*mMinBufferSizes)[i])) {
matches = false;
break;
}
@@ -190,7 +190,7 @@
"Pipeline and bind group layout doesn't match for bind group " +
std::to_string(static_cast<uint32_t>(i)));
} else if (!BufferSizesAtLeastAsBig(mBindgroups[i]->GetUnverifiedBufferSizes(),
- (*mMinimumBufferSizes)[i])) {
+ (*mMinBufferSizes)[i])) {
return DAWN_VALIDATION_ERROR("Binding sizes too small for bind group " +
std::to_string(static_cast<uint32_t>(i)));
}
@@ -236,7 +236,7 @@
void CommandBufferStateTracker::SetPipelineCommon(PipelineBase* pipeline) {
mLastPipelineLayout = pipeline->GetLayout();
- mMinimumBufferSizes = &pipeline->GetMinimumBufferSizes();
+ mMinBufferSizes = &pipeline->GetMinBufferSizes();
mAspects.set(VALIDATION_ASPECT_PIPELINE);
diff --git a/src/dawn_native/CommandBufferStateTracker.h b/src/dawn_native/CommandBufferStateTracker.h
index 67645e4..146214d 100644
--- a/src/dawn_native/CommandBufferStateTracker.h
+++ b/src/dawn_native/CommandBufferStateTracker.h
@@ -61,7 +61,7 @@
PipelineLayoutBase* mLastPipelineLayout = nullptr;
RenderPipelineBase* mLastRenderPipeline = nullptr;
- const RequiredBufferSizes* mMinimumBufferSizes = nullptr;
+ const RequiredBufferSizes* mMinBufferSizes = nullptr;
};
} // namespace dawn_native
diff --git a/src/dawn_native/ComputePipeline.cpp b/src/dawn_native/ComputePipeline.cpp
index ee49b11..793765d 100644
--- a/src/dawn_native/ComputePipeline.cpp
+++ b/src/dawn_native/ComputePipeline.cpp
@@ -19,13 +19,6 @@
namespace dawn_native {
- namespace {
- RequiredBufferSizes ComputeMinBufferSizes(const ComputePipelineDescriptor* descriptor) {
- return descriptor->computeStage.module->ComputeRequiredBufferSizesForLayout(
- descriptor->layout);
- }
- } // anonymous namespace
-
MaybeError ValidateComputePipelineDescriptor(DeviceBase* device,
const ComputePipelineDescriptor* descriptor) {
if (descriptor->nextInChain != nullptr) {
@@ -47,10 +40,7 @@
const ComputePipelineDescriptor* descriptor)
: PipelineBase(device,
descriptor->layout,
- wgpu::ShaderStage::Compute,
- ComputeMinBufferSizes(descriptor)),
- mModule(descriptor->computeStage.module),
- mEntryPoint(descriptor->computeStage.entryPoint) {
+ {{SingleShaderStage::Compute, &descriptor->computeStage}}) {
}
ComputePipelineBase::ComputePipelineBase(DeviceBase* device, ObjectBase::ErrorTag tag)
@@ -70,15 +60,12 @@
}
size_t ComputePipelineBase::HashFunc::operator()(const ComputePipelineBase* pipeline) const {
- size_t hash = 0;
- HashCombine(&hash, pipeline->mModule.Get(), pipeline->mEntryPoint, pipeline->GetLayout());
- return hash;
+ return PipelineBase::HashForCache(pipeline);
}
bool ComputePipelineBase::EqualityFunc::operator()(const ComputePipelineBase* a,
const ComputePipelineBase* b) const {
- return a->mModule.Get() == b->mModule.Get() && a->mEntryPoint == b->mEntryPoint &&
- a->GetLayout() == b->GetLayout();
+ return PipelineBase::EqualForCache(a, b);
}
} // namespace dawn_native
diff --git a/src/dawn_native/ComputePipeline.h b/src/dawn_native/ComputePipeline.h
index 43d7966..c2f1188 100644
--- a/src/dawn_native/ComputePipeline.h
+++ b/src/dawn_native/ComputePipeline.h
@@ -41,10 +41,6 @@
private:
ComputePipelineBase(DeviceBase* device, ObjectBase::ErrorTag tag);
-
- // TODO(cwallez@chromium.org): Store a crypto hash of the module instead.
- Ref<ShaderModuleBase> mModule;
- std::string mEntryPoint;
};
} // namespace dawn_native
diff --git a/src/dawn_native/Pipeline.cpp b/src/dawn_native/Pipeline.cpp
index ae05c02..f37e30c 100644
--- a/src/dawn_native/Pipeline.cpp
+++ b/src/dawn_native/Pipeline.cpp
@@ -14,6 +14,7 @@
#include "dawn_native/Pipeline.h"
+#include "common/HashUtils.h"
#include "dawn_native/BindGroupLayout.h"
#include "dawn_native/Device.h"
#include "dawn_native/PipelineLayout.h"
@@ -43,23 +44,38 @@
PipelineBase::PipelineBase(DeviceBase* device,
PipelineLayoutBase* layout,
- wgpu::ShaderStage stages,
- RequiredBufferSizes minimumBufferSizes)
- : CachedObject(device),
- mStageMask(stages),
- mLayout(layout),
- mMinimumBufferSizes(std::move(minimumBufferSizes)) {
+ std::vector<StageAndDescriptor> stages)
+ : CachedObject(device), mLayout(layout) {
+ ASSERT(!stages.empty());
+
+ for (const StageAndDescriptor& stage : stages) {
+ bool isFirstStage = mStageMask == wgpu::ShaderStage::None;
+ mStageMask |= StageBit(stage.first);
+ mStages[stage.first] = {stage.second->module, stage.second->entryPoint};
+
+ // Compute the max() of all minBufferSizes across all stages.
+ RequiredBufferSizes stageMinBufferSizes =
+ stage.second->module->ComputeRequiredBufferSizesForLayout(layout);
+
+ if (isFirstStage) {
+ mMinBufferSizes = std::move(stageMinBufferSizes);
+ } else {
+ for (BindGroupIndex group(0); group < mMinBufferSizes.size(); ++group) {
+ ASSERT(stageMinBufferSizes[group].size() == mMinBufferSizes[group].size());
+
+ for (size_t i = 0; i < stageMinBufferSizes[group].size(); ++i) {
+ mMinBufferSizes[group][i] =
+ std::max(mMinBufferSizes[group][i], stageMinBufferSizes[group][i]);
+ }
+ }
+ }
+ }
}
PipelineBase::PipelineBase(DeviceBase* device, ObjectBase::ErrorTag tag)
: CachedObject(device, tag) {
}
- wgpu::ShaderStage PipelineBase::GetStageMask() const {
- ASSERT(!IsError());
- return mStageMask;
- }
-
PipelineLayoutBase* PipelineBase::GetLayout() {
ASSERT(!IsError());
return mLayout.Get();
@@ -70,9 +86,14 @@
return mLayout.Get();
}
- const RequiredBufferSizes& PipelineBase::GetMinimumBufferSizes() const {
+ const RequiredBufferSizes& PipelineBase::GetMinBufferSizes() const {
ASSERT(!IsError());
- return mMinimumBufferSizes;
+ return mMinBufferSizes;
+ }
+
+ const ProgrammableStage& PipelineBase::GetStage(SingleShaderStage stage) const {
+ ASSERT(!IsError());
+ return mStages[stage];
}
MaybeError PipelineBase::ValidateGetBindGroupLayout(uint32_t groupIndex) {
@@ -102,4 +123,39 @@
return bgl;
}
+ // static
+ size_t PipelineBase::HashForCache(const PipelineBase* pipeline) {
+ size_t hash = 0;
+
+ // The layout is deduplicated so it can be hashed by pointer.
+ HashCombine(&hash, pipeline->mLayout.Get());
+
+ HashCombine(&hash, pipeline->mStageMask);
+ for (SingleShaderStage stage : IterateStages(pipeline->mStageMask)) {
+ // The module is deduplicated so it can be hashed by pointer.
+ HashCombine(&hash, pipeline->mStages[stage].module.Get());
+ HashCombine(&hash, pipeline->mStages[stage].entryPoint);
+ }
+
+ return hash;
+ }
+
+ // static
+ bool PipelineBase::EqualForCache(const PipelineBase* a, const PipelineBase* b) {
+ // The layout is deduplicated so it can be compared by pointer.
+ if (a->mLayout.Get() != b->mLayout.Get() || a->mStageMask != b->mStageMask) {
+ return false;
+ }
+
+ for (SingleShaderStage stage : IterateStages(a->mStageMask)) {
+ // The module is deduplicated so it can be compared by pointer.
+ if (a->mStages[stage].module.Get() != b->mStages[stage].module.Get() ||
+ a->mStages[stage].entryPoint != b->mStages[stage].entryPoint) {
+ return false;
+ }
+ }
+
+ return true;
+ }
+
} // namespace dawn_native
diff --git a/src/dawn_native/Pipeline.h b/src/dawn_native/Pipeline.h
index bfc846b..9df2db6 100644
--- a/src/dawn_native/Pipeline.h
+++ b/src/dawn_native/Pipeline.h
@@ -33,27 +33,40 @@
const PipelineLayoutBase* layout,
SingleShaderStage stage);
+ struct ProgrammableStage {
+ Ref<ShaderModuleBase> module;
+ std::string entryPoint;
+ };
+
class PipelineBase : public CachedObject {
public:
- wgpu::ShaderStage GetStageMask() const;
PipelineLayoutBase* GetLayout();
const PipelineLayoutBase* GetLayout() const;
+ const RequiredBufferSizes& GetMinBufferSizes() const;
+ const ProgrammableStage& GetStage(SingleShaderStage stage) const;
+
BindGroupLayoutBase* GetBindGroupLayout(uint32_t groupIndex);
- const RequiredBufferSizes& GetMinimumBufferSizes() const;
+
+ // Helper function for the functors for std::unordered_map-based pipeline caches.
+ static size_t HashForCache(const PipelineBase* pipeline);
+ static bool EqualForCache(const PipelineBase* a, const PipelineBase* b);
protected:
+ using StageAndDescriptor = std::pair<SingleShaderStage, const ProgrammableStageDescriptor*>;
+
PipelineBase(DeviceBase* device,
PipelineLayoutBase* layout,
- wgpu::ShaderStage stages,
- RequiredBufferSizes bufferSizes);
+ std::vector<StageAndDescriptor> stages);
PipelineBase(DeviceBase* device, ObjectBase::ErrorTag tag);
private:
MaybeError ValidateGetBindGroupLayout(uint32_t group);
- wgpu::ShaderStage mStageMask;
+ wgpu::ShaderStage mStageMask = wgpu::ShaderStage::None;
+ PerStage<ProgrammableStage> mStages;
+
Ref<PipelineLayoutBase> mLayout;
- RequiredBufferSizes mMinimumBufferSizes;
+ RequiredBufferSizes mMinBufferSizes;
};
} // namespace dawn_native
diff --git a/src/dawn_native/RenderPipeline.cpp b/src/dawn_native/RenderPipeline.cpp
index 83d9b11..1341ee5 100644
--- a/src/dawn_native/RenderPipeline.cpp
+++ b/src/dawn_native/RenderPipeline.cpp
@@ -193,29 +193,6 @@
return {};
}
- RequiredBufferSizes ComputeMinBufferSizes(const RenderPipelineDescriptor* descriptor) {
- RequiredBufferSizes bufferSizes =
- descriptor->vertexStage.module->ComputeRequiredBufferSizesForLayout(
- descriptor->layout);
-
- // Merge the two buffer size requirements by taking the larger element from each
- if (descriptor->fragmentStage != nullptr) {
- RequiredBufferSizes fragmentSizes =
- descriptor->fragmentStage->module->ComputeRequiredBufferSizesForLayout(
- descriptor->layout);
-
- for (BindGroupIndex group(0); group < bufferSizes.size(); ++group) {
- ASSERT(bufferSizes[group].size() == fragmentSizes[group].size());
- for (size_t i = 0; i < bufferSizes[group].size(); ++i) {
- bufferSizes[group][i] =
- std::max(bufferSizes[group][i], fragmentSizes[group][i]);
- }
- }
- }
-
- return bufferSizes;
- }
-
} // anonymous namespace
// Helper functions
@@ -411,16 +388,12 @@
const RenderPipelineDescriptor* descriptor)
: PipelineBase(device,
descriptor->layout,
- wgpu::ShaderStage::Vertex | wgpu::ShaderStage::Fragment,
- ComputeMinBufferSizes(descriptor)),
+ {{SingleShaderStage::Vertex, &descriptor->vertexStage},
+ {SingleShaderStage::Fragment, descriptor->fragmentStage}}),
mAttachmentState(device->GetOrCreateAttachmentState(descriptor)),
mPrimitiveTopology(descriptor->primitiveTopology),
mSampleMask(descriptor->sampleMask),
- mAlphaToCoverageEnabled(descriptor->alphaToCoverageEnabled),
- mVertexModule(descriptor->vertexStage.module),
- mVertexEntryPoint(descriptor->vertexStage.entryPoint),
- mFragmentModule(descriptor->fragmentStage->module),
- mFragmentEntryPoint(descriptor->fragmentStage->entryPoint) {
+ mAlphaToCoverageEnabled(descriptor->alphaToCoverageEnabled) {
if (descriptor->vertexState != nullptr) {
mVertexState = *descriptor->vertexState;
} else {
@@ -608,12 +581,8 @@
}
size_t RenderPipelineBase::HashFunc::operator()(const RenderPipelineBase* pipeline) const {
- size_t hash = 0;
-
// Hash modules and layout
- HashCombine(&hash, pipeline->GetLayout());
- HashCombine(&hash, pipeline->mVertexModule.Get(), pipeline->mFragmentEntryPoint);
- HashCombine(&hash, pipeline->mFragmentModule.Get(), pipeline->mFragmentEntryPoint);
+ size_t hash = PipelineBase::HashForCache(pipeline);
// Hierarchically hash the attachment state.
// It contains the attachments set, texture formats, and sample count.
@@ -671,11 +640,8 @@
bool RenderPipelineBase::EqualityFunc::operator()(const RenderPipelineBase* a,
const RenderPipelineBase* b) const {
- // Check modules and layout
- if (a->GetLayout() != b->GetLayout() || a->mVertexModule.Get() != b->mVertexModule.Get() ||
- a->mVertexEntryPoint != b->mVertexEntryPoint ||
- a->mFragmentModule.Get() != b->mFragmentModule.Get() ||
- a->mFragmentEntryPoint != b->mFragmentEntryPoint) {
+ // Check the layout and shader stages.
+ if (!PipelineBase::EqualForCache(a, b)) {
return false;
}
diff --git a/src/dawn_native/RenderPipeline.h b/src/dawn_native/RenderPipeline.h
index 5c328f9..aee8d3d 100644
--- a/src/dawn_native/RenderPipeline.h
+++ b/src/dawn_native/RenderPipeline.h
@@ -114,13 +114,6 @@
RasterizationStateDescriptor mRasterizationState;
uint32_t mSampleMask;
bool mAlphaToCoverageEnabled;
-
- // Stage information
- // TODO(cwallez@chromium.org): Store a crypto hash of the modules instead.
- Ref<ShaderModuleBase> mVertexModule;
- std::string mVertexEntryPoint;
- Ref<ShaderModuleBase> mFragmentModule;
- std::string mFragmentEntryPoint;
};
} // namespace dawn_native