| [[block]] |
| struct Uniforms { |
| aShape : vec2<u32>; |
| bShape : vec2<u32>; |
| outShape : vec2<u32>; |
| }; |
| |
| [[block]] |
| struct Matrix { |
| numbers : array<u32>; |
| }; |
| |
| [[group(0), binding(0)]] var<storage, read> firstMatrix : Matrix; |
| |
| [[group(0), binding(1)]] var<storage, read> secondMatrix : Matrix; |
| |
| [[group(0), binding(2)]] var<storage, write> resultMatrix : Matrix; |
| |
| [[group(0), binding(3)]] var<uniform> uniforms : Uniforms; |
| |
| [[stage(compute), workgroup_size(2, 2, 1)]] |
| fn main([[builtin(global_invocation_id)]] global_id : vec3<u32>) { |
| let resultCell : vec2<u32> = vec2<u32>(global_id.y, global_id.x); |
| let dimInner : u32 = uniforms.aShape.y; |
| let dimOutter : u32 = uniforms.outShape.y; |
| var result : u32 = 0u; |
| for(var i : u32 = 0u; (i < dimInner); i = (i + 1u)) { |
| let a : u32 = (i + (resultCell.x * dimInner)); |
| let b : u32 = (resultCell.y + (i * dimOutter)); |
| result = (result + (firstMatrix.numbers[a] * secondMatrix.numbers[b])); |
| } |
| let index : u32 = (resultCell.y + (resultCell.x * dimOutter)); |
| resultMatrix.numbers[index] = result; |
| } |