blob: 72db750d2b5597c3aeee88e6f88721a79976da60 [file] [log] [blame]
[[block]] struct Uniforms {
aShape : vec2<u32>;
bShape : vec2<u32>;
outShape : vec2<u32>;
};
[[block]] struct Matrix {
numbers: array<u32>;
};
[[group(0), binding(0)]] var<storage, read> firstMatrix : Matrix;
[[group(0), binding(1)]] var<storage, read> secondMatrix : Matrix;
[[group(0), binding(2)]] var<storage, write> resultMatrix : Matrix;
[[group(0), binding(3)]] var<uniform> uniforms : Uniforms;
[[stage(compute), workgroup_size(2,2,1)]]
fn main([[builtin(global_invocation_id)]] global_id : vec3<u32>) {
let resultCell : vec2<u32> = vec2<u32>(global_id.y, global_id.x);
let dimInner : u32 = uniforms.aShape.y;
let dimOutter: u32 = uniforms.outShape.y;
var result : u32 = 0u;
for (var i : u32 = 0u; i < dimInner; i = i + 1u) {
let a : u32 = i + resultCell.x * dimInner;
let b : u32 = resultCell.y + i * dimOutter;
result = result + firstMatrix.numbers[a] * secondMatrix.numbers[b];
}
let index : u32 = resultCell.y + resultCell.x * dimOutter;
resultMatrix.numbers[index] = result;
}