Created
November 19, 2025 02:12
-
-
Save yiliu30/9e0695c3a9bffc427e61eab4567a3fe3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| diff --git a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py | |
| index 69c03d8efb8..f3668018c43 100755 | |
| --- a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py | |
| +++ b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py | |
| @@ -930,6 +930,22 @@ class PatchedVllmMixtureOfExpertsOp(PatchedModuleBase): | |
| router_weights, | |
| permuted_weights=True, | |
| activation="silu"): | |
| + enable_moe_chunk = hasattr(self.orig_mod, "enable_moe_chunk") and self.orig_mod.enable_moe_chunk | |
| + if not enable_moe_chunk: | |
| + return self._forward_quant( | |
| + hidden_states, expert_routing_table, router_weights, permuted_weights, activation | |
| + ) | |
| + else: | |
| + return self._chunk_moe( | |
| + hidden_states, expert_routing_table, router_weights, permuted_weights, activation | |
| + ) | |
| + | |
| + def _forward_quant(self, | |
| + hidden_states, | |
| + expert_routing_table, | |
| + router_weights, | |
| + permuted_weights=True, | |
| + activation="silu"): | |
| tokens_num, hidden_dim = hidden_states.shape | |
| extra_kwargs = self._get_extra_kwargs(tokens_num) | |
| experts_range = range(self.experts_used) | |
| @@ -956,6 +972,73 @@ class PatchedVllmMixtureOfExpertsOp(PatchedModuleBase): | |
| ) | |
| return output | |
| + def _chunk_moe(self, x, expert_routing_table, router_weights, permuted_weights=True, activation="silu"): | |
| + batched_tokens = x.shape[0] | |
| + kwargs = {} | |
| + orig_mod = self.orig_mod | |
| + | |
| + if orig_mod.enable_moe_chunk: | |
| + chunk_size = orig_mod.chunk_size_list[-1] | |
| + for idx, threshold in enumerate(orig_mod.token_boundary_list): | |
| + if batched_tokens <= threshold: | |
| + chunk_size = orig_mod.chunk_size_list[idx] | |
| + break | |
| + kwargs = { | |
| + "chunk_size": chunk_size, | |
| + "total_experts": 256, | |
| + } | |
| + | |
| + qinput = self.quant_input(x) | |
| + # tokens_num, hidden_dim = hidden_states.shape | |
| + # extra_kwargs = self._get_extra_kwargs(tokens_num) | |
| + extra_kwargs = kwargs | |
| + experts_range = range(self.experts_used) | |
| + w1_list = [self.w13_list[i].weight for i in experts_range] | |
| + w2_list = [self.w2_list[i].weight for i in experts_range] | |
| + scale_w1 = [self.w13_list[i].scale_weight for i in experts_range] | |
| + scale_w2 = [self.w2_list[i].scale_weight for i in experts_range] | |
| + | |
| + def _inner_forward(cur_qinput, cur_expert_routing_table, cur_router_weights, scale_input, extra_kwargs): | |
| + output = self.dynamic_moe_op( | |
| + hidden_states=cur_qinput, | |
| + expert_routing_table=cur_expert_routing_table, | |
| + router_weights=cur_router_weights, | |
| + w12=w1_list, | |
| + w3=w2_list, | |
| + d_scale_w12=scale_w1, | |
| + d_scale_w3=scale_w2, | |
| + d_scale_hidden_states=scale_input, | |
| + d_scale_intermediate_hidden_states=self.scale_intermediate, | |
| + permuted_weights=False, | |
| + activation=activation, | |
| + experts_min=self.experts_min, | |
| + experts_max=self.experts_max, | |
| + **extra_kwargs, | |
| + ) | |
| + return output | |
| + | |
| + if batched_tokens > orig_mod.moe_slice_length: | |
| + final_hidden_states_list = [] | |
| + n_slice = (batched_tokens + orig_mod.moe_slice_length - 1) // orig_mod.moe_slice_length | |
| + for i in range(n_slice): | |
| + s = i * orig_mod.moe_slice_length | |
| + e = batched_tokens if i == (n_slice - 1) else (i + 1) * orig_mod.moe_slice_length | |
| + cur_qinput = qinput[s:e, ...] | |
| + cur_expert_routing_table = expert_routing_table[s:e, ...] | |
| + cur_router_weights = router_weights[s:e, ...] | |
| + scale_input = self.scale_input | |
| + cur_out = _inner_forward( | |
| + cur_qinput, cur_expert_routing_table, cur_router_weights, scale_input, extra_kwargs | |
| + ) | |
| + final_hidden_states_list.append(cur_out) | |
| + final_hidden_states = torch.cat(final_hidden_states_list, dim=0) | |
| + else: | |
| + final_hidden_states = _inner_forward( | |
| + qinput, expert_routing_table, router_weights, self.scale_input, extra_kwargs | |
| + ) | |
| + | |
| + return final_hidden_states.view(-1, x.shape[1]) | |
| + | |
| def forward_dynamic_quant( | |
| self, hidden_states, expert_routing_table, router_weights, permuted_weights=True, layer=None, activation="silu" | |
| ): |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Caution
Please set
VLLM_SUPPORT_MOE_CHUNK=true, the default value ofVLLM_SUPPORT_MOE_CHUNKisfalse.