Skip to content

Instantly share code, notes, and snippets.

@awni
Last active October 9, 2025 19:19
Show Gist options
  • Select an option

  • Save awni/e67327ee5dce23ad4f92adf855cdda5f to your computer and use it in GitHub Desktop.

Select an option

Save awni/e67327ee5dce23ad4f92adf855cdda5f to your computer and use it in GitHub Desktop.
MLX Tiled Matmul
import mlx.core as mx
# Possible tile size for tensor cores
TS = 32
# Matrix dimension (M = N = K = D)
D = 2048
A = mx.random.uniform(shape=(D, D))
B = mx.random.uniform(shape=(D, D))
# Reshape and transpose so a tile is in the last two dimensions
A_tiled = A.reshape((2048 // TS, TS, D // TS, TS)).swapaxes(1, 2)
B_tiled = B.reshape((2048 // TS, TS, D // TS, TS)).swapaxes(1, 2)
# Each thread group computes one tile of the output:
i = 1
j = 1
C_ij = sum(A_tiled[i, k] @ B_tiled[k, j] for k in range(D // TS))
C = A @ B
# Get the `i, j` tile of the output
C_ij_expected = C[i * TS:(i+1) * TS, j * TS: (j+1) * TS]
assert mx.allclose(C_ij, C_ij_expected)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment