Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 1 | // Copyright 2021 The Tint Authors. |
| 2 | // |
| 3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | // you may not use this file except in compliance with the License. |
| 5 | // You may obtain a copy of the License at |
| 6 | // |
| 7 | // http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | // |
| 9 | // Unless required by applicable law or agreed to in writing, software |
| 10 | // distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | // See the License for the specific language governing permissions and |
| 13 | // limitations under the License. |
| 14 | |
| 15 | #include "src/tint/transform/canonicalize_entry_point_io.h" |
| 16 | |
| 17 | #include <algorithm> |
| 18 | #include <string> |
| 19 | #include <unordered_set> |
| 20 | #include <utility> |
| 21 | #include <vector> |
| 22 | |
| 23 | #include "src/tint/ast/disable_validation_attribute.h" |
| 24 | #include "src/tint/program_builder.h" |
| 25 | #include "src/tint/sem/function.h" |
| 26 | #include "src/tint/transform/unshadow.h" |
| 27 | |
Ben Clayton | 0ce9ab0 | 2022-05-05 20:23:40 +0000 | [diff] [blame] | 28 | using namespace tint::number_suffixes; // NOLINT |
| 29 | |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 30 | TINT_INSTANTIATE_TYPEINFO(tint::transform::CanonicalizeEntryPointIO); |
| 31 | TINT_INSTANTIATE_TYPEINFO(tint::transform::CanonicalizeEntryPointIO::Config); |
| 32 | |
dan sinclair | b5599d3 | 2022-04-07 16:55:14 +0000 | [diff] [blame] | 33 | namespace tint::transform { |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 34 | |
| 35 | CanonicalizeEntryPointIO::CanonicalizeEntryPointIO() = default; |
| 36 | CanonicalizeEntryPointIO::~CanonicalizeEntryPointIO() = default; |
| 37 | |
| 38 | namespace { |
| 39 | |
| 40 | // Comparison function used to reorder struct members such that all members with |
| 41 | // location attributes appear first (ordered by location slot), followed by |
| 42 | // those with builtin attributes. |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 43 | bool StructMemberComparator(const ast::StructMember* a, const ast::StructMember* b) { |
| 44 | auto* a_loc = ast::GetAttribute<ast::LocationAttribute>(a->attributes); |
| 45 | auto* b_loc = ast::GetAttribute<ast::LocationAttribute>(b->attributes); |
| 46 | auto* a_blt = ast::GetAttribute<ast::BuiltinAttribute>(a->attributes); |
| 47 | auto* b_blt = ast::GetAttribute<ast::BuiltinAttribute>(b->attributes); |
| 48 | if (a_loc) { |
| 49 | if (!b_loc) { |
| 50 | // `a` has location attribute and `b` does not: `a` goes first. |
| 51 | return true; |
| 52 | } |
| 53 | // Both have location attributes: smallest goes first. |
| 54 | return a_loc->value < b_loc->value; |
| 55 | } else { |
| 56 | if (b_loc) { |
| 57 | // `b` has location attribute and `a` does not: `b` goes first. |
| 58 | return false; |
| 59 | } |
| 60 | // Both are builtins: order doesn't matter, just use enum value. |
| 61 | return a_blt->builtin < b_blt->builtin; |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 62 | } |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 63 | } |
| 64 | |
| 65 | // Returns true if `attr` is a shader IO attribute. |
| 66 | bool IsShaderIOAttribute(const ast::Attribute* attr) { |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 67 | return attr->IsAnyOf<ast::BuiltinAttribute, ast::InterpolateAttribute, ast::InvariantAttribute, |
| 68 | ast::LocationAttribute>(); |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 69 | } |
| 70 | |
| 71 | // Returns true if `attrs` contains a `sample_mask` builtin. |
| 72 | bool HasSampleMask(const ast::AttributeList& attrs) { |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 73 | auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(attrs); |
| 74 | return builtin && builtin->builtin == ast::Builtin::kSampleMask; |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 75 | } |
| 76 | |
| 77 | } // namespace |
| 78 | |
| 79 | /// State holds the current transform state for a single entry point. |
| 80 | struct CanonicalizeEntryPointIO::State { |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 81 | /// OutputValue represents a shader result that the wrapper function produces. |
| 82 | struct OutputValue { |
| 83 | /// The name of the output value. |
| 84 | std::string name; |
| 85 | /// The type of the output value. |
| 86 | const ast::Type* type; |
| 87 | /// The shader IO attributes. |
| 88 | ast::AttributeList attributes; |
| 89 | /// The value itself. |
| 90 | const ast::Expression* value; |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 91 | }; |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 92 | |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 93 | /// The clone context. |
| 94 | CloneContext& ctx; |
| 95 | /// The transform config. |
| 96 | CanonicalizeEntryPointIO::Config const cfg; |
| 97 | /// The entry point function (AST). |
| 98 | const ast::Function* func_ast; |
| 99 | /// The entry point function (SEM). |
| 100 | const sem::Function* func_sem; |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 101 | |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 102 | /// The new entry point wrapper function's parameters. |
| 103 | ast::VariableList wrapper_ep_parameters; |
| 104 | /// The members of the wrapper function's struct parameter. |
| 105 | ast::StructMemberList wrapper_struct_param_members; |
| 106 | /// The name of the wrapper function's struct parameter. |
| 107 | Symbol wrapper_struct_param_name; |
| 108 | /// The parameters that will be passed to the original function. |
| 109 | ast::ExpressionList inner_call_parameters; |
| 110 | /// The members of the wrapper function's struct return type. |
| 111 | ast::StructMemberList wrapper_struct_output_members; |
| 112 | /// The wrapper function output values. |
| 113 | std::vector<OutputValue> wrapper_output_values; |
| 114 | /// The body of the wrapper function. |
| 115 | ast::StatementList wrapper_body; |
| 116 | /// Input names used by the entrypoint |
| 117 | std::unordered_set<std::string> input_names; |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 118 | |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 119 | /// Constructor |
| 120 | /// @param context the clone context |
| 121 | /// @param config the transform config |
| 122 | /// @param function the entry point function |
| 123 | State(CloneContext& context, |
| 124 | const CanonicalizeEntryPointIO::Config& config, |
| 125 | const ast::Function* function) |
| 126 | : ctx(context), cfg(config), func_ast(function), func_sem(ctx.src->Sem().Get(function)) {} |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 127 | |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 128 | /// Clones the shader IO attributes from `src`. |
| 129 | /// @param src the attributes to clone |
| 130 | /// @param do_interpolate whether to clone InterpolateAttribute |
| 131 | /// @return the cloned attributes |
| 132 | ast::AttributeList CloneShaderIOAttributes(const ast::AttributeList& src, bool do_interpolate) { |
| 133 | ast::AttributeList new_attributes; |
| 134 | for (auto* attr : src) { |
| 135 | if (IsShaderIOAttribute(attr) && |
| 136 | (do_interpolate || !attr->Is<ast::InterpolateAttribute>())) { |
| 137 | new_attributes.push_back(ctx.Clone(attr)); |
| 138 | } |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 139 | } |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 140 | return new_attributes; |
| 141 | } |
| 142 | |
| 143 | /// Create or return a symbol for the wrapper function's struct parameter. |
| 144 | /// @returns the symbol for the struct parameter |
| 145 | Symbol InputStructSymbol() { |
| 146 | if (!wrapper_struct_param_name.IsValid()) { |
| 147 | wrapper_struct_param_name = ctx.dst->Sym(); |
| 148 | } |
| 149 | return wrapper_struct_param_name; |
| 150 | } |
| 151 | |
| 152 | /// Add a shader input to the entry point. |
| 153 | /// @param name the name of the shader input |
| 154 | /// @param type the type of the shader input |
| 155 | /// @param attributes the attributes to apply to the shader input |
| 156 | /// @returns an expression which evaluates to the value of the shader input |
| 157 | const ast::Expression* AddInput(std::string name, |
| 158 | const sem::Type* type, |
| 159 | ast::AttributeList attributes) { |
| 160 | auto* ast_type = CreateASTTypeFor(ctx, type); |
| 161 | if (cfg.shader_style == ShaderStyle::kSpirv || cfg.shader_style == ShaderStyle::kGlsl) { |
| 162 | // Vulkan requires that integer user-defined fragment inputs are |
| 163 | // always decorated with `Flat`. |
| 164 | // TODO(crbug.com/tint/1224): Remove this once a flat interpolation |
| 165 | // attribute is required for integers. |
| 166 | if (type->is_integer_scalar_or_vector() && |
| 167 | ast::HasAttribute<ast::LocationAttribute>(attributes) && |
| 168 | !ast::HasAttribute<ast::InterpolateAttribute>(attributes) && |
| 169 | func_ast->PipelineStage() == ast::PipelineStage::kFragment) { |
| 170 | attributes.push_back(ctx.dst->Interpolate(ast::InterpolationType::kFlat, |
| 171 | ast::InterpolationSampling::kNone)); |
| 172 | } |
| 173 | |
| 174 | // Disable validation for use of the `input` storage class. |
| 175 | attributes.push_back(ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass)); |
| 176 | |
| 177 | // In GLSL, if it's a builtin, override the name with the |
| 178 | // corresponding gl_ builtin name |
| 179 | auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(attributes); |
| 180 | if (cfg.shader_style == ShaderStyle::kGlsl && builtin) { |
| 181 | name = GLSLBuiltinToString(builtin->builtin, func_ast->PipelineStage(), |
| 182 | ast::StorageClass::kInput); |
| 183 | } |
| 184 | auto symbol = ctx.dst->Symbols().New(name); |
| 185 | |
| 186 | // Create the global variable and use its value for the shader input. |
| 187 | const ast::Expression* value = ctx.dst->Expr(symbol); |
| 188 | |
| 189 | if (builtin) { |
| 190 | if (cfg.shader_style == ShaderStyle::kGlsl) { |
| 191 | value = FromGLSLBuiltin(builtin->builtin, value, ast_type); |
| 192 | } else if (builtin->builtin == ast::Builtin::kSampleMask) { |
| 193 | // Vulkan requires the type of a SampleMask builtin to be an array. |
| 194 | // Declare it as array<u32, 1> and then load the first element. |
Ben Clayton | 0ce9ab0 | 2022-05-05 20:23:40 +0000 | [diff] [blame] | 195 | ast_type = ctx.dst->ty.array(ast_type, 1_u); |
| 196 | value = ctx.dst->IndexAccessor(value, 0_i); |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 197 | } |
| 198 | } |
| 199 | ctx.dst->Global(symbol, ast_type, ast::StorageClass::kInput, std::move(attributes)); |
| 200 | return value; |
| 201 | } else if (cfg.shader_style == ShaderStyle::kMsl && |
| 202 | ast::HasAttribute<ast::BuiltinAttribute>(attributes)) { |
| 203 | // If this input is a builtin and we are targeting MSL, then add it to the |
| 204 | // parameter list and pass it directly to the inner function. |
| 205 | Symbol symbol = input_names.emplace(name).second ? ctx.dst->Symbols().Register(name) |
| 206 | : ctx.dst->Symbols().New(name); |
| 207 | wrapper_ep_parameters.push_back( |
| 208 | ctx.dst->Param(symbol, ast_type, std::move(attributes))); |
| 209 | return ctx.dst->Expr(symbol); |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 210 | } else { |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 211 | // Otherwise, move it to the new structure member list. |
| 212 | Symbol symbol = input_names.emplace(name).second ? ctx.dst->Symbols().Register(name) |
| 213 | : ctx.dst->Symbols().New(name); |
| 214 | wrapper_struct_param_members.push_back( |
| 215 | ctx.dst->Member(symbol, ast_type, std::move(attributes))); |
| 216 | return ctx.dst->MemberAccessor(InputStructSymbol(), symbol); |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 217 | } |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 218 | } |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 219 | |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 220 | /// Add a shader output to the entry point. |
| 221 | /// @param name the name of the shader output |
| 222 | /// @param type the type of the shader output |
| 223 | /// @param attributes the attributes to apply to the shader output |
| 224 | /// @param value the value of the shader output |
| 225 | void AddOutput(std::string name, |
| 226 | const sem::Type* type, |
| 227 | ast::AttributeList attributes, |
| 228 | const ast::Expression* value) { |
| 229 | // Vulkan requires that integer user-defined vertex outputs are |
| 230 | // always decorated with `Flat`. |
| 231 | // TODO(crbug.com/tint/1224): Remove this once a flat interpolation |
| 232 | // attribute is required for integers. |
| 233 | if (cfg.shader_style == ShaderStyle::kSpirv && type->is_integer_scalar_or_vector() && |
| 234 | ast::HasAttribute<ast::LocationAttribute>(attributes) && |
| 235 | !ast::HasAttribute<ast::InterpolateAttribute>(attributes) && |
| 236 | func_ast->PipelineStage() == ast::PipelineStage::kVertex) { |
| 237 | attributes.push_back(ctx.dst->Interpolate(ast::InterpolationType::kFlat, |
| 238 | ast::InterpolationSampling::kNone)); |
| 239 | } |
| 240 | |
| 241 | // In GLSL, if it's a builtin, override the name with the |
| 242 | // corresponding gl_ builtin name |
| 243 | if (cfg.shader_style == ShaderStyle::kGlsl) { |
| 244 | if (auto* b = ast::GetAttribute<ast::BuiltinAttribute>(attributes)) { |
| 245 | name = GLSLBuiltinToString(b->builtin, func_ast->PipelineStage(), |
| 246 | ast::StorageClass::kOutput); |
| 247 | value = ToGLSLBuiltin(b->builtin, value, type); |
| 248 | } |
| 249 | } |
| 250 | |
| 251 | OutputValue output; |
| 252 | output.name = name; |
| 253 | output.type = CreateASTTypeFor(ctx, type); |
| 254 | output.attributes = std::move(attributes); |
| 255 | output.value = value; |
| 256 | wrapper_output_values.push_back(output); |
| 257 | } |
| 258 | |
| 259 | /// Process a non-struct parameter. |
| 260 | /// This creates a new object for the shader input, moving the shader IO |
| 261 | /// attributes to it. It also adds an expression to the list of parameters |
| 262 | /// that will be passed to the original function. |
| 263 | /// @param param the original function parameter |
| 264 | void ProcessNonStructParameter(const sem::Parameter* param) { |
| 265 | // Remove the shader IO attributes from the inner function parameter, and |
| 266 | // attach them to the new object instead. |
| 267 | ast::AttributeList attributes; |
| 268 | for (auto* attr : param->Declaration()->attributes) { |
| 269 | if (IsShaderIOAttribute(attr)) { |
| 270 | ctx.Remove(param->Declaration()->attributes, attr); |
| 271 | attributes.push_back(ctx.Clone(attr)); |
| 272 | } |
| 273 | } |
| 274 | |
| 275 | auto name = ctx.src->Symbols().NameFor(param->Declaration()->symbol); |
| 276 | auto* input_expr = AddInput(name, param->Type(), std::move(attributes)); |
| 277 | inner_call_parameters.push_back(input_expr); |
| 278 | } |
| 279 | |
| 280 | /// Process a struct parameter. |
| 281 | /// This creates new objects for each struct member, moving the shader IO |
| 282 | /// attributes to them. It also creates the structure that will be passed to |
| 283 | /// the original function. |
| 284 | /// @param param the original function parameter |
| 285 | void ProcessStructParameter(const sem::Parameter* param) { |
| 286 | auto* str = param->Type()->As<sem::Struct>(); |
| 287 | |
| 288 | // Recreate struct members in the outer entry point and build an initializer |
| 289 | // list to pass them through to the inner function. |
| 290 | ast::ExpressionList inner_struct_values; |
| 291 | for (auto* member : str->Members()) { |
| 292 | if (member->Type()->Is<sem::Struct>()) { |
| 293 | TINT_ICE(Transform, ctx.dst->Diagnostics()) << "nested IO struct"; |
| 294 | continue; |
| 295 | } |
| 296 | |
| 297 | auto* member_ast = member->Declaration(); |
| 298 | auto name = ctx.src->Symbols().NameFor(member_ast->symbol); |
| 299 | |
| 300 | // In GLSL, do not add interpolation attributes on vertex input |
| 301 | bool do_interpolate = true; |
| 302 | if (cfg.shader_style == ShaderStyle::kGlsl && |
| 303 | func_ast->PipelineStage() == ast::PipelineStage::kVertex) { |
| 304 | do_interpolate = false; |
| 305 | } |
| 306 | auto attributes = CloneShaderIOAttributes(member_ast->attributes, do_interpolate); |
| 307 | auto* input_expr = AddInput(name, member->Type(), std::move(attributes)); |
| 308 | inner_struct_values.push_back(input_expr); |
| 309 | } |
| 310 | |
| 311 | // Construct the original structure using the new shader input objects. |
| 312 | inner_call_parameters.push_back( |
| 313 | ctx.dst->Construct(ctx.Clone(param->Declaration()->type), inner_struct_values)); |
| 314 | } |
| 315 | |
| 316 | /// Process the entry point return type. |
| 317 | /// This generates a list of output values that are returned by the original |
| 318 | /// function. |
| 319 | /// @param inner_ret_type the original function return type |
| 320 | /// @param original_result the result object produced by the original function |
| 321 | void ProcessReturnType(const sem::Type* inner_ret_type, Symbol original_result) { |
| 322 | bool do_interpolate = true; |
| 323 | // In GLSL, do not add interpolation attributes on fragment output |
| 324 | if (cfg.shader_style == ShaderStyle::kGlsl && |
| 325 | func_ast->PipelineStage() == ast::PipelineStage::kFragment) { |
| 326 | do_interpolate = false; |
| 327 | } |
| 328 | if (auto* str = inner_ret_type->As<sem::Struct>()) { |
| 329 | for (auto* member : str->Members()) { |
| 330 | if (member->Type()->Is<sem::Struct>()) { |
| 331 | TINT_ICE(Transform, ctx.dst->Diagnostics()) << "nested IO struct"; |
| 332 | continue; |
| 333 | } |
| 334 | |
| 335 | auto* member_ast = member->Declaration(); |
| 336 | auto name = ctx.src->Symbols().NameFor(member_ast->symbol); |
| 337 | auto attributes = CloneShaderIOAttributes(member_ast->attributes, do_interpolate); |
| 338 | |
| 339 | // Extract the original structure member. |
| 340 | AddOutput(name, member->Type(), std::move(attributes), |
| 341 | ctx.dst->MemberAccessor(original_result, name)); |
| 342 | } |
| 343 | } else if (!inner_ret_type->Is<sem::Void>()) { |
| 344 | auto attributes = |
| 345 | CloneShaderIOAttributes(func_ast->return_type_attributes, do_interpolate); |
| 346 | |
| 347 | // Propagate the non-struct return value as is. |
| 348 | AddOutput("value", func_sem->ReturnType(), std::move(attributes), |
| 349 | ctx.dst->Expr(original_result)); |
| 350 | } |
| 351 | } |
| 352 | |
| 353 | /// Add a fixed sample mask to the wrapper function output. |
| 354 | /// If there is already a sample mask, bitwise-and it with the fixed mask. |
| 355 | /// Otherwise, create a new output value from the fixed mask. |
| 356 | void AddFixedSampleMask() { |
| 357 | // Check the existing output values for a sample mask builtin. |
| 358 | for (auto& outval : wrapper_output_values) { |
| 359 | if (HasSampleMask(outval.attributes)) { |
| 360 | // Combine the authored sample mask with the fixed mask. |
Ben Clayton | 0ce9ab0 | 2022-05-05 20:23:40 +0000 | [diff] [blame] | 361 | outval.value = ctx.dst->And(outval.value, u32(cfg.fixed_sample_mask)); |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 362 | return; |
| 363 | } |
| 364 | } |
| 365 | |
| 366 | // No existing sample mask builtin was found, so create a new output value |
| 367 | // using the fixed sample mask. |
| 368 | AddOutput("fixed_sample_mask", ctx.dst->create<sem::U32>(), |
| 369 | {ctx.dst->Builtin(ast::Builtin::kSampleMask)}, |
Ben Clayton | 0ce9ab0 | 2022-05-05 20:23:40 +0000 | [diff] [blame] | 370 | ctx.dst->Expr(u32(cfg.fixed_sample_mask))); |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 371 | } |
| 372 | |
| 373 | /// Add a point size builtin to the wrapper function output. |
| 374 | void AddVertexPointSize() { |
| 375 | // Create a new output value and assign it a literal 1.0 value. |
| 376 | AddOutput("vertex_point_size", ctx.dst->create<sem::F32>(), |
Ben Clayton | 0a3cda9 | 2022-05-10 17:30:15 +0000 | [diff] [blame] | 377 | {ctx.dst->Builtin(ast::Builtin::kPointSize)}, ctx.dst->Expr(1_f)); |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 378 | } |
| 379 | |
| 380 | /// Create an expression for gl_Position.[component] |
| 381 | /// @param component the component of gl_Position to access |
| 382 | /// @returns the new expression |
| 383 | const ast::Expression* GLPosition(const char* component) { |
| 384 | Symbol pos = ctx.dst->Symbols().Register("gl_Position"); |
| 385 | Symbol c = ctx.dst->Symbols().Register(component); |
| 386 | return ctx.dst->MemberAccessor(ctx.dst->Expr(pos), ctx.dst->Expr(c)); |
| 387 | } |
| 388 | |
| 389 | /// Create the wrapper function's struct parameter and type objects. |
| 390 | void CreateInputStruct() { |
| 391 | // Sort the struct members to satisfy HLSL interfacing matching rules. |
| 392 | std::sort(wrapper_struct_param_members.begin(), wrapper_struct_param_members.end(), |
| 393 | StructMemberComparator); |
| 394 | |
| 395 | // Create the new struct type. |
| 396 | auto struct_name = ctx.dst->Sym(); |
| 397 | auto* in_struct = ctx.dst->create<ast::Struct>(struct_name, wrapper_struct_param_members, |
| 398 | ast::AttributeList{}); |
| 399 | ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast, in_struct); |
| 400 | |
| 401 | // Create a new function parameter using this struct type. |
| 402 | auto* param = ctx.dst->Param(InputStructSymbol(), ctx.dst->ty.type_name(struct_name)); |
| 403 | wrapper_ep_parameters.push_back(param); |
| 404 | } |
| 405 | |
| 406 | /// Create and return the wrapper function's struct result object. |
| 407 | /// @returns the struct type |
| 408 | ast::Struct* CreateOutputStruct() { |
| 409 | ast::StatementList assignments; |
| 410 | |
| 411 | auto wrapper_result = ctx.dst->Symbols().New("wrapper_result"); |
| 412 | |
| 413 | // Create the struct members and their corresponding assignment statements. |
| 414 | std::unordered_set<std::string> member_names; |
| 415 | for (auto& outval : wrapper_output_values) { |
| 416 | // Use the original output name, unless that is already taken. |
| 417 | Symbol name; |
| 418 | if (member_names.count(outval.name)) { |
| 419 | name = ctx.dst->Symbols().New(outval.name); |
| 420 | } else { |
| 421 | name = ctx.dst->Symbols().Register(outval.name); |
| 422 | } |
| 423 | member_names.insert(ctx.dst->Symbols().NameFor(name)); |
| 424 | |
| 425 | wrapper_struct_output_members.push_back( |
| 426 | ctx.dst->Member(name, outval.type, std::move(outval.attributes))); |
| 427 | assignments.push_back( |
| 428 | ctx.dst->Assign(ctx.dst->MemberAccessor(wrapper_result, name), outval.value)); |
| 429 | } |
| 430 | |
| 431 | // Sort the struct members to satisfy HLSL interfacing matching rules. |
| 432 | std::sort(wrapper_struct_output_members.begin(), wrapper_struct_output_members.end(), |
| 433 | StructMemberComparator); |
| 434 | |
| 435 | // Create the new struct type. |
| 436 | auto* out_struct = ctx.dst->create<ast::Struct>( |
| 437 | ctx.dst->Sym(), wrapper_struct_output_members, ast::AttributeList{}); |
| 438 | ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast, out_struct); |
| 439 | |
| 440 | // Create the output struct object, assign its members, and return it. |
| 441 | auto* result_object = ctx.dst->Var(wrapper_result, ctx.dst->ty.type_name(out_struct->name)); |
| 442 | wrapper_body.push_back(ctx.dst->Decl(result_object)); |
| 443 | wrapper_body.insert(wrapper_body.end(), assignments.begin(), assignments.end()); |
| 444 | wrapper_body.push_back(ctx.dst->Return(wrapper_result)); |
| 445 | |
| 446 | return out_struct; |
| 447 | } |
| 448 | |
| 449 | /// Create and assign the wrapper function's output variables. |
| 450 | void CreateGlobalOutputVariables() { |
| 451 | for (auto& outval : wrapper_output_values) { |
| 452 | // Disable validation for use of the `output` storage class. |
| 453 | ast::AttributeList attributes = std::move(outval.attributes); |
| 454 | attributes.push_back(ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass)); |
| 455 | |
| 456 | // Create the global variable and assign it the output value. |
| 457 | auto name = ctx.dst->Symbols().New(outval.name); |
| 458 | auto* type = outval.type; |
| 459 | const ast::Expression* lhs = ctx.dst->Expr(name); |
| 460 | if (HasSampleMask(attributes)) { |
| 461 | // Vulkan requires the type of a SampleMask builtin to be an array. |
| 462 | // Declare it as array<u32, 1> and then store to the first element. |
Ben Clayton | 0ce9ab0 | 2022-05-05 20:23:40 +0000 | [diff] [blame] | 463 | type = ctx.dst->ty.array(type, 1_u); |
| 464 | lhs = ctx.dst->IndexAccessor(lhs, 0_i); |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 465 | } |
| 466 | ctx.dst->Global(name, type, ast::StorageClass::kOutput, std::move(attributes)); |
| 467 | wrapper_body.push_back(ctx.dst->Assign(lhs, outval.value)); |
| 468 | } |
| 469 | } |
| 470 | |
| 471 | // Recreate the original function without entry point attributes and call it. |
| 472 | /// @returns the inner function call expression |
| 473 | const ast::CallExpression* CallInnerFunction() { |
| 474 | Symbol inner_name; |
| 475 | if (cfg.shader_style == ShaderStyle::kGlsl) { |
| 476 | // In GLSL, clone the original entry point name, as the wrapper will be |
| 477 | // called "main". |
| 478 | inner_name = ctx.Clone(func_ast->symbol); |
| 479 | } else { |
| 480 | // Add a suffix to the function name, as the wrapper function will take |
| 481 | // the original entry point name. |
| 482 | auto ep_name = ctx.src->Symbols().NameFor(func_ast->symbol); |
| 483 | inner_name = ctx.dst->Symbols().New(ep_name + "_inner"); |
| 484 | } |
| 485 | |
| 486 | // Clone everything, dropping the function and return type attributes. |
| 487 | // The parameter attributes will have already been stripped during |
| 488 | // processing. |
| 489 | auto* inner_function = ctx.dst->create<ast::Function>( |
| 490 | inner_name, ctx.Clone(func_ast->params), ctx.Clone(func_ast->return_type), |
| 491 | ctx.Clone(func_ast->body), ast::AttributeList{}, ast::AttributeList{}); |
| 492 | ctx.Replace(func_ast, inner_function); |
| 493 | |
| 494 | // Call the function. |
| 495 | return ctx.dst->Call(inner_function->symbol, inner_call_parameters); |
| 496 | } |
| 497 | |
| 498 | /// Process the entry point function. |
| 499 | void Process() { |
| 500 | bool needs_fixed_sample_mask = false; |
| 501 | bool needs_vertex_point_size = false; |
| 502 | if (func_ast->PipelineStage() == ast::PipelineStage::kFragment && |
| 503 | cfg.fixed_sample_mask != 0xFFFFFFFF) { |
| 504 | needs_fixed_sample_mask = true; |
| 505 | } |
| 506 | if (func_ast->PipelineStage() == ast::PipelineStage::kVertex && |
| 507 | cfg.emit_vertex_point_size) { |
| 508 | needs_vertex_point_size = true; |
| 509 | } |
| 510 | |
| 511 | // Exit early if there is no shader IO to handle. |
| 512 | if (func_sem->Parameters().size() == 0 && func_sem->ReturnType()->Is<sem::Void>() && |
| 513 | !needs_fixed_sample_mask && !needs_vertex_point_size && |
| 514 | cfg.shader_style != ShaderStyle::kGlsl) { |
| 515 | return; |
| 516 | } |
| 517 | |
| 518 | // Process the entry point parameters, collecting those that need to be |
| 519 | // aggregated into a single structure. |
| 520 | if (!func_sem->Parameters().empty()) { |
| 521 | for (auto* param : func_sem->Parameters()) { |
| 522 | if (param->Type()->Is<sem::Struct>()) { |
| 523 | ProcessStructParameter(param); |
| 524 | } else { |
| 525 | ProcessNonStructParameter(param); |
| 526 | } |
| 527 | } |
| 528 | |
| 529 | // Create a structure parameter for the outer entry point if necessary. |
| 530 | if (!wrapper_struct_param_members.empty()) { |
| 531 | CreateInputStruct(); |
| 532 | } |
| 533 | } |
| 534 | |
| 535 | // Recreate the original function and call it. |
| 536 | auto* call_inner = CallInnerFunction(); |
| 537 | |
| 538 | // Process the return type, and start building the wrapper function body. |
| 539 | std::function<const ast::Type*()> wrapper_ret_type = [&] { return ctx.dst->ty.void_(); }; |
| 540 | if (func_sem->ReturnType()->Is<sem::Void>()) { |
| 541 | // The function call is just a statement with no result. |
| 542 | wrapper_body.push_back(ctx.dst->CallStmt(call_inner)); |
| 543 | } else { |
| 544 | // Capture the result of calling the original function. |
| 545 | auto* inner_result = |
| 546 | ctx.dst->Let(ctx.dst->Symbols().New("inner_result"), nullptr, call_inner); |
| 547 | wrapper_body.push_back(ctx.dst->Decl(inner_result)); |
| 548 | |
| 549 | // Process the original return type to determine the outputs that the |
| 550 | // outer function needs to produce. |
| 551 | ProcessReturnType(func_sem->ReturnType(), inner_result->symbol); |
| 552 | } |
| 553 | |
| 554 | // Add a fixed sample mask, if necessary. |
| 555 | if (needs_fixed_sample_mask) { |
| 556 | AddFixedSampleMask(); |
| 557 | } |
| 558 | |
| 559 | // Add the pointsize builtin, if necessary. |
| 560 | if (needs_vertex_point_size) { |
| 561 | AddVertexPointSize(); |
| 562 | } |
| 563 | |
| 564 | // Produce the entry point outputs, if necessary. |
| 565 | if (!wrapper_output_values.empty()) { |
| 566 | if (cfg.shader_style == ShaderStyle::kSpirv || cfg.shader_style == ShaderStyle::kGlsl) { |
| 567 | CreateGlobalOutputVariables(); |
| 568 | } else { |
| 569 | auto* output_struct = CreateOutputStruct(); |
| 570 | wrapper_ret_type = [&, output_struct] { |
| 571 | return ctx.dst->ty.type_name(output_struct->name); |
| 572 | }; |
| 573 | } |
| 574 | } |
| 575 | |
| 576 | if (cfg.shader_style == ShaderStyle::kGlsl && |
| 577 | func_ast->PipelineStage() == ast::PipelineStage::kVertex) { |
| 578 | auto* pos_y = GLPosition("y"); |
| 579 | auto* negate_pos_y = |
| 580 | ctx.dst->create<ast::UnaryOpExpression>(ast::UnaryOp::kNegation, GLPosition("y")); |
| 581 | wrapper_body.push_back(ctx.dst->Assign(pos_y, negate_pos_y)); |
| 582 | |
Ben Clayton | 0a3cda9 | 2022-05-10 17:30:15 +0000 | [diff] [blame] | 583 | auto* two_z = ctx.dst->Mul(ctx.dst->Expr(2_f), GLPosition("z")); |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 584 | auto* fixed_z = ctx.dst->Sub(two_z, GLPosition("w")); |
| 585 | wrapper_body.push_back(ctx.dst->Assign(GLPosition("z"), fixed_z)); |
| 586 | } |
| 587 | |
| 588 | // Create the wrapper entry point function. |
| 589 | // For GLSL, use "main", otherwise take the name of the original |
| 590 | // entry point function. |
| 591 | Symbol name; |
| 592 | if (cfg.shader_style == ShaderStyle::kGlsl) { |
| 593 | name = ctx.dst->Symbols().New("main"); |
| 594 | } else { |
| 595 | name = ctx.Clone(func_ast->symbol); |
| 596 | } |
| 597 | |
| 598 | auto* wrapper_func = ctx.dst->create<ast::Function>( |
| 599 | name, wrapper_ep_parameters, wrapper_ret_type(), ctx.dst->Block(wrapper_body), |
| 600 | ctx.Clone(func_ast->attributes), ast::AttributeList{}); |
| 601 | ctx.InsertAfter(ctx.src->AST().GlobalDeclarations(), func_ast, wrapper_func); |
| 602 | } |
| 603 | |
| 604 | /// Retrieve the gl_ string corresponding to a builtin. |
| 605 | /// @param builtin the builtin |
| 606 | /// @param stage the current pipeline stage |
| 607 | /// @param storage_class the storage class (input or output) |
| 608 | /// @returns the gl_ string corresponding to that builtin |
| 609 | const char* GLSLBuiltinToString(ast::Builtin builtin, |
| 610 | ast::PipelineStage stage, |
| 611 | ast::StorageClass storage_class) { |
| 612 | switch (builtin) { |
| 613 | case ast::Builtin::kPosition: |
| 614 | switch (stage) { |
| 615 | case ast::PipelineStage::kVertex: |
| 616 | return "gl_Position"; |
| 617 | case ast::PipelineStage::kFragment: |
| 618 | return "gl_FragCoord"; |
| 619 | default: |
| 620 | return ""; |
| 621 | } |
| 622 | case ast::Builtin::kVertexIndex: |
| 623 | return "gl_VertexID"; |
| 624 | case ast::Builtin::kInstanceIndex: |
| 625 | return "gl_InstanceID"; |
| 626 | case ast::Builtin::kFrontFacing: |
| 627 | return "gl_FrontFacing"; |
| 628 | case ast::Builtin::kFragDepth: |
| 629 | return "gl_FragDepth"; |
| 630 | case ast::Builtin::kLocalInvocationId: |
| 631 | return "gl_LocalInvocationID"; |
| 632 | case ast::Builtin::kLocalInvocationIndex: |
| 633 | return "gl_LocalInvocationIndex"; |
| 634 | case ast::Builtin::kGlobalInvocationId: |
| 635 | return "gl_GlobalInvocationID"; |
| 636 | case ast::Builtin::kNumWorkgroups: |
| 637 | return "gl_NumWorkGroups"; |
| 638 | case ast::Builtin::kWorkgroupId: |
| 639 | return "gl_WorkGroupID"; |
| 640 | case ast::Builtin::kSampleIndex: |
| 641 | return "gl_SampleID"; |
| 642 | case ast::Builtin::kSampleMask: |
| 643 | if (storage_class == ast::StorageClass::kInput) { |
| 644 | return "gl_SampleMaskIn"; |
| 645 | } else { |
| 646 | return "gl_SampleMask"; |
| 647 | } |
| 648 | default: |
| 649 | return ""; |
| 650 | } |
| 651 | } |
| 652 | |
| 653 | /// Convert a given GLSL builtin value to the corresponding WGSL value. |
| 654 | /// @param builtin the builtin variable |
| 655 | /// @param value the value to convert |
| 656 | /// @param ast_type (inout) the incoming WGSL and outgoing GLSL types |
| 657 | /// @returns an expression representing the GLSL builtin converted to what |
| 658 | /// WGSL expects |
| 659 | const ast::Expression* FromGLSLBuiltin(ast::Builtin builtin, |
| 660 | const ast::Expression* value, |
| 661 | const ast::Type*& ast_type) { |
| 662 | switch (builtin) { |
| 663 | case ast::Builtin::kVertexIndex: |
| 664 | case ast::Builtin::kInstanceIndex: |
| 665 | case ast::Builtin::kSampleIndex: |
| 666 | // GLSL uses i32 for these, so bitcast to u32. |
| 667 | value = ctx.dst->Bitcast(ast_type, value); |
| 668 | ast_type = ctx.dst->ty.i32(); |
| 669 | break; |
| 670 | case ast::Builtin::kSampleMask: |
| 671 | // gl_SampleMask is an array of i32. Retrieve the first element and |
| 672 | // bitcast it to u32. |
Ben Clayton | 0ce9ab0 | 2022-05-05 20:23:40 +0000 | [diff] [blame] | 673 | value = ctx.dst->IndexAccessor(value, 0_i); |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 674 | value = ctx.dst->Bitcast(ast_type, value); |
Ben Clayton | 0ce9ab0 | 2022-05-05 20:23:40 +0000 | [diff] [blame] | 675 | ast_type = ctx.dst->ty.array(ctx.dst->ty.i32(), 1_u); |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 676 | break; |
| 677 | default: |
| 678 | break; |
| 679 | } |
| 680 | return value; |
| 681 | } |
| 682 | |
| 683 | /// Convert a given WGSL value to the type expected when assigning to a |
| 684 | /// GLSL builtin. |
| 685 | /// @param builtin the builtin variable |
| 686 | /// @param value the value to convert |
| 687 | /// @param type (out) the type to which the value was converted |
| 688 | /// @returns the converted value which can be assigned to the GLSL builtin |
| 689 | const ast::Expression* ToGLSLBuiltin(ast::Builtin builtin, |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 690 | const ast::Expression* value, |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 691 | const sem::Type*& type) { |
| 692 | switch (builtin) { |
| 693 | case ast::Builtin::kVertexIndex: |
| 694 | case ast::Builtin::kInstanceIndex: |
| 695 | case ast::Builtin::kSampleIndex: |
| 696 | case ast::Builtin::kSampleMask: |
| 697 | type = ctx.dst->create<sem::I32>(); |
| 698 | value = ctx.dst->Bitcast(CreateASTTypeFor(ctx, type), value); |
| 699 | break; |
| 700 | default: |
| 701 | break; |
| 702 | } |
| 703 | return value; |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 704 | } |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 705 | }; |
| 706 | |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 707 | void CanonicalizeEntryPointIO::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const { |
| 708 | auto* cfg = inputs.Get<Config>(); |
| 709 | if (cfg == nullptr) { |
| 710 | ctx.dst->Diagnostics().add_error( |
| 711 | diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name)); |
| 712 | return; |
| 713 | } |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 714 | |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 715 | // Remove entry point IO attributes from struct declarations. |
| 716 | // New structures will be created for each entry point, as necessary. |
| 717 | for (auto* ty : ctx.src->AST().TypeDecls()) { |
| 718 | if (auto* struct_ty = ty->As<ast::Struct>()) { |
| 719 | for (auto* member : struct_ty->members) { |
| 720 | for (auto* attr : member->attributes) { |
| 721 | if (IsShaderIOAttribute(attr)) { |
| 722 | ctx.Remove(member->attributes, attr); |
| 723 | } |
| 724 | } |
| 725 | } |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 726 | } |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 727 | } |
| 728 | |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 729 | for (auto* func_ast : ctx.src->AST().Functions()) { |
| 730 | if (!func_ast->IsEntryPoint()) { |
| 731 | continue; |
| 732 | } |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 733 | |
dan sinclair | 41e4d9a | 2022-05-01 14:40:55 +0000 | [diff] [blame] | 734 | State state(ctx, *cfg, func_ast); |
| 735 | state.Process(); |
| 736 | } |
| 737 | |
| 738 | ctx.Clone(); |
Ryan Harrison | dbc13af | 2022-02-21 15:19:07 +0000 | [diff] [blame] | 739 | } |
| 740 | |
| 741 | CanonicalizeEntryPointIO::Config::Config(ShaderStyle style, |
| 742 | uint32_t sample_mask, |
| 743 | bool emit_point_size) |
| 744 | : shader_style(style), |
| 745 | fixed_sample_mask(sample_mask), |
| 746 | emit_vertex_point_size(emit_point_size) {} |
| 747 | |
| 748 | CanonicalizeEntryPointIO::Config::Config(const Config&) = default; |
| 749 | CanonicalizeEntryPointIO::Config::~Config() = default; |
| 750 | |
dan sinclair | b5599d3 | 2022-04-07 16:55:14 +0000 | [diff] [blame] | 751 | } // namespace tint::transform |