blob: 86013562fc2abf84ba5f57c0e5d087bb6071c6ba [file] [log] [blame]
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/transform/zero_init_workgroup_memory.h"
#include <algorithm>
#include <map>
#include <unordered_map>
#include <utility>
#include <vector>
#include "src/ast/workgroup_decoration.h"
#include "src/program_builder.h"
#include "src/sem/atomic_type.h"
#include "src/sem/function.h"
#include "src/sem/variable.h"
#include "src/utils/map.h"
#include "src/utils/unique_vector.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::ZeroInitWorkgroupMemory);
namespace tint {
namespace transform {
/// PIMPL state for the ZeroInitWorkgroupMemory transform
struct ZeroInitWorkgroupMemory::State {
/// The clone context
CloneContext& ctx;
/// An alias to *ctx.dst
ProgramBuilder& b = *ctx.dst;
/// The constant size of the workgroup. If 0, then #workgroup_size_expr should
/// be used instead.
uint32_t workgroup_size_const = 0;
/// The size of the workgroup as an expression generator. Use if
/// #workgroup_size_const is 0.
std::function<const ast::Expression*()> workgroup_size_expr;
/// ArrayIndex represents a function on the local invocation index, of
/// the form: `array_index = (local_invocation_index % modulo) / division`
struct ArrayIndex {
/// The RHS of the modulus part of the expression
uint32_t modulo = 1;
/// The RHS of the division part of the expression
uint32_t division = 1;
/// Equality operator
/// @param i the ArrayIndex to compare to this ArrayIndex
/// @returns true if `i` and this ArrayIndex are equal
bool operator==(const ArrayIndex& i) const {
return modulo == i.modulo && division == i.division;
}
/// Hash function for the ArrayIndex type
struct Hasher {
/// @param i the ArrayIndex to calculate a hash for
/// @returns the hash value for the ArrayIndex `i`
size_t operator()(const ArrayIndex& i) const {
return utils::Hash(i.modulo, i.division);
}
};
};
/// A list of unique ArrayIndex
using ArrayIndices = utils::UniqueVector<ArrayIndex, ArrayIndex::Hasher>;
/// Expression holds information about an expression that is being built for a
/// statement will zero workgroup values.
struct Expression {
/// The AST expression node
const ast::Expression* expr = nullptr;
/// The number of iterations required to zero the value
uint32_t num_iterations = 0;
/// All array indices used by this expression
ArrayIndices array_indices;
};
/// Statement holds information about a statement that will zero workgroup
/// values.
struct Statement {
/// The AST statement node
const ast::Statement* stmt;
/// The number of iterations required to zero the value
uint32_t num_iterations;
/// All array indices used by this statement
ArrayIndices array_indices;
};
/// All statements that zero workgroup memory
std::vector<Statement> statements;
/// A map of ArrayIndex to the name reserved for the `let` declaration of that
/// index.
std::unordered_map<ArrayIndex, Symbol, ArrayIndex::Hasher> array_index_names;
/// Constructor
/// @param c the CloneContext used for the transform
explicit State(CloneContext& c) : ctx(c) {}
/// Run inserts the workgroup memory zero-initialization logic at the top of
/// the given function
/// @param fn a compute shader entry point function
void Run(const ast::Function* fn) {
auto& sem = ctx.src->Sem();
CalculateWorkgroupSize(
ast::GetDecoration<ast::WorkgroupDecoration>(fn->decorations));
// Generate a list of statements to zero initialize each of the
// workgroup storage variables used by `fn`. This will populate #statements.
auto* func = sem.Get(fn);
for (auto* var : func->TransitivelyReferencedGlobals()) {
if (var->StorageClass() == ast::StorageClass::kWorkgroup) {
BuildZeroingStatements(
var->Type()->UnwrapRef(), [&](uint32_t num_values) {
auto var_name = ctx.Clone(var->Declaration()->symbol);
return Expression{b.Expr(var_name), num_values, ArrayIndices{}};
});
}
}
if (statements.empty()) {
return; // No workgroup variables to initialize.
}
// Scan the entry point for an existing local_invocation_index builtin
// parameter
std::function<const ast::Expression*()> local_index;
for (auto* param : fn->params) {
if (auto* builtin =
ast::GetDecoration<ast::BuiltinDecoration>(param->decorations)) {
if (builtin->builtin == ast::Builtin::kLocalInvocationIndex) {
local_index = [=] { return b.Expr(ctx.Clone(param->symbol)); };
break;
}
}
if (auto* str = sem.Get(param)->Type()->As<sem::Struct>()) {
for (auto* member : str->Members()) {
if (auto* builtin = ast::GetDecoration<ast::BuiltinDecoration>(
member->Declaration()->decorations)) {
if (builtin->builtin == ast::Builtin::kLocalInvocationIndex) {
local_index = [=] {
auto* param_expr = b.Expr(ctx.Clone(param->symbol));
auto member_name = ctx.Clone(member->Declaration()->symbol);
return b.MemberAccessor(param_expr, member_name);
};
break;
}
}
}
}
}
if (!local_index) {
// No existing local index parameter. Append one to the entry point.
auto* param =
b.Param(b.Symbols().New("local_invocation_index"), b.ty.u32(),
{b.Builtin(ast::Builtin::kLocalInvocationIndex)});
ctx.InsertBack(fn->params, param);
local_index = [=] { return b.Expr(param->symbol); };
}
// Take the zeroing statements and bin them by the number of iterations
// required to zero the workgroup data. We then emit these in blocks,
// possibly wrapped in if-statements or for-loops.
std::unordered_map<uint32_t, std::vector<Statement>>
stmts_by_num_iterations;
std::vector<uint32_t> num_sorted_iterations;
for (auto& s : statements) {
auto& stmts = stmts_by_num_iterations[s.num_iterations];
if (stmts.empty()) {
num_sorted_iterations.emplace_back(s.num_iterations);
}
stmts.emplace_back(s);
}
std::sort(num_sorted_iterations.begin(), num_sorted_iterations.end());
// Loop over the statements, grouped by num_iterations.
for (auto num_iterations : num_sorted_iterations) {
auto& stmts = stmts_by_num_iterations[num_iterations];
// Gather all the array indices used by all the statements in the block.
ArrayIndices array_indices;
for (auto& s : stmts) {
for (auto& idx : s.array_indices) {
array_indices.add(idx);
}
}
// Determine the block type used to emit these statements.
if (workgroup_size_const == 0 || num_iterations > workgroup_size_const) {
// Either the workgroup size is dynamic, or smaller than num_iterations.
// In either case, we need to generate a for loop to ensure we
// initialize all the array elements.
//
// for (var idx : u32 = local_index;
// idx < num_iterations;
// idx += workgroup_size) {
// ...
// }
auto idx = b.Symbols().New("idx");
auto* init = b.Decl(b.Var(idx, b.ty.u32(), local_index()));
auto* cond = b.create<ast::BinaryExpression>(
ast::BinaryOp::kLessThan, b.Expr(idx), b.Expr(num_iterations));
auto* cont = b.Assign(
idx, b.Add(idx, workgroup_size_const ? b.Expr(workgroup_size_const)
: workgroup_size_expr()));
auto block = DeclareArrayIndices(num_iterations, array_indices,
[&] { return b.Expr(idx); });
for (auto& s : stmts) {
block.emplace_back(s.stmt);
}
auto* for_loop = b.For(init, cond, cont, b.Block(block));
ctx.InsertFront(fn->body->statements, for_loop);
} else if (num_iterations < workgroup_size_const) {
// Workgroup size is a known constant, but is greater than
// num_iterations. Emit an if statement:
//
// if (local_index < num_iterations) {
// ...
// }
auto* cond = b.create<ast::BinaryExpression>(
ast::BinaryOp::kLessThan, local_index(), b.Expr(num_iterations));
auto block = DeclareArrayIndices(num_iterations, array_indices,
[&] { return b.Expr(local_index()); });
for (auto& s : stmts) {
block.emplace_back(s.stmt);
}
auto* if_stmt = b.If(cond, b.Block(block));
ctx.InsertFront(fn->body->statements, if_stmt);
} else {
// Workgroup size exactly equals num_iterations.
// No need for any conditionals. Just emit a basic block:
//
// {
// ...
// }
auto block = DeclareArrayIndices(num_iterations, array_indices,
[&] { return b.Expr(local_index()); });
for (auto& s : stmts) {
block.emplace_back(s.stmt);
}
ctx.InsertFront(fn->body->statements, b.Block(block));
}
}
// Append a single workgroup barrier after the zero initialization.
ctx.InsertFront(fn->body->statements,
b.CallStmt(b.Call("workgroupBarrier")));
}
/// BuildZeroingExpr is a function that builds a sub-expression used to zero
/// workgroup values. `num_values` is the number of elements that the
/// expression will be used to zero. Returns the expression.
using BuildZeroingExpr = std::function<Expression(uint32_t num_values)>;
/// BuildZeroingStatements() generates the statements required to zero
/// initialize the workgroup storage expression of type `ty`.
/// @param ty the expression type
/// @param get_expr a function that builds the AST nodes for the expression.
void BuildZeroingStatements(const sem::Type* ty,
const BuildZeroingExpr& get_expr) {
if (CanTriviallyZero(ty)) {
auto var = get_expr(1u);
auto* zero_init = b.Construct(CreateASTTypeFor(ctx, ty));
statements.emplace_back(Statement{b.Assign(var.expr, zero_init),
var.num_iterations, var.array_indices});
return;
}
if (auto* atomic = ty->As<sem::Atomic>()) {
auto* zero_init = b.Construct(CreateASTTypeFor(ctx, atomic->Type()));
auto expr = get_expr(1u);
auto* store = b.Call("atomicStore", b.AddressOf(expr.expr), zero_init);
statements.emplace_back(Statement{b.CallStmt(store), expr.num_iterations,
expr.array_indices});
return;
}
if (auto* str = ty->As<sem::Struct>()) {
for (auto* member : str->Members()) {
auto name = ctx.Clone(member->Declaration()->symbol);
BuildZeroingStatements(member->Type(), [&](uint32_t num_values) {
auto s = get_expr(num_values);
return Expression{b.MemberAccessor(s.expr, name), s.num_iterations,
s.array_indices};
});
}
return;
}
if (auto* arr = ty->As<sem::Array>()) {
BuildZeroingStatements(arr->ElemType(), [&](uint32_t num_values) {
// num_values is the number of values to zero for the element type.
// The number of iterations required to zero the array and its elements
// is:
// `num_values * arr->Count()`
// The index for this array is:
// `(idx % modulo) / division`
auto modulo = num_values * arr->Count();
auto division = num_values;
auto a = get_expr(modulo);
auto array_indices = a.array_indices;
array_indices.add(ArrayIndex{modulo, division});
auto index =
utils::GetOrCreate(array_index_names, ArrayIndex{modulo, division},
[&] { return b.Symbols().New("i"); });
return Expression{b.IndexAccessor(a.expr, index), a.num_iterations,
array_indices};
});
return;
}
TINT_UNREACHABLE(Transform, b.Diagnostics())
<< "could not zero workgroup type: " << ty->type_name();
}
/// DeclareArrayIndices returns a list of statements that contain the `let`
/// declarations for all of the ArrayIndices.
/// @param num_iterations the number of iterations for the block
/// @param array_indices the list of array indices to generate `let`
/// declarations for
/// @param iteration a function that returns the index of the current
/// iteration.
/// @returns the list of `let` statements that declare the array indices
ast::StatementList DeclareArrayIndices(
uint32_t num_iterations,
const ArrayIndices& array_indices,
const std::function<const ast::Expression*()>& iteration) {
ast::StatementList stmts;
std::map<Symbol, ArrayIndex> indices_by_name;
for (auto index : array_indices) {
auto name = array_index_names.at(index);
auto* mod =
(num_iterations > index.modulo)
? b.create<ast::BinaryExpression>(
ast::BinaryOp::kModulo, iteration(), b.Expr(index.modulo))
: iteration();
auto* div = (index.division != 1u) ? b.Div(mod, index.division) : mod;
auto* decl = b.Decl(b.Const(name, b.ty.u32(), div));
stmts.emplace_back(decl);
}
return stmts;
}
/// CalculateWorkgroupSize initializes the members #workgroup_size_const and
/// #workgroup_size_expr with the linear workgroup size.
/// @param deco the workgroup decoration applied to the entry point function
void CalculateWorkgroupSize(const ast::WorkgroupDecoration* deco) {
bool is_signed = false;
workgroup_size_const = 1u;
workgroup_size_expr = nullptr;
for (auto* expr : deco->Values()) {
if (!expr) {
continue;
}
auto* sem = ctx.src->Sem().Get(expr);
if (auto c = sem->ConstantValue()) {
if (c.ElementType()->Is<sem::I32>()) {
workgroup_size_const *= static_cast<uint32_t>(c.Elements()[0].i32);
continue;
} else if (c.ElementType()->Is<sem::U32>()) {
workgroup_size_const *= c.Elements()[0].u32;
continue;
}
}
// Constant value could not be found. Build expression instead.
workgroup_size_expr = [this, expr, size = workgroup_size_expr] {
auto* e = ctx.Clone(expr);
if (ctx.src->TypeOf(expr)->UnwrapRef()->Is<sem::I32>()) {
e = b.Construct<ProgramBuilder::u32>(e);
}
return size ? b.Mul(size(), e) : e;
};
}
if (workgroup_size_expr) {
if (workgroup_size_const != 1) {
// Fold workgroup_size_const in to workgroup_size_expr
workgroup_size_expr = [this, is_signed,
const_size = workgroup_size_const,
expr_size = workgroup_size_expr] {
return is_signed
? b.Mul(expr_size(), static_cast<int32_t>(const_size))
: b.Mul(expr_size(), const_size);
};
}
// Indicate that workgroup_size_expr should be used instead of the
// constant.
workgroup_size_const = 0;
}
}
/// @returns true if a variable with store type `ty` can be efficiently zeroed
/// by assignment of a type constructor without operands. If
/// CanTriviallyZero() returns false, then the type needs to be
/// initialized by decomposing the initialization into multiple
/// sub-initializations.
/// @param ty the type to inspect
bool CanTriviallyZero(const sem::Type* ty) {
if (ty->Is<sem::Atomic>()) {
return false;
}
if (auto* str = ty->As<sem::Struct>()) {
for (auto* member : str->Members()) {
if (!CanTriviallyZero(member->Type())) {
return false;
}
}
}
if (ty->Is<sem::Array>()) {
return false;
}
// True for all other storable types
return true;
}
};
ZeroInitWorkgroupMemory::ZeroInitWorkgroupMemory() = default;
ZeroInitWorkgroupMemory::~ZeroInitWorkgroupMemory() = default;
void ZeroInitWorkgroupMemory::Run(CloneContext& ctx, const DataMap&, DataMap&) {
for (auto* fn : ctx.src->AST().Functions()) {
if (fn->PipelineStage() == ast::PipelineStage::kCompute) {
State{ctx}.Run(fn);
}
}
ctx.Clone();
}
} // namespace transform
} // namespace tint