Make ShaderModule reflection go through EntryPointMetadata

PipelineBase now collects the EntryPointMetadata for all its
stages which makes the rest of the code agnostic to the entrypoint
name (except D3D12 and OpenGL that required transition hacks and
will be fixed in follow-up CLs).

Bug: dawn:216

Change-Id: I643da198cb2a20a9d94d805a2dc783d6d4346ae9
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/27260
Reviewed-by: Ryan Harrison <rharrison@chromium.org>
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/dawn_native/ComputePipeline.h b/src/dawn_native/ComputePipeline.h
index c2f1188..c2b470a 100644
--- a/src/dawn_native/ComputePipeline.h
+++ b/src/dawn_native/ComputePipeline.h
@@ -20,6 +20,7 @@
 namespace dawn_native {
 
     class DeviceBase;
+    struct EntryPointMetadata;
 
     MaybeError ValidateComputePipelineDescriptor(DeviceBase* device,
                                                  const ComputePipelineDescriptor* descriptor);
@@ -31,6 +32,8 @@
 
         static ComputePipelineBase* MakeError(DeviceBase* device);
 
+        const EntryPointMetadata& GetMetadata() const;
+
         // Functors necessary for the unordered_set<ComputePipelineBase*>-based cache.
         struct HashFunc {
             size_t operator()(const ComputePipelineBase* pipeline) const;
diff --git a/src/dawn_native/Device.cpp b/src/dawn_native/Device.cpp
index 1f76e49..5f48898 100644
--- a/src/dawn_native/Device.cpp
+++ b/src/dawn_native/Device.cpp
@@ -877,9 +877,9 @@
         if (descriptor->layout == nullptr) {
             ComputePipelineDescriptor descriptorWithDefaultLayout = *descriptor;
 
-            DAWN_TRY_ASSIGN(
-                descriptorWithDefaultLayout.layout,
-                PipelineLayoutBase::CreateDefault(this, &descriptor->computeStage.module, 1));
+            DAWN_TRY_ASSIGN(descriptorWithDefaultLayout.layout,
+                            PipelineLayoutBase::CreateDefault(
+                                this, {{SingleShaderStage::Compute, &descriptor->computeStage}}));
             // Ref will keep the pipeline layout alive until the end of the function where
             // the pipeline will take another reference.
             Ref<PipelineLayoutBase> layoutRef = AcquireRef(descriptorWithDefaultLayout.layout);
@@ -934,18 +934,14 @@
         if (descriptor->layout == nullptr) {
             RenderPipelineDescriptor descriptorWithDefaultLayout = *descriptor;
 
-            const ShaderModuleBase* modules[2];
-            modules[0] = descriptor->vertexStage.module;
-            uint32_t count;
-            if (descriptor->fragmentStage == nullptr) {
-                count = 1;
-            } else {
-                modules[1] = descriptor->fragmentStage->module;
-                count = 2;
+            std::vector<StageAndDescriptor> stages;
+            stages.emplace_back(SingleShaderStage::Vertex, &descriptor->vertexStage);
+            if (descriptor->fragmentStage != nullptr) {
+                stages.emplace_back(SingleShaderStage::Fragment, descriptor->fragmentStage);
             }
 
             DAWN_TRY_ASSIGN(descriptorWithDefaultLayout.layout,
-                            PipelineLayoutBase::CreateDefault(this, modules, count));
+                            PipelineLayoutBase::CreateDefault(this, std::move(stages)));
             // Ref will keep the pipeline layout alive until the end of the function where
             // the pipeline will take another reference.
             Ref<PipelineLayoutBase> layoutRef = AcquireRef(descriptorWithDefaultLayout.layout);
diff --git a/src/dawn_native/Pipeline.cpp b/src/dawn_native/Pipeline.cpp
index f37e30c..691e3a9 100644
--- a/src/dawn_native/Pipeline.cpp
+++ b/src/dawn_native/Pipeline.cpp
@@ -26,16 +26,17 @@
                                                    const ProgrammableStageDescriptor* descriptor,
                                                    const PipelineLayoutBase* layout,
                                                    SingleShaderStage stage) {
-        DAWN_TRY(device->ValidateObject(descriptor->module));
+        const ShaderModuleBase* module = descriptor->module;
+        DAWN_TRY(device->ValidateObject(module));
 
-        if (descriptor->entryPoint != std::string("main")) {
-            return DAWN_VALIDATION_ERROR("Entry point must be \"main\"");
+        if (!module->HasEntryPoint(descriptor->entryPoint, stage)) {
+            return DAWN_VALIDATION_ERROR("Entry point doesn't exist in the module");
         }
-        if (descriptor->module->GetExecutionModel() != stage) {
-            return DAWN_VALIDATION_ERROR("Setting module with wrong stages");
-        }
+
         if (layout != nullptr) {
-            DAWN_TRY(descriptor->module->ValidateCompatibilityWithPipelineLayout(layout));
+            const EntryPointMetadata& metadata =
+                module->GetEntryPoint(descriptor->entryPoint, stage);
+            DAWN_TRY(ValidateCompatibilityWithPipelineLayout(metadata, layout));
         }
         return {};
     }
@@ -49,13 +50,20 @@
         ASSERT(!stages.empty());
 
         for (const StageAndDescriptor& stage : stages) {
+            // Extract argument for this stage.
+            SingleShaderStage shaderStage = stage.first;
+            ShaderModuleBase* module = stage.second->module;
+            const char* entryPointName = stage.second->entryPoint;
+            const EntryPointMetadata& metadata = module->GetEntryPoint(entryPointName, shaderStage);
+
+            // Record them internally.
             bool isFirstStage = mStageMask == wgpu::ShaderStage::None;
-            mStageMask |= StageBit(stage.first);
-            mStages[stage.first] = {stage.second->module, stage.second->entryPoint};
+            mStageMask |= StageBit(shaderStage);
+            mStages[shaderStage] = {module, entryPointName, &metadata};
 
             // Compute the max() of all minBufferSizes across all stages.
             RequiredBufferSizes stageMinBufferSizes =
-                stage.second->module->ComputeRequiredBufferSizesForLayout(layout);
+                ComputeRequiredBufferSizesForLayout(metadata, layout);
 
             if (isFirstStage) {
                 mMinBufferSizes = std::move(stageMinBufferSizes);
diff --git a/src/dawn_native/Pipeline.h b/src/dawn_native/Pipeline.h
index 9df2db6..44e3f98 100644
--- a/src/dawn_native/Pipeline.h
+++ b/src/dawn_native/Pipeline.h
@@ -36,6 +36,9 @@
     struct ProgrammableStage {
         Ref<ShaderModuleBase> module;
         std::string entryPoint;
+
+        // The metadata lives as long as module, that's ref-ed in the same structure.
+        const EntryPointMetadata* metadata = nullptr;
     };
 
     class PipelineBase : public CachedObject {
@@ -52,8 +55,6 @@
         static bool EqualForCache(const PipelineBase* a, const PipelineBase* b);
 
       protected:
-        using StageAndDescriptor = std::pair<SingleShaderStage, const ProgrammableStageDescriptor*>;
-
         PipelineBase(DeviceBase* device,
                      PipelineLayoutBase* layout,
                      std::vector<StageAndDescriptor> stages);
diff --git a/src/dawn_native/PipelineLayout.cpp b/src/dawn_native/PipelineLayout.cpp
index f34003a..61fdf74 100644
--- a/src/dawn_native/PipelineLayout.cpp
+++ b/src/dawn_native/PipelineLayout.cpp
@@ -114,9 +114,8 @@
     // static
     ResultOrError<PipelineLayoutBase*> PipelineLayoutBase::CreateDefault(
         DeviceBase* device,
-        const ShaderModuleBase* const* modules,
-        uint32_t count) {
-        ASSERT(count > 0);
+        std::vector<StageAndDescriptor> stages) {
+        ASSERT(!stages.empty());
 
         // Data which BindGroupLayoutDescriptor will point to for creation
         ityp::array<
@@ -134,20 +133,22 @@
 
         BindingCounts bindingCounts = {};
         BindGroupIndex bindGroupLayoutCount(0);
-        for (uint32_t moduleIndex = 0; moduleIndex < count; ++moduleIndex) {
-            const ShaderModuleBase* module = modules[moduleIndex];
-            const ShaderModuleBase::ModuleBindingInfo& info = module->GetBindingInfo();
+        for (const StageAndDescriptor& stage : stages) {
+            // Extract argument for this stage.
+            SingleShaderStage shaderStage = stage.first;
+            const EntryPointMetadata::BindingInfo& info =
+                stage.second->module->GetEntryPoint(stage.second->entryPoint, shaderStage).bindings;
 
             for (BindGroupIndex group(0); group < info.size(); ++group) {
                 for (const auto& it : info[group]) {
                     BindingNumber bindingNumber = it.first;
-                    const ShaderModuleBase::ShaderBindingInfo& bindingInfo = it.second;
+                    const EntryPointMetadata::ShaderBindingInfo& bindingInfo = it.second;
 
                     BindGroupLayoutEntry bindingSlot;
                     bindingSlot.binding = static_cast<uint32_t>(bindingNumber);
 
-                    DAWN_TRY(ValidateBindingTypeWithShaderStageVisibility(
-                        bindingInfo.type, StageBit(module->GetExecutionModel())));
+                    DAWN_TRY(ValidateBindingTypeWithShaderStageVisibility(bindingInfo.type,
+                                                                          StageBit(shaderStage)));
                     DAWN_TRY(ValidateStorageTextureFormat(device, bindingInfo.type,
                                                           bindingInfo.storageTextureFormat));
                     DAWN_TRY(ValidateStorageTextureViewDimension(bindingInfo.type,
@@ -239,10 +240,10 @@
             }
         }
 
-        for (uint32_t moduleIndex = 0; moduleIndex < count; ++moduleIndex) {
-            ASSERT(modules[moduleIndex]
-                       ->ValidateCompatibilityWithPipelineLayout(pipelineLayout)
-                       .IsSuccess());
+        for (const StageAndDescriptor& stage : stages) {
+            const EntryPointMetadata& metadata =
+                stage.second->module->GetEntryPoint(stage.second->entryPoint, stage.first);
+            ASSERT(ValidateCompatibilityWithPipelineLayout(metadata, pipelineLayout).IsSuccess());
         }
 
         return pipelineLayout;
diff --git a/src/dawn_native/PipelineLayout.h b/src/dawn_native/PipelineLayout.h
index 862caaf..be8a75c 100644
--- a/src/dawn_native/PipelineLayout.h
+++ b/src/dawn_native/PipelineLayout.h
@@ -37,14 +37,17 @@
         ityp::array<BindGroupIndex, Ref<BindGroupLayoutBase>, kMaxBindGroups>;
     using BindGroupLayoutMask = ityp::bitset<BindGroupIndex, kMaxBindGroups>;
 
+    using StageAndDescriptor = std::pair<SingleShaderStage, const ProgrammableStageDescriptor*>;
+
     class PipelineLayoutBase : public CachedObject {
       public:
         PipelineLayoutBase(DeviceBase* device, const PipelineLayoutDescriptor* descriptor);
         ~PipelineLayoutBase() override;
 
         static PipelineLayoutBase* MakeError(DeviceBase* device);
-        static ResultOrError<PipelineLayoutBase*>
-        CreateDefault(DeviceBase* device, const ShaderModuleBase* const* modules, uint32_t count);
+        static ResultOrError<PipelineLayoutBase*> CreateDefault(
+            DeviceBase* device,
+            std::vector<StageAndDescriptor> stages);
 
         const BindGroupLayoutBase* GetBindGroupLayout(BindGroupIndex group) const;
         BindGroupLayoutBase* GetBindGroupLayout(BindGroupIndex group);
diff --git a/src/dawn_native/RenderPipeline.cpp b/src/dawn_native/RenderPipeline.cpp
index 2d91291..62ea4b3 100644
--- a/src/dawn_native/RenderPipeline.cpp
+++ b/src/dawn_native/RenderPipeline.cpp
@@ -333,8 +333,9 @@
             DAWN_TRY(ValidateRasterizationStateDescriptor(descriptor->rasterizationState));
         }
 
-        if ((descriptor->vertexStage.module->GetUsedVertexAttributes() & ~attributesSetMask)
-                .any()) {
+        const EntryPointMetadata& vertexMetadata = descriptor->vertexStage.module->GetEntryPoint(
+            descriptor->vertexStage.entryPoint, SingleShaderStage::Vertex);
+        if ((vertexMetadata.usedVertexAttributes & ~attributesSetMask).any()) {
             return DAWN_VALIDATION_ERROR(
                 "Pipeline vertex stage uses vertex buffers not in the vertex state");
         }
@@ -352,11 +353,13 @@
         }
 
         ASSERT(descriptor->fragmentStage != nullptr);
-        const ShaderModuleBase::FragmentOutputBaseTypes& fragmentOutputBaseTypes =
-            descriptor->fragmentStage->module->GetFragmentOutputBaseTypes();
+        const EntryPointMetadata& fragmentMetadata =
+            descriptor->fragmentStage->module->GetEntryPoint(descriptor->fragmentStage->entryPoint,
+                                                             SingleShaderStage::Fragment);
         for (uint32_t i = 0; i < descriptor->colorStateCount; ++i) {
-            DAWN_TRY(ValidateColorStateDescriptor(device, descriptor->colorStates[i],
-                                                  fragmentOutputBaseTypes[i]));
+            DAWN_TRY(
+                ValidateColorStateDescriptor(device, descriptor->colorStates[i],
+                                             fragmentMetadata.fragmentOutputFormatBaseTypes[i]));
         }
 
         if (descriptor->depthStencilState) {
diff --git a/src/dawn_native/RenderPipeline.h b/src/dawn_native/RenderPipeline.h
index f06abff..002b330 100644
--- a/src/dawn_native/RenderPipeline.h
+++ b/src/dawn_native/RenderPipeline.h
@@ -28,6 +28,7 @@
     struct BeginRenderPassCmd;
 
     class DeviceBase;
+    struct EntryPointMetadata;
     class RenderBundleEncoder;
 
     MaybeError ValidateRenderPipelineDescriptor(DeviceBase* device,
diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp
index 03ddf03..96cd483 100644
--- a/src/dawn_native/ShaderModule.cpp
+++ b/src/dawn_native/ShaderModule.cpp
@@ -549,7 +549,7 @@
 #endif  // DAWN_ENABLE_WGSL
 
         std::vector<uint64_t> GetBindGroupMinBufferSizes(
-            const ShaderModuleBase::BindingInfoMap& shaderBindings,
+            const EntryPointMetadata::BindingGroupInfoMap& shaderBindings,
             const BindGroupLayoutBase* layout) {
             std::vector<uint64_t> requiredBufferSizes(layout->GetUnverifiedBufferCount());
             uint32_t packedIdx = 0;
@@ -578,17 +578,16 @@
             return requiredBufferSizes;
         }
 
-        MaybeError ValidateCompatibilityWithBindGroupLayout(
-            BindGroupIndex group,
-            const ShaderModuleBase::EntryPointMetadata& entryPoint,
-            const BindGroupLayoutBase* layout) {
+        MaybeError ValidateCompatibilityWithBindGroupLayout(BindGroupIndex group,
+                                                            const EntryPointMetadata& entryPoint,
+                                                            const BindGroupLayoutBase* layout) {
             const BindGroupLayoutBase::BindingMap& layoutBindings = layout->GetBindingMap();
 
             // Iterate over all bindings used by this group in the shader, and find the
             // corresponding binding in the BindGroupLayout, if it exists.
             for (const auto& it : entryPoint.bindings[group]) {
                 BindingNumber bindingNumber = it.first;
-                const ShaderModuleBase::ShaderBindingInfo& shaderInfo = it.second;
+                const EntryPointMetadata::ShaderBindingInfo& shaderInfo = it.second;
 
                 const auto& bindingIt = layoutBindings.find(bindingNumber);
                 if (bindingIt == layoutBindings.end()) {
@@ -732,9 +731,8 @@
         return {};
     }
 
-    RequiredBufferSizes ComputeRequiredBufferSizesForLayout(
-        const ShaderModuleBase::EntryPointMetadata& entryPoint,
-        const PipelineLayoutBase* layout) {
+    RequiredBufferSizes ComputeRequiredBufferSizesForLayout(const EntryPointMetadata& entryPoint,
+                                                            const PipelineLayoutBase* layout) {
         RequiredBufferSizes bufferSizes;
         for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
             bufferSizes[group] = GetBindGroupMinBufferSizes(entryPoint.bindings[group],
@@ -744,9 +742,8 @@
         return bufferSizes;
     }
 
-    MaybeError ValidateCompatibilityWithPipelineLayout(
-        const ShaderModuleBase::EntryPointMetadata& entryPoint,
-        const PipelineLayoutBase* layout) {
+    MaybeError ValidateCompatibilityWithPipelineLayout(const EntryPointMetadata& entryPoint,
+                                                       const PipelineLayoutBase* layout) {
         for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
             DAWN_TRY(ValidateCompatibilityWithBindGroupLayout(group, entryPoint,
                                                               layout->GetBindGroupLayout(group)));
@@ -766,7 +763,7 @@
 
     // EntryPointMetadata
 
-    ShaderModuleBase::EntryPointMetadata::EntryPointMetadata() {
+    EntryPointMetadata::EntryPointMetadata() {
         fragmentOutputFormatBaseTypes.fill(Format::Type::Other);
     }
 
@@ -814,6 +811,20 @@
         return new ShaderModuleBase(device, ObjectBase::kError);
     }
 
+    bool ShaderModuleBase::HasEntryPoint(const std::string& entryPoint,
+                                         SingleShaderStage stage) const {
+        // TODO(dawn:216): Properly extract all entryPoints from the shader module.
+        return entryPoint == "main" && stage == mMainEntryPoint->stage;
+    }
+
+    const EntryPointMetadata& ShaderModuleBase::GetEntryPoint(const std::string& entryPoint,
+                                                              SingleShaderStage stage) const {
+        // TODO(dawn:216): Properly extract all entryPoints from the shader module.
+        ASSERT(entryPoint == "main");
+        ASSERT(stage == mMainEntryPoint->stage);
+        return *mMainEntryPoint;
+    }
+
     MaybeError ShaderModuleBase::ExtractSpirvInfo(const spirv_cross::Compiler& compiler) {
         ASSERT(!IsError());
         if (GetDevice()->IsToggleEnabled(Toggle::UseSpvc)) {
@@ -824,7 +835,7 @@
         return {};
     }
 
-    ResultOrError<std::unique_ptr<ShaderModuleBase::EntryPointMetadata>>
+    ResultOrError<std::unique_ptr<EntryPointMetadata>>
     ShaderModuleBase::ExtractSpirvInfoWithSpvc() {
         DeviceBase* device = GetDevice();
         std::unique_ptr<EntryPointMetadata> metadata = std::make_unique<EntryPointMetadata>();
@@ -848,7 +859,7 @@
         // Fill in bindingInfo with the SPIRV bindings
         auto ExtractResourcesBinding =
             [](const DeviceBase* device, const std::vector<shaderc_spvc_binding_info>& spvcBindings,
-               ModuleBindingInfo* metadataBindings) -> MaybeError {
+               EntryPointMetadata::BindingInfo* metadataBindings) -> MaybeError {
             for (const shaderc_spvc_binding_info& binding : spvcBindings) {
                 BindGroupIndex bindGroupIndex(binding.set);
 
@@ -857,12 +868,12 @@
                 }
 
                 const auto& it = (*metadataBindings)[bindGroupIndex].emplace(
-                    BindingNumber(binding.binding), ShaderBindingInfo{});
+                    BindingNumber(binding.binding), EntryPointMetadata::ShaderBindingInfo{});
                 if (!it.second) {
                     return DAWN_VALIDATION_ERROR("Shader has duplicate bindings");
                 }
 
-                ShaderBindingInfo* info = &it.first->second;
+                EntryPointMetadata::ShaderBindingInfo* info = &it.first->second;
                 info->id = binding.id;
                 info->base_type_id = binding.base_type_id;
                 info->type = ToWGPUBindingType(binding.binding_type);
@@ -994,7 +1005,7 @@
         return {std::move(metadata)};
     }
 
-    ResultOrError<std::unique_ptr<ShaderModuleBase::EntryPointMetadata>>
+    ResultOrError<std::unique_ptr<EntryPointMetadata>>
     ShaderModuleBase::ExtractSpirvInfoWithSpirvCross(const spirv_cross::Compiler& compiler) {
         DeviceBase* device = GetDevice();
         std::unique_ptr<EntryPointMetadata> metadata = std::make_unique<EntryPointMetadata>();
@@ -1031,7 +1042,7 @@
             [](const DeviceBase* device,
                const spirv_cross::SmallVector<spirv_cross::Resource>& resources,
                const spirv_cross::Compiler& compiler, wgpu::BindingType bindingType,
-               ModuleBindingInfo* metadataBindings) -> MaybeError {
+               EntryPointMetadata::BindingInfo* metadataBindings) -> MaybeError {
             for (const auto& resource : resources) {
                 if (!compiler.get_decoration_bitset(resource.id).get(spv::DecorationBinding)) {
                     return DAWN_VALIDATION_ERROR("No Binding decoration set for resource");
@@ -1051,13 +1062,13 @@
                     return DAWN_VALIDATION_ERROR("Bind group index over limits in the SPIRV");
                 }
 
-                const auto& it =
-                    (*metadataBindings)[bindGroupIndex].emplace(bindingNumber, ShaderBindingInfo{});
+                const auto& it = (*metadataBindings)[bindGroupIndex].emplace(
+                    bindingNumber, EntryPointMetadata::ShaderBindingInfo{});
                 if (!it.second) {
                     return DAWN_VALIDATION_ERROR("Shader has duplicate bindings");
                 }
 
-                ShaderBindingInfo* info = &it.first->second;
+                EntryPointMetadata::ShaderBindingInfo* info = &it.first->second;
                 info->id = resource.id;
                 info->base_type_id = resource.base_type_id;
 
@@ -1204,39 +1215,6 @@
         return {std::move(metadata)};
     }
 
-    const ShaderModuleBase::ModuleBindingInfo& ShaderModuleBase::GetBindingInfo() const {
-        ASSERT(!IsError());
-        return mMainEntryPoint->bindings;
-    }
-
-    const std::bitset<kMaxVertexAttributes>& ShaderModuleBase::GetUsedVertexAttributes() const {
-        ASSERT(!IsError());
-        return mMainEntryPoint->usedVertexAttributes;
-    }
-
-    const ShaderModuleBase::FragmentOutputBaseTypes& ShaderModuleBase::GetFragmentOutputBaseTypes()
-        const {
-        ASSERT(!IsError());
-        return mMainEntryPoint->fragmentOutputFormatBaseTypes;
-    }
-
-    SingleShaderStage ShaderModuleBase::GetExecutionModel() const {
-        ASSERT(!IsError());
-        return mMainEntryPoint->stage;
-    }
-
-    RequiredBufferSizes ShaderModuleBase::ComputeRequiredBufferSizesForLayout(
-        const PipelineLayoutBase* layout) const {
-        ASSERT(!IsError());
-        return ::dawn_native::ComputeRequiredBufferSizesForLayout(*mMainEntryPoint, layout);
-    }
-
-    MaybeError ShaderModuleBase::ValidateCompatibilityWithPipelineLayout(
-        const PipelineLayoutBase* layout) const {
-        ASSERT(!IsError());
-        return ::dawn_native::ValidateCompatibilityWithPipelineLayout(*mMainEntryPoint, layout);
-    }
-
     size_t ShaderModuleBase::HashFunc::operator()(const ShaderModuleBase* module) const {
         size_t hash = 0;
 
@@ -1298,4 +1276,10 @@
 
         return {};
     }
+
+    SingleShaderStage ShaderModuleBase::GetMainEntryPointStageForTransition() const {
+        ASSERT(!IsError());
+        return mMainEntryPoint->stage;
+    }
+
 }  // namespace dawn_native
diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h
index f6779aa..15d31ae 100644
--- a/src/dawn_native/ShaderModule.h
+++ b/src/dawn_native/ShaderModule.h
@@ -38,18 +38,25 @@
 
 namespace dawn_native {
 
+    struct EntryPointMetadata;
+
     MaybeError ValidateShaderModuleDescriptor(DeviceBase* device,
                                               const ShaderModuleDescriptor* descriptor);
+    MaybeError ValidateCompatibilityWithPipelineLayout(const EntryPointMetadata& entryPoint,
+                                                       const PipelineLayoutBase* layout);
 
-    class ShaderModuleBase : public CachedObject {
-      public:
-        ShaderModuleBase(DeviceBase* device, const ShaderModuleDescriptor* descriptor);
-        ~ShaderModuleBase() override;
+    RequiredBufferSizes ComputeRequiredBufferSizesForLayout(const EntryPointMetadata& entryPoint,
+                                                            const PipelineLayoutBase* layout);
 
-        static ShaderModuleBase* MakeError(DeviceBase* device);
+    // Contains all the reflection data for a valid (ShaderModule, entryPoint, stage). They are
+    // stored in the ShaderModuleBase and destroyed only when the shader module is destroyed so
+    // pointers to EntryPointMetadata are safe to store as long as you also keep a Ref to the
+    // ShaderModuleBase.
+    struct EntryPointMetadata {
+        EntryPointMetadata();
 
-        MaybeError ExtractSpirvInfo(const spirv_cross::Compiler& compiler);
-
+        // Per-binding shader metadata contains some SPIRV specific information in addition to
+        // most of the frontend per-binding information.
         struct ShaderBindingInfo : BindingInfo {
             // The SPIRV ID of the resource.
             uint32_t id;
@@ -61,22 +68,42 @@
             using BindingInfo::visibility;
         };
 
-        using BindingInfoMap = std::map<BindingNumber, ShaderBindingInfo>;
-        using ModuleBindingInfo = ityp::array<BindGroupIndex, BindingInfoMap, kMaxBindGroups>;
+        // bindings[G][B] is the reflection data for the binding defined with
+        // [[group=G, binding=B]] in WGSL / SPIRV.
+        using BindingGroupInfoMap = std::map<BindingNumber, ShaderBindingInfo>;
+        using BindingInfo = ityp::array<BindGroupIndex, BindingGroupInfoMap, kMaxBindGroups>;
+        BindingInfo bindings;
 
-        const ModuleBindingInfo& GetBindingInfo() const;
-        const std::bitset<kMaxVertexAttributes>& GetUsedVertexAttributes() const;
-        SingleShaderStage GetExecutionModel() const;
+        // The set of vertex attributes this entryPoint uses.
+        std::bitset<kMaxVertexAttributes> usedVertexAttributes;
 
         // An array to record the basic types (float, int and uint) of the fragment shader outputs
         // or Format::Type::Other means the fragment shader output is unused.
         using FragmentOutputBaseTypes = std::array<Format::Type, kMaxColorAttachments>;
-        const FragmentOutputBaseTypes& GetFragmentOutputBaseTypes() const;
+        FragmentOutputBaseTypes fragmentOutputFormatBaseTypes;
 
-        MaybeError ValidateCompatibilityWithPipelineLayout(const PipelineLayoutBase* layout) const;
+        // The shader stage for this binding, TODO(dawn:216): can likely be removed once we
+        // properly support multiple entrypoints per ShaderModule.
+        SingleShaderStage stage;
+    };
 
-        RequiredBufferSizes ComputeRequiredBufferSizesForLayout(
-            const PipelineLayoutBase* layout) const;
+    class ShaderModuleBase : public CachedObject {
+      public:
+        ShaderModuleBase(DeviceBase* device, const ShaderModuleDescriptor* descriptor);
+        ~ShaderModuleBase() override;
+
+        static ShaderModuleBase* MakeError(DeviceBase* device);
+
+        // Return true iff the module has an entrypoint called `entryPoint` for stage `stage`.
+        bool HasEntryPoint(const std::string& entryPoint, SingleShaderStage stage) const;
+
+        // Returns the metadata for the given `entryPoint` and `stage`. HasEntryPoint with the same
+        // arguments must be true.
+        const EntryPointMetadata& GetEntryPoint(const std::string& entryPoint,
+                                                SingleShaderStage stage) const;
+
+        // TODO make this member protected, it is only used outside of child classes in DeviceNull.
+        MaybeError ExtractSpirvInfo(const spirv_cross::Compiler& compiler);
 
         // Functors necessary for the unordered_set<ShaderModuleBase*>-based cache.
         struct HashFunc {
@@ -96,15 +123,6 @@
             uint32_t pullingBufferBindingSet) const;
 #endif
 
-        struct EntryPointMetadata {
-            EntryPointMetadata();
-
-            ModuleBindingInfo bindings;
-            std::bitset<kMaxVertexAttributes> usedVertexAttributes;
-            SingleShaderStage stage;
-            FragmentOutputBaseTypes fragmentOutputFormatBaseTypes;
-        };
-
       protected:
         static MaybeError CheckSpvcSuccess(shaderc_spvc_status status, const char* error_msg);
         shaderc_spvc::CompileOptions GetCompileOptions() const;
@@ -112,6 +130,11 @@
 
         shaderc_spvc::Context mSpvcContext;
 
+        // Allows backends to get the stage for the "main" entrypoint while they are transitioned to
+        // support multiple entrypoints.
+        // TODO(dawn:216): Remove this once the transition is complete.
+        SingleShaderStage GetMainEntryPointStageForTransition() const;
+
       private:
         ShaderModuleBase(DeviceBase* device, ObjectBase::ErrorTag tag);
 
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
index 0bdbafa..3b16696 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
@@ -170,13 +170,15 @@
             compiler->set_hlsl_options(options_hlsl);
         }
 
-        const ModuleBindingInfo& moduleBindingInfo = GetBindingInfo();
+        const EntryPointMetadata::BindingInfo& moduleBindingInfo =
+            GetEntryPoint("main", GetMainEntryPointStageForTransition()).bindings;
+
         for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
             const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group));
             const auto& bindingOffsets = bgl->GetBindingOffsets();
             const auto& groupBindingInfo = moduleBindingInfo[group];
             for (const auto& it : groupBindingInfo) {
-                const ShaderBindingInfo& bindingInfo = it.second;
+                const EntryPointMetadata::ShaderBindingInfo& bindingInfo = it.second;
                 BindingNumber bindingNumber = it.first;
                 BindingIndex bindingIndex = bgl->GetBindingIndex(bindingNumber);
 
diff --git a/src/dawn_native/metal/RenderPipelineMTL.mm b/src/dawn_native/metal/RenderPipelineMTL.mm
index 5823ffe..88e905c 100644
--- a/src/dawn_native/metal/RenderPipelineMTL.mm
+++ b/src/dawn_native/metal/RenderPipelineMTL.mm
@@ -368,8 +368,8 @@
             }
         }
 
-        const ShaderModuleBase::FragmentOutputBaseTypes& fragmentOutputBaseTypes =
-            descriptor->fragmentStage->module->GetFragmentOutputBaseTypes();
+        const EntryPointMetadata::FragmentOutputBaseTypes& fragmentOutputBaseTypes =
+            GetStage(SingleShaderStage::Fragment).metadata->fragmentOutputFormatBaseTypes;
         for (uint32_t i : IterateBitSet(GetColorAttachmentsMask())) {
             descriptorMTL.colorAttachments[i].pixelFormat =
                 MetalPixelFormat(GetColorAttachmentFormat(i));
diff --git a/src/dawn_native/metal/ShaderModuleMTL.mm b/src/dawn_native/metal/ShaderModuleMTL.mm
index 208612e..1e227ac 100644
--- a/src/dawn_native/metal/ShaderModuleMTL.mm
+++ b/src/dawn_native/metal/ShaderModuleMTL.mm
@@ -259,11 +259,15 @@
             // TODO(kainino@chromium.org): make this somehow more robust; it needs to behave like
             // clean_func_name:
             // https://github.com/KhronosGroup/SPIRV-Cross/blob/4e915e8c483e319d0dd7a1fa22318bef28f8cca3/spirv_msl.cpp#L1213
-            if (strcmp(functionName, "main") == 0) {
-                functionName = "main0";
+            const char* metalFunctionName = functionName;
+            if (strcmp(metalFunctionName, "main") == 0) {
+                metalFunctionName = "main0";
+            }
+            if (strcmp(metalFunctionName, "saturate") == 0) {
+                metalFunctionName = "saturate0";
             }
 
-            NSString* name = [[NSString alloc] initWithUTF8String:functionName];
+            NSString* name = [[NSString alloc] initWithUTF8String:metalFunctionName];
             out->function = [library newFunctionWithName:name];
             [library release];
         }
@@ -277,7 +281,7 @@
         }
 
         if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
-            functionStage == SingleShaderStage::Vertex && GetUsedVertexAttributes().any()) {
+            GetEntryPoint(functionName, functionStage).usedVertexAttributes.any()) {
             out->needsStorageBufferLength = true;
         }
 
diff --git a/src/dawn_native/opengl/ShaderModuleGL.cpp b/src/dawn_native/opengl/ShaderModuleGL.cpp
index 53aa101..0b02d53 100644
--- a/src/dawn_native/opengl/ShaderModuleGL.cpp
+++ b/src/dawn_native/opengl/ShaderModuleGL.cpp
@@ -125,8 +125,6 @@
 
         DAWN_TRY(ExtractSpirvInfo(*compiler));
 
-        const ShaderModuleBase::ModuleBindingInfo& bindingInfo = GetBindingInfo();
-
         // Extract bindings names so that it can be used to get its location in program.
         // Now translate the separate sampler / textures into combined ones and store their info.
         // We need to do this before removing the set and binding decorations.
@@ -182,6 +180,9 @@
             }
         }
 
+        const EntryPointMetadata::BindingInfo& bindingInfo =
+            GetEntryPoint("main", GetMainEntryPointStageForTransition()).bindings;
+
         // Change binding names to be "dawn_binding_<group>_<binding>".
         // Also unsets the SPIRV "Binding" decoration as it outputs "layout(binding=)" which
         // isn't supported on OSX's OpenGL.
diff --git a/src/dawn_native/vulkan/RenderPipelineVk.cpp b/src/dawn_native/vulkan/RenderPipelineVk.cpp
index 271aaaa..2c18cfd 100644
--- a/src/dawn_native/vulkan/RenderPipelineVk.cpp
+++ b/src/dawn_native/vulkan/RenderPipelineVk.cpp
@@ -425,8 +425,8 @@
         // Initialize the "blend state info" that will be chained in the "create info" from the data
         // pre-computed in the ColorState
         std::array<VkPipelineColorBlendAttachmentState, kMaxColorAttachments> colorBlendAttachments;
-        const ShaderModuleBase::FragmentOutputBaseTypes& fragmentOutputBaseTypes =
-            descriptor->fragmentStage->module->GetFragmentOutputBaseTypes();
+        const EntryPointMetadata::FragmentOutputBaseTypes& fragmentOutputBaseTypes =
+            GetStage(SingleShaderStage::Fragment).metadata->fragmentOutputFormatBaseTypes;
         for (uint32_t i : IterateBitSet(GetColorAttachmentsMask())) {
             const ColorStateDescriptor* colorStateDescriptor = GetColorStateDescriptor(i);
             bool isDeclaredInFragmentShader = fragmentOutputBaseTypes[i] != Format::Type::Other;