| #include <metal_stdlib> |
| using namespace metal; |
| |
| template<typename T, size_t N> |
| struct tint_array { |
| const constant T& operator[](size_t i) const constant { return elements[i]; } |
| device T& operator[](size_t i) device { return elements[i]; } |
| const device T& operator[](size_t i) const device { return elements[i]; } |
| thread T& operator[](size_t i) thread { return elements[i]; } |
| const thread T& operator[](size_t i) const thread { return elements[i]; } |
| threadgroup T& operator[](size_t i) threadgroup { return elements[i]; } |
| const threadgroup T& operator[](size_t i) const threadgroup { return elements[i]; } |
| T elements[N]; |
| }; |
| |
| struct Matrix { |
| /* 0x0000 */ tint_array<float, 1> numbers; |
| }; |
| |
| struct Uniforms { |
| /* 0x0000 */ uint dimAOuter; |
| /* 0x0004 */ uint dimInner; |
| /* 0x0008 */ uint dimBOuter; |
| }; |
| |
| struct tint_module_vars_struct { |
| const device Matrix* firstMatrix; |
| const device Matrix* secondMatrix; |
| device Matrix* resultMatrix; |
| const constant Uniforms* uniforms; |
| threadgroup tint_array<tint_array<float, 64>, 64>* mm_Asub; |
| threadgroup tint_array<tint_array<float, 64>, 64>* mm_Bsub; |
| const constant tint_array<uint4, 1>* tint_storage_buffer_sizes; |
| }; |
| |
| struct tint_symbol_3 { |
| tint_array<tint_array<float, 64>, 64> tint_symbol_1; |
| tint_array<tint_array<float, 64>, 64> tint_symbol_2; |
| }; |
| |
| float mm_readA(uint row, uint col, tint_module_vars_struct tint_module_vars) { |
| bool v = false; |
| if ((row < (*tint_module_vars.uniforms).dimAOuter)) { |
| v = (col < (*tint_module_vars.uniforms).dimInner); |
| } else { |
| v = false; |
| } |
| if (v) { |
| float const result = (*tint_module_vars.firstMatrix).numbers[min(((row * (*tint_module_vars.uniforms).dimInner) + col), ((((*tint_module_vars.tint_storage_buffer_sizes)[0u].x - 0u) / 4u) - 1u))]; |
| return result; |
| } |
| return 0.0f; |
| } |
| |
| float mm_readB(uint row, uint col, tint_module_vars_struct tint_module_vars) { |
| bool v_1 = false; |
| if ((row < (*tint_module_vars.uniforms).dimInner)) { |
| v_1 = (col < (*tint_module_vars.uniforms).dimBOuter); |
| } else { |
| v_1 = false; |
| } |
| if (v_1) { |
| float const result = (*tint_module_vars.secondMatrix).numbers[min(((row * (*tint_module_vars.uniforms).dimBOuter) + col), ((((*tint_module_vars.tint_storage_buffer_sizes)[0u].y - 0u) / 4u) - 1u))]; |
| return result; |
| } |
| return 0.0f; |
| } |
| |
| void mm_write(uint row, uint col, float value, tint_module_vars_struct tint_module_vars) { |
| bool v_2 = false; |
| if ((row < (*tint_module_vars.uniforms).dimAOuter)) { |
| v_2 = (col < (*tint_module_vars.uniforms).dimBOuter); |
| } else { |
| v_2 = false; |
| } |
| if (v_2) { |
| uint const index = (col + (row * (*tint_module_vars.uniforms).dimBOuter)); |
| (*tint_module_vars.resultMatrix).numbers[min(index, ((((*tint_module_vars.tint_storage_buffer_sizes)[0u].z - 0u) / 4u) - 1u))] = value; |
| } |
| } |
| |
| uint tint_div_u32(uint lhs, uint rhs) { |
| return (lhs / select(rhs, 1u, (rhs == 0u))); |
| } |
| |
| void tint_symbol_inner(uint3 local_id, uint3 global_id, uint tint_local_index, tint_module_vars_struct tint_module_vars) { |
| { |
| uint v_3 = 0u; |
| v_3 = tint_local_index; |
| while(true) { |
| uint const v_4 = v_3; |
| if ((v_4 >= 4096u)) { |
| break; |
| } |
| (*tint_module_vars.mm_Asub)[(v_4 / 64u)][(v_4 % 64u)] = 0.0f; |
| (*tint_module_vars.mm_Bsub)[(v_4 / 64u)][(v_4 % 64u)] = 0.0f; |
| { |
| v_3 = (v_4 + 256u); |
| } |
| continue; |
| } |
| } |
| threadgroup_barrier(mem_flags::mem_threadgroup); |
| uint const tileRow = (local_id.y * 4u); |
| uint const tileCol = (local_id.x * 4u); |
| uint const globalRow = (global_id.y * 4u); |
| uint const globalCol = (global_id.x * 4u); |
| uint const numTiles = (tint_div_u32(((*tint_module_vars.uniforms).dimInner - 1u), 64u) + 1u); |
| tint_array<float, 16> acc = {}; |
| float ACached = 0.0f; |
| tint_array<float, 4> BCached = {}; |
| { |
| uint index = 0u; |
| while(true) { |
| if ((index < 16u)) { |
| } else { |
| break; |
| } |
| acc[min(index, 15u)] = 0.0f; |
| { |
| index = (index + 1u); |
| } |
| continue; |
| } |
| } |
| uint const ColPerThreadA = 4u; |
| uint const tileColA = (local_id.x * ColPerThreadA); |
| uint const RowPerThreadB = 4u; |
| uint const tileRowB = (local_id.y * RowPerThreadB); |
| { |
| uint t = 0u; |
| while(true) { |
| if ((t < numTiles)) { |
| } else { |
| break; |
| } |
| { |
| uint innerRow = 0u; |
| while(true) { |
| if ((innerRow < 4u)) { |
| } else { |
| break; |
| } |
| { |
| uint innerCol = 0u; |
| while(true) { |
| if ((innerCol < ColPerThreadA)) { |
| } else { |
| break; |
| } |
| uint const inputRow = (tileRow + innerRow); |
| uint const inputCol = (tileColA + innerCol); |
| (*tint_module_vars.mm_Asub)[min(inputRow, 63u)][min(inputCol, 63u)] = mm_readA((globalRow + innerRow), ((t * 64u) + inputCol), tint_module_vars); |
| { |
| innerCol = (innerCol + 1u); |
| } |
| continue; |
| } |
| } |
| { |
| innerRow = (innerRow + 1u); |
| } |
| continue; |
| } |
| } |
| { |
| uint innerRow = 0u; |
| while(true) { |
| if ((innerRow < RowPerThreadB)) { |
| } else { |
| break; |
| } |
| { |
| uint innerCol = 0u; |
| while(true) { |
| if ((innerCol < 4u)) { |
| } else { |
| break; |
| } |
| uint const inputRow = (tileRowB + innerRow); |
| uint const inputCol = (tileCol + innerCol); |
| threadgroup float* const v_5 = (&(*tint_module_vars.mm_Bsub)[min(innerCol, 63u)][min(inputCol, 63u)]); |
| (*v_5) = mm_readB(((t * 64u) + inputRow), (globalCol + innerCol), tint_module_vars); |
| { |
| innerCol = (innerCol + 1u); |
| } |
| continue; |
| } |
| } |
| { |
| innerRow = (innerRow + 1u); |
| } |
| continue; |
| } |
| } |
| threadgroup_barrier(mem_flags::mem_threadgroup); |
| { |
| uint k = 0u; |
| while(true) { |
| if ((k < 64u)) { |
| } else { |
| break; |
| } |
| { |
| uint inner = 0u; |
| while(true) { |
| if ((inner < 4u)) { |
| } else { |
| break; |
| } |
| BCached[min(inner, 3u)] = (*tint_module_vars.mm_Bsub)[min(k, 63u)][min((tileCol + inner), 63u)]; |
| { |
| inner = (inner + 1u); |
| } |
| continue; |
| } |
| } |
| { |
| uint innerRow = 0u; |
| while(true) { |
| if ((innerRow < 4u)) { |
| } else { |
| break; |
| } |
| ACached = (*tint_module_vars.mm_Asub)[min((tileRow + innerRow), 63u)][min(k, 63u)]; |
| { |
| uint innerCol = 0u; |
| while(true) { |
| if ((innerCol < 4u)) { |
| } else { |
| break; |
| } |
| uint const index = ((innerRow * 4u) + innerCol); |
| acc[min(index, 15u)] = (acc[min(index, 15u)] + (ACached * BCached[min(innerCol, 3u)])); |
| { |
| innerCol = (innerCol + 1u); |
| } |
| continue; |
| } |
| } |
| { |
| innerRow = (innerRow + 1u); |
| } |
| continue; |
| } |
| } |
| { |
| k = (k + 1u); |
| } |
| continue; |
| } |
| } |
| threadgroup_barrier(mem_flags::mem_threadgroup); |
| { |
| t = (t + 1u); |
| } |
| continue; |
| } |
| } |
| { |
| uint innerRow = 0u; |
| while(true) { |
| if ((innerRow < 4u)) { |
| } else { |
| break; |
| } |
| { |
| uint innerCol = 0u; |
| while(true) { |
| if ((innerCol < 4u)) { |
| } else { |
| break; |
| } |
| uint const index = ((innerRow * 4u) + innerCol); |
| mm_write((globalRow + innerRow), (globalCol + innerCol), acc[min(index, 15u)], tint_module_vars); |
| { |
| innerCol = (innerCol + 1u); |
| } |
| continue; |
| } |
| } |
| { |
| innerRow = (innerRow + 1u); |
| } |
| continue; |
| } |
| } |
| } |
| |
| kernel void tint_symbol(uint3 local_id [[thread_position_in_threadgroup]], uint3 global_id [[thread_position_in_grid]], uint tint_local_index [[thread_index_in_threadgroup]], const device Matrix* firstMatrix [[buffer(2)]], const device Matrix* secondMatrix [[buffer(3)]], device Matrix* resultMatrix [[buffer(1)]], const constant Uniforms* uniforms [[buffer(0)]], threadgroup tint_symbol_3* v_6 [[threadgroup(0)]], const constant tint_array<uint4, 1>* tint_storage_buffer_sizes [[buffer(30)]]) { |
| tint_module_vars_struct const tint_module_vars = tint_module_vars_struct{.firstMatrix=firstMatrix, .secondMatrix=secondMatrix, .resultMatrix=resultMatrix, .uniforms=uniforms, .mm_Asub=(&(*v_6).tint_symbol_1), .mm_Bsub=(&(*v_6).tint_symbol_2), .tint_storage_buffer_sizes=tint_storage_buffer_sizes}; |
| tint_symbol_inner(local_id, global_id, tint_local_index, tint_module_vars); |
| } |