Skip to content

[Quantization] MSE-calibrate every per-expert weight in fused-experts MoE#1421

Open
cjluo-nv wants to merge 2 commits intomainfrom
chenjiel/recipe-nvfp4-experts-mse-fp8-cast-kv-3
Open

[Quantization] MSE-calibrate every per-expert weight in fused-experts MoE#1421
cjluo-nv wants to merge 2 commits intomainfrom
chenjiel/recipe-nvfp4-experts-mse-fp8-cast-kv-3

Conversation

@cjluo-nv
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv commented May 8, 2026

What does this PR do?

Type of change: Bug fix

Two-part fix for transformers 5.x fused-experts containers (Qwen3-MoE / Qwen3.5-MoE / Mixtral / DeepSeek / Kimi-K2.x ...) where weight quantizers live in nn.ModuleLists (gate_up_proj_weight_quantizers, down_proj_weight_quantizers):

  1. Per-expert weight iteration for calibration. Add _QuantFusedExperts.iter_weights_for_calibration that yields per-expert (weight_slice, quantizer) pairs for both projections. The base impl uses singular *_weight_quantizer and silently skips fused-experts modules, so weight-only calibration paths never reached per-expert quantizers.

  2. mse_calibrate refactor.

    • Add _bootstrap_uncalibrated_weight_quantizers after max_calibrate to populate _amax on quantizers the forward pass didn't reach (dead MoE experts that received no calibration tokens). Runs the existing calibrator on the weight slice surfaced by iter_weights_for_calibration.
    • Replace the singular-only weight_attr_names discovery + getattr-by-name walk with an iter_weights_for_calibration walk done inside each parent module's enable_weight_access_and_writeback context, so MSE processes every per-expert quantizer (active and dead) and remains FSDP-safe.

Without this, the export-time fallback in _export_fused_experts derived separate gate/up amaxes from each half of the fused weight, breaking the gate==up weight_scale_2 invariant on dead experts.

Also includes:

  • _sanitize_generation_config_for_save in unified_export_hf — coerces do_sample=True when an upstream generation_config.json has top_k/top_p set, so newer transformers' strict validate doesn't block save_pretrained.
  • Small companion plumbing in moe_utils.py, tensor_quantizer.py, and core_utils.py to support the per-expert iteration and bootstrap path.

Usage

import modelopt.torch.quantization as mtq
from modelopt.recipe import load_config

# Recipe `nvfp4_experts_only_mse-kv_fp8_cast` (already on main) now correctly
# MSE-calibrates every per-expert weight quantizer in fused-experts MoE models.
cfg = load_config("general/ptq/nvfp4_experts_only_mse-kv_fp8_cast")
mtq.quantize(model, cfg, forward_loop=calibration_forward_loop)

Testing

Original validation — Qwen3.5-122B-A10B with nvfp4_experts_only_mse-fp8_cast_kv:

  • Before: 1/12288 (layer 38 expert 69) gate \!= up; 0 weights MSE-calibrated.
  • After: 0/12288 mismatches; 24576 weights MSE-calibrated; ~4.2 min.

End-to-end pipeline validation — Qwen3.5-35B-A3B (40 layers × 256 experts × 2 projections = 20,480 per-expert weight quantizers), TRT-LLM 1.3.0rc13 + transformers 5.6 docker, single B200:

Path A (4-sample calib, deliberately undercalibrated) Path B (zero forward-pass tokens)
Per-expert weight quantizers calibrated 20,480 / 20,480 20,480 / 20,480
Missing _amax 0 0
All-zero _amax 0 0
mtq.quantize time 25–34 s 23 s
  • Cross-path diff: every per-expert weight amax matches bit-for-bit between the two paths (n=20480 exact=20480 diff=0 max_rel=0). With 8/256 experts routed per token and 4 calib samples, almost all experts are "dead" in Path A. Bootstrap fills them from max(|weight|), MSE searches deterministically from there → identical to Path B which bootstraps everything.
  • Export to HF NVFP4 checkpoint succeeded (~95 s, 22 GB checkpoint). Resulting generation_config.json has do_sample: true (upstream had top_k=20 + top_p=0.95 which would have failed strict validate).
  • TRT-LLM inference loaded the checkpoint and generated text: "Born in north-east France, Soyer trained as a"" tailor. Demonstrating his craft at a young age, at 20 he moved to Paris at the requests of the noble people of Picardy." (coherent grammar; factually wrong as expected with 4-sample calib, but no NaN/Inf in logits, no scale-mismatch crash). 92 GB GPU memory used.

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: N/A
  • Did you write any new necessary tests?: ❌ <!-- relies on existing recipe-level integration coverage; verified end-to-end on Qwen3.5-122B-A10B and Qwen3.5-35B-A3B + TRT-LLM 1.3.0rc13 -->
  • Did you update Changelog?: N/A
  • Did you get Claude approval on this PR?: ❌ <!-- will run `/claude review` -->

Additional Information

Follow-up to PR #1407 (MSE+FP8-cast-KV recipes). The recipe YAML files landed there; this PR fixes the calibration codepath so the MSE recipes actually exercise per-expert weight quantizers in fused-experts MoE containers.

@cjluo-nv cjluo-nv requested review from a team as code owners May 8, 2026 23:35
@cjluo-nv cjluo-nv requested a review from meenchen May 8, 2026 23:35
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 8, 2026

Review Change Stack

📝 Walkthrough

Walkthrough

Adds NVFP4-static grouped weight-quantizer support and synchronization, bootstraps missing per-weight amax, ensures per-expert fused-expert calibration coverage, adjusts fused-expert export amax slicing for static block quant, and sanitizes HF generation config before save.

Changes

NVFP4-static grouped quantizer support with MoE improvements

Layer / File(s) Summary
Type foundation
modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Added is_nvfp4_static property to detect NVFP4 static block quantization format without requiring calibrated amax values.
Utility simplification
modelopt/torch/quantization/utils/core_utils.py
Simplified promote_nvfp4_static_quantizers() to use is_nvfp4_static property instead of recomputing eligibility from quantizer internals.
Grouped quantizer patterns and sync
modelopt/torch/quantization/model_calib.py
Introduced pattern matching for grouped linears (Q/K/V and gate/up pairs), sibling detection, missing amax bootstrap, and exported sync_grouped_weight_global_amax() to unify grouped NVFP4-static global_amax grids.
MSE pre-calibration bootstrap & sync
modelopt/torch/quantization/model_calib.py
Runs _bootstrap_uncalibrated_weight_quantizers() and sync_grouped_weight_global_amax() immediately after max_calibrate to populate missing amax and unify grouped global_amax before MSE search.
Weight calibration refactoring
modelopt/torch/quantization/model_calib.py
Refactored weight discovery from weight_attr_names() loop to a two-pass QuantModule.iter_weights_for_calibration()-driven approach for improved fused-expert coverage and accurate progress counting.
Hessian calibration integration
modelopt/torch/quantization/model_calib.py
Updated local_hessian_calibrate to call sync_grouped_weight_global_amax() after max_calibrate and skip re-conversion when already NVFP4StaticQuantizer using is_nvfp4_static.
Fused expert weight iteration
modelopt/torch/quantization/plugins/huggingface.py
Added iter_weights_for_calibration() override to _QuantFusedExperts yielding per-expert weight quantizer pairs for calibration coverage.
Export amax slicing for fused experts
modelopt/torch/export/moe_utils.py
Extended _export_fused_experts to handle static block-quant _amax shapes by reshaping to restore row dimension and deleting existing buffer before re-registration.
Export integration for HuggingFace checkpoints
modelopt/torch/export/unified_export_hf.py
Added _sanitize_generation_config_for_save() helper and call before model.save_pretrained() to set do_sample=True when top_k/top_p are present.

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 6
✅ Passed checks (6 passed)
Check name Status Explanation
Title check ✅ Passed The title directly reflects the main change: enabling MSE calibration for per-expert weights in fused-expert MoE models, which is the primary bug fix addressed across multiple files.
Docstring Coverage ✅ Passed Docstring coverage is 88.89% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed Audit passed. No torch.load(weights_only=False), numpy.load(allow_pickle=True), trust_remote_code=True hardcoding, eval/exec, nosec comments, or problematic dependencies added.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch chenjiel/recipe-nvfp4-experts-mse-fp8-cast-kv-3

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 8, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1421/

Built to branch gh-pages at 2026-05-08 23:56 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@cjluo-nv cjluo-nv force-pushed the chenjiel/recipe-nvfp4-experts-mse-fp8-cast-kv-3 branch from 360b53e to 8e21516 Compare May 8, 2026 23:41
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@modelopt/torch/export/unified_export_hf.py`:
- Around line 1138-1160: The function _sanitize_generation_config_for_save
currently mutates model.generation_config.do_sample in-place and never restores
it; change the flow so the original value is preserved and restored after export
(e.g., capture original_do_sample = getattr(gc, "do_sample", None) before
setting gc.do_sample = True and restore gc.do_sample = original_do_sample after
the save operation), and apply the same pattern to the other affected block
around the code referenced at 1262-1270; locate usages of
_sanitize_generation_config_for_save (and the other block) surrounding the
save_pretrained/export call and ensure restoration occurs even on exceptions
(use try/finally or a context manager).

In `@modelopt/torch/quantization/model_calib.py`:
- Around line 121-162: The bootstrap loop in
_bootstrap_uncalibrated_weight_quantizers must run weight reads inside
enable_weight_access_and_writeback() so FSDP/HF-TP/offload sharded modules
perform proper local access/writeback instead of triggering an access failure
swallowed by the blanket except; wrap the per-module calibration work (the call
to module.iter_weights_for_calibration() and the q(weight) calibration call
inside the loop) with with enable_weight_access_and_writeback(module):, remove
or narrow the broad try/except that currently hides access errors so genuine
access failures surface, and keep the rest of the logic (q.disable_quant(),
q.enable_calib(), q(weight), q.load_calib_amax(), q.enable_quant(),
q.disable_calib(), q._calibrator.reset()) unchanged.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 5bd29857-aece-4f2e-9df6-23870190b9ec

📥 Commits

Reviewing files that changed from the base of the PR and between 1d796f9 and 360b53e.

📒 Files selected for processing (6)
  • modelopt/torch/export/moe_utils.py
  • modelopt/torch/export/unified_export_hf.py
  • modelopt/torch/quantization/model_calib.py
  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py
  • modelopt/torch/quantization/plugins/huggingface.py
  • modelopt/torch/quantization/utils/core_utils.py

Comment thread modelopt/torch/export/unified_export_hf.py
Comment on lines +121 to +162
@torch.no_grad()
def _bootstrap_uncalibrated_weight_quantizers(model: nn.Module) -> int:
"""Run a max-style amax collection on weight quantizers whose ``_amax`` is missing.

Forward-pass max calibration only populates per-expert weight quantizers in
fused-experts containers when tokens are routed to that expert. "Dead"
experts that received no tokens end up with no ``_amax``, which causes
``mse_calibrate``'s subsequent walk to skip them and forces the export-time
fallback to derive separate per-half amax for gate/up. This helper walks
every QuantModule's :meth:`iter_weights_for_calibration` pairs and, for any
quantizer that lacks ``_amax``, runs the existing calibrator (typically
:class:`MaxCalibrator`) on the corresponding weight slice — populating
``_amax`` from the weight rather than from runtime activations.

Returns the number of quantizers bootstrapped (mostly for diagnostics).
"""
n = 0
for module in model.modules():
if not isinstance(module, QuantModule):
continue
try:
pairs = list(module.iter_weights_for_calibration())
except Exception:
continue
for weight, q in pairs:
if not isinstance(q, TensorQuantizer) or q._disabled or q._dynamic:
continue
if q._calibrator is None:
continue
if hasattr(q, "_amax") and q._amax is not None and not torch.all(q._amax == 0):
continue
q.disable_quant()
q.enable_calib()
q(weight)
if q._calibrator.compute_amax() is not None:
q.load_calib_amax()
q.enable_quant()
q.disable_calib()
if hasattr(q._calibrator, "reset"):
q._calibrator.reset()
n += 1
return n
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Run dead-expert bootstrap under enable_weight_access_and_writeback().

This helper reads weight slices and calibrates them before entering any weight-access context. On FSDP/HF-TP/offloaded modules that can either calibrate only the local shard or hit an access failure that gets swallowed by the blanket except, leaving the “dead” expert unbootstrapped. That recreates the missing-_amax path this PR is trying to eliminate in distributed/export flows.

Suggested adjustment
 `@torch.no_grad`()
 def _bootstrap_uncalibrated_weight_quantizers(model: nn.Module) -> int:
     """Run a max-style amax collection on weight quantizers whose ``_amax`` is missing."""
     n = 0
+    name_to_module = dict(model.named_modules())
     for module in model.modules():
         if not isinstance(module, QuantModule):
             continue
-        try:
-            pairs = list(module.iter_weights_for_calibration())
-        except Exception:
-            continue
-        for weight, q in pairs:
-            if not isinstance(q, TensorQuantizer) or q._disabled or q._dynamic:
-                continue
-            if q._calibrator is None:
-                continue
-            if hasattr(q, "_amax") and q._amax is not None and not torch.all(q._amax == 0):
-                continue
-            q.disable_quant()
-            q.enable_calib()
-            q(weight)
-            if q._calibrator.compute_amax() is not None:
-                q.load_calib_amax()
-            q.enable_quant()
-            q.disable_calib()
-            if hasattr(q._calibrator, "reset"):
-                q._calibrator.reset()
-            n += 1
+        with enable_weight_access_and_writeback(module, model, name_to_module):
+            for weight, q in module.iter_weights_for_calibration():
+                if not isinstance(q, TensorQuantizer) or q._disabled or q._dynamic:
+                    continue
+                if q._calibrator is None:
+                    continue
+                if hasattr(q, "_amax") and q._amax is not None and not torch.all(q._amax == 0):
+                    continue
+                q.disable_quant()
+                q.enable_calib()
+                q(weight)
+                if q._calibrator.compute_amax() is not None:
+                    q.load_calib_amax()
+                q.enable_quant()
+                q.disable_calib()
+                if hasattr(q._calibrator, "reset"):
+                    q._calibrator.reset()
+                n += 1
     return n
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/quantization/model_calib.py` around lines 121 - 162, The
bootstrap loop in _bootstrap_uncalibrated_weight_quantizers must run weight
reads inside enable_weight_access_and_writeback() so FSDP/HF-TP/offload sharded
modules perform proper local access/writeback instead of triggering an access
failure swallowed by the blanket except; wrap the per-module calibration work
(the call to module.iter_weights_for_calibration() and the q(weight) calibration
call inside the loop) with with enable_weight_access_and_writeback(module):,
remove or narrow the broad try/except that currently hides access errors so
genuine access failures surface, and keep the rest of the logic
(q.disable_quant(), q.enable_calib(), q(weight), q.load_calib_amax(),
q.enable_quant(), q.disable_calib(), q._calibrator.reset()) unchanged.

@cjluo-nv cjluo-nv force-pushed the chenjiel/recipe-nvfp4-experts-mse-fp8-cast-kv-3 branch from 8e21516 to adee8b5 Compare May 8, 2026 23:49
cjluo-nv added 2 commits May 8, 2026 23:52
Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
… MoE

Two-part fix for transformers 5.x fused-experts containers (Qwen3-MoE /
Qwen3.5-MoE / Mixtral / DeepSeek / Kimi-K2.x ...) where weight quantizers
live in `nn.ModuleList`s (`gate_up_proj_weight_quantizers`,
`down_proj_weight_quantizers`):

1. Add `_QuantFusedExperts.iter_weights_for_calibration` that yields
   per-expert (weight_slice, quantizer) pairs for both projections. The base
   impl uses singular `*_weight_quantizer` and silently skips fused-experts
   modules, so weight-only calibration paths never reach per-expert
   quantizers.

2. Refactor `mse_calibrate`:
   - Add `_bootstrap_uncalibrated_weight_quantizers` after `max_calibrate`
     to populate `_amax` on quantizers the forward pass didn't reach (dead
     MoE experts that received no calibration tokens). Runs the existing
     calibrator on the weight slice surfaced by
     `iter_weights_for_calibration`.
   - Replace the singular-only `weight_attr_names` discovery + `getattr`-by-
     name walk with an `iter_weights_for_calibration` walk done inside each
     parent module's `enable_weight_access_and_writeback` context, so MSE
     processes every per-expert quantizer (active and dead) and remains
     FSDP-safe.

Without this, the export-time fallback in `_export_fused_experts` derived
separate gate/up amaxes from each half of the fused weight, breaking the
gate==up `weight_scale_2` invariant on dead experts. End-to-end check on
Qwen3.5-122B-A10B with `nvfp4_experts_only_mse-fp8_cast_kv`:
  - Before: 1/12288 (layer 38 expert 69) gate \!= up; 0 weights MSE-calibrated
  - After:  0/12288 mismatches; 24576 weights MSE-calibrated; ~4.2 min

Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
@cjluo-nv cjluo-nv force-pushed the chenjiel/recipe-nvfp4-experts-mse-fp8-cast-kv-3 branch from adee8b5 to 12e3c24 Compare May 8, 2026 23:52
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
modelopt/torch/export/unified_export_hf.py (1)

1137-1148: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Scope generation-config mutation and patch cleanup to avoid state leakage.

_sanitize_generation_config_for_save mutates model.generation_config.do_sample permanently, and it currently runs outside the try/finally that unpatches transformers internals. If sanitize raises, _unpatch_revert_weight_conversion is skipped; if it succeeds, do_sample still leaks into later calls on the same model object.

Proposed fix
-def _sanitize_generation_config_for_save(model: torch.nn.Module) -> None:
+def _sanitize_generation_config_for_save(model: torch.nn.Module) -> Callable[[], None]:
@@
-    gc = getattr(model, "generation_config", None)
-    if gc is None:
-        return
+    gc = getattr(model, "generation_config", None)
+    if gc is None:
+        return lambda: None
     if getattr(gc, "top_k", None) is not None or getattr(gc, "top_p", None) is not None:
-        gc.do_sample = True
+        original_do_sample = getattr(gc, "do_sample", None)
+        gc.do_sample = True
+        return lambda: setattr(gc, "do_sample", original_do_sample)
+    return lambda: None
@@
-        _sanitize_generation_config_for_save(model)
-
-        try:
+        restore_generation_config = lambda: None
+        try:
+            restore_generation_config = _sanitize_generation_config_for_save(model)
             model.save_pretrained(
                 export_dir,
                 state_dict={**post_state_dict, **(extra_state_dict or {})},
                 save_modelopt_state=save_modelopt_state,
                 max_shard_size=max_shard_size,
             )
         finally:
+            restore_generation_config()
             _unpatch_revert_weight_conversion(_patches)

Also applies to: 1242-1254

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/export/unified_export_hf.py` around lines 1137 - 1148, The
helper _sanitize_generation_config_for_save currently mutates
model.generation_config.do_sample permanently and runs outside the
unpatch/cleanup flow, causing state leakage and skipped cleanup on exceptions;
change it to only temporarily set do_sample: inside the same try/finally that
calls _unpatch_revert_weight_conversion, capture the original value (orig =
getattr(gc, "do_sample", None)), set gc.do_sample = True only when sampling
attrs exist, and always restore gc.do_sample = orig in the finally block so the
model's generation_config is unchanged after save. Apply the identical
temporary-mutation+restore pattern to the other occurrence referenced around
lines 1242-1254.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Duplicate comments:
In `@modelopt/torch/export/unified_export_hf.py`:
- Around line 1137-1148: The helper _sanitize_generation_config_for_save
currently mutates model.generation_config.do_sample permanently and runs outside
the unpatch/cleanup flow, causing state leakage and skipped cleanup on
exceptions; change it to only temporarily set do_sample: inside the same
try/finally that calls _unpatch_revert_weight_conversion, capture the original
value (orig = getattr(gc, "do_sample", None)), set gc.do_sample = True only when
sampling attrs exist, and always restore gc.do_sample = orig in the finally
block so the model's generation_config is unchanged after save. Apply the
identical temporary-mutation+restore pattern to the other occurrence referenced
around lines 1242-1254.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: a398d472-2294-43dc-80ef-57d6c78c285a

📥 Commits

Reviewing files that changed from the base of the PR and between adee8b5 and 12e3c24.

📒 Files selected for processing (6)
  • modelopt/torch/export/moe_utils.py
  • modelopt/torch/export/unified_export_hf.py
  • modelopt/torch/quantization/model_calib.py
  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py
  • modelopt/torch/quantization/plugins/huggingface.py
  • modelopt/torch/quantization/utils/core_utils.py
✅ Files skipped from review due to trivial changes (1)
  • modelopt/torch/quantization/plugins/huggingface.py
🚧 Files skipped from review as they are similar to previous changes (4)
  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py
  • modelopt/torch/quantization/utils/core_utils.py
  • modelopt/torch/export/moe_utils.py
  • modelopt/torch/quantization/model_calib.py

Copy link
Copy Markdown
Contributor

@meenchen meenchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

Follow-up bug fix to #1407 for dead-expert MSE calibration in fused-experts MoE. The design is not a new subsystem — the new sync_grouped_weight_global_amax composes on the existing preprocess_linear_fusion from quant_utils.py, and iter_weights_for_calibration extends an already-existing base-class hook. The fix targets a real correctness bug (gate/up weight_scale_2 divergence on experts that received no calibration tokens), backed by end-to-end Qwen3.5-122B numbers (0/12288 vs 1/12288 mismatches).

Reasons I did not approve:

  1. No unit tests for non-trivial new behavior. The PR body explicitly opts out ("Did you write any new necessary tests?: ❌"). tests/unit/torch/quantization/plugins/test_fused_experts.py already has good infrastructure (_SyntheticFusedExperts, TestFusedExpertsCalibration) that could exercise:

    • _QuantFusedExperts.iter_weights_for_calibration yielding num_experts * 2 pairs.
    • _bootstrap_uncalibrated_weight_quantizers populating _amax on never-routed experts (simulate by only routing to a subset in forward_loop).
    • The new _export_fused_experts per-block amax reshape path for NVFP4 (the existing test_uncalibrated_expert_gate_up_share_amax covers the per-tensor fallback but not the new amax.numel() % fused_total == 0 reshape branch).
    • sync_grouped_weight_global_amax unifying global_amax across a Q/K/V sibling group.
  2. _sanitize_generation_config_for_save silently mutates user state. It flips do_sample=False → True on the model's live generation_config whenever top_k/top_p is set, and this mutation is persisted to the exported generation_config.json. In practice this is a semantic no-op for greedy decoding, but the change is invisible to the caller. Consider (a) emitting a warning naming the fields being rewritten, or (b) doing the mutation on a copy that's scoped to the save_pretrained call only.

  3. _GROUPED_WEIGHT_QUANTIZER_PATTERNS is architecture-name heuristic. The hardcoded tuple covers Llama/Qwen/Mistral/Mixtral but will silently miss any model using different attribute names (wqkv, fused qkv_proj, DeepSeek naming variants, etc.) — grouped unification would just not run and export would fall back to per-module amax. Worth either documenting this as a known limitation in the docstring or logging when a model has NVFP4-static quantizers but produces zero groups.

Minor: sync_grouped_weight_global_amax is added to __all__ as public API, but given the hardcoded sibling-name heuristic it's really an internal helper; consider dropping it from __all__ or renaming to _sync_....

No licensing changes. Size is fine (+195/-71 across 6 files). Happy for a human reviewer with MoE export context to make the final call on the testing gap and the generation_config mutation policy.

@codecov
Copy link
Copy Markdown

codecov Bot commented May 9, 2026

Codecov Report

❌ Patch coverage is 74.10714% with 29 lines in your changes missing coverage. Please review.
✅ Project coverage is 76.80%. Comparing base (1d796f9) to head (12e3c24).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/quantization/model_calib.py 82.75% 15 Missing ⚠️
modelopt/torch/quantization/plugins/huggingface.py 12.50% 7 Missing ⚠️
modelopt/torch/export/moe_utils.py 0.00% 6 Missing ⚠️
modelopt/torch/export/unified_export_hf.py 85.71% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1421      +/-   ##
==========================================
- Coverage   77.30%   76.80%   -0.51%     
==========================================
  Files         478      478              
  Lines       51404    51480      +76     
==========================================
- Hits        39737    39537     -200     
- Misses      11667    11943     +276     
Flag Coverage Δ
examples 41.74% <15.17%> (-0.23%) ⬇️
gpu 59.86% <73.21%> (-0.58%) ⬇️
regression 15.21% <8.92%> (+0.07%) ⬆️
unit 52.48% <57.14%> (-0.01%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants