Created
October 22, 2025 05:11
-
-
Save galv/3d7b2d324cdb33875d3eea0a884e025f to your computer and use it in GitHub Desktop.
NVTX interactions with cuda graph
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # run like this: `nsys profile -c cudaProfilerApi --cuda-graph-trace=node python nvtx_graph_projection.py` | |
| import torch | |
| device = torch.device("cuda") | |
| @torch.no_grad() | |
| def add_two_numbers(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor): | |
| out.copy_(a) | |
| out.add_(b) | |
| def main(): | |
| torch.cuda.synchronize() | |
| # Create static tensors that will participate in the graph (addresses must remain stable) | |
| static_a = torch.ones(1024, 1024, device=device) | |
| static_b = 2 * torch.ones(1024, 1024, device=device) | |
| static_out = torch.empty_like(static_a) | |
| # Warmup | |
| for _ in range(10): | |
| add_two_numbers(static_a, static_b, static_out) | |
| torch.cuda.synchronize() | |
| # Prepare CUDA Graph + stream | |
| g = torch.cuda.CUDAGraph() | |
| capture_stream = torch.cuda.Stream() | |
| with torch.cuda.stream(capture_stream): | |
| static_out.zero_() | |
| capture_stream.synchronize() | |
| torch.cuda.cudart().cudaProfilerStart() | |
| # ---- Capture: record 10 calls, each with its own NVTX range ---- | |
| with torch.cuda.graph(g, stream=capture_stream): | |
| with torch.cuda.nvtx.range("whole workload"): | |
| for i in range(10): | |
| # Each loop iteration is a distinct NVTX range visible in Nsight Systems/Compute | |
| with torch.cuda.nvtx.range(f"add_{i+1}"): | |
| add_two_numbers(static_a, static_b, static_out) | |
| # ---- Replay on new input ---- | |
| new_a = 3 * torch.ones_like(static_a) | |
| new_b = 4 * torch.ones_like(static_b) | |
| # Copy new values into the static tensors the graph uses | |
| static_a.copy_(new_a) | |
| static_b.copy_(new_b) | |
| # Launch the captured graph (runs the 10-add workload with NVTX ranges) | |
| with torch.cuda.nvtx.range("GALVEZ: calling cudaGraphLaunch()"): | |
| g.replay() | |
| torch.cuda.synchronize() | |
| torch.cuda.cudart().cudaProfilerStop() | |
| # Check result (since we overwrite out each time, it should be new_a + new_b) | |
| reference = new_a + new_b | |
| max_abs_err = (static_out - reference).abs().max().item() | |
| print(f"Max abs error vs. reference: {max_abs_err:.3e}") | |
| # Optional: print a small checksum so you can confirm different inputs change output | |
| print(f"Output checksum: {float(static_out.sum().item()):.1f}") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment