Skip to content

Instantly share code, notes, and snippets.

@deebuls
Forked from Chillee/1-pw_op_fusion.py
Created January 22, 2023 11:35
Show Gist options
  • Select an option

  • Save deebuls/3c74d150a5c86215eee5767b18d68ad1 to your computer and use it in GitHub Desktop.

Select an option

Save deebuls/3c74d150a5c86215eee5767b18d68ad1 to your computer and use it in GitHub Desktop.
PT 2.0 Benchmarks
import torch
import torch._inductor.config
import time
torch._inductor.config.triton.cudagraphs = False
torch.set_float32_matmul_precision('high')
def bench(f, name=None, iters=100, warmup=5, display=True, profile=False):
for _ in range(warmup):
f()
if profile:
with torch.profiler.profile() as prof:
f()
prof.export_chrome_trace(f"{name if name is not None else 'trace'}.json")
torch.cuda.synchronize()
begin = time.time()
for _ in range(iters):
f()
torch.cuda.synchronize()
us_per_iter = (time.time()-begin)*1e6/iters
if name is None:
res = us_per_iter
else:
res= f"{name}: {us_per_iter}us"
if display:
print(res)
return res
def f1(a, b, c, d):
a = a.relu()
b = b.tanh()
e = a * b
f = (c + 2).cos()
return (e + f) * d
inp = [torch.randn(2**24, device='cuda') for _ in range(4)]
f = f1
nf = torch.compile(f)
bench(lambda: f(*inp), name="eager")
bench(lambda: nf(*inp), name="PT 2.0")
import torch
from torch.nn import *
torch.set_float32_matmul_precision('high')
def bench(f, name=None, iters=100, warmup=5, display=True, profile=False):
import time
for _ in range(warmup):
f()
if profile:
with torch.profiler.profile() as prof:
f()
prof.export_chrome_trace(f"{name if name is not None else 'trace'}.json")
torch.cuda.synchronize()
begin = time.time()
for _ in range(iters):
f()
torch.cuda.synchronize()
us_per_iter = (time.time()-begin)*1e6/iters
if name is None:
res = us_per_iter
else:
res= f"{name}: {us_per_iter:.2f}us"
if display:
print(res)
return res
import torchvision.models as models
mod = models.resnet18().eval().cuda()
opt_mod = torch.compile(mod, mode="reduce-overhead")
inp = torch.randn(1, 3, 224, 224).cuda()
with torch.no_grad():
# Eager: 1938.18us
bench(lambda: mod(inp), "Eager")
# torch.compile (default): 953.96us
# torch.compile (reduce-overhead): 744.02us
bench(lambda: opt_mod(inp), "torch.compile (reduce-overhead)")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment