Skip to content

Instantly share code, notes, and snippets.

@HDCharles
Last active March 19, 2025 04:08
Show Gist options
  • Select an option

  • Save HDCharles/415e412e12bbad2acc8462070a4e3605 to your computer and use it in GitHub Desktop.

Select an option

Save HDCharles/415e412e12bbad2acc8462070a4e3605 to your computer and use it in GitHub Desktop.
code testing moe implementations with compile
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F
from dataclasses import dataclass
torch.manual_seed(0)
# T tokens
# E experts
# D dim
# I intermediate dim
# A activated experts
# T'(e) tokens for expert e
class MOEFeedForward(nn.Module):
def __init__(self) -> None:
super().__init__()
self.gate = nn.Linear(4, 8, bias=False)
self.cond_ffn = ConditionalFeedForward()
self.dim = 4
self.num_activated_experts = 2
def forward(self, x: Tensor) -> Tensor:
batch_size = x.shape[0]
x = x.view(-1, self.dim) # x: [T, D]
scores = self.gate(x) # [T, E]
expert_weights = F.softmax(scores, dim=-1)
expert_weights, expert_indices = torch.topk(expert_weights, self.num_activated_experts, dim=-1) # [T, A], [T, A]
expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A]
out = self.cond_ffn(x, expert_indices, expert_weights, self.num_activated_experts)
return out.reshape(batch_size, -1, self.dim)
# THIS IS WHERE COMPILE KEEPS FAILING
def get_indices_and_weights_per_expert(expert_indices, expert_weights, num_experts):
num_tokens, experts_per_token = expert_indices.shape
# goal: extract the tokens used by each expert by sorting expert_indices by expert, then indexing the sorted list based on how many tokens each expert gets
# expert_indices = [[0, 1] [1, 3], [0, 2]] i.e. 0th token goes to experts 0 and 1, next token goes to experts 1 and 3 ...etc
# num_tokens_per_expert = [2, 2, 1, 1] i.e. 0th expert gets token 0 and 2 ([[0*, 1] [1, 3], [0*, 2]])
# sorted_tokens_by_expert = [0, 2, 0, 1, 2, 1] # need to break up sorted_tokens_by_expert into groups of size num_tokens_per_expert
# tok_indices_per_expert = [|0, 2 | 0, 1 | 2 | 1 |]
# ->[[0, 2][ 0, 1][ 2 ][1 ]]
# ALL 3 OF THESE DO THE SAME THING, NEED TO GET INDICES OF TOKEN FOR EACH EXPERT
# IS THERE A COMPILE COMPATIBLE WAY TO DO THIS?
# do_split: uses torch.split
# do_arange: uses arange
# index: uses indexing to token shuffle
if version == "split":
sorted_token_activation_by_expert = expert_indices.view(-1).argsort(stable=True)
sorted_tokens_by_expert = (sorted_token_activation_by_expert/experts_per_token).floor().to(torch.int)
num_tokens_per_expert = torch.histc(expert_indices, bins=num_experts+1, min=-1, max=num_experts)
# arrange weights in same way as tokens and then group weights by expert
sorted_weights_by_expert = expert_weights.view(-1)[sorted_token_activation_by_expert].view(-1,1)
tok_indices_per_expert = torch.split(sorted_tokens_by_expert, num_tokens_per_expert.tolist()) # fails on tolist, breaks cudagraph
tok_weights_per_expert = torch.split(sorted_weights_by_expert, num_tokens_per_expert.tolist())
elif version == "arange":
sorted_token_activation_by_expert = expert_indices.view(-1).argsort(stable=True)
sorted_tokens_by_expert = (sorted_token_activation_by_expert/experts_per_token).floor().to(torch.int)
cum_tokens_per_expert = torch.histc(expert_indices, bins=num_experts+1, min=-1, max=num_experts).cumsum(0)
sorted_weights_by_expert = expert_weights.view(-1)[sorted_token_activation_by_expert]
tok_indices_per_expert = []
tok_weights_per_expert = []
for i in range(num_experts):
indices = torch.arange(cum_tokens_per_expert[i],cum_tokens_per_expert[i+1]) # fails on arange for fullgraph
tok_indices_per_expert.append(sorted_tokens_by_expert[indices])
tok_weights_per_expert.append(sorted_weights_by_expert[indices].view(-1, 1))
elif version == "index":
sorted_token_activation_by_expert = expert_indices.view(-1).argsort(stable=True)
sorted_tokens_by_expert = (sorted_token_activation_by_expert/experts_per_token).floor().to(torch.int)
cum_tokens_per_expert = torch.histc(expert_indices, bins=num_experts+1, min=-1, max=num_experts).cumsum(0)
sorted_weights_by_expert = expert_weights.view(-1)[sorted_token_activation_by_expert]
tok_indices_per_expert = []
tok_weights_per_expert = []
for i in range(num_experts):
tok_indices_per_expert.append(sorted_tokens_by_expert[cum_tokens_per_expert[i]:cum_tokens_per_expert[i+1]]) # fails on indexing for fullgraph
tok_weights_per_expert.append(sorted_weights_by_expert[cum_tokens_per_expert[i]:cum_tokens_per_expert[i+1]].view(-1, 1))
else:
raise Exception("unimplemented")
return tok_indices_per_expert, tok_weights_per_expert
class ConditionalFeedForward(nn.Module):
def __init__(self):
super().__init__()
self.w1 = nn.Parameter(torch.randn(8, 16, 4)) # E, I, D
self.w2 = nn.Parameter(torch.randn(8, 4, 16)) # E, D, I
self.w3 = nn.Parameter(torch.randn(8, 16, 4)) # E, I, D
self.num_experts = 8
algorithm = "forloop"
def forward(
self, x: Tensor, # T, D
expert_indices: Tensor, # T, A
expert_weights: Tensor, # T, A
num_activated_experts: int,
) -> Tensor:
if algorithm == "gptfast":
# This compiles the memory usage of the w(n)_weights grows as O(num_activated_experts)
# whereas it can be constant with an expert-by-expert implementation (forloop)
w1_weights = self.w1[expert_indices] # [T, A, I, D]
w3_weights = self.w3[expert_indices] # [T, A, I, D]
w2_weights = self.w2[expert_indices] # [T, A, D, I]
x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights))
x3 = torch.einsum('ti, taoi -> tao', x, w3_weights)
expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights)
return torch.einsum('tai,ta -> ti', expert_outs, expert_weights)
# This works for both cases but doesn't compile
elif algorithm == "forloop":
tok_indices_per_expert, tok_weights_per_expert = get_indices_and_weights_per_expert(expert_indices, expert_weights, self.num_experts)
expert_list = [x for x in range(self.num_experts)]
# tok_indices_per_expert, tok_weights_per_expert: list([T'(e0) ,T'(e1) , ...])
outs = []
for activated_expert_idx, expert in enumerate(expert_list):
w1=self.w1[expert] # I, D
w2=self.w2[expert] # D, I
w3=self.w3[expert] # I, D
tok_indices = tok_indices_per_expert[activated_expert_idx]
cur_x = x[tok_indices] # T', D
cur_weights = tok_weights_per_expert[activated_expert_idx] # T'
cur_out = F.linear( F.silu(F.linear(cur_x, w1)) * F.linear(cur_x*cur_weights, w3), w2) # T', D
outs.append(cur_out)
mixed_out = torch.cat(outs, dim=0) #T*A,D
final_out = torch.zeros_like(x) # T, D
indices=expert_indices.view(-1).argsort(stable=True).div(num_activated_experts).floor().to(torch.int64).unsqueeze(-1)
final_out = final_out.scatter_add(dim=0, index=indices, src=mixed_out)
return final_out
elif algorithm == "carefulindex":
ordered_token_activations = expert_indices.view(-1).argsort(stable=True)
ordered_tokens =ordered_token_activations.div(num_activated_experts).floor().to(torch.int64)
num_tokens_per_expert = torch.histc(expert_indices, bins=self.num_experts+1, min=-1, max=self.num_experts).cumsum(0)
expert_starts = num_tokens_per_expert[:-1]
expert_ends = num_tokens_per_expert[1:]
mega_input = x[ordered_tokens]
outs = []
expert_list = [x for x in range(self.num_experts)]
# expert_list = [x for x in expert_indices.unique()] # does this work?
for expert in expert_list:
w1=self.w1[expert] # I, D
w2=self.w2[expert] # D, I
w3=self.w3[expert] # I, D
cur_x = mega_input[expert_starts[expert]:expert_ends[expert]]
cur_out = F.linear( F.silu(F.linear(cur_x, w1)) * F.linear(cur_x, w3), w2) # T', D
outs.append(cur_out)
ordered_outs = torch.cat(outs, dim=0)
weight_order = expert_weights.view(-1,1)[ordered_token_activations].view(-1,1)
weighted_ordered_outs = ordered_outs*weight_order
final_out = torch.zeros_like(x)
final_out = final_out.scatter_add(dim=0, index=ordered_tokens.unsqueeze(-1).expand(4,4), src=weighted_ordered_outs)
return final_out
else:
raise Exception("unimplemented")
# different algorithms I've tried
algorithm = "carefulindex"
# algorithm = "forloop"
# version = "index"
# version = "split"
# version = "arange"
# algorithm = "gptfast"
moe = MOEFeedForward().to("cuda").to(torch.bfloat16)
input1 = torch.randn(1, 2, 4).to("cuda").to(torch.bfloat16)
input2 = torch.randn(1, 2, 4).to("cuda").to(torch.bfloat16)
with torch.no_grad():
out1 = moe(input1)
print(out1.sum())
out2 = moe(input2)
print(out2.sum())
# moe_c = torch.compile(moe, mode="reduce-overhead") # working
moe_c = torch.compile(moe, mode="reduce-overhead", fullgraph=True) #this fails on token shuffle part
moe_c(input1)
moe_c(input2)
out1c = moe_c(input1)
print(out1c.sum())
out2c = moe_c(input2)
print(out2c.sum())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment