Created
June 5, 2025 14:35
-
-
Save garrett361/125c02fe897e62d894a82e807d8fced6 to your computer and use it in GitHub Desktop.
set_reduce_scatter_divide_factor Error
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| torchrun --nproc-per-node 2 set_div_err.py | |
| W0605 14:34:10.112000 783116 torch/distributed/run.py:766] | |
| W0605 14:34:10.112000 783116 torch/distributed/run.py:766] ***************************************** | |
| W0605 14:34:10.112000 783116 torch/distributed/run.py:766] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. | |
| W0605 14:34:10.112000 783116 torch/distributed/run.py:766] ***************************************** | |
| Running with trivial mp_policy | |
| Passed on RANK=0 with mp_policy=MixedPrecisionPolicy(param_dtype=None, reduce_dtype=None, output_dtype=None, cast_forward_inputs=True)Passed on RANK=1 with mp_policy=MixedPrecisionPolicy(param_dtype=None, reduce_dtype=None, output_dtype=None, cast_forward_inputs=True) | |
| Running with non-trivial mp_policy | |
| FAILED on RANK=0 with mp_policy=MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, output_dtype=None, cast_forward_inputs=True) e=TypeError('PreMulSum Data type must be half, float, or double') | |
| FAILED on RANK=1 with mp_policy=MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, output_dtype=None, cast_forward_inputs=True) e=TypeError('PreMulSum Data type must be half, float, or double') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| from copy import deepcopy | |
| import torch | |
| import torch.distributed as dist | |
| import torch.nn as nn | |
| from torch.distributed._composable.fsdp import fully_shard | |
| from torch.distributed.device_mesh import init_device_mesh | |
| from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard | |
| d_model = 128 | |
| def run_test(model, mesh, mp_policy, inputs, div_factor: float = 8.0) -> None: | |
| for lin in model: | |
| fully_shard(lin, mesh=mesh, mp_policy=mp_policy) | |
| fully_shard(model, mesh=mesh, mp_policy=mp_policy) | |
| model[0].set_reduce_scatter_divide_factor(div_factor) | |
| try: | |
| model(inputs).sum().backward() | |
| print(f"Passed on {RANK=} with {mp_policy=}") | |
| except Exception as e: | |
| print(f"FAILED on {RANK=} with {mp_policy=} {e=}") | |
| dist.barrier() | |
| if __name__ == "__main__": | |
| RANK = int(os.environ["RANK"]) | |
| LOCAL_RANK = int(os.environ["LOCAL_RANK"]) | |
| WORLD_SIZE = int(os.environ["WORLD_SIZE"]) | |
| device = torch.device(f"cuda:{RANK}") | |
| torch.cuda.set_device(device) | |
| try: | |
| dist.init_process_group(backend="nccl", rank=RANK, world_size=WORLD_SIZE, device_id=device) | |
| mesh = init_device_mesh("cuda", (WORLD_SIZE,)) | |
| model = nn.Sequential( | |
| *(nn.Linear(d_model, d_model, bias=False, device=device) for _ in range(3)) | |
| ) | |
| model_mp = deepcopy(model) | |
| inputs = torch.randn(1, d_model, device=device) | |
| if not RANK: | |
| print("Running with trivial mp_policy") | |
| dist.barrier() | |
| run_test(model, mesh, mp_policy=MixedPrecisionPolicy(), inputs=inputs, div_factor=8.0) | |
| if not RANK: | |
| print("Running with non-trivial mp_policy") | |
| dist.barrier() | |
| run_test( | |
| model_mp, | |
| mesh, | |
| mp_policy=MixedPrecisionPolicy(torch.bfloat16, torch.bfloat16), | |
| inputs=inputs, | |
| div_factor=8.0, | |
| ) | |
| finally: | |
| dist.destroy_process_group() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment