Skip to content

Instantly share code, notes, and snippets.

@galv
Created October 22, 2025 05:11
Show Gist options
  • Select an option

  • Save galv/3d7b2d324cdb33875d3eea0a884e025f to your computer and use it in GitHub Desktop.

Select an option

Save galv/3d7b2d324cdb33875d3eea0a884e025f to your computer and use it in GitHub Desktop.
NVTX interactions with cuda graph
# run like this: `nsys profile -c cudaProfilerApi --cuda-graph-trace=node python nvtx_graph_projection.py`
import torch
device = torch.device("cuda")
@torch.no_grad()
def add_two_numbers(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor):
out.copy_(a)
out.add_(b)
def main():
torch.cuda.synchronize()
# Create static tensors that will participate in the graph (addresses must remain stable)
static_a = torch.ones(1024, 1024, device=device)
static_b = 2 * torch.ones(1024, 1024, device=device)
static_out = torch.empty_like(static_a)
# Warmup
for _ in range(10):
add_two_numbers(static_a, static_b, static_out)
torch.cuda.synchronize()
# Prepare CUDA Graph + stream
g = torch.cuda.CUDAGraph()
capture_stream = torch.cuda.Stream()
with torch.cuda.stream(capture_stream):
static_out.zero_()
capture_stream.synchronize()
torch.cuda.cudart().cudaProfilerStart()
# ---- Capture: record 10 calls, each with its own NVTX range ----
with torch.cuda.graph(g, stream=capture_stream):
with torch.cuda.nvtx.range("whole workload"):
for i in range(10):
# Each loop iteration is a distinct NVTX range visible in Nsight Systems/Compute
with torch.cuda.nvtx.range(f"add_{i+1}"):
add_two_numbers(static_a, static_b, static_out)
# ---- Replay on new input ----
new_a = 3 * torch.ones_like(static_a)
new_b = 4 * torch.ones_like(static_b)
# Copy new values into the static tensors the graph uses
static_a.copy_(new_a)
static_b.copy_(new_b)
# Launch the captured graph (runs the 10-add workload with NVTX ranges)
with torch.cuda.nvtx.range("GALVEZ: calling cudaGraphLaunch()"):
g.replay()
torch.cuda.synchronize()
torch.cuda.cudart().cudaProfilerStop()
# Check result (since we overwrite out each time, it should be new_a + new_b)
reference = new_a + new_b
max_abs_err = (static_out - reference).abs().max().item()
print(f"Max abs error vs. reference: {max_abs_err:.3e}")
# Optional: print a small checksum so you can confirm different inputs change output
print(f"Output checksum: {float(static_out.sum().item()):.1f}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment