| // Copyright 2023 The Tint Authors. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "src/tint/lang/core/ir/transform/zero_init_workgroup_memory.h" |
| |
| #include <map> |
| #include <utility> |
| |
| #include "src/tint/lang/core/ir/builder.h" |
| #include "src/tint/lang/core/ir/module.h" |
| #include "src/tint/lang/core/ir/validator.h" |
| #include "src/tint/utils/containers/reverse.h" |
| |
| using namespace tint::core::fluent_types; // NOLINT |
| using namespace tint::core::number_suffixes; // NOLINT |
| |
| namespace tint::core::ir::transform { |
| |
| namespace { |
| |
| /// PIMPL state for the transform. |
| struct State { |
| /// The IR module. |
| Module& ir; |
| |
| /// The IR builder. |
| Builder b{ir}; |
| |
| /// The type manager. |
| core::type::Manager& ty{ir.Types()}; |
| |
| /// VarSet is a hash set of workgroup variables. |
| using VarSet = Hashset<Var*, 8>; |
| |
| /// A map from variable to an ID used for sorting. |
| Hashmap<Var*, uint32_t, 8> var_to_id{}; |
| |
| /// A map from blocks to their directly referenced workgroup variables. |
| Hashmap<Block*, VarSet, 64> block_to_direct_vars{}; |
| |
| /// A map from functions to their transitively referenced workgroup variables. |
| Hashmap<Function*, VarSet, 8> function_to_transitive_vars{}; |
| |
| /// ArrayIndex represents a required array index for an access instruction. |
| struct ArrayIndex { |
| /// The size of the array that will be indexed. |
| uint32_t count = 0u; |
| }; |
| |
| /// Index represents an index for an access instruction, which is either a constant value or |
| /// an array index that will be dynamically calculated from an array size. |
| using Index = std::variant<uint32_t, ArrayIndex>; |
| |
| /// Store describes a store to a sub-element of a workgroup variable. |
| struct Store { |
| /// The workgroup variable. |
| Var* var = nullptr; |
| /// The store type of the element. |
| const type::Type* store_type = nullptr; |
| /// The list of index operands to get to the element. |
| Vector<Index, 4> indices; |
| }; |
| |
| /// StoreList is a list of `Store` descriptors. |
| using StoreList = Vector<Store, 8>; |
| |
| /// StoreMap is a map from iteration count to a list of `Store` descriptors. |
| using StoreMap = Hashmap<uint32_t, StoreList, 8>; |
| |
| /// Process the module. |
| void Process() { |
| if (ir.root_block->IsEmpty()) { |
| return; |
| } |
| |
| // Loop over module-scope variables, looking for workgroup variables. |
| uint32_t next_id = 0; |
| for (auto inst : *ir.root_block) { |
| if (auto* var = inst->As<Var>()) { |
| auto* ptr = var->Result()->Type()->As<core::type::Pointer>(); |
| if (ptr && ptr->AddressSpace() == core::AddressSpace::kWorkgroup) { |
| // Record the usage of the variable for each block that references it. |
| var->Result()->ForEachUse([&](const Usage& use) { |
| block_to_direct_vars.GetOrZero(use.instruction->Block())->Add(var); |
| }); |
| var_to_id.Add(var, next_id++); |
| } |
| } |
| } |
| |
| // Process each entry point function. |
| for (auto* func : ir.functions) { |
| if (func->Stage() == Function::PipelineStage::kCompute) { |
| ProcessEntryPoint(func); |
| } |
| } |
| } |
| |
| /// Process an entry point function to zero-initialize the workgroup variables that it uses. |
| /// @param func the entry point function |
| void ProcessEntryPoint(Function* func) { |
| // Get list of transitively referenced workgroup variables. |
| auto vars = GetReferencedVars(func); |
| if (vars.IsEmpty()) { |
| return; |
| } |
| |
| // Sort the variables to get deterministic output in tests. |
| auto sorted_vars = vars.Vector(); |
| sorted_vars.Sort([&](Var* first, Var* second) { |
| return *var_to_id.Get(first) < *var_to_id.Get(second); |
| }); |
| |
| // Build list of store descriptors for all workgroup variables. |
| StoreMap stores; |
| for (auto* var : sorted_vars) { |
| PrepareStores(var, var->Result()->Type()->UnwrapPtr(), 1, {}, stores); |
| } |
| |
| // Sort the iteration counts to get deterministic output in tests. |
| auto sorted_iteration_counts = stores.Keys(); |
| sorted_iteration_counts.Sort(); |
| |
| // Capture the first instruction of the function. |
| // All new instructions will be inserted before this. |
| auto* function_start = func->Block()->Front(); |
| |
| // Get the local invocation index and the linearized workgroup size. |
| auto* local_index = GetLocalInvocationIndex(func); |
| auto wgsizes = func->WorkgroupSize().value(); |
| auto wgsize = wgsizes[0] * wgsizes[1] * wgsizes[2]; |
| |
| // Insert instructions to zero-initialize every variable. |
| b.InsertBefore(function_start, [&] { |
| for (auto count : sorted_iteration_counts) { |
| auto element_stores = stores.Get(count); |
| if (count == 1u) { |
| // Make the first invocation in the group perform all of the non-arrayed stores. |
| auto* ifelse = b.If(b.Equal(ty.bool_(), local_index, 0_u)); |
| b.Append(ifelse->True(), [&] { |
| for (auto& store : *element_stores) { |
| GenerateStore(store, count, b.Constant(0_u)); |
| } |
| b.ExitIf(ifelse); |
| }); |
| } else { |
| // Use a loop for arrayed stores. |
| GenerateZeroingLoop(local_index, count, wgsize, *element_stores); |
| } |
| } |
| b.Call(ty.void_(), core::BuiltinFn::kWorkgroupBarrier); |
| }); |
| } |
| |
| /// Get the set of workgroup variables transitively referenced by @p func. |
| /// @param func the function |
| /// @returns the set of transitively referenced workgroup variables |
| VarSet GetReferencedVars(Function* func) { |
| return function_to_transitive_vars.GetOrCreate(func, [&] { |
| VarSet vars; |
| GetReferencedVars(func->Block(), vars); |
| return vars; |
| }); |
| } |
| |
| /// Get the set of workgroup variables transitively referenced by @p block. |
| /// @param block the block |
| /// @param vars the set of transitively referenced workgroup variables to populate |
| void GetReferencedVars(Block* block, VarSet& vars) { |
| // Add directly referenced vars. |
| if (auto itr = block_to_direct_vars.Find(block)) { |
| for (auto* var : *itr) { |
| vars.Add(var); |
| } |
| } |
| |
| // Loop over instructions in the block. |
| for (auto* inst : *block) { |
| tint::Switch( |
| inst, |
| [&](UserCall* call) { |
| // Get variables referenced by a function called from this block. |
| auto callee_vars = GetReferencedVars(call->Target()); |
| for (auto* var : callee_vars) { |
| vars.Add(var); |
| } |
| }, |
| [&](ControlInstruction* ctrl) { |
| // Recurse into control instructions and gather their referenced vars. |
| ctrl->ForeachBlock([&](Block* blk) { GetReferencedVars(blk, vars); }); |
| }); |
| } |
| } |
| |
| /// Recursively generate store descriptors for a workgroup variable. |
| /// Determines the combined array iteration count of each inner element. |
| /// @param var the workgroup variable |
| /// @param type the current element type |
| /// @param iteration_count the iteration count of this inner element of the variable |
| /// @param indices the access indices needed to get to this element |
| /// @param stores the map of stores to populate |
| void PrepareStores(Var* var, |
| const type::Type* type, |
| uint32_t iteration_count, |
| Vector<Index, 4> indices, |
| StoreMap& stores) { |
| // If this type can be trivially zeroed, store to the whole element. |
| if (CanTriviallyZero(type)) { |
| stores.GetOrZero(iteration_count)->Push(Store{var, type, indices}); |
| return; |
| } |
| |
| tint::Switch( |
| type, |
| [&](const type::Array* arr) { |
| // Add an array index to the list and recurse into the element type. |
| TINT_ASSERT(arr->ConstantCount()); |
| auto count = arr->ConstantCount().value(); |
| auto new_indices = indices; |
| if (count > 1) { |
| new_indices.Push(ArrayIndex{count}); |
| } else { |
| new_indices.Push(0u); |
| } |
| PrepareStores(var, arr->ElemType(), iteration_count * count, new_indices, stores); |
| }, |
| [&](const type::Atomic*) { |
| stores.GetOrZero(iteration_count)->Push(Store{var, type, indices}); |
| }, |
| [&](const type::Struct* str) { |
| for (auto* member : str->Members()) { |
| // Add the member index to the index list and recurse into its type. |
| auto new_indices = indices; |
| new_indices.Push(member->Index()); |
| PrepareStores(var, member->Type(), iteration_count, new_indices, stores); |
| } |
| }, |
| [&](Default) { TINT_UNREACHABLE(); }); |
| } |
| |
| /// Get or inject an entry point builtin for the local invocation index. |
| /// @param func the entry point function |
| /// @returns the local invocation index builtin |
| Value* GetLocalInvocationIndex(Function* func) { |
| // Look for an existing local_invocation_index builtin parameter. |
| for (auto* param : func->Params()) { |
| if (auto* str = param->Type()->As<type::Struct>()) { |
| // Check each member for the local invocation index builtin attribute. |
| for (auto* member : str->Members()) { |
| if (member->Attributes().builtin && member->Attributes().builtin.value() == |
| BuiltinValue::kLocalInvocationIndex) { |
| auto* access = b.Access(ty.u32(), param, u32(member->Index())); |
| access->InsertBefore(func->Block()->Front()); |
| return access->Result(); |
| } |
| } |
| } else { |
| // Check if the parameter is the local invocation index. |
| if (param->Builtin() && |
| param->Builtin().value() == FunctionParam::Builtin::kLocalInvocationIndex) { |
| return param; |
| } |
| } |
| } |
| |
| // No local invocation index was found, so add one to the parameter list and use that. |
| Vector<FunctionParam*, 4> params = func->Params(); |
| auto* param = b.FunctionParam("tint_local_index", ty.u32()); |
| param->SetBuiltin(FunctionParam::Builtin::kLocalInvocationIndex); |
| params.Push(param); |
| func->SetParams(params); |
| return param; |
| } |
| |
| /// Generate the store instruction for a given store descriptor. |
| /// @param store the store descriptor |
| /// @param total_count the total number of elements that will be zeroed |
| /// @param linear_index the linear index of the single element that will be zeroed |
| void GenerateStore(const Store& store, uint32_t total_count, Value* linear_index) { |
| auto* to = store.var->Result(); |
| if (!store.indices.IsEmpty()) { |
| // Build the access indices to get to the target element. |
| // We walk backwards along the index list so that adjacent invocation store to |
| // adjacent array elements. |
| uint32_t count = 1; |
| Vector<Value*, 4> indices; |
| for (auto idx : Reverse(store.indices)) { |
| if (std::holds_alternative<ArrayIndex>(idx)) { |
| // Array indices are computed from the linear index based on the size of the |
| // array and the size of the sub-arrays that have already been indexed. |
| auto array_index = std::get<ArrayIndex>(idx); |
| Value* index = linear_index; |
| if (count > 1) { |
| index = b.Divide(ty.u32(), index, u32(count))->Result(); |
| } |
| if (total_count > count * array_index.count) { |
| index = b.Modulo(ty.u32(), index, u32(array_index.count))->Result(); |
| } |
| indices.Push(index); |
| count *= array_index.count; |
| } else { |
| // Constant indices are added to the list unmodified. |
| indices.Push(b.Constant(u32(std::get<uint32_t>(idx)))); |
| } |
| } |
| indices.Reverse(); |
| to = b.Access(ty.ptr(workgroup, store.store_type), to, indices)->Result(); |
| } |
| |
| // Generate the store instruction. |
| if (auto* atomic = store.store_type->As<type::Atomic>()) { |
| auto* zero = b.Constant(ir.constant_values.Zero(atomic->Type())); |
| b.Call(ty.void_(), core::BuiltinFn::kAtomicStore, to, zero); |
| } else { |
| auto* zero = b.Constant(ir.constant_values.Zero(store.store_type)); |
| b.Store(to, zero); |
| } |
| } |
| |
| /// Generate a loop for a list of stores with the same iteration count. |
| /// @param local_index the local invocation index |
| /// @param total_count the number of iterations needed to store to all elements |
| /// @param wgsize the linearized workgroup size |
| /// @param stores the list of store descriptors |
| void GenerateZeroingLoop(Value* local_index, |
| uint32_t total_count, |
| uint32_t wgsize, |
| const StoreList& stores) { |
| // The loop is equivalent to: |
| // for (var idx = local_index; idx < linear_iteration_count; idx += wgsize) { |
| // <store to elements at `idx`> |
| // } |
| auto* loop = b.Loop(); |
| auto* index = b.BlockParam(ty.u32()); |
| loop->Body()->SetParams({index}); |
| b.Append(loop->Initializer(), [&] { // |
| b.NextIteration(loop, local_index); |
| }); |
| b.Append(loop->Body(), [&] { |
| // Exit the loop when the iteration count has been exceeded. |
| auto* gt_max = b.GreaterThan(ty.bool_(), index, u32(total_count - 1u)); |
| auto* ifelse = b.If(gt_max); |
| b.Append(ifelse->True(), [&] { // |
| b.ExitLoop(loop); |
| }); |
| |
| // Insert all of the store instructions. |
| for (auto& store : stores) { |
| GenerateStore(store, total_count, index); |
| } |
| |
| b.Continue(loop); |
| }); |
| b.Append(loop->Continuing(), [&] { // |
| // Increment the loop index by linearized workgroup size. |
| b.NextIteration(loop, b.Add(ty.u32(), index, u32(wgsize))); |
| }); |
| } |
| |
| /// Check if a type can be efficiently zeroed with a single store. Returns `false` if there are |
| /// any nested arrays or atomics. |
| /// @param type the type to inspect |
| /// @returns true if a variable with store type @p ty can be efficiently zeroed |
| bool CanTriviallyZero(const core::type::Type* type) { |
| if (type->IsAnyOf<core::type::Atomic, core::type::Array>()) { |
| return false; |
| } |
| if (auto* str = type->As<core::type::Struct>()) { |
| for (auto* member : str->Members()) { |
| if (!CanTriviallyZero(member->Type())) { |
| return false; |
| } |
| } |
| } |
| return true; |
| } |
| }; |
| |
| } // namespace |
| |
| Result<SuccessType> ZeroInitWorkgroupMemory(Module& ir) { |
| auto result = ValidateAndDumpIfNeeded(ir, "ZeroInitWorkgroupMemory transform"); |
| if (!result) { |
| return result; |
| } |
| |
| State{ir}.Process(); |
| |
| return Success; |
| } |
| |
| } // namespace tint::core::ir::transform |