Skip to content

Instantly share code, notes, and snippets.

@johnmeade
Last active November 26, 2025 23:36
Show Gist options
  • Select an option

  • Save johnmeade/c314d7de4f7305992e0436e86047aacf to your computer and use it in GitHub Desktop.

Select an option

Save johnmeade/c314d7de4f7305992e0436e86047aacf to your computer and use it in GitHub Desktop.
High-GPU-utilization PyTorch multiprocessing pipeline example
"""
MIT License 2025 John Meade
Writing custom multiprocessing code for PyTorch can be tricky.
Generally you should use `accelerate` if possible, but this is not an option for bigger pipelines.
This gist provides a starting point for high GPU utilization multiprocessing pipelines for PyTorch models.
Some symptoms of CPU bottlenecks when using this framework are:
* input queues to GPU workers frequently having size 0
* high GPU utilization on one GPU, but low util on others
"""
# PyTorch requires the "spawn" method to work properly. This may be handled already by torch.multiprocessing.
if __name__ == '__main__':
import multiprocessing as py_mp
py_mp.set_start_method("spawn")
from threading import Thread
from time import perf_counter
from itertools import count
from queue import Queue as LocalQueue
import torch
from torch import nn, Tensor
# NOTE: pytorch multiprocessing wrapper
import torch.multiprocessing as mp
# each pytorch process uses multithreading, so you
# can threadbomb yourself if this is too high
torch.set_num_threads(2)
def cpu_worker(tag: str, stop_event, in_queue: mp.Queue, out_queue: mp.Queue):
"mock CPU task"
print(f"[{tag}] Starting main loop")
while not stop_event.is_set():
obj: Tensor = in_queue.get()
out_queue.put(obj + 1)
print(f"[{tag}] Received stop signal, exiting")
@torch.inference_mode()
def gpu_worker_infer_thread(tag: str, dim: int, device: str, stop_event, in_queue: LocalQueue, out_queue: LocalQueue):
"mock GPU model infer task -- no device casting"
print(f"[{tag}] Starting main loop")
# load your model here
model = nn.Sequential(*[ nn.Linear(dim, dim) for _ in range(16) ]).to(device)
while not stop_event.is_set():
obj: Tensor = in_queue.get()
# simulate high load
for _ in range(20):
obj = model(obj / 100)
out_queue.put(obj)
print(f"[{tag}] Received stop signal, exiting")
def gpu_worker_cpu_cast_thread(tag: str, stop_event, in_queue: LocalQueue, out_queue: LocalQueue):
"Cast GPU model outputs to CPU and put in IPC queue"
print(f"[{tag}] Starting main loop")
while not stop_event.is_set():
obj: Tensor = in_queue.get()
out_queue.put(obj.cpu())
print(f"[{tag}] Received stop signal, exiting")
def gpu_worker(tag: str, dim: int, device: str, stop_event, in_queue: LocalQueue, out_queue: mp.Queue):
"""
Orchestrate GPU model inference on a process.
This requires extra threads to avoid device casting bottlenecks.
"""
print(f"[{tag}] Starting main loop")
# internal thread-local queues
infer_in_q = LocalQueue(100)
infer_out_q = LocalQueue(100)
# internal threads
infer_thread = Thread(target=gpu_worker_infer_thread, kwargs=dict(
tag=tag + " infer thread",
dim=dim,
device=device,
stop_event=stop_event,
in_queue=infer_in_q,
out_queue=infer_out_q,
), daemon=True)
infer_thread.start()
cast_out_thread = Thread(target=gpu_worker_cpu_cast_thread, kwargs=dict(
tag=tag + " cpu cast thread",
stop_event=stop_event,
in_queue=infer_out_q,
out_queue=out_queue,
), daemon=True)
cast_out_thread.start()
# main collection loop
while not stop_event.is_set():
obj: Tensor = in_queue.get()
obj = obj.to(device)
infer_in_q.put(obj)
print(f"[{tag}] Received stop signal, joining threads")
infer_thread.join()
cast_out_thread.join()
print(f"[{tag}] Exiting")
def flush_worker(tag: str, stop_event, in_queue: mp.Queue):
"final stop in the queue chain, eg this might write files to disk"
print(f"[{tag}] Starting main loop")
n = 0
t_start = perf_counter()
while not stop_event.is_set():
obj: Tensor = in_queue.get()
n += 1
if (n in [1, 10, 50]) or (n % 100 == 0):
dt = perf_counter() - t_start
print(f"[{tag}] Received {n:,} items in {1000 * dt:,.0f} ms ({n / dt:.2f} items / ms)")
print(f"[{tag}] Received stop signal, exiting")
def print_queue_size(name, q):
s = q.qsize()
print(f"> {name}: {s:,}" + ("" if s > 0 else " => bottlenecked"))
if __name__ == '__main__':
tensor_dim = 8192
qmax = 100
n_cpu = 32
gpus = [0, 1, 2, 3]
gpu_devices = [f"cuda:{i}" for i in gpus]
# sycnhronization and queues -- pytorch should use shared memory automatically for Tensor objects
stop_event = mp.Event()
cpu_in_q = mp.Queue(qmax)
cpu_to_gpu_q = mp.Queue(qmax) # NOTE: CPU worker output links to GPU worker input
gpu_to_flush_q = mp.Queue(qmax)
# Create CPU processes
cpu_procs = [
mp.Process(
target=cpu_worker,
kwargs=dict(
tag=f"CPU worker {i}",
stop_event=stop_event,
in_queue=cpu_in_q,
out_queue=cpu_to_gpu_q,
),
# daemon mode terminates BG processes if the main thread exits / crashes
daemon=True,
)
for i in range(n_cpu)
]
# Create GPU processes
gpu_procs = [
mp.Process(
target=gpu_worker,
kwargs=dict(
tag=f"GPU worker {device}",
stop_event=stop_event,
dim=tensor_dim,
device=device,
in_queue=cpu_to_gpu_q,
out_queue=gpu_to_flush_q,
),
daemon=True,
)
for device in gpu_devices
]
# Create flush worker
flush_proc = mp.Process(
target=flush_worker,
kwargs=dict(
tag=f"Flush worker",
stop_event=stop_event,
in_queue=gpu_to_flush_q,
),
daemon=True,
)
# Start procs
print("Starting processes")
[p.start() for p in cpu_procs]
[p.start() for p in gpu_procs]
flush_proc.start()
# Start test
print("Starting main loop")
try:
for i in count():
x = torch.randn(tensor_dim, tensor_dim)
cpu_in_q.put(x)
if i % 100 == 0:
print(f"Queue sizes at {i=}")
print_queue_size("cpu_in_q", cpu_in_q)
print_queue_size("cpu_to_gpu_q", cpu_to_gpu_q)
print_queue_size("gpu_to_flush_q", gpu_to_flush_q)
except KeyboardInterrupt:
print("Handling Interrupt")
finally:
print("Joining processes")
stop_event.set()
[p.join() for p in cpu_procs]
[p.join() for p in gpu_procs]
flush_proc.join()
print("Done")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment