Created
December 2, 2025 16:35
-
-
Save zhuangh/3e26751984312938fd0fc61f9bd50722 to your computer and use it in GitHub Desktop.
5d_parallelism
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import torch.multiprocessing as mp | |
| import torch.distributed as dist | |
| from torch.distributed.device_mesh import init_device_mesh | |
| from torch.distributed.fsdp import fully_shard | |
| # ========================================== | |
| # 0. UTILITY: Differentiable All-to-All | |
| # ========================================== | |
| class AllToAll(torch.autograd.Function): | |
| """ | |
| A custom autograd function that allows gradients to flow through | |
| the all_to_all communication primitive. | |
| """ | |
| @staticmethod | |
| def forward(ctx, input_tensor, group): | |
| ctx.group = group | |
| # input_tensor: [World, Capacity, Hidden] | |
| output_tensor = torch.empty_like(input_tensor) | |
| dist.all_to_all_single(output_tensor, input_tensor, group=group) | |
| return output_tensor | |
| @staticmethod | |
| def backward(ctx, grad_output): | |
| # The backward of All-to-All is simply another All-to-All | |
| # (Symmetric communication) | |
| grad_input = torch.empty_like(grad_output) | |
| dist.all_to_all_single(grad_input, grad_output, group=ctx.group) | |
| return grad_input, None | |
| # ========================================== | |
| # 1. COMPONENT: Context Parallel (Ring Attention) | |
| # ========================================== | |
| class RingAttention(nn.Module): | |
| def __init__(self, hidden_dim, mesh): | |
| super().__init__() | |
| self.mesh = mesh | |
| self.pg = mesh.get_group() | |
| self.head_dim = hidden_dim // 4 | |
| self.proj_q = nn.Linear(hidden_dim, hidden_dim) | |
| self.proj_k = nn.Linear(hidden_dim, hidden_dim) | |
| self.proj_v = nn.Linear(hidden_dim, hidden_dim) | |
| self.out = nn.Linear(hidden_dim, hidden_dim) | |
| def forward(self, x): | |
| local_seq = x.size(1) | |
| q = self.proj_q(x) | |
| k = self.proj_k(x) | |
| v = self.proj_v(x) | |
| # Rank calculations | |
| group_rank = dist.get_rank(self.pg) | |
| group_size = dist.get_world_size(self.pg) | |
| next_rank_idx = (group_rank + 1) % group_size | |
| prev_rank_idx = (group_rank - 1 + group_size) % group_size | |
| next_global_rank = dist.get_global_rank(self.pg, next_rank_idx) | |
| prev_global_rank = dist.get_global_rank(self.pg, prev_rank_idx) | |
| curr_k, curr_v = k, v | |
| recv_k = torch.zeros_like(k) | |
| recv_v = torch.zeros_like(v) | |
| attn_out = torch.zeros_like(q) | |
| # Ring Loop | |
| for step in range(group_size): | |
| # 1. Compute | |
| scores = torch.matmul(q, curr_k.transpose(1, 2)) / math.sqrt(self.head_dim) | |
| attn_out += torch.matmul(scores, curr_v) | |
| if step == group_size - 1: | |
| break | |
| # 2. Communicate | |
| reqs = [ | |
| dist.isend(curr_k, dst=next_global_rank), | |
| dist.isend(curr_v, dst=next_global_rank), | |
| dist.irecv(recv_k, src=prev_global_rank), | |
| dist.irecv(recv_v, src=prev_global_rank) | |
| ] | |
| for req in reqs: | |
| req.wait() | |
| curr_k = recv_k.clone() | |
| curr_v = recv_v.clone() | |
| return self.out(attn_out) | |
| # ========================================== | |
| # 2. COMPONENT: Expert Parallel (MoE MLP) | |
| # ========================================== | |
| class MoELayer(nn.Module): | |
| def __init__(self, hidden_dim, ffn_dim, mesh): | |
| super().__init__() | |
| self.mesh = mesh | |
| self.pg = mesh.get_group() | |
| self.num_experts = dist.get_world_size(self.pg) | |
| self.router = nn.Linear(hidden_dim, self.num_experts) | |
| self.expert = nn.Sequential( | |
| nn.Linear(hidden_dim, ffn_dim), | |
| nn.ReLU(), | |
| nn.Linear(ffn_dim, hidden_dim) | |
| ) | |
| def forward(self, x): | |
| b, s, h = x.shape | |
| x_flat = x.view(-1, h) | |
| # 1. Route | |
| logits = self.router(x_flat) | |
| _, indices = torch.max(logits, dim=1) | |
| capacity = x_flat.size(0) // self.num_experts | |
| dispatched_x = torch.zeros(self.num_experts, capacity, h) | |
| counts = torch.zeros(self.num_experts, dtype=torch.long) | |
| # 2. Assign to slots (Naive CPU implementation) | |
| # Note: The 'indices' (torch.max) operation is non-differentiable. | |
| # This breaks gradient flow to the Router, but NOT to the Expert. | |
| # To fix Router training, we would need to multiply by Softmax probabilities. | |
| for i in range(x_flat.size(0)): | |
| dest = indices[i] | |
| if counts[dest] < capacity: | |
| dispatched_x[dest, counts[dest]] = x_flat[i] | |
| counts[dest] += 1 | |
| # 3. Dispatch (Differentiable) | |
| # FIX: Use the custom Autograd function | |
| recv_buffer = AllToAll.apply(dispatched_x, self.pg) | |
| # 4. Expert Compute | |
| expert_out = self.expert(recv_buffer) | |
| # 5. Combine (Differentiable) | |
| # FIX: Use the custom Autograd function | |
| final_out_sorted = AllToAll.apply(expert_out, self.pg) | |
| out_flat = final_out_sorted.view(b*s, h) | |
| return out_flat.view(b, s, h) | |
| # ========================================== | |
| # 3. The 5D Training Loop | |
| # ========================================== | |
| def train_5d(rank, world_size): | |
| os.environ["MASTER_ADDR"] = "localhost" | |
| os.environ["MASTER_PORT"] = "12391" | |
| dist.init_process_group("gloo", rank=rank, world_size=world_size) | |
| # --- Mesh Setup --- | |
| mesh = init_device_mesh("cpu", (2, 2, 2), mesh_dim_names=("pp", "ep", "cp")) | |
| pp_mesh = mesh["pp"] | |
| ep_mesh = mesh["ep"] | |
| cp_mesh = mesh["cp"] | |
| pp_rank = mesh.get_coordinate()[0] | |
| if pp_rank == 0: | |
| model = nn.Sequential( | |
| RingAttention(hidden_dim=32, mesh=cp_mesh), | |
| MoELayer(hidden_dim=32, ffn_dim=64, mesh=ep_mesh) | |
| ) | |
| else: | |
| model = nn.Sequential( | |
| RingAttention(hidden_dim=32, mesh=cp_mesh), | |
| MoELayer(hidden_dim=32, ffn_dim=64, mesh=ep_mesh), | |
| nn.Linear(32, 10) | |
| ) | |
| fsdp_model = fully_shard(model, mesh=ep_mesh) | |
| optimizer = optim.AdamW(fsdp_model.parameters(), lr=1e-3) | |
| # --- Training --- | |
| local_seq_len = 8 | |
| hidden_dim = 32 | |
| batch_size = 4 | |
| for step in range(3): | |
| optimizer.zero_grad() | |
| if pp_rank == 0: | |
| torch.manual_seed(rank + step) | |
| inputs = torch.randn(batch_size, local_seq_len, hidden_dim) | |
| out = fsdp_model(inputs) | |
| peer = rank + 4 | |
| dist.send(out.detach(), dst=peer) | |
| grad = torch.zeros_like(out) | |
| dist.recv(grad, src=peer) | |
| # Now out.backward() will work because MoELayer is differentiable | |
| out.backward(grad) | |
| elif pp_rank == 1: | |
| peer = rank - 4 | |
| inputs = torch.zeros(batch_size, local_seq_len, hidden_dim) | |
| dist.recv(inputs, src=peer) | |
| inputs.requires_grad = True | |
| out = fsdp_model(inputs) | |
| targets = torch.randint(0, 10, (batch_size * local_seq_len,)) | |
| loss = nn.CrossEntropyLoss()(out.view(-1, 10), targets) | |
| if step % 2 == 0: | |
| print(f"Rank {rank} (Stage 1): Loss {loss.item():.4f}") | |
| loss.backward() | |
| # Check for NaN or None gradients to debug early | |
| if inputs.grad is None: | |
| print(f"Rank {rank}: FATAL - inputs.grad is None!") | |
| else: | |
| dist.send(inputs.grad, dst=peer) | |
| optimizer.step() | |
| print(f"Rank {rank}: Finished. Barrier...") | |
| dist.barrier() | |
| dist.destroy_process_group() | |
| if __name__ == "__main__": | |
| world_size = 8 | |
| mp.set_start_method("spawn", force=True) | |
| print(f"Running Final 5D Parallelism on {world_size} Processes...") | |
| mp.spawn(train_5d, args=(world_size,), nprocs=world_size, join=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment