[Quantization] Saturate NVFP4 export FP8 scale cast to avoid NaN#1397
[Quantization] Saturate NVFP4 export FP8 scale cast to avoid NaN#1397
Conversation
When ``NVFP4QTensor.get_weights_scaling_factor_from_quantizer`` runs the static path on a weight that contains an all-zero block (per-block amax == 0), the existing ``per_block_scale[per_block_scale == 0] = 1.0`` safety net drives the pre-cast value to ``1.0 * 448 / (global_amax / 6)``. ``fp8_e4m3fn`` has no Inf, so any value >= 480 rounds to NaN — silently writing a 0x7F byte into the exported ``weight_scale``. Saturating to 448 before the cast keeps the stored byte finite. The all-zero block dequantizes to zero regardless of the stored fp8 scale (the FP4 nibbles are all 0), so dequantization is unaffected. For non-degenerate blocks the clamp is a no-op since ``per_block_amax <= global_amax`` already bounds the pre-cast value at 448. Includes a regression test that reproduces the NaN without the clamp. Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
📝 WalkthroughWalkthroughThis PR adds overflow safety to FP8 scale factor computation in NVFP4 quantization by clamping per-block scaling factors to 448.0 before casting to FP8, and introduces a regression test verifying no NaN values are emitted when exporting scales for all-zero weight blocks. ChangesFP8 Scaling Safety and Regression Test
Estimated code review effort🎯 2 (Simple) | ⏱️ ~12 minutes 🚥 Pre-merge checks | ✅ 6✅ Passed checks (6 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
meenchen
left a comment
There was a problem hiding this comment.
Bot review — DM the bot to share feedback.
Clean, targeted bug fix for a real NaN issue in NVFP4 FP8 scale export when zero-amax blocks are present. The .clamp_(max=448.0) before the to(torch.float8_e4m3fn) cast is correct — fp8_e4m3fn has no Inf representation and values ≥ 480 become NaN. The clamp is a no-op for non-degenerate blocks (where per_block_amax ≤ global_amax bounds the value at 448), and for zero blocks the FP4 nibbles are all zero so the scale value is irrelevant to dequantization correctness. The regression test is thorough — it constructs the exact triggering condition, verifies no NaN bytes, and checks the saturated value is 0x7E (448). +55/-3 lines, well-scoped.
Testing: Test plan has 1 unchecked item(s) out of 3. Finish or remove them before approving.
|
meenchen
left a comment
There was a problem hiding this comment.
Bot review — DM the bot to share feedback.
Re-review: The previous bot comment flagged unchecked test plan items — those are now all checked (2/2). The fix is correct, minimal (+55/-3), well-commented, and the regression test thoroughly covers the exact triggering condition. No outstanding issues.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@modelopt/torch/quantization/qtensor/nvfp4_tensor.py`:
- Around line 125-135: get_weights_scaling_factor() can produce NaN when both
per_block_amax and weights_scaling_factor_2 are zero (all-zero tensor), which
later breaks quantize(); update get_weights_scaling_factor (and/or add an early
all-zero fast path in quantize) to detect the zero-denominator case and
short-circuit: if per_block_amax == 0 or weights_scaling_factor_2 == 0, set
per_block_scale to a safe finite value (e.g., 0.0) or return an all-zero
quantized result immediately; ensure the same clamp/finite handling applied in
the non-static branch (the per_block_scale fixup used elsewhere) is applied here
so per_block_scale is never NaN before division in quantize().
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 688f0594-3925-40c3-8d6a-56dc7abe8833
📒 Files selected for processing (2)
modelopt/torch/quantization/qtensor/nvfp4_tensor.pytests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py
| # Quantize scales to FP8. Saturate to the fp8_e4m3fn max (448) before the | ||
| # cast: when the [==0]=1.0 safety net above fires (per_block_amax was zero | ||
| # for an all-zero weight block) and global_amax is small, the pre-cast value | ||
| # explodes to ``1.0 * 448 / (global_amax/6)``. fp8_e4m3fn has no Inf, so any | ||
| # value >= 480 casts to NaN — clamp first to keep the stored byte finite. | ||
| if not keep_high_precision: | ||
| per_block_scale = (per_block_scale * 448.0 / per_block_scale_max).to( | ||
| torch.float8_e4m3fn | ||
| per_block_scale = ( | ||
| (per_block_scale * 448.0 / per_block_scale_max) | ||
| .clamp_(max=448.0) | ||
| .to(torch.float8_e4m3fn) | ||
| ) |
There was a problem hiding this comment.
This only closes the static branch; the dynamic all-zero path can still produce NaNs.
get_weights_scaling_factor() still computes per_block_amax / (6 * weights_scaling_factor_2), and for an all-zero tensor both terms are zero. That leaves NaN in per_block_scale because the later per_block_scale == 0 fixup does not catch it, and quantize() then divides by weights_scaling_factor * weights_scaling_factor_2 with a zero denominator. Please add the same zero-amax handling there, or an early all-zero fast path, so this fix covers both code paths.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@modelopt/torch/quantization/qtensor/nvfp4_tensor.py` around lines 125 - 135,
get_weights_scaling_factor() can produce NaN when both per_block_amax and
weights_scaling_factor_2 are zero (all-zero tensor), which later breaks
quantize(); update get_weights_scaling_factor (and/or add an early all-zero fast
path in quantize) to detect the zero-denominator case and short-circuit: if
per_block_amax == 0 or weights_scaling_factor_2 == 0, set per_block_scale to a
safe finite value (e.g., 0.0) or return an all-zero quantized result
immediately; ensure the same clamp/finite handling applied in the non-static
branch (the per_block_scale fixup used elsewhere) is applied here so
per_block_scale is never NaN before division in quantize().
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1397 +/- ##
===========================================
- Coverage 76.73% 66.07% -10.67%
===========================================
Files 476 476
Lines 51306 51306
===========================================
- Hits 39369 33899 -5470
- Misses 11937 17407 +5470
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Summary
per_block_scale * 448 / per_block_scale_maxto ≤ 448 before theto(torch.float8_e4m3fn)cast inNVFP4QTensor.get_weights_scaling_factor_from_quantizer.Why
When
_amaxcontains a zero entry (e.g. an all-zero weight block left untouched by max calibration), the existingper_block_scale[per_block_scale == 0] = 1.0safety net drives the pre-cast value to1.0 * 448 / (global_amax / 6).fp8_e4m3fnhas no Inf — anything≥ 480rounds to NaN — so a 0x7F byte slips into the exportedweight_scale.This was observed in a saved Kimi-K2.6-NVFP4-MSE checkpoint at
language_model.model.layers.1.mlp.experts.21.down_proj.weight_scale[4001, 18]. The MSE FP8 sweep itself never produces zero per-block amax (it always emits at leastc[0] * global_amax), but any export path where_amaxends up zero — including pure max calibration — hits the bug. With the clamp the byte saturates to0x7E(= 448, fp8 max finite) and dequantization is unaffected: the FP4 nibbles for an all-zero block are all 0, so0 × 448 × weight_scale_2 = 0regardless of the stored fp8 scale. For non-degenerate blocks the clamp is a no-op sinceper_block_amax ≤ global_amaxalready bounds the pre-cast value at 448.Test plan
test_export_fp8_scale_no_nan_for_zero_amax_blockfails onmain's export code (reproduces the 0x7F NaN byte) and passes with the clamp.tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.pystill pass (10/10).🤖 Generated with Claude Code
Summary by CodeRabbit