[msl] Add support for chromium_experimental_subgroups

Use a transform to replace calls to `subgroupBallot()` with a helper
function that masks out the unused bits, as well as converting the
64-bit output from MSL's `simd_active_threads_mask()` function to a
vec4u.

Adds the ability to change the target MSL version used for validation,
which is automatically bumped to MSL 2.1 if the subgroups extension is
used.

Bug: tint:2000
Change-Id: I3161b249acd90378be52b18f6325c82fd9dc80b2
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/146460
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index dfec136..30e4ddb 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -425,6 +425,8 @@
     "lang/wgsl/ast/transform/merge_return.h",
     "lang/wgsl/ast/transform/module_scope_var_to_entry_point_param.cc",
     "lang/wgsl/ast/transform/module_scope_var_to_entry_point_param.h",
+    "lang/wgsl/ast/transform/msl_subgroup_ballot.cc",
+    "lang/wgsl/ast/transform/msl_subgroup_ballot.h",
     "lang/wgsl/ast/transform/multiplanar_external_texture.cc",
     "lang/wgsl/ast/transform/multiplanar_external_texture.h",
     "lang/wgsl/ast/transform/num_workgroups_from_uniform.cc",
@@ -1847,6 +1849,7 @@
       "lang/wgsl/ast/transform/manager_test.cc",
       "lang/wgsl/ast/transform/merge_return_test.cc",
       "lang/wgsl/ast/transform/module_scope_var_to_entry_point_param_test.cc",
+      "lang/wgsl/ast/transform/msl_subgroup_ballot_test.cc",
       "lang/wgsl/ast/transform/multiplanar_external_texture_test.cc",
       "lang/wgsl/ast/transform/num_workgroups_from_uniform_test.cc",
       "lang/wgsl/ast/transform/packed_vec3_test.cc",
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 70f1821..b36d852 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -360,6 +360,8 @@
   lang/wgsl/ast/transform/merge_return.h
   lang/wgsl/ast/transform/module_scope_var_to_entry_point_param.cc
   lang/wgsl/ast/transform/module_scope_var_to_entry_point_param.h
+  lang/wgsl/ast/transform/msl_subgroup_ballot.cc
+  lang/wgsl/ast/transform/msl_subgroup_ballot.h
   lang/wgsl/ast/transform/multiplanar_external_texture.cc
   lang/wgsl/ast/transform/multiplanar_external_texture.h
   lang/wgsl/ast/transform/num_workgroups_from_uniform.cc
@@ -1498,6 +1500,7 @@
       lang/wgsl/ast/transform/localize_struct_array_assignment_test.cc
       lang/wgsl/ast/transform/merge_return_test.cc
       lang/wgsl/ast/transform/module_scope_var_to_entry_point_param_test.cc
+      lang/wgsl/ast/transform/msl_subgroup_ballot_test.cc
       lang/wgsl/ast/transform/multiplanar_external_texture_test.cc
       lang/wgsl/ast/transform/num_workgroups_from_uniform_test.cc
       lang/wgsl/ast/transform/packed_vec3_test.cc
diff --git a/src/tint/cmd/tint/main.cc b/src/tint/cmd/tint/main.cc
index 7434165..23a7e3f 100644
--- a/src/tint/cmd/tint/main.cc
+++ b/src/tint/cmd/tint/main.cc
@@ -671,10 +671,19 @@
         PrintHash(hash);
     }
 
+    // Default to validating against MSL 1.2.
+    // If subgroups are used, bump the version to 2.1.
+    auto msl_version = tint::msl::validate::MslVersion::kMsl_1_2;
+    for (auto* enable : program->AST().Enables()) {
+        if (enable->HasExtension(tint::core::Extension::kChromiumExperimentalSubgroups)) {
+            msl_version = tint::msl::validate::MslVersion::kMsl_2_1;
+        }
+    }
+
     if (options.validate && options.skip_hash.count(hash) == 0) {
         tint::msl::validate::Result res;
 #ifdef TINT_ENABLE_MSL_VALIDATION_USING_METAL_API
-        res = tint::msl::validate::UsingMetalAPI(result->msl);
+        res = tint::msl::validate::UsingMetalAPI(result->msl, msl_version);
 #else
 #ifdef _WIN32
         const char* default_xcrun_exe = "metal.exe";
@@ -684,7 +693,7 @@
         auto xcrun = tint::Command::LookPath(
             options.xcrun_path.empty() ? default_xcrun_exe : std::string(options.xcrun_path));
         if (xcrun.Found()) {
-            res = tint::msl::validate::Msl(xcrun.Path(), result->msl);
+            res = tint::msl::validate::Msl(xcrun.Path(), result->msl, msl_version);
         } else {
             res.output = "xcrun executable not found. Cannot validate.";
             res.failed = true;
diff --git a/src/tint/lang/msl/validate/msl.cc b/src/tint/lang/msl/validate/msl.cc
index b3d0000..fce70d6 100644
--- a/src/tint/lang/msl/validate/msl.cc
+++ b/src/tint/lang/msl/validate/msl.cc
@@ -21,7 +21,7 @@
 
 namespace tint::msl::validate {
 
-Result Msl(const std::string& xcrun_path, const std::string& source) {
+Result Msl(const std::string& xcrun_path, const std::string& source, MslVersion version) {
     Result result;
 
     auto xcrun = tint::Command(xcrun_path);
@@ -34,17 +34,27 @@
     tint::TmpFile file(".metal");
     file << source;
 
+    const char* version_str = nullptr;
+    switch (version) {
+        case MslVersion::kMsl_1_2:
+            version_str = "-std=macos-metal1.2";
+            break;
+        case MslVersion::kMsl_2_1:
+            version_str = "-std=macos-metal2.1";
+            break;
+    }
+
 #ifdef _WIN32
     // On Windows, we should actually be running metal.exe from the Metal
     // Developer Tools for Windows
-    auto res = xcrun("-x", "metal",        //
-                     "-o", "NUL",          //
-                     "-std=osx-metal1.2",  //
+    auto res = xcrun("-x", "metal",  //
+                     "-o", "NUL",    //
+                     version_str,    //
                      "-c", file.Path());
 #else
     auto res = xcrun("-sdk", "macosx", "metal",  //
                      "-o", "/dev/null",          //
-                     "-std=osx-metal1.2",        //
+                     version_str,                //
                      "-c", file.Path());
 #endif
     if (!res.out.empty()) {
diff --git a/src/tint/lang/msl/validate/msl_metal.mm b/src/tint/lang/msl/validate/msl_metal.mm
index 0b2724d..9999f2a 100644
--- a/src/tint/lang/msl/validate/msl_metal.mm
+++ b/src/tint/lang/msl/validate/msl_metal.mm
@@ -24,7 +24,7 @@
 
 namespace tint::msl::validate {
 
-Result UsingMetalAPI(const std::string& src) {
+Result UsingMetalAPI(const std::string& src, MslVersion version) {
     Result result;
 
     NSError* error = nil;
@@ -40,7 +40,14 @@
 
     MTLCompileOptions* compileOptions = [MTLCompileOptions new];
     compileOptions.fastMathEnabled = true;
-    compileOptions.languageVersion = MTLLanguageVersion1_2;
+    switch (version) {
+        case MslVersion::kMsl_1_2:
+            compileOptions.languageVersion = MTLLanguageVersion1_2;
+            break;
+        case MslVersion::kMsl_2_1:
+            compileOptions.languageVersion = MTLLanguageVersion2_1;
+            break;
+    }
 
     id<MTLLibrary> library = [device newLibraryWithSource:source
                                                   options:compileOptions
diff --git a/src/tint/lang/msl/validate/val.h b/src/tint/lang/msl/validate/val.h
index c4c0a15..498516b 100644
--- a/src/tint/lang/msl/validate/val.h
+++ b/src/tint/lang/msl/validate/val.h
@@ -30,6 +30,12 @@
 
 using EntryPointList = std::vector<std::pair<std::string, ast::PipelineStage>>;
 
+// The version of MSL to validate against.
+enum class MslVersion {
+    kMsl_1_2,
+    kMsl_2_1,
+};
+
 /// The return structure of Validate()
 struct Result {
     /// True if validation passed
@@ -42,15 +48,17 @@
 /// verifying that the shader compiles successfully.
 /// @param xcrun_path path to xcrun
 /// @param source the generated MSL source
+/// @param version the version of MSL to validate against
 /// @return the result of the compile
-Result Msl(const std::string& xcrun_path, const std::string& source);
+Result Msl(const std::string& xcrun_path, const std::string& source, MslVersion version);
 
 #ifdef TINT_ENABLE_MSL_VALIDATION_USING_METAL_API
 /// Msl attempts to compile the shader with the runtime Metal Shader Compiler
 /// API, verifying that the shader compiles successfully.
 /// @param source the generated MSL source
+/// @param version the version of MSL to validate against
 /// @return the result of the compile
-Result UsingMetalAPI(const std::string& source);
+Result UsingMetalAPI(const std::string& source, MslVersion version);
 #endif  // TINT_ENABLE_MSL_VALIDATION_USING_METAL_API
 
 }  // namespace tint::msl::validate
diff --git a/src/tint/lang/msl/writer/ast_printer/ast_printer.cc b/src/tint/lang/msl/writer/ast_printer/ast_printer.cc
index 86d611f..59ca50b 100644
--- a/src/tint/lang/msl/writer/ast_printer/ast_printer.cc
+++ b/src/tint/lang/msl/writer/ast_printer/ast_printer.cc
@@ -59,6 +59,7 @@
 #include "src/tint/lang/wgsl/ast/transform/expand_compound_assignment.h"
 #include "src/tint/lang/wgsl/ast/transform/manager.h"
 #include "src/tint/lang/wgsl/ast/transform/module_scope_var_to_entry_point_param.h"
+#include "src/tint/lang/wgsl/ast/transform/msl_subgroup_ballot.h"
 #include "src/tint/lang/wgsl/ast/transform/multiplanar_external_texture.h"
 #include "src/tint/lang/wgsl/ast/transform/packed_vec3.h"
 #include "src/tint/lang/wgsl/ast/transform/preserve_padding.h"
@@ -205,6 +206,9 @@
     manager.Add<ast::transform::RemovePhonies>();
     manager.Add<ast::transform::SimplifyPointers>();
 
+    // MslSubgroupBallot() must come after CanonicalizeEntryPointIO.
+    manager.Add<ast::transform::MslSubgroupBallot>();
+
     // ArrayLengthFromUniform must come after SimplifyPointers, as
     // it assumes that the form of the array length argument is &var.array.
     manager.Add<ast::transform::ArrayLengthFromUniform>();
@@ -244,6 +248,7 @@
                 core::Extension::kChromiumExperimentalFullPtrParameters,
                 core::Extension::kChromiumExperimentalPushConstant,
                 core::Extension::kChromiumExperimentalReadWriteStorageTexture,
+                core::Extension::kChromiumExperimentalSubgroups,
                 core::Extension::kChromiumInternalRelaxedUniformLayout,
                 core::Extension::kF16,
                 core::Extension::kChromiumInternalDualSourceBlending,
@@ -622,6 +627,12 @@
 bool ASTPrinter::EmitFunctionCall(StringStream& out,
                                   const sem::Call* call,
                                   const sem::Function* fn) {
+    if (ast::GetAttribute<ast::transform::MslSubgroupBallot::SimdActiveThreadsMask>(
+            fn->Declaration()->attributes) != nullptr) {
+        out << "as_type<uint2>((ulong)simd_active_threads_mask())";
+        return true;
+    }
+
     out << fn->Declaration()->name->symbol.Name() << "(";
 
     bool first = true;
@@ -1847,6 +1858,11 @@
 }
 
 bool ASTPrinter::EmitFunction(const ast::Function* func) {
+    if (func->body == nullptr) {
+        // An internal function. Do not emit.
+        return true;
+    }
+
     auto* func_sem = builder_.Sem().Get(func);
 
     {
diff --git a/src/tint/lang/msl/writer/common/printer_support.cc b/src/tint/lang/msl/writer/common/printer_support.cc
index 3311449..9312650 100644
--- a/src/tint/lang/msl/writer/common/printer_support.cc
+++ b/src/tint/lang/msl/writer/common/printer_support.cc
@@ -60,6 +60,10 @@
             return "sample_mask";
         case core::BuiltinValue::kPointSize:
             return "point_size";
+        case core::BuiltinValue::kSubgroupInvocationId:
+            return "thread_index_in_simdgroup";
+        case core::BuiltinValue::kSubgroupSize:
+            return "threads_per_simdgroup";
         default:
             break;
     }
diff --git a/src/tint/lang/wgsl/ast/transform/msl_subgroup_ballot.cc b/src/tint/lang/wgsl/ast/transform/msl_subgroup_ballot.cc
new file mode 100644
index 0000000..8bd3350
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/msl_subgroup_ballot.cc
@@ -0,0 +1,197 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/msl_subgroup_ballot.h"
+
+#include <utility>
+
+#include "src/tint/lang/wgsl/program/clone_context.h"
+#include "src/tint/lang/wgsl/program/program_builder.h"
+#include "src/tint/lang/wgsl/resolver/resolve.h"
+#include "src/tint/lang/wgsl/sem/call.h"
+#include "src/tint/lang/wgsl/sem/function.h"
+#include "src/tint/lang/wgsl/sem/statement.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::MslSubgroupBallot);
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::MslSubgroupBallot::SimdActiveThreadsMask);
+
+using namespace tint::number_suffixes;  // NOLINT
+
+namespace tint::ast::transform {
+
+/// PIMPL state for the transform
+struct MslSubgroupBallot::State {
+    /// The source program
+    const Program* const src;
+    /// The target program builder
+    ProgramBuilder b;
+    /// The clone context
+    program::CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
+
+    /// The name of the `tint_subgroup_ballot` helper function.
+    Symbol ballot_helper{};
+
+    /// The name of the `tint_subgroup_size_mask` global variable.
+    Symbol subgroup_size_mask{};
+
+    /// The set of a functions that directly call `subgroupBallot()`.
+    Hashset<const sem::Function*, 4> ballot_callers;
+
+    /// Constructor
+    /// @param program the source program
+    explicit State(const Program* program) : src(program) {}
+
+    /// Runs the transform
+    /// @returns the new program or SkipTransform if the transform is not required
+    ApplyResult Run() {
+        auto& sem = src->Sem();
+
+        bool made_changes = false;
+        for (auto* node : ctx.src->ASTNodes().Objects()) {
+            auto* call = sem.Get<sem::Call>(node);
+            if (call) {
+                // If this is a call to a `subgroupBallot()` builtin, replace it with a call to the
+                // helper function and make a note of the function that we are in.
+                auto* builtin = call->Target()->As<sem::Builtin>();
+                if (builtin && builtin->Type() == core::Function::kSubgroupBallot) {
+                    ctx.Replace(call->Declaration(), b.Call(GetHelper()));
+                    ballot_callers.Add(call->Stmt()->Function());
+                    made_changes = true;
+                }
+            }
+        }
+        if (!made_changes) {
+            return SkipTransform;
+        }
+
+        // Set the subgroup size mask at the start of each entry point that transitively calls
+        // `subgroupBallot()`.
+        for (auto* global : src->AST().GlobalDeclarations()) {
+            auto* func = global->As<Function>();
+            if (func && func->IsEntryPoint() && TransitvelyCallsSubgroupBallot(sem.Get(func))) {
+                SetSubgroupSizeMask(func);
+            }
+        }
+
+        ctx.Clone();
+        return resolver::Resolve(b);
+    }
+
+    /// Get (or create) the `tint_msl_subgroup` helper function.
+    /// @returns the name of the helper function
+    Symbol GetHelper() {
+        if (!ballot_helper) {
+            auto intrinsic = b.Symbols().New("tint_msl_simd_active_threads_mask");
+            subgroup_size_mask = b.Symbols().New("tint_subgroup_size_mask");
+            ballot_helper = b.Symbols().New("tint_msl_subgroup_ballot");
+
+            // Declare the `tint_msl_subgroup_ballot` intrinsic function, which will use the
+            // `simd_active_threads_mask` function to return 64-bit vote.
+            b.Func(intrinsic, Empty, b.ty.vec2<u32>(), nullptr,
+                   Vector{b.ASTNodes().Create<SimdActiveThreadsMask>(b.ID(), b.AllocateNodeID()),
+                          b.Disable(DisabledValidation::kFunctionHasNoBody)});
+
+            // Declare the `tint_subgroup_size_mask` variable.
+            b.GlobalVar(subgroup_size_mask, core::AddressSpace::kPrivate, b.ty.vec4<u32>());
+
+            // Declare the `tint_msl_subgroup_ballot` helper function as follows:
+            //   fn tint_msl_subgroup_ballot() -> vec4u {
+            //     let vote : vec2u = vec4f(tint_simd_active_threads_mask(), 0, 0);
+            //     return (vote & tint_subgroup_size_mask);
+            //   }
+            auto* vote = b.Let(b.Sym(), b.Call(b.ty.vec4<u32>(), b.Call(intrinsic), 0_u, 0_u));
+            b.Func(ballot_helper, Empty, b.ty.vec4<u32>(),
+                   Vector{
+                       b.Decl(vote),
+                       b.Return(b.And(vote, subgroup_size_mask)),
+                   });
+        }
+        return ballot_helper;
+    }
+
+    /// Check if a function directly or transitively calls the `subgroupBallot()` builtin.
+    /// @param func the function to check
+    /// @returns true if the function transitively calls `subgroupBallot()`
+    bool TransitvelyCallsSubgroupBallot(const sem::Function* func) {
+        if (ballot_callers.Contains(func)) {
+            return true;
+        }
+        for (auto* called : func->TransitivelyCalledFunctions()) {
+            if (ballot_callers.Contains(called)) {
+                return true;
+            }
+        }
+        return false;
+    }
+
+    /// Add code to set the `subgroup_size_mask` variable at the start of an entry point.
+    /// @param ep the entry point
+    void SetSubgroupSizeMask(const ast::Function* ep) {
+        // Check the entry point parameters for an existing `subgroup_size` builtin.
+        Symbol subgroup_size;
+        for (auto* param : ep->params) {
+            auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(param->attributes);
+            if (builtin &&
+                src->Sem()
+                        .Get<sem::BuiltinEnumExpression<core::BuiltinValue>>(builtin->builtin)
+                        ->Value() == core::BuiltinValue::kSubgroupSize) {
+                subgroup_size = ctx.Clone(param->name->symbol);
+            }
+        }
+        if (!subgroup_size.IsValid()) {
+            // No `subgroup_size` builtin parameter was found, so add one.
+            subgroup_size = b.Symbols().New("tint_subgroup_size");
+            ctx.InsertBack(ep->params, b.Param(subgroup_size, b.ty.u32(),
+                                               Vector{
+                                                   b.Builtin(core::BuiltinValue::kSubgroupSize),
+                                               }));
+        }
+
+        // Add the following to the top of the entry point:
+        // {
+        //   let gt = subgroup_size > 32;
+        //   subgroup_size_mask[0] = select(1 << (subgroup_size - 1), 0xffffffff, gt);
+        //   subgroup_size_mask[1] = select(0, 1 << (subgroup_size - 33), gt);
+        // }
+        auto* gt = b.Let(b.Sym("gt"), b.GreaterThan(subgroup_size, 32_u));
+        auto* lo = b.Call("select", b.Shl(1_u, b.Sub(subgroup_size, 1_u)), 0xffffffff_u, gt);
+        auto* hi = b.Call("select", 0_u, b.Shl(1_u, b.Sub(subgroup_size, 33_u)), gt);
+        auto* block = b.Block(Vector{
+            b.Decl(gt),
+            b.Assign(b.IndexAccessor(subgroup_size_mask, 0_u), lo),
+            b.Assign(b.IndexAccessor(subgroup_size_mask, 1_u), hi),
+        });
+        ctx.InsertFront(ep->body->statements, block);
+    }
+};
+
+MslSubgroupBallot::MslSubgroupBallot() = default;
+
+MslSubgroupBallot::~MslSubgroupBallot() = default;
+
+Transform::ApplyResult MslSubgroupBallot::Apply(const Program* src,
+                                                const DataMap&,
+                                                DataMap&) const {
+    return State(src).Run();
+}
+
+MslSubgroupBallot::SimdActiveThreadsMask::~SimdActiveThreadsMask() = default;
+
+const MslSubgroupBallot::SimdActiveThreadsMask* MslSubgroupBallot::SimdActiveThreadsMask::Clone(
+    ast::CloneContext& ctx) const {
+    return ctx.dst->ASTNodes().Create<MslSubgroupBallot::SimdActiveThreadsMask>(
+        ctx.dst->ID(), ctx.dst->AllocateNodeID());
+}
+
+}  // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/msl_subgroup_ballot.h b/src/tint/lang/wgsl/ast/transform/msl_subgroup_ballot.h
new file mode 100644
index 0000000..982f3e2
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/msl_subgroup_ballot.h
@@ -0,0 +1,70 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_MSL_SUBGROUP_BALLOT_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_MSL_SUBGROUP_BALLOT_H_
+
+#include <string>
+
+#include "src/tint/lang/wgsl/ast/internal_attribute.h"
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// MslSubgroupBallot is a transform that replaces calls to `subgroupBallot()` with an
+/// implementation that uses MSL's `simd_active_threads_mask()`.
+///
+/// @note Depends on the following transforms to have been run first:
+/// * CanonicalizeEntryPointIO
+class MslSubgroupBallot final : public Castable<MslSubgroupBallot, Transform> {
+  public:
+    /// Constructor
+    MslSubgroupBallot();
+
+    /// Destructor
+    ~MslSubgroupBallot() override;
+
+    /// @copydoc Transform::Apply
+    ApplyResult Apply(const Program* program,
+                      const DataMap& inputs,
+                      DataMap& outputs) const override;
+
+    /// SimdActiveThreadsMask is an InternalAttribute that is used to decorate a stub function so
+    /// that the MSL backend transforms this into calls to the `simd_active_threads_mask` function.
+    class SimdActiveThreadsMask final : public Castable<SimdActiveThreadsMask, InternalAttribute> {
+      public:
+        /// Constructor
+        /// @param pid the identifier of the program that owns this node
+        /// @param nid the unique node identifier
+        SimdActiveThreadsMask(GenerationID pid, NodeID nid) : Base(pid, nid, Empty) {}
+
+        /// Destructor
+        ~SimdActiveThreadsMask() override;
+
+        /// @copydoc InternalAttribute::InternalName
+        std::string InternalName() const override { return "simd_active_threads_mask"; }
+
+        /// Performs a deep clone of this object using the program::CloneContext `ctx`.
+        /// @param ctx the clone context
+        /// @return the newly cloned object
+        const SimdActiveThreadsMask* Clone(CloneContext& ctx) const override;
+    };
+
+  private:
+    struct State;
+};
+
+}  // namespace tint::ast::transform
+
+#endif  // SRC_TINT_LANG_WGSL_AST_TRANSFORM_MSL_SUBGROUP_BALLOT_H_
diff --git a/src/tint/lang/wgsl/ast/transform/msl_subgroup_ballot_test.cc b/src/tint/lang/wgsl/ast/transform/msl_subgroup_ballot_test.cc
new file mode 100644
index 0000000..0ac0105
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/msl_subgroup_ballot_test.cc
@@ -0,0 +1,163 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/msl_subgroup_ballot.h"
+
+#include "src/tint/lang/wgsl/ast/transform/helper_test.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using MslSubgroupBallotTest = TransformTest;
+
+TEST_F(MslSubgroupBallotTest, EmptyModule) {
+    auto* src = "";
+
+    EXPECT_FALSE(ShouldRun<MslSubgroupBallot>(src));
+}
+
+TEST_F(MslSubgroupBallotTest, DirectUse) {
+    auto* src = R"(
+enable chromium_experimental_subgroups;
+
+@compute @workgroup_size(64)
+fn foo() {
+  let x : vec4u = subgroupBallot();
+}
+)";
+
+    auto* expect =
+        R"(
+enable chromium_experimental_subgroups;
+
+@internal(simd_active_threads_mask) @internal(disable_validation__function_has_no_body)
+fn tint_msl_simd_active_threads_mask() -> vec2<u32>
+
+var<private> tint_subgroup_size_mask : vec4<u32>;
+
+fn tint_msl_subgroup_ballot() -> vec4<u32> {
+  let tint_symbol = vec4<u32>(tint_msl_simd_active_threads_mask(), 0u, 0u);
+  return (tint_symbol & tint_subgroup_size_mask);
+}
+
+@compute @workgroup_size(64)
+fn foo(@builtin(subgroup_size) tint_subgroup_size : u32) {
+  {
+    let gt = (tint_subgroup_size > 32u);
+    tint_subgroup_size_mask[0u] = select((1u << (tint_subgroup_size - 1u)), 4294967295u, gt);
+    tint_subgroup_size_mask[1u] = select(0u, (1u << (tint_subgroup_size - 33u)), gt);
+  }
+  let x : vec4u = tint_msl_subgroup_ballot();
+}
+)";
+
+    auto got = Run<MslSubgroupBallot>(src);
+
+    EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(MslSubgroupBallotTest, IndirectUse) {
+    auto* src = R"(
+enable chromium_experimental_subgroups;
+
+fn bar() -> vec4u {
+  return subgroupBallot();
+}
+
+@compute @workgroup_size(64)
+fn foo() {
+  let x: vec4u = bar();
+}
+)";
+
+    auto* expect =
+        R"(
+enable chromium_experimental_subgroups;
+
+@internal(simd_active_threads_mask) @internal(disable_validation__function_has_no_body)
+fn tint_msl_simd_active_threads_mask() -> vec2<u32>
+
+var<private> tint_subgroup_size_mask : vec4<u32>;
+
+fn tint_msl_subgroup_ballot() -> vec4<u32> {
+  let tint_symbol = vec4<u32>(tint_msl_simd_active_threads_mask(), 0u, 0u);
+  return (tint_symbol & tint_subgroup_size_mask);
+}
+
+fn bar() -> vec4u {
+  return tint_msl_subgroup_ballot();
+}
+
+@compute @workgroup_size(64)
+fn foo(@builtin(subgroup_size) tint_subgroup_size : u32) {
+  {
+    let gt = (tint_subgroup_size > 32u);
+    tint_subgroup_size_mask[0u] = select((1u << (tint_subgroup_size - 1u)), 4294967295u, gt);
+    tint_subgroup_size_mask[1u] = select(0u, (1u << (tint_subgroup_size - 33u)), gt);
+  }
+  let x : vec4u = bar();
+}
+)";
+
+    auto got = Run<MslSubgroupBallot>(src);
+
+    EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(MslSubgroupBallotTest, PreexistingSubgroupSizeBuiltin) {
+    auto* src = R"(
+enable chromium_experimental_subgroups;
+
+@compute @workgroup_size(64)
+fn foo(@builtin(workgroup_id) group_id: vec3u,
+       @builtin(subgroup_size) size : u32,
+       @builtin(local_invocation_index) index : u32) {
+  let sz = size;
+  let x : vec4u = subgroupBallot();
+}
+)";
+
+    auto* expect =
+        R"(
+enable chromium_experimental_subgroups;
+
+@internal(simd_active_threads_mask) @internal(disable_validation__function_has_no_body)
+fn tint_msl_simd_active_threads_mask() -> vec2<u32>
+
+var<private> tint_subgroup_size_mask : vec4<u32>;
+
+fn tint_msl_subgroup_ballot() -> vec4<u32> {
+  let tint_symbol = vec4<u32>(tint_msl_simd_active_threads_mask(), 0u, 0u);
+  return (tint_symbol & tint_subgroup_size_mask);
+}
+
+@compute @workgroup_size(64)
+fn foo(@builtin(workgroup_id) group_id : vec3u, @builtin(subgroup_size) size : u32, @builtin(local_invocation_index) index : u32) {
+  {
+    let gt = (size > 32u);
+    tint_subgroup_size_mask[0u] = select((1u << (size - 1u)), 4294967295u, gt);
+    tint_subgroup_size_mask[1u] = select(0u, (1u << (size - 33u)), gt);
+  }
+  let sz = size;
+  let x : vec4u = tint_msl_subgroup_ballot();
+}
+)";
+
+    auto got = Run<MslSubgroupBallot>(src);
+
+    EXPECT_EQ(expect, str(got));
+}
+
+}  // namespace
+}  // namespace tint::ast::transform
diff --git a/test/tint/builtins/gen/literal/subgroupBallot/7e6d0e.wgsl.expected.msl b/test/tint/builtins/gen/literal/subgroupBallot/7e6d0e.wgsl.expected.msl
index 134107e..78d6841 100644
--- a/test/tint/builtins/gen/literal/subgroupBallot/7e6d0e.wgsl.expected.msl
+++ b/test/tint/builtins/gen/literal/subgroupBallot/7e6d0e.wgsl.expected.msl
@@ -1,21 +1,29 @@
-SKIP: FAILED
+#include <metal_stdlib>
+
+using namespace metal;
+struct tint_private_vars_struct {
+  uint4 tint_subgroup_size_mask;
+};
 
 
-enable chromium_experimental_subgroups;
-
-fn subgroupBallot_7e6d0e() {
-  var res : vec4<u32> = subgroupBallot();
-  prevent_dce = res;
+uint4 tint_msl_subgroup_ballot(thread tint_private_vars_struct* const tint_private_vars) {
+  uint4 const tint_symbol = uint4(as_type<uint2>((ulong)simd_active_threads_mask()), 0u, 0u);
+  return (tint_symbol & (*(tint_private_vars)).tint_subgroup_size_mask);
 }
 
-@group(2) @binding(0) var<storage, read_write> prevent_dce : vec4<u32>;
-
-@compute @workgroup_size(1)
-fn compute_main() {
-  subgroupBallot_7e6d0e();
+void subgroupBallot_7e6d0e(thread tint_private_vars_struct* const tint_private_vars, device uint4* const tint_symbol_1) {
+  uint4 res = tint_msl_subgroup_ballot(tint_private_vars);
+  *(tint_symbol_1) = res;
 }
 
-Failed to generate: builtins/gen/literal/subgroupBallot/7e6d0e.wgsl:24:8 error: MSL backend does not support extension 'chromium_experimental_subgroups'
-enable chromium_experimental_subgroups;
-       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+kernel void compute_main(device uint4* tint_symbol_2 [[buffer(0)]], uint tint_subgroup_size [[threads_per_simdgroup]]) {
+  thread tint_private_vars_struct tint_private_vars = {};
+  {
+    bool const gt = (tint_subgroup_size > 32u);
+    tint_private_vars.tint_subgroup_size_mask[0u] = select((1u << (tint_subgroup_size - 1u)), 4294967295u, gt);
+    tint_private_vars.tint_subgroup_size_mask[1u] = select(0u, (1u << (tint_subgroup_size - 33u)), gt);
+  }
+  subgroupBallot_7e6d0e(&(tint_private_vars), tint_symbol_2);
+  return;
+}
 
diff --git a/test/tint/builtins/gen/var/subgroupBallot/7e6d0e.wgsl.expected.msl b/test/tint/builtins/gen/var/subgroupBallot/7e6d0e.wgsl.expected.msl
index b16ab89..78d6841 100644
--- a/test/tint/builtins/gen/var/subgroupBallot/7e6d0e.wgsl.expected.msl
+++ b/test/tint/builtins/gen/var/subgroupBallot/7e6d0e.wgsl.expected.msl
@@ -1,21 +1,29 @@
-SKIP: FAILED
+#include <metal_stdlib>
+
+using namespace metal;
+struct tint_private_vars_struct {
+  uint4 tint_subgroup_size_mask;
+};
 
 
-enable chromium_experimental_subgroups;
-
-fn subgroupBallot_7e6d0e() {
-  var res : vec4<u32> = subgroupBallot();
-  prevent_dce = res;
+uint4 tint_msl_subgroup_ballot(thread tint_private_vars_struct* const tint_private_vars) {
+  uint4 const tint_symbol = uint4(as_type<uint2>((ulong)simd_active_threads_mask()), 0u, 0u);
+  return (tint_symbol & (*(tint_private_vars)).tint_subgroup_size_mask);
 }
 
-@group(2) @binding(0) var<storage, read_write> prevent_dce : vec4<u32>;
-
-@compute @workgroup_size(1)
-fn compute_main() {
-  subgroupBallot_7e6d0e();
+void subgroupBallot_7e6d0e(thread tint_private_vars_struct* const tint_private_vars, device uint4* const tint_symbol_1) {
+  uint4 res = tint_msl_subgroup_ballot(tint_private_vars);
+  *(tint_symbol_1) = res;
 }
 
-Failed to generate: builtins/gen/var/subgroupBallot/7e6d0e.wgsl:24:8 error: MSL backend does not support extension 'chromium_experimental_subgroups'
-enable chromium_experimental_subgroups;
-       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+kernel void compute_main(device uint4* tint_symbol_2 [[buffer(0)]], uint tint_subgroup_size [[threads_per_simdgroup]]) {
+  thread tint_private_vars_struct tint_private_vars = {};
+  {
+    bool const gt = (tint_subgroup_size > 32u);
+    tint_private_vars.tint_subgroup_size_mask[0u] = select((1u << (tint_subgroup_size - 1u)), 4294967295u, gt);
+    tint_private_vars.tint_subgroup_size_mask[1u] = select(0u, (1u << (tint_subgroup_size - 33u)), gt);
+  }
+  subgroupBallot_7e6d0e(&(tint_private_vars), tint_symbol_2);
+  return;
+}
 
diff --git a/test/tint/shader_io/compute_subgroup_builtins.wgsl.expected.msl b/test/tint/shader_io/compute_subgroup_builtins.wgsl.expected.msl
index 6e4dcdd..c79a68a 100644
--- a/test/tint/shader_io/compute_subgroup_builtins.wgsl.expected.msl
+++ b/test/tint/shader_io/compute_subgroup_builtins.wgsl.expected.msl
@@ -1,16 +1,29 @@
-SKIP: FAILED
+#include <metal_stdlib>
 
+using namespace metal;
 
-enable chromium_experimental_subgroups;
+template<typename T, size_t N>
+struct tint_array {
+    const constant T& operator[](size_t i) const constant { return elements[i]; }
+    device T& operator[](size_t i) device { return elements[i]; }
+    const device T& operator[](size_t i) const device { return elements[i]; }
+    thread T& operator[](size_t i) thread { return elements[i]; }
+    const thread T& operator[](size_t i) const thread { return elements[i]; }
+    threadgroup T& operator[](size_t i) threadgroup { return elements[i]; }
+    const threadgroup T& operator[](size_t i) const threadgroup { return elements[i]; }
+    T elements[N];
+};
 
-@group(0) @binding(0) var<storage, read_write> output : array<u32>;
+struct tint_symbol_3 {
+  /* 0x0000 */ tint_array<uint, 1> arr;
+};
 
-@compute @workgroup_size(1)
-fn tint_symbol(@builtin(subgroup_invocation_id) subgroup_invocation_id : u32, @builtin(subgroup_size) subgroup_size : u32) {
-  output[subgroup_invocation_id] = subgroup_size;
+void tint_symbol_inner(uint subgroup_invocation_id, uint subgroup_size, device tint_array<uint, 1>* const tint_symbol_1) {
+  (*(tint_symbol_1))[subgroup_invocation_id] = subgroup_size;
 }
 
-Failed to generate: shader_io/compute_subgroup_builtins.wgsl:1:8 error: MSL backend does not support extension 'chromium_experimental_subgroups'
-enable chromium_experimental_subgroups;
-       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+kernel void tint_symbol(device tint_symbol_3* tint_symbol_2 [[buffer(0)]], uint subgroup_invocation_id [[thread_index_in_simdgroup]], uint subgroup_size [[threads_per_simdgroup]]) {
+  tint_symbol_inner(subgroup_invocation_id, subgroup_size, &((*(tint_symbol_2)).arr));
+  return;
+}
 
diff --git a/test/tint/shader_io/compute_subgroup_builtins_struct.wgsl.expected.msl b/test/tint/shader_io/compute_subgroup_builtins_struct.wgsl.expected.msl
index 3c8b497..d5f1182 100644
--- a/test/tint/shader_io/compute_subgroup_builtins_struct.wgsl.expected.msl
+++ b/test/tint/shader_io/compute_subgroup_builtins_struct.wgsl.expected.msl
@@ -1,23 +1,35 @@
-SKIP: FAILED
+#include <metal_stdlib>
 
+using namespace metal;
 
-enable chromium_experimental_subgroups;
+template<typename T, size_t N>
+struct tint_array {
+    const constant T& operator[](size_t i) const constant { return elements[i]; }
+    device T& operator[](size_t i) device { return elements[i]; }
+    const device T& operator[](size_t i) const device { return elements[i]; }
+    thread T& operator[](size_t i) thread { return elements[i]; }
+    const thread T& operator[](size_t i) const thread { return elements[i]; }
+    threadgroup T& operator[](size_t i) threadgroup { return elements[i]; }
+    const threadgroup T& operator[](size_t i) const threadgroup { return elements[i]; }
+    T elements[N];
+};
 
-@group(0) @binding(0) var<storage, read_write> output : array<u32>;
+struct tint_symbol_4 {
+  /* 0x0000 */ tint_array<uint, 1> arr;
+};
 
 struct ComputeInputs {
-  @builtin(subgroup_invocation_id)
-  subgroup_invocation_id : u32,
-  @builtin(subgroup_size)
-  subgroup_size : u32,
+  uint subgroup_invocation_id;
+  uint subgroup_size;
+};
+
+void tint_symbol_inner(ComputeInputs inputs, device tint_array<uint, 1>* const tint_symbol_2) {
+  (*(tint_symbol_2))[inputs.subgroup_invocation_id] = inputs.subgroup_size;
 }
 
-@compute @workgroup_size(1)
-fn tint_symbol(inputs : ComputeInputs) {
-  output[inputs.subgroup_invocation_id] = inputs.subgroup_size;
+kernel void tint_symbol(device tint_symbol_4* tint_symbol_3 [[buffer(0)]], uint subgroup_invocation_id [[thread_index_in_simdgroup]], uint subgroup_size [[threads_per_simdgroup]]) {
+  ComputeInputs const tint_symbol_1 = {.subgroup_invocation_id=subgroup_invocation_id, .subgroup_size=subgroup_size};
+  tint_symbol_inner(tint_symbol_1, &((*(tint_symbol_3)).arr));
+  return;
 }
 
-Failed to generate: shader_io/compute_subgroup_builtins_struct.wgsl:1:8 error: MSL backend does not support extension 'chromium_experimental_subgroups'
-enable chromium_experimental_subgroups;
-       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-