[Quantization] MSE-calibrate every per-expert weight in fused-experts MoE#1421
[Quantization] MSE-calibrate every per-expert weight in fused-experts MoE#1421
Conversation
📝 WalkthroughWalkthroughAdds NVFP4-static grouped weight-quantizer support and synchronization, bootstraps missing per-weight amax, ensures per-expert fused-expert calibration coverage, adjusts fused-expert export amax slicing for static block quant, and sanitizes HF generation config before save. ChangesNVFP4-static grouped quantizer support with MoE improvements
🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 6✅ Passed checks (6 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
|
360b53e to
8e21516
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@modelopt/torch/export/unified_export_hf.py`:
- Around line 1138-1160: The function _sanitize_generation_config_for_save
currently mutates model.generation_config.do_sample in-place and never restores
it; change the flow so the original value is preserved and restored after export
(e.g., capture original_do_sample = getattr(gc, "do_sample", None) before
setting gc.do_sample = True and restore gc.do_sample = original_do_sample after
the save operation), and apply the same pattern to the other affected block
around the code referenced at 1262-1270; locate usages of
_sanitize_generation_config_for_save (and the other block) surrounding the
save_pretrained/export call and ensure restoration occurs even on exceptions
(use try/finally or a context manager).
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 121-162: The bootstrap loop in
_bootstrap_uncalibrated_weight_quantizers must run weight reads inside
enable_weight_access_and_writeback() so FSDP/HF-TP/offload sharded modules
perform proper local access/writeback instead of triggering an access failure
swallowed by the blanket except; wrap the per-module calibration work (the call
to module.iter_weights_for_calibration() and the q(weight) calibration call
inside the loop) with with enable_weight_access_and_writeback(module):, remove
or narrow the broad try/except that currently hides access errors so genuine
access failures surface, and keep the rest of the logic (q.disable_quant(),
q.enable_calib(), q(weight), q.load_calib_amax(), q.enable_quant(),
q.disable_calib(), q._calibrator.reset()) unchanged.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 5bd29857-aece-4f2e-9df6-23870190b9ec
📒 Files selected for processing (6)
modelopt/torch/export/moe_utils.pymodelopt/torch/export/unified_export_hf.pymodelopt/torch/quantization/model_calib.pymodelopt/torch/quantization/nn/modules/tensor_quantizer.pymodelopt/torch/quantization/plugins/huggingface.pymodelopt/torch/quantization/utils/core_utils.py
| @torch.no_grad() | ||
| def _bootstrap_uncalibrated_weight_quantizers(model: nn.Module) -> int: | ||
| """Run a max-style amax collection on weight quantizers whose ``_amax`` is missing. | ||
|
|
||
| Forward-pass max calibration only populates per-expert weight quantizers in | ||
| fused-experts containers when tokens are routed to that expert. "Dead" | ||
| experts that received no tokens end up with no ``_amax``, which causes | ||
| ``mse_calibrate``'s subsequent walk to skip them and forces the export-time | ||
| fallback to derive separate per-half amax for gate/up. This helper walks | ||
| every QuantModule's :meth:`iter_weights_for_calibration` pairs and, for any | ||
| quantizer that lacks ``_amax``, runs the existing calibrator (typically | ||
| :class:`MaxCalibrator`) on the corresponding weight slice — populating | ||
| ``_amax`` from the weight rather than from runtime activations. | ||
|
|
||
| Returns the number of quantizers bootstrapped (mostly for diagnostics). | ||
| """ | ||
| n = 0 | ||
| for module in model.modules(): | ||
| if not isinstance(module, QuantModule): | ||
| continue | ||
| try: | ||
| pairs = list(module.iter_weights_for_calibration()) | ||
| except Exception: | ||
| continue | ||
| for weight, q in pairs: | ||
| if not isinstance(q, TensorQuantizer) or q._disabled or q._dynamic: | ||
| continue | ||
| if q._calibrator is None: | ||
| continue | ||
| if hasattr(q, "_amax") and q._amax is not None and not torch.all(q._amax == 0): | ||
| continue | ||
| q.disable_quant() | ||
| q.enable_calib() | ||
| q(weight) | ||
| if q._calibrator.compute_amax() is not None: | ||
| q.load_calib_amax() | ||
| q.enable_quant() | ||
| q.disable_calib() | ||
| if hasattr(q._calibrator, "reset"): | ||
| q._calibrator.reset() | ||
| n += 1 | ||
| return n |
There was a problem hiding this comment.
Run dead-expert bootstrap under enable_weight_access_and_writeback().
This helper reads weight slices and calibrates them before entering any weight-access context. On FSDP/HF-TP/offloaded modules that can either calibrate only the local shard or hit an access failure that gets swallowed by the blanket except, leaving the “dead” expert unbootstrapped. That recreates the missing-_amax path this PR is trying to eliminate in distributed/export flows.
Suggested adjustment
`@torch.no_grad`()
def _bootstrap_uncalibrated_weight_quantizers(model: nn.Module) -> int:
"""Run a max-style amax collection on weight quantizers whose ``_amax`` is missing."""
n = 0
+ name_to_module = dict(model.named_modules())
for module in model.modules():
if not isinstance(module, QuantModule):
continue
- try:
- pairs = list(module.iter_weights_for_calibration())
- except Exception:
- continue
- for weight, q in pairs:
- if not isinstance(q, TensorQuantizer) or q._disabled or q._dynamic:
- continue
- if q._calibrator is None:
- continue
- if hasattr(q, "_amax") and q._amax is not None and not torch.all(q._amax == 0):
- continue
- q.disable_quant()
- q.enable_calib()
- q(weight)
- if q._calibrator.compute_amax() is not None:
- q.load_calib_amax()
- q.enable_quant()
- q.disable_calib()
- if hasattr(q._calibrator, "reset"):
- q._calibrator.reset()
- n += 1
+ with enable_weight_access_and_writeback(module, model, name_to_module):
+ for weight, q in module.iter_weights_for_calibration():
+ if not isinstance(q, TensorQuantizer) or q._disabled or q._dynamic:
+ continue
+ if q._calibrator is None:
+ continue
+ if hasattr(q, "_amax") and q._amax is not None and not torch.all(q._amax == 0):
+ continue
+ q.disable_quant()
+ q.enable_calib()
+ q(weight)
+ if q._calibrator.compute_amax() is not None:
+ q.load_calib_amax()
+ q.enable_quant()
+ q.disable_calib()
+ if hasattr(q._calibrator, "reset"):
+ q._calibrator.reset()
+ n += 1
return n🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@modelopt/torch/quantization/model_calib.py` around lines 121 - 162, The
bootstrap loop in _bootstrap_uncalibrated_weight_quantizers must run weight
reads inside enable_weight_access_and_writeback() so FSDP/HF-TP/offload sharded
modules perform proper local access/writeback instead of triggering an access
failure swallowed by the blanket except; wrap the per-module calibration work
(the call to module.iter_weights_for_calibration() and the q(weight) calibration
call inside the loop) with with enable_weight_access_and_writeback(module):,
remove or narrow the broad try/except that currently hides access errors so
genuine access failures surface, and keep the rest of the logic
(q.disable_quant(), q.enable_calib(), q(weight), q.load_calib_amax(),
q.enable_quant(), q.disable_calib(), q._calibrator.reset()) unchanged.
8e21516 to
adee8b5
Compare
Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
… MoE
Two-part fix for transformers 5.x fused-experts containers (Qwen3-MoE /
Qwen3.5-MoE / Mixtral / DeepSeek / Kimi-K2.x ...) where weight quantizers
live in `nn.ModuleList`s (`gate_up_proj_weight_quantizers`,
`down_proj_weight_quantizers`):
1. Add `_QuantFusedExperts.iter_weights_for_calibration` that yields
per-expert (weight_slice, quantizer) pairs for both projections. The base
impl uses singular `*_weight_quantizer` and silently skips fused-experts
modules, so weight-only calibration paths never reach per-expert
quantizers.
2. Refactor `mse_calibrate`:
- Add `_bootstrap_uncalibrated_weight_quantizers` after `max_calibrate`
to populate `_amax` on quantizers the forward pass didn't reach (dead
MoE experts that received no calibration tokens). Runs the existing
calibrator on the weight slice surfaced by
`iter_weights_for_calibration`.
- Replace the singular-only `weight_attr_names` discovery + `getattr`-by-
name walk with an `iter_weights_for_calibration` walk done inside each
parent module's `enable_weight_access_and_writeback` context, so MSE
processes every per-expert quantizer (active and dead) and remains
FSDP-safe.
Without this, the export-time fallback in `_export_fused_experts` derived
separate gate/up amaxes from each half of the fused weight, breaking the
gate==up `weight_scale_2` invariant on dead experts. End-to-end check on
Qwen3.5-122B-A10B with `nvfp4_experts_only_mse-fp8_cast_kv`:
- Before: 1/12288 (layer 38 expert 69) gate \!= up; 0 weights MSE-calibrated
- After: 0/12288 mismatches; 24576 weights MSE-calibrated; ~4.2 min
Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
adee8b5 to
12e3c24
Compare
There was a problem hiding this comment.
♻️ Duplicate comments (1)
modelopt/torch/export/unified_export_hf.py (1)
1137-1148:⚠️ Potential issue | 🟠 Major | ⚡ Quick winScope generation-config mutation and patch cleanup to avoid state leakage.
_sanitize_generation_config_for_savemutatesmodel.generation_config.do_samplepermanently, and it currently runs outside thetry/finallythat unpatches transformers internals. If sanitize raises,_unpatch_revert_weight_conversionis skipped; if it succeeds,do_samplestill leaks into later calls on the same model object.Proposed fix
-def _sanitize_generation_config_for_save(model: torch.nn.Module) -> None: +def _sanitize_generation_config_for_save(model: torch.nn.Module) -> Callable[[], None]: @@ - gc = getattr(model, "generation_config", None) - if gc is None: - return + gc = getattr(model, "generation_config", None) + if gc is None: + return lambda: None if getattr(gc, "top_k", None) is not None or getattr(gc, "top_p", None) is not None: - gc.do_sample = True + original_do_sample = getattr(gc, "do_sample", None) + gc.do_sample = True + return lambda: setattr(gc, "do_sample", original_do_sample) + return lambda: None @@ - _sanitize_generation_config_for_save(model) - - try: + restore_generation_config = lambda: None + try: + restore_generation_config = _sanitize_generation_config_for_save(model) model.save_pretrained( export_dir, state_dict={**post_state_dict, **(extra_state_dict or {})}, save_modelopt_state=save_modelopt_state, max_shard_size=max_shard_size, ) finally: + restore_generation_config() _unpatch_revert_weight_conversion(_patches)Also applies to: 1242-1254
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/export/unified_export_hf.py` around lines 1137 - 1148, The helper _sanitize_generation_config_for_save currently mutates model.generation_config.do_sample permanently and runs outside the unpatch/cleanup flow, causing state leakage and skipped cleanup on exceptions; change it to only temporarily set do_sample: inside the same try/finally that calls _unpatch_revert_weight_conversion, capture the original value (orig = getattr(gc, "do_sample", None)), set gc.do_sample = True only when sampling attrs exist, and always restore gc.do_sample = orig in the finally block so the model's generation_config is unchanged after save. Apply the identical temporary-mutation+restore pattern to the other occurrence referenced around lines 1242-1254.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Duplicate comments:
In `@modelopt/torch/export/unified_export_hf.py`:
- Around line 1137-1148: The helper _sanitize_generation_config_for_save
currently mutates model.generation_config.do_sample permanently and runs outside
the unpatch/cleanup flow, causing state leakage and skipped cleanup on
exceptions; change it to only temporarily set do_sample: inside the same
try/finally that calls _unpatch_revert_weight_conversion, capture the original
value (orig = getattr(gc, "do_sample", None)), set gc.do_sample = True only when
sampling attrs exist, and always restore gc.do_sample = orig in the finally
block so the model's generation_config is unchanged after save. Apply the
identical temporary-mutation+restore pattern to the other occurrence referenced
around lines 1242-1254.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: a398d472-2294-43dc-80ef-57d6c78c285a
📒 Files selected for processing (6)
modelopt/torch/export/moe_utils.pymodelopt/torch/export/unified_export_hf.pymodelopt/torch/quantization/model_calib.pymodelopt/torch/quantization/nn/modules/tensor_quantizer.pymodelopt/torch/quantization/plugins/huggingface.pymodelopt/torch/quantization/utils/core_utils.py
✅ Files skipped from review due to trivial changes (1)
- modelopt/torch/quantization/plugins/huggingface.py
🚧 Files skipped from review as they are similar to previous changes (4)
- modelopt/torch/quantization/nn/modules/tensor_quantizer.py
- modelopt/torch/quantization/utils/core_utils.py
- modelopt/torch/export/moe_utils.py
- modelopt/torch/quantization/model_calib.py
meenchen
left a comment
There was a problem hiding this comment.
Bot review — DM the bot to share feedback.
Follow-up bug fix to #1407 for dead-expert MSE calibration in fused-experts MoE. The design is not a new subsystem — the new sync_grouped_weight_global_amax composes on the existing preprocess_linear_fusion from quant_utils.py, and iter_weights_for_calibration extends an already-existing base-class hook. The fix targets a real correctness bug (gate/up weight_scale_2 divergence on experts that received no calibration tokens), backed by end-to-end Qwen3.5-122B numbers (0/12288 vs 1/12288 mismatches).
Reasons I did not approve:
-
No unit tests for non-trivial new behavior. The PR body explicitly opts out ("Did you write any new necessary tests?: ❌").
tests/unit/torch/quantization/plugins/test_fused_experts.pyalready has good infrastructure (_SyntheticFusedExperts,TestFusedExpertsCalibration) that could exercise:_QuantFusedExperts.iter_weights_for_calibrationyieldingnum_experts * 2pairs._bootstrap_uncalibrated_weight_quantizerspopulating_amaxon never-routed experts (simulate by only routing to a subset inforward_loop).- The new
_export_fused_expertsper-block amax reshape path for NVFP4 (the existingtest_uncalibrated_expert_gate_up_share_amaxcovers the per-tensor fallback but not the newamax.numel() % fused_total == 0reshape branch). sync_grouped_weight_global_amaxunifyingglobal_amaxacross a Q/K/V sibling group.
-
_sanitize_generation_config_for_savesilently mutates user state. It flipsdo_sample=False → Trueon the model's livegeneration_configwhenevertop_k/top_pis set, and this mutation is persisted to the exportedgeneration_config.json. In practice this is a semantic no-op for greedy decoding, but the change is invisible to the caller. Consider (a) emitting a warning naming the fields being rewritten, or (b) doing the mutation on a copy that's scoped to thesave_pretrainedcall only. -
_GROUPED_WEIGHT_QUANTIZER_PATTERNSis architecture-name heuristic. The hardcoded tuple covers Llama/Qwen/Mistral/Mixtral but will silently miss any model using different attribute names (wqkv, fusedqkv_proj, DeepSeek naming variants, etc.) — grouped unification would just not run and export would fall back to per-module amax. Worth either documenting this as a known limitation in the docstring or logging when a model has NVFP4-static quantizers but produces zero groups.
Minor: sync_grouped_weight_global_amax is added to __all__ as public API, but given the hardcoded sibling-name heuristic it's really an internal helper; consider dropping it from __all__ or renaming to _sync_....
No licensing changes. Size is fine (+195/-71 across 6 files). Happy for a human reviewer with MoE export context to make the final call on the testing gap and the generation_config mutation policy.
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1421 +/- ##
==========================================
- Coverage 77.30% 76.80% -0.51%
==========================================
Files 478 478
Lines 51404 51480 +76
==========================================
- Hits 39737 39537 -200
- Misses 11667 11943 +276
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
What does this PR do?
Type of change: Bug fix
Two-part fix for transformers 5.x fused-experts containers (Qwen3-MoE / Qwen3.5-MoE / Mixtral / DeepSeek / Kimi-K2.x ...) where weight quantizers live in
nn.ModuleLists (gate_up_proj_weight_quantizers,down_proj_weight_quantizers):Per-expert weight iteration for calibration. Add
_QuantFusedExperts.iter_weights_for_calibrationthat yields per-expert(weight_slice, quantizer)pairs for both projections. The base impl uses singular*_weight_quantizerand silently skips fused-experts modules, so weight-only calibration paths never reached per-expert quantizers.mse_calibraterefactor._bootstrap_uncalibrated_weight_quantizersaftermax_calibrateto populate_amaxon quantizers the forward pass didn't reach (dead MoE experts that received no calibration tokens). Runs the existing calibrator on the weight slice surfaced byiter_weights_for_calibration.weight_attr_namesdiscovery +getattr-by-name walk with aniter_weights_for_calibrationwalk done inside each parent module'senable_weight_access_and_writebackcontext, so MSE processes every per-expert quantizer (active and dead) and remains FSDP-safe.Without this, the export-time fallback in
_export_fused_expertsderived separate gate/up amaxes from each half of the fused weight, breaking thegate==upweight_scale_2invariant on dead experts.Also includes:
_sanitize_generation_config_for_saveinunified_export_hf— coercesdo_sample=Truewhen an upstreamgeneration_config.jsonhastop_k/top_pset, so newer transformers' strict validate doesn't blocksave_pretrained.moe_utils.py,tensor_quantizer.py, andcore_utils.pyto support the per-expert iteration and bootstrap path.Usage
Testing
Original validation — Qwen3.5-122B-A10B with
nvfp4_experts_only_mse-fp8_cast_kv:gate \!= up; 0 weights MSE-calibrated.End-to-end pipeline validation — Qwen3.5-35B-A3B (40 layers × 256 experts × 2 projections = 20,480 per-expert weight quantizers), TRT-LLM 1.3.0rc13 + transformers 5.6 docker, single B200:
_amax_amaxmtq.quantizetimen=20480 exact=20480 diff=0 max_rel=0). With 8/256 experts routed per token and 4 calib samples, almost all experts are "dead" in Path A. Bootstrap fills them frommax(|weight|), MSE searches deterministically from there → identical to Path B which bootstraps everything.generation_config.jsonhasdo_sample: true(upstream hadtop_k=20+top_p=0.95which would have failed strict validate)."Born in north-east France, Soyer trained as a"→" tailor. Demonstrating his craft at a young age, at 20 he moved to Paris at the requests of the noble people of Picardy."(coherent grammar; factually wrong as expected with 4-sample calib, but no NaN/Inf in logits, no scale-mismatch crash). 92 GB GPU memory used.Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: N/AAdditional Information
Follow-up to PR #1407 (MSE+FP8-cast-KV recipes). The recipe YAML files landed there; this PR fixes the calibration codepath so the MSE recipes actually exercise per-expert weight quantizers in fused-experts MoE containers.