Skip to content

Instantly share code, notes, and snippets.

@davidberard98
Last active September 23, 2025 22:24
Show Gist options
  • Select an option

  • Save davidberard98/b97e834e36fa9ee49a016a38aee3f182 to your computer and use it in GitHub Desktop.

Select an option

Save davidberard98/b97e834e36fa9ee49a016a38aee3f182 to your computer and use it in GitHub Desktop.
import argparse
import multiprocessing
import os
from time import sleep
import torch
import triton
import triton.language as tl
def get_num_bytes(*args):
num_bytes = sum(
(x.numel() * x.element_size() for x in args if isinstance(x, torch.Tensor))
)
return num_bytes
@triton.autotune(
[triton.Config({"ROW_BLOCK_SIZE": 8}, num_warps=2, num_stages=3)], key=[]
)
@triton.jit
def kernel_layernorm_2d(
X,
Y,
stride,
M,
N,
eps,
ROW_BLOCK_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
row = tl.program_id(0) * ROW_BLOCK_SIZE + tl.arange(0, ROW_BLOCK_SIZE)
X = X + row * stride
Y = Y + row * stride
mask_row = row < M
cols = tl.arange(0, BLOCK_SIZE)
mask = cols < N
x = tl.load(
X[:, None] + cols[None, :], mask=mask_row[:, None] & mask[None, :], other=0.0
).to(tl.float32)
mean = tl.sum(x, axis=1) / N
var = tl.sum((x - mean[:, None]) * (x - mean[:, None]), axis=1) / N
rstd = 1 / tl.sqrt(var + eps)
y_hat = (x - mean[:, None]) * rstd[:, None]
y = y_hat
tl.store(Y[:, None] + cols[None, :], y, mask=mask_row[:, None] & mask[None, :])
def triton_layernorm_2d(x, eps, *, return_rstd=True, return_mean=True):
assert return_rstd and return_mean
M, N = x.size()
out = torch.empty_like(x)
BLOCK_SIZE = triton.next_power_of_2(N)
def grid(meta):
return (triton.cdiv(M, meta["ROW_BLOCK_SIZE"]),)
kernel_layernorm_2d[grid](x, out, x.stride(0), M, N, eps, BLOCK_SIZE=BLOCK_SIZE)
return out
def benchmark_under_load(fn, cache_clearer=False, warmup_reps=10000, timing_reps=10000):
assert not cache_clearer # not implemented
begin_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
for _ in range(warmup_reps):
fn()
begin_event.record()
for _ in range(timing_reps):
fn()
end_event.record()
torch.cuda.synchronize()
return begin_event.elapsed_time(end_event) / timing_reps
def main():
M, K = 2**20, 512
x = (torch.rand(M, K, device="cuda") - 0.5) * 2 * 50000
x = x.to(torch.bfloat16)
eps = 1e-3
def fn():
return triton_layernorm_2d(x, eps)
ms = benchmark_under_load(fn)
def gbps(ms):
return get_num_bytes(x, fn()) / ms * 1e-6
print(
f"Perf: {ms:.3f} ms ({gbps(ms)} GB/s)"
)
print()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment