msl: Handle workgroup matrix allocations

Use a threadgroup memory argument for any workgroup variable that
contains a matrix.

The generator now provides a list of threadgroup memory arguments for
each entry point, so that the runtime knows how many bytes to allocate
for each argument.

Bug: tint:938
Change-Id: Ia4af33cd6a44c4f74258793443eb737c2931f5eb
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/64042
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/transform/module_scope_var_to_entry_point_param.cc b/src/transform/module_scope_var_to_entry_point_param.cc
index 3567d69..e865cab 100644
--- a/src/transform/module_scope_var_to_entry_point_param.cc
+++ b/src/transform/module_scope_var_to_entry_point_param.cc
@@ -15,6 +15,7 @@
 #include "src/transform/module_scope_var_to_entry_point_param.h"
 
 #include <unordered_map>
+#include <unordered_set>
 #include <utility>
 #include <vector>
 
@@ -29,6 +30,24 @@
 
 namespace tint {
 namespace transform {
+namespace {
+// Returns `true` if `type` is or contains a matrix type.
+bool ContainsMatrix(const sem::Type* type) {
+  type = type->UnwrapRef();
+  if (type->Is<sem::Matrix>()) {
+    return true;
+  } else if (auto* ary = type->As<sem::Array>()) {
+    return ContainsMatrix(ary->ElemType());
+  } else if (auto* str = type->As<sem::Struct>()) {
+    for (auto* member : str->Members()) {
+      if (ContainsMatrix(member->Type())) {
+        return true;
+      }
+    }
+  }
+  return false;
+}
+}  // namespace
 
 ModuleScopeVarToEntryPointParam::ModuleScopeVarToEntryPointParam() = default;
 
@@ -105,6 +124,9 @@
 
       auto* store_type = CreateASTTypeFor(ctx, var->Type()->UnwrapRef());
 
+      // Track whether the new variable is a pointer or not.
+      bool is_pointer = false;
+
       if (is_entry_point) {
         if (store_type->is_handle()) {
           // For a texture or sampler variable, redeclare it as an entry point
@@ -117,17 +139,36 @@
           auto* param = ctx.dst->Param(new_var_symbol, store_type, decos);
           ctx.InsertFront(func_ast->params(), param);
         } else {
-          // For a private or workgroup variable, redeclare it at function
-          // scope. Disable storage class validation on this variable.
-          auto* disable_validation =
-              ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
-                  ctx.dst->ID(), ast::DisabledValidation::kIgnoreStorageClass);
-          auto* constructor = ctx.Clone(var->Declaration()->constructor());
-          auto* local_var = ctx.dst->Var(
-              new_var_symbol, store_type, var->StorageClass(), constructor,
-              ast::DecorationList{disable_validation});
-          ctx.InsertFront(func_ast->body()->statements(),
-                          ctx.dst->Decl(local_var));
+          if (var->StorageClass() == ast::StorageClass::kWorkgroup &&
+              ContainsMatrix(var->Type())) {
+            // Due to a bug in the MSL compiler, we use a threadgroup memory
+            // argument for any workgroup allocation that contains a matrix.
+            // See crbug.com/tint/938.
+            auto* disable_validation =
+                ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
+                    ctx.dst->ID(),
+                    ast::DisabledValidation::kEntryPointParameter);
+            auto* param_type =
+                ctx.dst->ty.pointer(store_type, var->StorageClass());
+            auto* param = ctx.dst->Param(new_var_symbol, param_type,
+                                         {disable_validation});
+            ctx.InsertFront(func_ast->params(), param);
+            is_pointer = true;
+          } else {
+            // For any other private or workgroup variable, redeclare it at
+            // function scope. Disable storage class validation on this
+            // variable.
+            auto* disable_validation =
+                ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
+                    ctx.dst->ID(),
+                    ast::DisabledValidation::kIgnoreStorageClass);
+            auto* constructor = ctx.Clone(var->Declaration()->constructor());
+            auto* local_var = ctx.dst->Var(
+                new_var_symbol, store_type, var->StorageClass(), constructor,
+                ast::DecorationList{disable_validation});
+            ctx.InsertFront(func_ast->body()->statements(),
+                            ctx.dst->Decl(local_var));
+          }
         }
       } else {
         // For a regular function, redeclare the variable as a parameter.
@@ -135,6 +176,7 @@
         auto* param_type = store_type;
         if (!store_type->is_handle()) {
           param_type = ctx.dst->ty.pointer(param_type, var->StorageClass());
+          is_pointer = true;
         }
         ctx.InsertBack(func_ast->params(),
                        ctx.dst->Param(new_var_symbol, param_type));
@@ -145,7 +187,7 @@
       for (auto* user : var->Users()) {
         if (user->Stmt()->Function() == func_ast) {
           ast::Expression* expr = ctx.dst->Expr(new_var_symbol);
-          if (!is_entry_point && !store_type->is_handle()) {
+          if (is_pointer) {
             // If this identifier is used by an address-of operator, just remove
             // the address-of instead of adding a deref, since we already have a
             // pointer.
@@ -172,11 +214,15 @@
       // Add new arguments for any variables that are needed by the callee.
       // For entry points, pass non-handle types as pointers.
       for (auto* target_var : target_sem->ReferencedModuleVariables()) {
+        bool is_handle = target_var->Type()->UnwrapRef()->is_handle();
+        bool is_workgroup_matrix =
+            target_var->StorageClass() == ast::StorageClass::kWorkgroup &&
+            ContainsMatrix(target_var->Type());
         if (target_var->StorageClass() == ast::StorageClass::kPrivate ||
             target_var->StorageClass() == ast::StorageClass::kWorkgroup ||
             target_var->StorageClass() == ast::StorageClass::kUniformConstant) {
           ast::Expression* arg = ctx.dst->Expr(var_to_symbol[target_var]);
-          if (is_entry_point && !target_var->Type()->UnwrapRef()->is_handle()) {
+          if (is_entry_point && !is_handle && !is_workgroup_matrix) {
             arg = ctx.dst->AddressOf(arg);
           }
           ctx.InsertBack(call->params(), arg);