Skip to content

Instantly share code, notes, and snippets.

@hvaara
Created September 16, 2024 01:04
Show Gist options
  • Select an option

  • Save hvaara/e3002c3258c689f9b6a2dd56fb15317c to your computer and use it in GitHub Desktop.

Select an option

Save hvaara/e3002c3258c689f9b6a2dd56fb15317c to your computer and use it in GitHub Desktop.
Tinygrad fused kernel example
METAL_XCODE=1 DISABLE_COMPILER_CACHE=1 DEBUG=4 python3 -c "from tinygrad import Tensor;
N = 1024; a, b = Tensor.rand(N, N), Tensor.rand(N, N);
c = (a.reshape(N, 1, N) * b.T.reshape(1, N, N)).sum(axis=2);
print((c.numpy() - (a.numpy() @ b.numpy())).mean())"
opened device METAL from pid:15248
opened device NPY from pid:15248
*** CUSTOM 1 custom_random mem 0.00 GB
*** CUSTOM 2 custom_random mem 0.01 GB
TENSOR CORES [(1, 1024, 1)] [(0, 1024, 1024)] WMMA_8_8_8_float_float
r_32_8_2_4_2_2_4_128_8_2_4_4
UOp(UOps.SINK, dtypes.void, arg=KernelInfo(local_dims=5, upcasted=4, dont_use_locals=False), src=(
UOp(UOps.STORE, dtypes.void, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()),
UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 8, 2, 4, 2, 2, 4, 1, 1, 2, 4, 4), strides=(32768, 128, 2, 1024, 4, 4096, 32, 0, 0, 1, 8192, 8), offset=0, mask=None, contiguous=False),)), src=()),
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7,)), src=(
UOp(UOps.WMMA, dtypes.float, arg=('WMMA_8_8_8_float_float', (8, 8, 8), dtypes.float, dtypes.float, 'METAL', 32, (((9, 2),), ((9, 2),), ((9, 2),)), (8,)), src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()),
UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 8, 2, 4, 2, 2, 4, 128, 8, 2, 4, 4), strides=(32768, 0, 2, 1024, 4, 4096, 0, 8, 0, 1, 8192, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()),
UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 8, 2, 4, 2, 2, 4, 128, 8, 2, 4, 4), strides=(0, 128, 2, 1024, 4, 4096, 32, 8192, 0, 1, 0, 8), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
[Opt(op=OptOps.TC, axis=0, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=1, amt=4)]
#include <metal_stdlib>
using namespace metal;
float2 __WMMA_8_8_8_float_float(float2 m, float2 n, float2 o) {
simdgroup_float8x8 a,b,c; a.thread_elements()[0] = m.x; a.thread_elements()[1] = m.y; b.thread_elements()[0] = n.x;
b.thread_elements()[1] = n.y; c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c);
return float2(c.thread_elements()[0], c.thread_elements()[1]);
}
kernel void r_32_8_2_4_2_2_4_128_8_2_4_4(device float* data0, const device float* data1, const device float* data2, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
int gidx0 = gid.x; /* 8 */
int gidx1 = gid.y; /* 32 */
int lidx0 = lid.x; /* 16 */
int lidx1 = lid.y; /* 2 */
int lidx2 = lid.z; /* 4 */
float2 cast0 = float2(0.0f,0.0f);
int alu0 = (gidx0*128);
int alu1 = (gidx1*32768);
int alu2 = (lidx1*4096);
int alu3 = (lidx2*32);
int alu4 = ((lidx0/8)*4);
int alu5 = ((lidx0%2)*2);
int alu6 = (((lidx0/2)%4)*1024);
int alu7 = (alu1+alu0+alu5+alu6+alu4+alu2+alu3);
float2 acc0 = cast0;
float2 acc1 = cast0;
float2 acc2 = cast0;
float2 acc3 = cast0;
float2 acc4 = cast0;
float2 acc5 = cast0;
float2 acc6 = cast0;
float2 acc7 = cast0;
float2 acc8 = cast0;
float2 acc9 = cast0;
float2 acc10 = cast0;
float2 acc11 = cast0;
float2 acc12 = cast0;
float2 acc13 = cast0;
float2 acc14 = cast0;
float2 acc15 = cast0;
for (int ridx0 = 0; ridx0 < 128; ridx0++) {
int alu8 = (alu1+alu5+alu6+alu4+alu2+(ridx0*8));
int alu9 = (alu0+alu5+alu6+alu4+alu2+alu3+(ridx0*8192));
float2 val0 = *((device float2*)(data1+alu8+8192));
float2 val1 = *((device float2*)(data1+alu8+16384));
float2 val2 = *((device float2*)(data1+alu8+24576));
float2 val3 = *((device float2*)(data1+alu8));
float2 val4 = *((device float2*)(data2+alu9+8));
float2 wmma0 = __WMMA_8_8_8_float_float(val0, val4, acc5);
float2 wmma1 = __WMMA_8_8_8_float_float(val1, val4, acc9);
float2 wmma2 = __WMMA_8_8_8_float_float(val2, val4, acc13);
float2 wmma3 = __WMMA_8_8_8_float_float(val3, val4, acc1);
float2 val5 = *((device float2*)(data2+alu9+16));
float2 wmma4 = __WMMA_8_8_8_float_float(val0, val5, acc6);
float2 wmma5 = __WMMA_8_8_8_float_float(val1, val5, acc10);
float2 wmma6 = __WMMA_8_8_8_float_float(val2, val5, acc14);
float2 wmma7 = __WMMA_8_8_8_float_float(val3, val5, acc2);
float2 val6 = *((device float2*)(data2+alu9+24));
float2 wmma8 = __WMMA_8_8_8_float_float(val0, val6, acc7);
float2 wmma9 = __WMMA_8_8_8_float_float(val1, val6, acc11);
float2 wmma10 = __WMMA_8_8_8_float_float(val2, val6, acc15);
float2 wmma11 = __WMMA_8_8_8_float_float(val3, val6, acc3);
float2 val7 = *((device float2*)(data2+alu9));
float2 wmma12 = __WMMA_8_8_8_float_float(val0, val7, acc4);
float2 wmma13 = __WMMA_8_8_8_float_float(val1, val7, acc8);
float2 wmma14 = __WMMA_8_8_8_float_float(val2, val7, acc12);
float2 wmma15 = __WMMA_8_8_8_float_float(val3, val7, acc0);
acc0 = wmma15;
acc1 = wmma3;
acc2 = wmma7;
acc3 = wmma11;
acc4 = wmma12;
acc5 = wmma0;
acc6 = wmma4;
acc7 = wmma8;
acc8 = wmma13;
acc9 = wmma1;
acc10 = wmma5;
acc11 = wmma9;
acc12 = wmma14;
acc13 = wmma2;
acc14 = wmma6;
acc15 = wmma10;
}
*((device float2*)(data0+alu7+8)) = acc1;
*((device float2*)(data0+alu7+16)) = acc2;
*((device float2*)(data0+alu7+24)) = acc3;
*((device float2*)(data0+alu7+8192)) = acc4;
*((device float2*)(data0+alu7+8200)) = acc5;
*((device float2*)(data0+alu7+8208)) = acc6;
*((device float2*)(data0+alu7+8216)) = acc7;
*((device float2*)(data0+alu7+16384)) = acc8;
*((device float2*)(data0+alu7+16392)) = acc9;
*((device float2*)(data0+alu7+16400)) = acc10;
*((device float2*)(data0+alu7+16408)) = acc11;
*((device float2*)(data0+alu7+24576)) = acc12;
*((device float2*)(data0+alu7+24584)) = acc13;
*((device float2*)(data0+alu7+24592)) = acc14;
*((device float2*)(data0+alu7+24600)) = acc15;
*((device float2*)(data0+alu7)) = acc0;
}
*** METAL 3 r_32_8_2_4_2_2_4_128_8_2_4_4 mem 0.01 GB tm 878.75us/ 0.88ms ( 2443.79 GFLOPS 14.3|310.2 GB/s) ['__mul__', 'sum']
opened device CLANG from pid:15248
*** CLANG 4 copy 4.19M, CLANG <- METAL mem 0.02 GB tm 590.25us/ 1.47ms ( 0.00 GFLOPS 7.1|7.1 GB/s)
*** CLANG 5 copy 4.19M, CLANG <- METAL mem 0.02 GB tm 164.08us/ 1.63ms ( 0.00 GFLOPS 25.6|25.6 GB/s)
*** CLANG 6 copy 4.19M, CLANG <- METAL mem 0.02 GB tm 130.38us/ 1.76ms ( 0.00 GFLOPS 32.2|32.2 GB/s)
-1.1816155e-07
avg: 1217.77 GFLOPS 14.27 GB/s total: 6 kernels 2.15 GOPS 0.03 GB 1.76 ms
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment