blob: ba61160940126fe060520f374578db74a90e2fd0 [file] [log] [blame]
James Price36464002021-09-07 18:59:21 +00001// Copyright 2021 The Tint Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "src/transform/module_scope_var_to_entry_point_param.h"
16
17#include <unordered_map>
James Priceacaecab2021-09-13 19:56:01 +000018#include <unordered_set>
James Price36464002021-09-07 18:59:21 +000019#include <utility>
20#include <vector>
21
22#include "src/ast/disable_validation_decoration.h"
23#include "src/program_builder.h"
24#include "src/sem/call.h"
25#include "src/sem/function.h"
26#include "src/sem/statement.h"
27#include "src/sem/variable.h"
28
29TINT_INSTANTIATE_TYPEINFO(tint::transform::ModuleScopeVarToEntryPointParam);
30
31namespace tint {
32namespace transform {
James Priceacaecab2021-09-13 19:56:01 +000033namespace {
34// Returns `true` if `type` is or contains a matrix type.
35bool ContainsMatrix(const sem::Type* type) {
36 type = type->UnwrapRef();
37 if (type->Is<sem::Matrix>()) {
38 return true;
39 } else if (auto* ary = type->As<sem::Array>()) {
40 return ContainsMatrix(ary->ElemType());
41 } else if (auto* str = type->As<sem::Struct>()) {
42 for (auto* member : str->Members()) {
43 if (ContainsMatrix(member->Type())) {
44 return true;
45 }
46 }
47 }
48 return false;
49}
50} // namespace
James Price36464002021-09-07 18:59:21 +000051
James Price1ca6fba2021-09-29 18:56:17 +000052/// State holds the current transform state.
53struct ModuleScopeVarToEntryPointParam::State {
54 /// The clone context.
55 CloneContext& ctx;
James Price36464002021-09-07 18:59:21 +000056
James Price1ca6fba2021-09-29 18:56:17 +000057 /// Constructor
58 /// @param context the clone context
59 explicit State(CloneContext& context) : ctx(context) {}
James Price36464002021-09-07 18:59:21 +000060
James Price1ca6fba2021-09-29 18:56:17 +000061 /// Clone any struct types that are contained in `ty` (including `ty` itself),
62 /// and add it to the global declarations now, so that they precede new global
63 /// declarations that need to reference them.
64 /// @param ty the type to clone
65 void CloneStructTypes(const sem::Type* ty) {
66 if (auto* str = ty->As<sem::Struct>()) {
67 if (!cloned_structs_.emplace(str).second) {
68 // The struct has already been cloned.
69 return;
James Price36464002021-09-07 18:59:21 +000070 }
James Price36464002021-09-07 18:59:21 +000071
James Price1ca6fba2021-09-29 18:56:17 +000072 // Recurse into members.
73 for (auto* member : str->Members()) {
74 CloneStructTypes(member->Type());
James Price36464002021-09-07 18:59:21 +000075 }
James Price1ca6fba2021-09-29 18:56:17 +000076
77 // Clone the struct and add it to the global declaration list.
78 // Remove the old declaration.
79 auto* ast_str = str->Declaration();
Ben Clayton86481202021-10-19 18:38:54 +000080 ctx.dst->AST().AddTypeDecl(ctx.Clone(ast_str));
James Price1ca6fba2021-09-29 18:56:17 +000081 ctx.Remove(ctx.src->AST().GlobalDeclarations(), ast_str);
82 } else if (auto* arr = ty->As<sem::Array>()) {
83 CloneStructTypes(arr->ElemType());
James Price36464002021-09-07 18:59:21 +000084 }
85 }
86
James Price1ca6fba2021-09-29 18:56:17 +000087 /// Process the module.
88 void Process() {
89 // Predetermine the list of function calls that need to be replaced.
90 using CallList = std::vector<const ast::CallExpression*>;
91 std::unordered_map<const ast::Function*, CallList> calls_to_replace;
92
Ben Clayton86481202021-10-19 18:38:54 +000093 std::vector<const ast::Function*> functions_to_process;
James Price1ca6fba2021-09-29 18:56:17 +000094
James Pricee548db92021-10-28 15:00:39 +000095 // Build a list of functions that transitively reference any module-scope
96 // variables.
James Price1ca6fba2021-09-29 18:56:17 +000097 for (auto* func_ast : ctx.src->AST().Functions()) {
98 auto* func_sem = ctx.src->Sem().Get(func_ast);
99
100 bool needs_processing = false;
Ben Clayton2423df32021-11-04 22:29:22 +0000101 for (auto* var : func_sem->TransitivelyReferencedGlobals()) {
James Pricee548db92021-10-28 15:00:39 +0000102 if (var->StorageClass() != ast::StorageClass::kNone) {
James Price1ca6fba2021-09-29 18:56:17 +0000103 needs_processing = true;
104 break;
105 }
106 }
James Price1ca6fba2021-09-29 18:56:17 +0000107 if (needs_processing) {
108 functions_to_process.push_back(func_ast);
109
110 // Find all of the calls to this function that will need to be replaced.
111 for (auto* call : func_sem->CallSites()) {
Ben Claytona9156ff2021-11-05 16:51:38 +0000112 calls_to_replace[call->Stmt()->Function()->Declaration()].push_back(
113 call->Declaration());
James Price1ca6fba2021-09-29 18:56:17 +0000114 }
115 }
James Price36464002021-09-07 18:59:21 +0000116 }
James Price36464002021-09-07 18:59:21 +0000117
James Price1ca6fba2021-09-29 18:56:17 +0000118 // Build a list of `&ident` expressions. We'll use this later to avoid
119 // generating expressions of the form `&*ident`, which break WGSL validation
120 // rules when this expression is passed to a function.
121 // TODO(jrprice): We should add support for bidirectional SEM tree traversal
122 // so that we can do this on the fly instead.
Ben Clayton86481202021-10-19 18:38:54 +0000123 std::unordered_map<const ast::IdentifierExpression*,
124 const ast::UnaryOpExpression*>
James Price1ca6fba2021-09-29 18:56:17 +0000125 ident_to_address_of;
126 for (auto* node : ctx.src->ASTNodes().Objects()) {
127 auto* address_of = node->As<ast::UnaryOpExpression>();
Ben Clayton4f3ff572021-10-15 17:33:10 +0000128 if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) {
James Price36464002021-09-07 18:59:21 +0000129 continue;
130 }
Ben Clayton4f3ff572021-10-15 17:33:10 +0000131 if (auto* ident = address_of->expr->As<ast::IdentifierExpression>()) {
James Price1ca6fba2021-09-29 18:56:17 +0000132 ident_to_address_of[ident] = address_of;
133 }
134 }
James Price36464002021-09-07 18:59:21 +0000135
James Price1ca6fba2021-09-29 18:56:17 +0000136 for (auto* func_ast : functions_to_process) {
137 auto* func_sem = ctx.src->Sem().Get(func_ast);
138 bool is_entry_point = func_ast->IsEntryPoint();
James Price36464002021-09-07 18:59:21 +0000139
James Pricee548db92021-10-28 15:00:39 +0000140 // Map module-scope variables onto their replacement.
141 struct NewVar {
142 Symbol symbol;
143 bool is_pointer;
144 };
145 std::unordered_map<const sem::Variable*, NewVar> var_to_newvar;
James Price36464002021-09-07 18:59:21 +0000146
James Price1ca6fba2021-09-29 18:56:17 +0000147 // We aggregate all workgroup variables into a struct to avoid hitting
148 // MSL's limit for threadgroup memory arguments.
149 Symbol workgroup_parameter_symbol;
150 ast::StructMemberList workgroup_parameter_members;
151 auto workgroup_param = [&]() {
152 if (!workgroup_parameter_symbol.IsValid()) {
153 workgroup_parameter_symbol = ctx.dst->Sym();
154 }
155 return workgroup_parameter_symbol;
156 };
James Priceacaecab2021-09-13 19:56:01 +0000157
Ben Clayton2423df32021-11-04 22:29:22 +0000158 for (auto* var : func_sem->TransitivelyReferencedGlobals()) {
James Pricee548db92021-10-28 15:00:39 +0000159 auto sc = var->StorageClass();
160 if (sc == ast::StorageClass::kNone) {
James Price1ca6fba2021-09-29 18:56:17 +0000161 continue;
162 }
James Pricee548db92021-10-28 15:00:39 +0000163 if (sc != ast::StorageClass::kPrivate &&
164 sc != ast::StorageClass::kStorage &&
165 sc != ast::StorageClass::kUniform &&
166 sc != ast::StorageClass::kUniformConstant &&
167 sc != ast::StorageClass::kWorkgroup) {
168 TINT_ICE(Transform, ctx.dst->Diagnostics())
169 << "unhandled module-scope storage class (" << sc << ")";
170 }
James Price1ca6fba2021-09-29 18:56:17 +0000171
172 // This is the symbol for the variable that replaces the module-scope
173 // var.
174 auto new_var_symbol = ctx.dst->Sym();
175
176 // Helper to create an AST node for the store type of the variable.
177 auto store_type = [&]() {
178 return CreateASTTypeFor(ctx, var->Type()->UnwrapRef());
179 };
180
181 // Track whether the new variable is a pointer or not.
182 bool is_pointer = false;
183
184 if (is_entry_point) {
185 if (var->Type()->UnwrapRef()->is_handle()) {
186 // For a texture or sampler variable, redeclare it as an entry point
187 // parameter. Disable entry point parameter validation.
Corentin Wallez40ef4a82021-09-27 19:00:15 +0000188 auto* disable_validation =
James Price8d7551c2021-10-28 15:00:39 +0000189 ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter);
Ben Clayton4f3ff572021-10-15 17:33:10 +0000190 auto decos = ctx.Clone(var->Declaration()->decorations);
James Price1ca6fba2021-09-29 18:56:17 +0000191 decos.push_back(disable_validation);
192 auto* param = ctx.dst->Param(new_var_symbol, store_type(), decos);
Ben Clayton4f3ff572021-10-15 17:33:10 +0000193 ctx.InsertFront(func_ast->params, param);
James Pricee548db92021-10-28 15:00:39 +0000194 } else if (sc == ast::StorageClass::kStorage ||
195 sc == ast::StorageClass::kUniform) {
196 // Variables into the Storage and Uniform storage classes are
197 // redeclared as entry point parameters with a pointer type.
198 auto attributes = ctx.Clone(var->Declaration()->decorations);
James Price8d7551c2021-10-28 15:00:39 +0000199 attributes.push_back(ctx.dst->Disable(
200 ast::DisabledValidation::kEntryPointParameter));
James Pricee548db92021-10-28 15:00:39 +0000201 attributes.push_back(
James Price8d7551c2021-10-28 15:00:39 +0000202 ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass));
James Pricee548db92021-10-28 15:00:39 +0000203 auto* param_type = ctx.dst->ty.pointer(
204 store_type(), sc, var->Declaration()->declared_access);
205 auto* param =
206 ctx.dst->Param(new_var_symbol, param_type, attributes);
207 ctx.InsertFront(func_ast->params, param);
208 is_pointer = true;
209 } else if (sc == ast::StorageClass::kWorkgroup &&
James Price1ca6fba2021-09-29 18:56:17 +0000210 ContainsMatrix(var->Type())) {
211 // Due to a bug in the MSL compiler, we use a threadgroup memory
212 // argument for any workgroup allocation that contains a matrix.
213 // See crbug.com/tint/938.
214 // TODO(jrprice): Do this for all other workgroup variables too.
215
216 // Create a member in the workgroup parameter struct.
Ben Clayton4f3ff572021-10-15 17:33:10 +0000217 auto member = ctx.Clone(var->Declaration()->symbol);
James Price1ca6fba2021-09-29 18:56:17 +0000218 workgroup_parameter_members.push_back(
219 ctx.dst->Member(member, store_type()));
220 CloneStructTypes(var->Type()->UnwrapRef());
221
222 // Create a function-scope variable that is a pointer to the member.
223 auto* member_ptr = ctx.dst->AddressOf(ctx.dst->MemberAccessor(
224 ctx.dst->Deref(workgroup_param()), member));
225 auto* local_var =
226 ctx.dst->Const(new_var_symbol,
227 ctx.dst->ty.pointer(
228 store_type(), ast::StorageClass::kWorkgroup),
229 member_ptr);
Ben Clayton4f3ff572021-10-15 17:33:10 +0000230 ctx.InsertFront(func_ast->body->statements,
James Price1ca6fba2021-09-29 18:56:17 +0000231 ctx.dst->Decl(local_var));
James Priceacaecab2021-09-13 19:56:01 +0000232 is_pointer = true;
233 } else {
James Pricee548db92021-10-28 15:00:39 +0000234 // Variables in the Private and Workgroup storage classes are
235 // redeclared at function scope. Disable storage class validation on
236 // this variable.
James Priceacaecab2021-09-13 19:56:01 +0000237 auto* disable_validation =
James Price8d7551c2021-10-28 15:00:39 +0000238 ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass);
Ben Clayton4f3ff572021-10-15 17:33:10 +0000239 auto* constructor = ctx.Clone(var->Declaration()->constructor);
James Pricee548db92021-10-28 15:00:39 +0000240 auto* local_var =
241 ctx.dst->Var(new_var_symbol, store_type(), sc, constructor,
242 ast::DecorationList{disable_validation});
Ben Clayton4f3ff572021-10-15 17:33:10 +0000243 ctx.InsertFront(func_ast->body->statements,
James Priceacaecab2021-09-13 19:56:01 +0000244 ctx.dst->Decl(local_var));
245 }
James Price1ca6fba2021-09-29 18:56:17 +0000246 } else {
247 // For a regular function, redeclare the variable as a parameter.
248 // Use a pointer for non-handle types.
249 auto* param_type = store_type();
250 ast::DecorationList attributes;
Ben Clayton5029e702021-10-15 14:17:31 +0000251 if (!var->Type()->UnwrapRef()->is_handle()) {
James Pricee548db92021-10-28 15:00:39 +0000252 param_type = ctx.dst->ty.pointer(
253 param_type, sc, var->Declaration()->declared_access);
James Price1ca6fba2021-09-29 18:56:17 +0000254 is_pointer = true;
James Price36464002021-09-07 18:59:21 +0000255
James Pricee548db92021-10-28 15:00:39 +0000256 // Disable validation of the parameter's storage class and of
257 // arguments passed it.
258 attributes.push_back(
James Price8d7551c2021-10-28 15:00:39 +0000259 ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass));
260 attributes.push_back(ctx.dst->Disable(
261 ast::DisabledValidation::kIgnoreInvalidPointerArgument));
James Price1ca6fba2021-09-29 18:56:17 +0000262 }
263 ctx.InsertBack(
Ben Clayton4f3ff572021-10-15 17:33:10 +0000264 func_ast->params,
James Price1ca6fba2021-09-29 18:56:17 +0000265 ctx.dst->Param(new_var_symbol, param_type, attributes));
266 }
267
268 // Replace all uses of the module-scope variable.
269 // For non-entry points, dereference non-handle pointer parameters.
270 for (auto* user : var->Users()) {
Ben Claytona9156ff2021-11-05 16:51:38 +0000271 if (user->Stmt()->Function()->Declaration() == func_ast) {
Ben Clayton86481202021-10-19 18:38:54 +0000272 const ast::Expression* expr = ctx.dst->Expr(new_var_symbol);
James Price1ca6fba2021-09-29 18:56:17 +0000273 if (is_pointer) {
274 // If this identifier is used by an address-of operator, just
275 // remove the address-of instead of adding a deref, since we
276 // already have a pointer.
277 auto* ident =
278 user->Declaration()->As<ast::IdentifierExpression>();
279 if (ident_to_address_of.count(ident)) {
280 ctx.Replace(ident_to_address_of[ident], expr);
281 continue;
282 }
283
284 expr = ctx.dst->Deref(expr);
James Price36464002021-09-07 18:59:21 +0000285 }
James Price1ca6fba2021-09-29 18:56:17 +0000286 ctx.Replace(user->Declaration(), expr);
James Price36464002021-09-07 18:59:21 +0000287 }
James Price36464002021-09-07 18:59:21 +0000288 }
James Price1ca6fba2021-09-29 18:56:17 +0000289
James Pricee548db92021-10-28 15:00:39 +0000290 var_to_newvar[var] = {new_var_symbol, is_pointer};
James Price36464002021-09-07 18:59:21 +0000291 }
292
James Price1ca6fba2021-09-29 18:56:17 +0000293 if (!workgroup_parameter_members.empty()) {
294 // Create the workgroup memory parameter.
295 // The parameter is a struct that contains members for each workgroup
296 // variable.
297 auto* str = ctx.dst->Structure(ctx.dst->Sym(),
298 std::move(workgroup_parameter_members));
299 auto* param_type = ctx.dst->ty.pointer(ctx.dst->ty.Of(str),
300 ast::StorageClass::kWorkgroup);
301 auto* disable_validation =
James Price8d7551c2021-10-28 15:00:39 +0000302 ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter);
James Price1ca6fba2021-09-29 18:56:17 +0000303 auto* param =
304 ctx.dst->Param(workgroup_param(), param_type, {disable_validation});
Ben Clayton4f3ff572021-10-15 17:33:10 +0000305 ctx.InsertFront(func_ast->params, param);
James Price1ca6fba2021-09-29 18:56:17 +0000306 }
307
308 // Pass the variables as pointers to any functions that need them.
309 for (auto* call : calls_to_replace[func_ast]) {
Ben Clayton735dca82021-11-15 20:45:50 +0000310 auto* target =
311 ctx.src->AST().Functions().Find(call->target.name->symbol);
James Price1ca6fba2021-09-29 18:56:17 +0000312 auto* target_sem = ctx.src->Sem().Get(target);
313
314 // Add new arguments for any variables that are needed by the callee.
315 // For entry points, pass non-handle types as pointers.
Ben Clayton2423df32021-11-04 22:29:22 +0000316 for (auto* target_var : target_sem->TransitivelyReferencedGlobals()) {
James Pricee548db92021-10-28 15:00:39 +0000317 auto sc = target_var->StorageClass();
318 if (sc == ast::StorageClass::kNone) {
319 continue;
James Price1ca6fba2021-09-29 18:56:17 +0000320 }
James Pricee548db92021-10-28 15:00:39 +0000321
322 auto new_var = var_to_newvar[target_var];
323 bool is_handle = target_var->Type()->UnwrapRef()->is_handle();
324 const ast::Expression* arg = ctx.dst->Expr(new_var.symbol);
325 if (is_entry_point && !is_handle && !new_var.is_pointer) {
326 // We need to pass a pointer and we don't already have one, so take
327 // the address of the new variable.
328 arg = ctx.dst->AddressOf(arg);
329 }
330 ctx.InsertBack(call->args, arg);
James Price1ca6fba2021-09-29 18:56:17 +0000331 }
332 }
James Price36464002021-09-07 18:59:21 +0000333 }
334
James Price1ca6fba2021-09-29 18:56:17 +0000335 // Now remove all module-scope variables with these storage classes.
336 for (auto* var_ast : ctx.src->AST().GlobalVariables()) {
337 auto* var_sem = ctx.src->Sem().Get(var_ast);
James Pricee548db92021-10-28 15:00:39 +0000338 if (var_sem->StorageClass() != ast::StorageClass::kNone) {
James Price1ca6fba2021-09-29 18:56:17 +0000339 ctx.Remove(ctx.src->AST().GlobalDeclarations(), var_ast);
James Price36464002021-09-07 18:59:21 +0000340 }
341 }
342 }
343
James Price1ca6fba2021-09-29 18:56:17 +0000344 private:
345 std::unordered_set<const sem::Struct*> cloned_structs_;
346};
James Price36464002021-09-07 18:59:21 +0000347
James Price1ca6fba2021-09-29 18:56:17 +0000348ModuleScopeVarToEntryPointParam::ModuleScopeVarToEntryPointParam() = default;
349
350ModuleScopeVarToEntryPointParam::~ModuleScopeVarToEntryPointParam() = default;
351
352void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
353 const DataMap&,
354 DataMap&) {
355 State state{ctx};
356 state.Process();
James Price36464002021-09-07 18:59:21 +0000357 ctx.Clone();
358}
359
360} // namespace transform
361} // namespace tint