Skip to content

Instantly share code, notes, and snippets.

@garrett361
Created December 11, 2024 03:03
Show Gist options
  • Select an option

  • Save garrett361/b8b9ffd31d4fc1080dc2356cbc35ea99 to your computer and use it in GitHub Desktop.

Select an option

Save garrett361/b8b9ffd31d4fc1080dc2356cbc35ea99 to your computer and use it in GitHub Desktop.
DTensor double-sharded random
import multiprocessing as mp
import os
import torch
import torch.distributed as dist
from torch.distributed.tensor import distribute_tensor
from torch.distributed.tensor.placement_types import (
Shard,
)
WORLD_SIZE = 8
def set_env(rank: int) -> None:
os.environ["RANK"] = os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(WORLD_SIZE)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29500"
def target() -> None:
device_mesh = dist.device_mesh.init_device_mesh("cpu", mesh_shape=(2, 4))
placements = (Shard(0), Shard(0))
dt_zeros = distribute_tensor(
torch.zeros(WORLD_SIZE), device_mesh=device_mesh, placements=placements
)
dt_zeros.normal_()
def wrapper(rank: int) -> None:
torch.manual_seed(42)
try:
set_env(rank)
dist.init_process_group(backend="gloo")
target()
finally:
dist.destroy_process_group()
if __name__ == "__main__":
print(f"{torch.__version__=}")
processes = [mp.Process(target=wrapper, args=(rank,)) for rank in range(WORLD_SIZE)]
for p in processes:
p.start()
for p in processes:
p.join()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment