Metal: Allocate threadgroup memory based on Tint reflection

Tint passes threadgroup memory in MSL as entrypoint arguments since
threadgroup memory at the module scope cannot be default initialized.
MSL lacks default constructors for matrices in threadgroup memory.

Bug: dawn:1110
Change-Id: I7462fa448c6ebdb3cc4dc24bd5ff0a99287cdba0
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/64240
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Austin Eng <enga@chromium.org>
diff --git a/src/dawn_native/metal/ComputePipelineMTL.h b/src/dawn_native/metal/ComputePipelineMTL.h
index 6f777b6..4ecb450 100644
--- a/src/dawn_native/metal/ComputePipelineMTL.h
+++ b/src/dawn_native/metal/ComputePipelineMTL.h
@@ -47,6 +47,7 @@
         NSPRef<id<MTLComputePipelineState>> mMtlComputePipelineState;
         MTLSize mLocalWorkgroupSize;
         bool mRequiresStorageBufferLength;
+        std::vector<uint32_t> mWorkgroupAllocations;
     };
 
 }}  // namespace dawn_native::metal
diff --git a/src/dawn_native/metal/ComputePipelineMTL.mm b/src/dawn_native/metal/ComputePipelineMTL.mm
index 14110cc..8879fb2 100644
--- a/src/dawn_native/metal/ComputePipelineMTL.mm
+++ b/src/dawn_native/metal/ComputePipelineMTL.mm
@@ -54,11 +54,18 @@
         mLocalWorkgroupSize = MTLSizeMake(localSize.x, localSize.y, localSize.z);
 
         mRequiresStorageBufferLength = computeData.needsStorageBufferLength;
+        mWorkgroupAllocations = std::move(computeData.workgroupAllocations);
         return {};
     }
 
     void ComputePipeline::Encode(id<MTLComputeCommandEncoder> encoder) {
         [encoder setComputePipelineState:mMtlComputePipelineState.Get()];
+        for (size_t i = 0; i < mWorkgroupAllocations.size(); ++i) {
+            if (mWorkgroupAllocations[i] == 0) {
+                continue;
+            }
+            [encoder setThreadgroupMemoryLength:mWorkgroupAllocations[i] atIndex:i];
+        }
     }
 
     MTLSize ComputePipeline::GetLocalWorkGroupSize() const {
diff --git a/src/dawn_native/metal/ShaderModuleMTL.h b/src/dawn_native/metal/ShaderModuleMTL.h
index ab87929..4cb91a4 100644
--- a/src/dawn_native/metal/ShaderModuleMTL.h
+++ b/src/dawn_native/metal/ShaderModuleMTL.h
@@ -37,6 +37,7 @@
         struct MetalFunctionData {
             NSPRef<id<MTLFunction>> function;
             bool needsStorageBufferLength;
+            std::vector<uint32_t> workgroupAllocations;
         };
         MaybeError CreateFunction(const char* entryPointName,
                                   SingleShaderStage stage,
@@ -53,7 +54,8 @@
                                                   const RenderPipeline* renderPipeline,
                                                   std::string* remappedEntryPointName,
                                                   bool* needsStorageBufferLength,
-                                                  bool* hasInvariantAttribute);
+                                                  bool* hasInvariantAttribute,
+                                                  std::vector<uint32_t>* workgroupAllocations);
         ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
         ~ShaderModule() override = default;
         MaybeError Initialize(ShaderModuleParseResult* parseResult);
diff --git a/src/dawn_native/metal/ShaderModuleMTL.mm b/src/dawn_native/metal/ShaderModuleMTL.mm
index 6685f4f..1ba4ee6 100644
--- a/src/dawn_native/metal/ShaderModuleMTL.mm
+++ b/src/dawn_native/metal/ShaderModuleMTL.mm
@@ -44,14 +44,16 @@
         return InitializeBase(parseResult);
     }
 
-    ResultOrError<std::string> ShaderModule::TranslateToMSL(const char* entryPointName,
-                                                            SingleShaderStage stage,
-                                                            const PipelineLayout* layout,
-                                                            uint32_t sampleMask,
-                                                            const RenderPipeline* renderPipeline,
-                                                            std::string* remappedEntryPointName,
-                                                            bool* needsStorageBufferLength,
-                                                            bool* hasInvariantAttribute) {
+    ResultOrError<std::string> ShaderModule::TranslateToMSL(
+        const char* entryPointName,
+        SingleShaderStage stage,
+        const PipelineLayout* layout,
+        uint32_t sampleMask,
+        const RenderPipeline* renderPipeline,
+        std::string* remappedEntryPointName,
+        bool* needsStorageBufferLength,
+        bool* hasInvariantAttribute,
+        std::vector<uint32_t>* workgroupAllocations) {
         ScopedTintICEHandler scopedICEHandler(GetDevice());
 
         std::ostringstream errorStream;
@@ -166,6 +168,7 @@
 
         *needsStorageBufferLength = result.needs_storage_buffer_sizes;
         *hasInvariantAttribute = result.has_invariant_attribute;
+        *workgroupAllocations = std::move(result.workgroup_allocations[*remappedEntryPointName]);
 
         return std::move(result.msl);
     }
@@ -190,7 +193,7 @@
         DAWN_TRY_ASSIGN(msl,
                         TranslateToMSL(entryPointName, stage, layout, sampleMask, renderPipeline,
                                        &remappedEntryPointName, &out->needsStorageBufferLength,
-                                       &hasInvariantAttribute));
+                                       &hasInvariantAttribute, &out->workgroupAllocations));
 
         // Metal uses Clang to compile the shader as C++14. Disable everything in the -Wall
         // category. -Wunused-variable in particular comes up a lot in generated code, and some
diff --git a/src/tests/end2end/ComputeSharedMemoryTests.cpp b/src/tests/end2end/ComputeSharedMemoryTests.cpp
index 4a822a8..7519e68 100644
--- a/src/tests/end2end/ComputeSharedMemoryTests.cpp
+++ b/src/tests/end2end/ComputeSharedMemoryTests.cpp
@@ -100,6 +100,102 @@
         })");
 }
 
+// Test using assorted types in workgroup memory. MSL lacks constructors
+// for matrices in threadgroup memory. Basic test that reading and
+// writing a matrix in workgroup memory works.
+TEST_P(ComputeSharedMemoryTests, AssortedTypes) {
+    wgpu::ComputePipelineDescriptor csDesc;
+    csDesc.compute.module = utils::CreateShaderModule(device, R"(
+        struct StructValues {
+            m: mat2x2<f32>;
+        };
+
+        [[block]] struct Dst {
+            d_struct : StructValues;
+            d_matrix : mat2x2<f32>;
+            d_array : array<u32, 4>;
+            d_vector : vec4<f32>;
+        };
+
+        [[group(0), binding(0)]] var<storage, write> dst : Dst;
+
+        var<workgroup> wg_struct : StructValues;
+        var<workgroup> wg_matrix : mat2x2<f32>;
+        var<workgroup> wg_array : array<u32, 4>;
+        var<workgroup> wg_vector : vec4<f32>;
+
+        [[stage(compute), workgroup_size(4,1,1)]]
+        fn main([[builtin(local_invocation_id)]] LocalInvocationID : vec3<u32>) {
+
+            let i = 4u * LocalInvocationID.x;
+            if (LocalInvocationID.x == 0u) {
+                wg_struct.m = mat2x2<f32>(
+                    vec2<f32>(f32(i), f32(i + 1u)),
+                    vec2<f32>(f32(i + 2u), f32(i + 3u)));
+            } elseif (LocalInvocationID.x == 1u) {
+                wg_matrix = mat2x2<f32>(
+                    vec2<f32>(f32(i), f32(i + 1u)),
+                    vec2<f32>(f32(i + 2u), f32(i + 3u)));
+            } elseif (LocalInvocationID.x == 2u) {
+                wg_array[0u] = i;
+                wg_array[1u] = i + 1u;
+                wg_array[2u] = i + 2u;
+                wg_array[3u] = i + 3u;
+            } elseif (LocalInvocationID.x == 3u) {
+                wg_vector = vec4<f32>(
+                    f32(i), f32(i + 1u), f32(i + 2u), f32(i + 3u));
+            }
+
+            workgroupBarrier();
+
+            if (LocalInvocationID.x == 0u) {
+                dst.d_struct = wg_struct;
+                dst.d_matrix = wg_matrix;
+                dst.d_array = wg_array;
+                dst.d_vector = wg_vector;
+            }
+        }
+    )");
+    csDesc.compute.entryPoint = "main";
+    wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&csDesc);
+
+    // Set up dst storage buffer
+    wgpu::BufferDescriptor dstDesc;
+    dstDesc.size = 64;
+    dstDesc.usage =
+        wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst;
+    wgpu::Buffer dst = device.CreateBuffer(&dstDesc);
+
+    // Set up bind group and issue dispatch
+    wgpu::BindGroup bindGroup = utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0),
+                                                     {
+                                                         {0, dst},
+                                                     });
+
+    wgpu::CommandBuffer commands;
+    {
+        wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+        wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
+        pass.SetPipeline(pipeline);
+        pass.SetBindGroup(0, bindGroup);
+        pass.Dispatch(1);
+        pass.EndPass();
+
+        commands = encoder.Finish();
+    }
+
+    queue.Submit(1, &commands);
+
+    std::array<float, 4> expectedStruct = {0., 1., 2., 3.};
+    std::array<float, 4> expectedMatrix = {4., 5., 6., 7.};
+    std::array<uint32_t, 4> expectedArray = {8, 9, 10, 11};
+    std::array<float, 4> expectedVector = {12., 13., 14., 15.};
+    EXPECT_BUFFER_FLOAT_RANGE_EQ(expectedStruct.data(), dst, 0, 4);
+    EXPECT_BUFFER_FLOAT_RANGE_EQ(expectedMatrix.data(), dst, 16, 4);
+    EXPECT_BUFFER_U32_RANGE_EQ(expectedArray.data(), dst, 32, 4);
+    EXPECT_BUFFER_FLOAT_RANGE_EQ(expectedVector.data(), dst, 48, 4);
+}
+
 DAWN_INSTANTIATE_TEST(ComputeSharedMemoryTests,
                       D3D12Backend(),
                       MetalBackend(),