| // Copyright 2024 The Dawn & Tint Authors |
| // |
| // Redistribution and use in source and binary forms, with or without |
| // modification, are permitted provided that the following conditions are met: |
| // |
| // 1. Redistributions of source code must retain the above copyright notice, this |
| // list of conditions and the following disclaimer. |
| // |
| // 2. Redistributions in binary form must reproduce the above copyright notice, |
| // this list of conditions and the following disclaimer in the documentation |
| // and/or other materials provided with the distribution. |
| // |
| // 3. Neither the name of the copyright holder nor the names of its |
| // contributors may be used to endorse or promote products derived from |
| // this software without specific prior written permission. |
| // |
| // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" |
| // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE |
| // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE |
| // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE |
| // FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL |
| // DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR |
| // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER |
| // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, |
| // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
| // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| |
| #include "src/tint/lang/spirv/reader/lower/shader_io.h" |
| |
| #include <utility> |
| |
| #include "src/tint/lang/core/ir/builder.h" |
| #include "src/tint/lang/core/ir/module.h" |
| #include "src/tint/lang/core/ir/referenced_module_vars.h" |
| #include "src/tint/lang/core/ir/validator.h" |
| |
| namespace tint::spirv::reader::lower { |
| |
| namespace { |
| |
| using namespace tint::core::fluent_types; // NOLINT |
| |
| /// PIMPL state for the transform. |
| struct State { |
| /// The IR module. |
| core::ir::Module& ir; |
| |
| /// The IR builder. |
| core::ir::Builder b{ir}; |
| |
| /// The type manager. |
| core::type::Manager& ty{ir.Types()}; |
| |
| /// A map from block to its containing function. |
| Hashmap<core::ir::Block*, core::ir::Function*, 64> block_to_function{}; |
| |
| /// A map from each function to a map from input variable to parameter. |
| Hashmap<core::ir::Function*, Hashmap<core::ir::Var*, core::ir::Value*, 4>, 8> |
| function_parameter_map{}; |
| |
| /// The set of output variables that have been processed. |
| Hashset<core::ir::Var*, 4> output_variables{}; |
| |
| /// The mapping from functions to their transitively referenced output variables. |
| core::ir::ReferencedModuleVars<core::ir::Module> referenced_output_vars{ |
| ir, [](const core::ir::Var* var) { |
| auto* view = var->Result(0)->Type()->As<core::type::MemoryView>(); |
| return view && view->AddressSpace() == core::AddressSpace::kOut; |
| }}; |
| |
| /// Process the module. |
| void Process() { |
| // Process outputs first, as that may introduce new functions that input variables need to |
| // be propagated through. |
| ProcessOutputs(); |
| ProcessInputs(); |
| } |
| |
| /// Process output variables. |
| /// Changes output variables to the `private` address space and wraps entry points that produce |
| /// outputs with new functions that copy the outputs from the private variables to the return |
| /// value. |
| void ProcessOutputs() { |
| // Update entry point functions to return their outputs, using a wrapper function. |
| // Use a worklist as `ProcessEntryPointOutputs()` will add new functions. |
| Vector<core::ir::Function*, 4> entry_points; |
| for (auto& func : ir.functions) { |
| if (func->Stage() != core::ir::Function::PipelineStage::kUndefined) { |
| entry_points.Push(func); |
| } |
| } |
| for (auto& ep : entry_points) { |
| ProcessEntryPointOutputs(ep); |
| } |
| |
| // Remove attributes from all of the original structs and module-scope output variables. |
| // This is done last as we need to copy attributes during `ProcessEntryPointOutputs()`. |
| for (auto& var : output_variables) { |
| var->SetAttributes({}); |
| if (auto* str = var->Result(0)->Type()->UnwrapPtr()->As<core::type::Struct>()) { |
| for (auto* member : str->Members()) { |
| // TODO(crbug.com/tint/745): Remove the const_cast. |
| const_cast<core::type::StructMember*>(member)->SetAttributes({}); |
| } |
| } |
| } |
| } |
| |
| /// Process input variables. |
| /// Pass inputs down the call stack as parameters to any functions that need them. |
| void ProcessInputs() { |
| // Seed the block-to-function map with the function entry blocks. |
| for (auto& func : ir.functions) { |
| block_to_function.Add(func->Block(), func); |
| } |
| |
| // Gather the list of all module-scope input variables. |
| Vector<core::ir::Var*, 4> inputs; |
| for (auto* global : *ir.root_block) { |
| if (auto* var = global->As<core::ir::Var>()) { |
| auto addrspace = var->Result(0)->Type()->As<core::type::Pointer>()->AddressSpace(); |
| if (addrspace == core::AddressSpace::kIn) { |
| inputs.Push(var); |
| } |
| } |
| } |
| |
| // Replace the input variables with function parameters. |
| for (auto* var : inputs) { |
| ReplaceInputPointerUses(var, var->Result(0)); |
| var->Destroy(); |
| } |
| } |
| |
| /// Replace an output pointer address space to make it `private`. |
| /// @param value the output variable |
| void ReplaceOutputPointerAddressSpace(core::ir::InstructionResult* value) { |
| // Change the address space to `private`. |
| auto* old_ptr_type = value->Type(); |
| auto* new_ptr_type = ty.ptr(core::AddressSpace::kPrivate, old_ptr_type->UnwrapPtr()); |
| value->SetType(new_ptr_type); |
| |
| // Update all uses of the module-scope variable. |
| value->ForEachUseUnsorted([&](core::ir::Usage use) { |
| if (auto* access = use.instruction->As<core::ir::Access>()) { |
| ReplaceOutputPointerAddressSpace(access->Result(0)); |
| } else if (!use.instruction->IsAnyOf<core::ir::Load, core::ir::LoadVectorElement, |
| core::ir::Store, core::ir::StoreVectorElement>()) { |
| TINT_UNREACHABLE() |
| << "unexpected instruction: " << use.instruction->TypeInfo().name; |
| } |
| }); |
| } |
| |
| /// Process the outputs of an entry point function, adding a wrapper function to forward outputs |
| /// through the return value. |
| /// @param ep the entry point |
| void ProcessEntryPointOutputs(core::ir::Function* ep) { |
| const auto& referenced_outputs = referenced_output_vars.TransitiveReferences(ep); |
| if (referenced_outputs.IsEmpty()) { |
| return; |
| } |
| |
| // Add a wrapper function to return either a single value or a struct. |
| auto* wrapper = b.Function(ty.void_(), ep->Stage()); |
| if (auto name = ir.NameOf(ep)) { |
| ir.SetName(ep, name.Name() + "_inner"); |
| ir.SetName(wrapper, name); |
| } |
| |
| // Call the original entry point and make it a regular function. |
| ep->SetStage(core::ir::Function::PipelineStage::kUndefined); |
| b.Append(wrapper->Block(), [&] { // |
| b.Call(ep); |
| }); |
| |
| // Collect all outputs into a list of struct member declarations. |
| // Also add instructions to load their final values in the wrapper function. |
| Vector<core::ir::Value*, 4> results; |
| Vector<core::type::Manager::StructMemberDesc, 4> output_descriptors; |
| auto add_output = [&](Symbol name, const core::type::Type* type, |
| core::IOAttributes attributes) { |
| if (!name) { |
| name = ir.symbols.New(); |
| } |
| output_descriptors.Push(core::type::Manager::StructMemberDesc{name, type, attributes}); |
| }; |
| for (auto* var : referenced_outputs) { |
| // Change the address space of the variable to private and update its uses, if we |
| // haven't already seen this variable. |
| if (output_variables.Add(var)) { |
| ReplaceOutputPointerAddressSpace(var->Result(0)); |
| } |
| |
| // Copy the variable attributes to the struct member. |
| auto var_attributes = var->Attributes(); |
| auto var_type = var->Result(0)->Type()->UnwrapPtr(); |
| if (auto* str = var_type->As<core::type::Struct>()) { |
| // Add an output for each member of the struct. |
| for (auto* member : str->Members()) { |
| // Use the base variable attributes if not specified directly on the member. |
| auto member_attributes = member->Attributes(); |
| if (auto base_loc = var_attributes.location) { |
| // Location values increment from the base location value on the variable. |
| member_attributes.location = base_loc.value() + member->Index(); |
| } |
| if (!member_attributes.interpolation) { |
| member_attributes.interpolation = var_attributes.interpolation; |
| } |
| |
| add_output(member->Name(), member->Type(), std::move(member_attributes)); |
| |
| // Load the final result from the member of the original struct variable. |
| b.Append(wrapper->Block(), [&] { // |
| auto* access = |
| b.Access(ty.ptr<private_>(member->Type()), var, u32(member->Index())); |
| results.Push(b.Load(access)->Result(0)); |
| }); |
| } |
| } else { |
| // Load the final result from the original variable. |
| b.Append(wrapper->Block(), [&] { |
| results.Push(b.Load(var)->Result(0)); |
| |
| // If we're dealing with sample_mask, extract the scalar from the array. |
| if (var_attributes.builtin == core::BuiltinValue::kSampleMask) { |
| var_type = ty.u32(); |
| results.Back() = b.Access(ty.u32(), results.Back(), u32(0))->Result(0); |
| } |
| }); |
| add_output(ir.NameOf(var), var_type, std::move(var_attributes)); |
| } |
| } |
| |
| if (output_descriptors.Length() == 1) { |
| // Copy the output attributes to the function return. |
| wrapper->SetReturnAttributes(output_descriptors[0].attributes); |
| |
| // Return the output from the wrapper function. |
| wrapper->SetReturnType(output_descriptors[0].type); |
| b.Append(wrapper->Block(), [&] { // |
| b.Return(wrapper, results[0]); |
| }); |
| } else { |
| // Create a struct to hold all of the output values. |
| auto* str = ty.Struct(ir.symbols.New(), std::move(output_descriptors)); |
| wrapper->SetReturnType(str); |
| |
| // Collect the output values and return them from the wrapper function. |
| b.Append(wrapper->Block(), [&] { // |
| b.Return(wrapper, b.Construct(str, std::move(results))); |
| }); |
| } |
| } |
| |
| /// Replace a use of an input pointer value. |
| /// @param var the originating input variable |
| /// @param value the input pointer value |
| void ReplaceInputPointerUses(core::ir::Var* var, core::ir::Value* value) { |
| Vector<core::ir::Instruction*, 8> to_destroy; |
| value->ForEachUseUnsorted([&](core::ir::Usage use) { |
| auto* object = value; |
| if (object->Type()->Is<core::type::Pointer>()) { |
| // Get (or create) the function parameter that will replace the variable. |
| auto* func = ContainingFunction(use.instruction); |
| object = GetParameter(func, var); |
| } |
| |
| Switch( |
| use.instruction, |
| [&](core::ir::Load* l) { |
| // Fold the load away and replace its uses with the new parameter. |
| l->Result(0)->ReplaceAllUsesWith(object); |
| to_destroy.Push(l); |
| }, |
| [&](core::ir::LoadVectorElement* lve) { |
| // Replace the vector element load with an access instruction. |
| auto* access = b.AccessWithResult(lve->DetachResult(), object, lve->Index()); |
| access->InsertBefore(lve); |
| to_destroy.Push(lve); |
| }, |
| [&](core::ir::Access* a) { |
| // Remove the pointer from the source and destination type. |
| a->SetOperand(core::ir::Access::kObjectOperandOffset, object); |
| a->Result(0)->SetType(a->Result(0)->Type()->UnwrapPtr()); |
| ReplaceInputPointerUses(var, a->Result(0)); |
| }, |
| TINT_ICE_ON_NO_MATCH); |
| }); |
| |
| // Clean up orphaned instructions. |
| for (auto* inst : to_destroy) { |
| inst->Destroy(); |
| } |
| } |
| |
| /// Get the function that contains an instruction. |
| /// @param inst the instruction |
| /// @returns the function |
| core::ir::Function* ContainingFunction(core::ir::Instruction* inst) { |
| return block_to_function.GetOrAdd(inst->Block(), [&] { // |
| return ContainingFunction(inst->Block()->Parent()); |
| }); |
| } |
| |
| /// Get or create a function parameter to replace a module-scope variable. |
| /// @param func the function |
| /// @param var the module-scope variable |
| /// @returns the function parameter |
| core::ir::Value* GetParameter(core::ir::Function* func, core::ir::Var* var) { |
| return function_parameter_map.GetOrAddZero(func).GetOrAdd(var, [&] { |
| const bool entry_point = func->Stage() != core::ir::Function::PipelineStage::kUndefined; |
| auto* var_type = var->Result(0)->Type()->UnwrapPtr(); |
| |
| // Use a scalar u32 for sample_mask builtins for entry point parameters. |
| if (entry_point && var->Attributes().builtin == core::BuiltinValue::kSampleMask) { |
| TINT_ASSERT(var_type->Is<core::type::Array>()); |
| TINT_ASSERT(var_type->As<core::type::Array>()->ConstantCount() == 1u); |
| var_type = ty.u32(); |
| } |
| |
| // Create a new function parameter for the input. |
| auto* param = b.FunctionParam(var_type); |
| func->AppendParam(param); |
| if (auto name = ir.NameOf(var)) { |
| ir.SetName(param, name); |
| } |
| |
| // Add attributes to the parameter if this is an entry point function. |
| if (entry_point) { |
| AddEntryPointParameterAttributes(param, var->Attributes()); |
| } |
| |
| // Update the callsites of this function. |
| func->ForEachUseUnsorted([&](core::ir::Usage use) { |
| if (auto* call = use.instruction->As<core::ir::UserCall>()) { |
| // Recurse into the calling function. |
| auto* caller = ContainingFunction(call); |
| call->AppendArg(GetParameter(caller, var)); |
| } else if (!use.instruction->Is<core::ir::Return>()) { |
| TINT_UNREACHABLE() |
| << "unexpected instruction: " << use.instruction->TypeInfo().name; |
| } |
| }); |
| |
| core::ir::Value* result = param; |
| if (entry_point && var->Attributes().builtin == core::BuiltinValue::kSampleMask) { |
| // Construct an array from the scalar sample_mask builtin value for entry points. |
| auto* construct = b.Construct(var->Result(0)->Type()->UnwrapPtr(), param); |
| func->Block()->Prepend(construct); |
| result = construct->Result(0); |
| } |
| return result; |
| }); |
| } |
| |
| /// Add attributes to an entry point function parameter. |
| /// @param param the parameter |
| /// @param attributes the attributes |
| void AddEntryPointParameterAttributes(core::ir::FunctionParam* param, |
| const core::IOAttributes& attributes) { |
| if (auto* str = param->Type()->UnwrapPtr()->As<core::type::Struct>()) { |
| for (auto* member : str->Members()) { |
| // Use the base variable attributes if not specified directly on the member. |
| auto member_attributes = member->Attributes(); |
| if (auto base_loc = attributes.location) { |
| // Location values increment from the base location value on the variable. |
| member_attributes.location = base_loc.value() + member->Index(); |
| } |
| if (!member_attributes.interpolation) { |
| member_attributes.interpolation = attributes.interpolation; |
| } |
| // TODO(crbug.com/tint/745): Remove the const_cast. |
| const_cast<core::type::StructMember*>(member)->SetAttributes( |
| std::move(member_attributes)); |
| } |
| } else { |
| // Set attributes directly on the function parameter. |
| param->SetAttributes(attributes); |
| } |
| } |
| }; |
| |
| } // namespace |
| |
| Result<SuccessType> ShaderIO(core::ir::Module& ir) { |
| auto result = ValidateAndDumpIfNeeded(ir, "spirv.ShaderIO"); |
| if (result != Success) { |
| return result.Failure(); |
| } |
| |
| State{ir}.Process(); |
| |
| return Success; |
| } |
| |
| } // namespace tint::spirv::reader::lower |