[metal] Enable subgroup matrix feature
The feature is supported if the GPU family is Apple 7 or higher.
The two supported configurations are 8x8 matrices with f32 and f16
component types.
Add a Dawn E2E test that tests all supported matrix configurations for
`subgroupMatrixMultiply` and `subgroupMatrixMultiplyAccumulate`.
Bug: 348702031
Change-Id: I579ce8f2a932b324e184c4885a02ee94c257638f
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/225854
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/dawn/native/metal/PhysicalDeviceMTL.mm b/src/dawn/native/metal/PhysicalDeviceMTL.mm
index af0d5e1..fd24888 100644
--- a/src/dawn/native/metal/PhysicalDeviceMTL.mm
+++ b/src/dawn/native/metal/PhysicalDeviceMTL.mm
@@ -680,6 +680,10 @@
EnableFeature(Feature::SubgroupsF16);
}
+ if ([*mDevice supportsFamily:MTLGPUFamilyApple7]) {
+ EnableFeature(Feature::ChromiumExperimentalSubgroupMatrix);
+ }
+
EnableFeature(Feature::SharedTextureMemoryIOSurface);
EnableFeature(Feature::SharedFenceMTLSharedEvent);
@@ -944,5 +948,24 @@
#endif
}
}
+ if (auto* subgroupMatrixConfigs = info.Get<AdapterPropertiesSubgroupMatrixConfigs>()) {
+ DAWN_ASSERT([*mDevice supportsFamily:MTLGPUFamilyApple7]);
+
+ auto* configs = new SubgroupMatrixConfig[2];
+ subgroupMatrixConfigs->configCount = 2;
+ subgroupMatrixConfigs->configs = configs;
+
+ configs[0].componentType = wgpu::SubgroupMatrixComponentType::F32;
+ configs[0].resultComponentType = wgpu::SubgroupMatrixComponentType::F32;
+ configs[0].M = 8;
+ configs[0].N = 8;
+ configs[0].K = 8;
+
+ configs[1].componentType = wgpu::SubgroupMatrixComponentType::F16;
+ configs[1].resultComponentType = wgpu::SubgroupMatrixComponentType::F16;
+ configs[1].M = 8;
+ configs[1].N = 8;
+ configs[1].K = 8;
+ }
}
} // namespace dawn::native::metal
diff --git a/src/dawn/tests/end2end/SubgroupMatrixTests.cpp b/src/dawn/tests/end2end/SubgroupMatrixTests.cpp
index d334b58..138f7bd 100644
--- a/src/dawn/tests/end2end/SubgroupMatrixTests.cpp
+++ b/src/dawn/tests/end2end/SubgroupMatrixTests.cpp
@@ -27,11 +27,39 @@
#include <vector>
+#include "dawn/common/Math.h"
#include "dawn/tests/DawnTest.h"
+#include "dawn/utils/WGPUHelpers.h"
namespace dawn {
namespace {
+const char* ComponentTypeToWgslType(wgpu::SubgroupMatrixComponentType c) {
+ switch (c) {
+ case wgpu::SubgroupMatrixComponentType::F32:
+ return "f32";
+ case wgpu::SubgroupMatrixComponentType::F16:
+ return "f16";
+ case wgpu::SubgroupMatrixComponentType::U32:
+ return "u32";
+ case wgpu::SubgroupMatrixComponentType::I32:
+ return "i32";
+ }
+ return "<invalid>";
+}
+
+uint32_t ComponentTypeToByteSize(wgpu::SubgroupMatrixComponentType c) {
+ switch (c) {
+ case wgpu::SubgroupMatrixComponentType::F32:
+ case wgpu::SubgroupMatrixComponentType::U32:
+ case wgpu::SubgroupMatrixComponentType::I32:
+ return 4;
+ case wgpu::SubgroupMatrixComponentType::F16:
+ return 2;
+ }
+ return 0;
+}
+
using SubgroupMatrixTest = DawnTest;
// Test that it is only valid to request the AdapterPropertiesSubgroupMatrixConfigs structure if the
@@ -58,5 +86,192 @@
DAWN_INSTANTIATE_TEST(SubgroupMatrixTest, D3D12Backend(), MetalBackend(), VulkanBackend());
+enum MatrixOp {
+ MatrixMultiply,
+ MatrixMultiplyAccumulate,
+};
+DAWN_TEST_PARAM_STRUCT(MatrixMatrixArithmeticParams, MatrixOp);
+
+class SubgroupMatrixArithmeticTest : public DawnTestWithParams<MatrixMatrixArithmeticParams> {
+ protected:
+ std::vector<wgpu::FeatureName> GetRequiredFeatures() override {
+ std::vector<wgpu::FeatureName> features;
+ if (SupportsFeatures({wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix})) {
+ features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
+ }
+ if (SupportsFeatures({wgpu::FeatureName::ShaderF16})) {
+ features.push_back(wgpu::FeatureName::ShaderF16);
+ }
+ return features;
+ }
+};
+
+using SubgroupMatrix_MatrixMatrixArithmeticTest = SubgroupMatrixArithmeticTest;
+TEST_P(SubgroupMatrix_MatrixMatrixArithmeticTest, MatrixMultiply) {
+ DAWN_TEST_UNSUPPORTED_IF(
+ !adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix));
+
+ MatrixOp op = GetParam().mMatrixOp;
+
+ // Query the support subgroup matrix configurations.
+ wgpu::AdapterInfo info;
+ wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroupMatrixConfigs;
+ info.nextInChain = &subgroupMatrixConfigs;
+ ASSERT_EQ(adapter.GetInfo(&info), wgpu::Status::Success);
+
+ // Test each supported config.
+ for (size_t i = 0; i < subgroupMatrixConfigs.configCount; i++) {
+ auto& config = subgroupMatrixConfigs.configs[i];
+ uint32_t componentByteSize = ComponentTypeToByteSize(config.componentType);
+ uint32_t resultComponentByteSize = ComponentTypeToByteSize(config.resultComponentType);
+
+ // Generate a shader that performs a matrix multiplication that matches the config.
+ std::ostringstream shader;
+ shader << "enable chromium_experimental_subgroup_matrix;\n";
+ if (config.componentType == wgpu::SubgroupMatrixComponentType::F16 ||
+ config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
+ shader << "enable f16;\n";
+ }
+ shader << "\n";
+ shader << "alias ComponentType = " << ComponentTypeToWgslType(config.componentType)
+ << ";\n";
+ shader << "alias ResultComponentType = "
+ << ComponentTypeToWgslType(config.resultComponentType) << ";\n";
+ shader << "\n";
+ shader << "const M = " << config.M << ";\n";
+ shader << "const N = " << config.N << ";\n";
+ shader << "const K = " << config.K << ";\n";
+ shader << R"(
+@group(0) @binding(0) var<storage, read> inputs : array<ComponentType, K*M + N*K>;
+@group(0) @binding(1) var<storage, read_write> output : array<ResultComponentType, M*N>;
+
+@compute @workgroup_size(N, M)
+fn main() {
+ let lhs = subgroupMatrixLoad<subgroup_matrix_left<ComponentType, K, M>>(&inputs, 0, false, M);
+ let rhs = subgroupMatrixLoad<subgroup_matrix_right<ComponentType, N, K>>(&inputs, K*M, false, K);
+)";
+ switch (op) {
+ case MatrixMultiply:
+ shader << "let result = subgroupMatrixMultiply(lhs, rhs);\n";
+ break;
+ case MatrixMultiplyAccumulate:
+ // Perform the multiplication twice, accumulating into a zero matrix the first time.
+ shader << "let zero = subgroup_matrix_result<ResultComponentType, N, M>();\n";
+ shader << "var result = subgroupMatrixMultiplyAccumulate(lhs, rhs, zero);\n";
+ shader << "result = subgroupMatrixMultiplyAccumulate(lhs, rhs, result);\n";
+ break;
+ }
+ shader << R"(
+ subgroupMatrixStore(&output, 0, result, false, M);
+})";
+
+ wgpu::ComputePipelineDescriptor csDesc;
+ csDesc.compute.module = utils::CreateShaderModule(device, shader.str());
+ wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&csDesc);
+
+ // Convert the matrix multiplication result value to the result component type.
+ auto toResultType = [&](auto value) -> uint32_t {
+ switch (config.resultComponentType) {
+ case wgpu::SubgroupMatrixComponentType::F32: {
+ float valueF32 = static_cast<float>(value);
+ return *reinterpret_cast<uint32_t*>(&valueF32);
+ }
+ case wgpu::SubgroupMatrixComponentType::F16: {
+ uint16_t valueF16 = Float32ToFloat16(static_cast<float>(value));
+ return (uint32_t(valueF16) << 16) | valueF16;
+ }
+ case wgpu::SubgroupMatrixComponentType::U32:
+ return uint32_t(value);
+ case wgpu::SubgroupMatrixComponentType::I32:
+ int32_t valueI32 = static_cast<int32_t>(value);
+ return *reinterpret_cast<uint32_t*>(&valueI32);
+ }
+ return 0;
+ };
+
+ // Generate the value to fill the input matrices with as a 32-bit word, and generate the
+ // corresponding output value as well. Pack multiple copies of the value together if the
+ // size of the input component type is less than 32 bits.
+ uint32_t inputValue;
+ uint32_t outputValue;
+ switch (config.componentType) {
+ case wgpu::SubgroupMatrixComponentType::F32: {
+ float in = 0.5;
+ float out = in * in * config.K;
+ if (op == MatrixMultiplyAccumulate) {
+ out *= 2;
+ }
+ inputValue = *reinterpret_cast<uint32_t*>(&in);
+ outputValue = toResultType(out);
+ break;
+ }
+ case wgpu::SubgroupMatrixComponentType::F16: {
+ float inF32 = 0.5;
+ uint16_t in = Float32ToFloat16(inF32);
+ float out = inF32 * inF32 * config.K;
+ if (op == MatrixMultiplyAccumulate) {
+ out *= 2;
+ }
+ inputValue = (uint32_t(in) << 16) | in;
+ outputValue = toResultType(out);
+ break;
+ }
+ case wgpu::SubgroupMatrixComponentType::U32:
+ case wgpu::SubgroupMatrixComponentType::I32: {
+ uint32_t in = 2;
+ uint32_t out = in * in * config.K;
+ if (op == MatrixMultiplyAccumulate) {
+ out *= 2;
+ }
+ inputValue = in;
+ outputValue = toResultType(out);
+ break;
+ }
+ }
+
+ uint32_t numInputElements = (config.M + config.N) * config.K;
+ std::vector<uint32_t> inValues(numInputElements * componentByteSize / 4, inputValue);
+ std::vector<uint32_t> expected(config.M * config.N * resultComponentByteSize / 4,
+ outputValue);
+ wgpu::Buffer inputs = utils::CreateBufferFromData(
+ device, inValues.data(), inValues.size() * 4, wgpu::BufferUsage::Storage);
+
+ wgpu::BufferDescriptor outputDescriptor;
+ outputDescriptor.size = config.M * config.N * resultComponentByteSize;
+ outputDescriptor.usage = wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage;
+ wgpu::Buffer output = device.CreateBuffer(&outputDescriptor);
+
+ wgpu::BindGroup bindGroup = utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0),
+ {{0, inputs}, {1, output}});
+
+ wgpu::CommandBuffer commands;
+ {
+ wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+ wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
+ pass.SetPipeline(pipeline);
+ pass.SetBindGroup(0, bindGroup);
+ pass.DispatchWorkgroups(1);
+ pass.End();
+
+ commands = encoder.Finish();
+ }
+
+ queue.Submit(1, &commands);
+
+ EXPECT_BUFFER_U32_RANGE_EQ(expected.data(), output, 0, expected.size());
+ }
+}
+
+DAWN_INSTANTIATE_TEST_P(SubgroupMatrix_MatrixMatrixArithmeticTest,
+ {
+ D3D12Backend(),
+ MetalBackend(),
+ VulkanBackend(),
+ },
+ {
+ MatrixOp::MatrixMultiply,
+ MatrixOp::MatrixMultiplyAccumulate,
+ });
+
} // anonymous namespace
} // namespace dawn