Last active
November 26, 2025 23:36
-
-
Save johnmeade/c314d7de4f7305992e0436e86047aacf to your computer and use it in GitHub Desktop.
High-GPU-utilization PyTorch multiprocessing pipeline example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| MIT License 2025 John Meade | |
| Writing custom multiprocessing code for PyTorch can be tricky. | |
| Generally you should use `accelerate` if possible, but this is not an option for bigger pipelines. | |
| This gist provides a starting point for high GPU utilization multiprocessing pipelines for PyTorch models. | |
| Some symptoms of CPU bottlenecks when using this framework are: | |
| * input queues to GPU workers frequently having size 0 | |
| * high GPU utilization on one GPU, but low util on others | |
| """ | |
| # PyTorch requires the "spawn" method to work properly. This may be handled already by torch.multiprocessing. | |
| if __name__ == '__main__': | |
| import multiprocessing as py_mp | |
| py_mp.set_start_method("spawn") | |
| from threading import Thread | |
| from time import perf_counter | |
| from itertools import count | |
| from queue import Queue as LocalQueue | |
| import torch | |
| from torch import nn, Tensor | |
| # NOTE: pytorch multiprocessing wrapper | |
| import torch.multiprocessing as mp | |
| # each pytorch process uses multithreading, so you | |
| # can threadbomb yourself if this is too high | |
| torch.set_num_threads(2) | |
| def cpu_worker(tag: str, stop_event, in_queue: mp.Queue, out_queue: mp.Queue): | |
| "mock CPU task" | |
| print(f"[{tag}] Starting main loop") | |
| while not stop_event.is_set(): | |
| obj: Tensor = in_queue.get() | |
| out_queue.put(obj + 1) | |
| print(f"[{tag}] Received stop signal, exiting") | |
| @torch.inference_mode() | |
| def gpu_worker_infer_thread(tag: str, dim: int, device: str, stop_event, in_queue: LocalQueue, out_queue: LocalQueue): | |
| "mock GPU model infer task -- no device casting" | |
| print(f"[{tag}] Starting main loop") | |
| # load your model here | |
| model = nn.Sequential(*[ nn.Linear(dim, dim) for _ in range(16) ]).to(device) | |
| while not stop_event.is_set(): | |
| obj: Tensor = in_queue.get() | |
| # simulate high load | |
| for _ in range(20): | |
| obj = model(obj / 100) | |
| out_queue.put(obj) | |
| print(f"[{tag}] Received stop signal, exiting") | |
| def gpu_worker_cpu_cast_thread(tag: str, stop_event, in_queue: LocalQueue, out_queue: LocalQueue): | |
| "Cast GPU model outputs to CPU and put in IPC queue" | |
| print(f"[{tag}] Starting main loop") | |
| while not stop_event.is_set(): | |
| obj: Tensor = in_queue.get() | |
| out_queue.put(obj.cpu()) | |
| print(f"[{tag}] Received stop signal, exiting") | |
| def gpu_worker(tag: str, dim: int, device: str, stop_event, in_queue: LocalQueue, out_queue: mp.Queue): | |
| """ | |
| Orchestrate GPU model inference on a process. | |
| This requires extra threads to avoid device casting bottlenecks. | |
| """ | |
| print(f"[{tag}] Starting main loop") | |
| # internal thread-local queues | |
| infer_in_q = LocalQueue(100) | |
| infer_out_q = LocalQueue(100) | |
| # internal threads | |
| infer_thread = Thread(target=gpu_worker_infer_thread, kwargs=dict( | |
| tag=tag + " infer thread", | |
| dim=dim, | |
| device=device, | |
| stop_event=stop_event, | |
| in_queue=infer_in_q, | |
| out_queue=infer_out_q, | |
| ), daemon=True) | |
| infer_thread.start() | |
| cast_out_thread = Thread(target=gpu_worker_cpu_cast_thread, kwargs=dict( | |
| tag=tag + " cpu cast thread", | |
| stop_event=stop_event, | |
| in_queue=infer_out_q, | |
| out_queue=out_queue, | |
| ), daemon=True) | |
| cast_out_thread.start() | |
| # main collection loop | |
| while not stop_event.is_set(): | |
| obj: Tensor = in_queue.get() | |
| obj = obj.to(device) | |
| infer_in_q.put(obj) | |
| print(f"[{tag}] Received stop signal, joining threads") | |
| infer_thread.join() | |
| cast_out_thread.join() | |
| print(f"[{tag}] Exiting") | |
| def flush_worker(tag: str, stop_event, in_queue: mp.Queue): | |
| "final stop in the queue chain, eg this might write files to disk" | |
| print(f"[{tag}] Starting main loop") | |
| n = 0 | |
| t_start = perf_counter() | |
| while not stop_event.is_set(): | |
| obj: Tensor = in_queue.get() | |
| n += 1 | |
| if (n in [1, 10, 50]) or (n % 100 == 0): | |
| dt = perf_counter() - t_start | |
| print(f"[{tag}] Received {n:,} items in {1000 * dt:,.0f} ms ({n / dt:.2f} items / ms)") | |
| print(f"[{tag}] Received stop signal, exiting") | |
| def print_queue_size(name, q): | |
| s = q.qsize() | |
| print(f"> {name}: {s:,}" + ("" if s > 0 else " => bottlenecked")) | |
| if __name__ == '__main__': | |
| tensor_dim = 8192 | |
| qmax = 100 | |
| n_cpu = 32 | |
| gpus = [0, 1, 2, 3] | |
| gpu_devices = [f"cuda:{i}" for i in gpus] | |
| # sycnhronization and queues -- pytorch should use shared memory automatically for Tensor objects | |
| stop_event = mp.Event() | |
| cpu_in_q = mp.Queue(qmax) | |
| cpu_to_gpu_q = mp.Queue(qmax) # NOTE: CPU worker output links to GPU worker input | |
| gpu_to_flush_q = mp.Queue(qmax) | |
| # Create CPU processes | |
| cpu_procs = [ | |
| mp.Process( | |
| target=cpu_worker, | |
| kwargs=dict( | |
| tag=f"CPU worker {i}", | |
| stop_event=stop_event, | |
| in_queue=cpu_in_q, | |
| out_queue=cpu_to_gpu_q, | |
| ), | |
| # daemon mode terminates BG processes if the main thread exits / crashes | |
| daemon=True, | |
| ) | |
| for i in range(n_cpu) | |
| ] | |
| # Create GPU processes | |
| gpu_procs = [ | |
| mp.Process( | |
| target=gpu_worker, | |
| kwargs=dict( | |
| tag=f"GPU worker {device}", | |
| stop_event=stop_event, | |
| dim=tensor_dim, | |
| device=device, | |
| in_queue=cpu_to_gpu_q, | |
| out_queue=gpu_to_flush_q, | |
| ), | |
| daemon=True, | |
| ) | |
| for device in gpu_devices | |
| ] | |
| # Create flush worker | |
| flush_proc = mp.Process( | |
| target=flush_worker, | |
| kwargs=dict( | |
| tag=f"Flush worker", | |
| stop_event=stop_event, | |
| in_queue=gpu_to_flush_q, | |
| ), | |
| daemon=True, | |
| ) | |
| # Start procs | |
| print("Starting processes") | |
| [p.start() for p in cpu_procs] | |
| [p.start() for p in gpu_procs] | |
| flush_proc.start() | |
| # Start test | |
| print("Starting main loop") | |
| try: | |
| for i in count(): | |
| x = torch.randn(tensor_dim, tensor_dim) | |
| cpu_in_q.put(x) | |
| if i % 100 == 0: | |
| print(f"Queue sizes at {i=}") | |
| print_queue_size("cpu_in_q", cpu_in_q) | |
| print_queue_size("cpu_to_gpu_q", cpu_to_gpu_q) | |
| print_queue_size("gpu_to_flush_q", gpu_to_flush_q) | |
| except KeyboardInterrupt: | |
| print("Handling Interrupt") | |
| finally: | |
| print("Joining processes") | |
| stop_event.set() | |
| [p.join() for p in cpu_procs] | |
| [p.join() for p in gpu_procs] | |
| flush_proc.join() | |
| print("Done") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment