Add feature documentation for subgroup matrix
Change-Id: I8ef2057f00c0fbfb2a827ef20874081e1ebdcda6
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/260495
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
Auto-Submit: Alan Baker <alanbaker@google.com>
Commit-Queue: Alan Baker <alanbaker@google.com>
diff --git a/docs/dawn/features/subgroup_matrix.md b/docs/dawn/features/subgroup_matrix.md
new file mode 100644
index 0000000..2b51760
--- /dev/null
+++ b/docs/dawn/features/subgroup_matrix.md
@@ -0,0 +1,1156 @@
+# Subgroup Matrix
+
+## Background
+
+Subgroup matrices are an abstract matrix data type. Computations on subgroup matrices utilize the SIMD nature of the GPU and are distributed among multiple invocations. Subgroup matrices are a critical primitive for ML workloads: Large matrix-multiply operations can be decomposed into many matrix multiplies over small sub-blocks of the larger matrices. (Note that 1x1 convolutions are mathematically the same as matrix multiply.)
+
+
+
+* Backgrounder: https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
+
+Many GPUs now support this feature, with various parameterizations of block sizes, component data types, and result types. In the following, a typical block matrix multiply operation is:
+
+
+ Accum = A * B
+
+or
+
+ Accum = A * B + Accum
+
+
+ where Accum is M-row, N-column,
+
+
+ A is M-row K-column,
+
+
+ B is K-row, N-column.
+
+
+## Target API support
+
+
+### SPIR-V/Vulkan
+
+Originally introduced as an Nvidia vendor extension ([SPV_NV_cooperative_matrix](https://github.khronos.org/SPIRV-Registry/extensions/NV/SPV_NV_cooperative_matrix.html) and [VK_NV_cooperative_matrix](https://registry.khronos.org/vulkan/specs/1.3-extensions/html/vkspec.html#VK_NV_cooperative_matrix)). NVIDIA presented an overview of the feature and its benefits, in a 2019 blog post and a [2022 presentation](https://www.khronos.org/assets/uploads/developers/presentations/Cooperative_Matrix_May22.pdf), and [optimization guide](https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html). The functionality has since been standardized as [SPV_KHR_cooperative_matrix](https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_cooperative_matrix.html) and [VK_KHR_cooperative_matrix](https://registry.khronos.org/vulkan/specs/1.3-extensions/html/vkspec.html#VK_KHR_cooperative_matrix).
+
+The Vulkan extension has two feature bits:
+
+
+
+* cooperativeMatrix: basic feature support
+* cooperativeMatrixRobustBufferAccess: robust buffer support for cooperative matrix load and store (usual caveats).
+
+There is a supported stages property in <code>[VkPhysicalDeviceCooperativeMatrixPropertiesKHR](https://registry.khronos.org/vulkan/specs/1.3-extensions/html/vkspec.html#VkPhysicalDeviceCooperativeMatrixPropertiesKHR)</code>, but I’ve only ever seen compute supported.
+
+OpTypeCooperativeMatrixKHR represents the abstract type. It parameterized as follows:
+
+
+
+* Component type: Matrix element type (must be a numerical scalar) (Float16/32/64, S/UInt8/16/32/64 currently)
+* Scope: the scope of operations on the type (Device, QueueFamily, Workgroup, or Subgroup)
+* Rows: Number of rows
+* Columns: Number of columns
+* Use: an enum specifying which kind of matrix this is used as:
+ * MatrixAKHR: LHS of a multiply (M rows x K columns)
+ * MatrixBKHR: RHS of a multiply (K rows x N columns)
+ * MatrixAccumulatorKHR: result of multiply-accumulate (M rows x N columns)
+
+SPIR-V can represent a wide variety of cooperative matrices, but actual devices only support a subset of parameterizations. These are queried from the API by calling [vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR](https://registry.khronos.org/vulkan/specs/1.3-extensions/html/vkspec.html#vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR). It returns a linked list of [valid enumerations](https://registry.khronos.org/vulkan/specs/1.3-extensions/html/vkspec.html#VkCooperativeMatrixPropertiesKHR). Preferred variants are listed earlier in the chain. At least one entry in the list must have power of two values for all of MSize, KSize, and NSize. The enumerations indicate whether they support saturating accumulation. Because the type is abstract it can only be stored in the Function and Private storage classes. Special load and store instructions are used to translate to/from backing memory.
+
+The fundamental operation is OpCooperativeMatrixMulAddKHR. It performs a multiply-accumulate operation. It requires consistent values of scope, M, N, and K for all types. The component types may differ, the operations occur using the result type. Signedness of results and operands is indicated using the Cooperative Matrix Operands (as is saturating accumulation). The result is undefined if saturating accumulation is used, but intermediate results overflow (same as integer dot product).
+
+OpCooperativeMatrixLoadKHR and OpCooperativeMatrixStoreKHR share the following operands:
+
+
+
+* Pointer: a logical pointer into an array (stride is ignored) (limited to Workgroup, StorageBuffer, and PhysicalStorageBuffer storage classes in Vulkan)
+* Memory Layout: an enum for row- or column-majorness
+* Stride: an optional operand indicating the stride between elements (aligned to min(16, align(col/row))). The stride counts the number of elements in the pointee type; it is not a byte count. For example, if the pointee type of the Pointer argument is 2xf16 then a stride of 8 translates to a byte stride of 8x (2x2) bytes.
+
+These operations convert to/from the abstract matrix type and memory. The operands **must** be uniform values for the scope.
+
+Additional operations:
+
+
+
+* OpCooperativeMatrixLengthKHR: Number of accessible components from an invocation for a given cooperative matrix type.
+* Conversions: All standard conversions are allowed (FToU, UToF, FToS, SToF, U, S, F, bitcast).
+* Arithmetic: Negate, add, sub, mul, div using the opcode for the appropriate component type. Additionally, OpMatrixTimesScalar is also supported.
+* Constants: OpCompositeConstruct and OpConstantComposite can provide a single fill value for a cooperative matrix.
+
+All operations involving a cooperative matrix **must** only be executed in uniform control flow for scope of the operation.
+
+Cooperative matrices **require** the Vulkan memory model in SPIR-V.
+
+OpTypeCooperativeMatrixKHR can only be instantiated by a variable in the Function or Private storage classes.
+
+
+### HLSL/Direct3D
+
+Microsoft calls them WaveMatrices, and added them to [SM6.8](https://microsoft.github.io/DirectX-Specs/d3d/HLSL_SM_6_8_WaveMatrix.html) as experimental.
+
+Google provided feedback via https://github.com/microsoft/hlsl-specs/issues/72
+
+This experimental feature was later withdrawn.
+
+Jul 14, 2025 [Chris Bieneman](mailto:cbieneman@microsoft.com) posted a revised proposal for HLSL linalg. See https://github.com/microsoft/hlsl-specs/pull/556
+
+
+### MSL/Metal
+
+Apple calls them simdgroup matrices and support has existed since MSL 2.3. They are limited to relatively few GPUs though (Apple7+ in the [feature set](https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf)). That is iPhone 12+, some newer iPads, and the M1 and later macs.
+
+MSL support a small number of matrix variants:
+
+
+
+* simdgroup_float8x8
+* simdgroup_half8x8
+* simdgroup_bfloat8x8 (MSL 3.1)
+
+The simdgroup_float8x8 and simdgroup_half8x8 types are equivalent to a SPIR-V 8x8 cooperative matrix with Subgroup scope and corresponding component type. Unlike SPIR-V and HLSL, MSL does not specialize the matrices based on the use.
+
+MSL also has more limited supported operations on simdgroup matrices. For load/store/creation:
+
+
+
+* simdgroup_matrix<_T_, _M_, _N_>: Create diagonal matrix with single value.
+* make_filled_simdgroup_matrix<_T_, _M_, _N_>: Create single value filled matrix.
+* simdgroup_load/store: specify stride, layout (via transpose), and an offset (via origin)
+
+Similar to SPIR-V the device or threadgroup pointer points to a matrix component type.
+
+MSL supports multiply and multiply-accumulate (non-saturating) arithmetic operations.
+
+All functions **must** be called from simdgroup uniform control flow. **Presumably** load and store parameter values **must** be simdgroup uniform though I can’t see this called out in the spec.
+
+MSL allows interfaces to be declared as simdgroup_matrix, but no operators exist for them to make them usable. Effectively, this limits them to function storage.
+
+
+### SYCL
+
+SYCL is a single-source CPU/GPU compute language, and may reflect support among GPU hardware, at least as exposed through OpenCL.
+
+SYCL has a corresponding proposal, using the terminology “joint matrix”. See https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_matrix/sycl_ext_oneapi_matrix.asciidoc
+
+
+## Proposal
+
+
+### WGSL
+
+
+#### Enable Extension
+
+Add a new enable extension, `chromium_experimental_subgroup_matrix`. Enables the declaration of subgroup_matrix types, related built-in functions, and limited use of some scalar types (u8 and i8). Implicitly depends on subgroups. WGSL does not require any subgroups enable, but the API does.
+
+
+#### Types
+
+Add new types:
+
+
+
+* subgroup_matrix_left<_T_, _K_, _M_>: M rows x K columns matrix of T components.
+* subgroup_matrix_right<_T_, _N_, _K_>: K rows x N columns matrix of T components
+* subgroup_matrix_result<_T_, _N_, _M_>: M rows x N columns matrix of T components
+* M, N, and K are override-expressions of positive integer values
+ * For _initial prototyping and feedback_, we will restrict them to const-expression. We should finalize const vs. override before full standardization.
+ * Similar to arrays, two subgroup matrices are the same if:
+ * They are the same base type
+ * They have the same component type
+ * The have matching row counts such that:
+ * Both are equal valued const-expressions, or
+ * Both resolve to the same override-declaration
+ * The have matching column counts such that:
+ * Both are equal valued const-expressions, or
+ * Both resolve to the same override-declaration
+ * Note: like arrays, types cannot match if the row or column counts use an override-expression that is not equivalent to override identifier.
+* T is the component type and must be one of the entries in the table
+ * The scalar shader type is the associated type usable in the shader code for scalar operations and data representation in memory
+ * Can be expanded in the future to support more types (e.g. bfloat16) via new enables.
+ * The u8 and i8 cases are predeclared types that are not otherwise usable in WGSL. For layout calculations, they are of size 1 byte and have an alignment requirement of 1 byte.
+
+<table>
+ <tr>
+ <td>
+<strong>Type</strong>
+ </td>
+ <td><strong>Extra Enable</strong>
+ </td>
+ <td><strong>Element Stride (bytes)</strong>
+ </td>
+ <td><strong>Scalar Shader Type</strong>
+ </td>
+ <td><strong>Minimum Value</strong>
+ </td>
+ <td><strong>Maximum Value</strong>
+ </td>
+ </tr>
+ <tr>
+ <td>f32
+ </td>
+ <td>
+ </td>
+ <td>4
+ </td>
+ <td>f32
+ </td>
+ <td>
+ </td>
+ <td>
+ </td>
+ </tr>
+ <tr>
+ <td>f16
+ </td>
+ <td>f16
+ </td>
+ <td>2
+ </td>
+ <td>f16
+ </td>
+ <td>
+ </td>
+ <td>
+ </td>
+ </tr>
+ <tr>
+ <td>u32
+ </td>
+ <td>
+ </td>
+ <td>4
+ </td>
+ <td>u32
+ </td>
+ <td>
+ </td>
+ <td>
+ </td>
+ </tr>
+ <tr>
+ <td>i32
+ </td>
+ <td>
+ </td>
+ <td>4
+ </td>
+ <td>i32
+ </td>
+ <td>
+ </td>
+ <td>
+ </td>
+ </tr>
+ <tr>
+ <td>u8
+ </td>
+ <td>
+ </td>
+ <td>1
+ </td>
+ <td>u32
+ </td>
+ <td>0
+ </td>
+ <td>255
+ </td>
+ </tr>
+ <tr>
+ <td>i8
+ </td>
+ <td>
+ </td>
+ <td>1
+ </td>
+ <td>i32
+ </td>
+ <td>-128
+ </td>
+ <td>127
+ </td>
+ </tr>
+</table>
+
+
+These types are not considered “composite” in the WGSL taxonomy, because they are not decomposible. You can’t reference a sub-vector or a single component. The numeric dimensions must be override-expressions. These types cannot be part of any interface (i.e. they can only be instantiated in Function and Private address spaces). They are plain types (similar to atomics) so that they can be included in composite types. An important use case is to make arrays of these matrices.
+
+It is a pipeline-creation error if any matrix type is not included in a supported `GPUSubgroupMatrixConfig`, `config`, on the device:
+
+
+
+* For `subgroup_matrix_left`:
+ * `M` equals `config.M`
+ * `K` equals `config.K`
+ * `T` matches `config.componentType`
+* For `subgroup_matrix_right`:
+ * `K` equals `config.K`
+ * `N` equals `config.N`
+ * `T` matches `config.componentType`
+* For `subgroup_matrix_result`:
+ * `M` equals `config.M`
+ * `N` equals `config.N`
+ * `T` matches `config.resultType`
+
+Why use a “left” and “right” matrix type, instead of a single matrix type like Metal has?
+
+
+
+* D3D and Vulkan both need them.
+* There is no free transpose operation. If you had a free transpose, then you can take advantage of: A\* B = transpose(transpose(B)\*transpose(A)). I haven’t seen a transpose operation in APIs or hardware specs.
+
+
+#### Variables
+
+A variable containing a subgroup matrix can only be instantiated in Function or Private address spaces. (This limitation comes from SPIR-V and Metal). These variables are meant to be used as very temporary scratch space.
+
+
+#### Loading and storing
+
+Builtin functions are used to load and store subgroup matrix values from variables in workgroup and storage address spaces. The builtins do two things:
+
+
+
+* Map matrix row and column indices to external memory locations. This mapping has two parameters: majorness (row-major, or column-major), and an integer Stride. Let Base be the byte address of the start of the external matrix, i.e. where the [0,0]’th element of the matrix is stored. Then
+ * For row-major:
+ * Matrix entry [r,c] maps to the sizeof(T) bytes located at Base + Stride\*r\*sizeof(T) + sizeof(T)\*c
+ * Stride >= number of matrix columns.
+ * For column-major:
+ * Matrix entry [r,c] maps to the sizeof(T) bytes located at Base + Stride\*c\*sizeof(T) + sizeof(T)\*r
+ * Stride >= number of matrix rows.
+* Reinterpret data values between the shader scalar type and the external component type T, when those types differ.
+
+For a subgroup_matrix_left/right/result<T, Cols, Rows>, loads and stores are out-of-bounds if the length of the array of the pointer argument is less than Offset + Stride \* Rows\* Cols.
+
+
+#### Attributes
+
+Add an additional error condition to `workgroup_size`. It is a pipeline-creation error if the x dimension is not a multiple of `GPUAdapterInfo.subgroupMaxSize` and the shader statically accesses any subgroup matrix.
+
+Note: this is where the requirement on the subgroups feature stems from in practice.
+
+
+#### Expressions
+
+Supported expressions:
+
+
+
+* Parenthesization
+* Identifier expressions
+* Indirection
+* Address-of
+* Function call expression
+* Type expressions
+
+Explicitly not supported:
+
+
+
+* Decomposition expressions
+
+Possible future expansion:
+
+
+
+* Arithmetic expressions
+
+
+#### Built-in Values
+
+
+<table>
+ <tr>
+ <td><strong>Name</strong>
+ </td>
+ <td>subgroup_id
+ </td>
+ </tr>
+ <tr>
+ <td><strong>Stage</strong>
+ </td>
+ <td>compute
+ </td>
+ </tr>
+ <tr>
+ <td><strong>Direction</strong>
+ </td>
+ <td>input
+ </td>
+ </tr>
+ <tr>
+ <td><strong>Type</strong>
+ </td>
+ <td>u32
+ </td>
+ </tr>
+ <tr>
+ <td><strong>Description</strong>
+ </td>
+ <td>The current invocation’s subgroup’s ID within the workgroup.
+ </td>
+ </tr>
+</table>
+
+
+_Note: This was not included in the subgroups feature because there was insufficient documentation detailing the mapping in HLSL between workgroups and subgroups. This experiment is not implemented on D3D so the built-in value is available on supported platforms. If guaranteed mapping from Microsoft is provided for D3D, subgroup_id could be added as a language feature on top of the subgroups feature._
+
+
+#### Built-in Functions
+
+Calls to these functions:
+
+
+
+* Must only be used in a compute shader stage.
+* Trigger a `chromium_experimental.subgroup_matrix_uniformity` diagnostic if uniformity analysis cannot prove the call is in subgroup uniform control flow.
+
+
+##### Value constructors
+
+
+<table>
+ <tr>
+ <td><strong>Overload</strong>
+ </td>
+ <td>
+
+
+
+<pre class="prettyprint">@must_use fn T(value : S) -> T
+@must_use fn T() -> T</pre>
+
+
+ </td>
+ </tr>
+ <tr>
+ <td><strong>Preconditions</strong>
+ </td>
+ <td>T is a subgroup_matrix type whose scalar shader type is S.
+ </td>
+ </tr>
+ <tr>
+ <td><strong>Description</strong>
+ </td>
+ <td>Create a subgroup matrix filled with value.
+<p>
+When no value is provided, use S().
+<p>
+Triggers a <code>chromium_experimental.subgroup_matrix_uniformity</code> diagnostic if uniformity analysis cannot prove value is a subgroup uniform value.
+ </td>
+ </tr>
+</table>
+
+
+
+##### Load/store functions
+
+See [Loading and Storing](?tab=t.0#bookmark=id.amox2cfe2ajc) above.
+
+
+<table>
+ <tr>
+ <td><strong>Overload</strong>
+ </td>
+ <td>
+
+
+
+<pre class="prettyprint">@must_use fn
+subgroupMatrixLoad<T>(p : ptr<AS, SA, AM>,
+ offset : u32,
+ col_major : bool,
+ stride : u32) -> T</pre>
+
+
+ </td>
+ </tr>
+ <tr>
+ <td><strong>Preconditions</strong>
+ </td>
+ <td>T is a subgroup matrix type with scalar shader type S.
+<p>
+SA is an array with element type S.
+<p>
+AS is storage or workgroup.
+<p>
+AM is read or read_write.
+ </td>
+ </tr>
+ <tr>
+ <td><strong>Description</strong>
+ </td>
+ <td>Load a subgroup matrix from p, offset elements from the start of the array.
+<p>
+<code>col_major</code> must be an override-expression.
+<p>
+Triggers a <code>chromium_experimental.subgroup_matrix_uniformity</code> diagnostic if uniformity analysis cannot prove p, offset, or stride are subgroup uniform values.
+<p>
+stride counts elements of the component type of T. Behavior is undefined if stride is less than:<ul>
+
+<li>The number of rows of T if col_major is true
+<li>The number of columns of T is row_major is true</li></ul>
+
+ </td>
+ </tr>
+</table>
+
+
+
+<table>
+ <tr>
+ <td><strong>Overload</strong>
+ </td>
+ <td>
+
+
+
+<pre class="prettyprint">@must_use fn
+subgroupMatrixStore(p : ptr<AS, SA, AM>,
+ offset : u32,
+ value : T,
+ col_major : bool,
+ stride : u32)</pre>
+
+
+ </td>
+ </tr>
+ <tr>
+ <td><strong>Preconditions</strong>
+ </td>
+ <td>T is a subgroup matrix type whose scalar shader type is S.
+<p>
+SA is an array with element type S.
+<p>
+AS is storage or workgroup.
+<p>
+AM is write or read_write.
+ </td>
+ </tr>
+ <tr>
+ <td><strong>Description</strong>
+ </td>
+ <td>Store the subgroup matrix value into p, offset elements from the start of the array.
+<p>
+<code>col_major</code> must be an override-expression.
+<p>
+Triggers a <code>chromium_experimental.subgroup_matrix_uniformity</code> diagnostic if uniformity analysis cannot prove p, offset, value, or stride are subgroup uniform values.
+<p>
+stride counts elements of the component type of T. Behavior is undefined if stride is less than:<ul>
+
+<li>The number of rows of T if col_major is true
+<li>The number of columns of T is row_major is true</li></ul>
+
+ </td>
+ </tr>
+</table>
+
+
+
+##### Matrix arithmetic functions
+
+The operands of a subgroup matrix arithmetic function comprise a **supported subgroup matrix configuration** if the device has a `GPUSubgroupMatrixConfig`, `config`, such that all operand types match as below:
+
+
+
+* For `L`:
+ * `M` equals `config.M`
+ * `K` equals `config.K`
+ * `T` matches `config.componentType`
+* For `R`:
+ * `K` equals `config.K`
+ * `N` equals `config.N`
+ * `T` matches `config.componentType`
+* For `RT`:
+ * `M` equals `config.M`
+ * `N` equals `config.N`
+ * `TR` matches `config.resultType`
+
+<table>
+ <tr>
+ <td>
+<strong>Overload</strong>
+ </td>
+ <td>
+
+
+
+<pre class="prettyprint">@must_use fn
+subgroupMatrixMultiply<TR>(left : L, right : R) -> RT</pre>
+
+
+ </td>
+ </tr>
+ <tr>
+ <td><strong>Preconditions</strong>
+ </td>
+ <td>L is a subgroup_matrix_left<T, M, K>.
+<p>
+R is a subgroup_matrix_right<T, K, N>.
+<p>
+RT is a subgroup_matrix_result<TR, M, N>.
+ </td>
+ </tr>
+ <tr>
+ <td><strong>Description</strong>
+ </td>
+ <td>Matrix multiply.
+<p>
+It is a pipeline-creation error if L, R, and RT do not comprise a supported subgroup matrix configuration.
+<p>
+Triggers a <code>chromium_experimental.subgroup_matrix_uniformity</code> diagnostic if uniformity analysis cannot prove left or right are subgroup uniform values.
+ </td>
+ </tr>
+</table>
+
+
+
+<table>
+ <tr>
+ <td><strong>Overload</strong>
+ </td>
+ <td>
+
+
+
+<pre class="prettyprint">@must_use fn
+subgroupMatrixMultiplyAccumulate<TR>(left : L,
+ right : R,
+ acc : RT) -> RT</pre>
+
+
+ </td>
+ </tr>
+ <tr>
+ <td><strong>Preconditions</strong>
+ </td>
+ <td>L is a subgroup_matrix_left<T, M, K>.
+<p>
+R is a subgroup_matrix_right<T, K, N>.
+<p>
+RT is a subgroup_matrix_result<TR, M, N>.
+ </td>
+ </tr>
+ <tr>
+ <td><strong>Description</strong>
+ </td>
+ <td>Matrix multiply add (left * right + acc).
+<p>
+It is a pipeline-creation error if L, R, and RT do not comprise a supported subgroup matrix configuration.
+<p>
+Triggers a <code>chromium_experimental.subgroup_matrix_uniformity</code> diagnostic if uniformity analysis cannot prove left or right are subgroup uniform values.
+ </td>
+ </tr>
+</table>
+
+
+
+##### Scalar arithmetic functions
+
+These functions are useful for operations such as applying biases to a model. \
+
+
+
+<table>
+ <tr>
+ <td><strong>Overload</strong>
+ </td>
+ <td>
+
+
+
+<pre class="prettyprint">@must_use fn
+subgroupMatrixScalarAdd(matrix : M, value : S) -> M</pre>
+
+
+ </td>
+ </tr>
+ <tr>
+ <td><strong>Preconditions</strong>
+ </td>
+ <td>M is a subgroup matrix type with scalar shader type S.
+ </td>
+ </tr>
+ <tr>
+ <td><strong>Description</strong>
+ </td>
+ <td>Scalar addition.
+<p>
+value is clamped to a valid range for the component type of M and then added to each element of matrix.
+<p>
+Triggers a <code>chromium_experimental.subgroup_matrix_uniformity</code> diagnostic if uniformity analysis cannot prove matrix or value are subgroup uniform values.
+ </td>
+ </tr>
+</table>
+
+
+
+<table>
+ <tr>
+ <td><strong>Overload</strong>
+ </td>
+ <td>
+
+
+
+<pre class="prettyprint">@must_use fn
+subgroupMatrixScalarSubtract(matrix : M, value : S) -> M</pre>
+
+
+ </td>
+ </tr>
+ <tr>
+ <td><strong>Preconditions</strong>
+ </td>
+ <td>M is a subgroup matrix type with scalar shader type S.
+ </td>
+ </tr>
+ <tr>
+ <td><strong>Description</strong>
+ </td>
+ <td>Scalar subtract.
+<p>
+Value is clamped to a valid range for the component type of M and then subtracted from each element of matrix.
+<p>
+Triggers a <code>chromium_experimental.subgroup_matrix_uniformity</code> diagnostic if uniformity analysis cannot prove matrix or value are subgroup uniform values.
+ </td>
+ </tr>
+</table>
+
+
+
+<table>
+ <tr>
+ <td><strong>Overload</strong>
+ </td>
+ <td>
+
+
+
+<pre class="prettyprint">@must_use fn
+subgroupMatrixScalarMultiply(matrix : M, value : S) -> M</pre>
+
+
+ </td>
+ </tr>
+ <tr>
+ <td><strong>Preconditions</strong>
+ </td>
+ <td>M is a subgroup matrix type with scalar shader type S.
+ </td>
+ </tr>
+ <tr>
+ <td><strong>Description</strong>
+ </td>
+ <td>Scalar multiply.
+<p>
+Value is clamped to a valid range for the component type of M and then each element of matrix is multiplied by it.
+<p>
+Triggers a <code>chromium_experimental.subgroup_matrix_uniformity</code> diagnostic if uniformity analysis cannot prove matrix or value are subgroup uniform values.
+ </td>
+ </tr>
+</table>
+
+
+**REMOVED** (in an attempt to be safely forward compatible if Metal adds integral components, but no other operators)
+
+
+<table>
+ <tr>
+ <td><strong><del>Overload</del></strong>
+ </td>
+ <td>
+
+
+
+<pre class="prettyprint">@must_use fn
+subgroupMatrixScalarDivide(matrix : M, value : S) -> M</pre>
+
+
+ </td>
+ </tr>
+ <tr>
+ <td><strong><del>Preconditions</del></strong>
+ </td>
+ <td><del>M is a subgroup matrix type with scalar shader type S.</del>
+ </td>
+ </tr>
+ <tr>
+ <td><strong><del>Description</del></strong>
+ </td>
+ <td><del>Scalar division.</del>
+<p>
+<del>Value is clamped to a valid range for the component type of M and then each element of matrix is divided by it.</del>
+<p>
+<del>Triggers a <code>chromium_experimental.subgroup_matrix_uniformity</code> diagnostic if uniformity analysis cannot prove matrix or value are subgroup uniform values.</del>
+ </td>
+ </tr>
+</table>
+
+
+
+#### Uniformity
+
+Subgroup matrices are an abstraction of data spread among the invocations in a subgroup. As such the built-in functions must only be called from subgroup uniform control flow. Additionally, most of the parameter values must also be subgroup uniform values.
+
+WGSL does not currently represent subgroup uniformity and empirical testing shows that implementations do not provide reliable reconvergence. Vulkan backends that implement VK_KHR_shader_subgroup_uniform_control_flow or, even better, VK_shader_maximal_reconvergence would be able to make enabling the analysis at other scopes reasonable, but that is not portable.
+
+For this experimental extension we should add a diagnostic, `chromium_experimental.subgroup_matrix_uniformity`, that defaults to `error`. This diagnostic would be triggered based on workgroup uniformity violations (if the control is not `off`) for all subgroup matrix built-in functions. This allows applications some control. We can gather data about the usefulness of the static analysis for subgroup matrices. It is likely that most use cases would satisfy workgroup uniformity.
+
+
+#### Floating-point Accuracy
+
+`subgroupMatrixLoad` and `subgroupMatrixStore` should be bit preserving. ULPs are undefined for other operations.
+
+
+### API
+
+New GPUFeatureName `chromium-experimental-subgroup-matrix`
+
+
+
+* Requires `subgroups` (for maxSubgroupSize)
+* Vulkan:
+ * <code>[vkPhysicalDeviceCooperativeMatrixFeaturesKHR](https://registry.khronos.org/vulkan/specs/1.3-extensions/html/vkspec.html#VkPhysicalDeviceCooperativeMatrixFeaturesKHR)::cooperativeMatrix</code> is <code>VK_TRUE</code>
+ * <code>[vkPhysicalDeviceCooperativeMatrixPropertiesKHR](https://registry.khronos.org/vulkan/specs/1.3-extensions/html/vkspec.html#VkPhysicalDeviceCooperativeMatrixPropertiesKHR)::cooperativeMatrixSupportedStages</code> includes <code>VK_SHADER_STAGE_COMPUTE_BIT</code>
+ * The [supported configurations](?tab=t.0#bookmark=id.x7vevmm1vqtf) list is non-empty
+ * <code>[vkPhysicalDeviceVulkanMemoryModelFeatures](https://registry.khronos.org/vulkan/specs/latest/html/vkspec.html#VkPhysicalDeviceVulkanMemoryModelFeatures)::vulkanMemoryModel</code> is <code>VK_TRUE</code>
+ * <code>[subgroupSizeControl](https://registry.khronos.org/vulkan/specs/latest/html/vkspec.html#features-subgroupSizeControl)</code>feature must be enabled
+ * [computeFullSubgroups](https://registry.khronos.org/vulkan/specs/latest/html/vkspec.html#features-computeFullSubgroups) feature must be enabled
+ * Vulkan pipelines will need to be compiled with <code>VK_PIPELINE_SHADER_STAGE_CREATE_ALLOW_VARYING_SUBGROUP_SIZE_BIT</code> and <code>VK_PIPELINE_SHADER_STAGE_CREATE_REQUIRE_FULL_SUBGROUPS_BIT</code>, or the SPIR-V module must be version 1.6 or later
+* Metal:
+ * Family is Apple 7+
+* D3D: no support
+
+New immutable array, <code>subgroupMatrixConfigs</code>, added to <code>GPUAdapterInfo</code>.
+
+
+```
+partial interface GPUAdapterInfo {
+ [SameObject] readonly attribute FrozenArray<GPUSubgroupMatrixConfig> subgroupMatrixConfigs;
+};
+
+enum GPUSubgroupMatrixComponentType {
+ "f32",
+ "f16",
+ "u32",
+ "i32",
+ "u8",
+ "i8",
+};
+
+interface GPUSubgroupMatrixConfig {
+ readonly attribute GPUSubgroupMatrixComponentType componentType;
+ readonly attribute GPUSubgroupMatrixComponentType resultComponentType;
+ readonly attribute unsigned long M;
+ readonly attribute unsigned long N;
+ readonly attribute unsigned long K;
+};
+```
+
+
+
+#### Validation
+
+No specific API validation is necessary; however, it is likely easier to implement the pipeline-creation error checks in Dawn by adding the subgroup matrix configurations to the shader reflection information.
+
+WGSL pipeline-creation checks (repeated for ease of reference):
+
+
+
+* All subgroup matrix types are part of a `GPUSubgroupMatrixConfig`
+* All subgroup matrix operands in `subgroupMatrixMultiply` and `subgroupMatrixMultipleAccumulate` are part of a single `GPUSubgroupMatrixConfig`
+* The x-dimension of `workgroup_size` is a multiple of `GPUSupportedLimits::maxSubgroupSize`
+
+
+### Mapping
+
+
+#### Types
+
+
+<table>
+ <tr>
+ <td><strong>Type</strong>
+ </td>
+ <td><strong>SPIR-V<sup>1,2</sup></strong>
+ </td>
+ <td><strong>MSL<sup>3</sup></strong>
+ </td>
+ </tr>
+ <tr>
+ <td>subgroup_matrix_left
+<p>
+<K, M, T>
+ </td>
+ <td>OpTypeCooperativeMatrixKHR
+<p>
+MatrixAKHR use
+<p>
+M rows
+<p>
+K cols
+<p>
+T component type
+ </td>
+ <td>simdgroup_matrix<T, 8, 8>
+ </td>
+ </tr>
+ <tr>
+ <td>subgroup_matrix_right
+<p>
+<N, K, T>
+ </td>
+ <td>OpTypeCooperativeMatrixKHR
+<p>
+MatrixBKHR use
+<p>
+K rows
+<p>
+N cols
+<p>
+T component type
+ </td>
+ <td>simdgroup_matrix<T, 8, 8>
+ </td>
+ </tr>
+ <tr>
+ <td>subgroup_matrix_result
+<p>
+<N, M, T>
+ </td>
+ <td>OpTypeCooperativeMatrixKHR
+<p>
+MatrixAccumulatorKHR use
+<p>
+M rows
+<p>
+N cols
+<p>
+T component type
+ </td>
+ <td>simdgroup_matrix<T, 8, 8>
+ </td>
+ </tr>
+</table>
+
+
+
+
+1. All OpTypeCooperativeMatrixKHR use subgroup scope.
+2. Component type enum maps directly to SPIR-V type (e.g. i8 to OpTypeInt 8 1).
+3. MSL types use the usual mappings (e.g. f32 to float).
+
+
+#### Builtin-in Values
+
+
+<table>
+ <tr>
+ <td><strong>Value</strong>
+ </td>
+ <td><strong>SPIR-V</strong>
+ </td>
+ <td><strong>MSL</strong>
+ </td>
+ </tr>
+ <tr>
+ <td>subgroup_id
+ </td>
+ <td>SubgroupId
+ </td>
+ <td>simdgroup_index_in_threadgroup
+ </td>
+ </tr>
+</table>
+
+
+
+#### Functions
+
+
+<table>
+ <tr>
+ <td><strong>Function</strong>
+ </td>
+ <td><strong>SPIR-V</strong>
+ </td>
+ <td><strong>MSL</strong>
+ </td>
+ </tr>
+ <tr>
+ <td>Value constructors
+ </td>
+ <td>OpCompositeConstruct with appropriate value
+<p>
+const-/override-expressions could use constant instructions
+ </td>
+ <td>make_filled_simdgroup_matrix
+ </td>
+ </tr>
+ <tr>
+ <td>subgroupMatrixLoad
+ </td>
+ <td>OpCooperativeMatrixLoadKHR
+<p>
+Pointer operand is a direct translation of the WGSL pointer
+ </td>
+ <td>simdgroup_matrix_load
+<p>
+The WGSL pointer needs translated into the origin operand in MSL
+ </td>
+ </tr>
+ <tr>
+ <td>subgroupMatrixStore
+ </td>
+ <td>OpCooperativeMatrixStoreKHR
+<p>
+Pointer operand is a direct translation of the WGSL pointer
+ </td>
+ <td>simdgroup_matrix_store
+<p>
+The WGSL pointer needs translated into the origin operand in MSL
+ </td>
+ </tr>
+ <tr>
+ <td>subgroupMatrixMultiply
+ </td>
+ <td>OpCooperativeMatrixMulAddKHR
+<p>
+C is a zero value matrix matching the result type
+ </td>
+ <td>simdgroup_multiply
+ </td>
+ </tr>
+ <tr>
+ <td>subgroupMatrixMultiplyAccumulate
+ </td>
+ <td>OpCooperativeMatrixMulAddKHR
+<p>
+Cooperative matrix operands must include signed values appropriately based on operand types
+ </td>
+ <td>simdgroup_multiply_accumulate
+ </td>
+ </tr>
+ <tr>
+ <td>subgroupMatrixScalarAdd
+ </td>
+ <td>OpI/FAdd with composite constructed scalar
+ </td>
+ <td>simdgroup_multiply_accumulate with identity matrix for b and filled matrix scalar value for c
+ </td>
+ </tr>
+ <tr>
+ <td>subgroupMatrixScalarSubtract
+ </td>
+ <td>OpI/FSub with composite constructed scalar
+ </td>
+ <td>simdgroup_multiply_accumulate with diagonal matrix for b and filled inverted matrix scalar value for c
+ </td>
+ </tr>
+ <tr>
+ <td>subgroupMatrixScalarMultiply
+ </td>
+ <td>OpI/FMul with composite constructed scalar or OpMatrixTimesScalar (for float components)
+ </td>
+ <td>simdgroup_multiply with diagonal matrix scalar value for b<sup>1</sup>
+ </td>
+ </tr>
+ <tr>
+ <td><del>subgroupMatrixScalarDivide</del>
+ </td>
+ <td><del>OpU/S/FDiv with composite constructed scalar or OpMatrixTimesScalar (for float components with reciprocal scalar value)</del>
+ </td>
+ <td><del>simdgroup_multiply with diagonal matrix reciprocal scalar value for b<sup>1</sup></del>
+ </td>
+ </tr>
+</table>
+
+
+
+
+1. This works because MSL only supports floating-point component types.
+
+
+#### Properties
+
+
+##### Vulkan
+
+Filter the list returned from [vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR](https://registry.khronos.org/vulkan/specs/1.3-extensions/html/vkspec.html#vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR) such that:
+
+
+
+* A type and B type match
+* C type and result type match
+* A, B, C, and result types are one of (-> API enum):
+ * VK_COMPONENT_TYPE_FLOAT32_KHR -> f32
+ * VK_COMPONENT_TYPE_FLOAT16_KHR -> f16
+ * VK_COMPONENT_TYPE_SINT32_KHR -> i32
+ * VK_COMPONENT_TYPE_UINT32_KHR -> u32
+ * VK_COMPONENT_TYPE_SINT8_KHR -> i8
+ * VK_COMPONENT_TYPE_UINT8_KHR -> u8
+* saturatingAccumulation is false
+* Scope is `VK_SCOPE_SUBGROUP_KHR`
+
+`VK_COMPONENT_TYPE_FLOAT16` will need to be filtered out of the device properties if the `shader-f16` feature is not requested.
+
+
+##### Metal
+
+Hardcode the following configurations if the feature is supported:
+
+
+<table>
+ <tr>
+ <td><strong>componentType</strong>
+ </td>
+ <td><strong>resultComponentType</strong>
+ </td>
+ <td><strong>M</strong>
+ </td>
+ <td><strong>N</strong>
+ </td>
+ <td><strong>K</strong>
+ </td>
+ </tr>
+ <tr>
+ <td>f32
+ </td>
+ <td>f32
+ </td>
+ <td>8
+ </td>
+ <td>8
+ </td>
+ <td>8
+ </td>
+ </tr>
+ <tr>
+ <td>f16<sup>1</sup>
+ </td>
+ <td>f16
+ </td>
+ <td>8
+ </td>
+ <td>8
+ </td>
+ <td>8
+ </td>
+ </tr>
+</table>
+
+
+
+
+1. Filter out f16 from the device properties if `shader-f16` is not requested on the device.
+
+
+## Future Expansion
+
+It is likely that more component types can be supported in the future. Metal can already support bfloat16 for example.
+
+We could consider exposing saturating accumulation.
+