Metal: Remove vertexState in the parameter of TranslateToMSL

This patch removes the parameter "vertexState" from TranslateToMSL
and ShaderModule::CreateFunction as we have already been able to
get the vertex states from the RenderPipelineBase object.

BUG=dawn:529

Change-Id: I2971438bfd5e0f3fbea900e1f06c1b33349571da
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/64140
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp
index f108c33..045b728 100644
--- a/src/dawn_native/ShaderModule.cpp
+++ b/src/dawn_native/ShaderModule.cpp
@@ -1035,30 +1035,36 @@
         return std::move(output.program);
     }
 
-    void AddVertexPullingTransformConfig(const VertexState& vertexState,
+    void AddVertexPullingTransformConfig(const RenderPipelineBase& renderPipeline,
                                          const std::string& entryPoint,
                                          BindGroupIndex pullingBufferBindingSet,
                                          tint::transform::DataMap* transformInputs) {
         tint::transform::VertexPulling::Config cfg;
         cfg.entry_point_name = entryPoint;
         cfg.pulling_group = static_cast<uint32_t>(pullingBufferBindingSet);
-        for (uint32_t i = 0; i < vertexState.bufferCount; ++i) {
-            const auto& vertexBuffer = vertexState.buffers[i];
-            tint::transform::VertexBufferLayoutDescriptor layout;
-            layout.array_stride = vertexBuffer.arrayStride;
-            layout.step_mode = ToTintVertexStepMode(vertexBuffer.stepMode);
 
-            for (uint32_t j = 0; j < vertexBuffer.attributeCount; ++j) {
-                const auto& attribute = vertexBuffer.attributes[j];
-                tint::transform::VertexAttributeDescriptor attr;
-                attr.format = ToTintVertexFormat(attribute.format);
-                attr.offset = attribute.offset;
-                attr.shader_location = attribute.shaderLocation;
-
-                layout.attributes.push_back(std::move(attr));
+        const auto& vertexBufferSlotUsed = renderPipeline.GetVertexBufferSlotsUsed();
+        cfg.vertex_state.resize(renderPipeline.GetVertexBufferCount());
+        for (uint8_t vertexBufferSlot = 0;
+             vertexBufferSlot < static_cast<uint8_t>(cfg.vertex_state.size()); ++vertexBufferSlot) {
+            if (vertexBufferSlotUsed[static_cast<VertexBufferSlot>(vertexBufferSlot)]) {
+                const auto& vertexBuffer =
+                    renderPipeline.GetVertexBuffer(static_cast<VertexBufferSlot>(vertexBufferSlot));
+                cfg.vertex_state[vertexBufferSlot].array_stride = vertexBuffer.arrayStride;
+                cfg.vertex_state[vertexBufferSlot].step_mode =
+                    ToTintVertexStepMode(vertexBuffer.stepMode);
             }
-
-            cfg.vertex_state.push_back(std::move(layout));
+        }
+        for (VertexAttributeLocation location :
+             IterateBitSet(renderPipeline.GetAttributeLocationsUsed())) {
+            const auto& attribute = renderPipeline.GetAttribute(location);
+            tint::transform::VertexAttributeDescriptor attr;
+            attr.format = ToTintVertexFormat(attribute.format);
+            attr.offset = attribute.offset;
+            attr.shader_location = static_cast<uint32_t>(static_cast<uint8_t>(location));
+            ASSERT(vertexBufferSlotUsed[attribute.vertexBufferSlot]);
+            uint8_t vertexBufferSlot = static_cast<uint8_t>(attribute.vertexBufferSlot);
+            cfg.vertex_state[vertexBufferSlot].attributes.push_back(attr);
         }
         transformInputs->Add<tint::transform::VertexPulling::Config>(cfg);
     }
diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h
index a4e8d3d..d981d56 100644
--- a/src/dawn_native/ShaderModule.h
+++ b/src/dawn_native/ShaderModule.h
@@ -110,7 +110,7 @@
                                                OwnedCompilationMessages* messages);
 
     /// Creates and adds the tint::transform::VertexPulling::Config to transformInputs.
-    void AddVertexPullingTransformConfig(const VertexState& vertexState,
+    void AddVertexPullingTransformConfig(const RenderPipelineBase& renderPipeline,
                                          const std::string& entryPoint,
                                          BindGroupIndex pullingBufferBindingSet,
                                          tint::transform::DataMap* transformInputs);
diff --git a/src/dawn_native/metal/RenderPipelineMTL.mm b/src/dawn_native/metal/RenderPipelineMTL.mm
index 6d78676..17706d7 100644
--- a/src/dawn_native/metal/RenderPipelineMTL.mm
+++ b/src/dawn_native/metal/RenderPipelineMTL.mm
@@ -341,17 +341,9 @@
         ShaderModule* vertexModule = ToBackend(descriptor->vertex.module);
         const char* vertexEntryPoint = descriptor->vertex.entryPoint;
         ShaderModule::MetalFunctionData vertexData;
-
-        const VertexState* vertexStatePtr = &descriptor->vertex;
-        VertexState vertexState;
-        if (vertexStatePtr == nullptr) {
-            vertexState = {};
-            vertexStatePtr = &vertexState;
-        }
-
         DAWN_TRY(vertexModule->CreateFunction(vertexEntryPoint, SingleShaderStage::Vertex,
-                                              ToBackend(GetLayout()), &vertexData, 0xFFFFFFFF, this,
-                                              vertexStatePtr));
+                                              ToBackend(GetLayout()), &vertexData, 0xFFFFFFFF,
+                                              this));
 
         descriptorMTL.vertexFunction = vertexData.function.Get();
         if (vertexData.needsStorageBufferLength) {
diff --git a/src/dawn_native/metal/ShaderModuleMTL.h b/src/dawn_native/metal/ShaderModuleMTL.h
index aced62b..ab87929 100644
--- a/src/dawn_native/metal/ShaderModuleMTL.h
+++ b/src/dawn_native/metal/ShaderModuleMTL.h
@@ -43,8 +43,7 @@
                                   const PipelineLayout* layout,
                                   MetalFunctionData* out,
                                   uint32_t sampleMask = 0xFFFFFFFF,
-                                  const RenderPipeline* renderPipeline = nullptr,
-                                  const VertexState* vertexState = nullptr);
+                                  const RenderPipeline* renderPipeline = nullptr);
 
       private:
         ResultOrError<std::string> TranslateToMSL(const char* entryPointName,
@@ -52,7 +51,6 @@
                                                   const PipelineLayout* layout,
                                                   uint32_t sampleMask,
                                                   const RenderPipeline* renderPipeline,
-                                                  const VertexState* vertexState,
                                                   std::string* remappedEntryPointName,
                                                   bool* needsStorageBufferLength,
                                                   bool* hasInvariantAttribute);
diff --git a/src/dawn_native/metal/ShaderModuleMTL.mm b/src/dawn_native/metal/ShaderModuleMTL.mm
index b52c903..6685f4f 100644
--- a/src/dawn_native/metal/ShaderModuleMTL.mm
+++ b/src/dawn_native/metal/ShaderModuleMTL.mm
@@ -49,7 +49,6 @@
                                                             const PipelineLayout* layout,
                                                             uint32_t sampleMask,
                                                             const RenderPipeline* renderPipeline,
-                                                            const VertexState* vertexState,
                                                             std::string* remappedEntryPointName,
                                                             bool* needsStorageBufferLength,
                                                             bool* hasInvariantAttribute) {
@@ -100,8 +99,8 @@
         if (stage == SingleShaderStage::Vertex &&
             GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling)) {
             transformManager.Add<tint::transform::VertexPulling>();
-            AddVertexPullingTransformConfig(*vertexState, entryPointName, kPullingBufferBindingSet,
-                                            &transformInputs);
+            AddVertexPullingTransformConfig(*renderPipeline, entryPointName,
+                                            kPullingBufferBindingSet, &transformInputs);
 
             for (VertexBufferSlot slot :
                  IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) {
@@ -176,15 +175,13 @@
                                             const PipelineLayout* layout,
                                             ShaderModule::MetalFunctionData* out,
                                             uint32_t sampleMask,
-                                            const RenderPipeline* renderPipeline,
-                                            const VertexState* vertexState) {
+                                            const RenderPipeline* renderPipeline) {
         ASSERT(!IsError());
         ASSERT(out);
 
-        // Vertex stages must specify a renderPipeline and vertexState
+        // Vertex stages must specify a renderPipeline
         if (stage == SingleShaderStage::Vertex) {
             ASSERT(renderPipeline != nullptr);
-            ASSERT(vertexState != nullptr);
         }
 
         std::string remappedEntryPointName;
@@ -192,8 +189,8 @@
         bool hasInvariantAttribute = false;
         DAWN_TRY_ASSIGN(msl,
                         TranslateToMSL(entryPointName, stage, layout, sampleMask, renderPipeline,
-                                       vertexState, &remappedEntryPointName,
-                                       &out->needsStorageBufferLength, &hasInvariantAttribute));
+                                       &remappedEntryPointName, &out->needsStorageBufferLength,
+                                       &hasInvariantAttribute));
 
         // Metal uses Clang to compile the shader as C++14. Disable everything in the -Wall
         // category. -Wunused-variable in particular comes up a lot in generated code, and some