-
Notifications
You must be signed in to change notification settings - Fork 384
[NVBUG: 5804406] Auto detect MOE layers #900
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
919be0f
8baeaaf
7da77b9
2e29ee7
4b4ef63
9b9377a
0126ce7
ddf211a
bd98041
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,74 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Utilities for Mixture-of-Experts (MoE) model export.""" | ||
|
|
||
| from pathlib import Path | ||
|
|
||
| import torch.nn as nn | ||
|
|
||
|
|
||
| def save_expert_token_count_table(model: nn.Module, output_dir: str | Path | None = None): | ||
| """Collect expert_token_count from all quantized MoE layers and save as an HTML table. | ||
|
|
||
| The table has rows for each MoE layer and columns for each expert, with cell values | ||
| showing the number of tokens routed to that expert during calibration. | ||
|
|
||
| Args: | ||
| model: The model containing quantized MoE layers with ``expert_token_count`` attributes. | ||
| output_dir: Directory to save the HTML file. Defaults to current directory. | ||
| """ | ||
| rows = [] | ||
| for name, module in model.named_modules(): | ||
| if hasattr(module, "expert_token_count") and module.expert_token_count.numel() > 0: | ||
| rows.append((name, module.expert_token_count)) | ||
|
|
||
| if not rows: | ||
| return | ||
|
|
||
| num_experts = rows[0][1].shape[0] | ||
| html_parts = [ | ||
| "<html><head><style>", | ||
| "table { border-collapse: collapse; font-family: monospace; }", | ||
| "th, td { border: 1px solid #ccc; padding: 4px 8px; text-align: right; }", | ||
| "th { background: #f0f0f0; }", | ||
| "</style></head><body>", | ||
| "<h2>Expert Token Counts (per MoE layer)</h2>", | ||
| "<table><tr><th>Layer/Expert</th>", | ||
| ] | ||
| html_parts.extend(f"<th>{i}</th>" for i in range(num_experts)) | ||
| html_parts.append("</tr>") | ||
|
|
||
| for name, counts in rows: | ||
| avg = counts.float().mean().item() | ||
| html_parts.append(f"<tr><td>{name}</td>") | ||
| for c in counts.tolist(): | ||
|
Comment on lines
+57
to
+60
|
||
| if avg > 0 and c < avg * 0.05: | ||
| style = ' style="background: #ff6666;"' | ||
| elif avg > 0 and c < avg * 0.1: | ||
| style = ' style="background: #ffcccc;"' | ||
| else: | ||
| style = "" | ||
| html_parts.append(f"<td{style}>{c}</td>") | ||
| html_parts.append("</tr>") | ||
|
|
||
| html_parts.append("</table></body></html>") | ||
| html_content = "\n".join(html_parts) | ||
|
|
||
| if output_dir is None: | ||
| output_dir = Path(".") | ||
| output_path = Path(output_dir) / ".moe.html" | ||
| output_path.write_text(html_content) | ||
| print(f"Expert token count table saved to {output_path}") | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -76,6 +76,7 @@ | |
| QUANTIZATION_W4A8_NVFP4_FP8, | ||
| ) | ||
| from .model_utils import get_language_model_from_vl, is_multimodal_model | ||
| from .moe_utils import save_expert_token_count_table | ||
| from .plugins import export_spec_ckpt_config, export_spec_ckpt_state_dict, spec_opt_only | ||
| from .quant_utils import ( | ||
| fuse_prequant_layernorm, | ||
|
|
@@ -1003,6 +1004,8 @@ def export_hf_checkpoint( | |
| try: | ||
| post_state_dict, hf_quant_config = _export_transformers_checkpoint(model, dtype) | ||
|
|
||
| save_expert_token_count_table(model, export_dir) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The call sits inside the ♻️ Proposed fix: isolate the diagnostic step+ try:
save_expert_token_count_table(model, export_dir)
+ except Exception as report_err:
+ warnings.warn(
+ f"Failed to save expert token count table: {report_err}. "
+ "Model export will continue."
+ )🤖 Prompt for AI Agents |
||
|
|
||
| if hf_quant_config is not None: | ||
| # Save hf_quant_config.json for backward compatibility | ||
| with open(f"{export_dir}/hf_quant_config.json", "w") as file: | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -450,20 +450,50 @@ class _QuantSparseMoe(QuantModule): | |||||
| """ | ||||||
|
|
||||||
| def _setup(self): | ||||||
| pass | ||||||
| num_experts = 0 | ||||||
| if hasattr(self, "gate") and hasattr(self.gate, "num_experts"): | ||||||
| num_experts = self.gate.num_experts | ||||||
| elif hasattr(self, "num_experts"): | ||||||
| num_experts = self.num_experts | ||||||
| elif hasattr(self, "experts") and hasattr(self.experts, "num_experts"): | ||||||
| num_experts = self.experts.num_experts | ||||||
|
|
||||||
| self.expert_token_count = torch.zeros(num_experts, dtype=torch.long, device="cpu") | ||||||
| self._count_expert_tokens = False | ||||||
|
|
||||||
| if hasattr(self, "gate"): | ||||||
| self.gate.register_forward_hook(self._gate_forward_hook) | ||||||
|
Comment on lines
452
to
+472
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Silent If none of the three attribute lookups ( 🛡️ Proposed fix+import warnings
...
self.expert_token_count = torch.zeros(num_experts, dtype=torch.long, device="cpu")
+if num_experts == 0:
+ warnings.warn(
+ f"Could not determine num_experts for {type(self).__name__}; "
+ "expert_token_count will not be tracked.",
+ stacklevel=2,
+ )🤖 Prompt for AI Agents |
||||||
|
|
||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit:
Suggested change
|
||||||
| def _gate_forward_hook(self, module, input, output): | ||||||
| if not self._count_expert_tokens: | ||||||
| return | ||||||
| with torch.no_grad(): | ||||||
| if isinstance(output, tuple) and len(output) >= 3: | ||||||
| # v5.x TopKRouter: returns (logits, scores, indices) | ||||||
| indices = output[2] | ||||||
| else: | ||||||
| # v4.x nn.Linear gate: returns logits tensor | ||||||
| logits = output if not isinstance(output, tuple) else output[0] | ||||||
| top_k = self.gate.top_k if hasattr(self.gate, "top_k") else self.top_k | ||||||
| _, indices = torch.topk(logits.float(), top_k, dim=-1) | ||||||
| counts = torch.bincount( | ||||||
| indices.reshape(-1).cpu(), minlength=len(self.expert_token_count) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Calling CPU here will add unnecessary CPU-GPU sync overhead. Can we accumulate to a GPU tensor? |
||||||
| ) | ||||||
| self.expert_token_count += counts | ||||||
|
Comment on lines
+477
to
+489
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we have a try except here? This looks risk and could fail
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you share why this is risky?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this seems too tailored for the model - What if we run on a model for which this pattern is not compatible?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is anyway a telemetry code right ? We can afford this to fail and still run PTQ and QAT without any issues. |
||||||
|
|
||||||
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||||||
| if any(getattr(m, "_if_calib", False) for m in self.experts.modules()): | ||||||
| is_calib = any(getattr(m, "_if_calib", False) for m in self.experts.modules()) | ||||||
| if is_calib: | ||||||
| # If any of the experts are in calibration mode, we will forward all tokens to all experts | ||||||
| # This is used only for calibration, we need to re-calculate the actual outputs again using | ||||||
| # the original top_k | ||||||
| if TRANSFORMERS_VERSION_GE_5_0: | ||||||
| assert hasattr(self, "gate") | ||||||
| # Path for transformers >= 5.0 | ||||||
| original_top_k = self.gate.topk | ||||||
| self.gate.topk = self.gate.num_experts | ||||||
| original_top_k = self.gate.top_k | ||||||
| self.gate.top_k = self.gate.num_experts | ||||||
| super().forward(hidden_states) | ||||||
| self.gate.topk = original_top_k | ||||||
| self.gate.top_k = original_top_k | ||||||
| else: | ||||||
| # Path for transformers < 5.0 | ||||||
| original_top_k = self.top_k | ||||||
|
|
@@ -475,7 +505,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |||||
| raise ValueError(f"Could not find num_experts in module {self}") | ||||||
| super().forward(hidden_states) | ||||||
| self.top_k = original_top_k | ||||||
| return super().forward(hidden_states) | ||||||
| # Enable counting only for the real-routing forward during calibration | ||||||
| self._count_expert_tokens = is_calib | ||||||
| output = super().forward(hidden_states) | ||||||
| self._count_expert_tokens = False | ||||||
| return output | ||||||
|
|
||||||
|
|
||||||
| class _QuantLlama4TextExperts(QuantModule): | ||||||
|
|
@@ -765,10 +799,7 @@ def unpack_weight(self): | |||||
|
|
||||||
|
|
||||||
| try: | ||||||
| from transformers.models.llama4.modeling_llama4 import Llama4TextExperts, Llama4TextMoe | ||||||
|
|
||||||
| if Llama4TextMoe not in QuantModuleRegistry: | ||||||
| QuantModuleRegistry.register({Llama4TextMoe: "hf.Llama4TextMoe"})(_QuantSparseMoe) | ||||||
| from transformers.models.llama4.modeling_llama4 import Llama4TextExperts | ||||||
|
|
||||||
| if Llama4TextExperts not in QuantModuleRegistry: | ||||||
| QuantModuleRegistry.register({Llama4TextExperts: "hf.Llama4TextExperts"})( | ||||||
|
|
@@ -791,16 +822,6 @@ def unpack_weight(self): | |||||
| except ImportError: | ||||||
| pass | ||||||
|
|
||||||
| try: | ||||||
| from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock | ||||||
|
|
||||||
| if MixtralSparseMoeBlock not in QuantModuleRegistry: | ||||||
| QuantModuleRegistry.register({MixtralSparseMoeBlock: "hf.MixtralSparseMoeBlock"})( | ||||||
| _QuantSparseMoe | ||||||
| ) | ||||||
| except ImportError: | ||||||
| pass | ||||||
|
|
||||||
| try: | ||||||
| from transformers.models.falcon.modeling_falcon import FalconLinear | ||||||
|
|
||||||
|
|
@@ -809,36 +830,6 @@ def unpack_weight(self): | |||||
| except ImportError: | ||||||
| pass | ||||||
|
|
||||||
| try: | ||||||
| from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock | ||||||
|
|
||||||
| if Qwen3MoeSparseMoeBlock not in QuantModuleRegistry: | ||||||
| QuantModuleRegistry.register({Qwen3MoeSparseMoeBlock: "hf.Qwen3MoeSparseMoeBlock"})( | ||||||
| _QuantSparseMoe | ||||||
| ) | ||||||
| except ImportError: | ||||||
| pass | ||||||
|
|
||||||
| try: | ||||||
| from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock | ||||||
|
|
||||||
| if Qwen2MoeSparseMoeBlock not in QuantModuleRegistry: | ||||||
| QuantModuleRegistry.register({Qwen2MoeSparseMoeBlock: "hf.Qwen2MoeSparseMoeBlock"})( | ||||||
| _QuantSparseMoe | ||||||
| ) | ||||||
| except ImportError: | ||||||
| pass | ||||||
|
|
||||||
| try: | ||||||
| from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock | ||||||
|
|
||||||
| if Qwen3NextSparseMoeBlock not in QuantModuleRegistry: | ||||||
| QuantModuleRegistry.register({Qwen3NextSparseMoeBlock: "hf.Qwen3NextSparseMoeBlock"})( | ||||||
| _QuantSparseMoe | ||||||
| ) | ||||||
| except ImportError: | ||||||
| pass | ||||||
|
|
||||||
| try: | ||||||
| from compressed_tensors.linear.compressed_linear import CompressedLinear | ||||||
|
|
||||||
|
|
@@ -850,15 +841,7 @@ def unpack_weight(self): | |||||
| pass | ||||||
|
|
||||||
| try: | ||||||
| from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import ( | ||||||
| Qwen3VLMoeTextExperts, | ||||||
| Qwen3VLMoeTextSparseMoeBlock, | ||||||
| ) | ||||||
|
|
||||||
| if Qwen3VLMoeTextSparseMoeBlock not in QuantModuleRegistry: | ||||||
| QuantModuleRegistry.register( | ||||||
| {Qwen3VLMoeTextSparseMoeBlock: "hf.Qwen3VLMoeTextSparseMoeBlock"} | ||||||
| )(_QuantSparseMoe) | ||||||
| from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextExperts | ||||||
|
|
||||||
| if Qwen3VLMoeTextExperts not in QuantModuleRegistry: | ||||||
| QuantModuleRegistry.register({Qwen3VLMoeTextExperts: "hf.Qwen3VLMoeTextExperts"})( | ||||||
|
|
@@ -989,15 +972,55 @@ def register_falcon_linears_on_the_fly(model): | |||||
| QuantModuleRegistry.register({linear_type: linear_type.__name__})(_QuantLinear) | ||||||
|
|
||||||
|
|
||||||
| def register_minimax_m2_moe_on_the_fly(model): | ||||||
| """Register MiniMax M2 MoE modules as a QUANT_MODULE. | ||||||
| def _is_sparse_moe_block(module): | ||||||
| """Check if a module is structurally a sparse MoE block compatible with _QuantSparseMoe. | ||||||
|
|
||||||
| MiniMax M2 MoE modules are defined in the model card, so we need to register them on the fly. | ||||||
| All HuggingFace MoE blocks (Mixtral, Qwen3Moe, Qwen2Moe, Qwen3Next, Llama4, MiniMax, etc.) | ||||||
| share a common structural pattern: a ``gate`` (TopKRouter) sub-module with routing attributes | ||||||
| (``top_k`` and ``num_experts``), and an ``experts`` sub-module. | ||||||
|
|
||||||
| This function detects that pattern instead of relying on class names, making it forward-compatible | ||||||
| with new MoE architectures. Some MoE models (e.g. Glm4MoeMoE) have ``gate`` and ``experts`` but | ||||||
| use a different routing interface (``n_routed_experts`` instead of ``num_experts``, custom | ||||||
| ``route_tokens_to_experts``), so we require ``num_experts`` to be present to avoid false positives. | ||||||
| """ | ||||||
| if type(model).__name__ in ["MiniMaxM2ForCausalLM"]: | ||||||
| moe_type = type(model.model.layers[0].block_sparse_moe) | ||||||
| if QuantModuleRegistry.get(moe_type) is None: | ||||||
| QuantModuleRegistry.register({moe_type: moe_type.__name__})(_QuantSparseMoe) | ||||||
| if not hasattr(module, "experts"): | ||||||
| return False | ||||||
|
|
||||||
| # Primary: gate sub-module has topk/top_k + num_experts (standard TopKRouter pattern) | ||||||
| if hasattr(module, "gate"): | ||||||
| gate = module.gate | ||||||
| has_topk = hasattr(gate, "top_k") | ||||||
| has_num_experts = hasattr(gate, "num_experts") | ||||||
| if has_topk and has_num_experts: | ||||||
| return True | ||||||
|
|
||||||
| # Fallback: top_k + num_experts on the block itself (older transformers, e.g. v4.x Qwen3Next) | ||||||
| return hasattr(module, "top_k") and hasattr(module, "num_experts") | ||||||
|
|
||||||
|
|
||||||
| def register_sparse_moe_on_the_fly(model): | ||||||
| """Auto-detect and register MOE modules as _QuantSparseMoe. | ||||||
|
|
||||||
| Walks the model tree, identifies MoE blocks by their structural attributes | ||||||
| (``gate`` + ``experts``), and registers unregistered ones with ``_QuantSparseMoe``. | ||||||
| """ | ||||||
| registered_types = set() | ||||||
| for name, module in model.named_modules(): | ||||||
| mod_type = type(module) | ||||||
|
|
||||||
| # Avoid duplicate registration: skip if we already processed this type | ||||||
| # in this walk, or if it was previously registered in the QuantModuleRegistry. | ||||||
| if mod_type in registered_types or QuantModuleRegistry.get(mod_type) is not None: | ||||||
| continue | ||||||
|
cjluo-nv marked this conversation as resolved.
Outdated
|
||||||
|
|
||||||
| if _is_sparse_moe_block(module): | ||||||
| print( | ||||||
| f"\033[1mDetected MOE module '{name}' of type {mod_type.__name__}, " | ||||||
| f"registering with _QuantSparseMoe.\033[0m" | ||||||
| ) | ||||||
|
Comment on lines
+1026
to
+1029
|
||||||
| QuantModuleRegistry.register({mod_type: f"hf.{mod_type.__name__}"})(_QuantSparseMoe) | ||||||
| registered_types.add(mod_type) | ||||||
|
|
||||||
|
|
||||||
| def _is_supported_hf_model(model): | ||||||
|
|
@@ -1065,7 +1088,7 @@ def _is_param_grad_enabled_for_auto_quantize(pname, model): | |||||
| [ | ||||||
| register_falcon_linears_on_the_fly, | ||||||
| register_dbrx_moe_on_the_fly, | ||||||
| register_minimax_m2_moe_on_the_fly, | ||||||
| register_sparse_moe_on_the_fly, | ||||||
| register_hf_attentions_on_the_fly, | ||||||
| convert_hf_parallel_linears_on_the_fly, | ||||||
| ] | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assumes all MoE layers have the same number of experts.
num_experts = rows[0][1].shape[0]is used to build the table header once, but subsequent rows each emit one<td>per their own expert count. If any two MoE layers in a model have different expert counts (e.g., fine-grained vs. shared experts in DeepSeek-style models, or future heterogeneous architectures), the HTML table will have misaligned columns — rows with fewer experts than the header will be missing cells, and rows with more will overflow.🛡️ Proposed fix: build header dynamically from max width
With this fix, each row loop may also need padding:
for c in counts.tolist(): ... html_parts.append(f"<td{style}>{c}</td>") +# Pad missing expert columns for layers with fewer experts +for _ in range(num_experts - len(counts)): + html_parts.append("<td>-</td>")🤖 Prompt for AI Agents