Skip to content

Instantly share code, notes, and snippets.

@garrett361
Created June 5, 2025 14:35
Show Gist options
  • Select an option

  • Save garrett361/125c02fe897e62d894a82e807d8fced6 to your computer and use it in GitHub Desktop.

Select an option

Save garrett361/125c02fe897e62d894a82e807d8fced6 to your computer and use it in GitHub Desktop.
set_reduce_scatter_divide_factor Error
torchrun --nproc-per-node 2 set_div_err.py
W0605 14:34:10.112000 783116 torch/distributed/run.py:766]
W0605 14:34:10.112000 783116 torch/distributed/run.py:766] *****************************************
W0605 14:34:10.112000 783116 torch/distributed/run.py:766] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0605 14:34:10.112000 783116 torch/distributed/run.py:766] *****************************************
Running with trivial mp_policy
Passed on RANK=0 with mp_policy=MixedPrecisionPolicy(param_dtype=None, reduce_dtype=None, output_dtype=None, cast_forward_inputs=True)Passed on RANK=1 with mp_policy=MixedPrecisionPolicy(param_dtype=None, reduce_dtype=None, output_dtype=None, cast_forward_inputs=True)
Running with non-trivial mp_policy
FAILED on RANK=0 with mp_policy=MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, output_dtype=None, cast_forward_inputs=True) e=TypeError('PreMulSum Data type must be half, float, or double')
FAILED on RANK=1 with mp_policy=MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, output_dtype=None, cast_forward_inputs=True) e=TypeError('PreMulSum Data type must be half, float, or double')
import os
from copy import deepcopy
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard
d_model = 128
def run_test(model, mesh, mp_policy, inputs, div_factor: float = 8.0) -> None:
for lin in model:
fully_shard(lin, mesh=mesh, mp_policy=mp_policy)
fully_shard(model, mesh=mesh, mp_policy=mp_policy)
model[0].set_reduce_scatter_divide_factor(div_factor)
try:
model(inputs).sum().backward()
print(f"Passed on {RANK=} with {mp_policy=}")
except Exception as e:
print(f"FAILED on {RANK=} with {mp_policy=} {e=}")
dist.barrier()
if __name__ == "__main__":
RANK = int(os.environ["RANK"])
LOCAL_RANK = int(os.environ["LOCAL_RANK"])
WORLD_SIZE = int(os.environ["WORLD_SIZE"])
device = torch.device(f"cuda:{RANK}")
torch.cuda.set_device(device)
try:
dist.init_process_group(backend="nccl", rank=RANK, world_size=WORLD_SIZE, device_id=device)
mesh = init_device_mesh("cuda", (WORLD_SIZE,))
model = nn.Sequential(
*(nn.Linear(d_model, d_model, bias=False, device=device) for _ in range(3))
)
model_mp = deepcopy(model)
inputs = torch.randn(1, d_model, device=device)
if not RANK:
print("Running with trivial mp_policy")
dist.barrier()
run_test(model, mesh, mp_policy=MixedPrecisionPolicy(), inputs=inputs, div_factor=8.0)
if not RANK:
print("Running with non-trivial mp_policy")
dist.barrier()
run_test(
model_mp,
mesh,
mp_policy=MixedPrecisionPolicy(torch.bfloat16, torch.bfloat16),
inputs=inputs,
div_factor=8.0,
)
finally:
dist.destroy_process_group()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment