Skip to content

Instantly share code, notes, and snippets.

@jeromeku
Forked from ezyang/localtensor.py
Created November 20, 2025 22:37
Show Gist options
  • Select an option

  • Save jeromeku/10b464199c7b4fd157b99bfb45b08e03 to your computer and use it in GitHub Desktop.

Select an option

Save jeromeku/10b464199c7b4fd157b99bfb45b08e03 to your computer and use it in GitHub Desktop.
import torch
import unittest
from torch import Tensor
from torch.distributed.tensor import (
DTensor,
DeviceMesh,
distribute_tensor,
init_device_mesh,
Partial,
Replicate,
Shard,
)
from torch.distributed.tensor.placement_types import _StridedShard
from torch.distributed._local_tensor import (
local_tensor_mode,
LocalTensor,
LocalTensorMode,
)
import traceback
from torch.distributed.tensor._sharding_prop import ShardingPropagator
S = Shard
R = Replicate()
_SS = _StridedShard
def product(it):
x = 1
for i in it:
x *= i
return x
def arange_nd(*sizes):
if len(sizes) == 1 and isinstance(sizes[0], (list, tuple)):
sizes = sizes[0]
return torch.arange(product(sizes)).view(sizes)
def reconcile(l: Tensor):
"""Asserts that a LocalTensor is the same on all ranks, and returns the single Tensor."""
if isinstance(l, LocalTensor):
return l.reconcile()
return l
def exit_local_tensor_mode():
from torch.distributed import _local_tensor
if getattr(_local_tensor, "_LOCAL_TENSOR_MODE", None):
for lm in list(reversed(_local_tensor._LOCAL_TENSOR_MODE)):
lm.__exit__(None, None, None)
elif getattr(_local_tensor, "_GLOBAL_LOCAL_TENSOR_MODE", None):
for lm in list(reversed(_local_tensor._GLOBAL_TENSOR_MODE)):
lm.__exit__(None, None, None)
def init_local_tensor_mode(world_size):
exit_local_tensor_mode()
try:
torch.distributed.destroy_process_group()
except AssertionError:
pass
torch.distributed.init_process_group(
"fake",
rank=0,
world_size=world_size,
)
lm = LocalTensorMode(world_size)
lm.__enter__()
return world_size
def init_fake_tensor_mode(world_size):
exit_local_tensor_mode()
try:
torch.distributed.destroy_process_group()
except AssertionError:
pass
torch.distributed.init_process_group(
"fake",
rank=0,
world_size=world_size,
)
return world_size
world_size = init_local_tensor_mode(4)
mesh = init_device_mesh("cpu", (4,), mesh_dim_names=("x",))
a = DTensor.from_local(arange_nd(4).float(), mesh, [R])
b = DTensor.from_local(torch.ones(4), mesh, [Partial()])
a += b
print(a)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment