Subgroup Matrix

Background

Subgroup matrices are an abstract matrix data type. Computations on subgroup matrices utilize the SIMD nature of the GPU and are distributed among multiple invocations. Subgroup matrices are a critical primitive for ML workloads: Large matrix-multiply operations can be decomposed into many matrix multiplies over small sub-blocks of the larger matrices. (Note that 1x1 convolutions are mathematically the same as matrix multiply.)

Many GPUs now support this feature, with various parameterizations of block sizes, component data types, and result types. In the following, a typical block matrix multiply operation is:

Accum = A * B

or

Accum = A * B + Accum


where  Accum is M-row, N-column,


A is M-row K-column,


B is K-row, N-column.

Target API support

SPIR-V/Vulkan

Originally introduced as an Nvidia vendor extension (SPV_NV_cooperative_matrix and VK_NV_cooperative_matrix). NVIDIA presented an overview of the feature and its benefits, in a 2019 blog post and a 2022 presentation, and optimization guide. The functionality has since been standardized as SPV_KHR_cooperative_matrix and VK_KHR_cooperative_matrix.

The Vulkan extension has two feature bits:

  • cooperativeMatrix: basic feature support
  • cooperativeMatrixRobustBufferAccess: robust buffer support for cooperative matrix load and store (usual caveats).

There is a supported stages property in VkPhysicalDeviceCooperativeMatrixPropertiesKHR, but I’ve only ever seen compute supported.

OpTypeCooperativeMatrixKHR represents the abstract type. It parameterized as follows:

  • Component type: Matrix element type (must be a numerical scalar) (Float16/32/64, S/UInt8/16/32/64 currently)
  • Scope: the scope of operations on the type (Device, QueueFamily, Workgroup, or Subgroup)
  • Rows: Number of rows
  • Columns: Number of columns
  • Use: an enum specifying which kind of matrix this is used as:
    • MatrixAKHR: LHS of a multiply (M rows x K columns)
    • MatrixBKHR: RHS of a multiply (K rows x N columns)
    • MatrixAccumulatorKHR: result of multiply-accumulate (M rows x N columns)

SPIR-V can represent a wide variety of cooperative matrices, but actual devices only support a subset of parameterizations. These are queried from the API by calling vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR. It returns a linked list of valid enumerations. Preferred variants are listed earlier in the chain. At least one entry in the list must have power of two values for all of MSize, KSize, and NSize. The enumerations indicate whether they support saturating accumulation. Because the type is abstract it can only be stored in the Function and Private storage classes. Special load and store instructions are used to translate to/from backing memory.

The fundamental operation is OpCooperativeMatrixMulAddKHR. It performs a multiply-accumulate operation. It requires consistent values of scope, M, N, and K for all types. The component types may differ, the operations occur using the result type. Signedness of results and operands is indicated using the Cooperative Matrix Operands (as is saturating accumulation). The result is undefined if saturating accumulation is used, but intermediate results overflow (same as integer dot product).

OpCooperativeMatrixLoadKHR and OpCooperativeMatrixStoreKHR share the following operands:

  • Pointer: a logical pointer into an array (stride is ignored) (limited to Workgroup, StorageBuffer, and PhysicalStorageBuffer storage classes in Vulkan)
  • Memory Layout: an enum for row- or column-majorness
  • Stride: an optional operand indicating the stride between elements (aligned to min(16, align(col/row))). The stride counts the number of elements in the pointee type; it is not a byte count. For example, if the pointee type of the Pointer argument is 2xf16 then a stride of 8 translates to a byte stride of 8x (2x2) bytes.

These operations convert to/from the abstract matrix type and memory. The operands must be uniform values for the scope.

Additional operations:

  • OpCooperativeMatrixLengthKHR: Number of accessible components from an invocation for a given cooperative matrix type.
  • Conversions: All standard conversions are allowed (FToU, UToF, FToS, SToF, U, S, F, bitcast).
  • Arithmetic: Negate, add, sub, mul, div using the opcode for the appropriate component type. Additionally, OpMatrixTimesScalar is also supported.
  • Constants: OpCompositeConstruct and OpConstantComposite can provide a single fill value for a cooperative matrix.

All operations involving a cooperative matrix must only be executed in uniform control flow for scope of the operation.

Cooperative matrices require the Vulkan memory model in SPIR-V.

OpTypeCooperativeMatrixKHR can only be instantiated by a variable in the Function or Private storage classes.

HLSL/Direct3D

Microsoft calls them WaveMatrices, and added them to SM6.8 as experimental.

Google provided feedback via https://github.com/microsoft/hlsl-specs/issues/72

This experimental feature was later withdrawn.

Jul 14, 2025 Chris Bieneman posted a revised proposal for HLSL linalg. See https://github.com/microsoft/hlsl-specs/pull/556

MSL/Metal

Apple calls them simdgroup matrices and support has existed since MSL 2.3. They are limited to relatively few GPUs though (Apple7+ in the feature set). That is iPhone 12+, some newer iPads, and the M1 and later macs.

MSL support a small number of matrix variants:

  • simdgroup_float8x8
  • simdgroup_half8x8
  • simdgroup_bfloat8x8 (MSL 3.1)

The simdgroup_float8x8 and simdgroup_half8x8 types are equivalent to a SPIR-V 8x8 cooperative matrix with Subgroup scope and corresponding component type. Unlike SPIR-V and HLSL, MSL does not specialize the matrices based on the use.

MSL also has more limited supported operations on simdgroup matrices. For load/store/creation:

  • simdgroup_matrix<T, M, N>: Create diagonal matrix with single value.
  • make_filled_simdgroup_matrix<T, M, N>: Create single value filled matrix.
  • simdgroup_load/store: specify stride, layout (via transpose), and an offset (via origin)

Similar to SPIR-V the device or threadgroup pointer points to a matrix component type.

MSL supports multiply and multiply-accumulate (non-saturating) arithmetic operations.

All functions must be called from simdgroup uniform control flow. Presumably load and store parameter values must be simdgroup uniform though I can’t see this called out in the spec.

MSL allows interfaces to be declared as simdgroup_matrix, but no operators exist for them to make them usable. Effectively, this limits them to function storage.

SYCL

SYCL is a single-source CPU/GPU compute language, and may reflect support among GPU hardware, at least as exposed through OpenCL.

SYCL has a corresponding proposal, using the terminology “joint matrix”. See https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_matrix/sycl_ext_oneapi_matrix.asciidoc

Proposal

WGSL

Enable Extension

Add a new enable extension, chromium_experimental_subgroup_matrix. Enables the declaration of subgroup_matrix types, related built-in functions, and limited use of some scalar types (u8 and i8). Implicitly depends on subgroups. WGSL does not require any subgroups enable, but the API does.

Types

Add new types:

  • subgroup_matrix_left<T, K, M>: M rows x K columns matrix of T components.
  • subgroup_matrix_right<T, N, K>: K rows x N columns matrix of T components
  • subgroup_matrix_result<T, N, M>: M rows x N columns matrix of T components
  • M, N, and K are override-expressions of positive integer values
    • For initial prototyping and feedback, we will restrict them to const-expression. We should finalize const vs. override before full standardization.
    • Similar to arrays, two subgroup matrices are the same if:
      • They are the same base type
      • They have the same component type
      • The have matching row counts such that:
        • Both are equal valued const-expressions, or
        • Both resolve to the same override-declaration
      • The have matching column counts such that:
        • Both are equal valued const-expressions, or
        • Both resolve to the same override-declaration
      • Note: like arrays, types cannot match if the row or column counts use an override-expression that is not equivalent to override identifier.
  • T is the component type and must be one of the entries in the table
    • The scalar shader type is the associated type usable in the shader code for scalar operations and data representation in memory
    • Can be expanded in the future to support more types (e.g. bfloat16) via new enables.
    • The u8 and i8 cases are predeclared types that are not otherwise usable in WGSL. For layout calculations, they are of size 1 byte and have an alignment requirement of 1 byte.
TypeExtra EnableElement Stride (bytes)Shader Scalar TypeMin ValueMax Value
f324f32
f16f162f16
u324u32
i324i32
u81u320255
i81i32-128127

These types are not considered “composite” in the WGSL taxonomy, because they are not decomposible. You can’t reference a sub-vector or a single component. The numeric dimensions must be override-expressions. These types cannot be part of any interface (i.e. they can only be instantiated in Function and Private address spaces). They are plain types (similar to atomics) so that they can be included in composite types. An important use case is to make arrays of these matrices.

It is a pipeline-creation error if any matrix type is not included in a supported GPUSubgroupMatrixConfig, config, on the device:

  • For subgroup_matrix_left:
    • M equals config.M
    • K equals config.K
    • T matches config.componentType
  • For subgroup_matrix_right:
    • K equals config.K
    • N equals config.N
    • T matches config.componentType
  • For subgroup_matrix_result:
    • M equals config.M
    • N equals config.N
    • T matches config.resultType

Why use a “left” and “right” matrix type, instead of a single matrix type like Metal has?

  • D3D and Vulkan both need them.
  • There is no free transpose operation. If you had a free transpose, then you can take advantage of: A* B = transpose(transpose(B)*transpose(A)). I haven’t seen a transpose operation in APIs or hardware specs.

Variables

A variable containing a subgroup matrix can only be instantiated in Function or Private address spaces. (This limitation comes from SPIR-V and Metal). These variables are meant to be used as very temporary scratch space.

Loading and storing

Builtin functions are used to load and store subgroup matrix values from variables in workgroup and storage address spaces. The builtins do two things:

  • Map matrix row and column indices to external memory locations. This mapping has two parameters: majorness (row-major, or column-major), and an integer Stride. Let Base be the byte address of the start of the external matrix, i.e. where the [0,0]’th element of the matrix is stored. Then
    • For row-major:
      • Matrix entry [r,c] maps to the sizeof(T) bytes located at Base + Stride*r*sizeof(T) + sizeof(T)*c
      • Stride >= number of matrix columns.
    • For column-major:
      • Matrix entry [r,c] maps to the sizeof(T) bytes located at Base + Stride*c*sizeof(T) + sizeof(T)*r
      • Stride >= number of matrix rows.
  • Reinterpret data values between the shader scalar type and the external component type T, when those types differ.

For a subgroup_matrix_left/right/result<T, Cols, Rows>, loads and stores are out-of-bounds if the length of the array of the pointer argument is less than Offset + Stride * Rows* Cols.

Attributes

Add an additional error condition to workgroup_size. It is a pipeline-creation error if the x dimension is not a multiple of GPUAdapterInfo.subgroupMaxSize and the shader statically accesses any subgroup matrix.

Note: this is where the requirement on the subgroups feature stems from in practice.

Expressions

Supported expressions:

  • Parenthesization
  • Identifier expressions
  • Indirection
  • Address-of
  • Function call expression
  • Type expressions

Explicitly not supported:

  • Decomposition expressions

Possible future expansion:

  • Arithmetic expressions

Built-in Values

New built-in value: subgroup_id. Input in compute shaders of type u32. It is the current invocation‘s subgroup’s ID within the workgroup.

Note: This was not included in the subgroups feature because there was insufficient documentation detailing the mapping in HLSL between workgroups and subgroups. This experiment is not implemented on D3D so the built-in value is available on supported platforms. If guaranteed mapping from Microsoft is provided for D3D, subgroup_id could be added as a language feature on top of the subgroups feature.

Built-in Functions

Calls to these functions:

  • Must only be used in a compute shader stage.
  • Trigger a chromium_experimental.subgroup_matrix_uniformity diagnostic if uniformity analysis cannot prove the call is in subgroup uniform control flow.
Value constructors

Overloads:

@must_use fn T(value : S) -> T
@must_use fn T() -> T

Preconditions:
T is a subgroup matrix type whose shader scalar type is S.

Description:
Create a subgroup matrix filled with value.

When no value is provided, use S().

Triggers a chromium_experimental.subgroup_matrix_uniformity diagnostic if uniformity analysis cannot prove value is a subgroup uniform value.

Load/store functions

See Loading and Storing above.

Overload:

@must_use fn
subgroupMatrixLoad<T>(p : ptr<AS, SA, AM>,
                      offset : u32,
                      col_major : bool,
                      stride : u32) -> T

Preconditions:
T is a subgroup matrix type with shader scalar type S.
SA is an array with type S.
AS is storage or workgroup.
AM is read or read_write.

Description:
Load a subgroup matrix from p, offset elements from the start of the array.

col_major must be an override-expression.

Triggers a chromium_experimental.subgroup_matrix_uniformity diagnostic if uniformity analysis cannot prove p, offset, or stride are subgroup uniform values.

stride counts elements of the component type of T. Behavior is undefined if stride is less than:

  • The number of rows of T if col_major is true
  • The number of columns of T is col_major is false

Overload:

fn subgroupMatrixStore(p : ptr<AS, SA, AM>,
                       offset : u32,
                       value : T,
                       col_major : bool,
                       stride : u32)

Preconditions:
T is a subgroup matrix type whose scalar shader type is S.
SA is an array with element type S.
AS is storage or workgroup.
AM is write or read_write.

Description:
Store the subgroup matrix value into p, offset elements from the start of the array.

col_major must be an override-expression.

Triggers a chromium_experimental.subgroup_matrix_uniformity diagnostic if uniformity analysis cannot prove p, offset, value, or stride are subgroup uniform values.

stride counts elements of the component type of T. Behavior is undefined if stride is less than:

  • The number of rows of T if col_major is true
  • The number of columns of T is col_major is false
Matrix arithmetic functions

The operands of a subgroup matrix arithmetic function comprise a supported subgroup matrix configuration if the device has a GPUSubgroupMatrixConfig, config, such that all operand types match as below:

  • For L:
    • M equals config.M
    • K equals config.K
    • T matches config.componentType
  • For R:
    • K equals config.K
    • N equals config.N
    • T matches config.componentType
  • For RT:
    • M equals config.M
    • N equals config.N
    • TR matches config.resultType

Overload:

@must_use fn
subgroupMatrixMultiply<RT>(left : L, right : R) -> RT

Preconditions:
L is a subgroup_matrix_left<T, M, K>.
R is a subgroup_matrix_right<T, K, N>.
RT is a subgroup_matrix_result<TR, M, N>.

Description:
Matrix multiply.

It is a pipeline-creation error if L, R, and RT do not comprise a supported subgroup matrix configuration.

Triggers a chromium_experimental.subgroup_matrix_uniformity diagnostic if uniformity analysis cannot prove left or right are subgroup uniform values.

Overload:

@must_use fn
subgroupMatrixMultiplyAccumulate<RT>(left : L,
                                     right : R,
                                     acc : RT) -> RT

Preconditions:
L is a subgroup_matrix_left<T, M, K>.
R is a subgroup_matrix_right<T, K, N>.
RT is a subgroup_matrix_result<TR, M, N>.

Description:
Matrix multiply add (left * right + acc).

It is a pipeline-creation error if L, R, and RT do not comprise a supported subgroup matrix configuration.

Triggers a chromium_experimental.subgroup_matrix_uniformity diagnostic if uniformity analysis cannot prove left or right are subgroup uniform values.

Scalar arithmetic functions

These functions are useful for operations such as applying biases to a model.

Overload:

@must_use fn
subgroupMatrixScalarAdd(matrix : M, value : S) -> M

Preconditions:
M is a subgroup matrix type with scalar shader type S.

Description:
Scalar addition.

value is clamped to a valid range for the component type of M and then added to each element of matrix.

Triggers a chromium_experimental.subgroup_matrix_uniformity diagnostic if uniformity analysis cannot prove matrix or value are subgroup uniform values.

Overload:

@must_use fn
subgroupMatrixScalarSubtract(matrix : M, value : S) -> M

Preconditions:
M is a subgroup matrix type with scalar shader type S.

Description:
Scalar subtraction.

value is clamped to a valid range for the component type of M and then added to each element of matrix.

Triggers a chromium_experimental.subgroup_matrix_uniformity diagnostic if uniformity analysis cannot prove matrix or value are subgroup uniform values.

Overload:

@must_use fn
subgroupMatrixScalarMultiply(matrix : M, value : S) -> M

Preconditions:
M is a subgroup matrix type with scalar shader type S.

Description:
Scalar multiplication.

value is clamped to a valid range for the component type of M and then added to each element of matrix.

Triggers a chromium_experimental.subgroup_matrix_uniformity diagnostic if uniformity analysis cannot prove matrix or value are subgroup uniform values.

REMOVED (in an attempt to be safely forward compatible if Metal adds integral components, but no other operators)

Overload:

@must_use fn
subgroupMatrixScalarDivide(matrix : M, value : S) -> M

Preconditions:
M is a subgroup matrix type with scalar shader type S.

Description:
Scalar division.

value is clamped to a valid range for the component type of M and then added to each element of matrix.

Triggers a chromium_experimental.subgroup_matrix_uniformity diagnostic if uniformity analysis cannot prove matrix or value are subgroup uniform values.

Uniformity

Subgroup matrices are an abstraction of data spread among the invocations in a subgroup. As such the built-in functions must only be called from subgroup uniform control flow. Additionally, most of the parameter values must also be subgroup uniform values.

WGSL does not currently represent subgroup uniformity and empirical testing shows that implementations do not provide reliable reconvergence. Vulkan backends that implement VK_KHR_shader_subgroup_uniform_control_flow or, even better, VK_shader_maximal_reconvergence would be able to make enabling the analysis at other scopes reasonable, but that is not portable.

For this experimental extension we should add a diagnostic, chromium_experimental.subgroup_matrix_uniformity, that defaults to error. This diagnostic would be triggered based on workgroup uniformity violations (if the control is not off) for all subgroup matrix built-in functions. This allows applications some control. We can gather data about the usefulness of the static analysis for subgroup matrices. It is likely that most use cases would satisfy workgroup uniformity.

Floating-point Accuracy

subgroupMatrixLoad and subgroupMatrixStore should be bit preserving. ULPs are undefined for other operations.

API

New GPUFeatureName chromium-experimental-subgroup-matrix

New immutable array, subgroupMatrixConfigs, added to GPUAdapterInfo.

partial interface GPUAdapterInfo {
  [SameObject] readonly attribute FrozenArray<GPUSubgroupMatrixConfig> subgroupMatrixConfigs;
};

enum GPUSubgroupMatrixComponentType {
  "f32",
  "f16",
  "u32",
  "i32",
  "u8",
  "i8",
};

interface GPUSubgroupMatrixConfig {
  readonly attribute GPUSubgroupMatrixComponentType componentType;
  readonly attribute GPUSubgroupMatrixComponentType resultComponentType;
  readonly attribute unsigned long M;
  readonly attribute unsigned long N;
  readonly attribute unsigned long K;
};

Validation

No specific API validation is necessary; however, it is likely easier to implement the pipeline-creation error checks in Dawn by adding the subgroup matrix configurations to the shader reflection information.

WGSL pipeline-creation checks (repeated for ease of reference):

  • All subgroup matrix types are part of a GPUSubgroupMatrixConfig
  • All subgroup matrix operands in subgroupMatrixMultiply and subgroupMatrixMultipleAccumulate are part of a single GPUSubgroupMatrixConfig
  • The x-dimension of workgroup_size is a multiple of GPUSupportedLimits::maxSubgroupSize

Mapping

Types

TypeSPIR-V1,2MSL3
subgroup_matrix_left<T, K, M>OpTypeCooperativeMatrixKHR
MatrixAKHR use
M rows
K cols
T component type
simdgroup_matrix<T, 8, 8>
subgroup_matrix_right<T, N, K>OpTypeCooperativeMatrixKHR
MatrixBKHR use
K rows
N cols
T component type
simdgroup_matrix<T, 8, 8>
subgroup_matrix_result<T, N, M>OpTypeCooperativeMatrixKHR
MatrixAccumulatorKHR use
M rows
N cols
T component type
simdgroup_matrix<T, 8, 8>
  1. All OpTypeCooperativeMatrixKHR use subgroup scope.
  2. Component type enum maps directly to SPIR-V type (e.g. i8 to OpTypeInt 8 1).
  3. MSL types use the usual mappings (e.g. f32 to float).

Builtin-in Values

ValueSPIR-VMSL
subgroup_idSubgroupIdsimdgroup_index_in_threadgroup

Functions

FunctionSPIR-VMSL
Value constructorsOpCompositeConstruct with appropriate value
const-/override-expressions could use constant instructions
make_filled_simdgroup_matrix
subgroupMatrixLoadOpCooperativeMatrixLoadKHR
Pointer operand is a direct translation of the WGSL pointer
simdgroup_matrix_load
The WGSL pointer needs translated into the origin operand in MSL
subgroupMatrixStoreOpCooperativeMatrixStoreKHR
Pointer operand is a direct translation of the WGSL pointer
simdgroup_matrix_store
The WGSL pointer needs translated into the origin operand in MSL
subgroupMatrixMultiplyOpCooperativeMatrixMulAddKHR
C is a zero value matrix matching the result type
simdgroup_multiply
subgroupMatrixMultiplyAccumulateOpCooperativeMatrixMulAddKHRsimdgroup_multiply_accumulate
subgroupMatrixScalarAddOpI/FAdd with composite constructed valuesimdgroup_multiply_accumulate with diagonal matrix for b and filled matrix scalar value for c
subgroupMatrixScalarSubtractOpI/FSub with composite constructed valuesimdgroup_multiply_accumulate with diagonal matrix for b and filled inverted matrix scalar value for c
subgroupMatrixScalarMultiplyOpI/FMul with composite constructed valuesimdgroup_multiply with diagonal matrix scalar value for b1
subgroupMatrixScalarDivideOpI/S/FDiv with composite constructed scalar or OpMatrixTimesScalar (for float components with reciprocal scalar value)simdgroup_multiply with diagonal matrix reciprocal scalar value for b1
  1. This works because MSL only supports floating-point component types.

Properties

Vulkan

Filter the list returned from vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR such that:

  • A type and B type match
  • C type and result type match
  • A, B, C, and result types are one of (-> API enum):
    • VK_COMPONENT_TYPE_FLOAT32_KHR -> f32
    • VK_COMPONENT_TYPE_FLOAT16_KHR -> f16
    • VK_COMPONENT_TYPE_SINT32_KHR -> i32
    • VK_COMPONENT_TYPE_UINT32_KHR -> u32
    • VK_COMPONENT_TYPE_SINT8_KHR -> i8
    • VK_COMPONENT_TYPE_UINT8_KHR -> u8
  • saturatingAccumulation is false
  • Scope is VK_SCOPE_SUBGROUP_KHR

VK_COMPONENT_TYPE_FLOAT16 will need to be filtered out of the device properties if the shader-f16 feature is not requested.

Metal

Hardcode the following configurations if the feature is supported:

componentTyperesultComponentTypeMNK
f32f32888
f16f16888
  1. Filter out f16 from the device properties if shader-f16 is not requested on the device.

Future Expansion

It is likely that more component types can be supported in the future. Metal can already support bfloat16 for example.

We could consider exposing saturating accumulation.