| struct Uniforms { |
| aShape : vec2<u32>; |
| bShape : vec2<u32>; |
| outShape : vec2<u32>; |
| }; |
| struct Matrix { |
| numbers: array<u32>; |
| }; |
| |
| @group(0) @binding(0) var<storage, read> firstMatrix : Matrix; |
| @group(0) @binding(1) var<storage, read> secondMatrix : Matrix; |
| @group(0) @binding(2) var<storage, write> resultMatrix : Matrix; |
| @group(0) @binding(3) var<uniform> uniforms : Uniforms; |
| |
| @stage(compute) @workgroup_size(2,2,1) |
| fn main(@builtin(global_invocation_id) global_id : vec3<u32>) { |
| let resultCell : vec2<u32> = vec2<u32>(global_id.y, global_id.x); |
| let dimInner : u32 = uniforms.aShape.y; |
| let dimOutter: u32 = uniforms.outShape.y; |
| var result : u32 = 0u; |
| for (var i : u32 = 0u; i < dimInner; i = i + 1u) { |
| let a : u32 = i + resultCell.x * dimInner; |
| let b : u32 = resultCell.y + i * dimOutter; |
| result = result + firstMatrix.numbers[a] * secondMatrix.numbers[b]; |
| } |
| |
| let index : u32 = resultCell.y + resultCell.x * dimOutter; |
| resultMatrix.numbers[index] = result; |
| } |