Skip to content

Instantly share code, notes, and snippets.

@garrett361
Created November 27, 2024 22:07
Show Gist options
  • Select an option

  • Save garrett361/2608c8e26e4ed49cc6660cc75b275f2b to your computer and use it in GitHub Desktop.

Select an option

Save garrett361/2608c8e26e4ed49cc6660cc75b275f2b to your computer and use it in GitHub Desktop.
overwrite reduce scatter
import argparse
import multiprocessing as mp
import os
import torch
import torch.distributed as dist
def print_rank(s: str) -> None:
s = f'[rank={os.environ["RANK"]}] ' + s
print(s)
def set_env(world_size: int, rank: int) -> None:
os.environ["RANK"] = os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29500"
@torch.compile
def reduce_scatter_tensor_wrapper(outputs: torch.Tensor, inputs: torch.Tensor) -> None:
return dist.reduce_scatter_tensor(outputs, inputs)
@torch.ops._c10d_functional.reduce_scatter_tensor.default.py_impl(
torch._C.DispatchKey.CompositeImplicitAutograd
)
def reduce_scatter_tensor(inputs, reduce_op, group_size, group_name) -> torch.Tensor:
print_rank("USING OVERWRITE")
world_size = int(os.environ["WORLD_SIZE"])
inputs_copy = torch.clone(inputs)
dist.all_reduce(inputs_copy)
return inputs_copy[: inputs_copy.shape[0] // world_size]
def target(world_size: int, rank: int, no_override: bool) -> None:
try:
set_env(world_size, rank)
dist.init_process_group(backend="gloo")
inputs = torch.ones(world_size, device=f"cpu:{rank}")
print_rank(f"Before {inputs=}")
if no_override:
outputs = torch.zeros(1, device=f"cpu:{rank}")
dist.reduce_scatter_tensor(outputs, inputs)
else:
outputs = torch.zeros(1, device=f"cpu:{rank}")
reduce_scatter_tensor_wrapper(outputs, inputs)
print_rank(f"After {inputs=}, {outputs=}")
torch.testing.assert_close(outputs, world_size * torch.ones_like(outputs))
finally:
dist.destroy_process_group()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--world_size", type=int, default=2)
parser.add_argument("--no_override", action="store_true")
args = parser.parse_args()
processes = [
mp.Process(target=target, args=(args.world_size, rank, args.no_override))
for rank in range(args.world_size)
]
for p in processes:
p.start()
for p in processes:
p.join()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment