Unify ProgrammableStageDescriptor handling in PipelineBase

Previously both Render and Compute pipelines handled extracting data
from the ProgrammableStageDescriptors. Unify them in PipelineBase in
preparation for gathering EntryPointMetadata in the PipelineBase.

Bug: dawn:216
Change-Id: I633dd2d8c9fdd0c08bb34cbf18955445951e312f
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/27263
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
Reviewed-by: Ryan Harrison <rharrison@chromium.org>
diff --git a/src/dawn_native/CommandBufferStateTracker.cpp b/src/dawn_native/CommandBufferStateTracker.cpp
index 3e806a1..3418753 100644
--- a/src/dawn_native/CommandBufferStateTracker.cpp
+++ b/src/dawn_native/CommandBufferStateTracker.cpp
@@ -26,11 +26,11 @@
 
     namespace {
         bool BufferSizesAtLeastAsBig(const ityp::span<uint32_t, uint64_t> unverifiedBufferSizes,
-                                     const std::vector<uint64_t>& pipelineMinimumBufferSizes) {
-            ASSERT(unverifiedBufferSizes.size() == pipelineMinimumBufferSizes.size());
+                                     const std::vector<uint64_t>& pipelineMinBufferSizes) {
+            ASSERT(unverifiedBufferSizes.size() == pipelineMinBufferSizes.size());
 
             for (uint32_t i = 0; i < unverifiedBufferSizes.size(); ++i) {
-                if (unverifiedBufferSizes[i] < pipelineMinimumBufferSizes[i]) {
+                if (unverifiedBufferSizes[i] < pipelineMinBufferSizes[i]) {
                     return false;
                 }
             }
@@ -105,7 +105,7 @@
                 if (mBindgroups[i] == nullptr ||
                     mLastPipelineLayout->GetBindGroupLayout(i) != mBindgroups[i]->GetLayout() ||
                     !BufferSizesAtLeastAsBig(mBindgroups[i]->GetUnverifiedBufferSizes(),
-                                             (*mMinimumBufferSizes)[i])) {
+                                             (*mMinBufferSizes)[i])) {
                     matches = false;
                     break;
                 }
@@ -190,7 +190,7 @@
                         "Pipeline and bind group layout doesn't match for bind group " +
                         std::to_string(static_cast<uint32_t>(i)));
                 } else if (!BufferSizesAtLeastAsBig(mBindgroups[i]->GetUnverifiedBufferSizes(),
-                                                    (*mMinimumBufferSizes)[i])) {
+                                                    (*mMinBufferSizes)[i])) {
                     return DAWN_VALIDATION_ERROR("Binding sizes too small for bind group " +
                                                  std::to_string(static_cast<uint32_t>(i)));
                 }
@@ -236,7 +236,7 @@
 
     void CommandBufferStateTracker::SetPipelineCommon(PipelineBase* pipeline) {
         mLastPipelineLayout = pipeline->GetLayout();
-        mMinimumBufferSizes = &pipeline->GetMinimumBufferSizes();
+        mMinBufferSizes = &pipeline->GetMinBufferSizes();
 
         mAspects.set(VALIDATION_ASPECT_PIPELINE);
 
diff --git a/src/dawn_native/CommandBufferStateTracker.h b/src/dawn_native/CommandBufferStateTracker.h
index 67645e4..146214d 100644
--- a/src/dawn_native/CommandBufferStateTracker.h
+++ b/src/dawn_native/CommandBufferStateTracker.h
@@ -61,7 +61,7 @@
         PipelineLayoutBase* mLastPipelineLayout = nullptr;
         RenderPipelineBase* mLastRenderPipeline = nullptr;
 
-        const RequiredBufferSizes* mMinimumBufferSizes = nullptr;
+        const RequiredBufferSizes* mMinBufferSizes = nullptr;
     };
 
 }  // namespace dawn_native
diff --git a/src/dawn_native/ComputePipeline.cpp b/src/dawn_native/ComputePipeline.cpp
index ee49b11..793765d 100644
--- a/src/dawn_native/ComputePipeline.cpp
+++ b/src/dawn_native/ComputePipeline.cpp
@@ -19,13 +19,6 @@
 
 namespace dawn_native {
 
-    namespace {
-        RequiredBufferSizes ComputeMinBufferSizes(const ComputePipelineDescriptor* descriptor) {
-            return descriptor->computeStage.module->ComputeRequiredBufferSizesForLayout(
-                descriptor->layout);
-        }
-    }  // anonymous namespace
-
     MaybeError ValidateComputePipelineDescriptor(DeviceBase* device,
                                                  const ComputePipelineDescriptor* descriptor) {
         if (descriptor->nextInChain != nullptr) {
@@ -47,10 +40,7 @@
                                              const ComputePipelineDescriptor* descriptor)
         : PipelineBase(device,
                        descriptor->layout,
-                       wgpu::ShaderStage::Compute,
-                       ComputeMinBufferSizes(descriptor)),
-          mModule(descriptor->computeStage.module),
-          mEntryPoint(descriptor->computeStage.entryPoint) {
+                       {{SingleShaderStage::Compute, &descriptor->computeStage}}) {
     }
 
     ComputePipelineBase::ComputePipelineBase(DeviceBase* device, ObjectBase::ErrorTag tag)
@@ -70,15 +60,12 @@
     }
 
     size_t ComputePipelineBase::HashFunc::operator()(const ComputePipelineBase* pipeline) const {
-        size_t hash = 0;
-        HashCombine(&hash, pipeline->mModule.Get(), pipeline->mEntryPoint, pipeline->GetLayout());
-        return hash;
+        return PipelineBase::HashForCache(pipeline);
     }
 
     bool ComputePipelineBase::EqualityFunc::operator()(const ComputePipelineBase* a,
                                                        const ComputePipelineBase* b) const {
-        return a->mModule.Get() == b->mModule.Get() && a->mEntryPoint == b->mEntryPoint &&
-               a->GetLayout() == b->GetLayout();
+        return PipelineBase::EqualForCache(a, b);
     }
 
 }  // namespace dawn_native
diff --git a/src/dawn_native/ComputePipeline.h b/src/dawn_native/ComputePipeline.h
index 43d7966..c2f1188 100644
--- a/src/dawn_native/ComputePipeline.h
+++ b/src/dawn_native/ComputePipeline.h
@@ -41,10 +41,6 @@
 
       private:
         ComputePipelineBase(DeviceBase* device, ObjectBase::ErrorTag tag);
-
-        // TODO(cwallez@chromium.org): Store a crypto hash of the module instead.
-        Ref<ShaderModuleBase> mModule;
-        std::string mEntryPoint;
     };
 
 }  // namespace dawn_native
diff --git a/src/dawn_native/Pipeline.cpp b/src/dawn_native/Pipeline.cpp
index ae05c02..f37e30c 100644
--- a/src/dawn_native/Pipeline.cpp
+++ b/src/dawn_native/Pipeline.cpp
@@ -14,6 +14,7 @@
 
 #include "dawn_native/Pipeline.h"
 
+#include "common/HashUtils.h"
 #include "dawn_native/BindGroupLayout.h"
 #include "dawn_native/Device.h"
 #include "dawn_native/PipelineLayout.h"
@@ -43,23 +44,38 @@
 
     PipelineBase::PipelineBase(DeviceBase* device,
                                PipelineLayoutBase* layout,
-                               wgpu::ShaderStage stages,
-                               RequiredBufferSizes minimumBufferSizes)
-        : CachedObject(device),
-          mStageMask(stages),
-          mLayout(layout),
-          mMinimumBufferSizes(std::move(minimumBufferSizes)) {
+                               std::vector<StageAndDescriptor> stages)
+        : CachedObject(device), mLayout(layout) {
+        ASSERT(!stages.empty());
+
+        for (const StageAndDescriptor& stage : stages) {
+            bool isFirstStage = mStageMask == wgpu::ShaderStage::None;
+            mStageMask |= StageBit(stage.first);
+            mStages[stage.first] = {stage.second->module, stage.second->entryPoint};
+
+            // Compute the max() of all minBufferSizes across all stages.
+            RequiredBufferSizes stageMinBufferSizes =
+                stage.second->module->ComputeRequiredBufferSizesForLayout(layout);
+
+            if (isFirstStage) {
+                mMinBufferSizes = std::move(stageMinBufferSizes);
+            } else {
+                for (BindGroupIndex group(0); group < mMinBufferSizes.size(); ++group) {
+                    ASSERT(stageMinBufferSizes[group].size() == mMinBufferSizes[group].size());
+
+                    for (size_t i = 0; i < stageMinBufferSizes[group].size(); ++i) {
+                        mMinBufferSizes[group][i] =
+                            std::max(mMinBufferSizes[group][i], stageMinBufferSizes[group][i]);
+                    }
+                }
+            }
+        }
     }
 
     PipelineBase::PipelineBase(DeviceBase* device, ObjectBase::ErrorTag tag)
         : CachedObject(device, tag) {
     }
 
-    wgpu::ShaderStage PipelineBase::GetStageMask() const {
-        ASSERT(!IsError());
-        return mStageMask;
-    }
-
     PipelineLayoutBase* PipelineBase::GetLayout() {
         ASSERT(!IsError());
         return mLayout.Get();
@@ -70,9 +86,14 @@
         return mLayout.Get();
     }
 
-    const RequiredBufferSizes& PipelineBase::GetMinimumBufferSizes() const {
+    const RequiredBufferSizes& PipelineBase::GetMinBufferSizes() const {
         ASSERT(!IsError());
-        return mMinimumBufferSizes;
+        return mMinBufferSizes;
+    }
+
+    const ProgrammableStage& PipelineBase::GetStage(SingleShaderStage stage) const {
+        ASSERT(!IsError());
+        return mStages[stage];
     }
 
     MaybeError PipelineBase::ValidateGetBindGroupLayout(uint32_t groupIndex) {
@@ -102,4 +123,39 @@
         return bgl;
     }
 
+    // static
+    size_t PipelineBase::HashForCache(const PipelineBase* pipeline) {
+        size_t hash = 0;
+
+        // The layout is deduplicated so it can be hashed by pointer.
+        HashCombine(&hash, pipeline->mLayout.Get());
+
+        HashCombine(&hash, pipeline->mStageMask);
+        for (SingleShaderStage stage : IterateStages(pipeline->mStageMask)) {
+            // The module is deduplicated so it can be hashed by pointer.
+            HashCombine(&hash, pipeline->mStages[stage].module.Get());
+            HashCombine(&hash, pipeline->mStages[stage].entryPoint);
+        }
+
+        return hash;
+    }
+
+    // static
+    bool PipelineBase::EqualForCache(const PipelineBase* a, const PipelineBase* b) {
+        // The layout is deduplicated so it can be compared by pointer.
+        if (a->mLayout.Get() != b->mLayout.Get() || a->mStageMask != b->mStageMask) {
+            return false;
+        }
+
+        for (SingleShaderStage stage : IterateStages(a->mStageMask)) {
+            // The module is deduplicated so it can be compared by pointer.
+            if (a->mStages[stage].module.Get() != b->mStages[stage].module.Get() ||
+                a->mStages[stage].entryPoint != b->mStages[stage].entryPoint) {
+                return false;
+            }
+        }
+
+        return true;
+    }
+
 }  // namespace dawn_native
diff --git a/src/dawn_native/Pipeline.h b/src/dawn_native/Pipeline.h
index bfc846b..9df2db6 100644
--- a/src/dawn_native/Pipeline.h
+++ b/src/dawn_native/Pipeline.h
@@ -33,27 +33,40 @@
                                                    const PipelineLayoutBase* layout,
                                                    SingleShaderStage stage);
 
+    struct ProgrammableStage {
+        Ref<ShaderModuleBase> module;
+        std::string entryPoint;
+    };
+
     class PipelineBase : public CachedObject {
       public:
-        wgpu::ShaderStage GetStageMask() const;
         PipelineLayoutBase* GetLayout();
         const PipelineLayoutBase* GetLayout() const;
+        const RequiredBufferSizes& GetMinBufferSizes() const;
+        const ProgrammableStage& GetStage(SingleShaderStage stage) const;
+
         BindGroupLayoutBase* GetBindGroupLayout(uint32_t groupIndex);
-        const RequiredBufferSizes& GetMinimumBufferSizes() const;
+
+        // Helper function for the functors for std::unordered_map-based pipeline caches.
+        static size_t HashForCache(const PipelineBase* pipeline);
+        static bool EqualForCache(const PipelineBase* a, const PipelineBase* b);
 
       protected:
+        using StageAndDescriptor = std::pair<SingleShaderStage, const ProgrammableStageDescriptor*>;
+
         PipelineBase(DeviceBase* device,
                      PipelineLayoutBase* layout,
-                     wgpu::ShaderStage stages,
-                     RequiredBufferSizes bufferSizes);
+                     std::vector<StageAndDescriptor> stages);
         PipelineBase(DeviceBase* device, ObjectBase::ErrorTag tag);
 
       private:
         MaybeError ValidateGetBindGroupLayout(uint32_t group);
 
-        wgpu::ShaderStage mStageMask;
+        wgpu::ShaderStage mStageMask = wgpu::ShaderStage::None;
+        PerStage<ProgrammableStage> mStages;
+
         Ref<PipelineLayoutBase> mLayout;
-        RequiredBufferSizes mMinimumBufferSizes;
+        RequiredBufferSizes mMinBufferSizes;
     };
 
 }  // namespace dawn_native
diff --git a/src/dawn_native/RenderPipeline.cpp b/src/dawn_native/RenderPipeline.cpp
index 83d9b11..1341ee5 100644
--- a/src/dawn_native/RenderPipeline.cpp
+++ b/src/dawn_native/RenderPipeline.cpp
@@ -193,29 +193,6 @@
             return {};
         }
 
-        RequiredBufferSizes ComputeMinBufferSizes(const RenderPipelineDescriptor* descriptor) {
-            RequiredBufferSizes bufferSizes =
-                descriptor->vertexStage.module->ComputeRequiredBufferSizesForLayout(
-                    descriptor->layout);
-
-            // Merge the two buffer size requirements by taking the larger element from each
-            if (descriptor->fragmentStage != nullptr) {
-                RequiredBufferSizes fragmentSizes =
-                    descriptor->fragmentStage->module->ComputeRequiredBufferSizesForLayout(
-                        descriptor->layout);
-
-                for (BindGroupIndex group(0); group < bufferSizes.size(); ++group) {
-                    ASSERT(bufferSizes[group].size() == fragmentSizes[group].size());
-                    for (size_t i = 0; i < bufferSizes[group].size(); ++i) {
-                        bufferSizes[group][i] =
-                            std::max(bufferSizes[group][i], fragmentSizes[group][i]);
-                    }
-                }
-            }
-
-            return bufferSizes;
-        }
-
     }  // anonymous namespace
 
     // Helper functions
@@ -411,16 +388,12 @@
                                            const RenderPipelineDescriptor* descriptor)
         : PipelineBase(device,
                        descriptor->layout,
-                       wgpu::ShaderStage::Vertex | wgpu::ShaderStage::Fragment,
-                       ComputeMinBufferSizes(descriptor)),
+                       {{SingleShaderStage::Vertex, &descriptor->vertexStage},
+                        {SingleShaderStage::Fragment, descriptor->fragmentStage}}),
           mAttachmentState(device->GetOrCreateAttachmentState(descriptor)),
           mPrimitiveTopology(descriptor->primitiveTopology),
           mSampleMask(descriptor->sampleMask),
-          mAlphaToCoverageEnabled(descriptor->alphaToCoverageEnabled),
-          mVertexModule(descriptor->vertexStage.module),
-          mVertexEntryPoint(descriptor->vertexStage.entryPoint),
-          mFragmentModule(descriptor->fragmentStage->module),
-          mFragmentEntryPoint(descriptor->fragmentStage->entryPoint) {
+          mAlphaToCoverageEnabled(descriptor->alphaToCoverageEnabled) {
         if (descriptor->vertexState != nullptr) {
             mVertexState = *descriptor->vertexState;
         } else {
@@ -608,12 +581,8 @@
     }
 
     size_t RenderPipelineBase::HashFunc::operator()(const RenderPipelineBase* pipeline) const {
-        size_t hash = 0;
-
         // Hash modules and layout
-        HashCombine(&hash, pipeline->GetLayout());
-        HashCombine(&hash, pipeline->mVertexModule.Get(), pipeline->mFragmentEntryPoint);
-        HashCombine(&hash, pipeline->mFragmentModule.Get(), pipeline->mFragmentEntryPoint);
+        size_t hash = PipelineBase::HashForCache(pipeline);
 
         // Hierarchically hash the attachment state.
         // It contains the attachments set, texture formats, and sample count.
@@ -671,11 +640,8 @@
 
     bool RenderPipelineBase::EqualityFunc::operator()(const RenderPipelineBase* a,
                                                       const RenderPipelineBase* b) const {
-        // Check modules and layout
-        if (a->GetLayout() != b->GetLayout() || a->mVertexModule.Get() != b->mVertexModule.Get() ||
-            a->mVertexEntryPoint != b->mVertexEntryPoint ||
-            a->mFragmentModule.Get() != b->mFragmentModule.Get() ||
-            a->mFragmentEntryPoint != b->mFragmentEntryPoint) {
+        // Check the layout and shader stages.
+        if (!PipelineBase::EqualForCache(a, b)) {
             return false;
         }
 
diff --git a/src/dawn_native/RenderPipeline.h b/src/dawn_native/RenderPipeline.h
index 5c328f9..aee8d3d 100644
--- a/src/dawn_native/RenderPipeline.h
+++ b/src/dawn_native/RenderPipeline.h
@@ -114,13 +114,6 @@
         RasterizationStateDescriptor mRasterizationState;
         uint32_t mSampleMask;
         bool mAlphaToCoverageEnabled;
-
-        // Stage information
-        // TODO(cwallez@chromium.org): Store a crypto hash of the modules instead.
-        Ref<ShaderModuleBase> mVertexModule;
-        std::string mVertexEntryPoint;
-        Ref<ShaderModuleBase> mFragmentModule;
-        std::string mFragmentEntryPoint;
     };
 
 }  // namespace dawn_native