| // Copyright 2021 The Tint Authors. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "src/tint/transform/zero_init_workgroup_memory.h" |
| |
| #include <algorithm> |
| #include <map> |
| #include <unordered_map> |
| #include <utility> |
| #include <vector> |
| |
| #include "src/tint/ast/workgroup_attribute.h" |
| #include "src/tint/program_builder.h" |
| #include "src/tint/sem/atomic.h" |
| #include "src/tint/sem/function.h" |
| #include "src/tint/sem/variable.h" |
| #include "src/tint/utils/map.h" |
| #include "src/tint/utils/unique_vector.h" |
| |
| TINT_INSTANTIATE_TYPEINFO(tint::transform::ZeroInitWorkgroupMemory); |
| |
| namespace tint::transform { |
| |
| using StatementList = utils::Vector<const ast::Statement*, 8>; |
| |
| /// PIMPL state for the ZeroInitWorkgroupMemory transform |
| struct ZeroInitWorkgroupMemory::State { |
| /// The clone context |
| CloneContext& ctx; |
| |
| /// An alias to *ctx.dst |
| ProgramBuilder& b = *ctx.dst; |
| |
| /// The constant size of the workgroup. If 0, then #workgroup_size_expr should |
| /// be used instead. |
| uint32_t workgroup_size_const = 0; |
| /// The size of the workgroup as an expression generator. Use if |
| /// #workgroup_size_const is 0. |
| std::function<const ast::Expression*()> workgroup_size_expr; |
| |
| /// ArrayIndex represents a function on the local invocation index, of |
| /// the form: `array_index = (local_invocation_index % modulo) / division` |
| struct ArrayIndex { |
| /// The RHS of the modulus part of the expression |
| uint32_t modulo = 1; |
| /// The RHS of the division part of the expression |
| uint32_t division = 1; |
| |
| /// Equality operator |
| /// @param i the ArrayIndex to compare to this ArrayIndex |
| /// @returns true if `i` and this ArrayIndex are equal |
| bool operator==(const ArrayIndex& i) const { |
| return modulo == i.modulo && division == i.division; |
| } |
| |
| /// Hash function for the ArrayIndex type |
| struct Hasher { |
| /// @param i the ArrayIndex to calculate a hash for |
| /// @returns the hash value for the ArrayIndex `i` |
| size_t operator()(const ArrayIndex& i) const { |
| return utils::Hash(i.modulo, i.division); |
| } |
| }; |
| }; |
| |
| /// A list of unique ArrayIndex |
| using ArrayIndices = utils::UniqueVector<ArrayIndex, 4, ArrayIndex::Hasher>; |
| |
| /// Expression holds information about an expression that is being built for a |
| /// statement will zero workgroup values. |
| struct Expression { |
| /// The AST expression node |
| const ast::Expression* expr = nullptr; |
| /// The number of iterations required to zero the value |
| uint32_t num_iterations = 0; |
| /// All array indices used by this expression |
| ArrayIndices array_indices; |
| }; |
| |
| /// Statement holds information about a statement that will zero workgroup |
| /// values. |
| struct Statement { |
| /// The AST statement node |
| const ast::Statement* stmt; |
| /// The number of iterations required to zero the value |
| uint32_t num_iterations; |
| /// All array indices used by this statement |
| ArrayIndices array_indices; |
| }; |
| |
| /// All statements that zero workgroup memory |
| std::vector<Statement> statements; |
| |
| /// A map of ArrayIndex to the name reserved for the `let` declaration of that |
| /// index. |
| std::unordered_map<ArrayIndex, Symbol, ArrayIndex::Hasher> array_index_names; |
| |
| /// Constructor |
| /// @param c the CloneContext used for the transform |
| explicit State(CloneContext& c) : ctx(c) {} |
| |
| /// Run inserts the workgroup memory zero-initialization logic at the top of |
| /// the given function |
| /// @param fn a compute shader entry point function |
| void Run(const ast::Function* fn) { |
| auto& sem = ctx.src->Sem(); |
| |
| CalculateWorkgroupSize(ast::GetAttribute<ast::WorkgroupAttribute>(fn->attributes)); |
| |
| // Generate a list of statements to zero initialize each of the |
| // workgroup storage variables used by `fn`. This will populate #statements. |
| auto* func = sem.Get(fn); |
| for (auto* var : func->TransitivelyReferencedGlobals()) { |
| if (var->AddressSpace() == ast::AddressSpace::kWorkgroup) { |
| BuildZeroingStatements(var->Type()->UnwrapRef(), [&](uint32_t num_values) { |
| auto var_name = ctx.Clone(var->Declaration()->symbol); |
| return Expression{b.Expr(var_name), num_values, ArrayIndices{}}; |
| }); |
| } |
| } |
| |
| if (statements.empty()) { |
| return; // No workgroup variables to initialize. |
| } |
| |
| // Scan the entry point for an existing local_invocation_index builtin |
| // parameter |
| std::function<const ast::Expression*()> local_index; |
| for (auto* param : fn->params) { |
| if (auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(param->attributes)) { |
| if (builtin->builtin == ast::BuiltinValue::kLocalInvocationIndex) { |
| local_index = [=] { return b.Expr(ctx.Clone(param->symbol)); }; |
| break; |
| } |
| } |
| |
| if (auto* str = sem.Get(param)->Type()->As<sem::Struct>()) { |
| for (auto* member : str->Members()) { |
| if (auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>( |
| member->Declaration()->attributes)) { |
| if (builtin->builtin == ast::BuiltinValue::kLocalInvocationIndex) { |
| local_index = [=] { |
| auto* param_expr = b.Expr(ctx.Clone(param->symbol)); |
| auto member_name = ctx.Clone(member->Declaration()->symbol); |
| return b.MemberAccessor(param_expr, member_name); |
| }; |
| break; |
| } |
| } |
| } |
| } |
| } |
| if (!local_index) { |
| // No existing local index parameter. Append one to the entry point. |
| auto* param = b.Param(b.Symbols().New("local_invocation_index"), b.ty.u32(), |
| utils::Vector{ |
| b.Builtin(ast::BuiltinValue::kLocalInvocationIndex), |
| }); |
| ctx.InsertBack(fn->params, param); |
| local_index = [=] { return b.Expr(param->symbol); }; |
| } |
| |
| // Take the zeroing statements and bin them by the number of iterations |
| // required to zero the workgroup data. We then emit these in blocks, |
| // possibly wrapped in if-statements or for-loops. |
| std::unordered_map<uint32_t, std::vector<Statement>> stmts_by_num_iterations; |
| std::vector<uint32_t> num_sorted_iterations; |
| for (auto& s : statements) { |
| auto& stmts = stmts_by_num_iterations[s.num_iterations]; |
| if (stmts.empty()) { |
| num_sorted_iterations.emplace_back(s.num_iterations); |
| } |
| stmts.emplace_back(s); |
| } |
| std::sort(num_sorted_iterations.begin(), num_sorted_iterations.end()); |
| |
| // Loop over the statements, grouped by num_iterations. |
| for (auto num_iterations : num_sorted_iterations) { |
| auto& stmts = stmts_by_num_iterations[num_iterations]; |
| |
| // Gather all the array indices used by all the statements in the block. |
| ArrayIndices array_indices; |
| for (auto& s : stmts) { |
| for (auto& idx : s.array_indices) { |
| array_indices.Add(idx); |
| } |
| } |
| |
| // Determine the block type used to emit these statements. |
| |
| if (workgroup_size_const == 0 || num_iterations > workgroup_size_const) { |
| // Either the workgroup size is dynamic, or smaller than num_iterations. |
| // In either case, we need to generate a for loop to ensure we |
| // initialize all the array elements. |
| // |
| // for (var idx : u32 = local_index; |
| // idx < num_iterations; |
| // idx += workgroup_size) { |
| // ... |
| // } |
| auto idx = b.Symbols().New("idx"); |
| auto* init = b.Decl(b.Var(idx, b.ty.u32(), local_index())); |
| auto* cond = b.create<ast::BinaryExpression>(ast::BinaryOp::kLessThan, b.Expr(idx), |
| b.Expr(u32(num_iterations))); |
| auto* cont = b.Assign( |
| idx, b.Add(idx, workgroup_size_const ? b.Expr(u32(workgroup_size_const)) |
| : workgroup_size_expr())); |
| |
| auto block = |
| DeclareArrayIndices(num_iterations, array_indices, [&] { return b.Expr(idx); }); |
| for (auto& s : stmts) { |
| block.Push(s.stmt); |
| } |
| auto* for_loop = b.For(init, cond, cont, b.Block(block)); |
| ctx.InsertFront(fn->body->statements, for_loop); |
| } else if (num_iterations < workgroup_size_const) { |
| // Workgroup size is a known constant, but is greater than |
| // num_iterations. Emit an if statement: |
| // |
| // if (local_index < num_iterations) { |
| // ... |
| // } |
| auto* cond = b.create<ast::BinaryExpression>( |
| ast::BinaryOp::kLessThan, local_index(), b.Expr(u32(num_iterations))); |
| auto block = DeclareArrayIndices(num_iterations, array_indices, |
| [&] { return b.Expr(local_index()); }); |
| for (auto& s : stmts) { |
| block.Push(s.stmt); |
| } |
| auto* if_stmt = b.If(cond, b.Block(block)); |
| ctx.InsertFront(fn->body->statements, if_stmt); |
| } else { |
| // Workgroup size exactly equals num_iterations. |
| // No need for any conditionals. Just emit a basic block: |
| // |
| // { |
| // ... |
| // } |
| auto block = DeclareArrayIndices(num_iterations, array_indices, |
| [&] { return b.Expr(local_index()); }); |
| for (auto& s : stmts) { |
| block.Push(s.stmt); |
| } |
| ctx.InsertFront(fn->body->statements, b.Block(block)); |
| } |
| } |
| |
| // Append a single workgroup barrier after the zero initialization. |
| ctx.InsertFront(fn->body->statements, b.CallStmt(b.Call("workgroupBarrier"))); |
| } |
| |
| /// BuildZeroingExpr is a function that builds a sub-expression used to zero |
| /// workgroup values. `num_values` is the number of elements that the |
| /// expression will be used to zero. Returns the expression. |
| using BuildZeroingExpr = std::function<Expression(uint32_t num_values)>; |
| |
| /// BuildZeroingStatements() generates the statements required to zero |
| /// initialize the workgroup storage expression of type `ty`. |
| /// @param ty the expression type |
| /// @param get_expr a function that builds the AST nodes for the expression. |
| void BuildZeroingStatements(const sem::Type* ty, const BuildZeroingExpr& get_expr) { |
| if (CanTriviallyZero(ty)) { |
| auto var = get_expr(1u); |
| auto* zero_init = b.Construct(CreateASTTypeFor(ctx, ty)); |
| statements.emplace_back( |
| Statement{b.Assign(var.expr, zero_init), var.num_iterations, var.array_indices}); |
| return; |
| } |
| |
| if (auto* atomic = ty->As<sem::Atomic>()) { |
| auto* zero_init = b.Construct(CreateASTTypeFor(ctx, atomic->Type())); |
| auto expr = get_expr(1u); |
| auto* store = b.Call("atomicStore", b.AddressOf(expr.expr), zero_init); |
| statements.emplace_back( |
| Statement{b.CallStmt(store), expr.num_iterations, expr.array_indices}); |
| return; |
| } |
| |
| if (auto* str = ty->As<sem::Struct>()) { |
| for (auto* member : str->Members()) { |
| auto name = ctx.Clone(member->Declaration()->symbol); |
| BuildZeroingStatements(member->Type(), [&](uint32_t num_values) { |
| auto s = get_expr(num_values); |
| return Expression{b.MemberAccessor(s.expr, name), s.num_iterations, |
| s.array_indices}; |
| }); |
| } |
| return; |
| } |
| |
| if (auto* arr = ty->As<sem::Array>()) { |
| BuildZeroingStatements(arr->ElemType(), [&](uint32_t num_values) { |
| // num_values is the number of values to zero for the element type. |
| // The number of iterations required to zero the array and its elements |
| // is: |
| // `num_values * arr->Count()` |
| // The index for this array is: |
| // `(idx % modulo) / division` |
| auto count = arr->ConstantCount(); |
| if (!count) { |
| ctx.dst->Diagnostics().add_error(diag::System::Transform, |
| sem::Array::kErrExpectedConstantCount); |
| return Expression{}; |
| } |
| auto modulo = num_values * count.value(); |
| auto division = num_values; |
| auto a = get_expr(modulo); |
| auto array_indices = a.array_indices; |
| array_indices.Add(ArrayIndex{modulo, division}); |
| auto index = utils::GetOrCreate(array_index_names, ArrayIndex{modulo, division}, |
| [&] { return b.Symbols().New("i"); }); |
| return Expression{b.IndexAccessor(a.expr, index), a.num_iterations, array_indices}; |
| }); |
| return; |
| } |
| |
| TINT_UNREACHABLE(Transform, b.Diagnostics()) |
| << "could not zero workgroup type: " << ty->FriendlyName(ctx.src->Symbols()); |
| } |
| |
| /// DeclareArrayIndices returns a list of statements that contain the `let` |
| /// declarations for all of the ArrayIndices. |
| /// @param num_iterations the number of iterations for the block |
| /// @param array_indices the list of array indices to generate `let` |
| /// declarations for |
| /// @param iteration a function that returns the index of the current |
| /// iteration. |
| /// @returns the list of `let` statements that declare the array indices |
| StatementList DeclareArrayIndices(uint32_t num_iterations, |
| const ArrayIndices& array_indices, |
| const std::function<const ast::Expression*()>& iteration) { |
| StatementList stmts; |
| std::map<Symbol, ArrayIndex> indices_by_name; |
| for (auto index : array_indices) { |
| auto name = array_index_names.at(index); |
| auto* mod = (num_iterations > index.modulo) |
| ? b.create<ast::BinaryExpression>(ast::BinaryOp::kModulo, iteration(), |
| b.Expr(u32(index.modulo))) |
| : iteration(); |
| auto* div = (index.division != 1u) ? b.Div(mod, u32(index.division)) : mod; |
| auto* decl = b.Decl(b.Let(name, b.ty.u32(), div)); |
| stmts.Push(decl); |
| } |
| return stmts; |
| } |
| |
| /// CalculateWorkgroupSize initializes the members #workgroup_size_const and |
| /// #workgroup_size_expr with the linear workgroup size. |
| /// @param attr the workgroup attribute applied to the entry point function |
| void CalculateWorkgroupSize(const ast::WorkgroupAttribute* attr) { |
| bool is_signed = false; |
| workgroup_size_const = 1u; |
| workgroup_size_expr = nullptr; |
| for (auto* expr : attr->Values()) { |
| if (!expr) { |
| continue; |
| } |
| auto* sem = ctx.src->Sem().Get(expr); |
| if (auto* c = sem->ConstantValue()) { |
| workgroup_size_const *= c->As<AInt>(); |
| continue; |
| } |
| // Constant value could not be found. Build expression instead. |
| workgroup_size_expr = [this, expr, size = workgroup_size_expr] { |
| auto* e = ctx.Clone(expr); |
| if (ctx.src->TypeOf(expr)->UnwrapRef()->Is<sem::I32>()) { |
| e = b.Construct<u32>(e); |
| } |
| return size ? b.Mul(size(), e) : e; |
| }; |
| } |
| if (workgroup_size_expr) { |
| if (workgroup_size_const != 1) { |
| // Fold workgroup_size_const in to workgroup_size_expr |
| workgroup_size_expr = [this, is_signed, const_size = workgroup_size_const, |
| expr_size = workgroup_size_expr] { |
| return is_signed ? b.Mul(expr_size(), i32(const_size)) |
| : b.Mul(expr_size(), u32(const_size)); |
| }; |
| } |
| // Indicate that workgroup_size_expr should be used instead of the |
| // constant. |
| workgroup_size_const = 0; |
| } |
| } |
| |
| /// @returns true if a variable with store type `ty` can be efficiently zeroed |
| /// by assignment of a type initializer without operands. If |
| /// CanTriviallyZero() returns false, then the type needs to be |
| /// initialized by decomposing the initialization into multiple |
| /// sub-initializations. |
| /// @param ty the type to inspect |
| bool CanTriviallyZero(const sem::Type* ty) { |
| if (ty->Is<sem::Atomic>()) { |
| return false; |
| } |
| if (auto* str = ty->As<sem::Struct>()) { |
| for (auto* member : str->Members()) { |
| if (!CanTriviallyZero(member->Type())) { |
| return false; |
| } |
| } |
| } |
| if (ty->Is<sem::Array>()) { |
| return false; |
| } |
| // True for all other storable types |
| return true; |
| } |
| }; |
| |
| ZeroInitWorkgroupMemory::ZeroInitWorkgroupMemory() = default; |
| |
| ZeroInitWorkgroupMemory::~ZeroInitWorkgroupMemory() = default; |
| |
| bool ZeroInitWorkgroupMemory::ShouldRun(const Program* program, const DataMap&) const { |
| for (auto* global : program->AST().GlobalVariables()) { |
| if (auto* var = global->As<ast::Var>()) { |
| if (var->declared_address_space == ast::AddressSpace::kWorkgroup) { |
| return true; |
| } |
| } |
| } |
| return false; |
| } |
| |
| void ZeroInitWorkgroupMemory::Run(CloneContext& ctx, const DataMap&, DataMap&) const { |
| for (auto* fn : ctx.src->AST().Functions()) { |
| if (fn->PipelineStage() == ast::PipelineStage::kCompute) { |
| State{ctx}.Run(fn); |
| } |
| } |
| ctx.Clone(); |
| } |
| |
| } // namespace tint::transform |