Skip to content

Instantly share code, notes, and snippets.

@yiliu30
Created November 19, 2025 02:12
Show Gist options
  • Select an option

  • Save yiliu30/9e0695c3a9bffc427e61eab4567a3fe3 to your computer and use it in GitHub Desktop.

Select an option

Save yiliu30/9e0695c3a9bffc427e61eab4567a3fe3 to your computer and use it in GitHub Desktop.
diff --git a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py
index 69c03d8efb8..f3668018c43 100755
--- a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py
+++ b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py
@@ -930,6 +930,22 @@ class PatchedVllmMixtureOfExpertsOp(PatchedModuleBase):
router_weights,
permuted_weights=True,
activation="silu"):
+ enable_moe_chunk = hasattr(self.orig_mod, "enable_moe_chunk") and self.orig_mod.enable_moe_chunk
+ if not enable_moe_chunk:
+ return self._forward_quant(
+ hidden_states, expert_routing_table, router_weights, permuted_weights, activation
+ )
+ else:
+ return self._chunk_moe(
+ hidden_states, expert_routing_table, router_weights, permuted_weights, activation
+ )
+
+ def _forward_quant(self,
+ hidden_states,
+ expert_routing_table,
+ router_weights,
+ permuted_weights=True,
+ activation="silu"):
tokens_num, hidden_dim = hidden_states.shape
extra_kwargs = self._get_extra_kwargs(tokens_num)
experts_range = range(self.experts_used)
@@ -956,6 +972,73 @@ class PatchedVllmMixtureOfExpertsOp(PatchedModuleBase):
)
return output
+ def _chunk_moe(self, x, expert_routing_table, router_weights, permuted_weights=True, activation="silu"):
+ batched_tokens = x.shape[0]
+ kwargs = {}
+ orig_mod = self.orig_mod
+
+ if orig_mod.enable_moe_chunk:
+ chunk_size = orig_mod.chunk_size_list[-1]
+ for idx, threshold in enumerate(orig_mod.token_boundary_list):
+ if batched_tokens <= threshold:
+ chunk_size = orig_mod.chunk_size_list[idx]
+ break
+ kwargs = {
+ "chunk_size": chunk_size,
+ "total_experts": 256,
+ }
+
+ qinput = self.quant_input(x)
+ # tokens_num, hidden_dim = hidden_states.shape
+ # extra_kwargs = self._get_extra_kwargs(tokens_num)
+ extra_kwargs = kwargs
+ experts_range = range(self.experts_used)
+ w1_list = [self.w13_list[i].weight for i in experts_range]
+ w2_list = [self.w2_list[i].weight for i in experts_range]
+ scale_w1 = [self.w13_list[i].scale_weight for i in experts_range]
+ scale_w2 = [self.w2_list[i].scale_weight for i in experts_range]
+
+ def _inner_forward(cur_qinput, cur_expert_routing_table, cur_router_weights, scale_input, extra_kwargs):
+ output = self.dynamic_moe_op(
+ hidden_states=cur_qinput,
+ expert_routing_table=cur_expert_routing_table,
+ router_weights=cur_router_weights,
+ w12=w1_list,
+ w3=w2_list,
+ d_scale_w12=scale_w1,
+ d_scale_w3=scale_w2,
+ d_scale_hidden_states=scale_input,
+ d_scale_intermediate_hidden_states=self.scale_intermediate,
+ permuted_weights=False,
+ activation=activation,
+ experts_min=self.experts_min,
+ experts_max=self.experts_max,
+ **extra_kwargs,
+ )
+ return output
+
+ if batched_tokens > orig_mod.moe_slice_length:
+ final_hidden_states_list = []
+ n_slice = (batched_tokens + orig_mod.moe_slice_length - 1) // orig_mod.moe_slice_length
+ for i in range(n_slice):
+ s = i * orig_mod.moe_slice_length
+ e = batched_tokens if i == (n_slice - 1) else (i + 1) * orig_mod.moe_slice_length
+ cur_qinput = qinput[s:e, ...]
+ cur_expert_routing_table = expert_routing_table[s:e, ...]
+ cur_router_weights = router_weights[s:e, ...]
+ scale_input = self.scale_input
+ cur_out = _inner_forward(
+ cur_qinput, cur_expert_routing_table, cur_router_weights, scale_input, extra_kwargs
+ )
+ final_hidden_states_list.append(cur_out)
+ final_hidden_states = torch.cat(final_hidden_states_list, dim=0)
+ else:
+ final_hidden_states = _inner_forward(
+ qinput, expert_routing_table, router_weights, self.scale_input, extra_kwargs
+ )
+
+ return final_hidden_states.view(-1, x.shape[1])
+
def forward_dynamic_quant(
self, hidden_states, expert_routing_table, router_weights, permuted_weights=True, layer=None, activation="silu"
):
@yiliu30
Copy link
Author

yiliu30 commented Nov 19, 2025

Caution

Please set VLLM_SUPPORT_MOE_CHUNK=true, the default value of VLLM_SUPPORT_MOE_CHUNK is false.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment