Move MSL configuration for `ArrayLengthFromUniform` transform.

The configuration for the `ArrayLengthFromUniform` transform was pulled
out to generator options in a previous CL. The HLSL backend was updated
to pass this information into the generator. The MSL backend was using
the deprecated path of having the transform determine the values.

This CL updates the MSL backend to pass the information into the
generator and removes the deprecated code from the transform.

Bug: tint:1855 chromium:1421379

Change-Id: I679c57914d575a758a9ff03b9db27a051d55fe17
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/123880
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/dawn/native/metal/ShaderModuleMTL.mm b/src/dawn/native/metal/ShaderModuleMTL.mm
index 350b5d0..c841d00 100644
--- a/src/dawn/native/metal/ShaderModuleMTL.mm
+++ b/src/dawn/native/metal/ShaderModuleMTL.mm
@@ -38,6 +38,7 @@
 #define MSL_COMPILATION_REQUEST_MEMBERS(X)                                                  \
     X(SingleShaderStage, stage)                                                             \
     X(const tint::Program*, inputProgram)                                                   \
+    X(tint::writer::ArrayLengthFromUniformOptions, arrayLengthFromUniform)                  \
     X(tint::transform::BindingRemapper::BindingPoints, bindingPoints)                       \
     X(tint::transform::MultiplanarExternalTexture::BindingsMap, externalTextureBindings)    \
     X(OptionalVertexPullingTransformConfig, vertexPullingTransformConfig)                   \
@@ -114,6 +115,9 @@
     using BindingPoint = tint::writer::BindingPoint;
     BindingRemapper::BindingPoints bindingPoints;
 
+    tint::writer::ArrayLengthFromUniformOptions arrayLengthFromUniform;
+    arrayLengthFromUniform.ubo_binding = {0, kBufferLengthBufferSlot};
+
     for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
         const BindGroupLayoutBase::BindingMap& bindingMap =
             layout->GetBindGroupLayout(group)->GetBindingMap();
@@ -133,6 +137,16 @@
             if (srcBindingPoint != dstBindingPoint) {
                 bindingPoints.emplace(srcBindingPoint, dstBindingPoint);
             }
+
+            // Use the ShaderIndex as the indices for the buffer size lookups in the array length
+            // uniform transform. This is used to compute the size of variable length arrays in
+            // storage buffers.
+            if (bindingInfo.buffer.type == wgpu::BufferBindingType::Storage ||
+                bindingInfo.buffer.type == wgpu::BufferBindingType::ReadOnlyStorage ||
+                bindingInfo.buffer.type == kInternalStorageBufferBinding) {
+                arrayLengthFromUniform.bindpoint_to_size_index.emplace(dstBindingPoint,
+                                                                       dstBindingPoint.binding);
+            }
         }
     }
 
@@ -154,6 +168,11 @@
             if (srcBindingPoint != dstBindingPoint) {
                 bindingPoints.emplace(srcBindingPoint, dstBindingPoint);
             }
+
+            // Use the ShaderIndex as the indices for the buffer size lookups in the array
+            // length uniform transform.
+            arrayLengthFromUniform.bindpoint_to_size_index.emplace(dstBindingPoint,
+                                                                   dstBindingPoint.binding);
         }
     }
 
@@ -177,6 +196,7 @@
     req.isRobustnessEnabled = device->IsRobustnessEnabled();
     req.disableSymbolRenaming = device->IsToggleEnabled(Toggle::DisableSymbolRenaming);
     req.tracePlatform = UnsafeUnkeyedValue(device->GetPlatform());
+    req.arrayLengthFromUniform = std::move(arrayLengthFromUniform);
 
     const CombinedLimits& limits = device->GetLimits();
     req.limits = LimitsForCompilationRequest::Create(limits.v1);
@@ -264,6 +284,8 @@
             options.fixed_sample_mask = r.sampleMask;
             options.disable_workgroup_init = r.disableWorkgroupInit;
             options.emit_vertex_point_size = r.emitVertexPointSize;
+            options.array_length_from_uniform = r.arrayLengthFromUniform;
+
             TRACE_EVENT0(r.tracePlatform.UnsafeGetValue(), General, "tint::writer::msl::Generate");
             auto result = tint::writer::msl::Generate(&program, options);
             DAWN_INVALID_IF(!result.success, "An error occured while generating MSL: %s.",
diff --git a/src/tint/cmd/main.cc b/src/tint/cmd/main.cc
index 898f0ac..45982f2 100644
--- a/src/tint/cmd/main.cc
+++ b/src/tint/cmd/main.cc
@@ -660,6 +660,11 @@
     gen_options.disable_workgroup_init = options.disable_workgroup_init;
     gen_options.external_texture_options.bindings_map =
         tint::cmd::GenerateExternalTextureBindings(input_program);
+    gen_options.array_length_from_uniform.ubo_binding = tint::writer::BindingPoint{0, 30};
+    gen_options.array_length_from_uniform.bindpoint_to_size_index.emplace(
+        tint::writer::BindingPoint{0, 0}, 0);
+    gen_options.array_length_from_uniform.bindpoint_to_size_index.emplace(
+        tint::writer::BindingPoint{0, 1}, 1);
     auto result = tint::writer::msl::Generate(input_program, gen_options);
     if (!result.success) {
         tint::cmd::PrintWGSL(std::cerr, *program);
diff --git a/src/tint/writer/msl/generator_impl.cc b/src/tint/writer/msl/generator_impl.cc
index 268736f..ae49d5d 100644
--- a/src/tint/writer/msl/generator_impl.cc
+++ b/src/tint/writer/msl/generator_impl.cc
@@ -173,29 +173,6 @@
     // ExpandCompoundAssignment must come before BuiltinPolyfill
     manager.Add<transform::ExpandCompoundAssignment>();
 
-    // Build the config for the internal ArrayLengthFromUniform transform.
-    auto& array_length_from_uniform = options.array_length_from_uniform;
-    transform::ArrayLengthFromUniform::Config array_length_from_uniform_cfg(
-        array_length_from_uniform.ubo_binding);
-    if (!array_length_from_uniform.bindpoint_to_size_index.empty()) {
-        // If |array_length_from_uniform| bindings are provided, use that config.
-        array_length_from_uniform_cfg.bindpoint_to_size_index =
-            array_length_from_uniform.bindpoint_to_size_index;
-    } else {
-        // If the binding map is empty, use the deprecated |buffer_size_ubo_index|
-        // and automatically choose indices using the binding numbers.
-        array_length_from_uniform_cfg = transform::ArrayLengthFromUniform::Config(
-            sem::BindingPoint{0, options.buffer_size_ubo_index});
-        // Use the SSBO binding numbers as the indices for the buffer size lookups.
-        for (auto* var : in->AST().GlobalVariables()) {
-            auto* global = in->Sem().Get<sem::GlobalVariable>(var);
-            if (global && global->AddressSpace() == builtin::AddressSpace::kStorage) {
-                array_length_from_uniform_cfg.bindpoint_to_size_index.emplace(
-                    global->BindingPoint(), global->BindingPoint().binding);
-            }
-        }
-    }
-
     // Build the configs for the internal CanonicalizeEntryPointIO transform.
     auto entry_point_io_cfg = transform::CanonicalizeEntryPointIO::Config(
         transform::CanonicalizeEntryPointIO::ShaderStyle::kMsl, options.fixed_sample_mask,
@@ -210,6 +187,7 @@
     if (!options.disable_robustness) {
         // Robustness must come after PromoteSideEffectsToDecl
         // Robustness must come before BuiltinPolyfill and CanonicalizeEntryPointIO
+        // Robustness must come before ArrayLengthFromUniform
         manager.Add<transform::Robustness>();
     }
 
@@ -260,10 +238,16 @@
     // ArrayLengthFromUniform must come after SimplifyPointers, as
     // it assumes that the form of the array length argument is &var.array.
     manager.Add<transform::ArrayLengthFromUniform>();
+
+    transform::ArrayLengthFromUniform::Config array_length_cfg(
+        std::move(options.array_length_from_uniform.ubo_binding));
+    array_length_cfg.bindpoint_to_size_index =
+        std::move(options.array_length_from_uniform.bindpoint_to_size_index);
+    data.Add<transform::ArrayLengthFromUniform::Config>(array_length_cfg);
+
     // PackedVec3 must come after ExpandCompoundAssignment.
     manager.Add<transform::PackedVec3>();
     manager.Add<transform::ModuleScopeVarToEntryPointParam>();
-    data.Add<transform::ArrayLengthFromUniform::Config>(std::move(array_length_from_uniform_cfg));
     data.Add<transform::CanonicalizeEntryPointIO::Config>(std::move(entry_point_io_cfg));
     auto out = manager.Run(in, data);
 
diff --git a/src/tint/writer/msl/generator_impl_sanitizer_test.cc b/src/tint/writer/msl/generator_impl_sanitizer_test.cc
index f72e0e9..c75fe92 100644
--- a/src/tint/writer/msl/generator_impl_sanitizer_test.cc
+++ b/src/tint/writer/msl/generator_impl_sanitizer_test.cc
@@ -39,7 +39,10 @@
              Stage(ast::PipelineStage::kFragment),
          });
 
-    GeneratorImpl& gen = SanitizeAndBuild();
+    Options opts = DefaultOptions();
+    opts.array_length_from_uniform.ubo_binding = sem::BindingPoint{0, 30};
+    opts.array_length_from_uniform.bindpoint_to_size_index.emplace(sem::BindingPoint{2, 1}, 1);
+    GeneratorImpl& gen = SanitizeAndBuild(opts);
 
     ASSERT_TRUE(gen.Generate()) << gen.error();
 
@@ -93,7 +96,10 @@
              Stage(ast::PipelineStage::kFragment),
          });
 
-    GeneratorImpl& gen = SanitizeAndBuild();
+    Options opts = DefaultOptions();
+    opts.array_length_from_uniform.ubo_binding = sem::BindingPoint{0, 30};
+    opts.array_length_from_uniform.bindpoint_to_size_index.emplace(sem::BindingPoint{2, 1}, 1);
+    GeneratorImpl& gen = SanitizeAndBuild(opts);
 
     ASSERT_TRUE(gen.Generate()) << gen.error();
 
@@ -151,7 +157,10 @@
              Stage(ast::PipelineStage::kFragment),
          });
 
-    GeneratorImpl& gen = SanitizeAndBuild();
+    Options opts = DefaultOptions();
+    opts.array_length_from_uniform.ubo_binding = sem::BindingPoint{0, 30};
+    opts.array_length_from_uniform.bindpoint_to_size_index.emplace(sem::BindingPoint{2, 1}, 1);
+    GeneratorImpl& gen = SanitizeAndBuild(opts);
 
     ASSERT_TRUE(gen.Generate()) << gen.error();