msl: Use a struct for threadgroup memory arguments
MSL has a limit on the number of threadgroup memory arguments, so use
a struct to support an arbitrary number of workgroup variables.
This commit introduces a `State` object to this transform, which is
used to track which structs have been cloned eagerly, in order to
avoid duplicating them.
Bug: tint:938
Change-Id: Ia467db186e176a08f160455eab5fd3b3662f56b8
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/65360
Auto-Submit: James Price <jrprice@google.com>
Kokoro: James Price <jrprice@google.com>
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/transform/module_scope_var_to_entry_point_param.cc b/src/transform/module_scope_var_to_entry_point_param.cc
index e865cab..787de98 100644
--- a/src/transform/module_scope_var_to_entry_point_param.cc
+++ b/src/transform/module_scope_var_to_entry_point_param.cc
@@ -49,110 +49,164 @@
}
} // namespace
-ModuleScopeVarToEntryPointParam::ModuleScopeVarToEntryPointParam() = default;
+/// State holds the current transform state.
+struct ModuleScopeVarToEntryPointParam::State {
+ /// The clone context.
+ CloneContext& ctx;
-ModuleScopeVarToEntryPointParam::~ModuleScopeVarToEntryPointParam() = default;
+ /// Constructor
+ /// @param context the clone context
+ explicit State(CloneContext& context) : ctx(context) {}
-void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
- const DataMap&,
- DataMap&) {
- // Predetermine the list of function calls that need to be replaced.
- using CallList = std::vector<const ast::CallExpression*>;
- std::unordered_map<const ast::Function*, CallList> calls_to_replace;
-
- std::vector<ast::Function*> functions_to_process;
-
- // Build a list of functions that transitively reference any private or
- // workgroup variables, or texture/sampler variables.
- for (auto* func_ast : ctx.src->AST().Functions()) {
- auto* func_sem = ctx.src->Sem().Get(func_ast);
-
- bool needs_processing = false;
- for (auto* var : func_sem->ReferencedModuleVariables()) {
- if (var->StorageClass() == ast::StorageClass::kPrivate ||
- var->StorageClass() == ast::StorageClass::kWorkgroup ||
- var->StorageClass() == ast::StorageClass::kUniformConstant) {
- needs_processing = true;
- break;
+ /// Clone any struct types that are contained in `ty` (including `ty` itself),
+ /// and add it to the global declarations now, so that they precede new global
+ /// declarations that need to reference them.
+ /// @param ty the type to clone
+ void CloneStructTypes(const sem::Type* ty) {
+ if (auto* str = ty->As<sem::Struct>()) {
+ if (!cloned_structs_.emplace(str).second) {
+ // The struct has already been cloned.
+ return;
}
- }
- if (needs_processing) {
- functions_to_process.push_back(func_ast);
-
- // Find all of the calls to this function that will need to be replaced.
- for (auto* call : func_sem->CallSites()) {
- auto* call_sem = ctx.src->Sem().Get(call);
- calls_to_replace[call_sem->Stmt()->Function()].push_back(call);
+ // Recurse into members.
+ for (auto* member : str->Members()) {
+ CloneStructTypes(member->Type());
}
+
+ // Clone the struct and add it to the global declaration list.
+ // Remove the old declaration.
+ auto* ast_str = str->Declaration();
+ ctx.dst->AST().AddTypeDecl(ctx.Clone(const_cast<ast::Struct*>(ast_str)));
+ ctx.Remove(ctx.src->AST().GlobalDeclarations(), ast_str);
+ } else if (auto* arr = ty->As<sem::Array>()) {
+ CloneStructTypes(arr->ElemType());
}
}
- // Build a list of `&ident` expressions. We'll use this later to avoid
- // generating expressions of the form `&*ident`, which break WGSL validation
- // rules when this expression is passed to a function.
- // TODO(jrprice): We should add support for bidirectional SEM tree traversal
- // so that we can do this on the fly instead.
- std::unordered_map<ast::IdentifierExpression*, ast::UnaryOpExpression*>
- ident_to_address_of;
- for (auto* node : ctx.src->ASTNodes().Objects()) {
- auto* address_of = node->As<ast::UnaryOpExpression>();
- if (!address_of || address_of->op() != ast::UnaryOp::kAddressOf) {
- continue;
+ /// Process the module.
+ void Process() {
+ // Predetermine the list of function calls that need to be replaced.
+ using CallList = std::vector<const ast::CallExpression*>;
+ std::unordered_map<const ast::Function*, CallList> calls_to_replace;
+
+ std::vector<ast::Function*> functions_to_process;
+
+ // Build a list of functions that transitively reference any private or
+ // workgroup variables, or texture/sampler variables.
+ for (auto* func_ast : ctx.src->AST().Functions()) {
+ auto* func_sem = ctx.src->Sem().Get(func_ast);
+
+ bool needs_processing = false;
+ for (auto* var : func_sem->ReferencedModuleVariables()) {
+ if (var->StorageClass() == ast::StorageClass::kPrivate ||
+ var->StorageClass() == ast::StorageClass::kWorkgroup ||
+ var->StorageClass() == ast::StorageClass::kUniformConstant) {
+ needs_processing = true;
+ break;
+ }
+ }
+
+ if (needs_processing) {
+ functions_to_process.push_back(func_ast);
+
+ // Find all of the calls to this function that will need to be replaced.
+ for (auto* call : func_sem->CallSites()) {
+ auto* call_sem = ctx.src->Sem().Get(call);
+ calls_to_replace[call_sem->Stmt()->Function()].push_back(call);
+ }
+ }
}
- if (auto* ident = address_of->expr()->As<ast::IdentifierExpression>()) {
- ident_to_address_of[ident] = address_of;
- }
- }
- for (auto* func_ast : functions_to_process) {
- auto* func_sem = ctx.src->Sem().Get(func_ast);
- bool is_entry_point = func_ast->IsEntryPoint();
-
- // Map module-scope variables onto their function-scope replacement.
- std::unordered_map<const sem::Variable*, Symbol> var_to_symbol;
-
- for (auto* var : func_sem->ReferencedModuleVariables()) {
- if (var->StorageClass() != ast::StorageClass::kPrivate &&
- var->StorageClass() != ast::StorageClass::kWorkgroup &&
- var->StorageClass() != ast::StorageClass::kUniformConstant) {
+ // Build a list of `&ident` expressions. We'll use this later to avoid
+ // generating expressions of the form `&*ident`, which break WGSL validation
+ // rules when this expression is passed to a function.
+ // TODO(jrprice): We should add support for bidirectional SEM tree traversal
+ // so that we can do this on the fly instead.
+ std::unordered_map<ast::IdentifierExpression*, ast::UnaryOpExpression*>
+ ident_to_address_of;
+ for (auto* node : ctx.src->ASTNodes().Objects()) {
+ auto* address_of = node->As<ast::UnaryOpExpression>();
+ if (!address_of || address_of->op() != ast::UnaryOp::kAddressOf) {
continue;
}
+ if (auto* ident = address_of->expr()->As<ast::IdentifierExpression>()) {
+ ident_to_address_of[ident] = address_of;
+ }
+ }
- // This is the symbol for the variable that replaces the module-scope var.
- auto new_var_symbol = ctx.dst->Sym();
+ for (auto* func_ast : functions_to_process) {
+ auto* func_sem = ctx.src->Sem().Get(func_ast);
+ bool is_entry_point = func_ast->IsEntryPoint();
- auto* store_type = CreateASTTypeFor(ctx, var->Type()->UnwrapRef());
+ // Map module-scope variables onto their function-scope replacement.
+ std::unordered_map<const sem::Variable*, Symbol> var_to_symbol;
- // Track whether the new variable is a pointer or not.
- bool is_pointer = false;
+ // We aggregate all workgroup variables into a struct to avoid hitting
+ // MSL's limit for threadgroup memory arguments.
+ Symbol workgroup_parameter_symbol;
+ ast::StructMemberList workgroup_parameter_members;
+ auto workgroup_param = [&]() {
+ if (!workgroup_parameter_symbol.IsValid()) {
+ workgroup_parameter_symbol = ctx.dst->Sym();
+ }
+ return workgroup_parameter_symbol;
+ };
- if (is_entry_point) {
- if (store_type->is_handle()) {
- // For a texture or sampler variable, redeclare it as an entry point
- // parameter. Disable entry point parameter validation.
- auto* disable_validation =
- ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
- ctx.dst->ID(), ast::DisabledValidation::kEntryPointParameter);
- auto decos = ctx.Clone(var->Declaration()->decorations());
- decos.push_back(disable_validation);
- auto* param = ctx.dst->Param(new_var_symbol, store_type, decos);
- ctx.InsertFront(func_ast->params(), param);
- } else {
- if (var->StorageClass() == ast::StorageClass::kWorkgroup &&
- ContainsMatrix(var->Type())) {
- // Due to a bug in the MSL compiler, we use a threadgroup memory
- // argument for any workgroup allocation that contains a matrix.
- // See crbug.com/tint/938.
+ for (auto* var : func_sem->ReferencedModuleVariables()) {
+ if (var->StorageClass() != ast::StorageClass::kPrivate &&
+ var->StorageClass() != ast::StorageClass::kWorkgroup &&
+ var->StorageClass() != ast::StorageClass::kUniformConstant) {
+ continue;
+ }
+
+ // This is the symbol for the variable that replaces the module-scope
+ // var.
+ auto new_var_symbol = ctx.dst->Sym();
+
+ // Helper to create an AST node for the store type of the variable.
+ auto store_type = [&]() {
+ return CreateASTTypeFor(ctx, var->Type()->UnwrapRef());
+ };
+
+ // Track whether the new variable is a pointer or not.
+ bool is_pointer = false;
+
+ if (is_entry_point) {
+ if (var->Type()->UnwrapRef()->is_handle()) {
+ // For a texture or sampler variable, redeclare it as an entry point
+ // parameter. Disable entry point parameter validation.
auto* disable_validation =
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
ctx.dst->ID(),
ast::DisabledValidation::kEntryPointParameter);
- auto* param_type =
- ctx.dst->ty.pointer(store_type, var->StorageClass());
- auto* param = ctx.dst->Param(new_var_symbol, param_type,
- {disable_validation});
+ auto decos = ctx.Clone(var->Declaration()->decorations());
+ decos.push_back(disable_validation);
+ auto* param = ctx.dst->Param(new_var_symbol, store_type(), decos);
ctx.InsertFront(func_ast->params(), param);
+ } else if (var->StorageClass() == ast::StorageClass::kWorkgroup &&
+ ContainsMatrix(var->Type())) {
+ // Due to a bug in the MSL compiler, we use a threadgroup memory
+ // argument for any workgroup allocation that contains a matrix.
+ // See crbug.com/tint/938.
+ // TODO(jrprice): Do this for all other workgroup variables too.
+
+ // Create a member in the workgroup parameter struct.
+ auto member = ctx.Clone(var->Declaration()->symbol());
+ workgroup_parameter_members.push_back(
+ ctx.dst->Member(member, store_type()));
+ CloneStructTypes(var->Type()->UnwrapRef());
+
+ // Create a function-scope variable that is a pointer to the member.
+ auto* member_ptr = ctx.dst->AddressOf(ctx.dst->MemberAccessor(
+ ctx.dst->Deref(workgroup_param()), member));
+ auto* local_var =
+ ctx.dst->Const(new_var_symbol,
+ ctx.dst->ty.pointer(
+ store_type(), ast::StorageClass::kWorkgroup),
+ member_ptr);
+ ctx.InsertFront(func_ast->body()->statements(),
+ ctx.dst->Decl(local_var));
is_pointer = true;
} else {
// For any other private or workgroup variable, redeclare it at
@@ -164,83 +218,123 @@
ast::DisabledValidation::kIgnoreStorageClass);
auto* constructor = ctx.Clone(var->Declaration()->constructor());
auto* local_var = ctx.dst->Var(
- new_var_symbol, store_type, var->StorageClass(), constructor,
+ new_var_symbol, store_type(), var->StorageClass(), constructor,
ast::DecorationList{disable_validation});
ctx.InsertFront(func_ast->body()->statements(),
ctx.dst->Decl(local_var));
}
- }
- } else {
- // For a regular function, redeclare the variable as a parameter.
- // Use a pointer for non-handle types.
- auto* param_type = store_type;
- if (!store_type->is_handle()) {
- param_type = ctx.dst->ty.pointer(param_type, var->StorageClass());
- is_pointer = true;
- }
- ctx.InsertBack(func_ast->params(),
- ctx.dst->Param(new_var_symbol, param_type));
- }
+ } else {
+ // For a regular function, redeclare the variable as a parameter.
+ // Use a pointer for non-handle types.
+ auto* param_type = store_type();
+ ast::DecorationList attributes;
+ if (!param_type->is_handle()) {
+ param_type = ctx.dst->ty.pointer(param_type, var->StorageClass());
+ is_pointer = true;
- // Replace all uses of the module-scope variable.
- // For non-entry points, dereference non-handle pointer parameters.
- for (auto* user : var->Users()) {
- if (user->Stmt()->Function() == func_ast) {
- ast::Expression* expr = ctx.dst->Expr(new_var_symbol);
- if (is_pointer) {
- // If this identifier is used by an address-of operator, just remove
- // the address-of instead of adding a deref, since we already have a
- // pointer.
- auto* ident = user->Declaration()->As<ast::IdentifierExpression>();
- if (ident_to_address_of.count(ident)) {
- ctx.Replace(ident_to_address_of[ident], expr);
- continue;
+ // Disable validation of arguments passed to this pointer parameter,
+ // as we will sometimes pass pointers to struct members.
+ attributes.push_back(
+ ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
+ ctx.dst->ID(),
+ ast::DisabledValidation::kIgnoreInvalidPointerArgument));
+ }
+ ctx.InsertBack(
+ func_ast->params(),
+ ctx.dst->Param(new_var_symbol, param_type, attributes));
+ }
+
+ // Replace all uses of the module-scope variable.
+ // For non-entry points, dereference non-handle pointer parameters.
+ for (auto* user : var->Users()) {
+ if (user->Stmt()->Function() == func_ast) {
+ ast::Expression* expr = ctx.dst->Expr(new_var_symbol);
+ if (is_pointer) {
+ // If this identifier is used by an address-of operator, just
+ // remove the address-of instead of adding a deref, since we
+ // already have a pointer.
+ auto* ident =
+ user->Declaration()->As<ast::IdentifierExpression>();
+ if (ident_to_address_of.count(ident)) {
+ ctx.Replace(ident_to_address_of[ident], expr);
+ continue;
+ }
+
+ expr = ctx.dst->Deref(expr);
}
-
- expr = ctx.dst->Deref(expr);
+ ctx.Replace(user->Declaration(), expr);
}
- ctx.Replace(user->Declaration(), expr);
}
+
+ var_to_symbol[var] = new_var_symbol;
}
- var_to_symbol[var] = new_var_symbol;
+ if (!workgroup_parameter_members.empty()) {
+ // Create the workgroup memory parameter.
+ // The parameter is a struct that contains members for each workgroup
+ // variable.
+ auto* str = ctx.dst->Structure(ctx.dst->Sym(),
+ std::move(workgroup_parameter_members));
+ auto* param_type = ctx.dst->ty.pointer(ctx.dst->ty.Of(str),
+ ast::StorageClass::kWorkgroup);
+ auto* disable_validation =
+ ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
+ ctx.dst->ID(), ast::DisabledValidation::kEntryPointParameter);
+ auto* param =
+ ctx.dst->Param(workgroup_param(), param_type, {disable_validation});
+ ctx.InsertFront(func_ast->params(), param);
+ }
+
+ // Pass the variables as pointers to any functions that need them.
+ for (auto* call : calls_to_replace[func_ast]) {
+ auto* target = ctx.src->AST().Functions().Find(call->func()->symbol());
+ auto* target_sem = ctx.src->Sem().Get(target);
+
+ // Add new arguments for any variables that are needed by the callee.
+ // For entry points, pass non-handle types as pointers.
+ for (auto* target_var : target_sem->ReferencedModuleVariables()) {
+ bool is_handle = target_var->Type()->UnwrapRef()->is_handle();
+ bool is_workgroup_matrix =
+ target_var->StorageClass() == ast::StorageClass::kWorkgroup &&
+ ContainsMatrix(target_var->Type());
+ if (target_var->StorageClass() == ast::StorageClass::kPrivate ||
+ target_var->StorageClass() == ast::StorageClass::kWorkgroup ||
+ target_var->StorageClass() ==
+ ast::StorageClass::kUniformConstant) {
+ ast::Expression* arg = ctx.dst->Expr(var_to_symbol[target_var]);
+ if (is_entry_point && !is_handle && !is_workgroup_matrix) {
+ arg = ctx.dst->AddressOf(arg);
+ }
+ ctx.InsertBack(call->params(), arg);
+ }
+ }
+ }
}
- // Pass the variables as pointers to any functions that need them.
- for (auto* call : calls_to_replace[func_ast]) {
- auto* target = ctx.src->AST().Functions().Find(call->func()->symbol());
- auto* target_sem = ctx.src->Sem().Get(target);
-
- // Add new arguments for any variables that are needed by the callee.
- // For entry points, pass non-handle types as pointers.
- for (auto* target_var : target_sem->ReferencedModuleVariables()) {
- bool is_handle = target_var->Type()->UnwrapRef()->is_handle();
- bool is_workgroup_matrix =
- target_var->StorageClass() == ast::StorageClass::kWorkgroup &&
- ContainsMatrix(target_var->Type());
- if (target_var->StorageClass() == ast::StorageClass::kPrivate ||
- target_var->StorageClass() == ast::StorageClass::kWorkgroup ||
- target_var->StorageClass() == ast::StorageClass::kUniformConstant) {
- ast::Expression* arg = ctx.dst->Expr(var_to_symbol[target_var]);
- if (is_entry_point && !is_handle && !is_workgroup_matrix) {
- arg = ctx.dst->AddressOf(arg);
- }
- ctx.InsertBack(call->params(), arg);
- }
+ // Now remove all module-scope variables with these storage classes.
+ for (auto* var_ast : ctx.src->AST().GlobalVariables()) {
+ auto* var_sem = ctx.src->Sem().Get(var_ast);
+ if (var_sem->StorageClass() == ast::StorageClass::kPrivate ||
+ var_sem->StorageClass() == ast::StorageClass::kWorkgroup ||
+ var_sem->StorageClass() == ast::StorageClass::kUniformConstant) {
+ ctx.Remove(ctx.src->AST().GlobalDeclarations(), var_ast);
}
}
}
- // Now remove all module-scope variables with these storage classes.
- for (auto* var_ast : ctx.src->AST().GlobalVariables()) {
- auto* var_sem = ctx.src->Sem().Get(var_ast);
- if (var_sem->StorageClass() == ast::StorageClass::kPrivate ||
- var_sem->StorageClass() == ast::StorageClass::kWorkgroup ||
- var_sem->StorageClass() == ast::StorageClass::kUniformConstant) {
- ctx.Remove(ctx.src->AST().GlobalDeclarations(), var_ast);
- }
- }
+ private:
+ std::unordered_set<const sem::Struct*> cloned_structs_;
+};
+ModuleScopeVarToEntryPointParam::ModuleScopeVarToEntryPointParam() = default;
+
+ModuleScopeVarToEntryPointParam::~ModuleScopeVarToEntryPointParam() = default;
+
+void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
+ const DataMap&,
+ DataMap&) {
+ State state{ctx};
+ state.Process();
ctx.Clone();
}