Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions modelopt/torch/export/moe_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for Mixture-of-Experts (MoE) model export."""

from pathlib import Path

import torch.nn as nn


def save_expert_token_count_table(model: nn.Module, output_dir: str | Path | None = None):
"""Collect expert_token_count from all quantized MoE layers and save as an HTML table.

The table has rows for each MoE layer and columns for each expert, with cell values
showing the number of tokens routed to that expert during calibration.

Args:
model: The model containing quantized MoE layers with ``expert_token_count`` attributes.
output_dir: Directory to save the HTML file. Defaults to current directory.
"""
rows = []
for name, module in model.named_modules():
if hasattr(module, "expert_token_count") and module.expert_token_count.numel() > 0:
rows.append((name, module.expert_token_count))

if not rows:
return

num_experts = rows[0][1].shape[0]
html_parts = [
"<html><head><style>",
"table { border-collapse: collapse; font-family: monospace; }",
"th, td { border: 1px solid #ccc; padding: 4px 8px; text-align: right; }",
"th { background: #f0f0f0; }",
"</style></head><body>",
"<h2>Expert Token Counts (per MoE layer)</h2>",
"<table><tr><th>Layer/Expert</th>",
]
html_parts.extend(f"<th>{i}</th>" for i in range(num_experts))
html_parts.append("</tr>")
Comment on lines +41 to +55
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Assumes all MoE layers have the same number of experts.

num_experts = rows[0][1].shape[0] is used to build the table header once, but subsequent rows each emit one <td> per their own expert count. If any two MoE layers in a model have different expert counts (e.g., fine-grained vs. shared experts in DeepSeek-style models, or future heterogeneous architectures), the HTML table will have misaligned columns — rows with fewer experts than the header will be missing cells, and rows with more will overflow.

🛡️ Proposed fix: build header dynamically from max width
-num_experts = rows[0][1].shape[0]
+num_experts = max(counts.shape[0] for _, counts in rows)
 html_parts = [
     ...
     "<table><tr><th>Layer/Expert</th>",
 ]
 html_parts.extend(f"<th>{i}</th>" for i in range(num_experts))

With this fix, each row loop may also need padding:

 for c in counts.tolist():
     ...
     html_parts.append(f"<td{style}>{c}</td>")
+# Pad missing expert columns for layers with fewer experts
+for _ in range(num_experts - len(counts)):
+    html_parts.append("<td>-</td>")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/moe_utils.py` around lines 41 - 52, The code assumes a
constant expert count by using num_experts = rows[0][1].shape[0] to build the
table header, which breaks when MoE layers have different expert counts; change
the header generation to compute the maximum expert count across rows (e.g.,
max(r[1].shape[0] for r in rows)) and use that to create html_parts header
cells, and when emitting each row (the loop that writes per-layer <td> cells)
pad shorter rows with empty <td></td> cells (or colspan equivalently) up to that
max so every row has the same number of columns and the table remains aligned.
Ensure you update references to num_experts accordingly and keep
html_parts.extend and row emission logic consistent.


for name, counts in rows:
avg = counts.float().mean().item()
html_parts.append(f"<tr><td>{name}</td>")
for c in counts.tolist():
Comment on lines +57 to +60
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The module name is interpolated into HTML without escaping. For models that come from remote/custom code, a module name containing </& could inject markup into the report. Escape name (and any other dynamic text) before writing HTML (e.g., via html.escape).

Copilot uses AI. Check for mistakes.
if avg > 0 and c < avg * 0.05:
style = ' style="background: #ff6666;"'
elif avg > 0 and c < avg * 0.1:
style = ' style="background: #ffcccc;"'
else:
style = ""
html_parts.append(f"<td{style}>{c}</td>")
html_parts.append("</tr>")

html_parts.append("</table></body></html>")
html_content = "\n".join(html_parts)

if output_dir is None:
output_dir = Path(".")
output_path = Path(output_dir) / ".moe.html"
output_path.write_text(html_content)
print(f"Expert token count table saved to {output_path}")
3 changes: 3 additions & 0 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
QUANTIZATION_W4A8_NVFP4_FP8,
)
from .model_utils import get_language_model_from_vl, is_multimodal_model
from .moe_utils import save_expert_token_count_table
from .plugins import export_spec_ckpt_config, export_spec_ckpt_state_dict, spec_opt_only
from .quant_utils import (
fuse_prequant_layernorm,
Expand Down Expand Up @@ -1003,6 +1004,8 @@ def export_hf_checkpoint(
try:
post_state_dict, hf_quant_config = _export_transformers_checkpoint(model, dtype)

save_expert_token_count_table(model, export_dir)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

save_expert_token_count_table failure aborts model export.

The call sits inside the try block alongside _export_transformers_checkpoint. If writing .moe.html fails for any reason (disk full, permission error, unexpected exception for an edge-case module), the exception propagates to the outer except, which warns and re-raises — preventing model.save_pretrained from running. A diagnostic HTML report should not gate the primary export operation.

♻️ Proposed fix: isolate the diagnostic step
+        try:
             save_expert_token_count_table(model, export_dir)
+        except Exception as report_err:
+            warnings.warn(
+                f"Failed to save expert token count table: {report_err}. "
+                "Model export will continue."
+            )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/unified_export_hf.py` around lines 1004 - 1007, The
call to save_expert_token_count_table is inside the same try as
_export_transformers_checkpoint so any failure aborts the whole export and
prevents model.save_pretrained from running; move or guard this diagnostic write
so it cannot raise out of the main export flow — either call
save_expert_token_count_table after the try/except that handles
_export_transformers_checkpoint, or wrap
save_expert_token_count_table(export_dir, model) in its own try/except that
catches Exception, logs a non-fatal warning (including the exception), and
continues; ensure post_state_dict and hf_quant_config from
_export_transformers_checkpoint remain unaffected and model.save_pretrained is
always called even if the .moe.html write fails.


if hf_quant_config is not None:
# Save hf_quant_config.json for backward compatibility
with open(f"{export_dir}/hf_quant_config.json", "w") as file:
Expand Down
157 changes: 90 additions & 67 deletions modelopt/torch/quantization/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,20 +450,50 @@ class _QuantSparseMoe(QuantModule):
"""

def _setup(self):
pass
num_experts = 0
if hasattr(self, "gate") and hasattr(self.gate, "num_experts"):
num_experts = self.gate.num_experts
elif hasattr(self, "num_experts"):
num_experts = self.num_experts
elif hasattr(self, "experts") and hasattr(self.experts, "num_experts"):
num_experts = self.experts.num_experts

self.expert_token_count = torch.zeros(num_experts, dtype=torch.long, device="cpu")
self._count_expert_tokens = False

if hasattr(self, "gate"):
self.gate.register_forward_hook(self._gate_forward_hook)
Comment on lines 452 to +472
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Silent expert_token_count of shape (0,) when num_experts cannot be resolved.

If none of the three attribute lookups (gate.num_experts, self.num_experts, experts.num_experts) succeeds, num_experts stays 0 and expert_token_count is silently initialized to an empty tensor. The save_expert_token_count_table function will silently skip this layer (due to the numel() > 0 guard), giving no indication that the layer's routing was not tracked. A warning here would help debugging.

🛡️ Proposed fix
+import warnings
 ...
 self.expert_token_count = torch.zeros(num_experts, dtype=torch.long, device="cpu")
+if num_experts == 0:
+    warnings.warn(
+        f"Could not determine num_experts for {type(self).__name__}; "
+        "expert_token_count will not be tracked.",
+        stacklevel=2,
+    )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/plugins/huggingface.py` around lines 452 - 465,
The _setup method can silently create an empty expert_token_count when
num_experts remains 0; update _setup (method name: _setup, symbols: self.gate,
gate.num_experts, self.num_experts, self.experts, experts.num_experts,
expert_token_count, save_expert_token_count_table) to detect when num_experts ==
0 after the three lookups and emit a clear warning (e.g., warnings.warn or
self.logger.warning) indicating the layer's routing won't be tracked and naming
the module (use self.__class__.__name__ or similar), then proceed (or skip
registering the hook) so callers know to fix the model instead of silently
losing tracking.


Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
@torch.no_grad()

def _gate_forward_hook(self, module, input, output):
if not self._count_expert_tokens:
return
with torch.no_grad():
if isinstance(output, tuple) and len(output) >= 3:
# v5.x TopKRouter: returns (logits, scores, indices)
indices = output[2]
else:
# v4.x nn.Linear gate: returns logits tensor
logits = output if not isinstance(output, tuple) else output[0]
top_k = self.gate.top_k if hasattr(self.gate, "top_k") else self.top_k
_, indices = torch.topk(logits.float(), top_k, dim=-1)
counts = torch.bincount(
indices.reshape(-1).cpu(), minlength=len(self.expert_token_count)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling CPU here will add unnecessary CPU-GPU sync overhead. Can we accumulate to a GPU tensor?

)
self.expert_token_count += counts
Comment on lines +477 to +489
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have a try except here?

This looks risk and could fail

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you share why this is risky?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems too tailored for the model - What if we run on a model for which this pattern is not compatible?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is anyway a telemetry code right ? We can afford this to fail and still run PTQ and QAT without any issues.


def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if any(getattr(m, "_if_calib", False) for m in self.experts.modules()):
is_calib = any(getattr(m, "_if_calib", False) for m in self.experts.modules())
if is_calib:
# If any of the experts are in calibration mode, we will forward all tokens to all experts
# This is used only for calibration, we need to re-calculate the actual outputs again using
# the original top_k
if TRANSFORMERS_VERSION_GE_5_0:
assert hasattr(self, "gate")
# Path for transformers >= 5.0
original_top_k = self.gate.topk
self.gate.topk = self.gate.num_experts
original_top_k = self.gate.top_k
self.gate.top_k = self.gate.num_experts
super().forward(hidden_states)
self.gate.topk = original_top_k
self.gate.top_k = original_top_k
else:
# Path for transformers < 5.0
original_top_k = self.top_k
Expand All @@ -475,7 +505,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
raise ValueError(f"Could not find num_experts in module {self}")
super().forward(hidden_states)
self.top_k = original_top_k
return super().forward(hidden_states)
# Enable counting only for the real-routing forward during calibration
self._count_expert_tokens = is_calib
output = super().forward(hidden_states)
self._count_expert_tokens = False
return output


class _QuantLlama4TextExperts(QuantModule):
Expand Down Expand Up @@ -765,10 +799,7 @@ def unpack_weight(self):


try:
from transformers.models.llama4.modeling_llama4 import Llama4TextExperts, Llama4TextMoe

if Llama4TextMoe not in QuantModuleRegistry:
QuantModuleRegistry.register({Llama4TextMoe: "hf.Llama4TextMoe"})(_QuantSparseMoe)
from transformers.models.llama4.modeling_llama4 import Llama4TextExperts

if Llama4TextExperts not in QuantModuleRegistry:
QuantModuleRegistry.register({Llama4TextExperts: "hf.Llama4TextExperts"})(
Expand All @@ -791,16 +822,6 @@ def unpack_weight(self):
except ImportError:
pass

try:
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

if MixtralSparseMoeBlock not in QuantModuleRegistry:
QuantModuleRegistry.register({MixtralSparseMoeBlock: "hf.MixtralSparseMoeBlock"})(
_QuantSparseMoe
)
except ImportError:
pass

try:
from transformers.models.falcon.modeling_falcon import FalconLinear

Expand All @@ -809,36 +830,6 @@ def unpack_weight(self):
except ImportError:
pass

try:
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock

if Qwen3MoeSparseMoeBlock not in QuantModuleRegistry:
QuantModuleRegistry.register({Qwen3MoeSparseMoeBlock: "hf.Qwen3MoeSparseMoeBlock"})(
_QuantSparseMoe
)
except ImportError:
pass

try:
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock

if Qwen2MoeSparseMoeBlock not in QuantModuleRegistry:
QuantModuleRegistry.register({Qwen2MoeSparseMoeBlock: "hf.Qwen2MoeSparseMoeBlock"})(
_QuantSparseMoe
)
except ImportError:
pass

try:
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock

if Qwen3NextSparseMoeBlock not in QuantModuleRegistry:
QuantModuleRegistry.register({Qwen3NextSparseMoeBlock: "hf.Qwen3NextSparseMoeBlock"})(
_QuantSparseMoe
)
except ImportError:
pass

try:
from compressed_tensors.linear.compressed_linear import CompressedLinear

Expand All @@ -850,15 +841,7 @@ def unpack_weight(self):
pass

try:
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
Qwen3VLMoeTextExperts,
Qwen3VLMoeTextSparseMoeBlock,
)

if Qwen3VLMoeTextSparseMoeBlock not in QuantModuleRegistry:
QuantModuleRegistry.register(
{Qwen3VLMoeTextSparseMoeBlock: "hf.Qwen3VLMoeTextSparseMoeBlock"}
)(_QuantSparseMoe)
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextExperts

if Qwen3VLMoeTextExperts not in QuantModuleRegistry:
QuantModuleRegistry.register({Qwen3VLMoeTextExperts: "hf.Qwen3VLMoeTextExperts"})(
Expand Down Expand Up @@ -989,15 +972,55 @@ def register_falcon_linears_on_the_fly(model):
QuantModuleRegistry.register({linear_type: linear_type.__name__})(_QuantLinear)


def register_minimax_m2_moe_on_the_fly(model):
"""Register MiniMax M2 MoE modules as a QUANT_MODULE.
def _is_sparse_moe_block(module):
"""Check if a module is structurally a sparse MoE block compatible with _QuantSparseMoe.

MiniMax M2 MoE modules are defined in the model card, so we need to register them on the fly.
All HuggingFace MoE blocks (Mixtral, Qwen3Moe, Qwen2Moe, Qwen3Next, Llama4, MiniMax, etc.)
share a common structural pattern: a ``gate`` (TopKRouter) sub-module with routing attributes
(``top_k`` and ``num_experts``), and an ``experts`` sub-module.

This function detects that pattern instead of relying on class names, making it forward-compatible
with new MoE architectures. Some MoE models (e.g. Glm4MoeMoE) have ``gate`` and ``experts`` but
use a different routing interface (``n_routed_experts`` instead of ``num_experts``, custom
``route_tokens_to_experts``), so we require ``num_experts`` to be present to avoid false positives.
"""
if type(model).__name__ in ["MiniMaxM2ForCausalLM"]:
moe_type = type(model.model.layers[0].block_sparse_moe)
if QuantModuleRegistry.get(moe_type) is None:
QuantModuleRegistry.register({moe_type: moe_type.__name__})(_QuantSparseMoe)
if not hasattr(module, "experts"):
return False

# Primary: gate sub-module has topk/top_k + num_experts (standard TopKRouter pattern)
if hasattr(module, "gate"):
gate = module.gate
has_topk = hasattr(gate, "top_k")
has_num_experts = hasattr(gate, "num_experts")
if has_topk and has_num_experts:
return True

# Fallback: top_k + num_experts on the block itself (older transformers, e.g. v4.x Qwen3Next)
return hasattr(module, "top_k") and hasattr(module, "num_experts")


def register_sparse_moe_on_the_fly(model):
"""Auto-detect and register MOE modules as _QuantSparseMoe.

Walks the model tree, identifies MoE blocks by their structural attributes
(``gate`` + ``experts``), and registers unregistered ones with ``_QuantSparseMoe``.
"""
registered_types = set()
for name, module in model.named_modules():
mod_type = type(module)

# Avoid duplicate registration: skip if we already processed this type
# in this walk, or if it was previously registered in the QuantModuleRegistry.
if mod_type in registered_types or QuantModuleRegistry.get(mod_type) is not None:
continue
Comment thread
cjluo-nv marked this conversation as resolved.
Outdated

if _is_sparse_moe_block(module):
print(
f"\033[1mDetected MOE module '{name}' of type {mod_type.__name__}, "
f"registering with _QuantSparseMoe.\033[0m"
)
Comment on lines +1026 to +1029
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Library code should avoid printing directly (especially with ANSI escape codes) because it pollutes stdout in normal use and will print once per rank under distributed execution. Prefer using the project logging helpers (e.g., print_rank_0/warn_rank_0) or a logging logger, and consider gating this behind a verbosity/debug flag.

Copilot uses AI. Check for mistakes.
QuantModuleRegistry.register({mod_type: f"hf.{mod_type.__name__}"})(_QuantSparseMoe)
registered_types.add(mod_type)


def _is_supported_hf_model(model):
Expand Down Expand Up @@ -1065,7 +1088,7 @@ def _is_param_grad_enabled_for_auto_quantize(pname, model):
[
register_falcon_linears_on_the_fly,
register_dbrx_moe_on_the_fly,
register_minimax_m2_moe_on_the_fly,
register_sparse_moe_on_the_fly,
register_hf_attentions_on_the_fly,
convert_hf_parallel_linears_on_the_fly,
]
Expand Down
Loading
Loading