| #include <metal_stdlib> |
| |
| using namespace metal; |
| struct Uniforms { |
| /* 0x0000 */ uint dimAOuter; |
| /* 0x0004 */ uint dimInner; |
| /* 0x0008 */ uint dimBOuter; |
| }; |
| struct Matrix { |
| /* 0x0000 */ float numbers[1]; |
| }; |
| struct tint_array_wrapper_1 { |
| float arr[64]; |
| }; |
| struct tint_array_wrapper { |
| tint_array_wrapper_1 arr[64]; |
| }; |
| struct tint_array_wrapper_2 { |
| float arr[16]; |
| }; |
| struct tint_array_wrapper_3 { |
| float arr[4]; |
| }; |
| |
| constant uint RowPerThread = 4u; |
| constant uint ColPerThread = 4u; |
| constant uint TileAOuter = 64u; |
| constant uint TileBOuter = 64u; |
| constant uint TileInner = 64u; |
| float mm_readA(constant Uniforms& uniforms, const device Matrix& firstMatrix, uint row, uint col) { |
| if (((row < uniforms.dimAOuter) && (col < uniforms.dimInner))) { |
| float const result = firstMatrix.numbers[((row * uniforms.dimInner) + col)]; |
| return result; |
| } |
| return 0.0f; |
| } |
| |
| float mm_readB(constant Uniforms& uniforms, const device Matrix& secondMatrix, uint row, uint col) { |
| if (((row < uniforms.dimInner) && (col < uniforms.dimBOuter))) { |
| float const result = secondMatrix.numbers[((row * uniforms.dimBOuter) + col)]; |
| return result; |
| } |
| return 0.0f; |
| } |
| |
| void mm_write(constant Uniforms& uniforms, device Matrix& resultMatrix, uint row, uint col, float value) { |
| if (((row < uniforms.dimAOuter) && (col < uniforms.dimBOuter))) { |
| uint const index = (col + (row * uniforms.dimBOuter)); |
| resultMatrix.numbers[index] = value; |
| } |
| } |
| |
| void tint_symbol_inner(constant Uniforms& uniforms, const device Matrix& firstMatrix, const device Matrix& secondMatrix, device Matrix& resultMatrix, uint3 local_id, uint3 global_id, uint local_invocation_index, threadgroup tint_array_wrapper* const tint_symbol_1, threadgroup tint_array_wrapper* const tint_symbol_2) { |
| for(uint idx = local_invocation_index; (idx < 4096u); idx = (idx + 256u)) { |
| uint const i = (idx / 64u); |
| uint const i_1 = (idx % 64u); |
| (*(tint_symbol_1)).arr[i].arr[i_1] = float(); |
| (*(tint_symbol_2)).arr[i].arr[i_1] = float(); |
| } |
| threadgroup_barrier(mem_flags::mem_threadgroup); |
| uint const tileRow = (local_id.y * RowPerThread); |
| uint const tileCol = (local_id.x * ColPerThread); |
| uint const globalRow = (global_id.y * RowPerThread); |
| uint const globalCol = (global_id.x * ColPerThread); |
| uint const numTiles = (((uniforms.dimInner - 1u) / TileInner) + 1u); |
| tint_array_wrapper_2 acc = {}; |
| float ACached = 0.0f; |
| tint_array_wrapper_3 BCached = {}; |
| for(uint index = 0u; (index < (RowPerThread * ColPerThread)); index = (index + 1u)) { |
| acc.arr[index] = 0.0f; |
| } |
| uint const ColPerThreadA = (TileInner / 16u); |
| uint const tileColA = (local_id.x * ColPerThreadA); |
| uint const RowPerThreadB = (TileInner / 16u); |
| uint const tileRowB = (local_id.y * RowPerThreadB); |
| for(uint t = 0u; (t < numTiles); t = (t + 1u)) { |
| for(uint innerRow = 0u; (innerRow < RowPerThread); innerRow = (innerRow + 1u)) { |
| for(uint innerCol = 0u; (innerCol < ColPerThreadA); innerCol = (innerCol + 1u)) { |
| uint const inputRow = (tileRow + innerRow); |
| uint const inputCol = (tileColA + innerCol); |
| (*(tint_symbol_1)).arr[inputRow].arr[inputCol] = mm_readA(uniforms, firstMatrix, (globalRow + innerRow), ((t * TileInner) + inputCol)); |
| } |
| } |
| for(uint innerRow = 0u; (innerRow < RowPerThreadB); innerRow = (innerRow + 1u)) { |
| for(uint innerCol = 0u; (innerCol < ColPerThread); innerCol = (innerCol + 1u)) { |
| uint const inputRow = (tileRowB + innerRow); |
| uint const inputCol = (tileCol + innerCol); |
| (*(tint_symbol_2)).arr[innerCol].arr[inputCol] = mm_readB(uniforms, secondMatrix, ((t * TileInner) + inputRow), (globalCol + innerCol)); |
| } |
| } |
| threadgroup_barrier(mem_flags::mem_threadgroup); |
| for(uint k = 0u; (k < TileInner); k = (k + 1u)) { |
| for(uint inner = 0u; (inner < ColPerThread); inner = (inner + 1u)) { |
| BCached.arr[inner] = (*(tint_symbol_2)).arr[k].arr[(tileCol + inner)]; |
| } |
| for(uint innerRow = 0u; (innerRow < RowPerThread); innerRow = (innerRow + 1u)) { |
| ACached = (*(tint_symbol_1)).arr[(tileRow + innerRow)].arr[k]; |
| for(uint innerCol = 0u; (innerCol < ColPerThread); innerCol = (innerCol + 1u)) { |
| uint const index = ((innerRow * ColPerThread) + innerCol); |
| acc.arr[index] = (acc.arr[index] + (ACached * BCached.arr[innerCol])); |
| } |
| } |
| } |
| threadgroup_barrier(mem_flags::mem_threadgroup); |
| } |
| for(uint innerRow = 0u; (innerRow < RowPerThread); innerRow = (innerRow + 1u)) { |
| for(uint innerCol = 0u; (innerCol < ColPerThread); innerCol = (innerCol + 1u)) { |
| uint const index = ((innerRow * ColPerThread) + innerCol); |
| mm_write(uniforms, resultMatrix, (globalRow + innerRow), (globalCol + innerCol), acc.arr[index]); |
| } |
| } |
| } |
| |
| kernel void tint_symbol(uint3 local_id [[thread_position_in_threadgroup]], uint3 global_id [[thread_position_in_grid]], uint local_invocation_index [[thread_index_in_threadgroup]], constant Uniforms& uniforms [[buffer(3)]], const device Matrix& firstMatrix [[buffer(0)]], const device Matrix& secondMatrix [[buffer(1)]], device Matrix& resultMatrix [[buffer(2)]]) { |
| threadgroup tint_array_wrapper tint_symbol_3; |
| threadgroup tint_array_wrapper tint_symbol_4; |
| tint_symbol_inner(uniforms, firstMatrix, secondMatrix, resultMatrix, local_id, global_id, local_invocation_index, &(tint_symbol_3), &(tint_symbol_4)); |
| return; |
| } |
| |