blob: 44d577bb52132b914d8e59d608ca838dc1bdb6f8 [file] [log] [blame]
Ryan Harrisondbc13af2022-02-21 15:19:07 +00001// Copyright 2021 The Tint Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "src/tint/transform/canonicalize_entry_point_io.h"
16
17#include <algorithm>
18#include <string>
19#include <unordered_set>
20#include <utility>
21#include <vector>
22
23#include "src/tint/ast/disable_validation_attribute.h"
24#include "src/tint/program_builder.h"
25#include "src/tint/sem/function.h"
26#include "src/tint/transform/unshadow.h"
27
Ben Clayton0ce9ab02022-05-05 20:23:40 +000028using namespace tint::number_suffixes; // NOLINT
29
Ryan Harrisondbc13af2022-02-21 15:19:07 +000030TINT_INSTANTIATE_TYPEINFO(tint::transform::CanonicalizeEntryPointIO);
31TINT_INSTANTIATE_TYPEINFO(tint::transform::CanonicalizeEntryPointIO::Config);
32
dan sinclairb5599d32022-04-07 16:55:14 +000033namespace tint::transform {
Ryan Harrisondbc13af2022-02-21 15:19:07 +000034
35CanonicalizeEntryPointIO::CanonicalizeEntryPointIO() = default;
36CanonicalizeEntryPointIO::~CanonicalizeEntryPointIO() = default;
37
38namespace {
39
40// Comparison function used to reorder struct members such that all members with
41// location attributes appear first (ordered by location slot), followed by
42// those with builtin attributes.
dan sinclair41e4d9a2022-05-01 14:40:55 +000043bool StructMemberComparator(const ast::StructMember* a, const ast::StructMember* b) {
44 auto* a_loc = ast::GetAttribute<ast::LocationAttribute>(a->attributes);
45 auto* b_loc = ast::GetAttribute<ast::LocationAttribute>(b->attributes);
46 auto* a_blt = ast::GetAttribute<ast::BuiltinAttribute>(a->attributes);
47 auto* b_blt = ast::GetAttribute<ast::BuiltinAttribute>(b->attributes);
48 if (a_loc) {
49 if (!b_loc) {
50 // `a` has location attribute and `b` does not: `a` goes first.
51 return true;
52 }
53 // Both have location attributes: smallest goes first.
54 return a_loc->value < b_loc->value;
55 } else {
56 if (b_loc) {
57 // `b` has location attribute and `a` does not: `b` goes first.
58 return false;
59 }
60 // Both are builtins: order doesn't matter, just use enum value.
61 return a_blt->builtin < b_blt->builtin;
Ryan Harrisondbc13af2022-02-21 15:19:07 +000062 }
Ryan Harrisondbc13af2022-02-21 15:19:07 +000063}
64
65// Returns true if `attr` is a shader IO attribute.
66bool IsShaderIOAttribute(const ast::Attribute* attr) {
dan sinclair41e4d9a2022-05-01 14:40:55 +000067 return attr->IsAnyOf<ast::BuiltinAttribute, ast::InterpolateAttribute, ast::InvariantAttribute,
68 ast::LocationAttribute>();
Ryan Harrisondbc13af2022-02-21 15:19:07 +000069}
70
71// Returns true if `attrs` contains a `sample_mask` builtin.
72bool HasSampleMask(const ast::AttributeList& attrs) {
dan sinclair41e4d9a2022-05-01 14:40:55 +000073 auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(attrs);
74 return builtin && builtin->builtin == ast::Builtin::kSampleMask;
Ryan Harrisondbc13af2022-02-21 15:19:07 +000075}
76
77} // namespace
78
79/// State holds the current transform state for a single entry point.
80struct CanonicalizeEntryPointIO::State {
dan sinclair41e4d9a2022-05-01 14:40:55 +000081 /// OutputValue represents a shader result that the wrapper function produces.
82 struct OutputValue {
83 /// The name of the output value.
84 std::string name;
85 /// The type of the output value.
86 const ast::Type* type;
87 /// The shader IO attributes.
88 ast::AttributeList attributes;
89 /// The value itself.
90 const ast::Expression* value;
Ryan Harrisondbc13af2022-02-21 15:19:07 +000091 };
Ryan Harrisondbc13af2022-02-21 15:19:07 +000092
dan sinclair41e4d9a2022-05-01 14:40:55 +000093 /// The clone context.
94 CloneContext& ctx;
95 /// The transform config.
96 CanonicalizeEntryPointIO::Config const cfg;
97 /// The entry point function (AST).
98 const ast::Function* func_ast;
99 /// The entry point function (SEM).
100 const sem::Function* func_sem;
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000101
dan sinclair41e4d9a2022-05-01 14:40:55 +0000102 /// The new entry point wrapper function's parameters.
103 ast::VariableList wrapper_ep_parameters;
104 /// The members of the wrapper function's struct parameter.
105 ast::StructMemberList wrapper_struct_param_members;
106 /// The name of the wrapper function's struct parameter.
107 Symbol wrapper_struct_param_name;
108 /// The parameters that will be passed to the original function.
109 ast::ExpressionList inner_call_parameters;
110 /// The members of the wrapper function's struct return type.
111 ast::StructMemberList wrapper_struct_output_members;
112 /// The wrapper function output values.
113 std::vector<OutputValue> wrapper_output_values;
114 /// The body of the wrapper function.
115 ast::StatementList wrapper_body;
116 /// Input names used by the entrypoint
117 std::unordered_set<std::string> input_names;
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000118
dan sinclair41e4d9a2022-05-01 14:40:55 +0000119 /// Constructor
120 /// @param context the clone context
121 /// @param config the transform config
122 /// @param function the entry point function
123 State(CloneContext& context,
124 const CanonicalizeEntryPointIO::Config& config,
125 const ast::Function* function)
126 : ctx(context), cfg(config), func_ast(function), func_sem(ctx.src->Sem().Get(function)) {}
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000127
dan sinclair41e4d9a2022-05-01 14:40:55 +0000128 /// Clones the shader IO attributes from `src`.
129 /// @param src the attributes to clone
130 /// @param do_interpolate whether to clone InterpolateAttribute
131 /// @return the cloned attributes
132 ast::AttributeList CloneShaderIOAttributes(const ast::AttributeList& src, bool do_interpolate) {
133 ast::AttributeList new_attributes;
134 for (auto* attr : src) {
135 if (IsShaderIOAttribute(attr) &&
136 (do_interpolate || !attr->Is<ast::InterpolateAttribute>())) {
137 new_attributes.push_back(ctx.Clone(attr));
138 }
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000139 }
dan sinclair41e4d9a2022-05-01 14:40:55 +0000140 return new_attributes;
141 }
142
143 /// Create or return a symbol for the wrapper function's struct parameter.
144 /// @returns the symbol for the struct parameter
145 Symbol InputStructSymbol() {
146 if (!wrapper_struct_param_name.IsValid()) {
147 wrapper_struct_param_name = ctx.dst->Sym();
148 }
149 return wrapper_struct_param_name;
150 }
151
152 /// Add a shader input to the entry point.
153 /// @param name the name of the shader input
154 /// @param type the type of the shader input
155 /// @param attributes the attributes to apply to the shader input
156 /// @returns an expression which evaluates to the value of the shader input
157 const ast::Expression* AddInput(std::string name,
158 const sem::Type* type,
159 ast::AttributeList attributes) {
160 auto* ast_type = CreateASTTypeFor(ctx, type);
161 if (cfg.shader_style == ShaderStyle::kSpirv || cfg.shader_style == ShaderStyle::kGlsl) {
162 // Vulkan requires that integer user-defined fragment inputs are
163 // always decorated with `Flat`.
164 // TODO(crbug.com/tint/1224): Remove this once a flat interpolation
165 // attribute is required for integers.
166 if (type->is_integer_scalar_or_vector() &&
167 ast::HasAttribute<ast::LocationAttribute>(attributes) &&
168 !ast::HasAttribute<ast::InterpolateAttribute>(attributes) &&
169 func_ast->PipelineStage() == ast::PipelineStage::kFragment) {
170 attributes.push_back(ctx.dst->Interpolate(ast::InterpolationType::kFlat,
171 ast::InterpolationSampling::kNone));
172 }
173
174 // Disable validation for use of the `input` storage class.
175 attributes.push_back(ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass));
176
177 // In GLSL, if it's a builtin, override the name with the
178 // corresponding gl_ builtin name
179 auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(attributes);
180 if (cfg.shader_style == ShaderStyle::kGlsl && builtin) {
181 name = GLSLBuiltinToString(builtin->builtin, func_ast->PipelineStage(),
182 ast::StorageClass::kInput);
183 }
184 auto symbol = ctx.dst->Symbols().New(name);
185
186 // Create the global variable and use its value for the shader input.
187 const ast::Expression* value = ctx.dst->Expr(symbol);
188
189 if (builtin) {
190 if (cfg.shader_style == ShaderStyle::kGlsl) {
191 value = FromGLSLBuiltin(builtin->builtin, value, ast_type);
192 } else if (builtin->builtin == ast::Builtin::kSampleMask) {
193 // Vulkan requires the type of a SampleMask builtin to be an array.
194 // Declare it as array<u32, 1> and then load the first element.
Ben Clayton0ce9ab02022-05-05 20:23:40 +0000195 ast_type = ctx.dst->ty.array(ast_type, 1_u);
196 value = ctx.dst->IndexAccessor(value, 0_i);
dan sinclair41e4d9a2022-05-01 14:40:55 +0000197 }
198 }
199 ctx.dst->Global(symbol, ast_type, ast::StorageClass::kInput, std::move(attributes));
200 return value;
201 } else if (cfg.shader_style == ShaderStyle::kMsl &&
202 ast::HasAttribute<ast::BuiltinAttribute>(attributes)) {
203 // If this input is a builtin and we are targeting MSL, then add it to the
204 // parameter list and pass it directly to the inner function.
205 Symbol symbol = input_names.emplace(name).second ? ctx.dst->Symbols().Register(name)
206 : ctx.dst->Symbols().New(name);
207 wrapper_ep_parameters.push_back(
208 ctx.dst->Param(symbol, ast_type, std::move(attributes)));
209 return ctx.dst->Expr(symbol);
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000210 } else {
dan sinclair41e4d9a2022-05-01 14:40:55 +0000211 // Otherwise, move it to the new structure member list.
212 Symbol symbol = input_names.emplace(name).second ? ctx.dst->Symbols().Register(name)
213 : ctx.dst->Symbols().New(name);
214 wrapper_struct_param_members.push_back(
215 ctx.dst->Member(symbol, ast_type, std::move(attributes)));
216 return ctx.dst->MemberAccessor(InputStructSymbol(), symbol);
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000217 }
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000218 }
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000219
dan sinclair41e4d9a2022-05-01 14:40:55 +0000220 /// Add a shader output to the entry point.
221 /// @param name the name of the shader output
222 /// @param type the type of the shader output
223 /// @param attributes the attributes to apply to the shader output
224 /// @param value the value of the shader output
225 void AddOutput(std::string name,
226 const sem::Type* type,
227 ast::AttributeList attributes,
228 const ast::Expression* value) {
229 // Vulkan requires that integer user-defined vertex outputs are
230 // always decorated with `Flat`.
231 // TODO(crbug.com/tint/1224): Remove this once a flat interpolation
232 // attribute is required for integers.
233 if (cfg.shader_style == ShaderStyle::kSpirv && type->is_integer_scalar_or_vector() &&
234 ast::HasAttribute<ast::LocationAttribute>(attributes) &&
235 !ast::HasAttribute<ast::InterpolateAttribute>(attributes) &&
236 func_ast->PipelineStage() == ast::PipelineStage::kVertex) {
237 attributes.push_back(ctx.dst->Interpolate(ast::InterpolationType::kFlat,
238 ast::InterpolationSampling::kNone));
239 }
240
241 // In GLSL, if it's a builtin, override the name with the
242 // corresponding gl_ builtin name
243 if (cfg.shader_style == ShaderStyle::kGlsl) {
244 if (auto* b = ast::GetAttribute<ast::BuiltinAttribute>(attributes)) {
245 name = GLSLBuiltinToString(b->builtin, func_ast->PipelineStage(),
246 ast::StorageClass::kOutput);
247 value = ToGLSLBuiltin(b->builtin, value, type);
248 }
249 }
250
251 OutputValue output;
252 output.name = name;
253 output.type = CreateASTTypeFor(ctx, type);
254 output.attributes = std::move(attributes);
255 output.value = value;
256 wrapper_output_values.push_back(output);
257 }
258
259 /// Process a non-struct parameter.
260 /// This creates a new object for the shader input, moving the shader IO
261 /// attributes to it. It also adds an expression to the list of parameters
262 /// that will be passed to the original function.
263 /// @param param the original function parameter
264 void ProcessNonStructParameter(const sem::Parameter* param) {
265 // Remove the shader IO attributes from the inner function parameter, and
266 // attach them to the new object instead.
267 ast::AttributeList attributes;
268 for (auto* attr : param->Declaration()->attributes) {
269 if (IsShaderIOAttribute(attr)) {
270 ctx.Remove(param->Declaration()->attributes, attr);
271 attributes.push_back(ctx.Clone(attr));
272 }
273 }
274
275 auto name = ctx.src->Symbols().NameFor(param->Declaration()->symbol);
276 auto* input_expr = AddInput(name, param->Type(), std::move(attributes));
277 inner_call_parameters.push_back(input_expr);
278 }
279
280 /// Process a struct parameter.
281 /// This creates new objects for each struct member, moving the shader IO
282 /// attributes to them. It also creates the structure that will be passed to
283 /// the original function.
284 /// @param param the original function parameter
285 void ProcessStructParameter(const sem::Parameter* param) {
286 auto* str = param->Type()->As<sem::Struct>();
287
288 // Recreate struct members in the outer entry point and build an initializer
289 // list to pass them through to the inner function.
290 ast::ExpressionList inner_struct_values;
291 for (auto* member : str->Members()) {
292 if (member->Type()->Is<sem::Struct>()) {
293 TINT_ICE(Transform, ctx.dst->Diagnostics()) << "nested IO struct";
294 continue;
295 }
296
297 auto* member_ast = member->Declaration();
298 auto name = ctx.src->Symbols().NameFor(member_ast->symbol);
299
300 // In GLSL, do not add interpolation attributes on vertex input
301 bool do_interpolate = true;
302 if (cfg.shader_style == ShaderStyle::kGlsl &&
303 func_ast->PipelineStage() == ast::PipelineStage::kVertex) {
304 do_interpolate = false;
305 }
306 auto attributes = CloneShaderIOAttributes(member_ast->attributes, do_interpolate);
307 auto* input_expr = AddInput(name, member->Type(), std::move(attributes));
308 inner_struct_values.push_back(input_expr);
309 }
310
311 // Construct the original structure using the new shader input objects.
312 inner_call_parameters.push_back(
313 ctx.dst->Construct(ctx.Clone(param->Declaration()->type), inner_struct_values));
314 }
315
316 /// Process the entry point return type.
317 /// This generates a list of output values that are returned by the original
318 /// function.
319 /// @param inner_ret_type the original function return type
320 /// @param original_result the result object produced by the original function
321 void ProcessReturnType(const sem::Type* inner_ret_type, Symbol original_result) {
322 bool do_interpolate = true;
323 // In GLSL, do not add interpolation attributes on fragment output
324 if (cfg.shader_style == ShaderStyle::kGlsl &&
325 func_ast->PipelineStage() == ast::PipelineStage::kFragment) {
326 do_interpolate = false;
327 }
328 if (auto* str = inner_ret_type->As<sem::Struct>()) {
329 for (auto* member : str->Members()) {
330 if (member->Type()->Is<sem::Struct>()) {
331 TINT_ICE(Transform, ctx.dst->Diagnostics()) << "nested IO struct";
332 continue;
333 }
334
335 auto* member_ast = member->Declaration();
336 auto name = ctx.src->Symbols().NameFor(member_ast->symbol);
337 auto attributes = CloneShaderIOAttributes(member_ast->attributes, do_interpolate);
338
339 // Extract the original structure member.
340 AddOutput(name, member->Type(), std::move(attributes),
341 ctx.dst->MemberAccessor(original_result, name));
342 }
343 } else if (!inner_ret_type->Is<sem::Void>()) {
344 auto attributes =
345 CloneShaderIOAttributes(func_ast->return_type_attributes, do_interpolate);
346
347 // Propagate the non-struct return value as is.
348 AddOutput("value", func_sem->ReturnType(), std::move(attributes),
349 ctx.dst->Expr(original_result));
350 }
351 }
352
353 /// Add a fixed sample mask to the wrapper function output.
354 /// If there is already a sample mask, bitwise-and it with the fixed mask.
355 /// Otherwise, create a new output value from the fixed mask.
356 void AddFixedSampleMask() {
357 // Check the existing output values for a sample mask builtin.
358 for (auto& outval : wrapper_output_values) {
359 if (HasSampleMask(outval.attributes)) {
360 // Combine the authored sample mask with the fixed mask.
Ben Clayton0ce9ab02022-05-05 20:23:40 +0000361 outval.value = ctx.dst->And(outval.value, u32(cfg.fixed_sample_mask));
dan sinclair41e4d9a2022-05-01 14:40:55 +0000362 return;
363 }
364 }
365
366 // No existing sample mask builtin was found, so create a new output value
367 // using the fixed sample mask.
368 AddOutput("fixed_sample_mask", ctx.dst->create<sem::U32>(),
369 {ctx.dst->Builtin(ast::Builtin::kSampleMask)},
Ben Clayton0ce9ab02022-05-05 20:23:40 +0000370 ctx.dst->Expr(u32(cfg.fixed_sample_mask)));
dan sinclair41e4d9a2022-05-01 14:40:55 +0000371 }
372
373 /// Add a point size builtin to the wrapper function output.
374 void AddVertexPointSize() {
375 // Create a new output value and assign it a literal 1.0 value.
376 AddOutput("vertex_point_size", ctx.dst->create<sem::F32>(),
Ben Clayton0a3cda92022-05-10 17:30:15 +0000377 {ctx.dst->Builtin(ast::Builtin::kPointSize)}, ctx.dst->Expr(1_f));
dan sinclair41e4d9a2022-05-01 14:40:55 +0000378 }
379
380 /// Create an expression for gl_Position.[component]
381 /// @param component the component of gl_Position to access
382 /// @returns the new expression
383 const ast::Expression* GLPosition(const char* component) {
384 Symbol pos = ctx.dst->Symbols().Register("gl_Position");
385 Symbol c = ctx.dst->Symbols().Register(component);
386 return ctx.dst->MemberAccessor(ctx.dst->Expr(pos), ctx.dst->Expr(c));
387 }
388
389 /// Create the wrapper function's struct parameter and type objects.
390 void CreateInputStruct() {
391 // Sort the struct members to satisfy HLSL interfacing matching rules.
392 std::sort(wrapper_struct_param_members.begin(), wrapper_struct_param_members.end(),
393 StructMemberComparator);
394
395 // Create the new struct type.
396 auto struct_name = ctx.dst->Sym();
397 auto* in_struct = ctx.dst->create<ast::Struct>(struct_name, wrapper_struct_param_members,
398 ast::AttributeList{});
399 ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast, in_struct);
400
401 // Create a new function parameter using this struct type.
402 auto* param = ctx.dst->Param(InputStructSymbol(), ctx.dst->ty.type_name(struct_name));
403 wrapper_ep_parameters.push_back(param);
404 }
405
406 /// Create and return the wrapper function's struct result object.
407 /// @returns the struct type
408 ast::Struct* CreateOutputStruct() {
409 ast::StatementList assignments;
410
411 auto wrapper_result = ctx.dst->Symbols().New("wrapper_result");
412
413 // Create the struct members and their corresponding assignment statements.
414 std::unordered_set<std::string> member_names;
415 for (auto& outval : wrapper_output_values) {
416 // Use the original output name, unless that is already taken.
417 Symbol name;
418 if (member_names.count(outval.name)) {
419 name = ctx.dst->Symbols().New(outval.name);
420 } else {
421 name = ctx.dst->Symbols().Register(outval.name);
422 }
423 member_names.insert(ctx.dst->Symbols().NameFor(name));
424
425 wrapper_struct_output_members.push_back(
426 ctx.dst->Member(name, outval.type, std::move(outval.attributes)));
427 assignments.push_back(
428 ctx.dst->Assign(ctx.dst->MemberAccessor(wrapper_result, name), outval.value));
429 }
430
431 // Sort the struct members to satisfy HLSL interfacing matching rules.
432 std::sort(wrapper_struct_output_members.begin(), wrapper_struct_output_members.end(),
433 StructMemberComparator);
434
435 // Create the new struct type.
436 auto* out_struct = ctx.dst->create<ast::Struct>(
437 ctx.dst->Sym(), wrapper_struct_output_members, ast::AttributeList{});
438 ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast, out_struct);
439
440 // Create the output struct object, assign its members, and return it.
441 auto* result_object = ctx.dst->Var(wrapper_result, ctx.dst->ty.type_name(out_struct->name));
442 wrapper_body.push_back(ctx.dst->Decl(result_object));
443 wrapper_body.insert(wrapper_body.end(), assignments.begin(), assignments.end());
444 wrapper_body.push_back(ctx.dst->Return(wrapper_result));
445
446 return out_struct;
447 }
448
449 /// Create and assign the wrapper function's output variables.
450 void CreateGlobalOutputVariables() {
451 for (auto& outval : wrapper_output_values) {
452 // Disable validation for use of the `output` storage class.
453 ast::AttributeList attributes = std::move(outval.attributes);
454 attributes.push_back(ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass));
455
456 // Create the global variable and assign it the output value.
457 auto name = ctx.dst->Symbols().New(outval.name);
458 auto* type = outval.type;
459 const ast::Expression* lhs = ctx.dst->Expr(name);
460 if (HasSampleMask(attributes)) {
461 // Vulkan requires the type of a SampleMask builtin to be an array.
462 // Declare it as array<u32, 1> and then store to the first element.
Ben Clayton0ce9ab02022-05-05 20:23:40 +0000463 type = ctx.dst->ty.array(type, 1_u);
464 lhs = ctx.dst->IndexAccessor(lhs, 0_i);
dan sinclair41e4d9a2022-05-01 14:40:55 +0000465 }
466 ctx.dst->Global(name, type, ast::StorageClass::kOutput, std::move(attributes));
467 wrapper_body.push_back(ctx.dst->Assign(lhs, outval.value));
468 }
469 }
470
471 // Recreate the original function without entry point attributes and call it.
472 /// @returns the inner function call expression
473 const ast::CallExpression* CallInnerFunction() {
474 Symbol inner_name;
475 if (cfg.shader_style == ShaderStyle::kGlsl) {
476 // In GLSL, clone the original entry point name, as the wrapper will be
477 // called "main".
478 inner_name = ctx.Clone(func_ast->symbol);
479 } else {
480 // Add a suffix to the function name, as the wrapper function will take
481 // the original entry point name.
482 auto ep_name = ctx.src->Symbols().NameFor(func_ast->symbol);
483 inner_name = ctx.dst->Symbols().New(ep_name + "_inner");
484 }
485
486 // Clone everything, dropping the function and return type attributes.
487 // The parameter attributes will have already been stripped during
488 // processing.
489 auto* inner_function = ctx.dst->create<ast::Function>(
490 inner_name, ctx.Clone(func_ast->params), ctx.Clone(func_ast->return_type),
491 ctx.Clone(func_ast->body), ast::AttributeList{}, ast::AttributeList{});
492 ctx.Replace(func_ast, inner_function);
493
494 // Call the function.
495 return ctx.dst->Call(inner_function->symbol, inner_call_parameters);
496 }
497
498 /// Process the entry point function.
499 void Process() {
500 bool needs_fixed_sample_mask = false;
501 bool needs_vertex_point_size = false;
502 if (func_ast->PipelineStage() == ast::PipelineStage::kFragment &&
503 cfg.fixed_sample_mask != 0xFFFFFFFF) {
504 needs_fixed_sample_mask = true;
505 }
506 if (func_ast->PipelineStage() == ast::PipelineStage::kVertex &&
507 cfg.emit_vertex_point_size) {
508 needs_vertex_point_size = true;
509 }
510
511 // Exit early if there is no shader IO to handle.
512 if (func_sem->Parameters().size() == 0 && func_sem->ReturnType()->Is<sem::Void>() &&
513 !needs_fixed_sample_mask && !needs_vertex_point_size &&
514 cfg.shader_style != ShaderStyle::kGlsl) {
515 return;
516 }
517
518 // Process the entry point parameters, collecting those that need to be
519 // aggregated into a single structure.
520 if (!func_sem->Parameters().empty()) {
521 for (auto* param : func_sem->Parameters()) {
522 if (param->Type()->Is<sem::Struct>()) {
523 ProcessStructParameter(param);
524 } else {
525 ProcessNonStructParameter(param);
526 }
527 }
528
529 // Create a structure parameter for the outer entry point if necessary.
530 if (!wrapper_struct_param_members.empty()) {
531 CreateInputStruct();
532 }
533 }
534
535 // Recreate the original function and call it.
536 auto* call_inner = CallInnerFunction();
537
538 // Process the return type, and start building the wrapper function body.
539 std::function<const ast::Type*()> wrapper_ret_type = [&] { return ctx.dst->ty.void_(); };
540 if (func_sem->ReturnType()->Is<sem::Void>()) {
541 // The function call is just a statement with no result.
542 wrapper_body.push_back(ctx.dst->CallStmt(call_inner));
543 } else {
544 // Capture the result of calling the original function.
545 auto* inner_result =
546 ctx.dst->Let(ctx.dst->Symbols().New("inner_result"), nullptr, call_inner);
547 wrapper_body.push_back(ctx.dst->Decl(inner_result));
548
549 // Process the original return type to determine the outputs that the
550 // outer function needs to produce.
551 ProcessReturnType(func_sem->ReturnType(), inner_result->symbol);
552 }
553
554 // Add a fixed sample mask, if necessary.
555 if (needs_fixed_sample_mask) {
556 AddFixedSampleMask();
557 }
558
559 // Add the pointsize builtin, if necessary.
560 if (needs_vertex_point_size) {
561 AddVertexPointSize();
562 }
563
564 // Produce the entry point outputs, if necessary.
565 if (!wrapper_output_values.empty()) {
566 if (cfg.shader_style == ShaderStyle::kSpirv || cfg.shader_style == ShaderStyle::kGlsl) {
567 CreateGlobalOutputVariables();
568 } else {
569 auto* output_struct = CreateOutputStruct();
570 wrapper_ret_type = [&, output_struct] {
571 return ctx.dst->ty.type_name(output_struct->name);
572 };
573 }
574 }
575
576 if (cfg.shader_style == ShaderStyle::kGlsl &&
577 func_ast->PipelineStage() == ast::PipelineStage::kVertex) {
578 auto* pos_y = GLPosition("y");
579 auto* negate_pos_y =
580 ctx.dst->create<ast::UnaryOpExpression>(ast::UnaryOp::kNegation, GLPosition("y"));
581 wrapper_body.push_back(ctx.dst->Assign(pos_y, negate_pos_y));
582
Ben Clayton0a3cda92022-05-10 17:30:15 +0000583 auto* two_z = ctx.dst->Mul(ctx.dst->Expr(2_f), GLPosition("z"));
dan sinclair41e4d9a2022-05-01 14:40:55 +0000584 auto* fixed_z = ctx.dst->Sub(two_z, GLPosition("w"));
585 wrapper_body.push_back(ctx.dst->Assign(GLPosition("z"), fixed_z));
586 }
587
588 // Create the wrapper entry point function.
589 // For GLSL, use "main", otherwise take the name of the original
590 // entry point function.
591 Symbol name;
592 if (cfg.shader_style == ShaderStyle::kGlsl) {
593 name = ctx.dst->Symbols().New("main");
594 } else {
595 name = ctx.Clone(func_ast->symbol);
596 }
597
598 auto* wrapper_func = ctx.dst->create<ast::Function>(
599 name, wrapper_ep_parameters, wrapper_ret_type(), ctx.dst->Block(wrapper_body),
600 ctx.Clone(func_ast->attributes), ast::AttributeList{});
601 ctx.InsertAfter(ctx.src->AST().GlobalDeclarations(), func_ast, wrapper_func);
602 }
603
604 /// Retrieve the gl_ string corresponding to a builtin.
605 /// @param builtin the builtin
606 /// @param stage the current pipeline stage
607 /// @param storage_class the storage class (input or output)
608 /// @returns the gl_ string corresponding to that builtin
609 const char* GLSLBuiltinToString(ast::Builtin builtin,
610 ast::PipelineStage stage,
611 ast::StorageClass storage_class) {
612 switch (builtin) {
613 case ast::Builtin::kPosition:
614 switch (stage) {
615 case ast::PipelineStage::kVertex:
616 return "gl_Position";
617 case ast::PipelineStage::kFragment:
618 return "gl_FragCoord";
619 default:
620 return "";
621 }
622 case ast::Builtin::kVertexIndex:
623 return "gl_VertexID";
624 case ast::Builtin::kInstanceIndex:
625 return "gl_InstanceID";
626 case ast::Builtin::kFrontFacing:
627 return "gl_FrontFacing";
628 case ast::Builtin::kFragDepth:
629 return "gl_FragDepth";
630 case ast::Builtin::kLocalInvocationId:
631 return "gl_LocalInvocationID";
632 case ast::Builtin::kLocalInvocationIndex:
633 return "gl_LocalInvocationIndex";
634 case ast::Builtin::kGlobalInvocationId:
635 return "gl_GlobalInvocationID";
636 case ast::Builtin::kNumWorkgroups:
637 return "gl_NumWorkGroups";
638 case ast::Builtin::kWorkgroupId:
639 return "gl_WorkGroupID";
640 case ast::Builtin::kSampleIndex:
641 return "gl_SampleID";
642 case ast::Builtin::kSampleMask:
643 if (storage_class == ast::StorageClass::kInput) {
644 return "gl_SampleMaskIn";
645 } else {
646 return "gl_SampleMask";
647 }
648 default:
649 return "";
650 }
651 }
652
653 /// Convert a given GLSL builtin value to the corresponding WGSL value.
654 /// @param builtin the builtin variable
655 /// @param value the value to convert
656 /// @param ast_type (inout) the incoming WGSL and outgoing GLSL types
657 /// @returns an expression representing the GLSL builtin converted to what
658 /// WGSL expects
659 const ast::Expression* FromGLSLBuiltin(ast::Builtin builtin,
660 const ast::Expression* value,
661 const ast::Type*& ast_type) {
662 switch (builtin) {
663 case ast::Builtin::kVertexIndex:
664 case ast::Builtin::kInstanceIndex:
665 case ast::Builtin::kSampleIndex:
666 // GLSL uses i32 for these, so bitcast to u32.
667 value = ctx.dst->Bitcast(ast_type, value);
668 ast_type = ctx.dst->ty.i32();
669 break;
670 case ast::Builtin::kSampleMask:
671 // gl_SampleMask is an array of i32. Retrieve the first element and
672 // bitcast it to u32.
Ben Clayton0ce9ab02022-05-05 20:23:40 +0000673 value = ctx.dst->IndexAccessor(value, 0_i);
dan sinclair41e4d9a2022-05-01 14:40:55 +0000674 value = ctx.dst->Bitcast(ast_type, value);
Ben Clayton0ce9ab02022-05-05 20:23:40 +0000675 ast_type = ctx.dst->ty.array(ctx.dst->ty.i32(), 1_u);
dan sinclair41e4d9a2022-05-01 14:40:55 +0000676 break;
677 default:
678 break;
679 }
680 return value;
681 }
682
683 /// Convert a given WGSL value to the type expected when assigning to a
684 /// GLSL builtin.
685 /// @param builtin the builtin variable
686 /// @param value the value to convert
687 /// @param type (out) the type to which the value was converted
688 /// @returns the converted value which can be assigned to the GLSL builtin
689 const ast::Expression* ToGLSLBuiltin(ast::Builtin builtin,
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000690 const ast::Expression* value,
dan sinclair41e4d9a2022-05-01 14:40:55 +0000691 const sem::Type*& type) {
692 switch (builtin) {
693 case ast::Builtin::kVertexIndex:
694 case ast::Builtin::kInstanceIndex:
695 case ast::Builtin::kSampleIndex:
696 case ast::Builtin::kSampleMask:
697 type = ctx.dst->create<sem::I32>();
698 value = ctx.dst->Bitcast(CreateASTTypeFor(ctx, type), value);
699 break;
700 default:
701 break;
702 }
703 return value;
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000704 }
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000705};
706
dan sinclair41e4d9a2022-05-01 14:40:55 +0000707void CanonicalizeEntryPointIO::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const {
708 auto* cfg = inputs.Get<Config>();
709 if (cfg == nullptr) {
710 ctx.dst->Diagnostics().add_error(
711 diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name));
712 return;
713 }
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000714
dan sinclair41e4d9a2022-05-01 14:40:55 +0000715 // Remove entry point IO attributes from struct declarations.
716 // New structures will be created for each entry point, as necessary.
717 for (auto* ty : ctx.src->AST().TypeDecls()) {
718 if (auto* struct_ty = ty->As<ast::Struct>()) {
719 for (auto* member : struct_ty->members) {
720 for (auto* attr : member->attributes) {
721 if (IsShaderIOAttribute(attr)) {
722 ctx.Remove(member->attributes, attr);
723 }
724 }
725 }
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000726 }
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000727 }
728
dan sinclair41e4d9a2022-05-01 14:40:55 +0000729 for (auto* func_ast : ctx.src->AST().Functions()) {
730 if (!func_ast->IsEntryPoint()) {
731 continue;
732 }
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000733
dan sinclair41e4d9a2022-05-01 14:40:55 +0000734 State state(ctx, *cfg, func_ast);
735 state.Process();
736 }
737
738 ctx.Clone();
Ryan Harrisondbc13af2022-02-21 15:19:07 +0000739}
740
741CanonicalizeEntryPointIO::Config::Config(ShaderStyle style,
742 uint32_t sample_mask,
743 bool emit_point_size)
744 : shader_style(style),
745 fixed_sample_mask(sample_mask),
746 emit_vertex_point_size(emit_point_size) {}
747
748CanonicalizeEntryPointIO::Config::Config(const Config&) = default;
749CanonicalizeEntryPointIO::Config::~Config() = default;
750
dan sinclairb5599d32022-04-07 16:55:14 +0000751} // namespace tint::transform