blob: 294f859613ae5e27891b3625b44c50b76c13991a [file] [log] [blame]
// Copyright 2024 The Dawn & Tint Authors
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "src/tint/lang/spirv/reader/lower/shader_io.h"
#include <utility>
#include "src/tint/lang/core/ir/builder.h"
#include "src/tint/lang/core/ir/module.h"
#include "src/tint/lang/core/ir/referenced_module_vars.h"
#include "src/tint/lang/core/ir/validator.h"
namespace tint::spirv::reader::lower {
namespace {
using namespace tint::core::fluent_types; // NOLINT
/// PIMPL state for the transform.
struct State {
/// The IR module.
core::ir::Module& ir;
/// The IR builder.
core::ir::Builder b{ir};
/// The type manager.
core::type::Manager& ty{ir.Types()};
/// A map from block to its containing function.
Hashmap<core::ir::Block*, core::ir::Function*, 64> block_to_function{};
/// A map from each function to a map from input variable to parameter.
Hashmap<core::ir::Function*, Hashmap<core::ir::Var*, core::ir::Value*, 4>, 8>
function_parameter_map{};
/// The set of output variables that have been processed.
Hashset<core::ir::Var*, 4> output_variables{};
/// The mapping from functions to their transitively referenced output variables.
core::ir::ReferencedModuleVars<core::ir::Module> referenced_output_vars{
ir, [](const core::ir::Var* var) {
auto* view = var->Result(0)->Type()->As<core::type::MemoryView>();
return view && view->AddressSpace() == core::AddressSpace::kOut;
}};
/// Process the module.
void Process() {
// Process outputs first, as that may introduce new functions that input variables need to
// be propagated through.
ProcessOutputs();
ProcessInputs();
}
/// Process output variables.
/// Changes output variables to the `private` address space and wraps entry points that produce
/// outputs with new functions that copy the outputs from the private variables to the return
/// value.
void ProcessOutputs() {
// Update entry point functions to return their outputs, using a wrapper function.
// Use a worklist as `ProcessEntryPointOutputs()` will add new functions.
Vector<core::ir::Function*, 4> entry_points;
for (auto& func : ir.functions) {
if (func->Stage() != core::ir::Function::PipelineStage::kUndefined) {
entry_points.Push(func);
}
}
for (auto& ep : entry_points) {
ProcessEntryPointOutputs(ep);
}
// Remove attributes from all of the original structs and module-scope output variables.
// This is done last as we need to copy attributes during `ProcessEntryPointOutputs()`.
for (auto& var : output_variables) {
var->SetAttributes({});
if (auto* str = var->Result(0)->Type()->UnwrapPtr()->As<core::type::Struct>()) {
for (auto* member : str->Members()) {
// TODO(crbug.com/tint/745): Remove the const_cast.
const_cast<core::type::StructMember*>(member)->SetAttributes({});
}
}
}
}
/// Process input variables.
/// Pass inputs down the call stack as parameters to any functions that need them.
void ProcessInputs() {
// Seed the block-to-function map with the function entry blocks.
for (auto& func : ir.functions) {
block_to_function.Add(func->Block(), func);
}
// Gather the list of all module-scope input variables.
Vector<core::ir::Var*, 4> inputs;
for (auto* global : *ir.root_block) {
if (auto* var = global->As<core::ir::Var>()) {
auto addrspace = var->Result(0)->Type()->As<core::type::Pointer>()->AddressSpace();
if (addrspace == core::AddressSpace::kIn) {
inputs.Push(var);
}
}
}
// Replace the input variables with function parameters.
for (auto* var : inputs) {
ReplaceInputPointerUses(var, var->Result(0));
var->Destroy();
}
}
/// Replace an output pointer address space to make it `private`.
/// @param value the output variable
void ReplaceOutputPointerAddressSpace(core::ir::InstructionResult* value) {
// Change the address space to `private`.
auto* old_ptr_type = value->Type();
auto* new_ptr_type = ty.ptr(core::AddressSpace::kPrivate, old_ptr_type->UnwrapPtr());
value->SetType(new_ptr_type);
// Update all uses of the module-scope variable.
value->ForEachUseUnsorted([&](core::ir::Usage use) {
if (auto* access = use.instruction->As<core::ir::Access>()) {
ReplaceOutputPointerAddressSpace(access->Result(0));
} else if (!use.instruction->IsAnyOf<core::ir::Load, core::ir::LoadVectorElement,
core::ir::Store, core::ir::StoreVectorElement>()) {
TINT_UNREACHABLE()
<< "unexpected instruction: " << use.instruction->TypeInfo().name;
}
});
}
/// Process the outputs of an entry point function, adding a wrapper function to forward outputs
/// through the return value.
/// @param ep the entry point
void ProcessEntryPointOutputs(core::ir::Function* ep) {
const auto& referenced_outputs = referenced_output_vars.TransitiveReferences(ep);
if (referenced_outputs.IsEmpty()) {
return;
}
// Add a wrapper function to return either a single value or a struct.
auto* wrapper = b.Function(ty.void_(), ep->Stage());
if (auto name = ir.NameOf(ep)) {
ir.SetName(ep, name.Name() + "_inner");
ir.SetName(wrapper, name);
}
// Call the original entry point and make it a regular function.
ep->SetStage(core::ir::Function::PipelineStage::kUndefined);
b.Append(wrapper->Block(), [&] { //
b.Call(ep);
});
// Collect all outputs into a list of struct member declarations.
// Also add instructions to load their final values in the wrapper function.
Vector<core::ir::Value*, 4> results;
Vector<core::type::Manager::StructMemberDesc, 4> output_descriptors;
auto add_output = [&](Symbol name, const core::type::Type* type,
core::IOAttributes attributes) {
if (!name) {
name = ir.symbols.New();
}
output_descriptors.Push(core::type::Manager::StructMemberDesc{name, type, attributes});
};
for (auto* var : referenced_outputs) {
// Change the address space of the variable to private and update its uses, if we
// haven't already seen this variable.
if (output_variables.Add(var)) {
ReplaceOutputPointerAddressSpace(var->Result(0));
}
// Copy the variable attributes to the struct member.
auto var_attributes = var->Attributes();
auto var_type = var->Result(0)->Type()->UnwrapPtr();
if (auto* str = var_type->As<core::type::Struct>()) {
// Add an output for each member of the struct.
for (auto* member : str->Members()) {
// Use the base variable attributes if not specified directly on the member.
auto member_attributes = member->Attributes();
if (auto base_loc = var_attributes.location) {
// Location values increment from the base location value on the variable.
member_attributes.location = base_loc.value() + member->Index();
}
if (!member_attributes.interpolation) {
member_attributes.interpolation = var_attributes.interpolation;
}
add_output(member->Name(), member->Type(), std::move(member_attributes));
// Load the final result from the member of the original struct variable.
b.Append(wrapper->Block(), [&] { //
auto* access =
b.Access(ty.ptr<private_>(member->Type()), var, u32(member->Index()));
results.Push(b.Load(access)->Result(0));
});
}
} else {
// Load the final result from the original variable.
b.Append(wrapper->Block(), [&] {
results.Push(b.Load(var)->Result(0));
// If we're dealing with sample_mask, extract the scalar from the array.
if (var_attributes.builtin == core::BuiltinValue::kSampleMask) {
var_type = ty.u32();
results.Back() = b.Access(ty.u32(), results.Back(), u32(0))->Result(0);
}
});
add_output(ir.NameOf(var), var_type, std::move(var_attributes));
}
}
if (output_descriptors.Length() == 1) {
// Copy the output attributes to the function return.
wrapper->SetReturnAttributes(output_descriptors[0].attributes);
// Return the output from the wrapper function.
wrapper->SetReturnType(output_descriptors[0].type);
b.Append(wrapper->Block(), [&] { //
b.Return(wrapper, results[0]);
});
} else {
// Create a struct to hold all of the output values.
auto* str = ty.Struct(ir.symbols.New(), std::move(output_descriptors));
wrapper->SetReturnType(str);
// Collect the output values and return them from the wrapper function.
b.Append(wrapper->Block(), [&] { //
b.Return(wrapper, b.Construct(str, std::move(results)));
});
}
}
/// Replace a use of an input pointer value.
/// @param var the originating input variable
/// @param value the input pointer value
void ReplaceInputPointerUses(core::ir::Var* var, core::ir::Value* value) {
Vector<core::ir::Instruction*, 8> to_destroy;
value->ForEachUseUnsorted([&](core::ir::Usage use) {
auto* object = value;
if (object->Type()->Is<core::type::Pointer>()) {
// Get (or create) the function parameter that will replace the variable.
auto* func = ContainingFunction(use.instruction);
object = GetParameter(func, var);
}
Switch(
use.instruction,
[&](core::ir::Load* l) {
// Fold the load away and replace its uses with the new parameter.
l->Result(0)->ReplaceAllUsesWith(object);
to_destroy.Push(l);
},
[&](core::ir::LoadVectorElement* lve) {
// Replace the vector element load with an access instruction.
auto* access = b.AccessWithResult(lve->DetachResult(), object, lve->Index());
access->InsertBefore(lve);
to_destroy.Push(lve);
},
[&](core::ir::Access* a) {
// Remove the pointer from the source and destination type.
a->SetOperand(core::ir::Access::kObjectOperandOffset, object);
a->Result(0)->SetType(a->Result(0)->Type()->UnwrapPtr());
ReplaceInputPointerUses(var, a->Result(0));
},
TINT_ICE_ON_NO_MATCH);
});
// Clean up orphaned instructions.
for (auto* inst : to_destroy) {
inst->Destroy();
}
}
/// Get the function that contains an instruction.
/// @param inst the instruction
/// @returns the function
core::ir::Function* ContainingFunction(core::ir::Instruction* inst) {
return block_to_function.GetOrAdd(inst->Block(), [&] { //
return ContainingFunction(inst->Block()->Parent());
});
}
/// Get or create a function parameter to replace a module-scope variable.
/// @param func the function
/// @param var the module-scope variable
/// @returns the function parameter
core::ir::Value* GetParameter(core::ir::Function* func, core::ir::Var* var) {
return function_parameter_map.GetOrAddZero(func).GetOrAdd(var, [&] {
const bool entry_point = func->Stage() != core::ir::Function::PipelineStage::kUndefined;
auto* var_type = var->Result(0)->Type()->UnwrapPtr();
// Use a scalar u32 for sample_mask builtins for entry point parameters.
if (entry_point && var->Attributes().builtin == core::BuiltinValue::kSampleMask) {
TINT_ASSERT(var_type->Is<core::type::Array>());
TINT_ASSERT(var_type->As<core::type::Array>()->ConstantCount() == 1u);
var_type = ty.u32();
}
// Create a new function parameter for the input.
auto* param = b.FunctionParam(var_type);
func->AppendParam(param);
if (auto name = ir.NameOf(var)) {
ir.SetName(param, name);
}
// Add attributes to the parameter if this is an entry point function.
if (entry_point) {
AddEntryPointParameterAttributes(param, var->Attributes());
}
// Update the callsites of this function.
func->ForEachUseUnsorted([&](core::ir::Usage use) {
if (auto* call = use.instruction->As<core::ir::UserCall>()) {
// Recurse into the calling function.
auto* caller = ContainingFunction(call);
call->AppendArg(GetParameter(caller, var));
} else if (!use.instruction->Is<core::ir::Return>()) {
TINT_UNREACHABLE()
<< "unexpected instruction: " << use.instruction->TypeInfo().name;
}
});
core::ir::Value* result = param;
if (entry_point && var->Attributes().builtin == core::BuiltinValue::kSampleMask) {
// Construct an array from the scalar sample_mask builtin value for entry points.
auto* construct = b.Construct(var->Result(0)->Type()->UnwrapPtr(), param);
func->Block()->Prepend(construct);
result = construct->Result(0);
}
return result;
});
}
/// Add attributes to an entry point function parameter.
/// @param param the parameter
/// @param attributes the attributes
void AddEntryPointParameterAttributes(core::ir::FunctionParam* param,
const core::IOAttributes& attributes) {
if (auto* str = param->Type()->UnwrapPtr()->As<core::type::Struct>()) {
for (auto* member : str->Members()) {
// Use the base variable attributes if not specified directly on the member.
auto member_attributes = member->Attributes();
if (auto base_loc = attributes.location) {
// Location values increment from the base location value on the variable.
member_attributes.location = base_loc.value() + member->Index();
}
if (!member_attributes.interpolation) {
member_attributes.interpolation = attributes.interpolation;
}
// TODO(crbug.com/tint/745): Remove the const_cast.
const_cast<core::type::StructMember*>(member)->SetAttributes(
std::move(member_attributes));
}
} else {
// Set attributes directly on the function parameter.
param->SetAttributes(attributes);
}
}
};
} // namespace
Result<SuccessType> ShaderIO(core::ir::Module& ir) {
auto result = ValidateAndDumpIfNeeded(ir, "spirv.ShaderIO");
if (result != Success) {
return result.Failure();
}
State{ir}.Process();
return Success;
}
} // namespace tint::spirv::reader::lower