| #include <metal_stdlib> |
| |
| using namespace metal; |
| struct Uniforms { |
| /* 0x0000 */ uint2 aShape; |
| /* 0x0008 */ uint2 bShape; |
| /* 0x0010 */ uint2 outShape; |
| }; |
| struct Matrix { |
| /* 0x0000 */ uint numbers[1]; |
| }; |
| |
| void tint_symbol_inner(uint3 global_id, const constant Uniforms* const tint_symbol_1, const device Matrix* const tint_symbol_2, const device Matrix* const tint_symbol_3, device Matrix* const tint_symbol_4) { |
| uint2 const resultCell = uint2(global_id[1], global_id[0]); |
| uint const dimInner = (*(tint_symbol_1)).aShape[1]; |
| uint const dimOutter = (*(tint_symbol_1)).outShape[1]; |
| uint result = 0u; |
| for(uint i = 0u; (i < dimInner); i = (i + 1u)) { |
| uint const a = (i + (resultCell[0] * dimInner)); |
| uint const b = (resultCell[1] + (i * dimOutter)); |
| result = (result + ((*(tint_symbol_2)).numbers[a] * (*(tint_symbol_3)).numbers[b])); |
| } |
| uint const index = (resultCell[1] + (resultCell[0] * dimOutter)); |
| (*(tint_symbol_4)).numbers[index] = result; |
| } |
| |
| kernel void tint_symbol(const constant Uniforms* tint_symbol_5 [[buffer(0)]], const device Matrix* tint_symbol_6 [[buffer(2)]], const device Matrix* tint_symbol_7 [[buffer(3)]], device Matrix* tint_symbol_8 [[buffer(1)]], uint3 global_id [[thread_position_in_grid]]) { |
| tint_symbol_inner(global_id, tint_symbol_5, tint_symbol_6, tint_symbol_7, tint_symbol_8); |
| return; |
| } |
| |