Skip to content

Instantly share code, notes, and snippets.

@papamoose
Created October 3, 2025 20:19
Show Gist options
  • Select an option

  • Save papamoose/2fe847954f5bdf78d7b3dca36c3533a8 to your computer and use it in GitHub Desktop.

Select an option

Save papamoose/2fe847954f5bdf78d7b3dca36c3533a8 to your computer and use it in GitHub Desktop.
pytorch code to make nvidia gpus do work
#!/usr/bin/env python3
"""
pytorch_bench.py – Multi‑GPU stress test with safety checks
Install
-------
python3 -m venv ./venv
python3 -m pip install tqdm
python3 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cu129
Features
--------
* Automatic detection of all CUDA devices.
* `--devices` lets you pick any subset (comma‑separated, e.g. 0,2,3).
* If you request more GPUs than exist, the script aborts with a clear error.
* Works with a single GPU (no multiprocessing overhead) or many GPUs
(one process per device via torch.multiprocessing.spawn).
* Mixed‑precision (fp16/bf16) and cuDNN‑benchmark toggle.
"""
import argparse
import sys
import time
import signal
from pathlib import Path
from typing import List
import torch
from tqdm import tqdm
import torch.multiprocessing as mp
# Graceful Ctrl‑C handling (works across all spawned processes)
def install_sigint_handler() -> None:
"""Install a SIGINT handler that exits cleanly."""
def handler(sig, frame):
print("\n\nCaught SIGINT – exiting cleanly. GPU memory will be released.")
sys.exit(0)
signal.signal(signal.SIGINT, handler)
# Argument parsing
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="GPU stress test using large matrix multiplies in PyTorch"
)
parser.add_argument(
"-d",
"--devices",
type=str,
default=None,
help=(
"Comma‑separated list of CUDA device IDs to use (e.g. 0,2,3). "
"If omitted, the script will use all visible GPUs."
),
)
parser.add_argument(
"-s",
"--size",
type=int,
default=8192,
help=(
"Linear dimension N of the square matrices (default: 8192). "
"Memory usage ≈ 2·N²·dtype_bytes bytes."
),
)
parser.add_argument(
"-i",
"--iters",
type=int,
default=200,
help="Number of matmul iterations (default: 200).",
)
parser.add_argument(
"-p",
"--precision",
choices=["fp32", "fp16", "bf16"],
default="fp32",
help="Arithmetic precision to use. fp16/bf16 require GPU support.",
)
parser.add_argument(
"--warmup",
type=int,
default=5,
help="Number of warm‑up iterations (default: 5).",
)
parser.add_argument(
"--no‑cudnn‑bench",
dest="no_cudnn_bench",
action="store_true",
help="Disable torch.backends.cudnn.benchmark (keeps determinism).",
)
return parser.parse_args()
# Turn the user‑provided device string into a list of ints
def parse_device_list(devices_str: str | None) -> List[int]:
"""
Return a list of device IDs.
* None → all devices (0 .. torch.cuda.device_count()‑1)
* '' → empty list (treated as “all” as well)
* '0,2,3' → [0,2,3] (duplicates removed, order preserved)
"""
if devices_str is None or devices_str.strip() == "":
return list(range(torch.cuda.device_count()))
parts = [p.strip() for p in devices_str.split(",") if p.strip() != ""]
ids = []
for p in parts:
if not p.isdigit():
raise ValueError(f"Device ID must be an integer, got '{p}'")
ids.append(int(p))
# Deduplicate while preserving order
seen = set()
uniq = []
for d in ids:
if d not in seen:
uniq.append(d)
seen.add(d)
return uniq
# Validate that the requested IDs actually exist on the system
def validate_device_list(requested: List[int]) -> List[int]:
"""
Ensure every ID in ``requested`` is a valid CUDA ordinal.
If the list contains an invalid ID, print a friendly error and exit.
Returns the unchanged list when everything is OK.
"""
available = torch.cuda.device_count()
if available == 0:
print("No CUDA devices detected on this machine. Exiting.")
sys.exit(1)
out_of_range = [d for d in requested if d < 0 or d >= available]
if out_of_range:
print(
f"You asked for GPU id(s) {out_of_range} but only "
f"{available} CUDA device(s) are present (IDs 0‑{available-1})."
)
print("Use `nvidia-smi` or `torch.cuda.device_count()` to see what is available.")
print("Either drop the invalid IDs or run the script without `--devices` to use all GPUs.")
sys.exit(1)
return requested
# Memory‑estimation helper (same as before)
def estimate_memory(size: int, dtype: torch.dtype) -> float:
"""Return memory (GiB) needed for two input matrices of shape (size,size)."""
bytes_per_elem = torch.tensor([], dtype=dtype).element_size()
mem_bytes = 2 * size * size * bytes_per_elem
return mem_bytes / (1024**3) # GiB
# The per‑GPU worker (run in its own process when >1 GPU)
def worker(rank: int, device_ids: List[int], args: argparse.Namespace) -> None:
"""
rank – 0‑based index of the spawned process.
device_ids – full list of requested GPU IDs.
args – the argparse.Namespace shared across processes.
"""
dev_id = device_ids[rank]
device = torch.device(f"cuda:{dev_id}")
torch.cuda.set_device(device)
print(
f"\nProcess {rank} → GPU {dev_id} : "
f"{torch.cuda.get_device_name(device)}"
)
# ---- precision ----------------------------------------------------
if args.precision == "fp32":
dtype = torch.float32
elif args.precision == "fp16":
dtype = torch.float16
elif args.precision == "bf16":
dtype = torch.bfloat16
else:
raise RuntimeError("Unsupported precision")
# ---- memory sanity check -------------------------------------------
needed = estimate_memory(args.size, dtype)
free = torch.cuda.mem_get_info(device=device)[0] / (1024**3)
print(
f"[GPU {dev_id}] Matrix size: {args.size}×{args.size} "
f"Precision: {args.precision} ({torch.tensor([], dtype=dtype).element_size()} B) "
f"Needed ≈ {needed:.2f} GiB Free ≈ {free:.2f} GiB"
)
if needed > free * 0.85:
print(
f"[GPU {dev_id}] Not enough free memory. Reduce --size or run on fewer GPUs."
)
sys.exit(1)
# ---- cudnn benchmark -----------------------------------------------
if not args.no_cudnn_bench:
torch.backends.cudnn.benchmark = True
print(f"[GPU {dev_id}] cudnn.benchmark = True")
# ---- allocate tensors -----------------------------------------------
torch.manual_seed(0) # deterministic per‑process seed
a = torch.randn((args.size, args.size), dtype=dtype, device=device)
b = torch.randn((args.size, args.size), dtype=dtype, device=device)
# ---- warm‑up (not timed) -------------------------------------------
if args.warmup > 0:
print(f"[GPU {dev_id}] Warm‑up: {args.warmup} iterations (not timed)")
for _ in tqdm(
range(args.warmup),
desc=f"GPU{dev_id} warm‑up",
unit="iter",
ncols=80,
leave=False,
):
c = torch.matmul(a, b)
c = c.contiguous()
torch.cuda.synchronize()
del c
# ---- timed benchmark -----------------------------------------------
print(f"[GPU {dev_id}] Starting timed run: {args.iters} iterations")
start = time.time()
for _ in tqdm(
range(args.iters),
desc=f"GPU{dev_id} run",
unit="iter",
ncols=80,
leave=False,
):
c = torch.matmul(a, b)
_ = c.sum() # forces materialisation
torch.cuda.synchronize()
del c
elapsed = time.time() - start
avg_sec = elapsed / args.iters
flops_per_iter = 2 * (args.size ** 3) # 2·N³ FLOPs per GEMM
tflops = (flops_per_iter / 1e12) / avg_sec
print(f"\n==== GPU {dev_id} RESULTS ====")
print(f"Total runtime : {elapsed:.2f} s")
print(f"Avg per iteration : {avg_sec*1e3:.2f} ms")
print(f"Approx. TFLOPs/iter : {(flops_per_iter/1e12):.2f} TFLOPs")
print(f"Sustained performance : {tflops:.2f} TFLOPs")
print("-" * 30)
# Main entry point – decides whether to spawn or run directly
def main() -> None:
install_sigint_handler()
args = parse_args()
# ---- Resolve and validate the device list -------------------------
raw_device_ids = parse_device_list(args.devices)
device_ids = validate_device_list(raw_device_ids)
if not device_ids:
print("No CUDA devices selected after validation. Exiting.")
sys.exit(1)
print(f"Detected {torch.cuda.device_count()} CUDA device(s).")
print(f"Using device list: {device_ids}")
# ---- Single‑GPU fast path -----------------------------------------
if len(device_ids) == 1:
worker(rank=0, device_ids=device_ids, args=args)
return
# ---- Multi‑GPU: spawn one process per GPU -------------------------
print(f"Spawning {len(device_ids)} processes (one per GPU)…")
mp.set_start_method("spawn", force=True)
mp.spawn(
fn=worker,
args=(device_ids, args),
nprocs=len(device_ids),
join=True,
)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment