| #include <metal_stdlib> |
| |
| using namespace metal; |
| |
| template<typename T, size_t N> |
| struct tint_array { |
| const constant T& operator[](size_t i) const constant { return elements[i]; } |
| device T& operator[](size_t i) device { return elements[i]; } |
| const device T& operator[](size_t i) const device { return elements[i]; } |
| thread T& operator[](size_t i) thread { return elements[i]; } |
| const thread T& operator[](size_t i) const thread { return elements[i]; } |
| threadgroup T& operator[](size_t i) threadgroup { return elements[i]; } |
| const threadgroup T& operator[](size_t i) const threadgroup { return elements[i]; } |
| T elements[N]; |
| }; |
| |
| #define TINT_ISOLATE_UB(VOLATILE_NAME) \ |
| volatile bool VOLATILE_NAME = true; \ |
| if (VOLATILE_NAME) |
| |
| void tint_zero_workgroup_memory(uint local_idx, threadgroup tint_array<tint_array<float, 64>, 64>* const tint_symbol_5, threadgroup tint_array<tint_array<float, 64>, 64>* const tint_symbol_6) { |
| TINT_ISOLATE_UB(tint_volatile_true) for(uint idx = local_idx; (idx < 4096u); idx = (idx + 256u)) { |
| uint const i = (idx / 64u); |
| uint const i_1 = (idx % 64u); |
| (*(tint_symbol_5))[i][i_1] = 0.0f; |
| (*(tint_symbol_6))[i][i_1] = 0.0f; |
| } |
| threadgroup_barrier(mem_flags::mem_threadgroup); |
| } |
| |
| struct Uniforms { |
| /* 0x0000 */ uint dimAOuter; |
| /* 0x0004 */ uint dimInner; |
| /* 0x0008 */ uint dimBOuter; |
| }; |
| |
| struct Matrix { |
| /* 0x0000 */ tint_array<float, 1> numbers; |
| }; |
| |
| float mm_readA(uint row, uint col, const constant Uniforms* const tint_symbol_7, const device Matrix* const tint_symbol_8) { |
| if (((row < (*(tint_symbol_7)).dimAOuter) && (col < (*(tint_symbol_7)).dimInner))) { |
| float const result = (*(tint_symbol_8)).numbers[((row * (*(tint_symbol_7)).dimInner) + col)]; |
| return result; |
| } |
| return 0.0f; |
| } |
| |
| float mm_readB(uint row, uint col, const constant Uniforms* const tint_symbol_9, const device Matrix* const tint_symbol_10) { |
| if (((row < (*(tint_symbol_9)).dimInner) && (col < (*(tint_symbol_9)).dimBOuter))) { |
| float const result = (*(tint_symbol_10)).numbers[((row * (*(tint_symbol_9)).dimBOuter) + col)]; |
| return result; |
| } |
| return 0.0f; |
| } |
| |
| void mm_write(uint row, uint col, float value, const constant Uniforms* const tint_symbol_11, device Matrix* const tint_symbol_12) { |
| if (((row < (*(tint_symbol_11)).dimAOuter) && (col < (*(tint_symbol_11)).dimBOuter))) { |
| uint const index = (col + (row * (*(tint_symbol_11)).dimBOuter)); |
| (*(tint_symbol_12)).numbers[index] = value; |
| } |
| } |
| |
| uint tint_div(uint lhs, uint rhs) { |
| return (lhs / select(rhs, 1u, (rhs == 0u))); |
| } |
| |
| void tint_symbol_inner(uint3 local_id, uint3 global_id, uint local_invocation_index, threadgroup tint_array<tint_array<float, 64>, 64>* const tint_symbol_13, threadgroup tint_array<tint_array<float, 64>, 64>* const tint_symbol_14, const constant Uniforms* const tint_symbol_15, const device Matrix* const tint_symbol_16, const device Matrix* const tint_symbol_17, device Matrix* const tint_symbol_18) { |
| tint_zero_workgroup_memory(local_invocation_index, tint_symbol_13, tint_symbol_14); |
| uint const tileRow = (local_id[1] * 4u); |
| uint const tileCol = (local_id[0] * 4u); |
| uint const globalRow = (global_id[1] * 4u); |
| uint const globalCol = (global_id[0] * 4u); |
| uint const numTiles = (tint_div(((*(tint_symbol_15)).dimInner - 1u), 64u) + 1u); |
| tint_array<float, 16> acc = {}; |
| float ACached = 0.0f; |
| tint_array<float, 4> BCached = {}; |
| TINT_ISOLATE_UB(tint_volatile_true_1) for(uint index = 0u; (index < 16u); index = (index + 1u)) { |
| acc[index] = 0.0f; |
| } |
| uint const ColPerThreadA = 4u; |
| uint const tileColA = (local_id[0] * ColPerThreadA); |
| uint const RowPerThreadB = 4u; |
| uint const tileRowB = (local_id[1] * RowPerThreadB); |
| TINT_ISOLATE_UB(tint_volatile_true_2) for(uint t = 0u; (t < numTiles); t = (t + 1u)) { |
| TINT_ISOLATE_UB(tint_volatile_true_3) for(uint innerRow = 0u; (innerRow < 4u); innerRow = (innerRow + 1u)) { |
| TINT_ISOLATE_UB(tint_volatile_true_4) for(uint innerCol = 0u; (innerCol < ColPerThreadA); innerCol = (innerCol + 1u)) { |
| uint const inputRow = (tileRow + innerRow); |
| uint const inputCol = (tileColA + innerCol); |
| uint const tint_symbol_1 = inputRow; |
| uint const tint_symbol_2 = inputCol; |
| (*(tint_symbol_13))[tint_symbol_1][tint_symbol_2] = mm_readA((globalRow + innerRow), ((t * 64u) + inputCol), tint_symbol_15, tint_symbol_16); |
| } |
| } |
| TINT_ISOLATE_UB(tint_volatile_true_5) for(uint innerRow = 0u; (innerRow < RowPerThreadB); innerRow = (innerRow + 1u)) { |
| TINT_ISOLATE_UB(tint_volatile_true_6) for(uint innerCol = 0u; (innerCol < 4u); innerCol = (innerCol + 1u)) { |
| uint const inputRow = (tileRowB + innerRow); |
| uint const inputCol = (tileCol + innerCol); |
| uint const tint_symbol_3 = innerCol; |
| uint const tint_symbol_4 = inputCol; |
| (*(tint_symbol_14))[tint_symbol_3][tint_symbol_4] = mm_readB(((t * 64u) + inputRow), (globalCol + innerCol), tint_symbol_15, tint_symbol_17); |
| } |
| } |
| threadgroup_barrier(mem_flags::mem_threadgroup); |
| TINT_ISOLATE_UB(tint_volatile_true_7) for(uint k = 0u; (k < 64u); k = (k + 1u)) { |
| TINT_ISOLATE_UB(tint_volatile_true_8) for(uint inner = 0u; (inner < 4u); inner = (inner + 1u)) { |
| BCached[inner] = (*(tint_symbol_14))[k][(tileCol + inner)]; |
| } |
| TINT_ISOLATE_UB(tint_volatile_true_9) for(uint innerRow = 0u; (innerRow < 4u); innerRow = (innerRow + 1u)) { |
| ACached = (*(tint_symbol_13))[(tileRow + innerRow)][k]; |
| TINT_ISOLATE_UB(tint_volatile_true_10) for(uint innerCol = 0u; (innerCol < 4u); innerCol = (innerCol + 1u)) { |
| uint const index = ((innerRow * 4u) + innerCol); |
| acc[index] = (acc[index] + (ACached * BCached[innerCol])); |
| } |
| } |
| } |
| threadgroup_barrier(mem_flags::mem_threadgroup); |
| } |
| TINT_ISOLATE_UB(tint_volatile_true_11) for(uint innerRow = 0u; (innerRow < 4u); innerRow = (innerRow + 1u)) { |
| TINT_ISOLATE_UB(tint_volatile_true_12) for(uint innerCol = 0u; (innerCol < 4u); innerCol = (innerCol + 1u)) { |
| uint const index = ((innerRow * 4u) + innerCol); |
| mm_write((globalRow + innerRow), (globalCol + innerCol), acc[index], tint_symbol_15, tint_symbol_18); |
| } |
| } |
| } |
| |
| kernel void tint_symbol(const constant Uniforms* tint_symbol_21 [[buffer(0)]], const device Matrix* tint_symbol_22 [[buffer(2)]], const device Matrix* tint_symbol_23 [[buffer(3)]], device Matrix* tint_symbol_24 [[buffer(1)]], uint3 local_id [[thread_position_in_threadgroup]], uint3 global_id [[thread_position_in_grid]], uint local_invocation_index [[thread_index_in_threadgroup]]) { |
| threadgroup tint_array<tint_array<float, 64>, 64> tint_symbol_19; |
| threadgroup tint_array<tint_array<float, 64>, 64> tint_symbol_20; |
| tint_symbol_inner(local_id, global_id, local_invocation_index, &(tint_symbol_19), &(tint_symbol_20), tint_symbol_21, tint_symbol_22, tint_symbol_23, tint_symbol_24); |
| return; |
| } |
| |