[msl] Add SimdBallot transform
Replace calls to subgroupBallot() with a call to a helper function
that calls an MSL intrinsic for simd_ballot and masks the result.
Bug: 42251016
Change-Id: I0c6b5266018e6a264c065c4ef73ac0f342075899
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/204676
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/tint/lang/msl/builtin_fn.cc b/src/tint/lang/msl/builtin_fn.cc
index 8b11556..a434592 100644
--- a/src/tint/lang/msl/builtin_fn.cc
+++ b/src/tint/lang/msl/builtin_fn.cc
@@ -106,6 +106,8 @@
return "sign";
case BuiltinFn::kThreadgroupBarrier:
return "threadgroup_barrier";
+ case BuiltinFn::kSimdBallot:
+ return "simd_ballot";
}
return "<unknown>";
}
diff --git a/src/tint/lang/msl/builtin_fn.h b/src/tint/lang/msl/builtin_fn.h
index d1f7ea3..b044e7c 100644
--- a/src/tint/lang/msl/builtin_fn.h
+++ b/src/tint/lang/msl/builtin_fn.h
@@ -79,6 +79,7 @@
kModf,
kSign,
kThreadgroupBarrier,
+ kSimdBallot,
kNone,
};
diff --git a/src/tint/lang/msl/intrinsic/data.cc b/src/tint/lang/msl/intrinsic/data.cc
index 65fd564..6de5735 100644
--- a/src/tint/lang/msl/intrinsic/data.cc
+++ b/src/tint/lang/msl/intrinsic/data.cc
@@ -2705,6 +2705,11 @@
/* usage */ core::ParameterUsage::kTexture,
/* matcher_indices */ MatcherIndicesIndex(9),
},
+ {
+ /* [309] */
+ /* usage */ core::ParameterUsage::kNone,
+ /* matcher_indices */ MatcherIndicesIndex(4),
+ },
};
static_assert(ParameterIndex::CanIndex(kParameters),
@@ -4684,6 +4689,17 @@
},
{
/* [171] */
+ /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
+ /* num_parameters */ 1,
+ /* num_explicit_templates */ 0,
+ /* num_templates */ 0,
+ /* templates */ TemplateIndex(/* invalid */),
+ /* parameters */ ParameterIndex(309),
+ /* return_matcher_indices */ MatcherIndicesIndex(121),
+ /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */),
+ },
+ {
+ /* [172] */
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* num_parameters */ 2,
/* num_explicit_templates */ 0,
@@ -5040,6 +5056,12 @@
/* num overloads */ 1,
/* overloads */ OverloadIndex(170),
},
+ {
+ /* [32] */
+ /* fn simd_ballot(bool) -> vec2<u32> */
+ /* num overloads */ 1,
+ /* overloads */ OverloadIndex(171),
+ },
};
constexpr IntrinsicInfo kBinaryOperators[] = {
@@ -5047,13 +5069,13 @@
/* [0] */
/* op +[T : iu8](T, T) -> T */
/* num overloads */ 1,
- /* overloads */ OverloadIndex(171),
+ /* overloads */ OverloadIndex(172),
},
{
/* [1] */
/* op *[T : iu8](T, T) -> T */
/* num overloads */ 1,
- /* overloads */ OverloadIndex(171),
+ /* overloads */ OverloadIndex(172),
},
};
constexpr uint8_t kBinaryOperatorPlus = 0;
diff --git a/src/tint/lang/msl/msl.def b/src/tint/lang/msl/msl.def
index a81b4a9..468a53f 100644
--- a/src/tint/lang/msl/msl.def
+++ b/src/tint/lang/msl/msl.def
@@ -339,6 +339,8 @@
implicit(N: num, T: f32_f16) fn sign(vec<N, T>) -> vec<N, T>
@stage("compute") fn threadgroup_barrier(u32)
+@stage("fragment", "compute") fn simd_ballot(bool) -> vec2<u32>
+
////////////////////////////////////////////////////////////////////////////////
// Binary Operators //
////////////////////////////////////////////////////////////////////////////////
diff --git a/src/tint/lang/msl/writer/printer/printer.cc b/src/tint/lang/msl/writer/printer/printer.cc
index c5a18b2..5ab53b9 100644
--- a/src/tint/lang/msl/writer/printer/printer.cc
+++ b/src/tint/lang/msl/writer/printer/printer.cc
@@ -866,6 +866,11 @@
out << ")";
return;
+ } else if (c->Func() == msl::BuiltinFn::kSimdBallot) {
+ out << "as_type<uint2>((simd_vote::vote_t)simd_ballot(";
+ EmitValue(out, c->Args()[0]);
+ out << "))";
+ return;
}
out << c->Func() << "(";
diff --git a/src/tint/lang/msl/writer/raise/BUILD.bazel b/src/tint/lang/msl/writer/raise/BUILD.bazel
index 3e019ea..3291fb7 100644
--- a/src/tint/lang/msl/writer/raise/BUILD.bazel
+++ b/src/tint/lang/msl/writer/raise/BUILD.bazel
@@ -45,6 +45,7 @@
"packed_vec3.cc",
"raise.cc",
"shader_io.cc",
+ "simd_ballot.cc",
],
hdrs = [
"binary_polyfill.h",
@@ -53,6 +54,7 @@
"packed_vec3.h",
"raise.h",
"shader_io.h",
+ "simd_ballot.h",
],
deps = [
"//src/tint/api/common",
@@ -100,6 +102,7 @@
"module_scope_vars_test.cc",
"packed_vec3_test.cc",
"shader_io_test.cc",
+ "simd_ballot_test.cc",
],
deps = [
"//src/tint/api/common",
diff --git a/src/tint/lang/msl/writer/raise/BUILD.cmake b/src/tint/lang/msl/writer/raise/BUILD.cmake
index 39d4546..7d1c017 100644
--- a/src/tint/lang/msl/writer/raise/BUILD.cmake
+++ b/src/tint/lang/msl/writer/raise/BUILD.cmake
@@ -53,6 +53,8 @@
lang/msl/writer/raise/raise.h
lang/msl/writer/raise/shader_io.cc
lang/msl/writer/raise/shader_io.h
+ lang/msl/writer/raise/simd_ballot.cc
+ lang/msl/writer/raise/simd_ballot.h
)
tint_target_add_dependencies(tint_lang_msl_writer_raise lib
@@ -107,6 +109,7 @@
lang/msl/writer/raise/module_scope_vars_test.cc
lang/msl/writer/raise/packed_vec3_test.cc
lang/msl/writer/raise/shader_io_test.cc
+ lang/msl/writer/raise/simd_ballot_test.cc
)
tint_target_add_dependencies(tint_lang_msl_writer_raise_test test
diff --git a/src/tint/lang/msl/writer/raise/BUILD.gn b/src/tint/lang/msl/writer/raise/BUILD.gn
index d6973f0..4b7e044 100644
--- a/src/tint/lang/msl/writer/raise/BUILD.gn
+++ b/src/tint/lang/msl/writer/raise/BUILD.gn
@@ -57,6 +57,8 @@
"raise.h",
"shader_io.cc",
"shader_io.h",
+ "simd_ballot.cc",
+ "simd_ballot.h",
]
deps = [
"${dawn_root}/src/utils:utils",
@@ -102,6 +104,7 @@
"module_scope_vars_test.cc",
"packed_vec3_test.cc",
"shader_io_test.cc",
+ "simd_ballot_test.cc",
]
deps = [
"${dawn_root}/src/utils:utils",
diff --git a/src/tint/lang/msl/writer/raise/raise.cc b/src/tint/lang/msl/writer/raise/raise.cc
index 49e6311..0230db1 100644
--- a/src/tint/lang/msl/writer/raise/raise.cc
+++ b/src/tint/lang/msl/writer/raise/raise.cc
@@ -51,6 +51,7 @@
#include "src/tint/lang/msl/writer/raise/module_scope_vars.h"
#include "src/tint/lang/msl/writer/raise/packed_vec3.h"
#include "src/tint/lang/msl/writer/raise/shader_io.h"
+#include "src/tint/lang/msl/writer/raise/simd_ballot.h"
namespace tint::msl::writer {
@@ -130,6 +131,7 @@
RUN_TRANSFORM(raise::ShaderIO, module,
raise::ShaderIOConfig{options.emit_vertex_point_size, options.fixed_sample_mask});
RUN_TRANSFORM(raise::PackedVec3, module);
+ RUN_TRANSFORM(raise::SimdBallot, module);
RUN_TRANSFORM(raise::ModuleScopeVars, module);
RUN_TRANSFORM(raise::BinaryPolyfill, module);
RUN_TRANSFORM(raise::BuiltinPolyfill, module);
diff --git a/src/tint/lang/msl/writer/raise/simd_ballot.cc b/src/tint/lang/msl/writer/raise/simd_ballot.cc
new file mode 100644
index 0000000..565d85d
--- /dev/null
+++ b/src/tint/lang/msl/writer/raise/simd_ballot.cc
@@ -0,0 +1,172 @@
+// Copyright 2024 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+// list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include "src/tint/lang/msl/writer/raise/simd_ballot.h"
+
+#include <utility>
+
+#include "src/tint/lang/core/ir/builder.h"
+#include "src/tint/lang/core/ir/transform/common/referenced_module_vars.h"
+#include "src/tint/lang/core/ir/validator.h"
+#include "src/tint/lang/msl/ir/builtin_call.h"
+
+namespace tint::msl::writer::raise {
+namespace {
+
+using namespace tint::core::fluent_types; // NOLINT
+
+/// PIMPL state for the transform.
+struct State {
+ /// The IR module.
+ core::ir::Module& ir;
+
+ /// The IR builder.
+ core::ir::Builder b{ir};
+
+ /// The type manager.
+ core::type::Manager& ty{ir.Types()};
+
+ /// The subgroupBallot polyfill function.
+ core::ir::Function* subgroup_ballot_polyfill = nullptr;
+
+ /// The subgroup_size_mask module-scope variable.
+ core::ir::Var* subgroup_size_mask = nullptr;
+
+ /// Process the module.
+ void Process() {
+ // Find calls to `subgroupBallot`.
+ for (auto* inst : ir.Instructions()) {
+ if (auto* call = inst->As<core::ir::CoreBuiltinCall>()) {
+ if (call->Func() == core::BuiltinFn::kSubgroupBallot) {
+ Replace(call);
+ }
+ }
+ }
+
+ // Set the subgroup size mask value from all entry points that use it.
+ core::ir::ReferencedModuleVars refs(ir, [&](const core::ir::Var* var) { //
+ return var == subgroup_size_mask;
+ });
+ for (auto func : ir.functions) {
+ if (func->Stage() != core::ir::Function::PipelineStage::kUndefined) {
+ if (refs.TransitiveReferences(func).Contains(subgroup_size_mask)) {
+ SetSubgroupSizeMaskForEntryPoint(func);
+ }
+ }
+ }
+ }
+
+ /// Replace a call to subgroupBallot with a call to a polyfill function.
+ void Replace(core::ir::CoreBuiltinCall* call) {
+ b.InsertBefore(call, [&] {
+ b.CallWithResult(call->DetachResult(), SubgroupBallotPolyfill(), call->Args()[0]);
+ });
+ call->Destroy();
+ }
+
+ /// Get (or create, on first call) the polyfill function.
+ core::ir::Function* SubgroupBallotPolyfill() {
+ if (subgroup_ballot_polyfill) {
+ // The polyfill has already been created.
+ return subgroup_ballot_polyfill;
+ }
+
+ // Declare the subgroup_size_mask variable that we need.
+ b.Append(ir.root_block, [&] {
+ subgroup_size_mask = b.Var<private_, vec2<u32>>("tint_subgroup_size_mask");
+ });
+
+ // Create the polyfill function, which looks like this:
+ // fn tint_subgroup_ballot(pred: bool) -> vec4u {
+ // let simd_vote: vec2u = msl.simd_ballot(pred);
+ // return vec4u(simd_vote & tint_subgroup_size_mask, 0, 0);
+ // }
+ auto* pred = b.FunctionParam("pred", ty.bool_());
+ subgroup_ballot_polyfill = b.Function("tint_subgroup_ballot", ty.vec4<u32>());
+ subgroup_ballot_polyfill->SetParams({pred});
+ b.Append(subgroup_ballot_polyfill->Block(), [&] {
+ auto* simd_vote =
+ b.Call<msl::ir::BuiltinCall>(ty.vec2<u32>(), msl::BuiltinFn::kSimdBallot, pred);
+ auto* masked = b.And<vec2<u32>>(simd_vote, b.Load(subgroup_size_mask));
+ auto* result = b.Construct(ty.vec4<u32>(), masked, u32(0), u32(0));
+ b.Return(subgroup_ballot_polyfill, result);
+ });
+
+ return subgroup_ballot_polyfill;
+ }
+
+ /// Set the subgroup_size_mask variable from an entry point.
+ void SetSubgroupSizeMaskForEntryPoint(core::ir::Function* ep) {
+ // Check if there is a user provided subgroup_size builtin.
+ core::ir::FunctionParam* subgroup_size = nullptr;
+ for (auto* param : ep->Params()) {
+ if (param->Attributes().builtin == core::BuiltinValue::kSubgroupSize) {
+ subgroup_size = param;
+ break;
+ }
+ }
+ if (!subgroup_size) {
+ // No user defined subgroup_size builtin was found, so create our own.
+ subgroup_size = b.FunctionParam("tint_subgroup_size", ty.u32());
+ subgroup_size->SetBuiltin(core::BuiltinValue::kSubgroupSize);
+ ep->AppendParam(subgroup_size);
+ }
+
+ // Set the subgroup_size_mask based on the subgroup_size:
+ // let size_gt_32 = (subgroup_size > 32u);
+ // let high = select(4294967295u >> (32u - subgroup_size), 4294967295u, size_gt_32);
+ // let low = select(0u, (4294967295u >> (64u - subgroup_size)), size_gt_32);
+ // tint_subgroup_size_mask[0u] = high;
+ // tint_subgroup_size_mask[1u] = low;
+ b.InsertBefore(ep->Block()->Front(), [&] {
+ auto* gt32 = b.GreaterThan<bool>(subgroup_size, u32(32));
+ auto* high_mask =
+ b.ShiftRight<u32>(u32::Highest(), b.Subtract<u32>(u32(32), subgroup_size));
+ auto* high = b.Call<u32>(core::BuiltinFn::kSelect, high_mask, u32::Highest(), gt32);
+ auto* low_mask =
+ b.ShiftRight<u32>(u32::Highest(), b.Subtract<u32>(u32(64), subgroup_size));
+ auto* low = b.Call<u32>(core::BuiltinFn::kSelect, u32(0), low_mask, gt32);
+ b.StoreVectorElement(subgroup_size_mask, u32(0), high);
+ b.StoreVectorElement(subgroup_size_mask, u32(1), low);
+ });
+ }
+};
+
+} // namespace
+
+Result<SuccessType> SimdBallot(core::ir::Module& ir) {
+ auto result = ValidateAndDumpIfNeeded(ir, "SimdBallot transform");
+ if (result != Success) {
+ return result.Failure();
+ }
+
+ State{ir}.Process();
+
+ return Success;
+}
+
+} // namespace tint::msl::writer::raise
diff --git a/src/tint/lang/msl/writer/raise/simd_ballot.h b/src/tint/lang/msl/writer/raise/simd_ballot.h
new file mode 100644
index 0000000..d34519d
--- /dev/null
+++ b/src/tint/lang/msl/writer/raise/simd_ballot.h
@@ -0,0 +1,48 @@
+// Copyright 2024 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+// list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef SRC_TINT_LANG_MSL_WRITER_RAISE_SIMD_BALLOT_H_
+#define SRC_TINT_LANG_MSL_WRITER_RAISE_SIMD_BALLOT_H_
+
+#include "src/tint/utils/result/result.h"
+
+// Forward declarations.
+namespace tint::core::ir {
+class Module;
+} // namespace tint::core::ir
+
+namespace tint::msl::writer::raise {
+
+/// This transform replaces core `subgroupBallot` calls with calls to a `simd_ballot` MSL intrinsic,
+/// adding conversion and masking operations to produce the correct result.
+/// @param module the module to transform
+/// @returns success or failure
+Result<SuccessType> SimdBallot(core::ir::Module& module);
+
+} // namespace tint::msl::writer::raise
+
+#endif // SRC_TINT_LANG_MSL_WRITER_RAISE_SIMD_BALLOT_H_
diff --git a/src/tint/lang/msl/writer/raise/simd_ballot_test.cc b/src/tint/lang/msl/writer/raise/simd_ballot_test.cc
new file mode 100644
index 0000000..3e8e1f7
--- /dev/null
+++ b/src/tint/lang/msl/writer/raise/simd_ballot_test.cc
@@ -0,0 +1,317 @@
+// Copyright 2024 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+// list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include "src/tint/lang/msl/writer/raise/simd_ballot.h"
+
+#include "gtest/gtest.h"
+
+#include "src/tint/lang/core/fluent_types.h"
+#include "src/tint/lang/core/ir/transform/helper_test.h"
+#include "src/tint/lang/core/number.h"
+
+using namespace tint::core::fluent_types; // NOLINT
+using namespace tint::core::number_suffixes; // NOLINT
+
+namespace tint::msl::writer::raise {
+namespace {
+
+using MslWriter_SimdBallotTest = core::ir::transform::TransformTest;
+
+TEST_F(MslWriter_SimdBallotTest, SimdBallot_WithUserDeclaredSubgroupSize) {
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ auto* subgroup_size = b.FunctionParam("user_subgroup_size", ty.u32());
+ func->SetParams({subgroup_size});
+ b.Append(func->Block(), [&] { //
+ b.Call<vec4<u32>>(core::BuiltinFn::kSubgroupBallot, true);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%foo = @fragment func(%user_subgroup_size:u32):void {
+ $B1: {
+ %3:vec4<u32> = subgroupBallot true
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+$B1: { # root
+ %tint_subgroup_size_mask:ptr<private, vec2<u32>, read_write> = var
+}
+
+%foo = @fragment func(%user_subgroup_size:u32, %tint_subgroup_size:u32 [@subgroup_size]):void {
+ $B2: {
+ %5:bool = gt %tint_subgroup_size, 32u
+ %6:u32 = sub 32u, %tint_subgroup_size
+ %7:u32 = shr 4294967295u, %6
+ %8:u32 = select %7, 4294967295u, %5
+ %9:u32 = sub 64u, %tint_subgroup_size
+ %10:u32 = shr 4294967295u, %9
+ %11:u32 = select 0u, %10, %5
+ store_vector_element %tint_subgroup_size_mask, 0u, %8
+ store_vector_element %tint_subgroup_size_mask, 1u, %11
+ %12:vec4<u32> = call %tint_subgroup_ballot, true
+ ret
+ }
+}
+%tint_subgroup_ballot = func(%pred:bool):vec4<u32> {
+ $B3: {
+ %15:vec2<u32> = msl.simd_ballot %pred
+ %16:vec2<u32> = load %tint_subgroup_size_mask
+ %17:vec2<u32> = and %15, %16
+ %18:vec4<u32> = construct %17, 0u, 0u
+ ret %18
+ }
+}
+)";
+
+ Run(SimdBallot);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(MslWriter_SimdBallotTest, SimdBallot_WithoutUserDeclaredSubgroupSize) {
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] { //
+ b.Call<vec4<u32>>(core::BuiltinFn::kSubgroupBallot, true);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%foo = @fragment func():void {
+ $B1: {
+ %2:vec4<u32> = subgroupBallot true
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+$B1: { # root
+ %tint_subgroup_size_mask:ptr<private, vec2<u32>, read_write> = var
+}
+
+%foo = @fragment func(%tint_subgroup_size:u32 [@subgroup_size]):void {
+ $B2: {
+ %4:bool = gt %tint_subgroup_size, 32u
+ %5:u32 = sub 32u, %tint_subgroup_size
+ %6:u32 = shr 4294967295u, %5
+ %7:u32 = select %6, 4294967295u, %4
+ %8:u32 = sub 64u, %tint_subgroup_size
+ %9:u32 = shr 4294967295u, %8
+ %10:u32 = select 0u, %9, %4
+ store_vector_element %tint_subgroup_size_mask, 0u, %7
+ store_vector_element %tint_subgroup_size_mask, 1u, %10
+ %11:vec4<u32> = call %tint_subgroup_ballot, true
+ ret
+ }
+}
+%tint_subgroup_ballot = func(%pred:bool):vec4<u32> {
+ $B3: {
+ %14:vec2<u32> = msl.simd_ballot %pred
+ %15:vec2<u32> = load %tint_subgroup_size_mask
+ %16:vec2<u32> = and %14, %15
+ %17:vec4<u32> = construct %16, 0u, 0u
+ ret %17
+ }
+}
+)";
+
+ Run(SimdBallot);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(MslWriter_SimdBallotTest, SimdBallot_InHelperFunction) {
+ auto* foo = b.Function("foo", ty.vec4<u32>());
+ auto* pred = b.FunctionParam("pred", ty.bool_());
+ foo->SetParams({pred});
+ b.Append(foo->Block(), [&] { //
+ auto* result = b.Call<vec4<u32>>(core::BuiltinFn::kSubgroupBallot, pred);
+ b.Return(foo, result);
+ });
+
+ auto* ep1 = b.Function("ep1", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ auto* subgroup_size = b.FunctionParam("user_subgroup_size", ty.u32());
+ ep1->SetParams({subgroup_size});
+ b.Append(ep1->Block(), [&] { //
+ b.Call<vec4<u32>>(foo, true);
+ b.Return(ep1);
+ });
+
+ auto* ep2 = b.Function("ep2", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(ep2->Block(), [&] { //
+ b.Call<vec4<u32>>(foo, false);
+ b.Return(ep2);
+ });
+
+ auto* src = R"(
+%foo = func(%pred:bool):vec4<u32> {
+ $B1: {
+ %3:vec4<u32> = subgroupBallot %pred
+ ret %3
+ }
+}
+%ep1 = @fragment func(%user_subgroup_size:u32):void {
+ $B2: {
+ %6:vec4<u32> = call %foo, true
+ ret
+ }
+}
+%ep2 = @fragment func():void {
+ $B3: {
+ %8:vec4<u32> = call %foo, false
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+$B1: { # root
+ %tint_subgroup_size_mask:ptr<private, vec2<u32>, read_write> = var
+}
+
+%foo = func(%pred:bool):vec4<u32> {
+ $B2: {
+ %4:vec4<u32> = call %tint_subgroup_ballot, %pred
+ ret %4
+ }
+}
+%ep1 = @fragment func(%user_subgroup_size:u32, %tint_subgroup_size:u32 [@subgroup_size]):void {
+ $B3: {
+ %9:bool = gt %tint_subgroup_size, 32u
+ %10:u32 = sub 32u, %tint_subgroup_size
+ %11:u32 = shr 4294967295u, %10
+ %12:u32 = select %11, 4294967295u, %9
+ %13:u32 = sub 64u, %tint_subgroup_size
+ %14:u32 = shr 4294967295u, %13
+ %15:u32 = select 0u, %14, %9
+ store_vector_element %tint_subgroup_size_mask, 0u, %12
+ store_vector_element %tint_subgroup_size_mask, 1u, %15
+ %16:vec4<u32> = call %foo, true
+ ret
+ }
+}
+%ep2 = @fragment func(%tint_subgroup_size_1:u32 [@subgroup_size]):void { # %tint_subgroup_size_1: 'tint_subgroup_size'
+ $B4: {
+ %19:bool = gt %tint_subgroup_size_1, 32u
+ %20:u32 = sub 32u, %tint_subgroup_size_1
+ %21:u32 = shr 4294967295u, %20
+ %22:u32 = select %21, 4294967295u, %19
+ %23:u32 = sub 64u, %tint_subgroup_size_1
+ %24:u32 = shr 4294967295u, %23
+ %25:u32 = select 0u, %24, %19
+ store_vector_element %tint_subgroup_size_mask, 0u, %22
+ store_vector_element %tint_subgroup_size_mask, 1u, %25
+ %26:vec4<u32> = call %foo, false
+ ret
+ }
+}
+%tint_subgroup_ballot = func(%pred_1:bool):vec4<u32> { # %pred_1: 'pred'
+ $B5: {
+ %28:vec2<u32> = msl.simd_ballot %pred_1
+ %29:vec2<u32> = load %tint_subgroup_size_mask
+ %30:vec2<u32> = and %28, %29
+ %31:vec4<u32> = construct %30, 0u, 0u
+ ret %31
+ }
+}
+)";
+
+ Run(SimdBallot);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(MslWriter_SimdBallotTest, SimdBallot_MultipleCalls) {
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] { //
+ b.Call<vec4<u32>>(core::BuiltinFn::kSubgroupBallot, true);
+ b.Call<vec4<u32>>(core::BuiltinFn::kSubgroupBallot, false);
+ b.Call<vec4<u32>>(core::BuiltinFn::kSubgroupBallot, true);
+ b.Call<vec4<u32>>(core::BuiltinFn::kSubgroupBallot, false);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%foo = @fragment func():void {
+ $B1: {
+ %2:vec4<u32> = subgroupBallot true
+ %3:vec4<u32> = subgroupBallot false
+ %4:vec4<u32> = subgroupBallot true
+ %5:vec4<u32> = subgroupBallot false
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+$B1: { # root
+ %tint_subgroup_size_mask:ptr<private, vec2<u32>, read_write> = var
+}
+
+%foo = @fragment func(%tint_subgroup_size:u32 [@subgroup_size]):void {
+ $B2: {
+ %4:bool = gt %tint_subgroup_size, 32u
+ %5:u32 = sub 32u, %tint_subgroup_size
+ %6:u32 = shr 4294967295u, %5
+ %7:u32 = select %6, 4294967295u, %4
+ %8:u32 = sub 64u, %tint_subgroup_size
+ %9:u32 = shr 4294967295u, %8
+ %10:u32 = select 0u, %9, %4
+ store_vector_element %tint_subgroup_size_mask, 0u, %7
+ store_vector_element %tint_subgroup_size_mask, 1u, %10
+ %11:vec4<u32> = call %tint_subgroup_ballot, true
+ %13:vec4<u32> = call %tint_subgroup_ballot, false
+ %14:vec4<u32> = call %tint_subgroup_ballot, true
+ %15:vec4<u32> = call %tint_subgroup_ballot, false
+ ret
+ }
+}
+%tint_subgroup_ballot = func(%pred:bool):vec4<u32> {
+ $B3: {
+ %17:vec2<u32> = msl.simd_ballot %pred
+ %18:vec2<u32> = load %tint_subgroup_size_mask
+ %19:vec2<u32> = and %17, %18
+ %20:vec4<u32> = construct %19, 0u, 0u
+ ret %20
+ }
+}
+)";
+
+ Run(SimdBallot);
+
+ EXPECT_EQ(expect, str());
+}
+
+} // namespace
+} // namespace tint::msl::writer::raise
diff --git a/test/tint/builtins/gen/literal/subgroupBallot/1a8251.wgsl.expected.ir.msl b/test/tint/builtins/gen/literal/subgroupBallot/1a8251.wgsl.expected.ir.msl
index e94069d..65c71e1 100644
--- a/test/tint/builtins/gen/literal/subgroupBallot/1a8251.wgsl.expected.ir.msl
+++ b/test/tint/builtins/gen/literal/subgroupBallot/1a8251.wgsl.expected.ir.msl
@@ -1,9 +1,37 @@
-SKIP: FAILED
+#include <metal_stdlib>
+using namespace metal;
-../../src/tint/lang/msl/writer/printer/printer.cc:988 internal compiler error: TINT_UNREACHABLE unhandled: subgroupBallot
-********************************************************************
-* The tint shader compiler has encountered an unexpected error. *
-* *
-* Please help us fix this issue by submitting a bug report at *
-* crbug.com/tint with the source program that triggered the bug. *
-********************************************************************
+struct tint_module_vars_struct {
+ device uint4* prevent_dce;
+ thread uint2* tint_subgroup_size_mask;
+};
+
+uint4 tint_subgroup_ballot(bool pred, tint_module_vars_struct tint_module_vars) {
+ uint2 const v = as_type<uint2>((simd_vote::vote_t)simd_ballot(pred));
+ return uint4((v & (*tint_module_vars.tint_subgroup_size_mask)), 0u, 0u);
+}
+
+uint4 subgroupBallot_1a8251(tint_module_vars_struct tint_module_vars) {
+ uint4 res = tint_subgroup_ballot(true, tint_module_vars);
+ return res;
+}
+
+fragment void fragment_main(uint tint_subgroup_size [[threads_per_simdgroup]], device uint4* prevent_dce [[buffer(0)]]) {
+ thread uint2 tint_subgroup_size_mask = 0u;
+ tint_module_vars_struct const tint_module_vars = tint_module_vars_struct{.prevent_dce=prevent_dce, .tint_subgroup_size_mask=(&tint_subgroup_size_mask)};
+ uint const v_1 = select((4294967295u >> (32u - tint_subgroup_size)), 4294967295u, (tint_subgroup_size > 32u));
+ uint const v_2 = select(0u, (4294967295u >> (64u - tint_subgroup_size)), (tint_subgroup_size > 32u));
+ (*tint_module_vars.tint_subgroup_size_mask)[0u] = v_1;
+ (*tint_module_vars.tint_subgroup_size_mask)[1u] = v_2;
+ (*tint_module_vars.prevent_dce) = subgroupBallot_1a8251(tint_module_vars);
+}
+
+kernel void compute_main(uint tint_subgroup_size [[threads_per_simdgroup]], device uint4* prevent_dce [[buffer(0)]]) {
+ thread uint2 tint_subgroup_size_mask = 0u;
+ tint_module_vars_struct const tint_module_vars = tint_module_vars_struct{.prevent_dce=prevent_dce, .tint_subgroup_size_mask=(&tint_subgroup_size_mask)};
+ uint const v_3 = select((4294967295u >> (32u - tint_subgroup_size)), 4294967295u, (tint_subgroup_size > 32u));
+ uint const v_4 = select(0u, (4294967295u >> (64u - tint_subgroup_size)), (tint_subgroup_size > 32u));
+ (*tint_module_vars.tint_subgroup_size_mask)[0u] = v_3;
+ (*tint_module_vars.tint_subgroup_size_mask)[1u] = v_4;
+ (*tint_module_vars.prevent_dce) = subgroupBallot_1a8251(tint_module_vars);
+}
diff --git a/test/tint/builtins/gen/var/subgroupBallot/1a8251.wgsl.expected.ir.msl b/test/tint/builtins/gen/var/subgroupBallot/1a8251.wgsl.expected.ir.msl
index e94069d..73db8be 100644
--- a/test/tint/builtins/gen/var/subgroupBallot/1a8251.wgsl.expected.ir.msl
+++ b/test/tint/builtins/gen/var/subgroupBallot/1a8251.wgsl.expected.ir.msl
@@ -1,9 +1,38 @@
-SKIP: FAILED
+#include <metal_stdlib>
+using namespace metal;
-../../src/tint/lang/msl/writer/printer/printer.cc:988 internal compiler error: TINT_UNREACHABLE unhandled: subgroupBallot
-********************************************************************
-* The tint shader compiler has encountered an unexpected error. *
-* *
-* Please help us fix this issue by submitting a bug report at *
-* crbug.com/tint with the source program that triggered the bug. *
-********************************************************************
+struct tint_module_vars_struct {
+ device uint4* prevent_dce;
+ thread uint2* tint_subgroup_size_mask;
+};
+
+uint4 tint_subgroup_ballot(bool pred, tint_module_vars_struct tint_module_vars) {
+ uint2 const v = as_type<uint2>((simd_vote::vote_t)simd_ballot(pred));
+ return uint4((v & (*tint_module_vars.tint_subgroup_size_mask)), 0u, 0u);
+}
+
+uint4 subgroupBallot_1a8251(tint_module_vars_struct tint_module_vars) {
+ bool arg_0 = true;
+ uint4 res = tint_subgroup_ballot(arg_0, tint_module_vars);
+ return res;
+}
+
+fragment void fragment_main(uint tint_subgroup_size [[threads_per_simdgroup]], device uint4* prevent_dce [[buffer(0)]]) {
+ thread uint2 tint_subgroup_size_mask = 0u;
+ tint_module_vars_struct const tint_module_vars = tint_module_vars_struct{.prevent_dce=prevent_dce, .tint_subgroup_size_mask=(&tint_subgroup_size_mask)};
+ uint const v_1 = select((4294967295u >> (32u - tint_subgroup_size)), 4294967295u, (tint_subgroup_size > 32u));
+ uint const v_2 = select(0u, (4294967295u >> (64u - tint_subgroup_size)), (tint_subgroup_size > 32u));
+ (*tint_module_vars.tint_subgroup_size_mask)[0u] = v_1;
+ (*tint_module_vars.tint_subgroup_size_mask)[1u] = v_2;
+ (*tint_module_vars.prevent_dce) = subgroupBallot_1a8251(tint_module_vars);
+}
+
+kernel void compute_main(uint tint_subgroup_size [[threads_per_simdgroup]], device uint4* prevent_dce [[buffer(0)]]) {
+ thread uint2 tint_subgroup_size_mask = 0u;
+ tint_module_vars_struct const tint_module_vars = tint_module_vars_struct{.prevent_dce=prevent_dce, .tint_subgroup_size_mask=(&tint_subgroup_size_mask)};
+ uint const v_3 = select((4294967295u >> (32u - tint_subgroup_size)), 4294967295u, (tint_subgroup_size > 32u));
+ uint const v_4 = select(0u, (4294967295u >> (64u - tint_subgroup_size)), (tint_subgroup_size > 32u));
+ (*tint_module_vars.tint_subgroup_size_mask)[0u] = v_3;
+ (*tint_module_vars.tint_subgroup_size_mask)[1u] = v_4;
+ (*tint_module_vars.prevent_dce) = subgroupBallot_1a8251(tint_module_vars);
+}