[metal] Enable subgroup matrix feature

The feature is supported if the GPU family is Apple 7 or higher.

The two supported configurations are 8x8 matrices with f32 and f16
component types.

Add a Dawn E2E test that tests all supported matrix configurations for
`subgroupMatrixMultiply` and `subgroupMatrixMultiplyAccumulate`.

Bug: 348702031
Change-Id: I579ce8f2a932b324e184c4885a02ee94c257638f
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/225854
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/dawn/native/metal/PhysicalDeviceMTL.mm b/src/dawn/native/metal/PhysicalDeviceMTL.mm
index af0d5e1..fd24888 100644
--- a/src/dawn/native/metal/PhysicalDeviceMTL.mm
+++ b/src/dawn/native/metal/PhysicalDeviceMTL.mm
@@ -680,6 +680,10 @@
         EnableFeature(Feature::SubgroupsF16);
     }
 
+    if ([*mDevice supportsFamily:MTLGPUFamilyApple7]) {
+        EnableFeature(Feature::ChromiumExperimentalSubgroupMatrix);
+    }
+
     EnableFeature(Feature::SharedTextureMemoryIOSurface);
 
     EnableFeature(Feature::SharedFenceMTLSharedEvent);
@@ -944,5 +948,24 @@
 #endif
         }
     }
+    if (auto* subgroupMatrixConfigs = info.Get<AdapterPropertiesSubgroupMatrixConfigs>()) {
+        DAWN_ASSERT([*mDevice supportsFamily:MTLGPUFamilyApple7]);
+
+        auto* configs = new SubgroupMatrixConfig[2];
+        subgroupMatrixConfigs->configCount = 2;
+        subgroupMatrixConfigs->configs = configs;
+
+        configs[0].componentType = wgpu::SubgroupMatrixComponentType::F32;
+        configs[0].resultComponentType = wgpu::SubgroupMatrixComponentType::F32;
+        configs[0].M = 8;
+        configs[0].N = 8;
+        configs[0].K = 8;
+
+        configs[1].componentType = wgpu::SubgroupMatrixComponentType::F16;
+        configs[1].resultComponentType = wgpu::SubgroupMatrixComponentType::F16;
+        configs[1].M = 8;
+        configs[1].N = 8;
+        configs[1].K = 8;
+    }
 }
 }  // namespace dawn::native::metal
diff --git a/src/dawn/tests/end2end/SubgroupMatrixTests.cpp b/src/dawn/tests/end2end/SubgroupMatrixTests.cpp
index d334b58..138f7bd 100644
--- a/src/dawn/tests/end2end/SubgroupMatrixTests.cpp
+++ b/src/dawn/tests/end2end/SubgroupMatrixTests.cpp
@@ -27,11 +27,39 @@
 
 #include <vector>
 
+#include "dawn/common/Math.h"
 #include "dawn/tests/DawnTest.h"
+#include "dawn/utils/WGPUHelpers.h"
 
 namespace dawn {
 namespace {
 
+const char* ComponentTypeToWgslType(wgpu::SubgroupMatrixComponentType c) {
+    switch (c) {
+        case wgpu::SubgroupMatrixComponentType::F32:
+            return "f32";
+        case wgpu::SubgroupMatrixComponentType::F16:
+            return "f16";
+        case wgpu::SubgroupMatrixComponentType::U32:
+            return "u32";
+        case wgpu::SubgroupMatrixComponentType::I32:
+            return "i32";
+    }
+    return "<invalid>";
+}
+
+uint32_t ComponentTypeToByteSize(wgpu::SubgroupMatrixComponentType c) {
+    switch (c) {
+        case wgpu::SubgroupMatrixComponentType::F32:
+        case wgpu::SubgroupMatrixComponentType::U32:
+        case wgpu::SubgroupMatrixComponentType::I32:
+            return 4;
+        case wgpu::SubgroupMatrixComponentType::F16:
+            return 2;
+    }
+    return 0;
+}
+
 using SubgroupMatrixTest = DawnTest;
 
 // Test that it is only valid to request the AdapterPropertiesSubgroupMatrixConfigs structure if the
@@ -58,5 +86,192 @@
 
 DAWN_INSTANTIATE_TEST(SubgroupMatrixTest, D3D12Backend(), MetalBackend(), VulkanBackend());
 
+enum MatrixOp {
+    MatrixMultiply,
+    MatrixMultiplyAccumulate,
+};
+DAWN_TEST_PARAM_STRUCT(MatrixMatrixArithmeticParams, MatrixOp);
+
+class SubgroupMatrixArithmeticTest : public DawnTestWithParams<MatrixMatrixArithmeticParams> {
+  protected:
+    std::vector<wgpu::FeatureName> GetRequiredFeatures() override {
+        std::vector<wgpu::FeatureName> features;
+        if (SupportsFeatures({wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix})) {
+            features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
+        }
+        if (SupportsFeatures({wgpu::FeatureName::ShaderF16})) {
+            features.push_back(wgpu::FeatureName::ShaderF16);
+        }
+        return features;
+    }
+};
+
+using SubgroupMatrix_MatrixMatrixArithmeticTest = SubgroupMatrixArithmeticTest;
+TEST_P(SubgroupMatrix_MatrixMatrixArithmeticTest, MatrixMultiply) {
+    DAWN_TEST_UNSUPPORTED_IF(
+        !adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix));
+
+    MatrixOp op = GetParam().mMatrixOp;
+
+    // Query the support subgroup matrix configurations.
+    wgpu::AdapterInfo info;
+    wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroupMatrixConfigs;
+    info.nextInChain = &subgroupMatrixConfigs;
+    ASSERT_EQ(adapter.GetInfo(&info), wgpu::Status::Success);
+
+    // Test each supported config.
+    for (size_t i = 0; i < subgroupMatrixConfigs.configCount; i++) {
+        auto& config = subgroupMatrixConfigs.configs[i];
+        uint32_t componentByteSize = ComponentTypeToByteSize(config.componentType);
+        uint32_t resultComponentByteSize = ComponentTypeToByteSize(config.resultComponentType);
+
+        // Generate a shader that performs a matrix multiplication that matches the config.
+        std::ostringstream shader;
+        shader << "enable chromium_experimental_subgroup_matrix;\n";
+        if (config.componentType == wgpu::SubgroupMatrixComponentType::F16 ||
+            config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
+            shader << "enable f16;\n";
+        }
+        shader << "\n";
+        shader << "alias ComponentType = " << ComponentTypeToWgslType(config.componentType)
+               << ";\n";
+        shader << "alias ResultComponentType = "
+               << ComponentTypeToWgslType(config.resultComponentType) << ";\n";
+        shader << "\n";
+        shader << "const M = " << config.M << ";\n";
+        shader << "const N = " << config.N << ";\n";
+        shader << "const K = " << config.K << ";\n";
+        shader << R"(
+@group(0) @binding(0) var<storage, read>       inputs : array<ComponentType, K*M + N*K>;
+@group(0) @binding(1) var<storage, read_write> output : array<ResultComponentType, M*N>;
+
+@compute @workgroup_size(N, M)
+fn main() {
+    let lhs = subgroupMatrixLoad<subgroup_matrix_left<ComponentType, K, M>>(&inputs,  0, false, M);
+    let rhs = subgroupMatrixLoad<subgroup_matrix_right<ComponentType, N, K>>(&inputs, K*M, false, K);
+)";
+        switch (op) {
+            case MatrixMultiply:
+                shader << "let result = subgroupMatrixMultiply(lhs, rhs);\n";
+                break;
+            case MatrixMultiplyAccumulate:
+                // Perform the multiplication twice, accumulating into a zero matrix the first time.
+                shader << "let zero = subgroup_matrix_result<ResultComponentType, N, M>();\n";
+                shader << "var result = subgroupMatrixMultiplyAccumulate(lhs, rhs, zero);\n";
+                shader << "result = subgroupMatrixMultiplyAccumulate(lhs, rhs, result);\n";
+                break;
+        }
+        shader << R"(
+    subgroupMatrixStore(&output, 0, result, false, M);
+})";
+
+        wgpu::ComputePipelineDescriptor csDesc;
+        csDesc.compute.module = utils::CreateShaderModule(device, shader.str());
+        wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&csDesc);
+
+        // Convert the matrix multiplication result value to the result component type.
+        auto toResultType = [&](auto value) -> uint32_t {
+            switch (config.resultComponentType) {
+                case wgpu::SubgroupMatrixComponentType::F32: {
+                    float valueF32 = static_cast<float>(value);
+                    return *reinterpret_cast<uint32_t*>(&valueF32);
+                }
+                case wgpu::SubgroupMatrixComponentType::F16: {
+                    uint16_t valueF16 = Float32ToFloat16(static_cast<float>(value));
+                    return (uint32_t(valueF16) << 16) | valueF16;
+                }
+                case wgpu::SubgroupMatrixComponentType::U32:
+                    return uint32_t(value);
+                case wgpu::SubgroupMatrixComponentType::I32:
+                    int32_t valueI32 = static_cast<int32_t>(value);
+                    return *reinterpret_cast<uint32_t*>(&valueI32);
+            }
+            return 0;
+        };
+
+        // Generate the value to fill the input matrices with as a 32-bit word, and generate the
+        // corresponding output value as well. Pack multiple copies of the value together if the
+        // size of the input component type is less than 32 bits.
+        uint32_t inputValue;
+        uint32_t outputValue;
+        switch (config.componentType) {
+            case wgpu::SubgroupMatrixComponentType::F32: {
+                float in = 0.5;
+                float out = in * in * config.K;
+                if (op == MatrixMultiplyAccumulate) {
+                    out *= 2;
+                }
+                inputValue = *reinterpret_cast<uint32_t*>(&in);
+                outputValue = toResultType(out);
+                break;
+            }
+            case wgpu::SubgroupMatrixComponentType::F16: {
+                float inF32 = 0.5;
+                uint16_t in = Float32ToFloat16(inF32);
+                float out = inF32 * inF32 * config.K;
+                if (op == MatrixMultiplyAccumulate) {
+                    out *= 2;
+                }
+                inputValue = (uint32_t(in) << 16) | in;
+                outputValue = toResultType(out);
+                break;
+            }
+            case wgpu::SubgroupMatrixComponentType::U32:
+            case wgpu::SubgroupMatrixComponentType::I32: {
+                uint32_t in = 2;
+                uint32_t out = in * in * config.K;
+                if (op == MatrixMultiplyAccumulate) {
+                    out *= 2;
+                }
+                inputValue = in;
+                outputValue = toResultType(out);
+                break;
+            }
+        }
+
+        uint32_t numInputElements = (config.M + config.N) * config.K;
+        std::vector<uint32_t> inValues(numInputElements * componentByteSize / 4, inputValue);
+        std::vector<uint32_t> expected(config.M * config.N * resultComponentByteSize / 4,
+                                       outputValue);
+        wgpu::Buffer inputs = utils::CreateBufferFromData(
+            device, inValues.data(), inValues.size() * 4, wgpu::BufferUsage::Storage);
+
+        wgpu::BufferDescriptor outputDescriptor;
+        outputDescriptor.size = config.M * config.N * resultComponentByteSize;
+        outputDescriptor.usage = wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage;
+        wgpu::Buffer output = device.CreateBuffer(&outputDescriptor);
+
+        wgpu::BindGroup bindGroup = utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0),
+                                                         {{0, inputs}, {1, output}});
+
+        wgpu::CommandBuffer commands;
+        {
+            wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+            wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
+            pass.SetPipeline(pipeline);
+            pass.SetBindGroup(0, bindGroup);
+            pass.DispatchWorkgroups(1);
+            pass.End();
+
+            commands = encoder.Finish();
+        }
+
+        queue.Submit(1, &commands);
+
+        EXPECT_BUFFER_U32_RANGE_EQ(expected.data(), output, 0, expected.size());
+    }
+}
+
+DAWN_INSTANTIATE_TEST_P(SubgroupMatrix_MatrixMatrixArithmeticTest,
+                        {
+                            D3D12Backend(),
+                            MetalBackend(),
+                            VulkanBackend(),
+                        },
+                        {
+                            MatrixOp::MatrixMultiply,
+                            MatrixOp::MatrixMultiplyAccumulate,
+                        });
+
 }  // anonymous namespace
 }  // namespace dawn