|  | // Copyright 2021 The Tint Authors. | 
|  | // | 
|  | // Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | // you may not use this file except in compliance with the License. | 
|  | // You may obtain a copy of the License at | 
|  | // | 
|  | //     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | // | 
|  | // Unless required by applicable law or agreed to in writing, software | 
|  | // distributed under the License is distributed on an "AS IS" BASIS, | 
|  | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | // See the License for the specific language governing permissions and | 
|  | // limitations under the License. | 
|  |  | 
|  | #include "src/transform/module_scope_var_to_entry_point_param.h" | 
|  |  | 
|  | #include <unordered_map> | 
|  | #include <unordered_set> | 
|  | #include <utility> | 
|  | #include <vector> | 
|  |  | 
|  | #include "src/ast/disable_validation_decoration.h" | 
|  | #include "src/program_builder.h" | 
|  | #include "src/sem/call.h" | 
|  | #include "src/sem/function.h" | 
|  | #include "src/sem/statement.h" | 
|  | #include "src/sem/variable.h" | 
|  |  | 
|  | TINT_INSTANTIATE_TYPEINFO(tint::transform::ModuleScopeVarToEntryPointParam); | 
|  |  | 
|  | namespace tint { | 
|  | namespace transform { | 
|  | namespace { | 
|  | // Returns `true` if `type` is or contains a matrix type. | 
|  | bool ContainsMatrix(const sem::Type* type) { | 
|  | type = type->UnwrapRef(); | 
|  | if (type->Is<sem::Matrix>()) { | 
|  | return true; | 
|  | } else if (auto* ary = type->As<sem::Array>()) { | 
|  | return ContainsMatrix(ary->ElemType()); | 
|  | } else if (auto* str = type->As<sem::Struct>()) { | 
|  | for (auto* member : str->Members()) { | 
|  | if (ContainsMatrix(member->Type())) { | 
|  | return true; | 
|  | } | 
|  | } | 
|  | } | 
|  | return false; | 
|  | } | 
|  | }  // namespace | 
|  |  | 
|  | /// State holds the current transform state. | 
|  | struct ModuleScopeVarToEntryPointParam::State { | 
|  | /// The clone context. | 
|  | CloneContext& ctx; | 
|  |  | 
|  | /// Constructor | 
|  | /// @param context the clone context | 
|  | explicit State(CloneContext& context) : ctx(context) {} | 
|  |  | 
|  | /// Clone any struct types that are contained in `ty` (including `ty` itself), | 
|  | /// and add it to the global declarations now, so that they precede new global | 
|  | /// declarations that need to reference them. | 
|  | /// @param ty the type to clone | 
|  | void CloneStructTypes(const sem::Type* ty) { | 
|  | if (auto* str = ty->As<sem::Struct>()) { | 
|  | if (!cloned_structs_.emplace(str).second) { | 
|  | // The struct has already been cloned. | 
|  | return; | 
|  | } | 
|  |  | 
|  | // Recurse into members. | 
|  | for (auto* member : str->Members()) { | 
|  | CloneStructTypes(member->Type()); | 
|  | } | 
|  |  | 
|  | // Clone the struct and add it to the global declaration list. | 
|  | // Remove the old declaration. | 
|  | auto* ast_str = str->Declaration(); | 
|  | ctx.dst->AST().AddTypeDecl(ctx.Clone(ast_str)); | 
|  | ctx.Remove(ctx.src->AST().GlobalDeclarations(), ast_str); | 
|  | } else if (auto* arr = ty->As<sem::Array>()) { | 
|  | CloneStructTypes(arr->ElemType()); | 
|  | } | 
|  | } | 
|  |  | 
|  | /// Process the module. | 
|  | void Process() { | 
|  | // Predetermine the list of function calls that need to be replaced. | 
|  | using CallList = std::vector<const ast::CallExpression*>; | 
|  | std::unordered_map<const ast::Function*, CallList> calls_to_replace; | 
|  |  | 
|  | std::vector<const ast::Function*> functions_to_process; | 
|  |  | 
|  | // Build a list of functions that transitively reference any module-scope | 
|  | // variables. | 
|  | for (auto* func_ast : ctx.src->AST().Functions()) { | 
|  | auto* func_sem = ctx.src->Sem().Get(func_ast); | 
|  |  | 
|  | bool needs_processing = false; | 
|  | for (auto* var : func_sem->ReferencedModuleVariables()) { | 
|  | if (var->StorageClass() != ast::StorageClass::kNone) { | 
|  | needs_processing = true; | 
|  | break; | 
|  | } | 
|  | } | 
|  | if (needs_processing) { | 
|  | functions_to_process.push_back(func_ast); | 
|  |  | 
|  | // Find all of the calls to this function that will need to be replaced. | 
|  | for (auto* call : func_sem->CallSites()) { | 
|  | auto* call_sem = ctx.src->Sem().Get(call); | 
|  | calls_to_replace[call_sem->Stmt()->Function()].push_back(call); | 
|  | } | 
|  | } | 
|  | } | 
|  |  | 
|  | // Build a list of `&ident` expressions. We'll use this later to avoid | 
|  | // generating expressions of the form `&*ident`, which break WGSL validation | 
|  | // rules when this expression is passed to a function. | 
|  | // TODO(jrprice): We should add support for bidirectional SEM tree traversal | 
|  | // so that we can do this on the fly instead. | 
|  | std::unordered_map<const ast::IdentifierExpression*, | 
|  | const ast::UnaryOpExpression*> | 
|  | ident_to_address_of; | 
|  | for (auto* node : ctx.src->ASTNodes().Objects()) { | 
|  | auto* address_of = node->As<ast::UnaryOpExpression>(); | 
|  | if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) { | 
|  | continue; | 
|  | } | 
|  | if (auto* ident = address_of->expr->As<ast::IdentifierExpression>()) { | 
|  | ident_to_address_of[ident] = address_of; | 
|  | } | 
|  | } | 
|  |  | 
|  | for (auto* func_ast : functions_to_process) { | 
|  | auto* func_sem = ctx.src->Sem().Get(func_ast); | 
|  | bool is_entry_point = func_ast->IsEntryPoint(); | 
|  |  | 
|  | // Map module-scope variables onto their replacement. | 
|  | struct NewVar { | 
|  | Symbol symbol; | 
|  | bool is_pointer; | 
|  | }; | 
|  | std::unordered_map<const sem::Variable*, NewVar> var_to_newvar; | 
|  |  | 
|  | // We aggregate all workgroup variables into a struct to avoid hitting | 
|  | // MSL's limit for threadgroup memory arguments. | 
|  | Symbol workgroup_parameter_symbol; | 
|  | ast::StructMemberList workgroup_parameter_members; | 
|  | auto workgroup_param = [&]() { | 
|  | if (!workgroup_parameter_symbol.IsValid()) { | 
|  | workgroup_parameter_symbol = ctx.dst->Sym(); | 
|  | } | 
|  | return workgroup_parameter_symbol; | 
|  | }; | 
|  |  | 
|  | for (auto* var : func_sem->ReferencedModuleVariables()) { | 
|  | auto sc = var->StorageClass(); | 
|  | if (sc == ast::StorageClass::kNone) { | 
|  | continue; | 
|  | } | 
|  | if (sc != ast::StorageClass::kPrivate && | 
|  | sc != ast::StorageClass::kStorage && | 
|  | sc != ast::StorageClass::kUniform && | 
|  | sc != ast::StorageClass::kUniformConstant && | 
|  | sc != ast::StorageClass::kWorkgroup) { | 
|  | TINT_ICE(Transform, ctx.dst->Diagnostics()) | 
|  | << "unhandled module-scope storage class (" << sc << ")"; | 
|  | } | 
|  |  | 
|  | // This is the symbol for the variable that replaces the module-scope | 
|  | // var. | 
|  | auto new_var_symbol = ctx.dst->Sym(); | 
|  |  | 
|  | // Helper to create an AST node for the store type of the variable. | 
|  | auto store_type = [&]() { | 
|  | return CreateASTTypeFor(ctx, var->Type()->UnwrapRef()); | 
|  | }; | 
|  |  | 
|  | // Track whether the new variable is a pointer or not. | 
|  | bool is_pointer = false; | 
|  |  | 
|  | if (is_entry_point) { | 
|  | if (var->Type()->UnwrapRef()->is_handle()) { | 
|  | // For a texture or sampler variable, redeclare it as an entry point | 
|  | // parameter. Disable entry point parameter validation. | 
|  | auto* disable_validation = | 
|  | ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter); | 
|  | auto decos = ctx.Clone(var->Declaration()->decorations); | 
|  | decos.push_back(disable_validation); | 
|  | auto* param = ctx.dst->Param(new_var_symbol, store_type(), decos); | 
|  | ctx.InsertFront(func_ast->params, param); | 
|  | } else if (sc == ast::StorageClass::kStorage || | 
|  | sc == ast::StorageClass::kUniform) { | 
|  | // Variables into the Storage and Uniform storage classes are | 
|  | // redeclared as entry point parameters with a pointer type. | 
|  | auto attributes = ctx.Clone(var->Declaration()->decorations); | 
|  | attributes.push_back(ctx.dst->Disable( | 
|  | ast::DisabledValidation::kEntryPointParameter)); | 
|  | attributes.push_back( | 
|  | ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass)); | 
|  | auto* param_type = ctx.dst->ty.pointer( | 
|  | store_type(), sc, var->Declaration()->declared_access); | 
|  | auto* param = | 
|  | ctx.dst->Param(new_var_symbol, param_type, attributes); | 
|  | ctx.InsertFront(func_ast->params, param); | 
|  | is_pointer = true; | 
|  | } else if (sc == ast::StorageClass::kWorkgroup && | 
|  | ContainsMatrix(var->Type())) { | 
|  | // Due to a bug in the MSL compiler, we use a threadgroup memory | 
|  | // argument for any workgroup allocation that contains a matrix. | 
|  | // See crbug.com/tint/938. | 
|  | // TODO(jrprice): Do this for all other workgroup variables too. | 
|  |  | 
|  | // Create a member in the workgroup parameter struct. | 
|  | auto member = ctx.Clone(var->Declaration()->symbol); | 
|  | workgroup_parameter_members.push_back( | 
|  | ctx.dst->Member(member, store_type())); | 
|  | CloneStructTypes(var->Type()->UnwrapRef()); | 
|  |  | 
|  | // Create a function-scope variable that is a pointer to the member. | 
|  | auto* member_ptr = ctx.dst->AddressOf(ctx.dst->MemberAccessor( | 
|  | ctx.dst->Deref(workgroup_param()), member)); | 
|  | auto* local_var = | 
|  | ctx.dst->Const(new_var_symbol, | 
|  | ctx.dst->ty.pointer( | 
|  | store_type(), ast::StorageClass::kWorkgroup), | 
|  | member_ptr); | 
|  | ctx.InsertFront(func_ast->body->statements, | 
|  | ctx.dst->Decl(local_var)); | 
|  | is_pointer = true; | 
|  | } else { | 
|  | // Variables in the Private and Workgroup storage classes are | 
|  | // redeclared at function scope. Disable storage class validation on | 
|  | // this variable. | 
|  | auto* disable_validation = | 
|  | ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass); | 
|  | auto* constructor = ctx.Clone(var->Declaration()->constructor); | 
|  | auto* local_var = | 
|  | ctx.dst->Var(new_var_symbol, store_type(), sc, constructor, | 
|  | ast::DecorationList{disable_validation}); | 
|  | ctx.InsertFront(func_ast->body->statements, | 
|  | ctx.dst->Decl(local_var)); | 
|  | } | 
|  | } else { | 
|  | // For a regular function, redeclare the variable as a parameter. | 
|  | // Use a pointer for non-handle types. | 
|  | auto* param_type = store_type(); | 
|  | ast::DecorationList attributes; | 
|  | if (!var->Type()->UnwrapRef()->is_handle()) { | 
|  | param_type = ctx.dst->ty.pointer( | 
|  | param_type, sc, var->Declaration()->declared_access); | 
|  | is_pointer = true; | 
|  |  | 
|  | // Disable validation of the parameter's storage class and of | 
|  | // arguments passed it. | 
|  | attributes.push_back( | 
|  | ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass)); | 
|  | attributes.push_back(ctx.dst->Disable( | 
|  | ast::DisabledValidation::kIgnoreInvalidPointerArgument)); | 
|  | } | 
|  | ctx.InsertBack( | 
|  | func_ast->params, | 
|  | ctx.dst->Param(new_var_symbol, param_type, attributes)); | 
|  | } | 
|  |  | 
|  | // Replace all uses of the module-scope variable. | 
|  | // For non-entry points, dereference non-handle pointer parameters. | 
|  | for (auto* user : var->Users()) { | 
|  | if (user->Stmt()->Function() == func_ast) { | 
|  | const ast::Expression* expr = ctx.dst->Expr(new_var_symbol); | 
|  | if (is_pointer) { | 
|  | // If this identifier is used by an address-of operator, just | 
|  | // remove the address-of instead of adding a deref, since we | 
|  | // already have a pointer. | 
|  | auto* ident = | 
|  | user->Declaration()->As<ast::IdentifierExpression>(); | 
|  | if (ident_to_address_of.count(ident)) { | 
|  | ctx.Replace(ident_to_address_of[ident], expr); | 
|  | continue; | 
|  | } | 
|  |  | 
|  | expr = ctx.dst->Deref(expr); | 
|  | } | 
|  | ctx.Replace(user->Declaration(), expr); | 
|  | } | 
|  | } | 
|  |  | 
|  | var_to_newvar[var] = {new_var_symbol, is_pointer}; | 
|  | } | 
|  |  | 
|  | if (!workgroup_parameter_members.empty()) { | 
|  | // Create the workgroup memory parameter. | 
|  | // The parameter is a struct that contains members for each workgroup | 
|  | // variable. | 
|  | auto* str = ctx.dst->Structure(ctx.dst->Sym(), | 
|  | std::move(workgroup_parameter_members)); | 
|  | auto* param_type = ctx.dst->ty.pointer(ctx.dst->ty.Of(str), | 
|  | ast::StorageClass::kWorkgroup); | 
|  | auto* disable_validation = | 
|  | ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter); | 
|  | auto* param = | 
|  | ctx.dst->Param(workgroup_param(), param_type, {disable_validation}); | 
|  | ctx.InsertFront(func_ast->params, param); | 
|  | } | 
|  |  | 
|  | // Pass the variables as pointers to any functions that need them. | 
|  | for (auto* call : calls_to_replace[func_ast]) { | 
|  | auto* target = ctx.src->AST().Functions().Find(call->func->symbol); | 
|  | auto* target_sem = ctx.src->Sem().Get(target); | 
|  |  | 
|  | // Add new arguments for any variables that are needed by the callee. | 
|  | // For entry points, pass non-handle types as pointers. | 
|  | for (auto* target_var : target_sem->ReferencedModuleVariables()) { | 
|  | auto sc = target_var->StorageClass(); | 
|  | if (sc == ast::StorageClass::kNone) { | 
|  | continue; | 
|  | } | 
|  |  | 
|  | auto new_var = var_to_newvar[target_var]; | 
|  | bool is_handle = target_var->Type()->UnwrapRef()->is_handle(); | 
|  | const ast::Expression* arg = ctx.dst->Expr(new_var.symbol); | 
|  | if (is_entry_point && !is_handle && !new_var.is_pointer) { | 
|  | // We need to pass a pointer and we don't already have one, so take | 
|  | // the address of the new variable. | 
|  | arg = ctx.dst->AddressOf(arg); | 
|  | } | 
|  | ctx.InsertBack(call->args, arg); | 
|  | } | 
|  | } | 
|  | } | 
|  |  | 
|  | // Now remove all module-scope variables with these storage classes. | 
|  | for (auto* var_ast : ctx.src->AST().GlobalVariables()) { | 
|  | auto* var_sem = ctx.src->Sem().Get(var_ast); | 
|  | if (var_sem->StorageClass() != ast::StorageClass::kNone) { | 
|  | ctx.Remove(ctx.src->AST().GlobalDeclarations(), var_ast); | 
|  | } | 
|  | } | 
|  | } | 
|  |  | 
|  | private: | 
|  | std::unordered_set<const sem::Struct*> cloned_structs_; | 
|  | }; | 
|  |  | 
|  | ModuleScopeVarToEntryPointParam::ModuleScopeVarToEntryPointParam() = default; | 
|  |  | 
|  | ModuleScopeVarToEntryPointParam::~ModuleScopeVarToEntryPointParam() = default; | 
|  |  | 
|  | void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx, | 
|  | const DataMap&, | 
|  | DataMap&) { | 
|  | State state{ctx}; | 
|  | state.Process(); | 
|  | ctx.Clone(); | 
|  | } | 
|  |  | 
|  | }  // namespace transform | 
|  | }  // namespace tint |