#include <metal_stdlib>

using namespace metal;
struct Uniforms {
  /* 0x0000 */ float tint_symbol;
  /* 0x0004 */ int8_t tint_pad[12];
  /* 0x0010 */ packed_int3 aShape;
  /* 0x001c */ int8_t tint_pad_1[4];
  /* 0x0020 */ packed_int3 bShape;
  /* 0x002c */ int8_t tint_pad_2[4];
  /* 0x0030 */ packed_int3 outShape;
  /* 0x003c */ int8_t tint_pad_3[4];
  /* 0x0040 */ packed_int2 outShapeStrides;
};
struct ssbOut {
  /* 0x0000 */ float result[1];
};
struct ssbA {
  /* 0x0000 */ float A[1];
};
struct ssbB {
  /* 0x0000 */ float B[1];
};
struct tint_array_wrapper_1 {
  float arr[64];
};
struct tint_array_wrapper {
  tint_array_wrapper_1 arr[64];
};
struct tint_array_wrapper_3 {
  float arr[1];
};
struct tint_array_wrapper_2 {
  tint_array_wrapper_3 arr[64];
};
struct tint_array_wrapper_4 {
  tint_array_wrapper_3 arr[1];
};

bool coordsInBounds_vi2_vi2_(thread int2* const coord, thread int2* const shape) {
  bool x_87 = false;
  bool x_88_phi = false;
  int2 const x_76 = *(coord);
  bool const x_81 = all((x_76 >= int2(0, 0)));
  x_88_phi = x_81;
  if (x_81) {
    int2 const x_84 = *(coord);
    int2 const x_85 = *(shape);
    x_87 = all((x_84 < x_85));
    x_88_phi = x_87;
  }
  bool const x_88 = x_88_phi;
  return x_88;
}

float mm_readA_i1_i1_(constant Uniforms& x_48, const device ssbA& x_165, thread int* const row, thread int* const col, thread int* const tint_symbol_3, thread int* const tint_symbol_4, thread int* const tint_symbol_5) {
  int batchASize = 0;
  int2 param_10 = 0;
  int2 param_11 = 0;
  float x_430 = 0.0f;
  int const x_417 = x_48.aShape.y;
  int const x_419 = x_48.aShape.z;
  batchASize = as_type<int>((as_type<uint>(x_417) * as_type<uint>(x_419)));
  int const x_421 = *(row);
  int const x_422 = *(col);
  int const x_424 = *(tint_symbol_3);
  int const x_425 = *(tint_symbol_4);
  param_10 = int2(x_421, x_422);
  param_11 = int2(x_424, x_425);
  bool const x_429 = coordsInBounds_vi2_vi2_(&(param_10), &(param_11));
  if (x_429) {
    int const x_438 = *(tint_symbol_5);
    int const x_439 = batchASize;
    int const x_441 = *(row);
    int const x_442 = *(tint_symbol_4);
    int const x_445 = *(col);
    float const x_448 = x_165.A[as_type<int>((as_type<uint>(as_type<int>((as_type<uint>(as_type<int>((as_type<uint>(x_438) * as_type<uint>(x_439)))) + as_type<uint>(as_type<int>((as_type<uint>(x_441) * as_type<uint>(x_442))))))) + as_type<uint>(x_445)))];
    x_430 = x_448;
  } else {
    x_430 = 0.0f;
  }
  float const x_450 = x_430;
  return x_450;
}

float mm_readB_i1_i1_(constant Uniforms& x_48, const device ssbB& x_185, thread int* const row_1, thread int* const col_1, thread int* const tint_symbol_6, thread int* const tint_symbol_7, thread int* const tint_symbol_8) {
  int batchBSize = 0;
  int2 param_12 = 0;
  int2 param_13 = 0;
  float x_468 = 0.0f;
  int const x_455 = x_48.bShape.y;
  int const x_457 = x_48.bShape.z;
  batchBSize = as_type<int>((as_type<uint>(x_455) * as_type<uint>(x_457)));
  int const x_459 = *(row_1);
  int const x_460 = *(col_1);
  int const x_462 = *(tint_symbol_6);
  int const x_463 = *(tint_symbol_7);
  param_12 = int2(x_459, x_460);
  param_13 = int2(x_462, x_463);
  bool const x_467 = coordsInBounds_vi2_vi2_(&(param_12), &(param_13));
  if (x_467) {
    int const x_475 = *(tint_symbol_8);
    int const x_476 = batchBSize;
    int const x_478 = *(row_1);
    int const x_479 = *(tint_symbol_7);
    int const x_482 = *(col_1);
    float const x_485 = x_185.B[as_type<int>((as_type<uint>(as_type<int>((as_type<uint>(as_type<int>((as_type<uint>(x_475) * as_type<uint>(x_476)))) + as_type<uint>(as_type<int>((as_type<uint>(x_478) * as_type<uint>(x_479))))))) + as_type<uint>(x_482)))];
    x_468 = x_485;
  } else {
    x_468 = 0.0f;
  }
  float const x_487 = x_468;
  return x_487;
}

int getOutputFlatIndex_vi3_(constant Uniforms& x_48, thread int3* const coords) {
  int3 const x_99 = *(coords);
  int const x_105 = x_48.outShapeStrides.x;
  int const x_107 = x_48.outShapeStrides.y;
  return int(dot(float3(x_99), float3(int3(x_105, x_107, 1))));
}

void setOutput_i1_f1_(device ssbOut& x_54, thread int* const flatIndex, thread float* const value) {
  int const x_95 = *(flatIndex);
  float const x_96 = *(value);
  x_54.result[x_95] = x_96;
  return;
}

void setOutput_i1_i1_i1_f1_(constant Uniforms& x_48, device ssbOut& x_54, thread int* const d0, thread int* const d1, thread int* const d2, thread float* const value_1) {
  int flatIndex_1 = 0;
  int3 param = 0;
  int param_1 = 0;
  float param_2 = 0.0f;
  int const x_115 = *(d0);
  int const x_116 = *(d1);
  int const x_117 = *(d2);
  param = int3(x_115, x_116, x_117);
  int const x_120 = getOutputFlatIndex_vi3_(x_48, &(param));
  flatIndex_1 = x_120;
  int const x_122 = flatIndex_1;
  param_1 = x_122;
  float const x_124 = *(value_1);
  param_2 = x_124;
  setOutput_i1_f1_(x_54, &(param_1), &(param_2));
  return;
}

void mm_write_i1_i1_f1_(constant Uniforms& x_48, device ssbOut& x_54, thread int* const row_2, thread int* const col_2, thread float* const value_2, thread int* const tint_symbol_9) {
  int3 outCoord = 0;
  int param_14 = 0;
  int param_15 = 0;
  int param_16 = 0;
  float param_17 = 0.0f;
  int const x_491 = *(tint_symbol_9);
  int const x_492 = *(row_2);
  int const x_493 = *(col_2);
  outCoord = int3(x_491, x_492, x_493);
  int const x_496 = *(tint_symbol_9);
  param_14 = x_496;
  int const x_498 = *(row_2);
  param_15 = x_498;
  int const x_500 = *(col_2);
  param_16 = x_500;
  float const x_502 = *(value_2);
  param_17 = x_502;
  setOutput_i1_i1_i1_f1_(x_48, x_54, &(param_14), &(param_15), &(param_16), &(param_17));
  return;
}

void mm_matMul_i1_i1_i1_(constant Uniforms& x_48, const device ssbA& x_165, const device ssbB& x_185, device ssbOut& x_54, thread int* const dimAOuter, thread int* const dimInner, thread int* const dimBOuter, thread uint3* const tint_symbol_10, thread uint3* const tint_symbol_11, thread int* const tint_symbol_12, thread int* const tint_symbol_13, thread int* const tint_symbol_14, threadgroup tint_array_wrapper* const tint_symbol_15, thread int* const tint_symbol_16, threadgroup tint_array_wrapper_2* const tint_symbol_17) {
  int tileRow = 0;
  int tileCol = 0;
  int globalRow = 0;
  int globalCol = 0;
  int numTiles = 0;
  int innerRow = 0;
  int innerCol = 0;
  tint_array_wrapper_4 acc = {};
  int tileColA = 0;
  int tileRowB = 0;
  int t = 0;
  int innerRow_1 = 0;
  int innerCol_1 = 0;
  int inputRow = 0;
  int inputCol = 0;
  int param_3 = 0;
  int param_4 = 0;
  int innerRow_2 = 0;
  int innerCol_2 = 0;
  int inputRow_1 = 0;
  int inputCol_1 = 0;
  int param_5 = 0;
  int param_6 = 0;
  int k = 0;
  int inner = 0;
  tint_array_wrapper_3 BCached = {};
  int innerRow_3 = 0;
  float ACached = 0.0f;
  int innerCol_3 = 0;
  int innerRow_4 = 0;
  int innerCol_4 = 0;
  int param_7 = 0;
  int param_8 = 0;
  float param_9 = 0.0f;
  uint const x_132 = (*(tint_symbol_10)).y;
  tileRow = as_type<int>((as_type<uint>(as_type<int>(x_132)) * as_type<uint>(1)));
  uint const x_137 = (*(tint_symbol_10)).x;
  tileCol = as_type<int>((as_type<uint>(as_type<int>(x_137)) * as_type<uint>(1)));
  uint const x_143 = (*(tint_symbol_11)).y;
  globalRow = as_type<int>((as_type<uint>(as_type<int>(x_143)) * as_type<uint>(1)));
  uint const x_148 = (*(tint_symbol_11)).x;
  globalCol = as_type<int>((as_type<uint>(as_type<int>(x_148)) * as_type<uint>(1)));
  int const x_152 = *(dimInner);
  numTiles = as_type<int>((as_type<uint>((as_type<int>((as_type<uint>(x_152) - as_type<uint>(1))) / 64)) + as_type<uint>(1)));
  innerRow = 0;
  while (true) {
    int const x_163 = innerRow;
    if ((x_163 < 1)) {
    } else {
      break;
    }
    innerCol = 0;
    while (true) {
      int const x_171 = innerCol;
      if ((x_171 < 1)) {
      } else {
        break;
      }
      int const x_177 = innerRow;
      int const x_178 = innerCol;
      acc.arr[x_177].arr[x_178] = 0.0f;
      {
        int const x_181 = innerCol;
        innerCol = as_type<int>((as_type<uint>(x_181) + as_type<uint>(1)));
      }
    }
    {
      int const x_183 = innerRow;
      innerRow = as_type<int>((as_type<uint>(x_183) + as_type<uint>(1)));
    }
  }
  uint const x_187 = (*(tint_symbol_10)).x;
  tileColA = as_type<int>((as_type<uint>(as_type<int>(x_187)) * as_type<uint>(64)));
  uint const x_192 = (*(tint_symbol_10)).y;
  tileRowB = as_type<int>((as_type<uint>(as_type<int>(x_192)) * as_type<uint>(1)));
  t = 0;
  while (true) {
    int const x_201 = t;
    int const x_202 = numTiles;
    if ((x_201 < x_202)) {
    } else {
      break;
    }
    innerRow_1 = 0;
    while (true) {
      int const x_210 = innerRow_1;
      if ((x_210 < 1)) {
      } else {
        break;
      }
      innerCol_1 = 0;
      while (true) {
        int const x_218 = innerCol_1;
        if ((x_218 < 64)) {
        } else {
          break;
        }
        int const x_221 = tileRow;
        int const x_222 = innerRow_1;
        inputRow = as_type<int>((as_type<uint>(x_221) + as_type<uint>(x_222)));
        int const x_225 = tileColA;
        int const x_226 = innerCol_1;
        inputCol = as_type<int>((as_type<uint>(x_225) + as_type<uint>(x_226)));
        int const x_233 = inputRow;
        int const x_234 = inputCol;
        int const x_235 = globalRow;
        int const x_236 = innerRow_1;
        int const x_238 = t;
        int const x_240 = inputCol;
        param_3 = as_type<int>((as_type<uint>(x_235) + as_type<uint>(x_236)));
        param_4 = as_type<int>((as_type<uint>(as_type<int>((as_type<uint>(x_238) * as_type<uint>(64)))) + as_type<uint>(x_240)));
        float const x_244 = mm_readA_i1_i1_(x_48, x_165, &(param_3), &(param_4), tint_symbol_12, tint_symbol_13, tint_symbol_14);
        (*(tint_symbol_15)).arr[x_233].arr[x_234] = x_244;
        {
          int const x_247 = innerCol_1;
          innerCol_1 = as_type<int>((as_type<uint>(x_247) + as_type<uint>(1)));
        }
      }
      {
        int const x_249 = innerRow_1;
        innerRow_1 = as_type<int>((as_type<uint>(x_249) + as_type<uint>(1)));
      }
    }
    innerRow_2 = 0;
    while (true) {
      int const x_257 = innerRow_2;
      if ((x_257 < 1)) {
      } else {
        break;
      }
      innerCol_2 = 0;
      while (true) {
        int const x_265 = innerCol_2;
        if ((x_265 < 1)) {
        } else {
          break;
        }
        int const x_268 = tileRowB;
        int const x_269 = innerRow_2;
        inputRow_1 = as_type<int>((as_type<uint>(x_268) + as_type<uint>(x_269)));
        int const x_272 = tileCol;
        int const x_273 = innerCol_2;
        inputCol_1 = as_type<int>((as_type<uint>(x_272) + as_type<uint>(x_273)));
        int const x_278 = inputRow_1;
        int const x_279 = inputCol_1;
        int const x_280 = t;
        int const x_282 = inputRow_1;
        int const x_284 = globalCol;
        int const x_285 = innerCol_2;
        param_5 = as_type<int>((as_type<uint>(as_type<int>((as_type<uint>(x_280) * as_type<uint>(64)))) + as_type<uint>(x_282)));
        param_6 = as_type<int>((as_type<uint>(x_284) + as_type<uint>(x_285)));
        float const x_289 = mm_readB_i1_i1_(x_48, x_185, &(param_5), &(param_6), tint_symbol_13, tint_symbol_16, tint_symbol_14);
        (*(tint_symbol_17)).arr[x_278].arr[x_279] = x_289;
        {
          int const x_291 = innerCol_2;
          innerCol_2 = as_type<int>((as_type<uint>(x_291) + as_type<uint>(1)));
        }
      }
      {
        int const x_293 = innerRow_2;
        innerRow_2 = as_type<int>((as_type<uint>(x_293) + as_type<uint>(1)));
      }
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);
    k = 0;
    while (true) {
      int const x_302 = k;
      if ((x_302 < 64)) {
      } else {
        break;
      }
      inner = 0;
      while (true) {
        int const x_310 = inner;
        if ((x_310 < 1)) {
        } else {
          break;
        }
        int const x_314 = inner;
        int const x_315 = k;
        int const x_316 = tileCol;
        int const x_317 = inner;
        float const x_320 = (*(tint_symbol_17)).arr[x_315].arr[as_type<int>((as_type<uint>(x_316) + as_type<uint>(x_317)))];
        BCached.arr[x_314] = x_320;
        {
          int const x_322 = inner;
          inner = as_type<int>((as_type<uint>(x_322) + as_type<uint>(1)));
        }
      }
      innerRow_3 = 0;
      while (true) {
        int const x_330 = innerRow_3;
        if ((x_330 < 1)) {
        } else {
          break;
        }
        int const x_333 = tileRow;
        int const x_334 = innerRow_3;
        int const x_336 = k;
        float const x_338 = (*(tint_symbol_15)).arr[as_type<int>((as_type<uint>(x_333) + as_type<uint>(x_334)))].arr[x_336];
        ACached = x_338;
        innerCol_3 = 0;
        while (true) {
          int const x_345 = innerCol_3;
          if ((x_345 < 1)) {
          } else {
            break;
          }
          int const x_347 = innerRow_3;
          int const x_348 = innerCol_3;
          float const x_349 = ACached;
          int const x_350 = innerCol_3;
          float const x_352 = BCached.arr[x_350];
          float const x_355 = acc.arr[x_347].arr[x_348];
          acc.arr[x_347].arr[x_348] = (x_355 + (x_349 * x_352));
          {
            int const x_358 = innerCol_3;
            innerCol_3 = as_type<int>((as_type<uint>(x_358) + as_type<uint>(1)));
          }
        }
        {
          int const x_360 = innerRow_3;
          innerRow_3 = as_type<int>((as_type<uint>(x_360) + as_type<uint>(1)));
        }
      }
      {
        int const x_362 = k;
        k = as_type<int>((as_type<uint>(x_362) + as_type<uint>(1)));
      }
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);
    {
      int const x_364 = t;
      t = as_type<int>((as_type<uint>(x_364) + as_type<uint>(1)));
    }
  }
  innerRow_4 = 0;
  while (true) {
    int const x_372 = innerRow_4;
    if ((x_372 < 1)) {
    } else {
      break;
    }
    innerCol_4 = 0;
    while (true) {
      bool x_393 = false;
      bool x_394_phi = false;
      int const x_380 = innerCol_4;
      if ((x_380 < 1)) {
      } else {
        break;
      }
      int const x_382 = globalCol;
      int const x_383 = innerCol_4;
      int const x_385 = *(dimBOuter);
      bool const x_386 = (as_type<int>((as_type<uint>(x_382) + as_type<uint>(x_383))) < x_385);
      x_394_phi = x_386;
      if (x_386) {
        int const x_389 = globalRow;
        int const x_390 = innerRow_4;
        int const x_392 = *(dimAOuter);
        x_393 = (as_type<int>((as_type<uint>(x_389) + as_type<uint>(x_390))) < x_392);
        x_394_phi = x_393;
      }
      bool const x_394 = x_394_phi;
      if (x_394) {
        int const x_397 = globalRow;
        int const x_398 = innerRow_4;
        int const x_400 = globalCol;
        int const x_401 = innerCol_4;
        int const x_403 = innerRow_4;
        int const x_404 = innerCol_4;
        param_7 = as_type<int>((as_type<uint>(x_397) + as_type<uint>(x_398)));
        param_8 = as_type<int>((as_type<uint>(x_400) + as_type<uint>(x_401)));
        float const x_409 = acc.arr[x_403].arr[x_404];
        param_9 = x_409;
        mm_write_i1_i1_f1_(x_48, x_54, &(param_7), &(param_8), &(param_9), tint_symbol_14);
      }
      {
        int const x_411 = innerCol_4;
        innerCol_4 = as_type<int>((as_type<uint>(x_411) + as_type<uint>(1)));
      }
    }
    {
      int const x_413 = innerRow_4;
      innerRow_4 = as_type<int>((as_type<uint>(x_413) + as_type<uint>(1)));
    }
  }
  return;
}

void main_1(constant Uniforms& x_48, const device ssbA& x_165, const device ssbB& x_185, device ssbOut& x_54, thread int* const tint_symbol_18, thread int* const tint_symbol_19, thread int* const tint_symbol_20, thread uint3* const tint_symbol_21, thread int* const tint_symbol_22, thread uint3* const tint_symbol_23, threadgroup tint_array_wrapper* const tint_symbol_24, threadgroup tint_array_wrapper_2* const tint_symbol_25) {
  int param_18 = 0;
  int param_19 = 0;
  int param_20 = 0;
  int const x_67 = x_48.aShape.y;
  *(tint_symbol_18) = x_67;
  int const x_71 = x_48.aShape.z;
  *(tint_symbol_19) = x_71;
  int const x_75 = x_48.bShape.z;
  *(tint_symbol_20) = x_75;
  uint const x_505 = (*(tint_symbol_21)).z;
  *(tint_symbol_22) = as_type<int>(x_505);
  int const x_508 = *(tint_symbol_18);
  param_18 = x_508;
  int const x_510 = *(tint_symbol_19);
  param_19 = x_510;
  int const x_512 = *(tint_symbol_20);
  param_20 = x_512;
  mm_matMul_i1_i1_i1_(x_48, x_165, x_185, x_54, &(param_18), &(param_19), &(param_20), tint_symbol_23, tint_symbol_21, tint_symbol_18, tint_symbol_19, tint_symbol_22, tint_symbol_24, tint_symbol_20, tint_symbol_25);
  return;
}

kernel void tint_symbol_1(uint3 gl_LocalInvocationID_param [[thread_position_in_threadgroup]], uint3 gl_GlobalInvocationID_param [[thread_position_in_grid]], uint local_invocation_index [[thread_index_in_threadgroup]], constant Uniforms& x_48 [[buffer(3)]], const device ssbA& x_165 [[buffer(1)]], const device ssbB& x_185 [[buffer(2)]], device ssbOut& x_54 [[buffer(0)]]) {
  threadgroup tint_array_wrapper_2 tint_symbol_26;
  threadgroup tint_array_wrapper tint_symbol_27;
  thread uint3 tint_symbol_28 = 0u;
  thread uint3 tint_symbol_29 = 0u;
  thread int tint_symbol_30 = 0;
  thread int tint_symbol_31 = 0;
  thread int tint_symbol_32 = 0;
  thread int tint_symbol_33 = 0;
  {
    uint const i_1 = local_invocation_index;
    uint const i_2 = (local_invocation_index % 1u);
    tint_symbol_26.arr[i_1].arr[i_2] = float();
  }
  for(uint idx = local_invocation_index; (idx < 4096u); idx = (idx + 64u)) {
    uint const i = (idx / 64u);
    uint const i_1 = (idx % 64u);
    tint_symbol_27.arr[i].arr[i_1] = float();
  }
  threadgroup_barrier(mem_flags::mem_threadgroup);
  tint_symbol_28 = gl_LocalInvocationID_param;
  tint_symbol_29 = gl_GlobalInvocationID_param;
  main_1(x_48, x_165, x_185, x_54, &(tint_symbol_30), &(tint_symbol_31), &(tint_symbol_32), &(tint_symbol_29), &(tint_symbol_33), &(tint_symbol_28), &(tint_symbol_27), &(tint_symbol_26));
  return;
}

