[msl] Add support for chromium_experimental_subgroups
Use a transform to replace calls to `subgroupBallot()` with a helper
function that masks out the unused bits, as well as converting the
64-bit output from MSL's `simd_active_threads_mask()` function to a
vec4u.
Adds the ability to change the target MSL version used for validation,
which is automatically bumped to MSL 2.1 if the subgroups extension is
used.
Bug: tint:2000
Change-Id: I3161b249acd90378be52b18f6325c82fd9dc80b2
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/146460
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index dfec136..30e4ddb 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -425,6 +425,8 @@
"lang/wgsl/ast/transform/merge_return.h",
"lang/wgsl/ast/transform/module_scope_var_to_entry_point_param.cc",
"lang/wgsl/ast/transform/module_scope_var_to_entry_point_param.h",
+ "lang/wgsl/ast/transform/msl_subgroup_ballot.cc",
+ "lang/wgsl/ast/transform/msl_subgroup_ballot.h",
"lang/wgsl/ast/transform/multiplanar_external_texture.cc",
"lang/wgsl/ast/transform/multiplanar_external_texture.h",
"lang/wgsl/ast/transform/num_workgroups_from_uniform.cc",
@@ -1847,6 +1849,7 @@
"lang/wgsl/ast/transform/manager_test.cc",
"lang/wgsl/ast/transform/merge_return_test.cc",
"lang/wgsl/ast/transform/module_scope_var_to_entry_point_param_test.cc",
+ "lang/wgsl/ast/transform/msl_subgroup_ballot_test.cc",
"lang/wgsl/ast/transform/multiplanar_external_texture_test.cc",
"lang/wgsl/ast/transform/num_workgroups_from_uniform_test.cc",
"lang/wgsl/ast/transform/packed_vec3_test.cc",
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 70f1821..b36d852 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -360,6 +360,8 @@
lang/wgsl/ast/transform/merge_return.h
lang/wgsl/ast/transform/module_scope_var_to_entry_point_param.cc
lang/wgsl/ast/transform/module_scope_var_to_entry_point_param.h
+ lang/wgsl/ast/transform/msl_subgroup_ballot.cc
+ lang/wgsl/ast/transform/msl_subgroup_ballot.h
lang/wgsl/ast/transform/multiplanar_external_texture.cc
lang/wgsl/ast/transform/multiplanar_external_texture.h
lang/wgsl/ast/transform/num_workgroups_from_uniform.cc
@@ -1498,6 +1500,7 @@
lang/wgsl/ast/transform/localize_struct_array_assignment_test.cc
lang/wgsl/ast/transform/merge_return_test.cc
lang/wgsl/ast/transform/module_scope_var_to_entry_point_param_test.cc
+ lang/wgsl/ast/transform/msl_subgroup_ballot_test.cc
lang/wgsl/ast/transform/multiplanar_external_texture_test.cc
lang/wgsl/ast/transform/num_workgroups_from_uniform_test.cc
lang/wgsl/ast/transform/packed_vec3_test.cc
diff --git a/src/tint/cmd/tint/main.cc b/src/tint/cmd/tint/main.cc
index 7434165..23a7e3f 100644
--- a/src/tint/cmd/tint/main.cc
+++ b/src/tint/cmd/tint/main.cc
@@ -671,10 +671,19 @@
PrintHash(hash);
}
+ // Default to validating against MSL 1.2.
+ // If subgroups are used, bump the version to 2.1.
+ auto msl_version = tint::msl::validate::MslVersion::kMsl_1_2;
+ for (auto* enable : program->AST().Enables()) {
+ if (enable->HasExtension(tint::core::Extension::kChromiumExperimentalSubgroups)) {
+ msl_version = tint::msl::validate::MslVersion::kMsl_2_1;
+ }
+ }
+
if (options.validate && options.skip_hash.count(hash) == 0) {
tint::msl::validate::Result res;
#ifdef TINT_ENABLE_MSL_VALIDATION_USING_METAL_API
- res = tint::msl::validate::UsingMetalAPI(result->msl);
+ res = tint::msl::validate::UsingMetalAPI(result->msl, msl_version);
#else
#ifdef _WIN32
const char* default_xcrun_exe = "metal.exe";
@@ -684,7 +693,7 @@
auto xcrun = tint::Command::LookPath(
options.xcrun_path.empty() ? default_xcrun_exe : std::string(options.xcrun_path));
if (xcrun.Found()) {
- res = tint::msl::validate::Msl(xcrun.Path(), result->msl);
+ res = tint::msl::validate::Msl(xcrun.Path(), result->msl, msl_version);
} else {
res.output = "xcrun executable not found. Cannot validate.";
res.failed = true;
diff --git a/src/tint/lang/msl/validate/msl.cc b/src/tint/lang/msl/validate/msl.cc
index b3d0000..fce70d6 100644
--- a/src/tint/lang/msl/validate/msl.cc
+++ b/src/tint/lang/msl/validate/msl.cc
@@ -21,7 +21,7 @@
namespace tint::msl::validate {
-Result Msl(const std::string& xcrun_path, const std::string& source) {
+Result Msl(const std::string& xcrun_path, const std::string& source, MslVersion version) {
Result result;
auto xcrun = tint::Command(xcrun_path);
@@ -34,17 +34,27 @@
tint::TmpFile file(".metal");
file << source;
+ const char* version_str = nullptr;
+ switch (version) {
+ case MslVersion::kMsl_1_2:
+ version_str = "-std=macos-metal1.2";
+ break;
+ case MslVersion::kMsl_2_1:
+ version_str = "-std=macos-metal2.1";
+ break;
+ }
+
#ifdef _WIN32
// On Windows, we should actually be running metal.exe from the Metal
// Developer Tools for Windows
- auto res = xcrun("-x", "metal", //
- "-o", "NUL", //
- "-std=osx-metal1.2", //
+ auto res = xcrun("-x", "metal", //
+ "-o", "NUL", //
+ version_str, //
"-c", file.Path());
#else
auto res = xcrun("-sdk", "macosx", "metal", //
"-o", "/dev/null", //
- "-std=osx-metal1.2", //
+ version_str, //
"-c", file.Path());
#endif
if (!res.out.empty()) {
diff --git a/src/tint/lang/msl/validate/msl_metal.mm b/src/tint/lang/msl/validate/msl_metal.mm
index 0b2724d..9999f2a 100644
--- a/src/tint/lang/msl/validate/msl_metal.mm
+++ b/src/tint/lang/msl/validate/msl_metal.mm
@@ -24,7 +24,7 @@
namespace tint::msl::validate {
-Result UsingMetalAPI(const std::string& src) {
+Result UsingMetalAPI(const std::string& src, MslVersion version) {
Result result;
NSError* error = nil;
@@ -40,7 +40,14 @@
MTLCompileOptions* compileOptions = [MTLCompileOptions new];
compileOptions.fastMathEnabled = true;
- compileOptions.languageVersion = MTLLanguageVersion1_2;
+ switch (version) {
+ case MslVersion::kMsl_1_2:
+ compileOptions.languageVersion = MTLLanguageVersion1_2;
+ break;
+ case MslVersion::kMsl_2_1:
+ compileOptions.languageVersion = MTLLanguageVersion2_1;
+ break;
+ }
id<MTLLibrary> library = [device newLibraryWithSource:source
options:compileOptions
diff --git a/src/tint/lang/msl/validate/val.h b/src/tint/lang/msl/validate/val.h
index c4c0a15..498516b 100644
--- a/src/tint/lang/msl/validate/val.h
+++ b/src/tint/lang/msl/validate/val.h
@@ -30,6 +30,12 @@
using EntryPointList = std::vector<std::pair<std::string, ast::PipelineStage>>;
+// The version of MSL to validate against.
+enum class MslVersion {
+ kMsl_1_2,
+ kMsl_2_1,
+};
+
/// The return structure of Validate()
struct Result {
/// True if validation passed
@@ -42,15 +48,17 @@
/// verifying that the shader compiles successfully.
/// @param xcrun_path path to xcrun
/// @param source the generated MSL source
+/// @param version the version of MSL to validate against
/// @return the result of the compile
-Result Msl(const std::string& xcrun_path, const std::string& source);
+Result Msl(const std::string& xcrun_path, const std::string& source, MslVersion version);
#ifdef TINT_ENABLE_MSL_VALIDATION_USING_METAL_API
/// Msl attempts to compile the shader with the runtime Metal Shader Compiler
/// API, verifying that the shader compiles successfully.
/// @param source the generated MSL source
+/// @param version the version of MSL to validate against
/// @return the result of the compile
-Result UsingMetalAPI(const std::string& source);
+Result UsingMetalAPI(const std::string& source, MslVersion version);
#endif // TINT_ENABLE_MSL_VALIDATION_USING_METAL_API
} // namespace tint::msl::validate
diff --git a/src/tint/lang/msl/writer/ast_printer/ast_printer.cc b/src/tint/lang/msl/writer/ast_printer/ast_printer.cc
index 86d611f..59ca50b 100644
--- a/src/tint/lang/msl/writer/ast_printer/ast_printer.cc
+++ b/src/tint/lang/msl/writer/ast_printer/ast_printer.cc
@@ -59,6 +59,7 @@
#include "src/tint/lang/wgsl/ast/transform/expand_compound_assignment.h"
#include "src/tint/lang/wgsl/ast/transform/manager.h"
#include "src/tint/lang/wgsl/ast/transform/module_scope_var_to_entry_point_param.h"
+#include "src/tint/lang/wgsl/ast/transform/msl_subgroup_ballot.h"
#include "src/tint/lang/wgsl/ast/transform/multiplanar_external_texture.h"
#include "src/tint/lang/wgsl/ast/transform/packed_vec3.h"
#include "src/tint/lang/wgsl/ast/transform/preserve_padding.h"
@@ -205,6 +206,9 @@
manager.Add<ast::transform::RemovePhonies>();
manager.Add<ast::transform::SimplifyPointers>();
+ // MslSubgroupBallot() must come after CanonicalizeEntryPointIO.
+ manager.Add<ast::transform::MslSubgroupBallot>();
+
// ArrayLengthFromUniform must come after SimplifyPointers, as
// it assumes that the form of the array length argument is &var.array.
manager.Add<ast::transform::ArrayLengthFromUniform>();
@@ -244,6 +248,7 @@
core::Extension::kChromiumExperimentalFullPtrParameters,
core::Extension::kChromiumExperimentalPushConstant,
core::Extension::kChromiumExperimentalReadWriteStorageTexture,
+ core::Extension::kChromiumExperimentalSubgroups,
core::Extension::kChromiumInternalRelaxedUniformLayout,
core::Extension::kF16,
core::Extension::kChromiumInternalDualSourceBlending,
@@ -622,6 +627,12 @@
bool ASTPrinter::EmitFunctionCall(StringStream& out,
const sem::Call* call,
const sem::Function* fn) {
+ if (ast::GetAttribute<ast::transform::MslSubgroupBallot::SimdActiveThreadsMask>(
+ fn->Declaration()->attributes) != nullptr) {
+ out << "as_type<uint2>((ulong)simd_active_threads_mask())";
+ return true;
+ }
+
out << fn->Declaration()->name->symbol.Name() << "(";
bool first = true;
@@ -1847,6 +1858,11 @@
}
bool ASTPrinter::EmitFunction(const ast::Function* func) {
+ if (func->body == nullptr) {
+ // An internal function. Do not emit.
+ return true;
+ }
+
auto* func_sem = builder_.Sem().Get(func);
{
diff --git a/src/tint/lang/msl/writer/common/printer_support.cc b/src/tint/lang/msl/writer/common/printer_support.cc
index 3311449..9312650 100644
--- a/src/tint/lang/msl/writer/common/printer_support.cc
+++ b/src/tint/lang/msl/writer/common/printer_support.cc
@@ -60,6 +60,10 @@
return "sample_mask";
case core::BuiltinValue::kPointSize:
return "point_size";
+ case core::BuiltinValue::kSubgroupInvocationId:
+ return "thread_index_in_simdgroup";
+ case core::BuiltinValue::kSubgroupSize:
+ return "threads_per_simdgroup";
default:
break;
}
diff --git a/src/tint/lang/wgsl/ast/transform/msl_subgroup_ballot.cc b/src/tint/lang/wgsl/ast/transform/msl_subgroup_ballot.cc
new file mode 100644
index 0000000..8bd3350
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/msl_subgroup_ballot.cc
@@ -0,0 +1,197 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/msl_subgroup_ballot.h"
+
+#include <utility>
+
+#include "src/tint/lang/wgsl/program/clone_context.h"
+#include "src/tint/lang/wgsl/program/program_builder.h"
+#include "src/tint/lang/wgsl/resolver/resolve.h"
+#include "src/tint/lang/wgsl/sem/call.h"
+#include "src/tint/lang/wgsl/sem/function.h"
+#include "src/tint/lang/wgsl/sem/statement.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::MslSubgroupBallot);
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::MslSubgroupBallot::SimdActiveThreadsMask);
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::ast::transform {
+
+/// PIMPL state for the transform
+struct MslSubgroupBallot::State {
+ /// The source program
+ const Program* const src;
+ /// The target program builder
+ ProgramBuilder b;
+ /// The clone context
+ program::CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
+
+ /// The name of the `tint_subgroup_ballot` helper function.
+ Symbol ballot_helper{};
+
+ /// The name of the `tint_subgroup_size_mask` global variable.
+ Symbol subgroup_size_mask{};
+
+ /// The set of a functions that directly call `subgroupBallot()`.
+ Hashset<const sem::Function*, 4> ballot_callers;
+
+ /// Constructor
+ /// @param program the source program
+ explicit State(const Program* program) : src(program) {}
+
+ /// Runs the transform
+ /// @returns the new program or SkipTransform if the transform is not required
+ ApplyResult Run() {
+ auto& sem = src->Sem();
+
+ bool made_changes = false;
+ for (auto* node : ctx.src->ASTNodes().Objects()) {
+ auto* call = sem.Get<sem::Call>(node);
+ if (call) {
+ // If this is a call to a `subgroupBallot()` builtin, replace it with a call to the
+ // helper function and make a note of the function that we are in.
+ auto* builtin = call->Target()->As<sem::Builtin>();
+ if (builtin && builtin->Type() == core::Function::kSubgroupBallot) {
+ ctx.Replace(call->Declaration(), b.Call(GetHelper()));
+ ballot_callers.Add(call->Stmt()->Function());
+ made_changes = true;
+ }
+ }
+ }
+ if (!made_changes) {
+ return SkipTransform;
+ }
+
+ // Set the subgroup size mask at the start of each entry point that transitively calls
+ // `subgroupBallot()`.
+ for (auto* global : src->AST().GlobalDeclarations()) {
+ auto* func = global->As<Function>();
+ if (func && func->IsEntryPoint() && TransitvelyCallsSubgroupBallot(sem.Get(func))) {
+ SetSubgroupSizeMask(func);
+ }
+ }
+
+ ctx.Clone();
+ return resolver::Resolve(b);
+ }
+
+ /// Get (or create) the `tint_msl_subgroup` helper function.
+ /// @returns the name of the helper function
+ Symbol GetHelper() {
+ if (!ballot_helper) {
+ auto intrinsic = b.Symbols().New("tint_msl_simd_active_threads_mask");
+ subgroup_size_mask = b.Symbols().New("tint_subgroup_size_mask");
+ ballot_helper = b.Symbols().New("tint_msl_subgroup_ballot");
+
+ // Declare the `tint_msl_subgroup_ballot` intrinsic function, which will use the
+ // `simd_active_threads_mask` function to return 64-bit vote.
+ b.Func(intrinsic, Empty, b.ty.vec2<u32>(), nullptr,
+ Vector{b.ASTNodes().Create<SimdActiveThreadsMask>(b.ID(), b.AllocateNodeID()),
+ b.Disable(DisabledValidation::kFunctionHasNoBody)});
+
+ // Declare the `tint_subgroup_size_mask` variable.
+ b.GlobalVar(subgroup_size_mask, core::AddressSpace::kPrivate, b.ty.vec4<u32>());
+
+ // Declare the `tint_msl_subgroup_ballot` helper function as follows:
+ // fn tint_msl_subgroup_ballot() -> vec4u {
+ // let vote : vec2u = vec4f(tint_simd_active_threads_mask(), 0, 0);
+ // return (vote & tint_subgroup_size_mask);
+ // }
+ auto* vote = b.Let(b.Sym(), b.Call(b.ty.vec4<u32>(), b.Call(intrinsic), 0_u, 0_u));
+ b.Func(ballot_helper, Empty, b.ty.vec4<u32>(),
+ Vector{
+ b.Decl(vote),
+ b.Return(b.And(vote, subgroup_size_mask)),
+ });
+ }
+ return ballot_helper;
+ }
+
+ /// Check if a function directly or transitively calls the `subgroupBallot()` builtin.
+ /// @param func the function to check
+ /// @returns true if the function transitively calls `subgroupBallot()`
+ bool TransitvelyCallsSubgroupBallot(const sem::Function* func) {
+ if (ballot_callers.Contains(func)) {
+ return true;
+ }
+ for (auto* called : func->TransitivelyCalledFunctions()) {
+ if (ballot_callers.Contains(called)) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ /// Add code to set the `subgroup_size_mask` variable at the start of an entry point.
+ /// @param ep the entry point
+ void SetSubgroupSizeMask(const ast::Function* ep) {
+ // Check the entry point parameters for an existing `subgroup_size` builtin.
+ Symbol subgroup_size;
+ for (auto* param : ep->params) {
+ auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(param->attributes);
+ if (builtin &&
+ src->Sem()
+ .Get<sem::BuiltinEnumExpression<core::BuiltinValue>>(builtin->builtin)
+ ->Value() == core::BuiltinValue::kSubgroupSize) {
+ subgroup_size = ctx.Clone(param->name->symbol);
+ }
+ }
+ if (!subgroup_size.IsValid()) {
+ // No `subgroup_size` builtin parameter was found, so add one.
+ subgroup_size = b.Symbols().New("tint_subgroup_size");
+ ctx.InsertBack(ep->params, b.Param(subgroup_size, b.ty.u32(),
+ Vector{
+ b.Builtin(core::BuiltinValue::kSubgroupSize),
+ }));
+ }
+
+ // Add the following to the top of the entry point:
+ // {
+ // let gt = subgroup_size > 32;
+ // subgroup_size_mask[0] = select(1 << (subgroup_size - 1), 0xffffffff, gt);
+ // subgroup_size_mask[1] = select(0, 1 << (subgroup_size - 33), gt);
+ // }
+ auto* gt = b.Let(b.Sym("gt"), b.GreaterThan(subgroup_size, 32_u));
+ auto* lo = b.Call("select", b.Shl(1_u, b.Sub(subgroup_size, 1_u)), 0xffffffff_u, gt);
+ auto* hi = b.Call("select", 0_u, b.Shl(1_u, b.Sub(subgroup_size, 33_u)), gt);
+ auto* block = b.Block(Vector{
+ b.Decl(gt),
+ b.Assign(b.IndexAccessor(subgroup_size_mask, 0_u), lo),
+ b.Assign(b.IndexAccessor(subgroup_size_mask, 1_u), hi),
+ });
+ ctx.InsertFront(ep->body->statements, block);
+ }
+};
+
+MslSubgroupBallot::MslSubgroupBallot() = default;
+
+MslSubgroupBallot::~MslSubgroupBallot() = default;
+
+Transform::ApplyResult MslSubgroupBallot::Apply(const Program* src,
+ const DataMap&,
+ DataMap&) const {
+ return State(src).Run();
+}
+
+MslSubgroupBallot::SimdActiveThreadsMask::~SimdActiveThreadsMask() = default;
+
+const MslSubgroupBallot::SimdActiveThreadsMask* MslSubgroupBallot::SimdActiveThreadsMask::Clone(
+ ast::CloneContext& ctx) const {
+ return ctx.dst->ASTNodes().Create<MslSubgroupBallot::SimdActiveThreadsMask>(
+ ctx.dst->ID(), ctx.dst->AllocateNodeID());
+}
+
+} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/msl_subgroup_ballot.h b/src/tint/lang/wgsl/ast/transform/msl_subgroup_ballot.h
new file mode 100644
index 0000000..982f3e2
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/msl_subgroup_ballot.h
@@ -0,0 +1,70 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_MSL_SUBGROUP_BALLOT_H_
+#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_MSL_SUBGROUP_BALLOT_H_
+
+#include <string>
+
+#include "src/tint/lang/wgsl/ast/internal_attribute.h"
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::ast::transform {
+
+/// MslSubgroupBallot is a transform that replaces calls to `subgroupBallot()` with an
+/// implementation that uses MSL's `simd_active_threads_mask()`.
+///
+/// @note Depends on the following transforms to have been run first:
+/// * CanonicalizeEntryPointIO
+class MslSubgroupBallot final : public Castable<MslSubgroupBallot, Transform> {
+ public:
+ /// Constructor
+ MslSubgroupBallot();
+
+ /// Destructor
+ ~MslSubgroupBallot() override;
+
+ /// @copydoc Transform::Apply
+ ApplyResult Apply(const Program* program,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+
+ /// SimdActiveThreadsMask is an InternalAttribute that is used to decorate a stub function so
+ /// that the MSL backend transforms this into calls to the `simd_active_threads_mask` function.
+ class SimdActiveThreadsMask final : public Castable<SimdActiveThreadsMask, InternalAttribute> {
+ public:
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ SimdActiveThreadsMask(GenerationID pid, NodeID nid) : Base(pid, nid, Empty) {}
+
+ /// Destructor
+ ~SimdActiveThreadsMask() override;
+
+ /// @copydoc InternalAttribute::InternalName
+ std::string InternalName() const override { return "simd_active_threads_mask"; }
+
+ /// Performs a deep clone of this object using the program::CloneContext `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned object
+ const SimdActiveThreadsMask* Clone(CloneContext& ctx) const override;
+ };
+
+ private:
+ struct State;
+};
+
+} // namespace tint::ast::transform
+
+#endif // SRC_TINT_LANG_WGSL_AST_TRANSFORM_MSL_SUBGROUP_BALLOT_H_
diff --git a/src/tint/lang/wgsl/ast/transform/msl_subgroup_ballot_test.cc b/src/tint/lang/wgsl/ast/transform/msl_subgroup_ballot_test.cc
new file mode 100644
index 0000000..0ac0105
--- /dev/null
+++ b/src/tint/lang/wgsl/ast/transform/msl_subgroup_ballot_test.cc
@@ -0,0 +1,163 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/wgsl/ast/transform/msl_subgroup_ballot.h"
+
+#include "src/tint/lang/wgsl/ast/transform/helper_test.h"
+
+namespace tint::ast::transform {
+namespace {
+
+using MslSubgroupBallotTest = TransformTest;
+
+TEST_F(MslSubgroupBallotTest, EmptyModule) {
+ auto* src = "";
+
+ EXPECT_FALSE(ShouldRun<MslSubgroupBallot>(src));
+}
+
+TEST_F(MslSubgroupBallotTest, DirectUse) {
+ auto* src = R"(
+enable chromium_experimental_subgroups;
+
+@compute @workgroup_size(64)
+fn foo() {
+ let x : vec4u = subgroupBallot();
+}
+)";
+
+ auto* expect =
+ R"(
+enable chromium_experimental_subgroups;
+
+@internal(simd_active_threads_mask) @internal(disable_validation__function_has_no_body)
+fn tint_msl_simd_active_threads_mask() -> vec2<u32>
+
+var<private> tint_subgroup_size_mask : vec4<u32>;
+
+fn tint_msl_subgroup_ballot() -> vec4<u32> {
+ let tint_symbol = vec4<u32>(tint_msl_simd_active_threads_mask(), 0u, 0u);
+ return (tint_symbol & tint_subgroup_size_mask);
+}
+
+@compute @workgroup_size(64)
+fn foo(@builtin(subgroup_size) tint_subgroup_size : u32) {
+ {
+ let gt = (tint_subgroup_size > 32u);
+ tint_subgroup_size_mask[0u] = select((1u << (tint_subgroup_size - 1u)), 4294967295u, gt);
+ tint_subgroup_size_mask[1u] = select(0u, (1u << (tint_subgroup_size - 33u)), gt);
+ }
+ let x : vec4u = tint_msl_subgroup_ballot();
+}
+)";
+
+ auto got = Run<MslSubgroupBallot>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(MslSubgroupBallotTest, IndirectUse) {
+ auto* src = R"(
+enable chromium_experimental_subgroups;
+
+fn bar() -> vec4u {
+ return subgroupBallot();
+}
+
+@compute @workgroup_size(64)
+fn foo() {
+ let x: vec4u = bar();
+}
+)";
+
+ auto* expect =
+ R"(
+enable chromium_experimental_subgroups;
+
+@internal(simd_active_threads_mask) @internal(disable_validation__function_has_no_body)
+fn tint_msl_simd_active_threads_mask() -> vec2<u32>
+
+var<private> tint_subgroup_size_mask : vec4<u32>;
+
+fn tint_msl_subgroup_ballot() -> vec4<u32> {
+ let tint_symbol = vec4<u32>(tint_msl_simd_active_threads_mask(), 0u, 0u);
+ return (tint_symbol & tint_subgroup_size_mask);
+}
+
+fn bar() -> vec4u {
+ return tint_msl_subgroup_ballot();
+}
+
+@compute @workgroup_size(64)
+fn foo(@builtin(subgroup_size) tint_subgroup_size : u32) {
+ {
+ let gt = (tint_subgroup_size > 32u);
+ tint_subgroup_size_mask[0u] = select((1u << (tint_subgroup_size - 1u)), 4294967295u, gt);
+ tint_subgroup_size_mask[1u] = select(0u, (1u << (tint_subgroup_size - 33u)), gt);
+ }
+ let x : vec4u = bar();
+}
+)";
+
+ auto got = Run<MslSubgroupBallot>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(MslSubgroupBallotTest, PreexistingSubgroupSizeBuiltin) {
+ auto* src = R"(
+enable chromium_experimental_subgroups;
+
+@compute @workgroup_size(64)
+fn foo(@builtin(workgroup_id) group_id: vec3u,
+ @builtin(subgroup_size) size : u32,
+ @builtin(local_invocation_index) index : u32) {
+ let sz = size;
+ let x : vec4u = subgroupBallot();
+}
+)";
+
+ auto* expect =
+ R"(
+enable chromium_experimental_subgroups;
+
+@internal(simd_active_threads_mask) @internal(disable_validation__function_has_no_body)
+fn tint_msl_simd_active_threads_mask() -> vec2<u32>
+
+var<private> tint_subgroup_size_mask : vec4<u32>;
+
+fn tint_msl_subgroup_ballot() -> vec4<u32> {
+ let tint_symbol = vec4<u32>(tint_msl_simd_active_threads_mask(), 0u, 0u);
+ return (tint_symbol & tint_subgroup_size_mask);
+}
+
+@compute @workgroup_size(64)
+fn foo(@builtin(workgroup_id) group_id : vec3u, @builtin(subgroup_size) size : u32, @builtin(local_invocation_index) index : u32) {
+ {
+ let gt = (size > 32u);
+ tint_subgroup_size_mask[0u] = select((1u << (size - 1u)), 4294967295u, gt);
+ tint_subgroup_size_mask[1u] = select(0u, (1u << (size - 33u)), gt);
+ }
+ let sz = size;
+ let x : vec4u = tint_msl_subgroup_ballot();
+}
+)";
+
+ auto got = Run<MslSubgroupBallot>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::ast::transform
diff --git a/test/tint/builtins/gen/literal/subgroupBallot/7e6d0e.wgsl.expected.msl b/test/tint/builtins/gen/literal/subgroupBallot/7e6d0e.wgsl.expected.msl
index 134107e..78d6841 100644
--- a/test/tint/builtins/gen/literal/subgroupBallot/7e6d0e.wgsl.expected.msl
+++ b/test/tint/builtins/gen/literal/subgroupBallot/7e6d0e.wgsl.expected.msl
@@ -1,21 +1,29 @@
-SKIP: FAILED
+#include <metal_stdlib>
+
+using namespace metal;
+struct tint_private_vars_struct {
+ uint4 tint_subgroup_size_mask;
+};
-enable chromium_experimental_subgroups;
-
-fn subgroupBallot_7e6d0e() {
- var res : vec4<u32> = subgroupBallot();
- prevent_dce = res;
+uint4 tint_msl_subgroup_ballot(thread tint_private_vars_struct* const tint_private_vars) {
+ uint4 const tint_symbol = uint4(as_type<uint2>((ulong)simd_active_threads_mask()), 0u, 0u);
+ return (tint_symbol & (*(tint_private_vars)).tint_subgroup_size_mask);
}
-@group(2) @binding(0) var<storage, read_write> prevent_dce : vec4<u32>;
-
-@compute @workgroup_size(1)
-fn compute_main() {
- subgroupBallot_7e6d0e();
+void subgroupBallot_7e6d0e(thread tint_private_vars_struct* const tint_private_vars, device uint4* const tint_symbol_1) {
+ uint4 res = tint_msl_subgroup_ballot(tint_private_vars);
+ *(tint_symbol_1) = res;
}
-Failed to generate: builtins/gen/literal/subgroupBallot/7e6d0e.wgsl:24:8 error: MSL backend does not support extension 'chromium_experimental_subgroups'
-enable chromium_experimental_subgroups;
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+kernel void compute_main(device uint4* tint_symbol_2 [[buffer(0)]], uint tint_subgroup_size [[threads_per_simdgroup]]) {
+ thread tint_private_vars_struct tint_private_vars = {};
+ {
+ bool const gt = (tint_subgroup_size > 32u);
+ tint_private_vars.tint_subgroup_size_mask[0u] = select((1u << (tint_subgroup_size - 1u)), 4294967295u, gt);
+ tint_private_vars.tint_subgroup_size_mask[1u] = select(0u, (1u << (tint_subgroup_size - 33u)), gt);
+ }
+ subgroupBallot_7e6d0e(&(tint_private_vars), tint_symbol_2);
+ return;
+}
diff --git a/test/tint/builtins/gen/var/subgroupBallot/7e6d0e.wgsl.expected.msl b/test/tint/builtins/gen/var/subgroupBallot/7e6d0e.wgsl.expected.msl
index b16ab89..78d6841 100644
--- a/test/tint/builtins/gen/var/subgroupBallot/7e6d0e.wgsl.expected.msl
+++ b/test/tint/builtins/gen/var/subgroupBallot/7e6d0e.wgsl.expected.msl
@@ -1,21 +1,29 @@
-SKIP: FAILED
+#include <metal_stdlib>
+
+using namespace metal;
+struct tint_private_vars_struct {
+ uint4 tint_subgroup_size_mask;
+};
-enable chromium_experimental_subgroups;
-
-fn subgroupBallot_7e6d0e() {
- var res : vec4<u32> = subgroupBallot();
- prevent_dce = res;
+uint4 tint_msl_subgroup_ballot(thread tint_private_vars_struct* const tint_private_vars) {
+ uint4 const tint_symbol = uint4(as_type<uint2>((ulong)simd_active_threads_mask()), 0u, 0u);
+ return (tint_symbol & (*(tint_private_vars)).tint_subgroup_size_mask);
}
-@group(2) @binding(0) var<storage, read_write> prevent_dce : vec4<u32>;
-
-@compute @workgroup_size(1)
-fn compute_main() {
- subgroupBallot_7e6d0e();
+void subgroupBallot_7e6d0e(thread tint_private_vars_struct* const tint_private_vars, device uint4* const tint_symbol_1) {
+ uint4 res = tint_msl_subgroup_ballot(tint_private_vars);
+ *(tint_symbol_1) = res;
}
-Failed to generate: builtins/gen/var/subgroupBallot/7e6d0e.wgsl:24:8 error: MSL backend does not support extension 'chromium_experimental_subgroups'
-enable chromium_experimental_subgroups;
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+kernel void compute_main(device uint4* tint_symbol_2 [[buffer(0)]], uint tint_subgroup_size [[threads_per_simdgroup]]) {
+ thread tint_private_vars_struct tint_private_vars = {};
+ {
+ bool const gt = (tint_subgroup_size > 32u);
+ tint_private_vars.tint_subgroup_size_mask[0u] = select((1u << (tint_subgroup_size - 1u)), 4294967295u, gt);
+ tint_private_vars.tint_subgroup_size_mask[1u] = select(0u, (1u << (tint_subgroup_size - 33u)), gt);
+ }
+ subgroupBallot_7e6d0e(&(tint_private_vars), tint_symbol_2);
+ return;
+}
diff --git a/test/tint/shader_io/compute_subgroup_builtins.wgsl.expected.msl b/test/tint/shader_io/compute_subgroup_builtins.wgsl.expected.msl
index 6e4dcdd..c79a68a 100644
--- a/test/tint/shader_io/compute_subgroup_builtins.wgsl.expected.msl
+++ b/test/tint/shader_io/compute_subgroup_builtins.wgsl.expected.msl
@@ -1,16 +1,29 @@
-SKIP: FAILED
+#include <metal_stdlib>
+using namespace metal;
-enable chromium_experimental_subgroups;
+template<typename T, size_t N>
+struct tint_array {
+ const constant T& operator[](size_t i) const constant { return elements[i]; }
+ device T& operator[](size_t i) device { return elements[i]; }
+ const device T& operator[](size_t i) const device { return elements[i]; }
+ thread T& operator[](size_t i) thread { return elements[i]; }
+ const thread T& operator[](size_t i) const thread { return elements[i]; }
+ threadgroup T& operator[](size_t i) threadgroup { return elements[i]; }
+ const threadgroup T& operator[](size_t i) const threadgroup { return elements[i]; }
+ T elements[N];
+};
-@group(0) @binding(0) var<storage, read_write> output : array<u32>;
+struct tint_symbol_3 {
+ /* 0x0000 */ tint_array<uint, 1> arr;
+};
-@compute @workgroup_size(1)
-fn tint_symbol(@builtin(subgroup_invocation_id) subgroup_invocation_id : u32, @builtin(subgroup_size) subgroup_size : u32) {
- output[subgroup_invocation_id] = subgroup_size;
+void tint_symbol_inner(uint subgroup_invocation_id, uint subgroup_size, device tint_array<uint, 1>* const tint_symbol_1) {
+ (*(tint_symbol_1))[subgroup_invocation_id] = subgroup_size;
}
-Failed to generate: shader_io/compute_subgroup_builtins.wgsl:1:8 error: MSL backend does not support extension 'chromium_experimental_subgroups'
-enable chromium_experimental_subgroups;
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+kernel void tint_symbol(device tint_symbol_3* tint_symbol_2 [[buffer(0)]], uint subgroup_invocation_id [[thread_index_in_simdgroup]], uint subgroup_size [[threads_per_simdgroup]]) {
+ tint_symbol_inner(subgroup_invocation_id, subgroup_size, &((*(tint_symbol_2)).arr));
+ return;
+}
diff --git a/test/tint/shader_io/compute_subgroup_builtins_struct.wgsl.expected.msl b/test/tint/shader_io/compute_subgroup_builtins_struct.wgsl.expected.msl
index 3c8b497..d5f1182 100644
--- a/test/tint/shader_io/compute_subgroup_builtins_struct.wgsl.expected.msl
+++ b/test/tint/shader_io/compute_subgroup_builtins_struct.wgsl.expected.msl
@@ -1,23 +1,35 @@
-SKIP: FAILED
+#include <metal_stdlib>
+using namespace metal;
-enable chromium_experimental_subgroups;
+template<typename T, size_t N>
+struct tint_array {
+ const constant T& operator[](size_t i) const constant { return elements[i]; }
+ device T& operator[](size_t i) device { return elements[i]; }
+ const device T& operator[](size_t i) const device { return elements[i]; }
+ thread T& operator[](size_t i) thread { return elements[i]; }
+ const thread T& operator[](size_t i) const thread { return elements[i]; }
+ threadgroup T& operator[](size_t i) threadgroup { return elements[i]; }
+ const threadgroup T& operator[](size_t i) const threadgroup { return elements[i]; }
+ T elements[N];
+};
-@group(0) @binding(0) var<storage, read_write> output : array<u32>;
+struct tint_symbol_4 {
+ /* 0x0000 */ tint_array<uint, 1> arr;
+};
struct ComputeInputs {
- @builtin(subgroup_invocation_id)
- subgroup_invocation_id : u32,
- @builtin(subgroup_size)
- subgroup_size : u32,
+ uint subgroup_invocation_id;
+ uint subgroup_size;
+};
+
+void tint_symbol_inner(ComputeInputs inputs, device tint_array<uint, 1>* const tint_symbol_2) {
+ (*(tint_symbol_2))[inputs.subgroup_invocation_id] = inputs.subgroup_size;
}
-@compute @workgroup_size(1)
-fn tint_symbol(inputs : ComputeInputs) {
- output[inputs.subgroup_invocation_id] = inputs.subgroup_size;
+kernel void tint_symbol(device tint_symbol_4* tint_symbol_3 [[buffer(0)]], uint subgroup_invocation_id [[thread_index_in_simdgroup]], uint subgroup_size [[threads_per_simdgroup]]) {
+ ComputeInputs const tint_symbol_1 = {.subgroup_invocation_id=subgroup_invocation_id, .subgroup_size=subgroup_size};
+ tint_symbol_inner(tint_symbol_1, &((*(tint_symbol_3)).arr));
+ return;
}
-Failed to generate: shader_io/compute_subgroup_builtins_struct.wgsl:1:8 error: MSL backend does not support extension 'chromium_experimental_subgroups'
-enable chromium_experimental_subgroups;
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-