blob: d9a4b829bc431793e80055606237359c25199ccc [file] [log] [blame]
#include <metal_stdlib>
using namespace metal;
struct Uniforms {
/* 0x0000 */ uint dimAOuter;
/* 0x0004 */ uint dimInner;
/* 0x0008 */ uint dimBOuter;
};
struct Matrix {
/* 0x0000 */ float numbers[1];
};
struct tint_array_wrapper_1 {
float arr[64];
};
struct tint_array_wrapper {
tint_array_wrapper_1 arr[64];
};
struct tint_array_wrapper_2 {
float arr[16];
};
struct tint_array_wrapper_3 {
float arr[4];
};
constant uint RowPerThread = 4u;
constant uint ColPerThread = 4u;
constant uint TileAOuter = 64u;
constant uint TileBOuter = 64u;
constant uint TileInner = 64u;
float mm_readA(constant Uniforms& uniforms, const device Matrix& firstMatrix, uint row, uint col) {
if (((row < uniforms.dimAOuter) && (col < uniforms.dimInner))) {
float const result = firstMatrix.numbers[((row * uniforms.dimInner) + col)];
return result;
}
return 0.0f;
}
float mm_readB(constant Uniforms& uniforms, const device Matrix& secondMatrix, uint row, uint col) {
if (((row < uniforms.dimInner) && (col < uniforms.dimBOuter))) {
float const result = secondMatrix.numbers[((row * uniforms.dimBOuter) + col)];
return result;
}
return 0.0f;
}
void mm_write(constant Uniforms& uniforms, device Matrix& resultMatrix, uint row, uint col, float value) {
if (((row < uniforms.dimAOuter) && (col < uniforms.dimBOuter))) {
uint const index = (col + (row * uniforms.dimBOuter));
resultMatrix.numbers[index] = value;
}
}
kernel void tint_symbol(uint3 local_id [[thread_position_in_threadgroup]], uint3 global_id [[thread_position_in_grid]], uint local_invocation_index [[thread_index_in_threadgroup]], constant Uniforms& uniforms [[buffer(3)]], const device Matrix& firstMatrix [[buffer(0)]], const device Matrix& secondMatrix [[buffer(1)]], device Matrix& resultMatrix [[buffer(2)]]) {
threadgroup tint_array_wrapper tint_symbol_4;
threadgroup tint_array_wrapper tint_symbol_5;
if ((local_invocation_index == 0u)) {
tint_array_wrapper const tint_symbol_2 = {.arr={}};
tint_symbol_4 = tint_symbol_2;
tint_array_wrapper const tint_symbol_3 = {.arr={}};
tint_symbol_5 = tint_symbol_3;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
uint const tileRow = (local_id.y * RowPerThread);
uint const tileCol = (local_id.x * ColPerThread);
uint const globalRow = (global_id.y * RowPerThread);
uint const globalCol = (global_id.x * ColPerThread);
uint const numTiles = (((uniforms.dimInner - 1u) / TileInner) + 1u);
tint_array_wrapper_2 acc = {};
float ACached = 0.0f;
tint_array_wrapper_3 BCached = {};
{
uint index = 0u;
while (true) {
if (!((index < (RowPerThread * ColPerThread)))) {
break;
}
acc.arr[index] = 0.0f;
{
index = (index + 1u);
}
}
}
uint const ColPerThreadA = (TileInner / 16u);
uint const tileColA = (local_id.x * ColPerThreadA);
uint const RowPerThreadB = (TileInner / 16u);
uint const tileRowB = (local_id.y * RowPerThreadB);
{
uint t = 0u;
while (true) {
if (!((t < numTiles))) {
break;
}
{
uint innerRow = 0u;
while (true) {
if (!((innerRow < RowPerThread))) {
break;
}
{
uint innerCol = 0u;
while (true) {
if (!((innerCol < ColPerThreadA))) {
break;
}
uint const inputRow = (tileRow + innerRow);
uint const inputCol = (tileColA + innerCol);
tint_symbol_4.arr[inputRow].arr[inputCol] = mm_readA(uniforms, firstMatrix, (globalRow + innerRow), ((t * TileInner) + inputCol));
{
innerCol = (innerCol + 1u);
}
}
}
{
innerRow = (innerRow + 1u);
}
}
}
{
uint innerRow = 0u;
while (true) {
if (!((innerRow < RowPerThreadB))) {
break;
}
{
uint innerCol = 0u;
while (true) {
if (!((innerCol < ColPerThread))) {
break;
}
uint const inputRow = (tileRowB + innerRow);
uint const inputCol = (tileCol + innerCol);
tint_symbol_5.arr[innerCol].arr[inputCol] = mm_readB(uniforms, secondMatrix, ((t * TileInner) + inputRow), (globalCol + innerCol));
{
innerCol = (innerCol + 1u);
}
}
}
{
innerRow = (innerRow + 1u);
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
{
uint k = 0u;
while (true) {
if (!((k < TileInner))) {
break;
}
{
uint inner = 0u;
while (true) {
if (!((inner < ColPerThread))) {
break;
}
BCached.arr[inner] = tint_symbol_5.arr[k].arr[(tileCol + inner)];
{
inner = (inner + 1u);
}
}
}
{
uint innerRow = 0u;
while (true) {
if (!((innerRow < RowPerThread))) {
break;
}
ACached = tint_symbol_4.arr[(tileRow + innerRow)].arr[k];
{
uint innerCol = 0u;
while (true) {
if (!((innerCol < ColPerThread))) {
break;
}
uint const index = ((innerRow * ColPerThread) + innerCol);
acc.arr[index] = (acc.arr[index] + (ACached * BCached.arr[innerCol]));
{
innerCol = (innerCol + 1u);
}
}
}
{
innerRow = (innerRow + 1u);
}
}
}
{
k = (k + 1u);
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
{
t = (t + 1u);
}
}
}
{
uint innerRow = 0u;
while (true) {
if (!((innerRow < RowPerThread))) {
break;
}
{
uint innerCol = 0u;
while (true) {
if (!((innerCol < ColPerThread))) {
break;
}
uint const index = ((innerRow * ColPerThread) + innerCol);
mm_write(uniforms, resultMatrix, (globalRow + innerRow), (globalCol + innerCol), acc.arr[index]);
{
innerCol = (innerCol + 1u);
}
}
}
{
innerRow = (innerRow + 1u);
}
}
}
return;
}