Created
September 16, 2024 01:04
-
-
Save hvaara/e3002c3258c689f9b6a2dd56fb15317c to your computer and use it in GitHub Desktop.
Tinygrad fused kernel example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| METAL_XCODE=1 DISABLE_COMPILER_CACHE=1 DEBUG=4 python3 -c "from tinygrad import Tensor; | |
| N = 1024; a, b = Tensor.rand(N, N), Tensor.rand(N, N); | |
| c = (a.reshape(N, 1, N) * b.T.reshape(1, N, N)).sum(axis=2); | |
| print((c.numpy() - (a.numpy() @ b.numpy())).mean())" | |
| opened device METAL from pid:15248 | |
| opened device NPY from pid:15248 | |
| *** CUSTOM 1 custom_random mem 0.00 GB | |
| *** CUSTOM 2 custom_random mem 0.01 GB | |
| TENSOR CORES [(1, 1024, 1)] [(0, 1024, 1024)] WMMA_8_8_8_float_float | |
| r_32_8_2_4_2_2_4_128_8_2_4_4 | |
| UOp(UOps.SINK, dtypes.void, arg=KernelInfo(local_dims=5, upcasted=4, dont_use_locals=False), src=( | |
| UOp(UOps.STORE, dtypes.void, arg=None, src=( | |
| UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), | |
| UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 8, 2, 4, 2, 2, 4, 1, 1, 2, 4, 4), strides=(32768, 128, 2, 1024, 4, 4096, 32, 0, 0, 1, 8192, 8), offset=0, mask=None, contiguous=False),)), src=()), | |
| UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7,)), src=( | |
| UOp(UOps.WMMA, dtypes.float, arg=('WMMA_8_8_8_float_float', (8, 8, 8), dtypes.float, dtypes.float, 'METAL', 32, (((9, 2),), ((9, 2),), ((9, 2),)), (8,)), src=( | |
| UOp(UOps.LOAD, dtypes.float, arg=None, src=( | |
| UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), | |
| UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 8, 2, 4, 2, 2, 4, 128, 8, 2, 4, 4), strides=(32768, 0, 2, 1024, 4, 4096, 0, 8, 0, 1, 8192, 0), offset=0, mask=None, contiguous=False),)), src=()),)), | |
| UOp(UOps.LOAD, dtypes.float, arg=None, src=( | |
| UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), | |
| UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 8, 2, 4, 2, 2, 4, 128, 8, 2, 4, 4), strides=(0, 128, 2, 1024, 4, 4096, 32, 8192, 0, 1, 0, 8), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) | |
| [Opt(op=OptOps.TC, axis=0, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=1, amt=4)] | |
| #include <metal_stdlib> | |
| using namespace metal; | |
| float2 __WMMA_8_8_8_float_float(float2 m, float2 n, float2 o) { | |
| simdgroup_float8x8 a,b,c; a.thread_elements()[0] = m.x; a.thread_elements()[1] = m.y; b.thread_elements()[0] = n.x; | |
| b.thread_elements()[1] = n.y; c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c); | |
| return float2(c.thread_elements()[0], c.thread_elements()[1]); | |
| } | |
| kernel void r_32_8_2_4_2_2_4_128_8_2_4_4(device float* data0, const device float* data1, const device float* data2, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { | |
| int gidx0 = gid.x; /* 8 */ | |
| int gidx1 = gid.y; /* 32 */ | |
| int lidx0 = lid.x; /* 16 */ | |
| int lidx1 = lid.y; /* 2 */ | |
| int lidx2 = lid.z; /* 4 */ | |
| float2 cast0 = float2(0.0f,0.0f); | |
| int alu0 = (gidx0*128); | |
| int alu1 = (gidx1*32768); | |
| int alu2 = (lidx1*4096); | |
| int alu3 = (lidx2*32); | |
| int alu4 = ((lidx0/8)*4); | |
| int alu5 = ((lidx0%2)*2); | |
| int alu6 = (((lidx0/2)%4)*1024); | |
| int alu7 = (alu1+alu0+alu5+alu6+alu4+alu2+alu3); | |
| float2 acc0 = cast0; | |
| float2 acc1 = cast0; | |
| float2 acc2 = cast0; | |
| float2 acc3 = cast0; | |
| float2 acc4 = cast0; | |
| float2 acc5 = cast0; | |
| float2 acc6 = cast0; | |
| float2 acc7 = cast0; | |
| float2 acc8 = cast0; | |
| float2 acc9 = cast0; | |
| float2 acc10 = cast0; | |
| float2 acc11 = cast0; | |
| float2 acc12 = cast0; | |
| float2 acc13 = cast0; | |
| float2 acc14 = cast0; | |
| float2 acc15 = cast0; | |
| for (int ridx0 = 0; ridx0 < 128; ridx0++) { | |
| int alu8 = (alu1+alu5+alu6+alu4+alu2+(ridx0*8)); | |
| int alu9 = (alu0+alu5+alu6+alu4+alu2+alu3+(ridx0*8192)); | |
| float2 val0 = *((device float2*)(data1+alu8+8192)); | |
| float2 val1 = *((device float2*)(data1+alu8+16384)); | |
| float2 val2 = *((device float2*)(data1+alu8+24576)); | |
| float2 val3 = *((device float2*)(data1+alu8)); | |
| float2 val4 = *((device float2*)(data2+alu9+8)); | |
| float2 wmma0 = __WMMA_8_8_8_float_float(val0, val4, acc5); | |
| float2 wmma1 = __WMMA_8_8_8_float_float(val1, val4, acc9); | |
| float2 wmma2 = __WMMA_8_8_8_float_float(val2, val4, acc13); | |
| float2 wmma3 = __WMMA_8_8_8_float_float(val3, val4, acc1); | |
| float2 val5 = *((device float2*)(data2+alu9+16)); | |
| float2 wmma4 = __WMMA_8_8_8_float_float(val0, val5, acc6); | |
| float2 wmma5 = __WMMA_8_8_8_float_float(val1, val5, acc10); | |
| float2 wmma6 = __WMMA_8_8_8_float_float(val2, val5, acc14); | |
| float2 wmma7 = __WMMA_8_8_8_float_float(val3, val5, acc2); | |
| float2 val6 = *((device float2*)(data2+alu9+24)); | |
| float2 wmma8 = __WMMA_8_8_8_float_float(val0, val6, acc7); | |
| float2 wmma9 = __WMMA_8_8_8_float_float(val1, val6, acc11); | |
| float2 wmma10 = __WMMA_8_8_8_float_float(val2, val6, acc15); | |
| float2 wmma11 = __WMMA_8_8_8_float_float(val3, val6, acc3); | |
| float2 val7 = *((device float2*)(data2+alu9)); | |
| float2 wmma12 = __WMMA_8_8_8_float_float(val0, val7, acc4); | |
| float2 wmma13 = __WMMA_8_8_8_float_float(val1, val7, acc8); | |
| float2 wmma14 = __WMMA_8_8_8_float_float(val2, val7, acc12); | |
| float2 wmma15 = __WMMA_8_8_8_float_float(val3, val7, acc0); | |
| acc0 = wmma15; | |
| acc1 = wmma3; | |
| acc2 = wmma7; | |
| acc3 = wmma11; | |
| acc4 = wmma12; | |
| acc5 = wmma0; | |
| acc6 = wmma4; | |
| acc7 = wmma8; | |
| acc8 = wmma13; | |
| acc9 = wmma1; | |
| acc10 = wmma5; | |
| acc11 = wmma9; | |
| acc12 = wmma14; | |
| acc13 = wmma2; | |
| acc14 = wmma6; | |
| acc15 = wmma10; | |
| } | |
| *((device float2*)(data0+alu7+8)) = acc1; | |
| *((device float2*)(data0+alu7+16)) = acc2; | |
| *((device float2*)(data0+alu7+24)) = acc3; | |
| *((device float2*)(data0+alu7+8192)) = acc4; | |
| *((device float2*)(data0+alu7+8200)) = acc5; | |
| *((device float2*)(data0+alu7+8208)) = acc6; | |
| *((device float2*)(data0+alu7+8216)) = acc7; | |
| *((device float2*)(data0+alu7+16384)) = acc8; | |
| *((device float2*)(data0+alu7+16392)) = acc9; | |
| *((device float2*)(data0+alu7+16400)) = acc10; | |
| *((device float2*)(data0+alu7+16408)) = acc11; | |
| *((device float2*)(data0+alu7+24576)) = acc12; | |
| *((device float2*)(data0+alu7+24584)) = acc13; | |
| *((device float2*)(data0+alu7+24592)) = acc14; | |
| *((device float2*)(data0+alu7+24600)) = acc15; | |
| *((device float2*)(data0+alu7)) = acc0; | |
| } | |
| *** METAL 3 r_32_8_2_4_2_2_4_128_8_2_4_4 mem 0.01 GB tm 878.75us/ 0.88ms ( 2443.79 GFLOPS 14.3|310.2 GB/s) ['__mul__', 'sum'] | |
| opened device CLANG from pid:15248 | |
| *** CLANG 4 copy 4.19M, CLANG <- METAL mem 0.02 GB tm 590.25us/ 1.47ms ( 0.00 GFLOPS 7.1|7.1 GB/s) | |
| *** CLANG 5 copy 4.19M, CLANG <- METAL mem 0.02 GB tm 164.08us/ 1.63ms ( 0.00 GFLOPS 25.6|25.6 GB/s) | |
| *** CLANG 6 copy 4.19M, CLANG <- METAL mem 0.02 GB tm 130.38us/ 1.76ms ( 0.00 GFLOPS 32.2|32.2 GB/s) | |
| -1.1816155e-07 | |
| avg: 1217.77 GFLOPS 14.27 GB/s total: 6 kernels 2.15 GOPS 0.03 GB 1.76 ms |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment