Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,17 @@ def pre_quantize(
preview_input_ids = next(iter(calib_dataloader))[
"input_features" if model_type == "whisper" else "input_ids"
][0:1]
# Strip leading padding tokens so the preview input shows real content
if model_type != "whisper" and tokenizer is not None and tokenizer.pad_token_id is not None:
first_non_pad = (preview_input_ids[0] != tokenizer.pad_token_id).nonzero(as_tuple=True)[0]
if first_non_pad.numel() > 0:
preview_input_ids = preview_input_ids[:, first_non_pad[0] :]
Comment on lines +831 to +835
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] Two minor robustness gaps:

  1. tokenizer may be undefined in some branches reached by pre_quantize — the existing reference at line 936 (elif tokenizer is not None:) already handles a None tokenizer for whisper-style models. The model_type != "whisper" guard you added is necessary for whisper, but if anyone refactors pre_quantize to be reachable for additional non-tokenizer models, the access here would NameError. Since you already established tokenizer is not None is a real possibility downstream, the gate tokenizer is not None and tokenizer.pad_token_id is not None is correctly defensive on that side — just confirming the order matters (short-circuit prevents the attribute access on None).
  2. After stripping leading pads, if the entire sample was padding (first_non_pad.numel() == 0), the if branch is skipped and preview_input_ids stays as the all-pad tensor — the post-quantization generate call below will then run on pure pad tokens. That matches prior behavior, but it's worth a one-line warning so a silently-degenerate preview doesn't mislead users into thinking PTQ broke their model.

else:
warnings.warn(
"Preview calibration sample is entirely padding; generated preview will be "
"degenerate. Check tokenizer padding side / dataset preprocessing.",
stacklevel=2,
)

# Generate preview before quantization
if args.skip_generate:
Expand Down Expand Up @@ -928,7 +939,7 @@ def input_decode(input_ids):
if processor is not None and isinstance(processor, WhisperProcessor):
return first_text_speech_dataset
elif tokenizer is not None:
return tokenizer.batch_decode(input_ids)
return tokenizer.batch_decode(input_ids, skip_special_tokens=True)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed? Do you see any issues?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, when I test with Qwen3.6, in some setting the pad tokens are very long and fill up output token length

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@meenchen could you comment on this one?

else:
raise ValueError("The processor or tokenizer must be set")

Expand Down
2 changes: 1 addition & 1 deletion modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,7 @@ def finish_stats_collection(model: nn.Module, method: str | None = None, **kwarg

cal = getattr(module, "_calibrator", None)
if cal and not getattr(module, "_dynamic", False):
if method in {"entropy"}:
if method == "entropy":
Comment thread
realAsma marked this conversation as resolved.
if cal.compute_amax(method) is not None:
module.load_calib_amax("entropy", **kwargs)
elif cal.compute_amax(**kwargs) is not None:
Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/quantization/model_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,7 @@ def print_quant_summary(model: nn.Module, output_dir: str | None = None):
lines.append(f"{len(lines)} TensorQuantizers found in model")

if output_dir:
os.makedirs(output_dir, exist_ok=True)
path = os.path.join(output_dir, ".quant_summary.txt")
with open(path, "w", encoding="utf-8") as f:
f.write("\n".join(lines) + "\n")
Expand Down
4 changes: 2 additions & 2 deletions modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,7 +1122,7 @@ def forward(self, inputs):

return outputs

def _short_amax(self, fmt=".4f"):
def _short_amax(self, fmt=".2e"):
"""Short description of amax.

Returns:
Expand All @@ -1140,7 +1140,7 @@ def _short_amax(self, fmt=".4f"):
return "meta"
return self._short_tensor(self._amax, fmt)

def _short_tensor(self, tensor: torch.Tensor, fmt=".4f"):
def _short_tensor(self, tensor: torch.Tensor, fmt=".2e"):
"""Short description of tensor."""
if tensor.numel() == 1:
return f"{tensor.item():{fmt}}"
Expand Down
20 changes: 9 additions & 11 deletions modelopt/torch/quantization/qtensor/nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
__all__ = ["NVFP4QTensor"]


def _cast_per_block_scale_to_fp8(per_block_scale: torch.Tensor) -> torch.Tensor:
"""Clamp to FP8 E4M3FN range [2**-9, 448] and cast — avoids underflow→0 / overflow→NaN."""
return per_block_scale.clamp(min=2**-9, max=448.0).to(torch.float8_e4m3fn)


class NVFP4QTensor(BaseQuantizedTensor):
"""Implements the INT4 quantization on tensors for more efficient storage or computation.

Expand Down Expand Up @@ -122,16 +127,10 @@ def get_weights_scaling_factor_from_quantizer(
expected_shape = (*weight.shape[:-1], num_blocks_per_row)
per_block_scale = per_block_scale.view(expected_shape)

# Quantize scales to FP8. Saturate to the fp8_e4m3fn max (448) before the
# cast: when the [==0]=1.0 safety net above fires (per_block_amax was zero
# for an all-zero weight block) and global_amax is small, the pre-cast value
# explodes to ``1.0 * 448 / (global_amax/6)``. fp8_e4m3fn has no Inf, so any
# value >= 480 casts to NaN — clamp first to keep the stored byte finite.
if not keep_high_precision:
per_block_scale = (
(per_block_scale * 448.0 / per_block_scale_max)
.clamp_(max=448.0)
.to(torch.float8_e4m3fn)
# The [==0]=1.0 safety net + small global_amax can drive the pre-cast value above 448 (PR #1397).
per_block_scale = _cast_per_block_scale_to_fp8(
per_block_scale * 448.0 / per_block_scale_max
)
Comment on lines 130 to 134
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] The original code carried a useful comment explaining why clamp(max=448) is needed in this branch (the [==0]=1.0 safety net combined with a small global_amax produces pre-cast values that overflow FP8 to NaN). That rationale is now gone — the helper docstring describes what the clamp does generally, but a future reader of this branch will not know why this specific multiplication can blow up. Consider keeping the original 4-line "why" comment here, since the math behind the overflow is non-obvious from the helper alone.

return per_block_scale, weights_scaling_factor_2
else:
Expand Down Expand Up @@ -171,9 +170,8 @@ def get_weights_scaling_factor(
)
# Set all zero values in scale to 1.0
per_block_scale[per_block_scale == 0] = 1.0
# Convert to torch.float8_e4m3fn
if not keep_high_precision:
per_block_scale = per_block_scale.to(torch.float8_e4m3fn)
per_block_scale = _cast_per_block_scale_to_fp8(per_block_scale)
return per_block_scale, weights_scaling_factor_2

@classmethod
Expand Down
129 changes: 112 additions & 17 deletions tests/unit/torch/quantization/plugins/test_fused_experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

"""Tests for _QuantFusedExperts: generic fused MoE quantization and export."""

from unittest.mock import patch

import pytest
import torch
import torch.nn as nn
Expand Down Expand Up @@ -256,27 +258,51 @@ def test_expert_index_recovery(self):
# Tests for export
# ---------------------------------------------------------------------------
class TestExportFusedExperts:
@staticmethod
def _cleanup_registry(mod_type):
if QuantModuleRegistry.get(mod_type) is not None:
QuantModuleRegistry.unregister(mod_type)

def test_export_creates_per_expert_submodules(self):
"""_export_fused_experts should create per-expert submodules with standard naming."""
import modelopt.torch.quantization as mtq
from modelopt.torch.export.moe_utils import _export_fused_experts

experts = _SyntheticFusedExperts()
expert_type = type(experts)
model = _TinyMoEModel()
expert_type = type(model.moe.experts)
self._cleanup_registry(expert_type)

# Manually register and convert
if QuantModuleRegistry.get(expert_type) is None:
QuantModuleRegistry.register({expert_type: "test.SyntheticFusedExperts"})(
_QuantFusedExperts
)
converted = QuantModuleRegistry.convert(experts)
quant_cfg = {
"quant_cfg": [
{"quantizer_name": "*", "enable": False},
{
"quantizer_name": "*gate_up_proj_input_quantizer",
"cfg": {"num_bits": 8, "axis": None},
},
{
"quantizer_name": "*down_proj_input_quantizer",
"cfg": {"num_bits": 8, "axis": None},
},
{
"quantizer_name": "*gate_up_proj_weight_quantizer",
"cfg": {"num_bits": 8, "axis": 0},
},
{
"quantizer_name": "*down_proj_weight_quantizer",
"cfg": {"num_bits": 8, "axis": 0},
},
],
"algorithm": "max",
}

# Run a forward pass to calibrate (set amaxes)
seq_len = 16
hidden_states = torch.randn(seq_len, HIDDEN_DIM)
top_k_index = torch.randint(0, NUM_EXPERTS, (seq_len, TOP_K))
top_k_weights = torch.softmax(torch.randn(seq_len, TOP_K), dim=-1)
with torch.no_grad():
converted(hidden_states, top_k_index, top_k_weights)
def forward_loop(m):
torch.manual_seed(0)
for _ in range(2):
x = torch.randn(1, 4, HIDDEN_DIM)
m(x)

mtq.quantize(model, quant_cfg, forward_loop=forward_loop)
converted = model.moe.experts

_export_fused_experts(converted, torch.float16)

Expand All @@ -297,8 +323,7 @@ def test_export_creates_per_expert_submodules(self):
assert not hasattr(converted, "down_proj")
assert not hasattr(converted, "gate_up_proj_weight_quantizers")

if QuantModuleRegistry.get(expert_type) is not None:
QuantModuleRegistry.unregister(expert_type)
self._cleanup_registry(expert_type)

def test_uncalibrated_expert_gate_up_share_amax(self, monkeypatch):
"""gate_proj and up_proj must share weight_scale_2 even when an expert
Expand Down Expand Up @@ -899,3 +924,73 @@ def test_unrelated_dotted_number_unchanged(self):
_normalize_fused_experts_quantizer_name("moe.layers.3.gate.weight")
== "moe.layers.3.gate.weight"
)


# Verifies that MSE calibration discovers and calibrates every per-expert weight quantizer
# inside a fused-expert ModuleList (both gate_up_proj and down_proj, for all experts).
class TestFusedExpertsMSECalibration:
@staticmethod
def _cleanup_registry(mod_type):
if QuantModuleRegistry.get(mod_type) is not None:
QuantModuleRegistry.unregister(mod_type)

def test_mse_calibration_populates_all_expert_quantizers(self):
# Strong assertion: every per-expert weight quantizer must be touched by the MSE
# search loop (mse_calibrate Step 3), not just have _amax set by max-calibrate or
# the dead-expert bootstrap. Spy on MseCalibrator.collect — that method is only
# invoked from Step 3, after Step 2 installs MseCalibrator on each quantizer.
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.calib.mse import MseCalibrator

model = _TinyMoEModel()
expert_type = type(model.moe.experts)
self._cleanup_registry(expert_type)

collected_calib_ids: set[int] = set()
original_collect = MseCalibrator.collect

def _spy_collect(self, x):
collected_calib_ids.add(id(self))
return original_collect(self, x)

with patch.object(MseCalibrator, "collect", _spy_collect):
mtq.quantize(
model,
{
"quant_cfg": [
{"quantizer_name": "*", "enable": False},
{
"quantizer_name": "*gate_up_proj_weight_quantizer",
"cfg": {"num_bits": 8, "axis": None},
},
{
"quantizer_name": "*down_proj_weight_quantizer",
"cfg": {"num_bits": 8, "axis": None},
},
],
"algorithm": "mse",
},
forward_loop=lambda m: [m(torch.randn(1, 4, HIDDEN_DIM)) for _ in range(2)],
)

experts = model.moe.experts
missed = []
for idx in range(NUM_EXPERTS):
assert experts.gate_up_proj_weight_quantizers[idx].amax is not None, (
f"gate_up_proj_weight_quantizers[{idx}] not calibrated — Bug 1 regression"
)
assert experts.down_proj_weight_quantizers[idx].amax is not None, (
f"down_proj_weight_quantizers[{idx}] not calibrated"
)
Comment on lines +978 to +984
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Compatibility] This assertion does not actually verify the bug fix described in the PR ("MSE weight calibration: 0it" / per-expert MSE search not running). It only checks that amax is not None, which is set by either:

  1. max_calibrate(...) — which iterates model.named_modules() and reaches per-expert TensorQuantizers inside the nn.ModuleList directly, OR
  2. _bootstrap_uncalibrated_weight_quantizers(...) — which calls iter_weights_for_calibration() (already overridden on _QuantFusedExperts).

Both run unconditionally before MSE search inside mse_calibrate. So this test passes even if the per-expert MSE search loop is fully skipped — it does not catch the regression its docstring/error message claims to guard against.

To actually verify MSE ran on per-expert quantizers, do one of:

  • Snapshot _amax after a algorithm="max" run, then re-run with algorithm="mse" and assert at least some per-expert amaxes changed,
  • Assert experts.gate_up_proj_weight_quantizers[idx]._calibrator is an MseCalibrator instance after calibration,
  • Patch MseCalibrator.compute_amax and assert it was called once per expert × per fused-projection.

if (
id(experts.gate_up_proj_weight_quantizers[idx]._calibrator)
not in collected_calib_ids
):
missed.append(f"gate_up_proj_weight_quantizers[{idx}]")
if id(experts.down_proj_weight_quantizers[idx]._calibrator) not in collected_calib_ids:
missed.append(f"down_proj_weight_quantizers[{idx}]")
assert not missed, (
f"MSE search loop skipped these per-expert quantizers: {missed}. "
"mse_calibrate Step 3 did not iterate them via iter_weights_for_calibration."
)
self._cleanup_registry(expert_type)
114 changes: 114 additions & 0 deletions tests/unit/torch/quantization/test_nvfp4_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for NVFP4QTensor per-block FP8 scale clamping (underflow + overflow)."""

from types import SimpleNamespace

import torch

from modelopt.torch.quantization.qtensor.nvfp4_tensor import (
NVFP4QTensor,
_cast_per_block_scale_to_fp8,
)

_FP8_E4M3FN_MIN = 2**-9 # 0.001953125 — smallest positive FP8 E4M3FN subnormal
_FP8_E4M3FN_MAX = 448.0


class TestNVFP4ScaleClamping:
"""Per-block weight scales outside the FP8 E4M3FN range must be clamped, not turned into 0/NaN."""

def test_no_zero_scales_for_tiny_weights(self):
"""Tiny per-block amax (<<FP8 min) must not underflow to zero after FP8 cast."""
block_size = 16
tiny_weight = torch.full((4, block_size), 1e-10)
# wsf2=1.0 → per_block_scale = amax/(6*wsf2) ≈ 1.7e-11 << 2^-9, exercises FP8-min clamp
wsf2 = torch.tensor(1.0)

per_block_scale, _ = NVFP4QTensor.get_weights_scaling_factor(tiny_weight, block_size, wsf2)
per_block_scale_f32 = per_block_scale.float()

assert (per_block_scale_f32 > 0).all(), (
f"Zero per-block scales found after FP8 cast: {per_block_scale_f32.tolist()}. "
"FP8 scale underflow clamping likely regressed."
)
assert (per_block_scale_f32 >= _FP8_E4M3FN_MIN).all(), (
"Per-block scales below FP8 minimum subnormal found after cast."
)

def test_normal_weights_unaffected_by_clamp(self):
"""Weights with typical magnitudes must not be affected by the underflow clamp."""
block_size = 16
torch.manual_seed(42)
normal_weight = torch.randn(8, block_size)

per_block_scale, _ = NVFP4QTensor.get_weights_scaling_factor(normal_weight, block_size)
assert (per_block_scale.float() > 0).all(), "Normal weights produced zero scales."

def test_mixed_weight_no_zeros(self):
"""Mixed-magnitude tensor (normal + tiny blocks) must have no zero scales."""
block_size = 16
weight = torch.cat(
[
torch.randn(4, block_size),
torch.full((4, block_size), 1e-12),
],
dim=0,
)

per_block_scale, _ = NVFP4QTensor.get_weights_scaling_factor(weight, block_size)
assert (per_block_scale.float() > 0).all(), (
"Zero scales in mixed-magnitude tensor after FP8 cast."
)

def test_helper_clamps_overflow_to_max(self):
"""Values above 448 must saturate to 448, not cast to NaN (fp8_e4m3fn has no Inf)."""
oversized = torch.tensor([100.0, 448.0, 1e3, 1e6])
out = _cast_per_block_scale_to_fp8(oversized).float()
assert torch.isfinite(out).all(), f"FP8 cast produced non-finite values: {out.tolist()}"
assert (out <= _FP8_E4M3FN_MAX).all(), f"FP8 cast values exceed 448: {out.tolist()}"

def test_helper_clamps_underflow_to_min(self):
"""Values below the FP8 subnormal must clamp up, not collapse to 0."""
tiny = torch.tensor([0.0, 1e-12, 1e-6, _FP8_E4M3FN_MIN / 2])
out = _cast_per_block_scale_to_fp8(tiny).float()
assert (out > 0).all(), f"FP8 cast produced zero scales: {out.tolist()}"

def test_static_path_no_nan_when_block_amax_zero(self):
"""Static path: when a block's amax is 0 (all-zero weights), the `[==0]=1.0` safety net
and a small global_amax push the pre-cast value above 448. Without the max clamp,
fp8_e4m3fn would cast it to NaN — regression for the export-time NaN reported on this PR.
"""
block_size = 16
# global_amax small enough that 1.0 * 448 / (global_amax/6) >> 448.
global_amax = torch.tensor(0.01)
# One block with amax=0 (triggers safety net), three normal blocks.
per_block_amax = torch.tensor([[0.0, 0.005, 0.008, 0.01]])
weight = torch.randn(1, 4 * block_size)
q = SimpleNamespace(
global_amax=global_amax,
_amax=per_block_amax,
block_sizes={-1: block_size},
)

per_block_scale, _ = NVFP4QTensor.get_weights_scaling_factor_from_quantizer(q, weight)
per_block_scale_f32 = per_block_scale.float()
assert torch.isfinite(per_block_scale_f32).all(), (
f"NaN/Inf in exported static per-block scale: {per_block_scale_f32.tolist()}"
)
assert (per_block_scale_f32 <= _FP8_E4M3FN_MAX).all(), (
f"Static per-block scale exceeds FP8 max 448: {per_block_scale_f32.tolist()}"
)
Loading