Created
June 4, 2024 18:22
-
-
Save garrett361/be3fc0e65963fc0c80ec80b6ae1438ed to your computer and use it in GitHub Desktop.
Torch Profile Comms Compute Overlap
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from abc import ABC, abstractmethod | |
| import torch | |
| import torch.distributed as dist | |
| if torch.cuda.is_available(): | |
| accel = torch.cuda | |
| DEVICE_TYPE = "cuda" | |
| BACKEND = "nccl" | |
| else: | |
| import intel_extension_for_pytorch as ipex # noqa | |
| import oneccl_bindings_for_pytorch as torch_ccl # noqa | |
| print( | |
| f"Using Versions: {torch.__version__=}, {ipex.__version__=}, {torch_ccl.__version__=}", | |
| flush=True, | |
| ) | |
| accel = torch.xpu | |
| DEVICE_TYPE = "xpu" | |
| BACKEND = "ccl" | |
| class Collective(ABC): | |
| name: str | |
| """ | |
| Unified abstract class for collectives. Repeatedly runs collectives on the same tensors. | |
| """ | |
| def __init__( | |
| self, numel: int, device: torch.device, dtype: torch.dtype, world_size: int | |
| ) -> None: | |
| """ | |
| Sets up tensors to run collectives on. numel sets the number of elements in the largest | |
| tensor involved in the collective, e.g. the input tensor for reduce_scatter and the output | |
| tensor for all_gather. | |
| """ | |
| ... | |
| self.numel = numel | |
| self.device = device | |
| self.dtype = dtype | |
| self.world_size = world_size | |
| self._setup() | |
| @abstractmethod | |
| def _setup(self) -> None: | |
| """ | |
| Performs any necessary setup. | |
| """ | |
| @abstractmethod | |
| def __call__(self) -> None: | |
| """ | |
| Runs the collective | |
| """ | |
| ... | |
| @abstractmethod | |
| def get_bit_s(self, time_s: float, iters: int) -> float: | |
| """ | |
| Computes the bandwidth in bits/s, based on the NVIDIA conventions: | |
| https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md#performance-reported-by-nccl-tests | |
| """ | |
| ... | |
| def get_Gbit_s(self, time_s: float, iters: int) -> float: | |
| """ | |
| Computes the bandwidth in gib/s, based on the NVIDIA conventions: | |
| https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md#performance-reported-by-nccl-tests | |
| """ | |
| return self.get_bit_s(time_s, iters) / 1e9 | |
| def get_GiB_s(self, time_s: float, iters: int) -> float: | |
| """ | |
| Computes the bandwidth in GiB/s, based on the NVIDIA conventions: | |
| https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md#performance-reported-by-nccl-tests | |
| """ | |
| return self.get_bit_s(time_s, iters) / 2**33 | |
| class ReduceScatter(Collective): | |
| """ | |
| For ReduceScatter, numel sets the size of the input tensor. | |
| """ | |
| name = "reduce_scatter" | |
| def _make_numel_divisible(self) -> None: | |
| # Forces self.numel to be divisible by the world_size | |
| new_numel = self.world_size * (self.numel // self.world_size) | |
| if new_numel != self.numel: | |
| print( | |
| f"Adjusting original {self.numel=} to {new_numel} in order", | |
| f"to be divisible by {self.world_size=}", | |
| flush=True, | |
| ) | |
| self.numel = new_numel | |
| def _setup(self) -> None: | |
| self._make_numel_divisible() | |
| self._in = [ | |
| torch.randn(self.numel // self.world_size, dtype=self.dtype, device=self.device) | |
| for _ in range(self.world_size) | |
| ] | |
| self._out = torch.empty(self.numel // self.world_size, dtype=self.dtype, device=self.device) | |
| def __call__(self) -> None: | |
| dist.reduce_scatter(self._out, self._in, op=dist.ReduceOp.SUM) | |
| def get_bit_s(self, time_s: float, iters: int) -> float: | |
| msg_bits = 8 * self.dtype.itemsize * self.numel | |
| algo_factor = (self.world_size - 1) / self.world_size | |
| bit_s = algo_factor * msg_bits * iters / time_s | |
| return bit_s | |
| class ReduceScatterTensor(ReduceScatter): | |
| name = "reduce_scatter_tensor" | |
| def _setup(self) -> None: | |
| self._make_numel_divisible() | |
| self._in = torch.randn(self.numel, dtype=self.dtype, device=self.device) | |
| self._out = torch.empty(self.numel // self.world_size, dtype=self.dtype, device=self.device) | |
| def __call__(self) -> None: | |
| dist.reduce_scatter_tensor(self._out, self._in, op=dist.ReduceOp.SUM) | |
| class AllReduce(Collective): | |
| name = "all_reduce" | |
| def _setup(self) -> None: | |
| self.t = torch.empty(self.numel, dtype=self.dtype, device=self.device) | |
| def __call__(self) -> None: | |
| dist.all_reduce(self.t, op=dist.ReduceOp.SUM) | |
| # Also divide by the world size to avoid blowup | |
| self.t = self.t / self.world_size | |
| def get_bit_s(self, time_s: float, iters: int) -> float: | |
| msg_bits = 8 * self.dtype.itemsize * self.numel | |
| algo_factor = 2 * (self.world_size - 1) / self.world_size | |
| bit_s = algo_factor * msg_bits * iters / time_s | |
| return bit_s | |
| class AllGather(Collective): | |
| """ | |
| For AllGather, numel sets the size of the output tensor. | |
| """ | |
| name = "all_gather" | |
| def _make_numel_divisible(self) -> None: | |
| # Forces self.numel to be divisible by the world_size | |
| new_numel = self.world_size * (self.numel // self.world_size) | |
| if new_numel != self.numel: | |
| print( | |
| f"Adjusting original {self.numel=} to {new_numel} in order", | |
| f"to be divisible by {self.world_size=}", | |
| flush=True, | |
| ) | |
| self.numel = new_numel | |
| def _setup(self) -> None: | |
| self._make_numel_divisible() | |
| self._in = torch.randn(self.numel // self.world_size, dtype=self.dtype, device=self.device) | |
| self._out = [ | |
| torch.empty(self.numel // self.world_size, dtype=self.dtype, device=self.device) | |
| for _ in range(self.world_size) | |
| ] | |
| def __call__(self) -> None: | |
| dist.all_gather(self._out, self._in) | |
| def get_bit_s(self, time_s: float, iters: int) -> float: | |
| msg_bits = 8 * self.dtype.itemsize * self.numel | |
| algo_factor = (self.world_size - 1) / self.world_size | |
| bit_s = algo_factor * msg_bits * iters / time_s | |
| return bit_s | |
| class AllGatherIntoTensor(AllGather): | |
| name = "all_gather_into_tensor" | |
| def _setup(self) -> None: | |
| self._make_numel_divisible() | |
| self._in = torch.randn(self.numel // self.world_size, dtype=self.dtype, device=self.device) | |
| self._out = torch.empty(self.numel, dtype=self.dtype, device=self.device) | |
| def __call__(self) -> None: | |
| dist.all_gather_into_tensor(self._out, self._in) | |
| # Auto-populate dict of collectives. | |
| COLLECTIVES_DICT: dict[str, Collective] = {} | |
| def populate_collectives_dict(cls) -> None: | |
| for sub_cls in cls.__subclasses__(): | |
| if sub_cls.name in COLLECTIVES_DICT: | |
| raise ValueError(f"{sub_cls.name} already exists!") | |
| COLLECTIVES_DICT[sub_cls.name] = sub_cls | |
| populate_collectives_dict(sub_cls) | |
| populate_collectives_dict(Collective) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Minimal profiling script for profiling compute/comms overlap. | |
| """ | |
| import argparse | |
| import os | |
| from pathlib import Path | |
| import torch | |
| import torch.distributed as dist | |
| from collectives import COLLECTIVES_DICT, Collective | |
| from torch.profiler import ProfilerActivity | |
| if torch.cuda.is_available(): | |
| from torch import cuda as accel # noqa | |
| DEVICE_TYPE = "cuda" | |
| BACKEND = "nccl" | |
| else: | |
| # Note all of the instructions for ipex profiling. In particular, need to | |
| # export IPEX_ZE_TRACING=1 to make GPU traces visible | |
| # https://github.com/intel/intel-extension-for-pytorch/blob/1296c267c4247a7027d2103d05204b6b556b3d63/docs/tutorials/features/profiler_kineto.md#L24-L24 | |
| import intel_extension_for_pytorch as ipex # noqa | |
| from torch import xpu as accel # noqa | |
| import oneccl_bindings_for_pytorch # noqa | |
| DEVICE_TYPE = "xpu" | |
| BACKEND = "ccl" | |
| SEQ_LEN = 2**13 | |
| D_MODEL = 2**14 | |
| N_LAYERS = 3 | |
| COMMS_NUMEL = 2**30 | |
| WARMUP = 3 | |
| ACTIVE = 2 | |
| RANK = int(os.getenv("RANK", 0)) | |
| LOCAL_RANK = int(os.getenv("LOCAL_RANK", 0)) | |
| WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1)) | |
| DEVICE = torch.device(f"{DEVICE_TYPE}:{LOCAL_RANK}") | |
| class Model(torch.nn.Module): | |
| def __init__(self, d_model: int, device: torch.device) -> None: | |
| super().__init__() | |
| self.layers = torch.nn.ModuleList( | |
| [torch.nn.Linear(d_model, d_model, bias=False, device=device) for _ in range(N_LAYERS)] | |
| ) | |
| def forward(self, x) -> torch.Tensor: | |
| for layer in self.layers: | |
| x = layer(x) | |
| return x | |
| def get_profiler() -> torch.profiler.profile: | |
| activities = [ProfilerActivity.CPU] | |
| if DEVICE_TYPE == "xpu": | |
| activities.append(ProfilerActivity.XPU) | |
| elif DEVICE_TYPE == "cuda": | |
| activities.append(ProfilerActivity.CUDA) | |
| else: | |
| raise ValueError(f"Unexpected device type {DEVICE_TYPE=}") | |
| return torch.profiler.profile( | |
| activities=activities, | |
| record_shapes=False, | |
| profile_memory=False, | |
| with_stack=False, | |
| ) | |
| def run_one_iter(model, batch, collective: Collective, comms_stream) -> None: | |
| with torch.autocast(device_type=DEVICE_TYPE, dtype=torch.bfloat16): | |
| model(batch) | |
| with accel.stream(comms_stream): | |
| collective() | |
| def get_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--warmup", type=int, default=WARMUP) | |
| parser.add_argument("--active", type=int, default=ACTIVE) | |
| parser.add_argument("--seq-len", type=int, default=SEQ_LEN) | |
| parser.add_argument("--d-model", type=int, default=D_MODEL) | |
| parser.add_argument("--comms-numel", type=int, default=COMMS_NUMEL) | |
| parser.add_argument( | |
| "--base-dir", | |
| type=str, | |
| default=None, | |
| help="Base directory where traces go. Defaults to user's home dir.", | |
| ) | |
| parser.add_argument("-c", "--collective", type=str, default="all_reduce") | |
| args = parser.parse_args() | |
| return args | |
| def main() -> None: | |
| args = get_args() | |
| model = Model(device=DEVICE, d_model=args.d_model) | |
| torch_profiler = get_profiler() | |
| batch = torch.randn((1, args.seq_len, args.d_model), device=DEVICE) | |
| collective = COLLECTIVES_DICT[args.collective]( | |
| numel=args.comms_numel, device=DEVICE, dtype=torch.bfloat16, world_size=WORLD_SIZE | |
| ) | |
| comms_stream = accel.Stream(device=DEVICE) | |
| if not RANK: | |
| print(f"Profililng {collective.name} overlap.", flush=True) | |
| # Warmups | |
| for _ in range(args.warmup): | |
| run_one_iter(model, batch, collective, comms_stream) | |
| dist.barrier() | |
| accel.synchronize() | |
| # Profile | |
| with torch_profiler as p: | |
| for _ in range(args.active): | |
| run_one_iter(model, batch, collective, comms_stream) | |
| # Write out traces | |
| profiler_output_dir = Path(args.base_dir or Path.home()).absolute() / "torch_profiler" | |
| profiler_output_dir.mkdir(exist_ok=True) | |
| file_name = f"profile.rank_{RANK}.{args.collective}.chrome_trace.json.gz" | |
| export_path_str = str(profiler_output_dir / file_name) | |
| p.export_chrome_trace(export_path_str) | |
| if __name__ == "__main__": | |
| assert WORLD_SIZE > 1 | |
| try: | |
| dist.init_process_group(backend=BACKEND) | |
| main() | |
| finally: | |
| dist.destroy_process_group() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment