Last active
March 19, 2025 04:08
-
-
Save HDCharles/415e412e12bbad2acc8462070a4e3605 to your computer and use it in GitHub Desktop.
code testing moe implementations with compile
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn as nn | |
| from torch import Tensor | |
| from torch.nn import functional as F | |
| from dataclasses import dataclass | |
| torch.manual_seed(0) | |
| # T tokens | |
| # E experts | |
| # D dim | |
| # I intermediate dim | |
| # A activated experts | |
| # T'(e) tokens for expert e | |
| class MOEFeedForward(nn.Module): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self.gate = nn.Linear(4, 8, bias=False) | |
| self.cond_ffn = ConditionalFeedForward() | |
| self.dim = 4 | |
| self.num_activated_experts = 2 | |
| def forward(self, x: Tensor) -> Tensor: | |
| batch_size = x.shape[0] | |
| x = x.view(-1, self.dim) # x: [T, D] | |
| scores = self.gate(x) # [T, E] | |
| expert_weights = F.softmax(scores, dim=-1) | |
| expert_weights, expert_indices = torch.topk(expert_weights, self.num_activated_experts, dim=-1) # [T, A], [T, A] | |
| expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A] | |
| out = self.cond_ffn(x, expert_indices, expert_weights, self.num_activated_experts) | |
| return out.reshape(batch_size, -1, self.dim) | |
| # THIS IS WHERE COMPILE KEEPS FAILING | |
| def get_indices_and_weights_per_expert(expert_indices, expert_weights, num_experts): | |
| num_tokens, experts_per_token = expert_indices.shape | |
| # goal: extract the tokens used by each expert by sorting expert_indices by expert, then indexing the sorted list based on how many tokens each expert gets | |
| # expert_indices = [[0, 1] [1, 3], [0, 2]] i.e. 0th token goes to experts 0 and 1, next token goes to experts 1 and 3 ...etc | |
| # num_tokens_per_expert = [2, 2, 1, 1] i.e. 0th expert gets token 0 and 2 ([[0*, 1] [1, 3], [0*, 2]]) | |
| # sorted_tokens_by_expert = [0, 2, 0, 1, 2, 1] # need to break up sorted_tokens_by_expert into groups of size num_tokens_per_expert | |
| # tok_indices_per_expert = [|0, 2 | 0, 1 | 2 | 1 |] | |
| # ->[[0, 2][ 0, 1][ 2 ][1 ]] | |
| # ALL 3 OF THESE DO THE SAME THING, NEED TO GET INDICES OF TOKEN FOR EACH EXPERT | |
| # IS THERE A COMPILE COMPATIBLE WAY TO DO THIS? | |
| # do_split: uses torch.split | |
| # do_arange: uses arange | |
| # index: uses indexing to token shuffle | |
| if version == "split": | |
| sorted_token_activation_by_expert = expert_indices.view(-1).argsort(stable=True) | |
| sorted_tokens_by_expert = (sorted_token_activation_by_expert/experts_per_token).floor().to(torch.int) | |
| num_tokens_per_expert = torch.histc(expert_indices, bins=num_experts+1, min=-1, max=num_experts) | |
| # arrange weights in same way as tokens and then group weights by expert | |
| sorted_weights_by_expert = expert_weights.view(-1)[sorted_token_activation_by_expert].view(-1,1) | |
| tok_indices_per_expert = torch.split(sorted_tokens_by_expert, num_tokens_per_expert.tolist()) # fails on tolist, breaks cudagraph | |
| tok_weights_per_expert = torch.split(sorted_weights_by_expert, num_tokens_per_expert.tolist()) | |
| elif version == "arange": | |
| sorted_token_activation_by_expert = expert_indices.view(-1).argsort(stable=True) | |
| sorted_tokens_by_expert = (sorted_token_activation_by_expert/experts_per_token).floor().to(torch.int) | |
| cum_tokens_per_expert = torch.histc(expert_indices, bins=num_experts+1, min=-1, max=num_experts).cumsum(0) | |
| sorted_weights_by_expert = expert_weights.view(-1)[sorted_token_activation_by_expert] | |
| tok_indices_per_expert = [] | |
| tok_weights_per_expert = [] | |
| for i in range(num_experts): | |
| indices = torch.arange(cum_tokens_per_expert[i],cum_tokens_per_expert[i+1]) # fails on arange for fullgraph | |
| tok_indices_per_expert.append(sorted_tokens_by_expert[indices]) | |
| tok_weights_per_expert.append(sorted_weights_by_expert[indices].view(-1, 1)) | |
| elif version == "index": | |
| sorted_token_activation_by_expert = expert_indices.view(-1).argsort(stable=True) | |
| sorted_tokens_by_expert = (sorted_token_activation_by_expert/experts_per_token).floor().to(torch.int) | |
| cum_tokens_per_expert = torch.histc(expert_indices, bins=num_experts+1, min=-1, max=num_experts).cumsum(0) | |
| sorted_weights_by_expert = expert_weights.view(-1)[sorted_token_activation_by_expert] | |
| tok_indices_per_expert = [] | |
| tok_weights_per_expert = [] | |
| for i in range(num_experts): | |
| tok_indices_per_expert.append(sorted_tokens_by_expert[cum_tokens_per_expert[i]:cum_tokens_per_expert[i+1]]) # fails on indexing for fullgraph | |
| tok_weights_per_expert.append(sorted_weights_by_expert[cum_tokens_per_expert[i]:cum_tokens_per_expert[i+1]].view(-1, 1)) | |
| else: | |
| raise Exception("unimplemented") | |
| return tok_indices_per_expert, tok_weights_per_expert | |
| class ConditionalFeedForward(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.w1 = nn.Parameter(torch.randn(8, 16, 4)) # E, I, D | |
| self.w2 = nn.Parameter(torch.randn(8, 4, 16)) # E, D, I | |
| self.w3 = nn.Parameter(torch.randn(8, 16, 4)) # E, I, D | |
| self.num_experts = 8 | |
| algorithm = "forloop" | |
| def forward( | |
| self, x: Tensor, # T, D | |
| expert_indices: Tensor, # T, A | |
| expert_weights: Tensor, # T, A | |
| num_activated_experts: int, | |
| ) -> Tensor: | |
| if algorithm == "gptfast": | |
| # This compiles the memory usage of the w(n)_weights grows as O(num_activated_experts) | |
| # whereas it can be constant with an expert-by-expert implementation (forloop) | |
| w1_weights = self.w1[expert_indices] # [T, A, I, D] | |
| w3_weights = self.w3[expert_indices] # [T, A, I, D] | |
| w2_weights = self.w2[expert_indices] # [T, A, D, I] | |
| x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights)) | |
| x3 = torch.einsum('ti, taoi -> tao', x, w3_weights) | |
| expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights) | |
| return torch.einsum('tai,ta -> ti', expert_outs, expert_weights) | |
| # This works for both cases but doesn't compile | |
| elif algorithm == "forloop": | |
| tok_indices_per_expert, tok_weights_per_expert = get_indices_and_weights_per_expert(expert_indices, expert_weights, self.num_experts) | |
| expert_list = [x for x in range(self.num_experts)] | |
| # tok_indices_per_expert, tok_weights_per_expert: list([T'(e0) ,T'(e1) , ...]) | |
| outs = [] | |
| for activated_expert_idx, expert in enumerate(expert_list): | |
| w1=self.w1[expert] # I, D | |
| w2=self.w2[expert] # D, I | |
| w3=self.w3[expert] # I, D | |
| tok_indices = tok_indices_per_expert[activated_expert_idx] | |
| cur_x = x[tok_indices] # T', D | |
| cur_weights = tok_weights_per_expert[activated_expert_idx] # T' | |
| cur_out = F.linear( F.silu(F.linear(cur_x, w1)) * F.linear(cur_x*cur_weights, w3), w2) # T', D | |
| outs.append(cur_out) | |
| mixed_out = torch.cat(outs, dim=0) #T*A,D | |
| final_out = torch.zeros_like(x) # T, D | |
| indices=expert_indices.view(-1).argsort(stable=True).div(num_activated_experts).floor().to(torch.int64).unsqueeze(-1) | |
| final_out = final_out.scatter_add(dim=0, index=indices, src=mixed_out) | |
| return final_out | |
| elif algorithm == "carefulindex": | |
| ordered_token_activations = expert_indices.view(-1).argsort(stable=True) | |
| ordered_tokens =ordered_token_activations.div(num_activated_experts).floor().to(torch.int64) | |
| num_tokens_per_expert = torch.histc(expert_indices, bins=self.num_experts+1, min=-1, max=self.num_experts).cumsum(0) | |
| expert_starts = num_tokens_per_expert[:-1] | |
| expert_ends = num_tokens_per_expert[1:] | |
| mega_input = x[ordered_tokens] | |
| outs = [] | |
| expert_list = [x for x in range(self.num_experts)] | |
| # expert_list = [x for x in expert_indices.unique()] # does this work? | |
| for expert in expert_list: | |
| w1=self.w1[expert] # I, D | |
| w2=self.w2[expert] # D, I | |
| w3=self.w3[expert] # I, D | |
| cur_x = mega_input[expert_starts[expert]:expert_ends[expert]] | |
| cur_out = F.linear( F.silu(F.linear(cur_x, w1)) * F.linear(cur_x, w3), w2) # T', D | |
| outs.append(cur_out) | |
| ordered_outs = torch.cat(outs, dim=0) | |
| weight_order = expert_weights.view(-1,1)[ordered_token_activations].view(-1,1) | |
| weighted_ordered_outs = ordered_outs*weight_order | |
| final_out = torch.zeros_like(x) | |
| final_out = final_out.scatter_add(dim=0, index=ordered_tokens.unsqueeze(-1).expand(4,4), src=weighted_ordered_outs) | |
| return final_out | |
| else: | |
| raise Exception("unimplemented") | |
| # different algorithms I've tried | |
| algorithm = "carefulindex" | |
| # algorithm = "forloop" | |
| # version = "index" | |
| # version = "split" | |
| # version = "arange" | |
| # algorithm = "gptfast" | |
| moe = MOEFeedForward().to("cuda").to(torch.bfloat16) | |
| input1 = torch.randn(1, 2, 4).to("cuda").to(torch.bfloat16) | |
| input2 = torch.randn(1, 2, 4).to("cuda").to(torch.bfloat16) | |
| with torch.no_grad(): | |
| out1 = moe(input1) | |
| print(out1.sum()) | |
| out2 = moe(input2) | |
| print(out2.sum()) | |
| # moe_c = torch.compile(moe, mode="reduce-overhead") # working | |
| moe_c = torch.compile(moe, mode="reduce-overhead", fullgraph=True) #this fails on token shuffle part | |
| moe_c(input1) | |
| moe_c(input2) | |
| out1c = moe_c(input1) | |
| print(out1c.sum()) | |
| out2c = moe_c(input2) | |
| print(out2c.sum()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment