| #include <metal_stdlib> |
| |
| using namespace metal; |
| struct Uniforms { |
| /* 0x0000 */ float tint_symbol; |
| /* 0x0004 */ int8_t tint_pad[12]; |
| /* 0x0010 */ packed_int3 aShape; |
| /* 0x001c */ int8_t tint_pad_1[4]; |
| /* 0x0020 */ packed_int3 bShape; |
| /* 0x002c */ int8_t tint_pad_2[4]; |
| /* 0x0030 */ packed_int3 outShape; |
| /* 0x003c */ int8_t tint_pad_3[4]; |
| /* 0x0040 */ packed_int2 outShapeStrides; |
| }; |
| struct ssbOut { |
| /* 0x0000 */ float result[1]; |
| }; |
| struct ssbA { |
| /* 0x0000 */ float A[1]; |
| }; |
| struct ssbB { |
| /* 0x0000 */ float B[1]; |
| }; |
| struct tint_array_wrapper_1 { |
| float arr[64]; |
| }; |
| struct tint_array_wrapper { |
| tint_array_wrapper_1 arr[64]; |
| }; |
| struct tint_array_wrapper_3 { |
| float arr[1]; |
| }; |
| struct tint_array_wrapper_2 { |
| tint_array_wrapper_3 arr[64]; |
| }; |
| struct tint_array_wrapper_4 { |
| tint_array_wrapper_3 arr[1]; |
| }; |
| |
| bool coordsInBounds_vi2_vi2_(thread int2* const coord, thread int2* const shape) { |
| bool x_87 = false; |
| bool x_88_phi = false; |
| int2 const x_76 = *(coord); |
| bool const x_81 = all((x_76 >= int2(0, 0))); |
| x_88_phi = x_81; |
| if (x_81) { |
| int2 const x_84 = *(coord); |
| int2 const x_85 = *(shape); |
| x_87 = all((x_84 < x_85)); |
| x_88_phi = x_87; |
| } |
| bool const x_88 = x_88_phi; |
| return x_88; |
| } |
| |
| float mm_readA_i1_i1_(constant Uniforms& x_48, const device ssbA& x_165, thread int* const row, thread int* const col, thread int* const tint_symbol_5, thread int* const tint_symbol_6, thread int* const tint_symbol_7) { |
| int batchASize = 0; |
| int2 param_10 = 0; |
| int2 param_11 = 0; |
| float x_430 = 0.0f; |
| int const x_417 = x_48.aShape.y; |
| int const x_419 = x_48.aShape.z; |
| batchASize = (x_417 * x_419); |
| int const x_421 = *(row); |
| int const x_422 = *(col); |
| int const x_424 = *(tint_symbol_5); |
| int const x_425 = *(tint_symbol_6); |
| param_10 = int2(x_421, x_422); |
| param_11 = int2(x_424, x_425); |
| bool const x_429 = coordsInBounds_vi2_vi2_(&(param_10), &(param_11)); |
| if (x_429) { |
| int const x_438 = *(tint_symbol_7); |
| int const x_439 = batchASize; |
| int const x_441 = *(row); |
| int const x_442 = *(tint_symbol_6); |
| int const x_445 = *(col); |
| float const x_448 = x_165.A[(((x_438 * x_439) + (x_441 * x_442)) + x_445)]; |
| x_430 = x_448; |
| } else { |
| x_430 = 0.0f; |
| } |
| float const x_450 = x_430; |
| return x_450; |
| } |
| |
| float mm_readB_i1_i1_(constant Uniforms& x_48, const device ssbB& x_185, thread int* const row_1, thread int* const col_1, thread int* const tint_symbol_8, thread int* const tint_symbol_9, thread int* const tint_symbol_10) { |
| int batchBSize = 0; |
| int2 param_12 = 0; |
| int2 param_13 = 0; |
| float x_468 = 0.0f; |
| int const x_455 = x_48.bShape.y; |
| int const x_457 = x_48.bShape.z; |
| batchBSize = (x_455 * x_457); |
| int const x_459 = *(row_1); |
| int const x_460 = *(col_1); |
| int const x_462 = *(tint_symbol_8); |
| int const x_463 = *(tint_symbol_9); |
| param_12 = int2(x_459, x_460); |
| param_13 = int2(x_462, x_463); |
| bool const x_467 = coordsInBounds_vi2_vi2_(&(param_12), &(param_13)); |
| if (x_467) { |
| int const x_475 = *(tint_symbol_10); |
| int const x_476 = batchBSize; |
| int const x_478 = *(row_1); |
| int const x_479 = *(tint_symbol_9); |
| int const x_482 = *(col_1); |
| float const x_485 = x_185.B[(((x_475 * x_476) + (x_478 * x_479)) + x_482)]; |
| x_468 = x_485; |
| } else { |
| x_468 = 0.0f; |
| } |
| float const x_487 = x_468; |
| return x_487; |
| } |
| |
| int getOutputFlatIndex_vi3_(constant Uniforms& x_48, thread int3* const coords) { |
| int3 const x_99 = *(coords); |
| int const x_105 = x_48.outShapeStrides.x; |
| int const x_107 = x_48.outShapeStrides.y; |
| return int(dot(float3(x_99), float3(int3(x_105, x_107, 1)))); |
| } |
| |
| void setOutput_i1_f1_(device ssbOut& x_54, thread int* const flatIndex, thread float* const value) { |
| int const x_95 = *(flatIndex); |
| float const x_96 = *(value); |
| x_54.result[x_95] = x_96; |
| return; |
| } |
| |
| void setOutput_i1_i1_i1_f1_(constant Uniforms& x_48, device ssbOut& x_54, thread int* const d0, thread int* const d1, thread int* const d2, thread float* const value_1) { |
| int flatIndex_1 = 0; |
| int3 param = 0; |
| int param_1 = 0; |
| float param_2 = 0.0f; |
| int const x_115 = *(d0); |
| int const x_116 = *(d1); |
| int const x_117 = *(d2); |
| param = int3(x_115, x_116, x_117); |
| int const x_120 = getOutputFlatIndex_vi3_(x_48, &(param)); |
| flatIndex_1 = x_120; |
| int const x_122 = flatIndex_1; |
| param_1 = x_122; |
| float const x_124 = *(value_1); |
| param_2 = x_124; |
| setOutput_i1_f1_(x_54, &(param_1), &(param_2)); |
| return; |
| } |
| |
| void mm_write_i1_i1_f1_(constant Uniforms& x_48, device ssbOut& x_54, thread int* const row_2, thread int* const col_2, thread float* const value_2, thread int* const tint_symbol_11) { |
| int3 outCoord = 0; |
| int param_14 = 0; |
| int param_15 = 0; |
| int param_16 = 0; |
| float param_17 = 0.0f; |
| int const x_491 = *(tint_symbol_11); |
| int const x_492 = *(row_2); |
| int const x_493 = *(col_2); |
| outCoord = int3(x_491, x_492, x_493); |
| int const x_496 = *(tint_symbol_11); |
| param_14 = x_496; |
| int const x_498 = *(row_2); |
| param_15 = x_498; |
| int const x_500 = *(col_2); |
| param_16 = x_500; |
| float const x_502 = *(value_2); |
| param_17 = x_502; |
| setOutput_i1_i1_i1_f1_(x_48, x_54, &(param_14), &(param_15), &(param_16), &(param_17)); |
| return; |
| } |
| |
| void mm_matMul_i1_i1_i1_(constant Uniforms& x_48, const device ssbA& x_165, const device ssbB& x_185, device ssbOut& x_54, thread int* const dimAOuter, thread int* const dimInner, thread int* const dimBOuter, thread uint3* const tint_symbol_12, thread uint3* const tint_symbol_13, thread int* const tint_symbol_14, thread int* const tint_symbol_15, thread int* const tint_symbol_16, threadgroup tint_array_wrapper* const tint_symbol_17, thread int* const tint_symbol_18, threadgroup tint_array_wrapper_2* const tint_symbol_19) { |
| int tileRow = 0; |
| int tileCol = 0; |
| int globalRow = 0; |
| int globalCol = 0; |
| int numTiles = 0; |
| int innerRow = 0; |
| int innerCol = 0; |
| tint_array_wrapper_4 acc = {}; |
| int tileColA = 0; |
| int tileRowB = 0; |
| int t = 0; |
| int innerRow_1 = 0; |
| int innerCol_1 = 0; |
| int inputRow = 0; |
| int inputCol = 0; |
| int param_3 = 0; |
| int param_4 = 0; |
| int innerRow_2 = 0; |
| int innerCol_2 = 0; |
| int inputRow_1 = 0; |
| int inputCol_1 = 0; |
| int param_5 = 0; |
| int param_6 = 0; |
| int k = 0; |
| int inner = 0; |
| tint_array_wrapper_3 BCached = {}; |
| int innerRow_3 = 0; |
| float ACached = 0.0f; |
| int innerCol_3 = 0; |
| int innerRow_4 = 0; |
| int innerCol_4 = 0; |
| int param_7 = 0; |
| int param_8 = 0; |
| float param_9 = 0.0f; |
| uint const x_132 = (*(tint_symbol_12)).y; |
| tileRow = (as_type<int>(x_132) * 1); |
| uint const x_137 = (*(tint_symbol_12)).x; |
| tileCol = (as_type<int>(x_137) * 1); |
| uint const x_143 = (*(tint_symbol_13)).y; |
| globalRow = (as_type<int>(x_143) * 1); |
| uint const x_148 = (*(tint_symbol_13)).x; |
| globalCol = (as_type<int>(x_148) * 1); |
| int const x_152 = *(dimInner); |
| numTiles = (((x_152 - 1) / 64) + 1); |
| innerRow = 0; |
| while (true) { |
| int const x_163 = innerRow; |
| if ((x_163 < 1)) { |
| } else { |
| break; |
| } |
| innerCol = 0; |
| while (true) { |
| int const x_171 = innerCol; |
| if ((x_171 < 1)) { |
| } else { |
| break; |
| } |
| int const x_177 = innerRow; |
| int const x_178 = innerCol; |
| acc.arr[x_177].arr[x_178] = 0.0f; |
| { |
| int const x_181 = innerCol; |
| innerCol = (x_181 + 1); |
| } |
| } |
| { |
| int const x_183 = innerRow; |
| innerRow = (x_183 + 1); |
| } |
| } |
| uint const x_187 = (*(tint_symbol_12)).x; |
| tileColA = (as_type<int>(x_187) * 64); |
| uint const x_192 = (*(tint_symbol_12)).y; |
| tileRowB = (as_type<int>(x_192) * 1); |
| t = 0; |
| while (true) { |
| int const x_201 = t; |
| int const x_202 = numTiles; |
| if ((x_201 < x_202)) { |
| } else { |
| break; |
| } |
| innerRow_1 = 0; |
| while (true) { |
| int const x_210 = innerRow_1; |
| if ((x_210 < 1)) { |
| } else { |
| break; |
| } |
| innerCol_1 = 0; |
| while (true) { |
| int const x_218 = innerCol_1; |
| if ((x_218 < 64)) { |
| } else { |
| break; |
| } |
| int const x_221 = tileRow; |
| int const x_222 = innerRow_1; |
| inputRow = (x_221 + x_222); |
| int const x_225 = tileColA; |
| int const x_226 = innerCol_1; |
| inputCol = (x_225 + x_226); |
| int const x_233 = inputRow; |
| int const x_234 = inputCol; |
| int const x_235 = globalRow; |
| int const x_236 = innerRow_1; |
| int const x_238 = t; |
| int const x_240 = inputCol; |
| param_3 = (x_235 + x_236); |
| param_4 = ((x_238 * 64) + x_240); |
| float const x_244 = mm_readA_i1_i1_(x_48, x_165, &(param_3), &(param_4), tint_symbol_14, tint_symbol_15, tint_symbol_16); |
| (*(tint_symbol_17)).arr[x_233].arr[x_234] = x_244; |
| { |
| int const x_247 = innerCol_1; |
| innerCol_1 = (x_247 + 1); |
| } |
| } |
| { |
| int const x_249 = innerRow_1; |
| innerRow_1 = (x_249 + 1); |
| } |
| } |
| innerRow_2 = 0; |
| while (true) { |
| int const x_257 = innerRow_2; |
| if ((x_257 < 1)) { |
| } else { |
| break; |
| } |
| innerCol_2 = 0; |
| while (true) { |
| int const x_265 = innerCol_2; |
| if ((x_265 < 1)) { |
| } else { |
| break; |
| } |
| int const x_268 = tileRowB; |
| int const x_269 = innerRow_2; |
| inputRow_1 = (x_268 + x_269); |
| int const x_272 = tileCol; |
| int const x_273 = innerCol_2; |
| inputCol_1 = (x_272 + x_273); |
| int const x_278 = inputRow_1; |
| int const x_279 = inputCol_1; |
| int const x_280 = t; |
| int const x_282 = inputRow_1; |
| int const x_284 = globalCol; |
| int const x_285 = innerCol_2; |
| param_5 = ((x_280 * 64) + x_282); |
| param_6 = (x_284 + x_285); |
| float const x_289 = mm_readB_i1_i1_(x_48, x_185, &(param_5), &(param_6), tint_symbol_15, tint_symbol_18, tint_symbol_16); |
| (*(tint_symbol_19)).arr[x_278].arr[x_279] = x_289; |
| { |
| int const x_291 = innerCol_2; |
| innerCol_2 = (x_291 + 1); |
| } |
| } |
| { |
| int const x_293 = innerRow_2; |
| innerRow_2 = (x_293 + 1); |
| } |
| } |
| threadgroup_barrier(mem_flags::mem_threadgroup); |
| k = 0; |
| while (true) { |
| int const x_302 = k; |
| if ((x_302 < 64)) { |
| } else { |
| break; |
| } |
| inner = 0; |
| while (true) { |
| int const x_310 = inner; |
| if ((x_310 < 1)) { |
| } else { |
| break; |
| } |
| int const x_314 = inner; |
| int const x_315 = k; |
| int const x_316 = tileCol; |
| int const x_317 = inner; |
| float const x_320 = (*(tint_symbol_19)).arr[x_315].arr[(x_316 + x_317)]; |
| BCached.arr[x_314] = x_320; |
| { |
| int const x_322 = inner; |
| inner = (x_322 + 1); |
| } |
| } |
| innerRow_3 = 0; |
| while (true) { |
| int const x_330 = innerRow_3; |
| if ((x_330 < 1)) { |
| } else { |
| break; |
| } |
| int const x_333 = tileRow; |
| int const x_334 = innerRow_3; |
| int const x_336 = k; |
| float const x_338 = (*(tint_symbol_17)).arr[(x_333 + x_334)].arr[x_336]; |
| ACached = x_338; |
| innerCol_3 = 0; |
| while (true) { |
| int const x_345 = innerCol_3; |
| if ((x_345 < 1)) { |
| } else { |
| break; |
| } |
| int const x_347 = innerRow_3; |
| int const x_348 = innerCol_3; |
| float const x_349 = ACached; |
| int const x_350 = innerCol_3; |
| float const x_352 = BCached.arr[x_350]; |
| float const x_355 = acc.arr[x_347].arr[x_348]; |
| acc.arr[x_347].arr[x_348] = (x_355 + (x_349 * x_352)); |
| { |
| int const x_358 = innerCol_3; |
| innerCol_3 = (x_358 + 1); |
| } |
| } |
| { |
| int const x_360 = innerRow_3; |
| innerRow_3 = (x_360 + 1); |
| } |
| } |
| { |
| int const x_362 = k; |
| k = (x_362 + 1); |
| } |
| } |
| threadgroup_barrier(mem_flags::mem_threadgroup); |
| { |
| int const x_364 = t; |
| t = (x_364 + 1); |
| } |
| } |
| innerRow_4 = 0; |
| while (true) { |
| int const x_372 = innerRow_4; |
| if ((x_372 < 1)) { |
| } else { |
| break; |
| } |
| innerCol_4 = 0; |
| while (true) { |
| bool x_393 = false; |
| bool x_394_phi = false; |
| int const x_380 = innerCol_4; |
| if ((x_380 < 1)) { |
| } else { |
| break; |
| } |
| int const x_382 = globalCol; |
| int const x_383 = innerCol_4; |
| int const x_385 = *(dimBOuter); |
| bool const x_386 = ((x_382 + x_383) < x_385); |
| x_394_phi = x_386; |
| if (x_386) { |
| int const x_389 = globalRow; |
| int const x_390 = innerRow_4; |
| int const x_392 = *(dimAOuter); |
| x_393 = ((x_389 + x_390) < x_392); |
| x_394_phi = x_393; |
| } |
| bool const x_394 = x_394_phi; |
| if (x_394) { |
| int const x_397 = globalRow; |
| int const x_398 = innerRow_4; |
| int const x_400 = globalCol; |
| int const x_401 = innerCol_4; |
| int const x_403 = innerRow_4; |
| int const x_404 = innerCol_4; |
| param_7 = (x_397 + x_398); |
| param_8 = (x_400 + x_401); |
| float const x_409 = acc.arr[x_403].arr[x_404]; |
| param_9 = x_409; |
| mm_write_i1_i1_f1_(x_48, x_54, &(param_7), &(param_8), &(param_9), tint_symbol_16); |
| } |
| { |
| int const x_411 = innerCol_4; |
| innerCol_4 = (x_411 + 1); |
| } |
| } |
| { |
| int const x_413 = innerRow_4; |
| innerRow_4 = (x_413 + 1); |
| } |
| } |
| return; |
| } |
| |
| void main_1(constant Uniforms& x_48, const device ssbA& x_165, const device ssbB& x_185, device ssbOut& x_54, thread int* const tint_symbol_20, thread int* const tint_symbol_21, thread int* const tint_symbol_22, thread uint3* const tint_symbol_23, thread int* const tint_symbol_24, thread uint3* const tint_symbol_25, threadgroup tint_array_wrapper* const tint_symbol_26, threadgroup tint_array_wrapper_2* const tint_symbol_27) { |
| int param_18 = 0; |
| int param_19 = 0; |
| int param_20 = 0; |
| int const x_67 = x_48.aShape.y; |
| *(tint_symbol_20) = x_67; |
| int const x_71 = x_48.aShape.z; |
| *(tint_symbol_21) = x_71; |
| int const x_75 = x_48.bShape.z; |
| *(tint_symbol_22) = x_75; |
| uint const x_505 = (*(tint_symbol_23)).z; |
| *(tint_symbol_24) = as_type<int>(x_505); |
| int const x_508 = *(tint_symbol_20); |
| param_18 = x_508; |
| int const x_510 = *(tint_symbol_21); |
| param_19 = x_510; |
| int const x_512 = *(tint_symbol_22); |
| param_20 = x_512; |
| mm_matMul_i1_i1_i1_(x_48, x_165, x_185, x_54, &(param_18), &(param_19), &(param_20), tint_symbol_25, tint_symbol_23, tint_symbol_20, tint_symbol_21, tint_symbol_24, tint_symbol_26, tint_symbol_22, tint_symbol_27); |
| return; |
| } |
| |
| kernel void tint_symbol_1(uint3 gl_LocalInvocationID_param [[thread_position_in_threadgroup]], uint3 gl_GlobalInvocationID_param [[thread_position_in_grid]], uint local_invocation_index [[thread_index_in_threadgroup]], constant Uniforms& x_48 [[buffer(3)]], const device ssbA& x_165 [[buffer(1)]], const device ssbB& x_185 [[buffer(2)]], device ssbOut& x_54 [[buffer(0)]]) { |
| threadgroup tint_array_wrapper tint_symbol_28; |
| threadgroup tint_array_wrapper_2 tint_symbol_29; |
| thread uint3 tint_symbol_30 = 0u; |
| thread uint3 tint_symbol_31 = 0u; |
| thread int tint_symbol_32 = 0; |
| thread int tint_symbol_33 = 0; |
| thread int tint_symbol_34 = 0; |
| thread int tint_symbol_35 = 0; |
| if ((local_invocation_index == 0u)) { |
| tint_array_wrapper const tint_symbol_3 = {.arr={}}; |
| tint_symbol_28 = tint_symbol_3; |
| tint_array_wrapper_2 const tint_symbol_4 = {.arr={}}; |
| tint_symbol_29 = tint_symbol_4; |
| } |
| threadgroup_barrier(mem_flags::mem_threadgroup); |
| tint_symbol_30 = gl_LocalInvocationID_param; |
| tint_symbol_31 = gl_GlobalInvocationID_param; |
| main_1(x_48, x_165, x_185, x_54, &(tint_symbol_32), &(tint_symbol_33), &(tint_symbol_34), &(tint_symbol_31), &(tint_symbol_35), &(tint_symbol_30), &(tint_symbol_28), &(tint_symbol_29)); |
| return; |
| } |
| |