-
Notifications
You must be signed in to change notification settings - Fork 400
fixes for fused moe (qwen3.6, GLM5.1 + MSE calibration #1382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d797509
9aac0fb
60e1851
ab8a162
b161f3b
5dcda40
4de8abf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -828,6 +828,17 @@ def pre_quantize( | |
| preview_input_ids = next(iter(calib_dataloader))[ | ||
| "input_features" if model_type == "whisper" else "input_ids" | ||
| ][0:1] | ||
| # Strip leading padding tokens so the preview input shows real content | ||
| if model_type != "whisper" and tokenizer is not None and tokenizer.pad_token_id is not None: | ||
| first_non_pad = (preview_input_ids[0] != tokenizer.pad_token_id).nonzero(as_tuple=True)[0] | ||
| if first_non_pad.numel() > 0: | ||
| preview_input_ids = preview_input_ids[:, first_non_pad[0] :] | ||
| else: | ||
| warnings.warn( | ||
| "Preview calibration sample is entirely padding; generated preview will be " | ||
| "degenerate. Check tokenizer padding side / dataset preprocessing.", | ||
| stacklevel=2, | ||
| ) | ||
|
|
||
| # Generate preview before quantization | ||
| if args.skip_generate: | ||
|
|
@@ -928,7 +939,7 @@ def input_decode(input_ids): | |
| if processor is not None and isinstance(processor, WhisperProcessor): | ||
| return first_text_speech_dataset | ||
| elif tokenizer is not None: | ||
| return tokenizer.batch_decode(input_ids) | ||
| return tokenizer.batch_decode(input_ids, skip_special_tokens=True) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is this needed? Do you see any issues?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, when I test with Qwen3.6, in some setting the pad tokens are very long and fill up output token length
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @meenchen could you comment on this one? |
||
| else: | ||
| raise ValueError("The processor or tokenizer must be set") | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,6 +28,11 @@ | |
| __all__ = ["NVFP4QTensor"] | ||
|
|
||
|
|
||
| def _cast_per_block_scale_to_fp8(per_block_scale: torch.Tensor) -> torch.Tensor: | ||
| """Clamp to FP8 E4M3FN range [2**-9, 448] and cast — avoids underflow→0 / overflow→NaN.""" | ||
| return per_block_scale.clamp(min=2**-9, max=448.0).to(torch.float8_e4m3fn) | ||
|
|
||
|
|
||
| class NVFP4QTensor(BaseQuantizedTensor): | ||
| """Implements the INT4 quantization on tensors for more efficient storage or computation. | ||
|
|
||
|
|
@@ -122,16 +127,10 @@ def get_weights_scaling_factor_from_quantizer( | |
| expected_shape = (*weight.shape[:-1], num_blocks_per_row) | ||
| per_block_scale = per_block_scale.view(expected_shape) | ||
|
|
||
| # Quantize scales to FP8. Saturate to the fp8_e4m3fn max (448) before the | ||
| # cast: when the [==0]=1.0 safety net above fires (per_block_amax was zero | ||
| # for an all-zero weight block) and global_amax is small, the pre-cast value | ||
| # explodes to ``1.0 * 448 / (global_amax/6)``. fp8_e4m3fn has no Inf, so any | ||
| # value >= 480 casts to NaN — clamp first to keep the stored byte finite. | ||
| if not keep_high_precision: | ||
| per_block_scale = ( | ||
| (per_block_scale * 448.0 / per_block_scale_max) | ||
| .clamp_(max=448.0) | ||
| .to(torch.float8_e4m3fn) | ||
| # The [==0]=1.0 safety net + small global_amax can drive the pre-cast value above 448 (PR #1397). | ||
| per_block_scale = _cast_per_block_scale_to_fp8( | ||
| per_block_scale * 448.0 / per_block_scale_max | ||
| ) | ||
|
Comment on lines
130
to
134
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [SUGGESTION] The original code carried a useful comment explaining why |
||
| return per_block_scale, weights_scaling_factor_2 | ||
| else: | ||
|
|
@@ -171,9 +170,8 @@ def get_weights_scaling_factor( | |
| ) | ||
| # Set all zero values in scale to 1.0 | ||
| per_block_scale[per_block_scale == 0] = 1.0 | ||
| # Convert to torch.float8_e4m3fn | ||
| if not keep_high_precision: | ||
| per_block_scale = per_block_scale.to(torch.float8_e4m3fn) | ||
| per_block_scale = _cast_per_block_scale_to_fp8(per_block_scale) | ||
| return per_block_scale, weights_scaling_factor_2 | ||
|
|
||
| @classmethod | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,8 @@ | |
|
|
||
| """Tests for _QuantFusedExperts: generic fused MoE quantization and export.""" | ||
|
|
||
| from unittest.mock import patch | ||
|
|
||
| import pytest | ||
| import torch | ||
| import torch.nn as nn | ||
|
|
@@ -256,27 +258,51 @@ def test_expert_index_recovery(self): | |
| # Tests for export | ||
| # --------------------------------------------------------------------------- | ||
| class TestExportFusedExperts: | ||
| @staticmethod | ||
| def _cleanup_registry(mod_type): | ||
| if QuantModuleRegistry.get(mod_type) is not None: | ||
| QuantModuleRegistry.unregister(mod_type) | ||
|
|
||
| def test_export_creates_per_expert_submodules(self): | ||
| """_export_fused_experts should create per-expert submodules with standard naming.""" | ||
| import modelopt.torch.quantization as mtq | ||
| from modelopt.torch.export.moe_utils import _export_fused_experts | ||
|
|
||
| experts = _SyntheticFusedExperts() | ||
| expert_type = type(experts) | ||
| model = _TinyMoEModel() | ||
| expert_type = type(model.moe.experts) | ||
| self._cleanup_registry(expert_type) | ||
|
|
||
| # Manually register and convert | ||
| if QuantModuleRegistry.get(expert_type) is None: | ||
| QuantModuleRegistry.register({expert_type: "test.SyntheticFusedExperts"})( | ||
| _QuantFusedExperts | ||
| ) | ||
| converted = QuantModuleRegistry.convert(experts) | ||
| quant_cfg = { | ||
| "quant_cfg": [ | ||
| {"quantizer_name": "*", "enable": False}, | ||
| { | ||
| "quantizer_name": "*gate_up_proj_input_quantizer", | ||
| "cfg": {"num_bits": 8, "axis": None}, | ||
| }, | ||
| { | ||
| "quantizer_name": "*down_proj_input_quantizer", | ||
| "cfg": {"num_bits": 8, "axis": None}, | ||
| }, | ||
| { | ||
| "quantizer_name": "*gate_up_proj_weight_quantizer", | ||
| "cfg": {"num_bits": 8, "axis": 0}, | ||
| }, | ||
| { | ||
| "quantizer_name": "*down_proj_weight_quantizer", | ||
| "cfg": {"num_bits": 8, "axis": 0}, | ||
| }, | ||
| ], | ||
| "algorithm": "max", | ||
| } | ||
|
|
||
| # Run a forward pass to calibrate (set amaxes) | ||
| seq_len = 16 | ||
| hidden_states = torch.randn(seq_len, HIDDEN_DIM) | ||
| top_k_index = torch.randint(0, NUM_EXPERTS, (seq_len, TOP_K)) | ||
| top_k_weights = torch.softmax(torch.randn(seq_len, TOP_K), dim=-1) | ||
| with torch.no_grad(): | ||
| converted(hidden_states, top_k_index, top_k_weights) | ||
| def forward_loop(m): | ||
| torch.manual_seed(0) | ||
| for _ in range(2): | ||
| x = torch.randn(1, 4, HIDDEN_DIM) | ||
| m(x) | ||
|
|
||
| mtq.quantize(model, quant_cfg, forward_loop=forward_loop) | ||
| converted = model.moe.experts | ||
|
|
||
| _export_fused_experts(converted, torch.float16) | ||
|
|
||
|
|
@@ -297,8 +323,7 @@ def test_export_creates_per_expert_submodules(self): | |
| assert not hasattr(converted, "down_proj") | ||
| assert not hasattr(converted, "gate_up_proj_weight_quantizers") | ||
|
|
||
| if QuantModuleRegistry.get(expert_type) is not None: | ||
| QuantModuleRegistry.unregister(expert_type) | ||
| self._cleanup_registry(expert_type) | ||
|
|
||
| def test_uncalibrated_expert_gate_up_share_amax(self, monkeypatch): | ||
| """gate_proj and up_proj must share weight_scale_2 even when an expert | ||
|
|
@@ -899,3 +924,73 @@ def test_unrelated_dotted_number_unchanged(self): | |
| _normalize_fused_experts_quantizer_name("moe.layers.3.gate.weight") | ||
| == "moe.layers.3.gate.weight" | ||
| ) | ||
|
|
||
|
|
||
| # Verifies that MSE calibration discovers and calibrates every per-expert weight quantizer | ||
| # inside a fused-expert ModuleList (both gate_up_proj and down_proj, for all experts). | ||
| class TestFusedExpertsMSECalibration: | ||
| @staticmethod | ||
| def _cleanup_registry(mod_type): | ||
| if QuantModuleRegistry.get(mod_type) is not None: | ||
| QuantModuleRegistry.unregister(mod_type) | ||
|
|
||
| def test_mse_calibration_populates_all_expert_quantizers(self): | ||
| # Strong assertion: every per-expert weight quantizer must be touched by the MSE | ||
| # search loop (mse_calibrate Step 3), not just have _amax set by max-calibrate or | ||
| # the dead-expert bootstrap. Spy on MseCalibrator.collect — that method is only | ||
| # invoked from Step 3, after Step 2 installs MseCalibrator on each quantizer. | ||
| import modelopt.torch.quantization as mtq | ||
| from modelopt.torch.quantization.calib.mse import MseCalibrator | ||
|
|
||
| model = _TinyMoEModel() | ||
| expert_type = type(model.moe.experts) | ||
| self._cleanup_registry(expert_type) | ||
|
|
||
| collected_calib_ids: set[int] = set() | ||
| original_collect = MseCalibrator.collect | ||
|
|
||
| def _spy_collect(self, x): | ||
| collected_calib_ids.add(id(self)) | ||
| return original_collect(self, x) | ||
|
|
||
| with patch.object(MseCalibrator, "collect", _spy_collect): | ||
| mtq.quantize( | ||
| model, | ||
| { | ||
| "quant_cfg": [ | ||
| {"quantizer_name": "*", "enable": False}, | ||
| { | ||
| "quantizer_name": "*gate_up_proj_weight_quantizer", | ||
| "cfg": {"num_bits": 8, "axis": None}, | ||
| }, | ||
| { | ||
| "quantizer_name": "*down_proj_weight_quantizer", | ||
| "cfg": {"num_bits": 8, "axis": None}, | ||
| }, | ||
| ], | ||
| "algorithm": "mse", | ||
| }, | ||
| forward_loop=lambda m: [m(torch.randn(1, 4, HIDDEN_DIM)) for _ in range(2)], | ||
| ) | ||
|
|
||
| experts = model.moe.experts | ||
| missed = [] | ||
| for idx in range(NUM_EXPERTS): | ||
| assert experts.gate_up_proj_weight_quantizers[idx].amax is not None, ( | ||
| f"gate_up_proj_weight_quantizers[{idx}] not calibrated — Bug 1 regression" | ||
| ) | ||
| assert experts.down_proj_weight_quantizers[idx].amax is not None, ( | ||
| f"down_proj_weight_quantizers[{idx}] not calibrated" | ||
| ) | ||
|
Comment on lines
+978
to
+984
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [IMPORTANT Compatibility] This assertion does not actually verify the bug fix described in the PR ("MSE weight calibration: 0it" / per-expert MSE search not running). It only checks that
Both run unconditionally before MSE search inside To actually verify MSE ran on per-expert quantizers, do one of:
|
||
| if ( | ||
| id(experts.gate_up_proj_weight_quantizers[idx]._calibrator) | ||
| not in collected_calib_ids | ||
| ): | ||
| missed.append(f"gate_up_proj_weight_quantizers[{idx}]") | ||
| if id(experts.down_proj_weight_quantizers[idx]._calibrator) not in collected_calib_ids: | ||
| missed.append(f"down_proj_weight_quantizers[{idx}]") | ||
| assert not missed, ( | ||
| f"MSE search loop skipped these per-expert quantizers: {missed}. " | ||
| "mse_calibrate Step 3 did not iterate them via iter_weights_for_calibration." | ||
| ) | ||
| self._cleanup_registry(expert_type) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,114 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Tests for NVFP4QTensor per-block FP8 scale clamping (underflow + overflow).""" | ||
|
|
||
| from types import SimpleNamespace | ||
|
|
||
| import torch | ||
|
|
||
| from modelopt.torch.quantization.qtensor.nvfp4_tensor import ( | ||
| NVFP4QTensor, | ||
| _cast_per_block_scale_to_fp8, | ||
| ) | ||
|
|
||
| _FP8_E4M3FN_MIN = 2**-9 # 0.001953125 — smallest positive FP8 E4M3FN subnormal | ||
| _FP8_E4M3FN_MAX = 448.0 | ||
|
|
||
|
|
||
| class TestNVFP4ScaleClamping: | ||
| """Per-block weight scales outside the FP8 E4M3FN range must be clamped, not turned into 0/NaN.""" | ||
|
|
||
| def test_no_zero_scales_for_tiny_weights(self): | ||
| """Tiny per-block amax (<<FP8 min) must not underflow to zero after FP8 cast.""" | ||
| block_size = 16 | ||
| tiny_weight = torch.full((4, block_size), 1e-10) | ||
| # wsf2=1.0 → per_block_scale = amax/(6*wsf2) ≈ 1.7e-11 << 2^-9, exercises FP8-min clamp | ||
| wsf2 = torch.tensor(1.0) | ||
|
|
||
| per_block_scale, _ = NVFP4QTensor.get_weights_scaling_factor(tiny_weight, block_size, wsf2) | ||
| per_block_scale_f32 = per_block_scale.float() | ||
|
|
||
| assert (per_block_scale_f32 > 0).all(), ( | ||
| f"Zero per-block scales found after FP8 cast: {per_block_scale_f32.tolist()}. " | ||
| "FP8 scale underflow clamping likely regressed." | ||
| ) | ||
| assert (per_block_scale_f32 >= _FP8_E4M3FN_MIN).all(), ( | ||
| "Per-block scales below FP8 minimum subnormal found after cast." | ||
| ) | ||
|
|
||
| def test_normal_weights_unaffected_by_clamp(self): | ||
| """Weights with typical magnitudes must not be affected by the underflow clamp.""" | ||
| block_size = 16 | ||
| torch.manual_seed(42) | ||
| normal_weight = torch.randn(8, block_size) | ||
|
|
||
| per_block_scale, _ = NVFP4QTensor.get_weights_scaling_factor(normal_weight, block_size) | ||
| assert (per_block_scale.float() > 0).all(), "Normal weights produced zero scales." | ||
|
|
||
| def test_mixed_weight_no_zeros(self): | ||
| """Mixed-magnitude tensor (normal + tiny blocks) must have no zero scales.""" | ||
| block_size = 16 | ||
| weight = torch.cat( | ||
| [ | ||
| torch.randn(4, block_size), | ||
| torch.full((4, block_size), 1e-12), | ||
| ], | ||
| dim=0, | ||
| ) | ||
|
|
||
| per_block_scale, _ = NVFP4QTensor.get_weights_scaling_factor(weight, block_size) | ||
| assert (per_block_scale.float() > 0).all(), ( | ||
| "Zero scales in mixed-magnitude tensor after FP8 cast." | ||
| ) | ||
|
|
||
| def test_helper_clamps_overflow_to_max(self): | ||
| """Values above 448 must saturate to 448, not cast to NaN (fp8_e4m3fn has no Inf).""" | ||
| oversized = torch.tensor([100.0, 448.0, 1e3, 1e6]) | ||
| out = _cast_per_block_scale_to_fp8(oversized).float() | ||
| assert torch.isfinite(out).all(), f"FP8 cast produced non-finite values: {out.tolist()}" | ||
| assert (out <= _FP8_E4M3FN_MAX).all(), f"FP8 cast values exceed 448: {out.tolist()}" | ||
|
|
||
| def test_helper_clamps_underflow_to_min(self): | ||
| """Values below the FP8 subnormal must clamp up, not collapse to 0.""" | ||
| tiny = torch.tensor([0.0, 1e-12, 1e-6, _FP8_E4M3FN_MIN / 2]) | ||
| out = _cast_per_block_scale_to_fp8(tiny).float() | ||
| assert (out > 0).all(), f"FP8 cast produced zero scales: {out.tolist()}" | ||
|
|
||
| def test_static_path_no_nan_when_block_amax_zero(self): | ||
| """Static path: when a block's amax is 0 (all-zero weights), the `[==0]=1.0` safety net | ||
| and a small global_amax push the pre-cast value above 448. Without the max clamp, | ||
| fp8_e4m3fn would cast it to NaN — regression for the export-time NaN reported on this PR. | ||
| """ | ||
| block_size = 16 | ||
| # global_amax small enough that 1.0 * 448 / (global_amax/6) >> 448. | ||
| global_amax = torch.tensor(0.01) | ||
| # One block with amax=0 (triggers safety net), three normal blocks. | ||
| per_block_amax = torch.tensor([[0.0, 0.005, 0.008, 0.01]]) | ||
| weight = torch.randn(1, 4 * block_size) | ||
| q = SimpleNamespace( | ||
| global_amax=global_amax, | ||
| _amax=per_block_amax, | ||
| block_sizes={-1: block_size}, | ||
| ) | ||
|
|
||
| per_block_scale, _ = NVFP4QTensor.get_weights_scaling_factor_from_quantizer(q, weight) | ||
| per_block_scale_f32 = per_block_scale.float() | ||
| assert torch.isfinite(per_block_scale_f32).all(), ( | ||
| f"NaN/Inf in exported static per-block scale: {per_block_scale_f32.tolist()}" | ||
| ) | ||
| assert (per_block_scale_f32 <= _FP8_E4M3FN_MAX).all(), ( | ||
| f"Static per-block scale exceeds FP8 max 448: {per_block_scale_f32.tolist()}" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[SUGGESTION] Two minor robustness gaps:
tokenizermay be undefined in some branches reached bypre_quantize— the existing reference at line 936 (elif tokenizer is not None:) already handles aNonetokenizer for whisper-style models. Themodel_type != "whisper"guard you added is necessary for whisper, but if anyone refactorspre_quantizeto be reachable for additional non-tokenizer models, the access here wouldNameError. Since you already establishedtokenizer is not Noneis a real possibility downstream, the gatetokenizer is not None and tokenizer.pad_token_id is not Noneis correctly defensive on that side — just confirming the order matters (short-circuit prevents the attribute access onNone).first_non_pad.numel() == 0), theifbranch is skipped andpreview_input_idsstays as the all-pad tensor — the post-quantizationgeneratecall below will then run on pure pad tokens. That matches prior behavior, but it's worth a one-line warning so a silently-degenerate preview doesn't mislead users into thinking PTQ broke their model.