[msl] Add SimdBallot transform

Replace calls to subgroupBallot() with a call to a helper function
that calls an MSL intrinsic for simd_ballot and masks the result.

Bug: 42251016
Change-Id: I0c6b5266018e6a264c065c4ef73ac0f342075899
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/204676
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/tint/lang/msl/builtin_fn.cc b/src/tint/lang/msl/builtin_fn.cc
index 8b11556..a434592 100644
--- a/src/tint/lang/msl/builtin_fn.cc
+++ b/src/tint/lang/msl/builtin_fn.cc
@@ -106,6 +106,8 @@
             return "sign";
         case BuiltinFn::kThreadgroupBarrier:
             return "threadgroup_barrier";
+        case BuiltinFn::kSimdBallot:
+            return "simd_ballot";
     }
     return "<unknown>";
 }
diff --git a/src/tint/lang/msl/builtin_fn.h b/src/tint/lang/msl/builtin_fn.h
index d1f7ea3..b044e7c 100644
--- a/src/tint/lang/msl/builtin_fn.h
+++ b/src/tint/lang/msl/builtin_fn.h
@@ -79,6 +79,7 @@
     kModf,
     kSign,
     kThreadgroupBarrier,
+    kSimdBallot,
     kNone,
 };
 
diff --git a/src/tint/lang/msl/intrinsic/data.cc b/src/tint/lang/msl/intrinsic/data.cc
index 65fd564..6de5735 100644
--- a/src/tint/lang/msl/intrinsic/data.cc
+++ b/src/tint/lang/msl/intrinsic/data.cc
@@ -2705,6 +2705,11 @@
     /* usage */ core::ParameterUsage::kTexture,
     /* matcher_indices */ MatcherIndicesIndex(9),
   },
+  {
+    /* [309] */
+    /* usage */ core::ParameterUsage::kNone,
+    /* matcher_indices */ MatcherIndicesIndex(4),
+  },
 };
 
 static_assert(ParameterIndex::CanIndex(kParameters),
@@ -4684,6 +4689,17 @@
   },
   {
     /* [171] */
+    /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
+    /* num_parameters */ 1,
+    /* num_explicit_templates */ 0,
+    /* num_templates   */ 0,
+    /* templates */ TemplateIndex(/* invalid */),
+    /* parameters */ ParameterIndex(309),
+    /* return_matcher_indices */ MatcherIndicesIndex(121),
+    /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
+  },
+  {
+    /* [172] */
     /* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
     /* num_parameters */ 2,
     /* num_explicit_templates */ 0,
@@ -5040,6 +5056,12 @@
     /* num overloads */ 1,
     /* overloads */ OverloadIndex(170),
   },
+  {
+    /* [32] */
+    /* fn simd_ballot(bool) -> vec2<u32> */
+    /* num overloads */ 1,
+    /* overloads */ OverloadIndex(171),
+  },
 };
 
 constexpr IntrinsicInfo kBinaryOperators[] = {
@@ -5047,13 +5069,13 @@
     /* [0] */
     /* op +[T : iu8](T, T) -> T */
     /* num overloads */ 1,
-    /* overloads */ OverloadIndex(171),
+    /* overloads */ OverloadIndex(172),
   },
   {
     /* [1] */
     /* op *[T : iu8](T, T) -> T */
     /* num overloads */ 1,
-    /* overloads */ OverloadIndex(171),
+    /* overloads */ OverloadIndex(172),
   },
 };
 constexpr uint8_t kBinaryOperatorPlus = 0;
diff --git a/src/tint/lang/msl/msl.def b/src/tint/lang/msl/msl.def
index a81b4a9..468a53f 100644
--- a/src/tint/lang/msl/msl.def
+++ b/src/tint/lang/msl/msl.def
@@ -339,6 +339,8 @@
 implicit(N: num, T: f32_f16) fn sign(vec<N, T>) -> vec<N, T>
 @stage("compute") fn threadgroup_barrier(u32)
 
+@stage("fragment", "compute") fn simd_ballot(bool) -> vec2<u32>
+
 ////////////////////////////////////////////////////////////////////////////////
 // Binary Operators                                                           //
 ////////////////////////////////////////////////////////////////////////////////
diff --git a/src/tint/lang/msl/writer/printer/printer.cc b/src/tint/lang/msl/writer/printer/printer.cc
index c5a18b2..5ab53b9 100644
--- a/src/tint/lang/msl/writer/printer/printer.cc
+++ b/src/tint/lang/msl/writer/printer/printer.cc
@@ -866,6 +866,11 @@
 
             out << ")";
             return;
+        } else if (c->Func() == msl::BuiltinFn::kSimdBallot) {
+            out << "as_type<uint2>((simd_vote::vote_t)simd_ballot(";
+            EmitValue(out, c->Args()[0]);
+            out << "))";
+            return;
         }
 
         out << c->Func() << "(";
diff --git a/src/tint/lang/msl/writer/raise/BUILD.bazel b/src/tint/lang/msl/writer/raise/BUILD.bazel
index 3e019ea..3291fb7 100644
--- a/src/tint/lang/msl/writer/raise/BUILD.bazel
+++ b/src/tint/lang/msl/writer/raise/BUILD.bazel
@@ -45,6 +45,7 @@
     "packed_vec3.cc",
     "raise.cc",
     "shader_io.cc",
+    "simd_ballot.cc",
   ],
   hdrs = [
     "binary_polyfill.h",
@@ -53,6 +54,7 @@
     "packed_vec3.h",
     "raise.h",
     "shader_io.h",
+    "simd_ballot.h",
   ],
   deps = [
     "//src/tint/api/common",
@@ -100,6 +102,7 @@
     "module_scope_vars_test.cc",
     "packed_vec3_test.cc",
     "shader_io_test.cc",
+    "simd_ballot_test.cc",
   ],
   deps = [
     "//src/tint/api/common",
diff --git a/src/tint/lang/msl/writer/raise/BUILD.cmake b/src/tint/lang/msl/writer/raise/BUILD.cmake
index 39d4546..7d1c017 100644
--- a/src/tint/lang/msl/writer/raise/BUILD.cmake
+++ b/src/tint/lang/msl/writer/raise/BUILD.cmake
@@ -53,6 +53,8 @@
   lang/msl/writer/raise/raise.h
   lang/msl/writer/raise/shader_io.cc
   lang/msl/writer/raise/shader_io.h
+  lang/msl/writer/raise/simd_ballot.cc
+  lang/msl/writer/raise/simd_ballot.h
 )
 
 tint_target_add_dependencies(tint_lang_msl_writer_raise lib
@@ -107,6 +109,7 @@
   lang/msl/writer/raise/module_scope_vars_test.cc
   lang/msl/writer/raise/packed_vec3_test.cc
   lang/msl/writer/raise/shader_io_test.cc
+  lang/msl/writer/raise/simd_ballot_test.cc
 )
 
 tint_target_add_dependencies(tint_lang_msl_writer_raise_test test
diff --git a/src/tint/lang/msl/writer/raise/BUILD.gn b/src/tint/lang/msl/writer/raise/BUILD.gn
index d6973f0..4b7e044 100644
--- a/src/tint/lang/msl/writer/raise/BUILD.gn
+++ b/src/tint/lang/msl/writer/raise/BUILD.gn
@@ -57,6 +57,8 @@
       "raise.h",
       "shader_io.cc",
       "shader_io.h",
+      "simd_ballot.cc",
+      "simd_ballot.h",
     ]
     deps = [
       "${dawn_root}/src/utils:utils",
@@ -102,6 +104,7 @@
         "module_scope_vars_test.cc",
         "packed_vec3_test.cc",
         "shader_io_test.cc",
+        "simd_ballot_test.cc",
       ]
       deps = [
         "${dawn_root}/src/utils:utils",
diff --git a/src/tint/lang/msl/writer/raise/raise.cc b/src/tint/lang/msl/writer/raise/raise.cc
index 49e6311..0230db1 100644
--- a/src/tint/lang/msl/writer/raise/raise.cc
+++ b/src/tint/lang/msl/writer/raise/raise.cc
@@ -51,6 +51,7 @@
 #include "src/tint/lang/msl/writer/raise/module_scope_vars.h"
 #include "src/tint/lang/msl/writer/raise/packed_vec3.h"
 #include "src/tint/lang/msl/writer/raise/shader_io.h"
+#include "src/tint/lang/msl/writer/raise/simd_ballot.h"
 
 namespace tint::msl::writer {
 
@@ -130,6 +131,7 @@
     RUN_TRANSFORM(raise::ShaderIO, module,
                   raise::ShaderIOConfig{options.emit_vertex_point_size, options.fixed_sample_mask});
     RUN_TRANSFORM(raise::PackedVec3, module);
+    RUN_TRANSFORM(raise::SimdBallot, module);
     RUN_TRANSFORM(raise::ModuleScopeVars, module);
     RUN_TRANSFORM(raise::BinaryPolyfill, module);
     RUN_TRANSFORM(raise::BuiltinPolyfill, module);
diff --git a/src/tint/lang/msl/writer/raise/simd_ballot.cc b/src/tint/lang/msl/writer/raise/simd_ballot.cc
new file mode 100644
index 0000000..565d85d
--- /dev/null
+++ b/src/tint/lang/msl/writer/raise/simd_ballot.cc
@@ -0,0 +1,172 @@
+// Copyright 2024 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+//    list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+//    this list of conditions and the following disclaimer in the documentation
+//    and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+//    contributors may be used to endorse or promote products derived from
+//    this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include "src/tint/lang/msl/writer/raise/simd_ballot.h"
+
+#include <utility>
+
+#include "src/tint/lang/core/ir/builder.h"
+#include "src/tint/lang/core/ir/transform/common/referenced_module_vars.h"
+#include "src/tint/lang/core/ir/validator.h"
+#include "src/tint/lang/msl/ir/builtin_call.h"
+
+namespace tint::msl::writer::raise {
+namespace {
+
+using namespace tint::core::fluent_types;  // NOLINT
+
+/// PIMPL state for the transform.
+struct State {
+    /// The IR module.
+    core::ir::Module& ir;
+
+    /// The IR builder.
+    core::ir::Builder b{ir};
+
+    /// The type manager.
+    core::type::Manager& ty{ir.Types()};
+
+    /// The subgroupBallot polyfill function.
+    core::ir::Function* subgroup_ballot_polyfill = nullptr;
+
+    /// The subgroup_size_mask module-scope variable.
+    core::ir::Var* subgroup_size_mask = nullptr;
+
+    /// Process the module.
+    void Process() {
+        // Find calls to `subgroupBallot`.
+        for (auto* inst : ir.Instructions()) {
+            if (auto* call = inst->As<core::ir::CoreBuiltinCall>()) {
+                if (call->Func() == core::BuiltinFn::kSubgroupBallot) {
+                    Replace(call);
+                }
+            }
+        }
+
+        // Set the subgroup size mask value from all entry points that use it.
+        core::ir::ReferencedModuleVars refs(ir, [&](const core::ir::Var* var) {  //
+            return var == subgroup_size_mask;
+        });
+        for (auto func : ir.functions) {
+            if (func->Stage() != core::ir::Function::PipelineStage::kUndefined) {
+                if (refs.TransitiveReferences(func).Contains(subgroup_size_mask)) {
+                    SetSubgroupSizeMaskForEntryPoint(func);
+                }
+            }
+        }
+    }
+
+    /// Replace a call to subgroupBallot with a call to a polyfill function.
+    void Replace(core::ir::CoreBuiltinCall* call) {
+        b.InsertBefore(call, [&] {
+            b.CallWithResult(call->DetachResult(), SubgroupBallotPolyfill(), call->Args()[0]);
+        });
+        call->Destroy();
+    }
+
+    /// Get (or create, on first call) the polyfill function.
+    core::ir::Function* SubgroupBallotPolyfill() {
+        if (subgroup_ballot_polyfill) {
+            // The polyfill has already been created.
+            return subgroup_ballot_polyfill;
+        }
+
+        // Declare the subgroup_size_mask variable that we need.
+        b.Append(ir.root_block, [&] {
+            subgroup_size_mask = b.Var<private_, vec2<u32>>("tint_subgroup_size_mask");
+        });
+
+        // Create the polyfill function, which looks like this:
+        //   fn tint_subgroup_ballot(pred: bool) -> vec4u {
+        //     let simd_vote: vec2u = msl.simd_ballot(pred);
+        //     return vec4u(simd_vote & tint_subgroup_size_mask, 0, 0);
+        //   }
+        auto* pred = b.FunctionParam("pred", ty.bool_());
+        subgroup_ballot_polyfill = b.Function("tint_subgroup_ballot", ty.vec4<u32>());
+        subgroup_ballot_polyfill->SetParams({pred});
+        b.Append(subgroup_ballot_polyfill->Block(), [&] {
+            auto* simd_vote =
+                b.Call<msl::ir::BuiltinCall>(ty.vec2<u32>(), msl::BuiltinFn::kSimdBallot, pred);
+            auto* masked = b.And<vec2<u32>>(simd_vote, b.Load(subgroup_size_mask));
+            auto* result = b.Construct(ty.vec4<u32>(), masked, u32(0), u32(0));
+            b.Return(subgroup_ballot_polyfill, result);
+        });
+
+        return subgroup_ballot_polyfill;
+    }
+
+    /// Set the subgroup_size_mask variable from an entry point.
+    void SetSubgroupSizeMaskForEntryPoint(core::ir::Function* ep) {
+        // Check if there is a user provided subgroup_size builtin.
+        core::ir::FunctionParam* subgroup_size = nullptr;
+        for (auto* param : ep->Params()) {
+            if (param->Attributes().builtin == core::BuiltinValue::kSubgroupSize) {
+                subgroup_size = param;
+                break;
+            }
+        }
+        if (!subgroup_size) {
+            // No user defined subgroup_size builtin was found, so create our own.
+            subgroup_size = b.FunctionParam("tint_subgroup_size", ty.u32());
+            subgroup_size->SetBuiltin(core::BuiltinValue::kSubgroupSize);
+            ep->AppendParam(subgroup_size);
+        }
+
+        // Set the subgroup_size_mask based on the subgroup_size:
+        //   let size_gt_32 = (subgroup_size > 32u);
+        //   let high = select(4294967295u >> (32u - subgroup_size), 4294967295u, size_gt_32);
+        //   let low  = select(0u, (4294967295u >> (64u - subgroup_size)), size_gt_32);
+        //   tint_subgroup_size_mask[0u] = high;
+        //   tint_subgroup_size_mask[1u] = low;
+        b.InsertBefore(ep->Block()->Front(), [&] {
+            auto* gt32 = b.GreaterThan<bool>(subgroup_size, u32(32));
+            auto* high_mask =
+                b.ShiftRight<u32>(u32::Highest(), b.Subtract<u32>(u32(32), subgroup_size));
+            auto* high = b.Call<u32>(core::BuiltinFn::kSelect, high_mask, u32::Highest(), gt32);
+            auto* low_mask =
+                b.ShiftRight<u32>(u32::Highest(), b.Subtract<u32>(u32(64), subgroup_size));
+            auto* low = b.Call<u32>(core::BuiltinFn::kSelect, u32(0), low_mask, gt32);
+            b.StoreVectorElement(subgroup_size_mask, u32(0), high);
+            b.StoreVectorElement(subgroup_size_mask, u32(1), low);
+        });
+    }
+};
+
+}  // namespace
+
+Result<SuccessType> SimdBallot(core::ir::Module& ir) {
+    auto result = ValidateAndDumpIfNeeded(ir, "SimdBallot transform");
+    if (result != Success) {
+        return result.Failure();
+    }
+
+    State{ir}.Process();
+
+    return Success;
+}
+
+}  // namespace tint::msl::writer::raise
diff --git a/src/tint/lang/msl/writer/raise/simd_ballot.h b/src/tint/lang/msl/writer/raise/simd_ballot.h
new file mode 100644
index 0000000..d34519d
--- /dev/null
+++ b/src/tint/lang/msl/writer/raise/simd_ballot.h
@@ -0,0 +1,48 @@
+// Copyright 2024 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+//    list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+//    this list of conditions and the following disclaimer in the documentation
+//    and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+//    contributors may be used to endorse or promote products derived from
+//    this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef SRC_TINT_LANG_MSL_WRITER_RAISE_SIMD_BALLOT_H_
+#define SRC_TINT_LANG_MSL_WRITER_RAISE_SIMD_BALLOT_H_
+
+#include "src/tint/utils/result/result.h"
+
+// Forward declarations.
+namespace tint::core::ir {
+class Module;
+}  // namespace tint::core::ir
+
+namespace tint::msl::writer::raise {
+
+/// This transform replaces core `subgroupBallot` calls with calls to a `simd_ballot` MSL intrinsic,
+/// adding conversion and masking operations to produce the correct result.
+/// @param module the module to transform
+/// @returns success or failure
+Result<SuccessType> SimdBallot(core::ir::Module& module);
+
+}  // namespace tint::msl::writer::raise
+
+#endif  // SRC_TINT_LANG_MSL_WRITER_RAISE_SIMD_BALLOT_H_
diff --git a/src/tint/lang/msl/writer/raise/simd_ballot_test.cc b/src/tint/lang/msl/writer/raise/simd_ballot_test.cc
new file mode 100644
index 0000000..3e8e1f7
--- /dev/null
+++ b/src/tint/lang/msl/writer/raise/simd_ballot_test.cc
@@ -0,0 +1,317 @@
+// Copyright 2024 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+//    list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+//    this list of conditions and the following disclaimer in the documentation
+//    and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+//    contributors may be used to endorse or promote products derived from
+//    this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include "src/tint/lang/msl/writer/raise/simd_ballot.h"
+
+#include "gtest/gtest.h"
+
+#include "src/tint/lang/core/fluent_types.h"
+#include "src/tint/lang/core/ir/transform/helper_test.h"
+#include "src/tint/lang/core/number.h"
+
+using namespace tint::core::fluent_types;     // NOLINT
+using namespace tint::core::number_suffixes;  // NOLINT
+
+namespace tint::msl::writer::raise {
+namespace {
+
+using MslWriter_SimdBallotTest = core::ir::transform::TransformTest;
+
+TEST_F(MslWriter_SimdBallotTest, SimdBallot_WithUserDeclaredSubgroupSize) {
+    auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+    auto* subgroup_size = b.FunctionParam("user_subgroup_size", ty.u32());
+    func->SetParams({subgroup_size});
+    b.Append(func->Block(), [&] {  //
+        b.Call<vec4<u32>>(core::BuiltinFn::kSubgroupBallot, true);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%foo = @fragment func(%user_subgroup_size:u32):void {
+  $B1: {
+    %3:vec4<u32> = subgroupBallot true
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+$B1: {  # root
+  %tint_subgroup_size_mask:ptr<private, vec2<u32>, read_write> = var
+}
+
+%foo = @fragment func(%user_subgroup_size:u32, %tint_subgroup_size:u32 [@subgroup_size]):void {
+  $B2: {
+    %5:bool = gt %tint_subgroup_size, 32u
+    %6:u32 = sub 32u, %tint_subgroup_size
+    %7:u32 = shr 4294967295u, %6
+    %8:u32 = select %7, 4294967295u, %5
+    %9:u32 = sub 64u, %tint_subgroup_size
+    %10:u32 = shr 4294967295u, %9
+    %11:u32 = select 0u, %10, %5
+    store_vector_element %tint_subgroup_size_mask, 0u, %8
+    store_vector_element %tint_subgroup_size_mask, 1u, %11
+    %12:vec4<u32> = call %tint_subgroup_ballot, true
+    ret
+  }
+}
+%tint_subgroup_ballot = func(%pred:bool):vec4<u32> {
+  $B3: {
+    %15:vec2<u32> = msl.simd_ballot %pred
+    %16:vec2<u32> = load %tint_subgroup_size_mask
+    %17:vec2<u32> = and %15, %16
+    %18:vec4<u32> = construct %17, 0u, 0u
+    ret %18
+  }
+}
+)";
+
+    Run(SimdBallot);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(MslWriter_SimdBallotTest, SimdBallot_WithoutUserDeclaredSubgroupSize) {
+    auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+    b.Append(func->Block(), [&] {  //
+        b.Call<vec4<u32>>(core::BuiltinFn::kSubgroupBallot, true);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%foo = @fragment func():void {
+  $B1: {
+    %2:vec4<u32> = subgroupBallot true
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+$B1: {  # root
+  %tint_subgroup_size_mask:ptr<private, vec2<u32>, read_write> = var
+}
+
+%foo = @fragment func(%tint_subgroup_size:u32 [@subgroup_size]):void {
+  $B2: {
+    %4:bool = gt %tint_subgroup_size, 32u
+    %5:u32 = sub 32u, %tint_subgroup_size
+    %6:u32 = shr 4294967295u, %5
+    %7:u32 = select %6, 4294967295u, %4
+    %8:u32 = sub 64u, %tint_subgroup_size
+    %9:u32 = shr 4294967295u, %8
+    %10:u32 = select 0u, %9, %4
+    store_vector_element %tint_subgroup_size_mask, 0u, %7
+    store_vector_element %tint_subgroup_size_mask, 1u, %10
+    %11:vec4<u32> = call %tint_subgroup_ballot, true
+    ret
+  }
+}
+%tint_subgroup_ballot = func(%pred:bool):vec4<u32> {
+  $B3: {
+    %14:vec2<u32> = msl.simd_ballot %pred
+    %15:vec2<u32> = load %tint_subgroup_size_mask
+    %16:vec2<u32> = and %14, %15
+    %17:vec4<u32> = construct %16, 0u, 0u
+    ret %17
+  }
+}
+)";
+
+    Run(SimdBallot);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(MslWriter_SimdBallotTest, SimdBallot_InHelperFunction) {
+    auto* foo = b.Function("foo", ty.vec4<u32>());
+    auto* pred = b.FunctionParam("pred", ty.bool_());
+    foo->SetParams({pred});
+    b.Append(foo->Block(), [&] {  //
+        auto* result = b.Call<vec4<u32>>(core::BuiltinFn::kSubgroupBallot, pred);
+        b.Return(foo, result);
+    });
+
+    auto* ep1 = b.Function("ep1", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+    auto* subgroup_size = b.FunctionParam("user_subgroup_size", ty.u32());
+    ep1->SetParams({subgroup_size});
+    b.Append(ep1->Block(), [&] {  //
+        b.Call<vec4<u32>>(foo, true);
+        b.Return(ep1);
+    });
+
+    auto* ep2 = b.Function("ep2", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+    b.Append(ep2->Block(), [&] {  //
+        b.Call<vec4<u32>>(foo, false);
+        b.Return(ep2);
+    });
+
+    auto* src = R"(
+%foo = func(%pred:bool):vec4<u32> {
+  $B1: {
+    %3:vec4<u32> = subgroupBallot %pred
+    ret %3
+  }
+}
+%ep1 = @fragment func(%user_subgroup_size:u32):void {
+  $B2: {
+    %6:vec4<u32> = call %foo, true
+    ret
+  }
+}
+%ep2 = @fragment func():void {
+  $B3: {
+    %8:vec4<u32> = call %foo, false
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+$B1: {  # root
+  %tint_subgroup_size_mask:ptr<private, vec2<u32>, read_write> = var
+}
+
+%foo = func(%pred:bool):vec4<u32> {
+  $B2: {
+    %4:vec4<u32> = call %tint_subgroup_ballot, %pred
+    ret %4
+  }
+}
+%ep1 = @fragment func(%user_subgroup_size:u32, %tint_subgroup_size:u32 [@subgroup_size]):void {
+  $B3: {
+    %9:bool = gt %tint_subgroup_size, 32u
+    %10:u32 = sub 32u, %tint_subgroup_size
+    %11:u32 = shr 4294967295u, %10
+    %12:u32 = select %11, 4294967295u, %9
+    %13:u32 = sub 64u, %tint_subgroup_size
+    %14:u32 = shr 4294967295u, %13
+    %15:u32 = select 0u, %14, %9
+    store_vector_element %tint_subgroup_size_mask, 0u, %12
+    store_vector_element %tint_subgroup_size_mask, 1u, %15
+    %16:vec4<u32> = call %foo, true
+    ret
+  }
+}
+%ep2 = @fragment func(%tint_subgroup_size_1:u32 [@subgroup_size]):void {  # %tint_subgroup_size_1: 'tint_subgroup_size'
+  $B4: {
+    %19:bool = gt %tint_subgroup_size_1, 32u
+    %20:u32 = sub 32u, %tint_subgroup_size_1
+    %21:u32 = shr 4294967295u, %20
+    %22:u32 = select %21, 4294967295u, %19
+    %23:u32 = sub 64u, %tint_subgroup_size_1
+    %24:u32 = shr 4294967295u, %23
+    %25:u32 = select 0u, %24, %19
+    store_vector_element %tint_subgroup_size_mask, 0u, %22
+    store_vector_element %tint_subgroup_size_mask, 1u, %25
+    %26:vec4<u32> = call %foo, false
+    ret
+  }
+}
+%tint_subgroup_ballot = func(%pred_1:bool):vec4<u32> {  # %pred_1: 'pred'
+  $B5: {
+    %28:vec2<u32> = msl.simd_ballot %pred_1
+    %29:vec2<u32> = load %tint_subgroup_size_mask
+    %30:vec2<u32> = and %28, %29
+    %31:vec4<u32> = construct %30, 0u, 0u
+    ret %31
+  }
+}
+)";
+
+    Run(SimdBallot);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(MslWriter_SimdBallotTest, SimdBallot_MultipleCalls) {
+    auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+    b.Append(func->Block(), [&] {  //
+        b.Call<vec4<u32>>(core::BuiltinFn::kSubgroupBallot, true);
+        b.Call<vec4<u32>>(core::BuiltinFn::kSubgroupBallot, false);
+        b.Call<vec4<u32>>(core::BuiltinFn::kSubgroupBallot, true);
+        b.Call<vec4<u32>>(core::BuiltinFn::kSubgroupBallot, false);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%foo = @fragment func():void {
+  $B1: {
+    %2:vec4<u32> = subgroupBallot true
+    %3:vec4<u32> = subgroupBallot false
+    %4:vec4<u32> = subgroupBallot true
+    %5:vec4<u32> = subgroupBallot false
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+$B1: {  # root
+  %tint_subgroup_size_mask:ptr<private, vec2<u32>, read_write> = var
+}
+
+%foo = @fragment func(%tint_subgroup_size:u32 [@subgroup_size]):void {
+  $B2: {
+    %4:bool = gt %tint_subgroup_size, 32u
+    %5:u32 = sub 32u, %tint_subgroup_size
+    %6:u32 = shr 4294967295u, %5
+    %7:u32 = select %6, 4294967295u, %4
+    %8:u32 = sub 64u, %tint_subgroup_size
+    %9:u32 = shr 4294967295u, %8
+    %10:u32 = select 0u, %9, %4
+    store_vector_element %tint_subgroup_size_mask, 0u, %7
+    store_vector_element %tint_subgroup_size_mask, 1u, %10
+    %11:vec4<u32> = call %tint_subgroup_ballot, true
+    %13:vec4<u32> = call %tint_subgroup_ballot, false
+    %14:vec4<u32> = call %tint_subgroup_ballot, true
+    %15:vec4<u32> = call %tint_subgroup_ballot, false
+    ret
+  }
+}
+%tint_subgroup_ballot = func(%pred:bool):vec4<u32> {
+  $B3: {
+    %17:vec2<u32> = msl.simd_ballot %pred
+    %18:vec2<u32> = load %tint_subgroup_size_mask
+    %19:vec2<u32> = and %17, %18
+    %20:vec4<u32> = construct %19, 0u, 0u
+    ret %20
+  }
+}
+)";
+
+    Run(SimdBallot);
+
+    EXPECT_EQ(expect, str());
+}
+
+}  // namespace
+}  // namespace tint::msl::writer::raise
diff --git a/test/tint/builtins/gen/literal/subgroupBallot/1a8251.wgsl.expected.ir.msl b/test/tint/builtins/gen/literal/subgroupBallot/1a8251.wgsl.expected.ir.msl
index e94069d..65c71e1 100644
--- a/test/tint/builtins/gen/literal/subgroupBallot/1a8251.wgsl.expected.ir.msl
+++ b/test/tint/builtins/gen/literal/subgroupBallot/1a8251.wgsl.expected.ir.msl
@@ -1,9 +1,37 @@
-SKIP: FAILED
+#include <metal_stdlib>
+using namespace metal;
 
-../../src/tint/lang/msl/writer/printer/printer.cc:988 internal compiler error: TINT_UNREACHABLE unhandled: subgroupBallot
-********************************************************************
-*  The tint shader compiler has encountered an unexpected error.   *
-*                                                                  *
-*  Please help us fix this issue by submitting a bug report at     *
-*  crbug.com/tint with the source program that triggered the bug.  *
-********************************************************************
+struct tint_module_vars_struct {
+  device uint4* prevent_dce;
+  thread uint2* tint_subgroup_size_mask;
+};
+
+uint4 tint_subgroup_ballot(bool pred, tint_module_vars_struct tint_module_vars) {
+  uint2 const v = as_type<uint2>((simd_vote::vote_t)simd_ballot(pred));
+  return uint4((v & (*tint_module_vars.tint_subgroup_size_mask)), 0u, 0u);
+}
+
+uint4 subgroupBallot_1a8251(tint_module_vars_struct tint_module_vars) {
+  uint4 res = tint_subgroup_ballot(true, tint_module_vars);
+  return res;
+}
+
+fragment void fragment_main(uint tint_subgroup_size [[threads_per_simdgroup]], device uint4* prevent_dce [[buffer(0)]]) {
+  thread uint2 tint_subgroup_size_mask = 0u;
+  tint_module_vars_struct const tint_module_vars = tint_module_vars_struct{.prevent_dce=prevent_dce, .tint_subgroup_size_mask=(&tint_subgroup_size_mask)};
+  uint const v_1 = select((4294967295u >> (32u - tint_subgroup_size)), 4294967295u, (tint_subgroup_size > 32u));
+  uint const v_2 = select(0u, (4294967295u >> (64u - tint_subgroup_size)), (tint_subgroup_size > 32u));
+  (*tint_module_vars.tint_subgroup_size_mask)[0u] = v_1;
+  (*tint_module_vars.tint_subgroup_size_mask)[1u] = v_2;
+  (*tint_module_vars.prevent_dce) = subgroupBallot_1a8251(tint_module_vars);
+}
+
+kernel void compute_main(uint tint_subgroup_size [[threads_per_simdgroup]], device uint4* prevent_dce [[buffer(0)]]) {
+  thread uint2 tint_subgroup_size_mask = 0u;
+  tint_module_vars_struct const tint_module_vars = tint_module_vars_struct{.prevent_dce=prevent_dce, .tint_subgroup_size_mask=(&tint_subgroup_size_mask)};
+  uint const v_3 = select((4294967295u >> (32u - tint_subgroup_size)), 4294967295u, (tint_subgroup_size > 32u));
+  uint const v_4 = select(0u, (4294967295u >> (64u - tint_subgroup_size)), (tint_subgroup_size > 32u));
+  (*tint_module_vars.tint_subgroup_size_mask)[0u] = v_3;
+  (*tint_module_vars.tint_subgroup_size_mask)[1u] = v_4;
+  (*tint_module_vars.prevent_dce) = subgroupBallot_1a8251(tint_module_vars);
+}
diff --git a/test/tint/builtins/gen/var/subgroupBallot/1a8251.wgsl.expected.ir.msl b/test/tint/builtins/gen/var/subgroupBallot/1a8251.wgsl.expected.ir.msl
index e94069d..73db8be 100644
--- a/test/tint/builtins/gen/var/subgroupBallot/1a8251.wgsl.expected.ir.msl
+++ b/test/tint/builtins/gen/var/subgroupBallot/1a8251.wgsl.expected.ir.msl
@@ -1,9 +1,38 @@
-SKIP: FAILED
+#include <metal_stdlib>
+using namespace metal;
 
-../../src/tint/lang/msl/writer/printer/printer.cc:988 internal compiler error: TINT_UNREACHABLE unhandled: subgroupBallot
-********************************************************************
-*  The tint shader compiler has encountered an unexpected error.   *
-*                                                                  *
-*  Please help us fix this issue by submitting a bug report at     *
-*  crbug.com/tint with the source program that triggered the bug.  *
-********************************************************************
+struct tint_module_vars_struct {
+  device uint4* prevent_dce;
+  thread uint2* tint_subgroup_size_mask;
+};
+
+uint4 tint_subgroup_ballot(bool pred, tint_module_vars_struct tint_module_vars) {
+  uint2 const v = as_type<uint2>((simd_vote::vote_t)simd_ballot(pred));
+  return uint4((v & (*tint_module_vars.tint_subgroup_size_mask)), 0u, 0u);
+}
+
+uint4 subgroupBallot_1a8251(tint_module_vars_struct tint_module_vars) {
+  bool arg_0 = true;
+  uint4 res = tint_subgroup_ballot(arg_0, tint_module_vars);
+  return res;
+}
+
+fragment void fragment_main(uint tint_subgroup_size [[threads_per_simdgroup]], device uint4* prevent_dce [[buffer(0)]]) {
+  thread uint2 tint_subgroup_size_mask = 0u;
+  tint_module_vars_struct const tint_module_vars = tint_module_vars_struct{.prevent_dce=prevent_dce, .tint_subgroup_size_mask=(&tint_subgroup_size_mask)};
+  uint const v_1 = select((4294967295u >> (32u - tint_subgroup_size)), 4294967295u, (tint_subgroup_size > 32u));
+  uint const v_2 = select(0u, (4294967295u >> (64u - tint_subgroup_size)), (tint_subgroup_size > 32u));
+  (*tint_module_vars.tint_subgroup_size_mask)[0u] = v_1;
+  (*tint_module_vars.tint_subgroup_size_mask)[1u] = v_2;
+  (*tint_module_vars.prevent_dce) = subgroupBallot_1a8251(tint_module_vars);
+}
+
+kernel void compute_main(uint tint_subgroup_size [[threads_per_simdgroup]], device uint4* prevent_dce [[buffer(0)]]) {
+  thread uint2 tint_subgroup_size_mask = 0u;
+  tint_module_vars_struct const tint_module_vars = tint_module_vars_struct{.prevent_dce=prevent_dce, .tint_subgroup_size_mask=(&tint_subgroup_size_mask)};
+  uint const v_3 = select((4294967295u >> (32u - tint_subgroup_size)), 4294967295u, (tint_subgroup_size > 32u));
+  uint const v_4 = select(0u, (4294967295u >> (64u - tint_subgroup_size)), (tint_subgroup_size > 32u));
+  (*tint_module_vars.tint_subgroup_size_mask)[0u] = v_3;
+  (*tint_module_vars.tint_subgroup_size_mask)[1u] = v_4;
+  (*tint_module_vars.prevent_dce) = subgroupBallot_1a8251(tint_module_vars);
+}