blob: 457516fdfac8984c309f3d1fc55d072b1f3c1a06 [file] [log] [blame]
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.h"
#include <algorithm>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "src/tint/lang/core/builtin_value.h"
#include "src/tint/lang/wgsl/ast/disable_validation_attribute.h"
#include "src/tint/lang/wgsl/ast/transform/unshadow.h"
#include "src/tint/lang/wgsl/program/clone_context.h"
#include "src/tint/lang/wgsl/program/program_builder.h"
#include "src/tint/lang/wgsl/resolver/resolve.h"
#include "src/tint/lang/wgsl/sem/function.h"
using namespace tint::number_suffixes; // NOLINT
TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::CanonicalizeEntryPointIO);
TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::CanonicalizeEntryPointIO::Config);
TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::CanonicalizeEntryPointIO::HLSLWaveIntrinsic);
namespace tint::ast::transform {
CanonicalizeEntryPointIO::CanonicalizeEntryPointIO() = default;
CanonicalizeEntryPointIO::~CanonicalizeEntryPointIO() = default;
namespace {
/// Info for a struct member
struct MemberInfo {
/// The struct member item
const StructMember* member;
/// The struct member location if provided
std::optional<uint32_t> location;
/// The struct member index if provided
std::optional<uint32_t> index;
};
/// FXC is sensitive to field order in structures, this is used by StructMemberComparator to ensure
/// that FXC is happy with the order of emitted fields.
uint32_t BuiltinOrder(core::BuiltinValue builtin) {
switch (builtin) {
case core::BuiltinValue::kPosition:
return 1;
case core::BuiltinValue::kVertexIndex:
return 2;
case core::BuiltinValue::kInstanceIndex:
return 3;
case core::BuiltinValue::kFrontFacing:
return 4;
case core::BuiltinValue::kFragDepth:
return 5;
case core::BuiltinValue::kLocalInvocationId:
return 6;
case core::BuiltinValue::kLocalInvocationIndex:
return 7;
case core::BuiltinValue::kGlobalInvocationId:
return 8;
case core::BuiltinValue::kWorkgroupId:
return 9;
case core::BuiltinValue::kNumWorkgroups:
return 10;
case core::BuiltinValue::kSampleIndex:
return 11;
case core::BuiltinValue::kSampleMask:
return 12;
case core::BuiltinValue::kPointSize:
return 13;
default:
break;
}
return 0;
}
// Returns true if `attr` is a shader IO attribute.
bool IsShaderIOAttribute(const Attribute* attr) {
return attr->IsAnyOf<BuiltinAttribute, InterpolateAttribute, InvariantAttribute,
LocationAttribute, IndexAttribute>();
}
} // namespace
/// PIMPL state for the transform
struct CanonicalizeEntryPointIO::State {
/// OutputValue represents a shader result that the wrapper function produces.
struct OutputValue {
/// The name of the output value.
std::string name;
/// The type of the output value.
Type type;
/// The shader IO attributes.
tint::Vector<const Attribute*, 8> attributes;
/// The value itself.
const Expression* value;
/// The output location.
std::optional<uint32_t> location;
/// The output index.
std::optional<uint32_t> index;
};
/// The clone context.
program::CloneContext& ctx;
/// The program builder
ProgramBuilder& b;
/// The transform config.
CanonicalizeEntryPointIO::Config const cfg;
/// The entry point function (AST).
const Function* func_ast;
/// The entry point function (SEM).
const sem::Function* func_sem;
/// The new entry point wrapper function's parameters.
tint::Vector<const Parameter*, 8> wrapper_ep_parameters;
/// The members of the wrapper function's struct parameter.
tint::Vector<MemberInfo, 8> wrapper_struct_param_members;
/// The name of the wrapper function's struct parameter.
Symbol wrapper_struct_param_name;
/// The parameters that will be passed to the original function.
tint::Vector<const Expression*, 8> inner_call_parameters;
/// The members of the wrapper function's struct return type.
tint::Vector<MemberInfo, 8> wrapper_struct_output_members;
/// The wrapper function output values.
tint::Vector<OutputValue, 8> wrapper_output_values;
/// The body of the wrapper function.
tint::Vector<const Statement*, 8> wrapper_body;
/// Input names used by the entrypoint
std::unordered_set<std::string> input_names;
/// A map of cloned attribute to builtin value
Hashmap<const BuiltinAttribute*, core::BuiltinValue, 16> builtin_attrs;
/// A map of builtin values to HLSL wave intrinsic functions.
Hashmap<core::BuiltinValue, Symbol, 2> wave_intrinsics;
/// Constructor
/// @param context the clone context
/// @param builder the program builder
/// @param config the transform config
/// @param function the entry point function
State(program::CloneContext& context,
ProgramBuilder& builder,
const CanonicalizeEntryPointIO::Config& config,
const Function* function)
: ctx(context),
b(builder),
cfg(config),
func_ast(function),
func_sem(ctx.src->Sem().Get(function)) {}
/// Clones the attributes from @p in and adds it to @p out. If @p in is a builtin attribute,
/// then builtin_attrs is updated with the builtin information.
/// @param in the attribute to clone
/// @param out the output Attributes
template <size_t N>
void CloneAttribute(const Attribute* in, tint::Vector<const Attribute*, N>& out) {
auto* cloned = ctx.Clone(in);
out.Push(cloned);
if (auto* builtin = in->As<BuiltinAttribute>()) {
builtin_attrs.Add(cloned->As<BuiltinAttribute>(), ctx.src->Sem().Get(builtin)->Value());
}
}
/// Clones the shader IO attributes from @p in.
/// @param in the attributes to clone
/// @param do_interpolate whether to clone InterpolateAttribute
/// @return the cloned attributes
template <size_t N>
auto CloneShaderIOAttributes(const tint::Vector<const Attribute*, N> in, bool do_interpolate) {
tint::Vector<const Attribute*, N> out;
for (auto* attr : in) {
if (IsShaderIOAttribute(attr) &&
(do_interpolate || !attr->template Is<InterpolateAttribute>())) {
CloneAttribute(attr, out);
}
}
return out;
}
/// @param attr the input attribute
/// @returns the builtin value of the attribute
core::BuiltinValue BuiltinOf(const BuiltinAttribute* attr) {
if (attr->generation_id == b.ID()) {
// attr belongs to the target program.
// Obtain the builtin value from #builtin_attrs.
if (auto builtin = builtin_attrs.Get(attr)) {
return *builtin;
}
} else {
// attr belongs to the source program.
// Obtain the builtin value from the semantic info.
return ctx.src->Sem().Get(attr)->Value();
}
TINT_ICE() << "could not obtain builtin value from attribute";
return core::BuiltinValue::kUndefined;
}
/// @param attrs the input attribute list
/// @returns the builtin value if any of the attributes in @p attrs is a builtin attribute,
/// otherwise core::BuiltinValue::kUndefined
core::BuiltinValue BuiltinOf(VectorRef<const Attribute*> attrs) {
if (auto* builtin = GetAttribute<BuiltinAttribute>(attrs)) {
return BuiltinOf(builtin);
}
return core::BuiltinValue::kUndefined;
}
/// Create or return a symbol for the wrapper function's struct parameter.
/// @returns the symbol for the struct parameter
Symbol InputStructSymbol() {
if (!wrapper_struct_param_name.IsValid()) {
wrapper_struct_param_name = b.Sym();
}
return wrapper_struct_param_name;
}
/// Call a wave intrinsic function for the subgroup builtin contained in @p attrs, if any.
/// @param attrs the attribute list that may contain a subgroup builtin
/// @returns an expression that calls a HLSL wave intrinsic, or nullptr
const ast::CallExpression* CallWaveIntrinsic(VectorRef<const Attribute*> attrs) {
if (cfg.shader_style != ShaderStyle::kHlsl) {
return nullptr;
}
// Helper to make a wave intrinsic.
auto make_intrinsic = [&](const char* name, HLSLWaveIntrinsic::Op op) {
auto symbol = b.Symbols().New(name);
b.Func(symbol, Empty, b.ty.u32(), nullptr,
Vector{b.ASTNodes().Create<HLSLWaveIntrinsic>(b.ID(), b.AllocateNodeID(), op),
b.Disable(DisabledValidation::kFunctionHasNoBody)});
return symbol;
};
// Get or create the intrinsic function.
auto builtin = BuiltinOf(attrs);
auto intrinsic = wave_intrinsics.GetOrCreate(builtin, [&] {
if (builtin == core::BuiltinValue::kSubgroupInvocationId) {
return make_intrinsic("__WaveGetLaneIndex",
HLSLWaveIntrinsic::Op::kWaveGetLaneIndex);
}
if (builtin == core::BuiltinValue::kSubgroupSize) {
return make_intrinsic("__WaveGetLaneCount",
HLSLWaveIntrinsic::Op::kWaveGetLaneCount);
}
return Symbol();
});
if (!intrinsic) {
return nullptr;
}
// Call the intrinsic function.
return b.Call(intrinsic);
}
/// Add a shader input to the entry point.
/// @param name the name of the shader input
/// @param type the type of the shader input
/// @param location the location if provided
/// @param attrs the attributes to apply to the shader input
/// @returns an expression which evaluates to the value of the shader input
const Expression* AddInput(std::string name,
const type::Type* type,
std::optional<uint32_t> location,
tint::Vector<const Attribute*, 8> attrs) {
auto ast_type = CreateASTTypeFor(ctx, type);
auto builtin_attr = BuiltinOf(attrs);
if (cfg.shader_style == ShaderStyle::kSpirv || cfg.shader_style == ShaderStyle::kGlsl) {
// Vulkan requires that integer user-defined fragment inputs are always decorated with
// `Flat`. See:
// https://www.khronos.org/registry/vulkan/specs/1.3-extensions/man/html/StandaloneSpirv.html#VUID-StandaloneSpirv-Flat-04744
// TODO(crbug.com/tint/1224): Remove this once a flat interpolation attribute is
// required for integers.
if (func_ast->PipelineStage() == PipelineStage::kFragment &&
type->is_integer_scalar_or_vector() && !HasAttribute<InterpolateAttribute>(attrs) &&
(HasAttribute<LocationAttribute>(attrs) ||
cfg.shader_style == ShaderStyle::kSpirv)) {
attrs.Push(b.Interpolate(core::InterpolationType::kFlat,
core::InterpolationSampling::kUndefined));
}
// Disable validation for use of the `input` address space.
attrs.Push(b.Disable(DisabledValidation::kIgnoreAddressSpace));
// In GLSL, if it's a builtin, override the name with the
// corresponding gl_ builtin name
if (cfg.shader_style == ShaderStyle::kGlsl &&
builtin_attr != core::BuiltinValue::kUndefined) {
name = GLSLBuiltinToString(builtin_attr, func_ast->PipelineStage(),
core::AddressSpace::kIn);
}
auto symbol = b.Symbols().New(name);
// Create the global variable and use its value for the shader input.
const Expression* value = b.Expr(symbol);
if (builtin_attr != core::BuiltinValue::kUndefined) {
if (cfg.shader_style == ShaderStyle::kGlsl) {
value = FromGLSLBuiltin(builtin_attr, value, ast_type);
} else if (builtin_attr == core::BuiltinValue::kSampleMask) {
// Vulkan requires the type of a SampleMask builtin to be an array.
// Declare it as array<u32, 1> and then load the first element.
ast_type = b.ty.array(ast_type, 1_u);
value = b.IndexAccessor(value, 0_i);
}
}
b.GlobalVar(symbol, ast_type, core::AddressSpace::kIn, std::move(attrs));
return value;
} else if (cfg.shader_style == ShaderStyle::kMsl &&
builtin_attr != core::BuiltinValue::kUndefined) {
// If this input is a builtin and we are targeting MSL, then add it to the
// parameter list and pass it directly to the inner function.
Symbol symbol = input_names.emplace(name).second ? b.Symbols().Register(name)
: b.Symbols().New(name);
wrapper_ep_parameters.Push(b.Param(symbol, ast_type, std::move(attrs)));
return b.Expr(symbol);
} else {
// Otherwise, move it to the new structure member list.
Symbol symbol = input_names.emplace(name).second ? b.Symbols().Register(name)
: b.Symbols().New(name);
wrapper_struct_param_members.Push(
{b.Member(symbol, ast_type, std::move(attrs)), location, std::nullopt});
return b.MemberAccessor(InputStructSymbol(), symbol);
}
}
/// Add a shader output to the entry point.
/// @param name the name of the shader output
/// @param type the type of the shader output
/// @param location the location if provided
/// @param index the index if provided
/// @param attrs the attributes to apply to the shader output
/// @param value the value of the shader output
void AddOutput(std::string name,
const type::Type* type,
std::optional<uint32_t> location,
std::optional<uint32_t> index,
tint::Vector<const Attribute*, 8> attrs,
const Expression* value) {
auto builtin_attr = BuiltinOf(attrs);
// Vulkan requires that integer user-defined vertex outputs are always decorated with
// `Flat`.
// TODO(crbug.com/tint/1224): Remove this once a flat interpolation attribute is required
// for integers.
if (cfg.shader_style == ShaderStyle::kSpirv &&
func_ast->PipelineStage() == PipelineStage::kVertex &&
type->is_integer_scalar_or_vector() && HasAttribute<LocationAttribute>(attrs) &&
!HasAttribute<InterpolateAttribute>(attrs)) {
attrs.Push(b.Interpolate(core::InterpolationType::kFlat,
core::InterpolationSampling::kUndefined));
}
// In GLSL, if it's a builtin, override the name with the
// corresponding gl_ builtin name
if (cfg.shader_style == ShaderStyle::kGlsl) {
if (builtin_attr != core::BuiltinValue::kUndefined) {
name = GLSLBuiltinToString(builtin_attr, func_ast->PipelineStage(),
core::AddressSpace::kOut);
value = ToGLSLBuiltin(builtin_attr, value, type);
}
}
OutputValue output;
output.name = name;
output.type = CreateASTTypeFor(ctx, type);
output.attributes = std::move(attrs);
output.value = value;
output.location = location;
output.index = index;
wrapper_output_values.Push(output);
}
/// Process a non-struct parameter.
/// This creates a new object for the shader input, moving the shader IO
/// attributes to it. It also adds an expression to the list of parameters
/// that will be passed to the original function.
/// @param param the original function parameter
void ProcessNonStructParameter(const sem::Parameter* param) {
if (auto* wave_intrinsic_call = CallWaveIntrinsic(param->Declaration()->attributes)) {
inner_call_parameters.Push(wave_intrinsic_call);
for (auto* attr : param->Declaration()->attributes) {
ctx.Remove(param->Declaration()->attributes, attr);
}
return;
}
// Do not add interpolation attributes on vertex input
bool do_interpolate = func_ast->PipelineStage() != PipelineStage::kVertex;
// Remove the shader IO attributes from the inner function parameter, and attach them to the
// new object instead.
tint::Vector<const Attribute*, 8> attributes;
for (auto* attr : param->Declaration()->attributes) {
if (IsShaderIOAttribute(attr)) {
ctx.Remove(param->Declaration()->attributes, attr);
if ((do_interpolate || !attr->Is<InterpolateAttribute>())) {
CloneAttribute(attr, attributes);
}
}
}
auto name = param->Declaration()->name->symbol.Name();
auto* input_expr = AddInput(name, param->Type(), param->Location(), std::move(attributes));
inner_call_parameters.Push(input_expr);
}
/// Process a struct parameter.
/// This creates new objects for each struct member, moving the shader IO
/// attributes to them. It also creates the structure that will be passed to
/// the original function.
/// @param param the original function parameter
void ProcessStructParameter(const sem::Parameter* param) {
// Do not add interpolation attributes on vertex input
bool do_interpolate = func_ast->PipelineStage() != PipelineStage::kVertex;
auto* str = param->Type()->As<sem::Struct>();
// Recreate struct members in the outer entry point and build an initializer
// list to pass them through to the inner function.
tint::Vector<const Expression*, 8> inner_struct_values;
for (auto* member : str->Members()) {
if (TINT_UNLIKELY(member->Type()->Is<type::Struct>())) {
TINT_ICE() << "nested IO struct";
continue;
}
if (auto* wave_intrinsic_call = CallWaveIntrinsic(member->Declaration()->attributes)) {
inner_struct_values.Push(wave_intrinsic_call);
continue;
}
auto name = member->Name().Name();
auto attributes =
CloneShaderIOAttributes(member->Declaration()->attributes, do_interpolate);
auto* input_expr = AddInput(name, member->Type(), member->Attributes().location,
std::move(attributes));
inner_struct_values.Push(input_expr);
}
// Construct the original structure using the new shader input objects.
inner_call_parameters.Push(
b.Call(ctx.Clone(param->Declaration()->type), inner_struct_values));
}
/// Process the entry point return type.
/// This generates a list of output values that are returned by the original
/// function.
/// @param inner_ret_type the original function return type
/// @param original_result the result object produced by the original function
void ProcessReturnType(const type::Type* inner_ret_type, Symbol original_result) {
// Do not add interpolation attributes on fragment output
bool do_interpolate = func_ast->PipelineStage() != PipelineStage::kFragment;
if (auto* str = inner_ret_type->As<sem::Struct>()) {
for (auto* member : str->Members()) {
if (TINT_UNLIKELY(member->Type()->Is<type::Struct>())) {
TINT_ICE() << "nested IO struct";
continue;
}
auto name = member->Name().Name();
auto attributes =
CloneShaderIOAttributes(member->Declaration()->attributes, do_interpolate);
// Extract the original structure member.
AddOutput(name, member->Type(), member->Attributes().location,
member->Attributes().index, std::move(attributes),
b.MemberAccessor(original_result, name));
}
} else if (!inner_ret_type->Is<type::Void>()) {
auto attributes =
CloneShaderIOAttributes(func_ast->return_type_attributes, do_interpolate);
// Propagate the non-struct return value as is.
AddOutput("value", func_sem->ReturnType(), func_sem->ReturnLocation(),
func_sem->ReturnIndex(), std::move(attributes), b.Expr(original_result));
}
}
/// Add a fixed sample mask to the wrapper function output.
/// If there is already a sample mask, bitwise-and it with the fixed mask.
/// Otherwise, create a new output value from the fixed mask.
void AddFixedSampleMask() {
// Check the existing output values for a sample mask builtin.
for (auto& outval : wrapper_output_values) {
if (BuiltinOf(outval.attributes) == core::BuiltinValue::kSampleMask) {
// Combine the authored sample mask with the fixed mask.
outval.value = b.And(outval.value, u32(cfg.fixed_sample_mask));
return;
}
}
// No existing sample mask builtin was found, so create a new output value using the fixed
// sample mask.
auto* builtin = b.Builtin(core::BuiltinValue::kSampleMask);
builtin_attrs.Add(builtin, core::BuiltinValue::kSampleMask);
AddOutput("fixed_sample_mask", b.create<type::U32>(), std::nullopt, std::nullopt, {builtin},
b.Expr(u32(cfg.fixed_sample_mask)));
}
/// Add a point size builtin to the wrapper function output.
void AddVertexPointSize() {
// Create a new output value and assign it a literal 1.0 value.
auto* builtin = b.Builtin(core::BuiltinValue::kPointSize);
builtin_attrs.Add(builtin, core::BuiltinValue::kPointSize);
AddOutput("vertex_point_size", b.create<type::F32>(), std::nullopt, std::nullopt, {builtin},
b.Expr(1_f));
}
/// Create an expression for gl_Position.[component]
/// @param component the component of gl_Position to access
/// @returns the new expression
const Expression* GLPosition(const char* component) {
Symbol pos = b.Symbols().Register("gl_Position");
Symbol c = b.Symbols().Register(component);
return b.MemberAccessor(b.Expr(pos), c);
}
/// Comparison function used to reorder struct members such that all members with
/// location attributes appear first (ordered by location slot), followed by
/// those with builtin attributes.
/// @param x a struct member
/// @param y another struct member
/// @returns true if a comes before b
bool StructMemberComparator(const MemberInfo& x, const MemberInfo& y) {
auto* x_loc = GetAttribute<LocationAttribute>(x.member->attributes);
auto* y_loc = GetAttribute<LocationAttribute>(y.member->attributes);
auto* x_blt = GetAttribute<BuiltinAttribute>(x.member->attributes);
auto* y_blt = GetAttribute<BuiltinAttribute>(y.member->attributes);
if (x_loc) {
if (!y_loc) {
// `a` has location attribute and `b` does not: `a` goes first.
return true;
}
// Both have location attributes: smallest goes first.
return x.location < y.location;
} else {
if (y_loc) {
// `b` has location attribute and `a` does not: `b` goes first.
return false;
}
// Both are builtins: order matters for FXC.
auto builtin_a = BuiltinOf(x_blt);
auto builtin_b = BuiltinOf(y_blt);
return BuiltinOrder(builtin_a) < BuiltinOrder(builtin_b);
}
}
/// Create the wrapper function's struct parameter and type objects.
void CreateInputStruct() {
// Sort the struct members to satisfy HLSL interfacing matching rules.
std::sort(wrapper_struct_param_members.begin(), wrapper_struct_param_members.end(),
[&](auto& x, auto& y) { return StructMemberComparator(x, y); });
tint::Vector<const StructMember*, 8> members;
for (auto& mem : wrapper_struct_param_members) {
members.Push(mem.member);
}
// Create the new struct type.
auto struct_name = b.Sym();
auto* in_struct = b.create<Struct>(b.Ident(struct_name), std::move(members), Empty);
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast, in_struct);
// Create a new function parameter using this struct type.
auto* param = b.Param(InputStructSymbol(), b.ty(struct_name));
wrapper_ep_parameters.Push(param);
}
/// Create and return the wrapper function's struct result object.
/// @returns the struct type
Struct* CreateOutputStruct() {
tint::Vector<const Statement*, 8> assignments;
auto wrapper_result = b.Symbols().New("wrapper_result");
// Create the struct members and their corresponding assignment statements.
std::unordered_set<std::string> member_names;
for (auto& outval : wrapper_output_values) {
// Use the original output name, unless that is already taken.
Symbol name;
if (member_names.count(outval.name)) {
name = b.Symbols().New(outval.name);
} else {
name = b.Symbols().Register(outval.name);
}
member_names.insert(name.Name());
wrapper_struct_output_members.Push(
{b.Member(name, outval.type, std::move(outval.attributes)), outval.location,
std::nullopt});
assignments.Push(b.Assign(b.MemberAccessor(wrapper_result, name), outval.value));
}
// Sort the struct members to satisfy HLSL interfacing matching rules.
std::sort(wrapper_struct_output_members.begin(), wrapper_struct_output_members.end(),
[&](auto& x, auto& y) { return StructMemberComparator(x, y); });
tint::Vector<const StructMember*, 8> members;
for (auto& mem : wrapper_struct_output_members) {
members.Push(mem.member);
}
// Create the new struct type.
auto* out_struct = b.create<Struct>(b.Ident(b.Sym()), std::move(members), Empty);
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast, out_struct);
// Create the output struct object, assign its members, and return it.
auto* result_object = b.Var(wrapper_result, b.ty(out_struct->name->symbol));
wrapper_body.Push(b.Decl(result_object));
for (auto* assignment : assignments) {
wrapper_body.Push(assignment);
}
wrapper_body.Push(b.Return(wrapper_result));
return out_struct;
}
/// Create and assign the wrapper function's output variables.
void CreateGlobalOutputVariables() {
for (auto& outval : wrapper_output_values) {
// Disable validation for use of the `output` address space.
auto attributes = std::move(outval.attributes);
attributes.Push(b.Disable(DisabledValidation::kIgnoreAddressSpace));
// Create the global variable and assign it the output value.
auto name = b.Symbols().New(outval.name);
Type type = outval.type;
const Expression* lhs = b.Expr(name);
if (BuiltinOf(attributes) == core::BuiltinValue::kSampleMask) {
// Vulkan requires the type of a SampleMask builtin to be an array.
// Declare it as array<u32, 1> and then store to the first element.
type = b.ty.array(type, 1_u);
lhs = b.IndexAccessor(lhs, 0_i);
}
b.GlobalVar(name, type, core::AddressSpace::kOut, std::move(attributes));
wrapper_body.Push(b.Assign(lhs, outval.value));
}
}
// Recreate the original function without entry point attributes and call it.
/// @returns the inner function call expression
const CallExpression* CallInnerFunction() {
Symbol inner_name;
if (cfg.shader_style == ShaderStyle::kGlsl) {
// In GLSL, clone the original entry point name, as the wrapper will be
// called "main".
inner_name = ctx.Clone(func_ast->name->symbol);
} else {
// Add a suffix to the function name, as the wrapper function will take
// the original entry point name.
auto ep_name = func_ast->name->symbol.Name();
inner_name = b.Symbols().New(ep_name + "_inner");
}
// Clone everything, dropping the function and return type attributes.
// The parameter attributes will have already been stripped during
// processing.
auto* inner_function = b.create<Function>(b.Ident(inner_name), ctx.Clone(func_ast->params),
ctx.Clone(func_ast->return_type),
ctx.Clone(func_ast->body), Empty, Empty);
ctx.Replace(func_ast, inner_function);
// Call the function.
return b.Call(inner_function->name->symbol, inner_call_parameters);
}
/// Process the entry point function.
void Process() {
bool needs_fixed_sample_mask = false;
bool needs_vertex_point_size = false;
if (func_ast->PipelineStage() == PipelineStage::kFragment &&
cfg.fixed_sample_mask != 0xFFFFFFFF) {
needs_fixed_sample_mask = true;
}
if (func_ast->PipelineStage() == PipelineStage::kVertex && cfg.emit_vertex_point_size) {
needs_vertex_point_size = true;
}
// Exit early if there is no shader IO to handle.
if (func_sem->Parameters().Length() == 0 && func_sem->ReturnType()->Is<type::Void>() &&
!needs_fixed_sample_mask && !needs_vertex_point_size &&
cfg.shader_style != ShaderStyle::kGlsl) {
return;
}
// Process the entry point parameters, collecting those that need to be
// aggregated into a single structure.
if (!func_sem->Parameters().IsEmpty()) {
for (auto* param : func_sem->Parameters()) {
if (param->Type()->Is<type::Struct>()) {
ProcessStructParameter(param);
} else {
ProcessNonStructParameter(param);
}
}
// Create a structure parameter for the outer entry point if necessary.
if (!wrapper_struct_param_members.IsEmpty()) {
CreateInputStruct();
}
}
// Recreate the original function and call it.
auto* call_inner = CallInnerFunction();
// Process the return type, and start building the wrapper function body.
std::function<Type()> wrapper_ret_type = [&] { return b.ty.void_(); };
if (func_sem->ReturnType()->Is<type::Void>()) {
// The function call is just a statement with no result.
wrapper_body.Push(b.CallStmt(call_inner));
} else {
// Capture the result of calling the original function.
auto* inner_result = b.Let(b.Symbols().New("inner_result"), call_inner);
wrapper_body.Push(b.Decl(inner_result));
// Process the original return type to determine the outputs that the
// outer function needs to produce.
ProcessReturnType(func_sem->ReturnType(), inner_result->name->symbol);
}
// Add a fixed sample mask, if necessary.
if (needs_fixed_sample_mask) {
AddFixedSampleMask();
}
// Add the pointsize builtin, if necessary.
if (needs_vertex_point_size) {
AddVertexPointSize();
}
// Produce the entry point outputs, if necessary.
if (!wrapper_output_values.IsEmpty()) {
if (cfg.shader_style == ShaderStyle::kSpirv || cfg.shader_style == ShaderStyle::kGlsl) {
CreateGlobalOutputVariables();
} else {
auto* output_struct = CreateOutputStruct();
wrapper_ret_type = [&, output_struct] { return b.ty(output_struct->name->symbol); };
}
}
if (cfg.shader_style == ShaderStyle::kGlsl &&
func_ast->PipelineStage() == PipelineStage::kVertex) {
auto* pos_y = GLPosition("y");
auto* negate_pos_y = b.create<UnaryOpExpression>(UnaryOp::kNegation, GLPosition("y"));
wrapper_body.Push(b.Assign(pos_y, negate_pos_y));
auto* two_z = b.Mul(b.Expr(2_f), GLPosition("z"));
auto* fixed_z = b.Sub(two_z, GLPosition("w"));
wrapper_body.Push(b.Assign(GLPosition("z"), fixed_z));
}
// Create the wrapper entry point function.
// For GLSL, use "main", otherwise take the name of the original
// entry point function.
Symbol name;
if (cfg.shader_style == ShaderStyle::kGlsl) {
name = b.Symbols().New("main");
} else {
name = ctx.Clone(func_ast->name->symbol);
}
auto* wrapper_func =
b.create<Function>(b.Ident(name), wrapper_ep_parameters, b.ty(wrapper_ret_type()),
b.Block(wrapper_body), ctx.Clone(func_ast->attributes), Empty);
ctx.InsertAfter(ctx.src->AST().GlobalDeclarations(), func_ast, wrapper_func);
}
/// Retrieve the gl_ string corresponding to a builtin.
/// @param builtin the builtin
/// @param stage the current pipeline stage
/// @param address_space the address space (input or output)
/// @returns the gl_ string corresponding to that builtin
const char* GLSLBuiltinToString(core::BuiltinValue builtin,
PipelineStage stage,
core::AddressSpace address_space) {
switch (builtin) {
case core::BuiltinValue::kPosition:
switch (stage) {
case PipelineStage::kVertex:
return "gl_Position";
case PipelineStage::kFragment:
return "gl_FragCoord";
default:
return "";
}
case core::BuiltinValue::kVertexIndex:
return "gl_VertexID";
case core::BuiltinValue::kInstanceIndex:
return "gl_InstanceID";
case core::BuiltinValue::kFrontFacing:
return "gl_FrontFacing";
case core::BuiltinValue::kFragDepth:
return "gl_FragDepth";
case core::BuiltinValue::kLocalInvocationId:
return "gl_LocalInvocationID";
case core::BuiltinValue::kLocalInvocationIndex:
return "gl_LocalInvocationIndex";
case core::BuiltinValue::kGlobalInvocationId:
return "gl_GlobalInvocationID";
case core::BuiltinValue::kNumWorkgroups:
return "gl_NumWorkGroups";
case core::BuiltinValue::kWorkgroupId:
return "gl_WorkGroupID";
case core::BuiltinValue::kSampleIndex:
return "gl_SampleID";
case core::BuiltinValue::kSampleMask:
if (address_space == core::AddressSpace::kIn) {
return "gl_SampleMaskIn";
} else {
return "gl_SampleMask";
}
default:
return "";
}
}
/// Convert a given GLSL builtin value to the corresponding WGSL value.
/// @param builtin the builtin variable
/// @param value the value to convert
/// @param ast_type (inout) the incoming WGSL and outgoing GLSL types
/// @returns an expression representing the GLSL builtin converted to what
/// WGSL expects
const Expression* FromGLSLBuiltin(core::BuiltinValue builtin,
const Expression* value,
Type& ast_type) {
switch (builtin) {
case core::BuiltinValue::kVertexIndex:
case core::BuiltinValue::kInstanceIndex:
case core::BuiltinValue::kSampleIndex:
// GLSL uses i32 for these, so bitcast to u32.
value = b.Bitcast(ast_type, value);
ast_type = b.ty.i32();
break;
case core::BuiltinValue::kSampleMask:
// gl_SampleMask is an array of i32. Retrieve the first element and
// bitcast it to u32.
value = b.IndexAccessor(value, 0_i);
value = b.Bitcast(ast_type, value);
ast_type = b.ty.array(b.ty.i32(), 1_u);
break;
default:
break;
}
return value;
}
/// Convert a given WGSL value to the type expected when assigning to a
/// GLSL builtin.
/// @param builtin the builtin variable
/// @param value the value to convert
/// @param type (out) the type to which the value was converted
/// @returns the converted value which can be assigned to the GLSL builtin
const Expression* ToGLSLBuiltin(core::BuiltinValue builtin,
const Expression* value,
const type::Type*& type) {
switch (builtin) {
case core::BuiltinValue::kVertexIndex:
case core::BuiltinValue::kInstanceIndex:
case core::BuiltinValue::kSampleIndex:
case core::BuiltinValue::kSampleMask:
type = b.create<type::I32>();
value = b.Bitcast(CreateASTTypeFor(ctx, type), value);
break;
default:
break;
}
return value;
}
};
Transform::ApplyResult CanonicalizeEntryPointIO::Apply(const Program* src,
const DataMap& inputs,
DataMap&) const {
ProgramBuilder b;
program::CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
auto* cfg = inputs.Get<Config>();
if (cfg == nullptr) {
b.Diagnostics().add_error(diag::System::Transform,
"missing transform data for " + std::string(TypeInfo().name));
return resolver::Resolve(b);
}
// Remove entry point IO attributes from struct declarations.
// New structures will be created for each entry point, as necessary.
for (auto* ty : src->AST().TypeDecls()) {
if (auto* struct_ty = ty->As<Struct>()) {
for (auto* member : struct_ty->members) {
for (auto* attr : member->attributes) {
if (IsShaderIOAttribute(attr)) {
ctx.Remove(member->attributes, attr);
}
}
}
}
}
for (auto* func_ast : src->AST().Functions()) {
if (!func_ast->IsEntryPoint()) {
continue;
}
State state(ctx, b, *cfg, func_ast);
state.Process();
}
ctx.Clone();
return resolver::Resolve(b);
}
CanonicalizeEntryPointIO::Config::Config(ShaderStyle style,
uint32_t sample_mask,
bool emit_point_size)
: shader_style(style),
fixed_sample_mask(sample_mask),
emit_vertex_point_size(emit_point_size) {}
CanonicalizeEntryPointIO::Config::Config(const Config&) = default;
CanonicalizeEntryPointIO::Config::~Config() = default;
CanonicalizeEntryPointIO::HLSLWaveIntrinsic::HLSLWaveIntrinsic(GenerationID pid, NodeID nid, Op o)
: Base(pid, nid, Empty), op(o) {}
CanonicalizeEntryPointIO::HLSLWaveIntrinsic::~HLSLWaveIntrinsic() = default;
std::string CanonicalizeEntryPointIO::HLSLWaveIntrinsic::InternalName() const {
StringStream ss;
switch (op) {
case Op::kWaveGetLaneCount:
return "intrinsic_wave_get_lane_count";
case Op::kWaveGetLaneIndex:
return "intrinsic_wave_get_lane_index";
}
return ss.str();
}
const CanonicalizeEntryPointIO::HLSLWaveIntrinsic*
CanonicalizeEntryPointIO::HLSLWaveIntrinsic::Clone(ast::CloneContext& ctx) const {
return ctx.dst->ASTNodes().Create<CanonicalizeEntryPointIO::HLSLWaveIntrinsic>(
ctx.dst->ID(), ctx.dst->AllocateNodeID(), op);
}
} // namespace tint::ast::transform