struct Uniforms {
  aShape : vec2<u32>;
  bShape : vec2<u32>;
  outShape : vec2<u32>;
};

struct Matrix {
  numbers : array<u32>;
};

[[group(0), binding(0)]] var<storage, read> firstMatrix : Matrix;

[[group(0), binding(1)]] var<storage, read> secondMatrix : Matrix;

[[group(0), binding(2)]] var<storage, write> resultMatrix : Matrix;

[[group(0), binding(3)]] var<uniform> uniforms : Uniforms;

[[stage(compute), workgroup_size(2, 2, 1)]]
fn main([[builtin(global_invocation_id)]] global_id : vec3<u32>) {
  let resultCell : vec2<u32> = vec2<u32>(global_id.y, global_id.x);
  let dimInner : u32 = uniforms.aShape.y;
  let dimOutter : u32 = uniforms.outShape.y;
  var result : u32 = 0u;
  for(var i : u32 = 0u; (i < dimInner); i = (i + 1u)) {
    let a : u32 = (i + (resultCell.x * dimInner));
    let b : u32 = (resultCell.y + (i * dimOutter));
    result = (result + (firstMatrix.numbers[a] * secondMatrix.numbers[b]));
  }
  let index : u32 = (resultCell.y + (resultCell.x * dimOutter));
  resultMatrix.numbers[index] = result;
}
