Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions modelopt/torch/quantization/qtensor/nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,16 @@ def get_weights_scaling_factor_from_quantizer(
expected_shape = (*weight.shape[:-1], num_blocks_per_row)
per_block_scale = per_block_scale.view(expected_shape)

# Quantize scales to FP8
# Quantize scales to FP8. Saturate to the fp8_e4m3fn max (448) before the
# cast: when the [==0]=1.0 safety net above fires (per_block_amax was zero
# for an all-zero weight block) and global_amax is small, the pre-cast value
# explodes to ``1.0 * 448 / (global_amax/6)``. fp8_e4m3fn has no Inf, so any
# value >= 480 casts to NaN — clamp first to keep the stored byte finite.
if not keep_high_precision:
per_block_scale = (per_block_scale * 448.0 / per_block_scale_max).to(
torch.float8_e4m3fn
per_block_scale = (
(per_block_scale * 448.0 / per_block_scale_max)
.clamp_(max=448.0)
.to(torch.float8_e4m3fn)
)
Comment on lines +125 to 135
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

This only closes the static branch; the dynamic all-zero path can still produce NaNs.

get_weights_scaling_factor() still computes per_block_amax / (6 * weights_scaling_factor_2), and for an all-zero tensor both terms are zero. That leaves NaN in per_block_scale because the later per_block_scale == 0 fixup does not catch it, and quantize() then divides by weights_scaling_factor * weights_scaling_factor_2 with a zero denominator. Please add the same zero-amax handling there, or an early all-zero fast path, so this fix covers both code paths.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/quantization/qtensor/nvfp4_tensor.py` around lines 125 - 135,
get_weights_scaling_factor() can produce NaN when both per_block_amax and
weights_scaling_factor_2 are zero (all-zero tensor), which later breaks
quantize(); update get_weights_scaling_factor (and/or add an early all-zero fast
path in quantize) to detect the zero-denominator case and short-circuit: if
per_block_amax == 0 or weights_scaling_factor_2 == 0, set per_block_scale to a
safe finite value (e.g., 0.0) or return an all-zero quantized result
immediately; ensure the same clamp/finite handling applied in the non-static
branch (the per_block_scale fixup used elsewhere) is applied here so
per_block_scale is never NaN before division in quantize().

return per_block_scale, weights_scaling_factor_2
else:
Expand Down
46 changes: 46 additions & 0 deletions tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from modelopt.torch.quantization.calib import NVFP4MSECalibrator
from modelopt.torch.quantization.config import QuantizerAttributeConfig
from modelopt.torch.quantization.nn import NVFP4StaticQuantizer, TensorQuantizer
from modelopt.torch.quantization.qtensor import NVFP4QTensor
from modelopt.torch.quantization.tensor_quant import (
scaled_e4m3_impl,
static_blockwise_fp4_fake_quant,
Expand Down Expand Up @@ -64,6 +65,51 @@ def test_global_amax_property(self, device):
quantizer.global_amax = None
assert quantizer.global_amax is None

def test_export_fp8_scale_no_nan_for_zero_amax_block(self, device):
"""Regression: export must not emit fp8 NaN bytes for an all-zero block.

When max-only calibration leaves ``_amax = 0`` for a fully-zero weight block,
the export's ``[per_block_scale == 0] = 1.0`` safety net drives the pre-cast
value to ``1.0 * 448 / (global_amax / 6)``. fp8_e4m3fn has no Inf, so any
pre-cast value >= 480 rounds to NaN — without a saturation clamp this writes
a 0x7F byte into ``weight_scale``. Reproduces the NaN seen in the saved
Kimi-K2.6-NVFP4-MSE checkpoint at expert 21 down_proj.
"""
block_size = 16
cfg = QuantizerAttributeConfig(
num_bits=(2, 1),
block_sizes={-1: block_size, "type": "static", "scale_bits": (4, 3)},
)
quantizer = NVFP4StaticQuantizer(quant_attribute_cfg=cfg).to(device)

# Two-block weight: block 0 is non-trivial; block 1 is all zeros so its
# per-block amax is exactly 0.
weight = torch.zeros(1, 2 * block_size, device=device, dtype=torch.bfloat16)
weight[0, :block_size] = 0.1

per_block_amax = weight.abs().reshape(1, 2, block_size).amax(dim=-1).flatten()
quantizer.amax = per_block_amax
quantizer.global_amax = per_block_amax.max()

# Sanity: the bug only fires when the would-be cast value exceeds 480.
# With global_amax = 0.1, scale_in_fp8 for a zero block is
# 1.0 * 448 / (0.1 / 6) ≈ 26880 — well past the 480 NaN threshold.
assert (per_block_amax == 0).any()
assert quantizer.global_amax.float().item() < 1.0

weight_scale, _ = NVFP4QTensor.get_weights_scaling_factor_from_quantizer(
quantizer, weight, weights_scaling_factor_2=None
)
assert weight_scale.dtype == torch.float8_e4m3fn

# No fp8_e4m3fn NaN bytes (NaN encoding is (b & 0x7F) == 0x7F).
raw = weight_scale.view(torch.uint8)
n_nan = ((raw & 0x7F) == 0x7F).sum().item()
assert n_nan == 0, f"fp8 weight_scale contains {n_nan} NaN byte(s)"

# The all-zero block's stored fp8 scale should saturate to 448 (max finite).
assert raw.flatten()[1].item() == 0x7E

def test_fake_quantize_with_both_amaxs(self, device):
"""Test _fake_quantize uses both _amax and _global_amax."""
num_blocks = 4
Expand Down
Loading