blob: 0afe1a07f34096e91a58be07cf8c08b9397e7890 [file] [log] [blame]
// Copyright 2024 The Dawn & Tint Authors
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "src/tint/lang/msl/writer/ast_raise/quad_swap.h"
#include <utility>
#include "src/tint/lang/core/builtin_fn.h"
#include "src/tint/lang/core/fluent_types.h"
#include "src/tint/lang/wgsl/ast/transform/transform.h"
#include "src/tint/lang/wgsl/builtin_fn.h"
#include "src/tint/lang/wgsl/program/clone_context.h"
#include "src/tint/lang/wgsl/program/program_builder.h"
#include "src/tint/lang/wgsl/resolver/resolve.h"
#include "src/tint/lang/wgsl/sem/call.h"
#include "src/tint/lang/wgsl/sem/function.h"
#include "src/tint/lang/wgsl/sem/statement.h"
TINT_INSTANTIATE_TYPEINFO(tint::msl::writer::QuadSwap);
TINT_INSTANTIATE_TYPEINFO(tint::msl::writer::QuadSwap::QuadShuffle);
TINT_INSTANTIATE_TYPEINFO(tint::msl::writer::QuadSwap::ThreadIndexInQuadgroup);
using namespace tint::core::number_suffixes; // NOLINT
using namespace tint::core::fluent_types; // NOLINT
namespace tint::msl::writer {
/// PIMPL state for the transform
struct QuadSwap::State {
/// The source program
const Program& src;
/// The target program builder
ProgramBuilder b;
/// The clone context
program::CloneContext ctx = {&b, &src, /* auto_clone_symbols */ true};
/// The set of names for the tint_msl_quadSwap* helper functions.
Hashmap<std::pair<wgsl::BuiltinFn, const core::type::Type*>, Symbol, 8> quad_swap_helpers;
/// The set of names for the tint_msl_quad_shuffle* intrinsic functions.
Hashmap<const core::type::Type*, Symbol, 8> quad_shuffle_intrinsics;
/// The name of the `tint_msl_thread_index_in_quadgroup` global variable.
Symbol thread_index_in_quadgroup{};
/// The set of a functions that directly call `quadSwap*()` builtin functions.
Hashset<const sem::Function*, 4> quad_swap_callers;
/// Constructor
/// @param program the source program
explicit State(const Program& program) : src(program) {}
/// Runs the transform
/// @returns the new program or SkipTransform if the transform is not required
ApplyResult Run() {
auto& sem = src.Sem();
bool made_changes = false;
for (auto* node : ctx.src->ASTNodes().Objects()) {
auto* call = sem.Get<sem::Call>(node);
if (call) {
// If this is a call to a `quadSwap*()` builtin, replace it with a call to the
// helper function and make a note of the function that we are in.
auto* builtin = call->Target()->As<sem::BuiltinFn>();
if (builtin && builtin->IsQuadSwap()) {
auto* expr = ctx.Clone(call->Arguments()[0]->Declaration());
ctx.Replace(call->Declaration(),
b.Call(GetHelper(builtin->Fn(), call->Type()), expr));
quad_swap_callers.Add(call->Stmt()->Function());
made_changes = true;
}
}
}
if (!made_changes) {
return SkipTransform;
}
// Set the thread_index_in_quadgroup global variable at the start of each entry point that
// transitively calls the `quadSwap*()` builtin function.
for (auto* global : src.AST().GlobalDeclarations()) {
auto* func = global->As<ast::Function>();
if (func && func->IsEntryPoint() && TransitvelyCallsQuadSwap(sem.Get(func))) {
SetThreadIndexInQuadgroup(func);
}
}
ctx.Clone();
return resolver::Resolve(b);
}
/// Get or create the appropriate `tint_msl_quadSwap*` helper function for the given builtin
/// function and return type.
/// @returns the name of the helper function
Symbol GetHelper(wgsl::BuiltinFn func, const core::type::Type* type) {
return quad_swap_helpers.GetOrAdd(std::make_pair(func, type), [&] {
Symbol quad_swap_helper = b.Symbols().New(std::string("tint_msl_") + str(func));
Symbol intrinsic = GetIntrinsic(type);
// Declare the `tint_msl_quadSwap*` helper function as follows:
// fn tint_msl_quadSwapFnName(e : T) -> T {
// tint_msl_quad_shuffle(e, tint_msl_thread_index_in_quadgroup ^ rhs)
// }
// where rhs is determined based on the builtin function call:
// +------------------+------+
// | func | rhs |
// +------------------+------+
// | quadSwapX | 0b1 |
// | quadSwapY | 0b10 |
// | quadSwapDiagonal | 0b11 |
// +------------------+------+
auto* expr = b.Param("e", CreateASTTypeFor(ctx, type));
core::u32 rhs;
switch (func) {
case wgsl::BuiltinFn::kQuadSwapX:
rhs = 0b1;
break;
case wgsl::BuiltinFn::kQuadSwapY:
rhs = 0b10;
break;
case wgsl::BuiltinFn::kQuadSwapDiagonal:
rhs = 0b11;
break;
default:
TINT_UNREACHABLE() << "unsupported builtin function";
}
b.Func(quad_swap_helper, Vector{expr}, CreateASTTypeFor(ctx, type),
Vector{
b.Return(b.Call(intrinsic, expr, b.Xor(GetGlobalVar(), rhs))),
});
return quad_swap_helper;
});
}
/// Get or create the `tint_msl_quad_shuffle` intrinsic placeholder function for the given
/// return type.
/// @returns the name of the function
Symbol GetIntrinsic(const core::type::Type* type) {
return quad_shuffle_intrinsics.GetOrAdd(type, [&] {
auto intrinsic = b.Symbols().New(std::string("tint_msl_quad_shuffle"));
// Declare the `tint_msl_quad_shuffle` function, which will be replaced by the MSL
// `quad_shuffle` intrinsic function to perform the swap.
{
auto* data = b.Param("data", CreateASTTypeFor(ctx, type));
auto* quad_lane_id = b.Param("quad_lane_id", b.ty.u32());
b.Func(intrinsic, Vector{data, quad_lane_id}, CreateASTTypeFor(ctx, type), nullptr,
Vector{b.ASTNodes().Create<QuadShuffle>(b.ID(), b.AllocateNodeID()),
b.Disable(ast::DisabledValidation::kFunctionHasNoBody)});
}
return intrinsic;
});
}
/// Get or create the `tint_msl_thread_index_in_quadgroup` global variable.
/// @returns the name of the variable
Symbol GetGlobalVar() {
if (!thread_index_in_quadgroup) {
thread_index_in_quadgroup = b.Symbols().New("tint_msl_thread_index_in_quadgroup");
// Declare the `tint_msl_thread_index_in_quadgroup` variable.
b.GlobalVar(thread_index_in_quadgroup, core::AddressSpace::kPrivate, b.ty.u32());
}
return thread_index_in_quadgroup;
}
/// Check if a function directly or transitively calls a `quadSwap*()` builtin.
/// @param func the function to check
/// @returns true if the function transitively calls `quadSwap*()`
bool TransitvelyCallsQuadSwap(const sem::Function* func) {
if (quad_swap_callers.Contains(func)) {
return true;
}
for (auto* called : func->TransitivelyCalledFunctions()) {
if (quad_swap_callers.Contains(called)) {
return true;
}
}
return false;
}
/// Add code to set the `thread_index_in_quadgroup` variable at the start of an entry point.
/// @param ep the entry point
void SetThreadIndexInQuadgroup(const ast::Function* ep) {
// Add the entry point parameter with an attribute to indicate to the the MSL backend that
// it should be annotated with [[thread_index_in_quadgroup]].
Symbol thread_index_in_quadgroup_param = b.Symbols().New("tint_thread_index_in_quadgroup");
ctx.InsertBack(ep->params, b.Param(thread_index_in_quadgroup_param, b.ty.u32(),
Vector{b.ASTNodes().Create<ThreadIndexInQuadgroup>(
b.ID(), b.AllocateNodeID())}));
// Add the following to the top of the entry point:
// {
// tint_msl_thread_index_in_quadgroup = tint_thread_index_in_quadgroup;
// }
auto* block = b.Block(Vector{
b.Assign(thread_index_in_quadgroup, thread_index_in_quadgroup_param),
});
ctx.InsertFront(ep->body->statements, block);
}
};
QuadSwap::QuadSwap() = default;
QuadSwap::~QuadSwap() = default;
ast::transform::Transform::ApplyResult QuadSwap::Apply(const Program& src,
const ast::transform::DataMap&,
ast::transform::DataMap&) const {
return State(src).Run();
}
QuadSwap::QuadShuffle::~QuadShuffle() = default;
const QuadSwap::QuadShuffle* QuadSwap::QuadShuffle::Clone(ast::CloneContext& ctx) const {
return ctx.dst->ASTNodes().Create<QuadSwap::QuadShuffle>(ctx.dst->ID(),
ctx.dst->AllocateNodeID());
}
QuadSwap::ThreadIndexInQuadgroup::~ThreadIndexInQuadgroup() = default;
const QuadSwap::ThreadIndexInQuadgroup* QuadSwap::ThreadIndexInQuadgroup::Clone(
ast::CloneContext& ctx) const {
return ctx.dst->ASTNodes().Create<QuadSwap::ThreadIndexInQuadgroup>(ctx.dst->ID(),
ctx.dst->AllocateNodeID());
}
} // namespace tint::msl::writer