tint::transform::VertexPulling: require SingleEntryPoint

This change the vertex pulling transform to look for the single vertex
entry point in the module, instead of taking the entry point name in the
config. This is necessary because the renamer needs to run before
VertexPulling so that builtins like min() don't end up referring to the
input WGSL. Putting the renamer before VertexPulling makes the config
entry point name no longer match.

Bug: dawn:1583
Change-Id: I4c96eb83518e0d6fe8ce23b37e238f4a890eeb2f
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/107080
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/dawn/native/TintUtils.cpp b/src/dawn/native/TintUtils.cpp
index a2cf5dc..74f377e 100644
--- a/src/dawn/native/TintUtils.cpp
+++ b/src/dawn/native/TintUtils.cpp
@@ -154,10 +154,8 @@
 
 tint::transform::VertexPulling::Config BuildVertexPullingTransformConfig(
     const RenderPipelineBase& renderPipeline,
-    const std::string_view& entryPoint,
     BindGroupIndex pullingBufferBindingSet) {
     tint::transform::VertexPulling::Config cfg;
-    cfg.entry_point_name = entryPoint;
     cfg.pulling_group = static_cast<uint32_t>(pullingBufferBindingSet);
 
     cfg.vertex_state.resize(renderPipeline.GetVertexBufferCount());
diff --git a/src/dawn/native/TintUtils.h b/src/dawn/native/TintUtils.h
index d6bea1d..fb73e4d 100644
--- a/src/dawn/native/TintUtils.h
+++ b/src/dawn/native/TintUtils.h
@@ -45,7 +45,6 @@
 
 tint::transform::VertexPulling::Config BuildVertexPullingTransformConfig(
     const RenderPipelineBase& renderPipeline,
-    const std::string_view& entryPoint,
     BindGroupIndex pullingBufferBindingSet);
 
 tint::transform::SubstituteOverride::Config BuildSubstituteOverridesTransformConfig(
diff --git a/src/dawn/native/metal/ShaderModuleMTL.mm b/src/dawn/native/metal/ShaderModuleMTL.mm
index e07cd1d..8ab0628 100644
--- a/src/dawn/native/metal/ShaderModuleMTL.mm
+++ b/src/dawn/native/metal/ShaderModuleMTL.mm
@@ -141,8 +141,8 @@
     std::optional<tint::transform::VertexPulling::Config> vertexPullingTransformConfig;
     if (stage == SingleShaderStage::Vertex &&
         device->IsToggleEnabled(Toggle::MetalEnableVertexPulling)) {
-        vertexPullingTransformConfig = BuildVertexPullingTransformConfig(
-            *renderPipeline, programmableStage.entryPoint.c_str(), kPullingBufferBindingSet);
+        vertexPullingTransformConfig =
+            BuildVertexPullingTransformConfig(*renderPipeline, kPullingBufferBindingSet);
 
         for (VertexBufferSlot slot : IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) {
             uint32_t metalIndex = renderPipeline->GetMtlVertexBufferIndex(slot);
diff --git a/src/tint/transform/vertex_pulling.cc b/src/tint/transform/vertex_pulling.cc
index 2ec12d5..00b5a06 100644
--- a/src/tint/transform/vertex_pulling.cc
+++ b/src/tint/transform/vertex_pulling.cc
@@ -882,8 +882,18 @@
     }
 
     // Find entry point
-    auto* func = ctx.src->AST().Functions().Find(ctx.src->Symbols().Get(cfg.entry_point_name),
-                                                 ast::PipelineStage::kVertex);
+    const ast::Function* func = nullptr;
+    for (auto* fn : ctx.src->AST().Functions()) {
+        if (fn->PipelineStage() == ast::PipelineStage::kVertex) {
+            if (func != nullptr) {
+                ctx.dst->Diagnostics().add_error(
+                    diag::System::Transform,
+                    "VertexPulling found more than one vertex entry point");
+                return;
+            }
+            func = fn;
+        }
+    }
     if (func == nullptr) {
         ctx.dst->Diagnostics().add_error(diag::System::Transform,
                                          "Vertex stage entry point not found");
diff --git a/src/tint/transform/vertex_pulling.h b/src/tint/transform/vertex_pulling.h
index 255a49a..6dd35bc 100644
--- a/src/tint/transform/vertex_pulling.h
+++ b/src/tint/transform/vertex_pulling.h
@@ -135,6 +135,8 @@
 /// code, but these are types that the data may arrive as. We need to convert
 /// these smaller types into the base types such as `f32` and `u32` for the
 /// shader to use.
+///
+/// The SingleEntryPoint transform must have run before VertexPulling.
 class VertexPulling final : public Castable<VertexPulling, Transform> {
   public:
     /// Configuration options for the transform
@@ -152,9 +154,6 @@
         /// @returns this Config
         Config& operator=(const Config&);
 
-        /// The entry point to add assignments into
-        std::string entry_point_name;
-
         /// The vertex state descriptor, containing info about attributes
         VertexStateDescriptor vertex_state;
 
@@ -163,7 +162,7 @@
         uint32_t pulling_group = 4u;
 
         /// Reflect the fields of this class so that it can be used by tint::ForeachField()
-        TINT_REFLECT(entry_point_name, vertex_state, pulling_group);
+        TINT_REFLECT(vertex_state, pulling_group);
     };
 
     /// Constructor
diff --git a/src/tint/transform/vertex_pulling_test.cc b/src/tint/transform/vertex_pulling_test.cc
index 5fb8b1c..7c774c5 100644
--- a/src/tint/transform/vertex_pulling_test.cc
+++ b/src/tint/transform/vertex_pulling_test.cc
@@ -35,18 +35,21 @@
     EXPECT_EQ(expect, str(got));
 }
 
-TEST_F(VertexPullingTest, Error_InvalidEntryPoint) {
+TEST_F(VertexPullingTest, Error_MultipleEntryPoint) {
     auto* src = R"(
 @vertex
 fn main() -> @builtin(position) vec4<f32> {
   return vec4<f32>();
 }
+@vertex
+fn main2() -> @builtin(position) vec4<f32> {
+  return vec4<f32>();
+}
 )";
 
-    auto* expect = "error: Vertex stage entry point not found";
+    auto* expect = "error: VertexPulling found more than one vertex entry point";
 
     VertexPulling::Config cfg;
-    cfg.entry_point_name = "_";
 
     DataMap data;
     data.Add<VertexPulling::Config>(cfg);
@@ -64,7 +67,6 @@
     auto* expect = "error: Vertex stage entry point not found";
 
     VertexPulling::Config cfg;
-    cfg.entry_point_name = "main";
 
     DataMap data;
     data.Add<VertexPulling::Config>(cfg);
@@ -87,7 +89,6 @@
 
     VertexPulling::Config cfg;
     cfg.vertex_state = {{{15, VertexStepMode::kVertex, {{VertexFormat::kFloat32, 0, 0}}}}};
-    cfg.entry_point_name = "main";
 
     DataMap data;
     data.Add<VertexPulling::Config>(cfg);
@@ -116,7 +117,6 @@
 )";
 
     VertexPulling::Config cfg;
-    cfg.entry_point_name = "main";
 
     DataMap data;
     data.Add<VertexPulling::Config>(cfg);
@@ -153,7 +153,6 @@
 
     VertexPulling::Config cfg;
     cfg.vertex_state = {{{4, VertexStepMode::kVertex, {{VertexFormat::kFloat32, 0, 0}}}}};
-    cfg.entry_point_name = "main";
 
     DataMap data;
     data.Add<VertexPulling::Config>(cfg);
@@ -190,7 +189,6 @@
 
     VertexPulling::Config cfg;
     cfg.vertex_state = {{{4, VertexStepMode::kInstance, {{VertexFormat::kFloat32, 0, 0}}}}};
-    cfg.entry_point_name = "main";
 
     DataMap data;
     data.Add<VertexPulling::Config>(cfg);
@@ -228,7 +226,6 @@
     VertexPulling::Config cfg;
     cfg.vertex_state = {{{4, VertexStepMode::kVertex, {{VertexFormat::kFloat32, 0, 0}}}}};
     cfg.pulling_group = 5;
-    cfg.entry_point_name = "main";
 
     DataMap data;
     data.Add<VertexPulling::Config>(cfg);
@@ -274,7 +271,6 @@
 
     VertexPulling::Config cfg;
     cfg.vertex_state = {{{4, VertexStepMode::kVertex, {{VertexFormat::kFloat32, 0, 0}}}}};
-    cfg.entry_point_name = "main";
 
     DataMap data;
     data.Add<VertexPulling::Config>(cfg);
@@ -332,7 +328,6 @@
             {{VertexFormat::kFloat32, 0, 1}},
         },
     }};
-    cfg.entry_point_name = "main";
 
     DataMap data;
     data.Add<VertexPulling::Config>(cfg);
@@ -411,7 +406,6 @@
             {{VertexFormat::kFloat32, 0, 1}},
         },
     }};
-    cfg.entry_point_name = "main";
 
     DataMap data;
     data.Add<VertexPulling::Config>(cfg);
@@ -490,7 +484,6 @@
             {{VertexFormat::kFloat32, 0, 1}},
         },
     }};
-    cfg.entry_point_name = "main";
 
     DataMap data;
     data.Add<VertexPulling::Config>(cfg);
@@ -566,7 +559,6 @@
             {{VertexFormat::kFloat32, 0, 1}},
         },
     }};
-    cfg.entry_point_name = "main";
 
     DataMap data;
     data.Add<VertexPulling::Config>(cfg);
@@ -642,7 +634,6 @@
             {{VertexFormat::kFloat32, 0, 1}},
         },
     }};
-    cfg.entry_point_name = "main";
 
     DataMap data;
     data.Add<VertexPulling::Config>(cfg);
@@ -684,7 +675,6 @@
     cfg.vertex_state = {{{16,
                           VertexStepMode::kVertex,
                           {{VertexFormat::kFloat32, 0, 0}, {VertexFormat::kFloat32x4, 0, 1}}}}};
-    cfg.entry_point_name = "main";
 
     DataMap data;
     data.Add<VertexPulling::Config>(cfg);
@@ -738,7 +728,6 @@
         {12, VertexStepMode::kVertex, {{VertexFormat::kFloat32x3, 0, 1}}},
         {16, VertexStepMode::kVertex, {{VertexFormat::kFloat32x4, 0, 2}}},
     }};
-    cfg.entry_point_name = "main";
 
     DataMap data;
     data.Add<VertexPulling::Config>(cfg);
@@ -788,7 +777,6 @@
     cfg.vertex_state = {{{16,
                           VertexStepMode::kVertex,
                           {{VertexFormat::kFloat32, 0, 0}, {VertexFormat::kFloat32x4, 0, 1}}}}};
-    cfg.entry_point_name = "main";
 
     DataMap data;
     data.Add<VertexPulling::Config>(cfg);
@@ -933,7 +921,6 @@
               {VertexFormat::kSint32, 64, 26},    {VertexFormat::kSint32x2, 64, 27},
               {VertexFormat::kSint32x3, 64, 28},  {VertexFormat::kSint32x4, 64, 29},
           }}}};
-    cfg.entry_point_name = "main";
 
     DataMap data;
     data.Add<VertexPulling::Config>(cfg);
@@ -1079,7 +1066,6 @@
               {VertexFormat::kSint32, 63, 26},    {VertexFormat::kSint32x2, 63, 27},
               {VertexFormat::kSint32x3, 63, 28},  {VertexFormat::kSint32x4, 63, 29},
           }}}};
-    cfg.entry_point_name = "main";
 
     DataMap data;
     data.Add<VertexPulling::Config>(cfg);
@@ -1224,7 +1210,6 @@
               {VertexFormat::kSint32, 64, 26},    {VertexFormat::kSint32x2, 64, 27},
               {VertexFormat::kSint32x3, 64, 28},  {VertexFormat::kSint32x4, 64, 29},
           }}}};
-    cfg.entry_point_name = "main";
 
     DataMap data;
     data.Add<VertexPulling::Config>(cfg);