msl: Handle workgroup matrix allocations
Use a threadgroup memory argument for any workgroup variable that
contains a matrix.
The generator now provides a list of threadgroup memory arguments for
each entry point, so that the runtime knows how many bytes to allocate
for each argument.
Bug: tint:938
Change-Id: Ia4af33cd6a44c4f74258793443eb737c2931f5eb
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/64042
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/transform/module_scope_var_to_entry_point_param.cc b/src/transform/module_scope_var_to_entry_point_param.cc
index 3567d69..e865cab 100644
--- a/src/transform/module_scope_var_to_entry_point_param.cc
+++ b/src/transform/module_scope_var_to_entry_point_param.cc
@@ -15,6 +15,7 @@
#include "src/transform/module_scope_var_to_entry_point_param.h"
#include <unordered_map>
+#include <unordered_set>
#include <utility>
#include <vector>
@@ -29,6 +30,24 @@
namespace tint {
namespace transform {
+namespace {
+// Returns `true` if `type` is or contains a matrix type.
+bool ContainsMatrix(const sem::Type* type) {
+ type = type->UnwrapRef();
+ if (type->Is<sem::Matrix>()) {
+ return true;
+ } else if (auto* ary = type->As<sem::Array>()) {
+ return ContainsMatrix(ary->ElemType());
+ } else if (auto* str = type->As<sem::Struct>()) {
+ for (auto* member : str->Members()) {
+ if (ContainsMatrix(member->Type())) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+} // namespace
ModuleScopeVarToEntryPointParam::ModuleScopeVarToEntryPointParam() = default;
@@ -105,6 +124,9 @@
auto* store_type = CreateASTTypeFor(ctx, var->Type()->UnwrapRef());
+ // Track whether the new variable is a pointer or not.
+ bool is_pointer = false;
+
if (is_entry_point) {
if (store_type->is_handle()) {
// For a texture or sampler variable, redeclare it as an entry point
@@ -117,17 +139,36 @@
auto* param = ctx.dst->Param(new_var_symbol, store_type, decos);
ctx.InsertFront(func_ast->params(), param);
} else {
- // For a private or workgroup variable, redeclare it at function
- // scope. Disable storage class validation on this variable.
- auto* disable_validation =
- ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
- ctx.dst->ID(), ast::DisabledValidation::kIgnoreStorageClass);
- auto* constructor = ctx.Clone(var->Declaration()->constructor());
- auto* local_var = ctx.dst->Var(
- new_var_symbol, store_type, var->StorageClass(), constructor,
- ast::DecorationList{disable_validation});
- ctx.InsertFront(func_ast->body()->statements(),
- ctx.dst->Decl(local_var));
+ if (var->StorageClass() == ast::StorageClass::kWorkgroup &&
+ ContainsMatrix(var->Type())) {
+ // Due to a bug in the MSL compiler, we use a threadgroup memory
+ // argument for any workgroup allocation that contains a matrix.
+ // See crbug.com/tint/938.
+ auto* disable_validation =
+ ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
+ ctx.dst->ID(),
+ ast::DisabledValidation::kEntryPointParameter);
+ auto* param_type =
+ ctx.dst->ty.pointer(store_type, var->StorageClass());
+ auto* param = ctx.dst->Param(new_var_symbol, param_type,
+ {disable_validation});
+ ctx.InsertFront(func_ast->params(), param);
+ is_pointer = true;
+ } else {
+ // For any other private or workgroup variable, redeclare it at
+ // function scope. Disable storage class validation on this
+ // variable.
+ auto* disable_validation =
+ ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
+ ctx.dst->ID(),
+ ast::DisabledValidation::kIgnoreStorageClass);
+ auto* constructor = ctx.Clone(var->Declaration()->constructor());
+ auto* local_var = ctx.dst->Var(
+ new_var_symbol, store_type, var->StorageClass(), constructor,
+ ast::DecorationList{disable_validation});
+ ctx.InsertFront(func_ast->body()->statements(),
+ ctx.dst->Decl(local_var));
+ }
}
} else {
// For a regular function, redeclare the variable as a parameter.
@@ -135,6 +176,7 @@
auto* param_type = store_type;
if (!store_type->is_handle()) {
param_type = ctx.dst->ty.pointer(param_type, var->StorageClass());
+ is_pointer = true;
}
ctx.InsertBack(func_ast->params(),
ctx.dst->Param(new_var_symbol, param_type));
@@ -145,7 +187,7 @@
for (auto* user : var->Users()) {
if (user->Stmt()->Function() == func_ast) {
ast::Expression* expr = ctx.dst->Expr(new_var_symbol);
- if (!is_entry_point && !store_type->is_handle()) {
+ if (is_pointer) {
// If this identifier is used by an address-of operator, just remove
// the address-of instead of adding a deref, since we already have a
// pointer.
@@ -172,11 +214,15 @@
// Add new arguments for any variables that are needed by the callee.
// For entry points, pass non-handle types as pointers.
for (auto* target_var : target_sem->ReferencedModuleVariables()) {
+ bool is_handle = target_var->Type()->UnwrapRef()->is_handle();
+ bool is_workgroup_matrix =
+ target_var->StorageClass() == ast::StorageClass::kWorkgroup &&
+ ContainsMatrix(target_var->Type());
if (target_var->StorageClass() == ast::StorageClass::kPrivate ||
target_var->StorageClass() == ast::StorageClass::kWorkgroup ||
target_var->StorageClass() == ast::StorageClass::kUniformConstant) {
ast::Expression* arg = ctx.dst->Expr(var_to_symbol[target_var]);
- if (is_entry_point && !target_var->Type()->UnwrapRef()->is_handle()) {
+ if (is_entry_point && !is_handle && !is_workgroup_matrix) {
arg = ctx.dst->AddressOf(arg);
}
ctx.InsertBack(call->params(), arg);