-
-
Save jeromeku/10b464199c7b4fd157b99bfb45b08e03 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import unittest | |
| from torch import Tensor | |
| from torch.distributed.tensor import ( | |
| DTensor, | |
| DeviceMesh, | |
| distribute_tensor, | |
| init_device_mesh, | |
| Partial, | |
| Replicate, | |
| Shard, | |
| ) | |
| from torch.distributed.tensor.placement_types import _StridedShard | |
| from torch.distributed._local_tensor import ( | |
| local_tensor_mode, | |
| LocalTensor, | |
| LocalTensorMode, | |
| ) | |
| import traceback | |
| from torch.distributed.tensor._sharding_prop import ShardingPropagator | |
| S = Shard | |
| R = Replicate() | |
| _SS = _StridedShard | |
| def product(it): | |
| x = 1 | |
| for i in it: | |
| x *= i | |
| return x | |
| def arange_nd(*sizes): | |
| if len(sizes) == 1 and isinstance(sizes[0], (list, tuple)): | |
| sizes = sizes[0] | |
| return torch.arange(product(sizes)).view(sizes) | |
| def reconcile(l: Tensor): | |
| """Asserts that a LocalTensor is the same on all ranks, and returns the single Tensor.""" | |
| if isinstance(l, LocalTensor): | |
| return l.reconcile() | |
| return l | |
| def exit_local_tensor_mode(): | |
| from torch.distributed import _local_tensor | |
| if getattr(_local_tensor, "_LOCAL_TENSOR_MODE", None): | |
| for lm in list(reversed(_local_tensor._LOCAL_TENSOR_MODE)): | |
| lm.__exit__(None, None, None) | |
| elif getattr(_local_tensor, "_GLOBAL_LOCAL_TENSOR_MODE", None): | |
| for lm in list(reversed(_local_tensor._GLOBAL_TENSOR_MODE)): | |
| lm.__exit__(None, None, None) | |
| def init_local_tensor_mode(world_size): | |
| exit_local_tensor_mode() | |
| try: | |
| torch.distributed.destroy_process_group() | |
| except AssertionError: | |
| pass | |
| torch.distributed.init_process_group( | |
| "fake", | |
| rank=0, | |
| world_size=world_size, | |
| ) | |
| lm = LocalTensorMode(world_size) | |
| lm.__enter__() | |
| return world_size | |
| def init_fake_tensor_mode(world_size): | |
| exit_local_tensor_mode() | |
| try: | |
| torch.distributed.destroy_process_group() | |
| except AssertionError: | |
| pass | |
| torch.distributed.init_process_group( | |
| "fake", | |
| rank=0, | |
| world_size=world_size, | |
| ) | |
| return world_size | |
| world_size = init_local_tensor_mode(4) | |
| mesh = init_device_mesh("cpu", (4,), mesh_dim_names=("x",)) | |
| a = DTensor.from_local(arange_nd(4).float(), mesh, [R]) | |
| b = DTensor.from_local(torch.ones(4), mesh, [Partial()]) | |
| a += b | |
| print(a) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment