diff --git a/modelopt/torch/quantization/qtensor/nvfp4_tensor.py b/modelopt/torch/quantization/qtensor/nvfp4_tensor.py index fe30e283c2d..bb39c8a81e3 100644 --- a/modelopt/torch/quantization/qtensor/nvfp4_tensor.py +++ b/modelopt/torch/quantization/qtensor/nvfp4_tensor.py @@ -122,10 +122,16 @@ def get_weights_scaling_factor_from_quantizer( expected_shape = (*weight.shape[:-1], num_blocks_per_row) per_block_scale = per_block_scale.view(expected_shape) - # Quantize scales to FP8 + # Quantize scales to FP8. Saturate to the fp8_e4m3fn max (448) before the + # cast: when the [==0]=1.0 safety net above fires (per_block_amax was zero + # for an all-zero weight block) and global_amax is small, the pre-cast value + # explodes to ``1.0 * 448 / (global_amax/6)``. fp8_e4m3fn has no Inf, so any + # value >= 480 casts to NaN — clamp first to keep the stored byte finite. if not keep_high_precision: - per_block_scale = (per_block_scale * 448.0 / per_block_scale_max).to( - torch.float8_e4m3fn + per_block_scale = ( + (per_block_scale * 448.0 / per_block_scale_max) + .clamp_(max=448.0) + .to(torch.float8_e4m3fn) ) return per_block_scale, weights_scaling_factor_2 else: diff --git a/tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py b/tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py index b1b3691a797..430b7ee4113 100644 --- a/tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py +++ b/tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py @@ -21,6 +21,7 @@ from modelopt.torch.quantization.calib import NVFP4MSECalibrator from modelopt.torch.quantization.config import QuantizerAttributeConfig from modelopt.torch.quantization.nn import NVFP4StaticQuantizer, TensorQuantizer +from modelopt.torch.quantization.qtensor import NVFP4QTensor from modelopt.torch.quantization.tensor_quant import ( scaled_e4m3_impl, static_blockwise_fp4_fake_quant, @@ -64,6 +65,51 @@ def test_global_amax_property(self, device): quantizer.global_amax = None assert quantizer.global_amax is None + def test_export_fp8_scale_no_nan_for_zero_amax_block(self, device): + """Regression: export must not emit fp8 NaN bytes for an all-zero block. + + When max-only calibration leaves ``_amax = 0`` for a fully-zero weight block, + the export's ``[per_block_scale == 0] = 1.0`` safety net drives the pre-cast + value to ``1.0 * 448 / (global_amax / 6)``. fp8_e4m3fn has no Inf, so any + pre-cast value >= 480 rounds to NaN — without a saturation clamp this writes + a 0x7F byte into ``weight_scale``. Reproduces the NaN seen in the saved + Kimi-K2.6-NVFP4-MSE checkpoint at expert 21 down_proj. + """ + block_size = 16 + cfg = QuantizerAttributeConfig( + num_bits=(2, 1), + block_sizes={-1: block_size, "type": "static", "scale_bits": (4, 3)}, + ) + quantizer = NVFP4StaticQuantizer(quant_attribute_cfg=cfg).to(device) + + # Two-block weight: block 0 is non-trivial; block 1 is all zeros so its + # per-block amax is exactly 0. + weight = torch.zeros(1, 2 * block_size, device=device, dtype=torch.bfloat16) + weight[0, :block_size] = 0.1 + + per_block_amax = weight.abs().reshape(1, 2, block_size).amax(dim=-1).flatten() + quantizer.amax = per_block_amax + quantizer.global_amax = per_block_amax.max() + + # Sanity: the bug only fires when the would-be cast value exceeds 480. + # With global_amax = 0.1, scale_in_fp8 for a zero block is + # 1.0 * 448 / (0.1 / 6) ≈ 26880 — well past the 480 NaN threshold. + assert (per_block_amax == 0).any() + assert quantizer.global_amax.float().item() < 1.0 + + weight_scale, _ = NVFP4QTensor.get_weights_scaling_factor_from_quantizer( + quantizer, weight, weights_scaling_factor_2=None + ) + assert weight_scale.dtype == torch.float8_e4m3fn + + # No fp8_e4m3fn NaN bytes (NaN encoding is (b & 0x7F) == 0x7F). + raw = weight_scale.view(torch.uint8) + n_nan = ((raw & 0x7F) == 0x7F).sum().item() + assert n_nan == 0, f"fp8 weight_scale contains {n_nan} NaN byte(s)" + + # The all-zero block's stored fp8 scale should saturate to 448 (max finite). + assert raw.flatten()[1].item() == 0x7E + def test_fake_quantize_with_both_amaxs(self, device): """Test _fake_quantize uses both _amax and _global_amax.""" num_blocks = 4