| // Copyright 2025 The Dawn & Tint Authors |
| // |
| // Redistribution and use in source and binary forms, with or without |
| // modification, are permitted provided that the following conditions are met: |
| // |
| // 1. Redistributions of source code must retain the above copyright notice, this |
| // list of conditions and the following disclaimer. |
| // |
| // 2. Redistributions in binary form must reproduce the above copyright notice, |
| // this list of conditions and the following disclaimer in the documentation |
| // and/or other materials provided with the distribution. |
| // |
| // 3. Neither the name of the copyright holder nor the names of its |
| // contributors may be used to endorse or promote products derived from |
| // this software without specific prior written permission. |
| // |
| // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" |
| // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE |
| // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE |
| // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE |
| // FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL |
| // DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR |
| // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER |
| // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, |
| // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
| // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| |
| #include "src/tint/lang/msl/writer/raise/argument_buffers.h" |
| |
| #include <string> |
| #include <utility> |
| |
| #include "src/tint/lang/core/ir/builder.h" |
| #include "src/tint/lang/core/ir/validator.h" |
| #include "src/tint/lang/core/type/binding_array.h" |
| |
| namespace tint::msl::writer::raise { |
| namespace { |
| |
| using namespace tint::core::fluent_types; // NOLINT |
| |
| /// PIMPL state for the transform. |
| struct State { |
| /// The IR module. |
| core::ir::Module& ir; |
| |
| /// The IR builder. |
| core::ir::Builder b{ir}; |
| |
| /// The type manager. |
| core::type::Manager& ty{ir.Types()}; |
| |
| /// The type of the structures that will contain the argument buffers. One entry per binding |
| /// group. |
| Hashmap<uint32_t, const core::type::Pointer*, 4> arg_buffers{}; |
| |
| /// The list of module-scope variables. |
| Vector<core::ir::Var*, 8> module_vars{}; |
| |
| /// Maps a binding argument buffer index to the function param |
| Hashmap<uint32_t, core::ir::Value*, 8> id_to_arg_buffer{}; |
| |
| /// Maps from variable to the argument buffer index |
| Hashmap<core::ir::Var*, uint32_t, 4> var_to_struct_idx{}; |
| |
| /// Maps a global `var` to an entry point parameter argument buffer |
| Hashmap<core::ir::Var*, Hashmap<core::ir::Function*, core::ir::FunctionParam*, 4>, 4> |
| var_to_function_param{}; |
| |
| /// A map from block to its containing function. |
| Hashmap<core::ir::Block*, core::ir::Function*, 64> block_to_function{}; |
| |
| // The name of the argument buffer structures. |
| static constexpr const char* kArgBufferName = "tint_arg_buffer_struct"; |
| static constexpr const char* kArgBufferParamName = "tint_arg_buffer"; |
| |
| /// Process the module. |
| void Process() { |
| // Seed the block-to-function map with the function entry blocks. |
| // This is used to determine the owning function for any given instruction. |
| for (auto& func : ir.functions) { |
| block_to_function.Add(func->Block(), func); |
| } |
| |
| CreateArgumentBuffers(); |
| |
| for (auto func : ir.functions) { |
| if (!func->IsEntryPoint()) { |
| continue; |
| } |
| AddArgumentBuffersToEntryPoint(func); |
| } |
| |
| // Replace uses of each module-scope variable. |
| for (auto& var : module_vars) { |
| if (!var->BindingPoint().has_value()) { |
| continue; |
| } |
| |
| Vector<core::ir::Instruction*, 16> to_destroy; |
| auto* ptr = var->Result()->Type()->As<core::type::Pointer>(); |
| var->Result()->ForEachUseUnsorted([&](core::ir::Usage use) { // |
| auto* extracted_variable = GetVariableFromStruct(var, use.instruction); |
| |
| // Everything but handles are just replaced with values from the structure. |
| if (ptr->AddressSpace() != core::AddressSpace::kHandle) { |
| use.instruction->SetOperand(use.operand_index, extracted_variable); |
| return; |
| } |
| |
| Switch( |
| use.instruction, |
| // Loads are replaced with a direct access to the variable. |
| [&](core::ir::Load* load) { |
| load->Result()->ReplaceAllUsesWith(extracted_variable); |
| to_destroy.Push(load); |
| }, |
| // Accesses are replaced with accesses of the extracted variable. |
| [&](core::ir::Access* access) { |
| auto* ba = ptr->StoreType()->As<core::type::BindingArray>(); |
| TINT_ASSERT(ba != nullptr); |
| auto* elem_type = ba->ElemType(); |
| |
| access->SetOperand(core::ir::Access::kObjectOperandOffset, |
| extracted_variable); |
| access->Result()->SetType(elem_type); |
| |
| // Accesses of the previously ptr<binding_array<T, N>> would return a ptr<T> |
| // but the new access returns a T. We need to modify all the previous load |
| // through the ptr<T> to direct accesses. |
| access->Result()->ForEachUseUnsorted([&](core::ir::Usage access_use) { |
| TINT_ASSERT(access_use.instruction->Is<core::ir::Load>()); |
| access_use.instruction->Result()->ReplaceAllUsesWith(access->Result()); |
| to_destroy.Push(access_use.instruction); |
| }); |
| }, |
| TINT_ICE_ON_NO_MATCH); |
| }); |
| var->Destroy(); |
| |
| // Clean up instructions that need to be removed. |
| for (auto* inst : to_destroy) { |
| inst->Destroy(); |
| } |
| } |
| } |
| |
| /// Create the argument buffers. Each bind group will have a separate structure. |
| void CreateArgumentBuffers() { |
| Vector<core::ir::Var*, 4> vars; |
| for (auto* global : *ir.root_block) { |
| auto* var = global->As<core::ir::Var>(); |
| if (!var) { |
| continue; |
| } |
| |
| // Only deal with vars which have binding points. |
| auto bp = var->BindingPoint(); |
| if (!bp.has_value()) { |
| continue; |
| } |
| |
| vars.Push(var); |
| } |
| |
| // Metal requires the argument buffer `id` entries to be in increasing order. SOrt the Vars |
| // such that when we create the struct we will create it in ascending order. |
| vars.Sort([&](const auto* va, const auto* vb) { |
| return va->BindingPoint() < vb->BindingPoint(); |
| }); |
| |
| // Collect a list of struct members for the variable declarations. |
| Hashmap<uint32_t, Vector<core::type::Manager::StructMemberDesc, 8>, 4> group_to_members; |
| for (auto& var : vars) { |
| auto bp = var->BindingPoint(); |
| auto* type = var->Result()->Type(); |
| |
| // Handle types drop the pointer and are passed around by value. |
| auto* ptr = type->As<core::type::Pointer>(); |
| if (ptr->AddressSpace() == core::AddressSpace::kHandle) { |
| type = ptr->StoreType(); |
| } |
| |
| auto name = ir.NameOf(var); |
| if (!name) { |
| name = ir.symbols.New(); |
| } |
| module_vars.Push(var); |
| |
| auto& struct_members = group_to_members.GetOrAddZero(bp->group); |
| var_to_struct_idx.Add(var, static_cast<uint32_t>(struct_members.Length())); |
| |
| struct_members.Push(core::type::Manager::StructMemberDesc{ |
| name, type, |
| core::IOAttributes{ |
| .binding_point = BindingPoint{0, bp->binding}, |
| }}); |
| } |
| |
| // Sort the keys for deterministic struct generation |
| auto keys = group_to_members.Keys().Sort(); |
| |
| for (auto& k : keys) { |
| // Create the structure. |
| auto name = ir.symbols.New(std::string(kArgBufferName) + "_" + std::to_string(k)); |
| auto members = group_to_members.Get(k); |
| |
| auto* strct = ty.Struct(name, std::move(*members)); |
| strct->SetStructFlag(core::type::kExplicitLayout); |
| |
| auto* type = ty.ptr(uniform, strct, read); |
| arg_buffers.Add(k, type); |
| } |
| } |
| |
| /// Add an argument buffer structure to an entry point function. |
| /// @param func the entry point function to modify |
| void AddArgumentBuffersToEntryPoint(core::ir::Function* func) { |
| auto keys = arg_buffers.Keys().Sort(); |
| |
| for (auto& buffer_id : keys) { |
| auto name = std::string(kArgBufferParamName) + "_" + std::to_string(buffer_id); |
| auto* param = b.FunctionParam(name, *arg_buffers.Get(buffer_id)); |
| param->SetBindingPoint(BindingPoint{buffer_id, 0}); |
| func->AppendParam(param); |
| |
| auto* ld = b.Load(param); |
| func->Block()->Prepend(ld); |
| |
| id_to_arg_buffer.Add(buffer_id, ld->Result()); |
| } |
| } |
| |
| /// Add an entry to each function which uses a module-scoped variable to |
| /// receive the variable as a parameter. |
| /// @param func the function to modify |
| /// @returns the function param |
| core::ir::FunctionParam* AddModuleVarsToFunction(core::ir::Function* func, core::ir::Var* var) { |
| auto& v = var_to_function_param.GetOrAddZero(var); |
| return v.GetOrAdd(func, [&] { |
| auto* type = var->Result()->Type(); |
| |
| // Handle types drop the pointer and are passed around by value. |
| auto* ptr = type->As<core::type::Pointer>(); |
| if (ptr->AddressSpace() == core::AddressSpace::kHandle) { |
| type = ptr->StoreType(); |
| } |
| |
| // Add a new parameter to receive the variable parameter. |
| core::ir::FunctionParam* param = nullptr; |
| |
| auto name = ir.NameOf(var).Name(); |
| if (name.empty()) { |
| param = b.FunctionParam(type); |
| } else { |
| param = b.FunctionParam(ir.NameOf(var).Name(), type); |
| } |
| func->AppendParam(param); |
| |
| func->ForEachUseUnsorted([&](core::ir::Usage use) { |
| if (auto* call = use.instruction->As<core::ir::UserCall>()) { |
| call->AppendArg(GetVariableFromStruct(var, call)); |
| } |
| }); |
| |
| return param; |
| }); |
| } |
| |
| /// Get a variable from the argument buffer, inserting new access |
| /// instructions before @p inst. |
| /// @param var the variable to get the replacement for |
| /// @param inst the instruction that uses the variable |
| /// @returns the variable extracted from the structure |
| core::ir::Value* GetVariableFromStruct(core::ir::Var* var, core::ir::Instruction* inst) { |
| auto* func = ContainingFunction(inst); |
| |
| auto* type = var->Result()->Type(); |
| |
| // Handle types drop the pointer and are passed around by value. |
| auto* ptr = type->As<core::type::Pointer>(); |
| if (ptr->AddressSpace() == core::AddressSpace::kHandle) { |
| type = ptr->StoreType(); |
| } |
| |
| if (func->IsEntryPoint()) { |
| auto* arg_buffer = *id_to_arg_buffer.Get(var->BindingPoint()->group); |
| auto idx = *var_to_struct_idx.Get(var); |
| |
| auto* access = b.Access(type, arg_buffer, u32(idx)); |
| access->InsertBefore(inst); |
| return access->Result(); |
| } |
| |
| return AddModuleVarsToFunction(func, var); |
| } |
| |
| /// Get the function that contains an instruction. |
| /// @param inst the instruction |
| /// @returns the function |
| core::ir::Function* ContainingFunction(core::ir::Instruction* inst) { |
| return block_to_function.GetOrAdd(inst->Block(), [&] { // |
| return ContainingFunction(inst->Block()->Parent()); |
| }); |
| } |
| }; |
| |
| } // namespace |
| |
| Result<SuccessType> ArgumentBuffers(core::ir::Module& ir) { |
| auto result = ValidateAndDumpIfNeeded(ir, "msl.ArgumentBuffers"); |
| if (result != Success) { |
| return result.Failure(); |
| } |
| |
| State{ir}.Process(); |
| |
| return Success; |
| } |
| |
| } // namespace tint::msl::writer::raise |