Skip to content

Instantly share code, notes, and snippets.

@MeetThePatel
Last active May 20, 2025 16:15
Show Gist options
  • Select an option

  • Save MeetThePatel/8ad3f3591562dd781c7fb0f7e5297736 to your computer and use it in GitHub Desktop.

Select an option

Save MeetThePatel/8ad3f3591562dd781c7fb0f7e5297736 to your computer and use it in GitHub Desktop.
FMA in Adam(fused=True) benchmarks
import torch
import torch.nn as nn
from torch.utils.benchmark import Timer
from torch.optim import Adam
import subprocess
DEVICE = "cuda"
def print_git_info():
branch = (
subprocess.check_output(
["git", "rev-parse", "--abbrev-ref", "HEAD"],
stderr=subprocess.DEVNULL,
)
.decode("utf-8")
.strip()
)
commit = (
subprocess.check_output(
["git", "rev-parse", "HEAD"],
stderr=subprocess.DEVNULL,
)
.decode("utf-8")
.strip()
)
print(f"[Git] Branch: {branch}")
print(f"[Git] Commit: {commit[:8]}")
class TransformerModel(nn.Module):
def __init__(
self,
d_model=1024,
nhead=16,
num_encoder_layers=12,
dim_feedforward=4096,
seq_len=512,
):
super().__init__()
self.embedding = nn.Embedding(10000, d_model)
self.positional_encoding = nn.Parameter(torch.zeros(1, seq_len, d_model))
encoder_layer = nn.TransformerEncoderLayer(
d_model, nhead, dim_feedforward, batch_first=True
)
self.transformer_encoder = nn.TransformerEncoder(
encoder_layer, num_layers=num_encoder_layers
)
self.linear = nn.Linear(d_model, 10000)
def forward(self, x):
x = self.embedding(x) + self.positional_encoding[:, : x.size(1), :]
x = self.transformer_encoder(x)
return self.linear(x)
def prepare_model_and_grads():
model = TransformerModel().to(DEVICE)
for param in model.parameters():
param.data.normal_(mean=0.0, std=0.02)
param.grad = torch.randn_like(param.data) # simulate gradients
return model
def benchmark_optimizer_step():
model = prepare_model_and_grads()
optimizer = Adam(model.parameters(), lr=1e-3, fused=True)
# Warm-up optimizer steps
for _ in range(10):
optimizer.step()
torch.cuda.synchronize()
# Benchmark only the optimizer step
t = Timer(
stmt="optimizer.step()",
setup="torch.cuda.synchronize()",
globals={"optimizer": optimizer},
label="Optimizer Step",
sub_label="Fused Adam",
description="torch.optim.Adam(fused=True)",
)
results = t.blocked_autorange(min_run_time=5.0)
print(f"{results.label}: {results.sub_label}")
print(f"{results.description}")
print(f"Mean time per run: {results.mean * 1e6:.4f} µs")
print(f"Median: {results.median * 1e6:.4f}")
if __name__ == "__main__":
print_git_info()
benchmark_optimizer_step()
@MeetThePatel
Copy link
Author

MeetThePatel commented May 20, 2025

This benchmark was run on:

  • Ryzen 9 5900X
  • Nvidia RTX 5080

Run this script on each branch with:

nsys profile --trace=cuda,nvtx python benchmark.py

Convert each .nsys-rep to .csv with:

nsys stats --report nvtx_sum --format csv

The benchmark results are:

upstream/main               average time (ns):    642631.8
meetthepatel/cuda-adam-lerp average time (ns):    503889.5
----------------------------------------------------------
speedup:                                            21.59%

upstream/main               median time (ns):     617061.5
meetthepatel/cuda-adam-lerp median time (ns):     463619.5
----------------------------------------------------------
speedup:                                            24.87%

@MeetThePatel
Copy link
Author

Updated training script using torch.utils.benchmark results:

[Git] Branch: main
[Git] Commit: 2e56ce09
Optimizer Step: Fused Adam
torch.optim.Adam(fused=True)
Mean time per run: 6715.9338 µs
Median: 6716.0217
[Git] Branch: cuda-adam-lerp
[Git] Commit: 76d4316f
Optimizer Step: Fused Adam
torch.optim.Adam(fused=True)
Mean time per run: 6143.7060 µs
Median: 6152.7162

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment