Created
November 7, 2025 03:30
-
-
Save yiliu30/c2271e28b606d8b5e3645ee7b2ff562e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| diff --git a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py | |
| index 69c03d8efb8..f3668018c43 100755 | |
| --- a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py | |
| +++ b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py | |
| @@ -930,6 +930,22 @@ class PatchedVllmMixtureOfExpertsOp(PatchedModuleBase): | |
| router_weights, | |
| permuted_weights=True, | |
| activation="silu"): | |
| + enable_moe_chunk = hasattr(self.orig_mod, "enable_moe_chunk") and self.orig_mod.enable_moe_chunk | |
| + if enable_moe_chunk: | |
| + return self._forward_quant( | |
| + hidden_states, expert_routing_table, router_weights, permuted_weights, activation | |
| + ) | |
| + else: | |
| + return self._chunk_moe( | |
| + hidden_states, expert_routing_table, router_weights, permuted_weights, activation | |
| + ) | |
| + | |
| + def _forward_quant(self, | |
| + hidden_states, | |
| + expert_routing_table, | |
| + router_weights, | |
| + permuted_weights=True, | |
| + activation="silu"): | |
| tokens_num, hidden_dim = hidden_states.shape | |
| extra_kwargs = self._get_extra_kwargs(tokens_num) | |
| experts_range = range(self.experts_used) | |
| @@ -956,6 +972,73 @@ class PatchedVllmMixtureOfExpertsOp(PatchedModuleBase): | |
| ) | |
| return output | |
| + def _chunk_moe(self, x, expert_routing_table, router_weights, permuted_weights=True, activation="silu"): | |
| + batched_tokens = x.shape[0] | |
| + kwargs = {} | |
| + orig_mod = self.orig_mod | |
| + | |
| + if orig_mod.enable_moe_chunk: | |
| + chunk_size = orig_mod.chunk_size_list[-1] | |
| + for idx, threshold in enumerate(orig_mod.token_boundary_list): | |
| + if batched_tokens <= threshold: | |
| + chunk_size = orig_mod.chunk_size_list[idx] | |
| + break | |
| + kwargs = { | |
| + "chunk_size": chunk_size, | |
| + "total_experts": 256, | |
| + } | |
| + | |
| + qinput = self.quant_input(x) | |
| + # tokens_num, hidden_dim = hidden_states.shape | |
| + # extra_kwargs = self._get_extra_kwargs(tokens_num) | |
| + extra_kwargs = kwargs | |
| + experts_range = range(self.experts_used) | |
| + w1_list = [self.w13_list[i].weight for i in experts_range] | |
| + w2_list = [self.w2_list[i].weight for i in experts_range] | |
| + scale_w1 = [self.w13_list[i].scale_weight for i in experts_range] | |
| + scale_w2 = [self.w2_list[i].scale_weight for i in experts_range] | |
| + | |
| + def _inner_forward(cur_qinput, cur_expert_routing_table, cur_router_weights, scale_input, extra_kwargs): | |
| + output = self.dynamic_moe_op( | |
| + hidden_states=cur_qinput, | |
| + expert_routing_table=cur_expert_routing_table, | |
| + router_weights=cur_router_weights, | |
| + w12=w1_list, | |
| + w3=w2_list, | |
| + d_scale_w12=scale_w1, | |
| + d_scale_w3=scale_w2, | |
| + d_scale_hidden_states=scale_input, | |
| + d_scale_intermediate_hidden_states=self.scale_intermediate, | |
| + permuted_weights=False, | |
| + activation=activation, | |
| + experts_min=self.experts_min, | |
| + experts_max=self.experts_max, | |
| + **extra_kwargs, | |
| + ) | |
| + return output | |
| + | |
| + if batched_tokens > orig_mod.moe_slice_length: | |
| + final_hidden_states_list = [] | |
| + n_slice = (batched_tokens + orig_mod.moe_slice_length - 1) // orig_mod.moe_slice_length | |
| + for i in range(n_slice): | |
| + s = i * orig_mod.moe_slice_length | |
| + e = batched_tokens if i == (n_slice - 1) else (i + 1) * orig_mod.moe_slice_length | |
| + cur_qinput = qinput[s:e, ...] | |
| + cur_expert_routing_table = expert_routing_table[s:e, ...] | |
| + cur_router_weights = router_weights[s:e, ...] | |
| + scale_input = self.scale_input | |
| + cur_out = _inner_forward( | |
| + cur_qinput, cur_expert_routing_table, cur_router_weights, scale_input, extra_kwargs | |
| + ) | |
| + final_hidden_states_list.append(cur_out) | |
| + final_hidden_states = torch.cat(final_hidden_states_list, dim=0) | |
| + else: | |
| + final_hidden_states = _inner_forward( | |
| + qinput, expert_routing_table, router_weights, self.scale_input, extra_kwargs | |
| + ) | |
| + | |
| + return final_hidden_states.view(-1, x.shape[1]) | |
| + | |
| def forward_dynamic_quant( | |
| self, hidden_states, expert_routing_table, router_weights, permuted_weights=True, layer=None, activation="silu" | |
| ): |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Caution
Below logic is incorrect! Please set
VLLM_SUPPORT_MOE_CHUNK=false, the default value ofVLLM_SUPPORT_MOE_CHUNKisfalse.