diff --git a/modelopt/torch/quantization/qtensor/nvfp4_tensor.py b/modelopt/torch/quantization/qtensor/nvfp4_tensor.py index bb39c8a81e..c67765463f 100644 --- a/modelopt/torch/quantization/qtensor/nvfp4_tensor.py +++ b/modelopt/torch/quantization/qtensor/nvfp4_tensor.py @@ -171,9 +171,12 @@ def get_weights_scaling_factor( ) # Set all zero values in scale to 1.0 per_block_scale[per_block_scale == 0] = 1.0 - # Convert to torch.float8_e4m3fn + # Convert to torch.float8_e4m3fn. fp8_e4m3fn has no Inf, so any + # value >= 480 casts to NaN. Clamp to the maximum finite value before + # casting so exported weight_scale stays finite when an externally + # calibrated global scale makes a per-block scale too large. if not keep_high_precision: - per_block_scale = per_block_scale.to(torch.float8_e4m3fn) + per_block_scale = per_block_scale.clamp_(max=448.0).to(torch.float8_e4m3fn) return per_block_scale, weights_scaling_factor_2 @classmethod diff --git a/tests/gpu/torch/quantization/test_qtensor_cuda.py b/tests/gpu/torch/quantization/test_qtensor_cuda.py index 08fac486f7..ea7a4b0098 100644 --- a/tests/gpu/torch/quantization/test_qtensor_cuda.py +++ b/tests/gpu/torch/quantization/test_qtensor_cuda.py @@ -397,6 +397,28 @@ def _unpack_tensor(x): # Compare with input tensor assert torch.allclose(deq_x, x, rtol=2e-1, atol=2e-1) + @pytest.mark.parametrize("device", ["cuda"]) + def test_nvfp4_dynamic_export_fp8_scale_no_nan_when_scale_exceeds_fp8(self, device): + """Regression: dynamic NVFP4 export must not emit fp8 NaN scale bytes.""" + block_size = 16 + weight = torch.ones(1, block_size, device=device, dtype=torch.bfloat16) + + # Force per_block_scale = per_block_amax / (6 * scale_2) = 1000. + # torch.float8_e4m3fn has no Inf; casting 1000 directly would produce + # the NaN byte 0x7F. Export should instead saturate to 448 (0x7E). + weights_scaling_factor_2 = torch.tensor( + 1.0 / (6.0 * 1000.0), device=device, dtype=torch.float32 + ) + weight_scale, _ = NVFP4QTensor.get_weights_scaling_factor( + weight, block_size, weights_scaling_factor_2=weights_scaling_factor_2 + ) + + assert weight_scale.dtype == torch.float8_e4m3fn + raw = weight_scale.view(torch.uint8) + n_nan = ((raw & 0x7F) == 0x7F).sum().item() + assert n_nan == 0, f"fp8 weight_scale contains {n_nan} NaN byte(s)" + assert raw.flatten()[0].item() == 0x7E + @pytest.mark.parametrize("device", ["cuda"]) @pytest.mark.parametrize( "test_input",