Skip to content

Instantly share code, notes, and snippets.

@garrett361
Created June 4, 2024 18:22
Show Gist options
  • Select an option

  • Save garrett361/be3fc0e65963fc0c80ec80b6ae1438ed to your computer and use it in GitHub Desktop.

Select an option

Save garrett361/be3fc0e65963fc0c80ec80b6ae1438ed to your computer and use it in GitHub Desktop.
Torch Profile Comms Compute Overlap
from abc import ABC, abstractmethod
import torch
import torch.distributed as dist
if torch.cuda.is_available():
accel = torch.cuda
DEVICE_TYPE = "cuda"
BACKEND = "nccl"
else:
import intel_extension_for_pytorch as ipex # noqa
import oneccl_bindings_for_pytorch as torch_ccl # noqa
print(
f"Using Versions: {torch.__version__=}, {ipex.__version__=}, {torch_ccl.__version__=}",
flush=True,
)
accel = torch.xpu
DEVICE_TYPE = "xpu"
BACKEND = "ccl"
class Collective(ABC):
name: str
"""
Unified abstract class for collectives. Repeatedly runs collectives on the same tensors.
"""
def __init__(
self, numel: int, device: torch.device, dtype: torch.dtype, world_size: int
) -> None:
"""
Sets up tensors to run collectives on. numel sets the number of elements in the largest
tensor involved in the collective, e.g. the input tensor for reduce_scatter and the output
tensor for all_gather.
"""
...
self.numel = numel
self.device = device
self.dtype = dtype
self.world_size = world_size
self._setup()
@abstractmethod
def _setup(self) -> None:
"""
Performs any necessary setup.
"""
@abstractmethod
def __call__(self) -> None:
"""
Runs the collective
"""
...
@abstractmethod
def get_bit_s(self, time_s: float, iters: int) -> float:
"""
Computes the bandwidth in bits/s, based on the NVIDIA conventions:
https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md#performance-reported-by-nccl-tests
"""
...
def get_Gbit_s(self, time_s: float, iters: int) -> float:
"""
Computes the bandwidth in gib/s, based on the NVIDIA conventions:
https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md#performance-reported-by-nccl-tests
"""
return self.get_bit_s(time_s, iters) / 1e9
def get_GiB_s(self, time_s: float, iters: int) -> float:
"""
Computes the bandwidth in GiB/s, based on the NVIDIA conventions:
https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md#performance-reported-by-nccl-tests
"""
return self.get_bit_s(time_s, iters) / 2**33
class ReduceScatter(Collective):
"""
For ReduceScatter, numel sets the size of the input tensor.
"""
name = "reduce_scatter"
def _make_numel_divisible(self) -> None:
# Forces self.numel to be divisible by the world_size
new_numel = self.world_size * (self.numel // self.world_size)
if new_numel != self.numel:
print(
f"Adjusting original {self.numel=} to {new_numel} in order",
f"to be divisible by {self.world_size=}",
flush=True,
)
self.numel = new_numel
def _setup(self) -> None:
self._make_numel_divisible()
self._in = [
torch.randn(self.numel // self.world_size, dtype=self.dtype, device=self.device)
for _ in range(self.world_size)
]
self._out = torch.empty(self.numel // self.world_size, dtype=self.dtype, device=self.device)
def __call__(self) -> None:
dist.reduce_scatter(self._out, self._in, op=dist.ReduceOp.SUM)
def get_bit_s(self, time_s: float, iters: int) -> float:
msg_bits = 8 * self.dtype.itemsize * self.numel
algo_factor = (self.world_size - 1) / self.world_size
bit_s = algo_factor * msg_bits * iters / time_s
return bit_s
class ReduceScatterTensor(ReduceScatter):
name = "reduce_scatter_tensor"
def _setup(self) -> None:
self._make_numel_divisible()
self._in = torch.randn(self.numel, dtype=self.dtype, device=self.device)
self._out = torch.empty(self.numel // self.world_size, dtype=self.dtype, device=self.device)
def __call__(self) -> None:
dist.reduce_scatter_tensor(self._out, self._in, op=dist.ReduceOp.SUM)
class AllReduce(Collective):
name = "all_reduce"
def _setup(self) -> None:
self.t = torch.empty(self.numel, dtype=self.dtype, device=self.device)
def __call__(self) -> None:
dist.all_reduce(self.t, op=dist.ReduceOp.SUM)
# Also divide by the world size to avoid blowup
self.t = self.t / self.world_size
def get_bit_s(self, time_s: float, iters: int) -> float:
msg_bits = 8 * self.dtype.itemsize * self.numel
algo_factor = 2 * (self.world_size - 1) / self.world_size
bit_s = algo_factor * msg_bits * iters / time_s
return bit_s
class AllGather(Collective):
"""
For AllGather, numel sets the size of the output tensor.
"""
name = "all_gather"
def _make_numel_divisible(self) -> None:
# Forces self.numel to be divisible by the world_size
new_numel = self.world_size * (self.numel // self.world_size)
if new_numel != self.numel:
print(
f"Adjusting original {self.numel=} to {new_numel} in order",
f"to be divisible by {self.world_size=}",
flush=True,
)
self.numel = new_numel
def _setup(self) -> None:
self._make_numel_divisible()
self._in = torch.randn(self.numel // self.world_size, dtype=self.dtype, device=self.device)
self._out = [
torch.empty(self.numel // self.world_size, dtype=self.dtype, device=self.device)
for _ in range(self.world_size)
]
def __call__(self) -> None:
dist.all_gather(self._out, self._in)
def get_bit_s(self, time_s: float, iters: int) -> float:
msg_bits = 8 * self.dtype.itemsize * self.numel
algo_factor = (self.world_size - 1) / self.world_size
bit_s = algo_factor * msg_bits * iters / time_s
return bit_s
class AllGatherIntoTensor(AllGather):
name = "all_gather_into_tensor"
def _setup(self) -> None:
self._make_numel_divisible()
self._in = torch.randn(self.numel // self.world_size, dtype=self.dtype, device=self.device)
self._out = torch.empty(self.numel, dtype=self.dtype, device=self.device)
def __call__(self) -> None:
dist.all_gather_into_tensor(self._out, self._in)
# Auto-populate dict of collectives.
COLLECTIVES_DICT: dict[str, Collective] = {}
def populate_collectives_dict(cls) -> None:
for sub_cls in cls.__subclasses__():
if sub_cls.name in COLLECTIVES_DICT:
raise ValueError(f"{sub_cls.name} already exists!")
COLLECTIVES_DICT[sub_cls.name] = sub_cls
populate_collectives_dict(sub_cls)
populate_collectives_dict(Collective)
"""
Minimal profiling script for profiling compute/comms overlap.
"""
import argparse
import os
from pathlib import Path
import torch
import torch.distributed as dist
from collectives import COLLECTIVES_DICT, Collective
from torch.profiler import ProfilerActivity
if torch.cuda.is_available():
from torch import cuda as accel # noqa
DEVICE_TYPE = "cuda"
BACKEND = "nccl"
else:
# Note all of the instructions for ipex profiling. In particular, need to
# export IPEX_ZE_TRACING=1 to make GPU traces visible
# https://github.com/intel/intel-extension-for-pytorch/blob/1296c267c4247a7027d2103d05204b6b556b3d63/docs/tutorials/features/profiler_kineto.md#L24-L24
import intel_extension_for_pytorch as ipex # noqa
from torch import xpu as accel # noqa
import oneccl_bindings_for_pytorch # noqa
DEVICE_TYPE = "xpu"
BACKEND = "ccl"
SEQ_LEN = 2**13
D_MODEL = 2**14
N_LAYERS = 3
COMMS_NUMEL = 2**30
WARMUP = 3
ACTIVE = 2
RANK = int(os.getenv("RANK", 0))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", 0))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
DEVICE = torch.device(f"{DEVICE_TYPE}:{LOCAL_RANK}")
class Model(torch.nn.Module):
def __init__(self, d_model: int, device: torch.device) -> None:
super().__init__()
self.layers = torch.nn.ModuleList(
[torch.nn.Linear(d_model, d_model, bias=False, device=device) for _ in range(N_LAYERS)]
)
def forward(self, x) -> torch.Tensor:
for layer in self.layers:
x = layer(x)
return x
def get_profiler() -> torch.profiler.profile:
activities = [ProfilerActivity.CPU]
if DEVICE_TYPE == "xpu":
activities.append(ProfilerActivity.XPU)
elif DEVICE_TYPE == "cuda":
activities.append(ProfilerActivity.CUDA)
else:
raise ValueError(f"Unexpected device type {DEVICE_TYPE=}")
return torch.profiler.profile(
activities=activities,
record_shapes=False,
profile_memory=False,
with_stack=False,
)
def run_one_iter(model, batch, collective: Collective, comms_stream) -> None:
with torch.autocast(device_type=DEVICE_TYPE, dtype=torch.bfloat16):
model(batch)
with accel.stream(comms_stream):
collective()
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--warmup", type=int, default=WARMUP)
parser.add_argument("--active", type=int, default=ACTIVE)
parser.add_argument("--seq-len", type=int, default=SEQ_LEN)
parser.add_argument("--d-model", type=int, default=D_MODEL)
parser.add_argument("--comms-numel", type=int, default=COMMS_NUMEL)
parser.add_argument(
"--base-dir",
type=str,
default=None,
help="Base directory where traces go. Defaults to user's home dir.",
)
parser.add_argument("-c", "--collective", type=str, default="all_reduce")
args = parser.parse_args()
return args
def main() -> None:
args = get_args()
model = Model(device=DEVICE, d_model=args.d_model)
torch_profiler = get_profiler()
batch = torch.randn((1, args.seq_len, args.d_model), device=DEVICE)
collective = COLLECTIVES_DICT[args.collective](
numel=args.comms_numel, device=DEVICE, dtype=torch.bfloat16, world_size=WORLD_SIZE
)
comms_stream = accel.Stream(device=DEVICE)
if not RANK:
print(f"Profililng {collective.name} overlap.", flush=True)
# Warmups
for _ in range(args.warmup):
run_one_iter(model, batch, collective, comms_stream)
dist.barrier()
accel.synchronize()
# Profile
with torch_profiler as p:
for _ in range(args.active):
run_one_iter(model, batch, collective, comms_stream)
# Write out traces
profiler_output_dir = Path(args.base_dir or Path.home()).absolute() / "torch_profiler"
profiler_output_dir.mkdir(exist_ok=True)
file_name = f"profile.rank_{RANK}.{args.collective}.chrome_trace.json.gz"
export_path_str = str(profiler_output_dir / file_name)
p.export_chrome_trace(export_path_str)
if __name__ == "__main__":
assert WORLD_SIZE > 1
try:
dist.init_process_group(backend=BACKEND)
main()
finally:
dist.destroy_process_group()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment