[msl] Move VertexPulling into the backend

It must come before BindingRemapper and Robustness.

Fixed: 380044409
Change-Id: I6ef4b37780323ceee4976e20997f39eee8fbde1e
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/222021
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/dawn/native/BUILD.gn b/src/dawn/native/BUILD.gn
index 10c31cc..a418a8f 100644
--- a/src/dawn/native/BUILD.gn
+++ b/src/dawn/native/BUILD.gn
@@ -677,8 +677,6 @@
       "metal/UtilsMetal.mm",
     ]
 
-    deps += [ "${dawn_root}/src/tint/lang/core/ir/transform" ]
-
     # If a "build with ARC" config is present, remove it.
     if (filter_include(configs, [ "//build/config/compiler:enable_arc" ]) !=
         []) {
diff --git a/src/dawn/native/metal/ShaderModuleMTL.mm b/src/dawn/native/metal/ShaderModuleMTL.mm
index 940e39f..6b78950 100644
--- a/src/dawn/native/metal/ShaderModuleMTL.mm
+++ b/src/dawn/native/metal/ShaderModuleMTL.mm
@@ -44,7 +44,6 @@
 #include "dawn/platform/DawnPlatform.h"
 #include "dawn/platform/metrics/HistogramMacros.h"
 #include "dawn/platform/tracing/TraceEvent.h"
-#include "src/tint/lang/core/ir/transform/vertex_pulling.h"
 
 #include <tint/tint.h>
 
@@ -61,7 +60,6 @@
 #define MSL_COMPILATION_REQUEST_MEMBERS(X)                                                       \
     X(SingleShaderStage, stage)                                                                  \
     X(const tint::Program*, inputProgram)                                                        \
-    X(OptionalVertexPullingTransformConfig, vertexPullingTransformConfig)                        \
     X(std::optional<tint::ast::transform::SubstituteOverride::Config>, substituteOverrideConfig) \
     X(LimitsForCompilationRequest, limits)                                                       \
     X(CacheKey::UnsafeUnkeyedValue<const AdapterBase*>, adapter)                                 \
@@ -274,7 +272,6 @@
     req.stage = stage;
     auto tintProgram = programmableStage.module->GetTintProgram();
     req.inputProgram = &(tintProgram->program);
-    req.vertexPullingTransformConfig = std::move(vertexPullingTransformConfig);
     req.substituteOverrideConfig = std::move(substituteOverrideConfig);
     req.entryPointName = programmableStage.entryPoint.c_str();
     req.disableSymbolRenaming = device->IsToggleEnabled(Toggle::DisableSymbolRenaming);
@@ -299,6 +296,7 @@
     req.tintOptions.bindings = std::move(bindings);
     req.tintOptions.disable_polyfill_integer_div_mod =
         device->IsToggleEnabled(Toggle::DisablePolyfillsOnIntegerDivisonAndModulo);
+    req.tintOptions.vertex_pulling_config = std::move(vertexPullingTransformConfig);
 
     const CombinedLimits& limits = device->GetLimits();
     req.limits = LimitsForCompilationRequest::Create(limits.v1);
@@ -326,15 +324,6 @@
                     r.disableSymbolRenaming ? tint::ast::transform::Renamer::Target::kMslKeywords
                                             : tint::ast::transform::Renamer::Target::kAll,
                     std::move(requestedNames));
-
-                if (r.vertexPullingTransformConfig) {
-                    tint::ast::transform::VertexPulling::Config config;
-                    config.pulling_group = r.vertexPullingTransformConfig->pulling_group;
-                    config.vertex_state = r.vertexPullingTransformConfig->vertex_state;
-                    transformManager.Add<tint::ast::transform::VertexPulling>();
-                    transformInputs.Add<tint::ast::transform::VertexPulling::Config>(
-                        std::move(config));
-                }
             }
 
             if (r.substituteOverrideConfig) {
@@ -365,15 +354,6 @@
                                 "An error occurred while generating Tint IR\n%s",
                                 ir.Failure().reason.Str());
 
-                // TODO(380044409): Move this into the backend.
-                if (r.vertexPullingTransformConfig) {
-                    auto vertex_pulling_result = tint::core::ir::transform::VertexPulling(
-                        ir.Get(), *r.vertexPullingTransformConfig);
-                    DAWN_INVALID_IF(vertex_pulling_result != tint::Success,
-                                    "An error occurred while running vertex pulling:\n%s",
-                                    vertex_pulling_result.Failure().reason.Str());
-                }
-
                 result = tint::msl::writer::Generate(ir.Get(), r.tintOptions);
 
                 // Workgroup validation has to come after `Generate` because it may require
diff --git a/src/tint/lang/msl/writer/ast_printer/ast_printer.cc b/src/tint/lang/msl/writer/ast_printer/ast_printer.cc
index 7236d60..39c1e69 100644
--- a/src/tint/lang/msl/writer/ast_printer/ast_printer.cc
+++ b/src/tint/lang/msl/writer/ast_printer/ast_printer.cc
@@ -85,6 +85,7 @@
 #include "src/tint/lang/wgsl/ast/transform/simplify_pointers.h"
 #include "src/tint/lang/wgsl/ast/transform/unshadow.h"
 #include "src/tint/lang/wgsl/ast/transform/vectorize_scalar_matrix_initializers.h"
+#include "src/tint/lang/wgsl/ast/transform/vertex_pulling.h"
 #include "src/tint/lang/wgsl/ast/transform/zero_init_workgroup_memory.h"
 #include "src/tint/lang/wgsl/ast/variable_decl_statement.h"
 #include "src/tint/lang/wgsl/helpers/check_supported_extensions.h"
@@ -162,6 +163,15 @@
 
     manager.Add<ast::transform::PromoteSideEffectsToDecl>();
 
+    // VertexPulling must come before Robustness.
+    if (options.vertex_pulling_config) {
+        ast::transform::VertexPulling::Config config;
+        config.pulling_group = options.vertex_pulling_config->pulling_group;
+        config.vertex_state = options.vertex_pulling_config->vertex_state;
+        manager.Add<ast::transform::VertexPulling>();
+        data.Add<ast::transform::VertexPulling::Config>(std::move(config));
+    }
+
     if (!options.disable_robustness) {
         // Robustness must come after PromoteSideEffectsToDecl
         // Robustness must come before BuiltinPolyfill and CanonicalizeEntryPointIO
diff --git a/src/tint/lang/msl/writer/common/options.h b/src/tint/lang/msl/writer/common/options.h
index bd13057..5bc74c2 100644
--- a/src/tint/lang/msl/writer/common/options.h
+++ b/src/tint/lang/msl/writer/common/options.h
@@ -33,6 +33,7 @@
 #include <unordered_map>
 
 #include "src/tint/api/common/binding_point.h"
+#include "src/tint/api/common/vertex_pulling_config.h"
 #include "src/tint/utils/reflection.h"
 
 namespace tint::msl::writer {
@@ -174,7 +175,10 @@
     /// from which to load buffer sizes.
     ArrayLengthFromUniformOptions array_length_from_uniform = {};
 
-    /// The bindings
+    /// The optional vertex pulling configuration.
+    std::optional<VertexPullingConfig> vertex_pulling_config = {};
+
+    /// The bindings.
     Bindings bindings;
 
     /// Reflect the fields of this class so that it can be used by tint::ForeachField()
@@ -190,6 +194,7 @@
                  fixed_sample_mask,
                  pixel_local_attachments,
                  array_length_from_uniform,
+                 vertex_pulling_config,
                  bindings);
 };
 
diff --git a/src/tint/lang/msl/writer/raise/raise.cc b/src/tint/lang/msl/writer/raise/raise.cc
index 7a8533f..3e94640 100644
--- a/src/tint/lang/msl/writer/raise/raise.cc
+++ b/src/tint/lang/msl/writer/raise/raise.cc
@@ -45,6 +45,7 @@
 #include "src/tint/lang/core/ir/transform/robustness.h"
 #include "src/tint/lang/core/ir/transform/value_to_let.h"
 #include "src/tint/lang/core/ir/transform/vectorize_scalar_matrix_constructors.h"
+#include "src/tint/lang/core/ir/transform/vertex_pulling.h"
 #include "src/tint/lang/core/ir/transform/zero_init_workgroup_memory.h"
 #include "src/tint/lang/msl/writer/common/option_helpers.h"
 #include "src/tint/lang/msl/writer/raise/binary_polyfill.h"
@@ -68,6 +69,11 @@
 
     RaiseResult raise_result;
 
+    // VertexPulling must come before BindingRemapper and Robustness.
+    if (options.vertex_pulling_config) {
+        RUN_TRANSFORM(core::ir::transform::VertexPulling, module, *options.vertex_pulling_config);
+    }
+
     tint::transform::multiplanar::BindingsMap multiplanar_map{};
     RemapperData remapper_data{};
     ArrayLengthFromUniformOptions array_length_from_uniform_options{};
diff --git a/src/tint/lang/msl/writer/writer_test.cc b/src/tint/lang/msl/writer/writer_test.cc
index 7826186..1cb9e27 100644
--- a/src/tint/lang/msl/writer/writer_test.cc
+++ b/src/tint/lang/msl/writer/writer_test.cc
@@ -231,5 +231,65 @@
 )");
 }
 
+TEST_F(MslWriterTest, VertexPulling) {
+    auto* ep = b.Function("main", ty.vec4<f32>(), core::ir::Function::PipelineStage::kVertex);
+    ep->SetReturnBuiltin(core::BuiltinValue::kPosition);
+    auto* attr = b.FunctionParam<vec4<f32>>("attr");
+    attr->SetLocation(1);
+    ep->SetParams({attr});
+    b.Append(ep->Block(), [&] {  //
+        b.Return(ep, attr);
+    });
+
+    VertexPullingConfig vertex_pulling_config;
+    vertex_pulling_config.pulling_group = 4u;
+    vertex_pulling_config.vertex_state = {
+        {{4, VertexStepMode::kVertex, {{VertexFormat::kFloat32, 0, 1}}}}};
+    ArrayLengthFromUniformOptions array_length_config;
+    array_length_config.ubo_binding = 30u;
+    array_length_config.bindpoint_to_size_index.insert({BindingPoint{0u, 1u}, 0u});
+    Options options;
+    options.bindings.storage.emplace(BindingPoint{4u, 0u}, tint::msl::writer::binding::Storage{1u});
+    options.vertex_pulling_config = std::move(vertex_pulling_config);
+    options.array_length_from_uniform = std::move(array_length_config);
+
+    ASSERT_TRUE(Generate(options)) << err_ << output_.msl;
+    EXPECT_EQ(output_.msl, R"(#include <metal_stdlib>
+using namespace metal;
+
+template<typename T, size_t N>
+struct tint_array {
+  const constant T& operator[](size_t i) const constant { return elements[i]; }
+  device T& operator[](size_t i) device { return elements[i]; }
+  const device T& operator[](size_t i) const device { return elements[i]; }
+  thread T& operator[](size_t i) thread { return elements[i]; }
+  const thread T& operator[](size_t i) const thread { return elements[i]; }
+  threadgroup T& operator[](size_t i) threadgroup { return elements[i]; }
+  const threadgroup T& operator[](size_t i) const threadgroup { return elements[i]; }
+  T elements[N];
+};
+
+struct tint_module_vars_struct {
+  const device tint_array<uint, 1>* tint_vertex_buffer_0;
+  const constant tint_array<uint4, 1>* tint_storage_buffer_sizes;
+};
+
+struct main_outputs {
+  float4 tint_symbol [[position]];
+};
+
+float4 main_inner(uint tint_vertex_index, tint_module_vars_struct tint_module_vars) {
+  return float4(as_type<float>((*tint_module_vars.tint_vertex_buffer_0)[min(tint_vertex_index, (((*tint_module_vars.tint_storage_buffer_sizes)[0u].x / 4u) - 1u))]), 0.0f, 0.0f, 1.0f);
+}
+
+vertex main_outputs v(uint tint_vertex_index [[vertex_id]], const device tint_array<uint, 1>* tint_vertex_buffer_0 [[buffer(1)]], const constant tint_array<uint4, 1>* tint_storage_buffer_sizes [[buffer(30)]]) {
+  tint_module_vars_struct const tint_module_vars = tint_module_vars_struct{.tint_vertex_buffer_0=tint_vertex_buffer_0, .tint_storage_buffer_sizes=tint_storage_buffer_sizes};
+  main_outputs tint_wrapper_result = {};
+  tint_wrapper_result.tint_symbol = main_inner(tint_vertex_index, tint_module_vars);
+  return tint_wrapper_result;
+}
+)");
+}
+
 }  // namespace
 }  // namespace tint::msl::writer