From 38148d0d1f7dfae71ffdbe472883af96c3ba89a5 Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Mon, 4 May 2026 20:57:52 +0000 Subject: [PATCH 1/8] [Quantization] Fused Triton kernel for NVFP4 FP8 scale sweep search Replaces the 126-iteration Python sweep in NVFP4MSECalibrator with a single fused Triton kernel that loads each NVFP4 block once, evaluates all 126 valid FP8 E4M3 scale candidates in registers, and emits the per-block best amax directly. For our specific candidate set (FP8 representable values / 448) the FP8 round-trip on the per-block scale is the identity, so the kernel uses `scale = candidate * global_amax / 6.0` and runs on any CUDA + Triton. Triton-backed calibrator is on by default for `mse_calibrate(... fp8_scale_sweep=True)`; set `MODELOPT_NVFP4_TRITON_SWEEP=0` to fall back to the reference for debugging. Measured ~7.4x speedup on a B300 over the reference NVFP4MSECalibrator (8192x4096 weight, ~2M NVFP4 blocks: 176.67 ms -> 23.81 ms). Bit-identical to the reference for typical block counts; on multi-million-block weights an occasional adjacent-candidate tie-break can differ at the fp32-noise level (observed 2 / 2,097,152 blocks; per-block MSE within 1e-7 relative). Signed-off-by: Chenjie Luo --- .../kernels/quantization/gemm/__init__.py | 1 + .../quantization/gemm/nvfp4_fp8_sweep.py | 142 +++++++++++ modelopt/torch/quantization/calib/mse.py | 81 ++++++- modelopt/torch/quantization/model_calib.py | 12 +- .../test_nvfp4_fp8_sweep_kernel.py | 221 ++++++++++++++++++ 5 files changed, 453 insertions(+), 4 deletions(-) create mode 100644 modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py create mode 100644 tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py diff --git a/modelopt/torch/kernels/quantization/gemm/__init__.py b/modelopt/torch/kernels/quantization/gemm/__init__.py index 39b07b4faa9..70f729cffb0 100644 --- a/modelopt/torch/kernels/quantization/gemm/__init__.py +++ b/modelopt/torch/kernels/quantization/gemm/__init__.py @@ -32,6 +32,7 @@ # fp4_kernel works on any CUDA GPU with triton from .fp4_kernel import * from .fp8_kernel import * + from .nvfp4_fp8_sweep import * # fp4_kernel_hopper requires compute >= 8.9 (uses tl.float8e4nv) if torch.cuda.get_device_capability() >= (8, 9): diff --git a/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py b/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py new file mode 100644 index 00000000000..4fdeaf7c104 --- /dev/null +++ b/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py @@ -0,0 +1,142 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fused Triton kernel for the NVFP4 weight-MSE FP8 scale sweep. + +Replaces the 126-iteration Python sweep in :class:`NVFP4MSECalibrator` with a single +kernel that, for each NVFP4 block, evaluates all 126 valid FP8 E4M3 scale candidates +and emits the per-block ``best_amax`` directly. + +The 126 candidates are constructed as ``valid_fp8_e4m3_value / 448`` (see +:func:`fp8_scale_candidates`). For these specific candidates, the FP8 round-trip on +the per-block scale is the identity, so the kernel can use +``scale = candidate * global_amax / 6.0`` without an explicit FP8 cast — making it +runnable on any CUDA GPU with Triton (no ``tl.float8e4nv`` requirement). +""" + +import torch +import triton +import triton.language as tl + +from .nvfp4_quant import nvfp4_scalar_quant + +__all__ = ["fp8_scale_candidates", "nvfp4_fp8_scale_sweep"] + + +def fp8_scale_candidates(device: torch.device | str = "cpu") -> torch.Tensor: + """Return the 126 valid finite positive FP8 E4M3 scale candidates / 448.""" + uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device) + fp8_values = uint8_values.view(torch.float8_e4m3fn).float() + valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0) + return fp8_values[valid_mask] / 448.0 + + +@triton.jit +def _fp8_scale_sweep_kernel( + x_ptr, # [N_BLOCKS * BLOCK_SIZE], any float dtype (loaded as fp32) + candidates_ptr, # [NUM_CANDIDATES] fp32 + global_amax_ptr, # scalar fp32 + best_amax_ptr, # [N_BLOCKS] fp32 output + N_BLOCKS, + BLOCK_SIZE: tl.constexpr, + NUM_CANDIDATES: tl.constexpr, + BLOCKS_PER_PROGRAM: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCKS_PER_PROGRAM + block_idx = block_start + tl.arange(0, BLOCKS_PER_PROGRAM) + block_mask = block_idx < N_BLOCKS + + # Load weights for this tile: [BLOCKS_PER_PROGRAM, BLOCK_SIZE] + elem_offs = block_idx[:, None] * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[None, :] + elem_mask = block_mask[:, None] + w = tl.load(x_ptr + elem_offs, mask=elem_mask, other=0.0).to(tl.float32) + + global_amax = tl.load(global_amax_ptr).to(tl.float32) + + best_loss = tl.full([BLOCKS_PER_PROGRAM], float("inf"), dtype=tl.float32) + best_idx = tl.zeros([BLOCKS_PER_PROGRAM], dtype=tl.int32) + + # Loop over the 126 FP8 candidates (compile-time unrolled). + for k in tl.static_range(NUM_CANDIDATES): + c = tl.load(candidates_ptr + k).to(tl.float32) + # block_amax = global_amax * c by construction; the FP8 round on the resulting + # scale is the identity for our candidate set, so we can skip the FP8 cast. + scale = c * global_amax / 6.0 + w_q = nvfp4_scalar_quant(w, scale, BLOCK_SIZE) + diff = w - w_q + loss = tl.sum(diff * diff, axis=1) # [BLOCKS_PER_PROGRAM] + is_better = loss < best_loss + best_loss = tl.where(is_better, loss, best_loss) + best_idx = tl.where(is_better, k, best_idx) + + # Map each block's winning candidate index back to its amax = global_amax * c[best]. + best_c = tl.load(candidates_ptr + best_idx, mask=block_mask, other=0.0).to(tl.float32) + best_amax = global_amax * best_c + tl.store(best_amax_ptr + block_idx, best_amax, mask=block_mask) + + +def nvfp4_fp8_scale_sweep( + x: torch.Tensor, + global_amax: torch.Tensor, + block_size: int = 16, + candidates: torch.Tensor | None = None, + blocks_per_program: int = 4, +) -> torch.Tensor: + """Find the per-block FP8 scale that minimizes NVFP4 quantization MSE. + + Equivalent to the 126-step sweep in :class:`NVFP4MSECalibrator`, but fused into + a single Triton kernel: every block's weight elements are loaded once, all 126 + candidates are evaluated in registers, and the running argmin is kept inline. + + Args: + x: Weight tensor on CUDA. Total element count must be divisible by + ``block_size``; layout is treated as a flat ``[N_BLOCKS, BLOCK_SIZE]``. + global_amax: Scalar FP32 global amax (``= reduce_amax(per_block_amax)``). + block_size: NVFP4 block size (typically 16). + candidates: Optional precomputed candidate tensor of shape ``[126]`` (must + be the FP8 E4M3 valid values divided by 448). Built lazily if omitted. + blocks_per_program: Number of blocks each Triton program handles. Trades + launch overhead for register pressure; 4 is a reasonable default. + + Returns: + ``best_amax`` of shape ``[N_BLOCKS]``, fp32, on the same device as ``x``. + """ + assert x.is_cuda, "nvfp4_fp8_scale_sweep requires a CUDA tensor" + if x.numel() % block_size != 0: + raise ValueError(f"x.numel() ({x.numel()}) is not divisible by block_size ({block_size}).") + + if candidates is None: + candidates = fp8_scale_candidates(x.device) + candidates = candidates.contiguous().to(device=x.device, dtype=torch.float32) + + n_blocks = x.numel() // block_size + x_flat = x.contiguous().view(-1) + global_amax_f32 = global_amax.detach().to(device=x.device, dtype=torch.float32).reshape(1) + best_amax = torch.empty(n_blocks, dtype=torch.float32, device=x.device) + + grid = (triton.cdiv(n_blocks, blocks_per_program),) + with torch.cuda.device(x.device): + _fp8_scale_sweep_kernel[grid]( + x_flat, + candidates, + global_amax_f32, + best_amax, + n_blocks, + BLOCK_SIZE=block_size, + NUM_CANDIDATES=int(candidates.numel()), + BLOCKS_PER_PROGRAM=blocks_per_program, + ) + return best_amax diff --git a/modelopt/torch/quantization/calib/mse.py b/modelopt/torch/quantization/calib/mse.py index 1f439a7e778..a879790cc42 100644 --- a/modelopt/torch/quantization/calib/mse.py +++ b/modelopt/torch/quantization/calib/mse.py @@ -24,7 +24,7 @@ from .. import utils as quant_utils from .calibrator import _Calibrator -__all__ = ["MseCalibrator", "NVFP4MSECalibrator"] +__all__ = ["MseCalibrator", "NVFP4MSECalibrator", "TritonNVFP4MSECalibrator"] class MseCalibrator(_Calibrator): @@ -198,3 +198,82 @@ def _generate_candidates(self, device: torch.device) -> torch.Tensor: valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0) fp8_values = fp8_values[valid_mask] return fp8_values / 448.0 + + +class TritonNVFP4MSECalibrator(NVFP4MSECalibrator): + """Triton-fused FP8 scale sweep calibrator for NVFP4 weight MSE. + + Numerically equivalent to :class:`NVFP4MSECalibrator` but evaluates all 126 + candidates in a single fused Triton kernel — one weight read instead of 126. + + Limitation: a single ``collect()`` call is supported per ``compute_amax`` cycle. + This matches the static weight-MSE flow (``mse_calibrate``'s weight loop), where + the calibrator is collected once per weight and immediately consumed. For + activation calibration (multiple ``collect`` calls), use :class:`NVFP4MSECalibrator`. + """ + + def __init__( + self, + amax: torch.Tensor, + global_amax: torch.Tensor, + axis: int | tuple | list | None = None, + quant_func: Callable | None = None, + error_func: Callable | None = None, + blocks_per_program: int = 4, + ): + """Initialize the Triton-fused NVFP4 MSE calibrator. + + See :class:`NVFP4MSECalibrator`. ``quant_func``/``error_func`` are unused by + the kernel path but accepted for API parity. + """ + super().__init__( + amax=amax, + global_amax=global_amax, + axis=axis, + quant_func=quant_func, + error_func=error_func, + ) + self._blocks_per_program = blocks_per_program + self._best_amax: torch.Tensor | None = None + + @torch.no_grad() + def collect(self, x: torch.Tensor): + """Run the fused FP8 sweep kernel and store the resulting per-block amax.""" + from modelopt.torch.kernels.quantization.gemm import nvfp4_fp8_scale_sweep + + if self._best_amax is not None: + raise RuntimeError( + "TritonNVFP4MSECalibrator only supports a single collect() per cycle; " + "call reset() before collecting again." + ) + + x = x.detach() + block_size = x.shape[-1] + n_blocks = x.numel() // block_size + if self._initial_amax.numel() != n_blocks: + raise ValueError( + f"initial_amax.numel() ({self._initial_amax.numel()}) does not match " + f"the number of NVFP4 blocks ({n_blocks})." + ) + + best_amax_flat = nvfp4_fp8_scale_sweep( + x, + self._global_amax, + block_size=block_size, + blocks_per_program=self._blocks_per_program, + ) + # Match the original shape/dtype of _initial_amax so downstream load_calib_amax + # behaves identically to the reference path. + self._best_amax = best_amax_flat.reshape(self._initial_amax.shape).to( + self._initial_amax.dtype + ) + + @torch.no_grad() + def compute_amax(self, verbose: bool = False): + """Return the per-block amax computed during ``collect``.""" + return self._best_amax + + def reset(self): + """Reset the stored best amax.""" + self._best_amax = None + super().reset() diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index aeae3dd4321..9920802567c 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -16,6 +16,7 @@ """Calibration utilities.""" import math +import os import time import warnings from collections.abc import Callable @@ -37,7 +38,7 @@ from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method -from .calib import MseCalibrator, NVFP4MSECalibrator, _Calibrator +from .calib import MseCalibrator, NVFP4MSECalibrator, TritonNVFP4MSECalibrator, _Calibrator from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context from .nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer from .utils import ( @@ -391,8 +392,13 @@ def mse_calibrate( continue if fp8_scale_sweep and is_nvfp4_static: - # Replace calibrator with NVFP4MSECalibrator - module._calibrator = NVFP4MSECalibrator( + # Replace calibrator with the fused Triton sweep kernel by default + # (single-shot collect, ~7-20x faster for the weight-MSE phase). + # Set MODELOPT_NVFP4_TRITON_SWEEP=0 to fall back to the reference + # NVFP4MSECalibrator for debugging or numerics comparison. + use_triton = os.environ.get("MODELOPT_NVFP4_TRITON_SWEEP", "1") != "0" + cls = TritonNVFP4MSECalibrator if use_triton else NVFP4MSECalibrator + module._calibrator = cls( amax=initial_amax, axis=module._calibrator._axis, global_amax=module.global_amax, diff --git a/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py b/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py new file mode 100644 index 00000000000..f1ac5b7f24d --- /dev/null +++ b/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py @@ -0,0 +1,221 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Parity + speedup tests for the fused NVFP4 FP8 scale sweep Triton kernel. + +Compares :class:`TritonNVFP4MSECalibrator` against the reference +:class:`NVFP4MSECalibrator` on the same inputs and asserts the resulting per-block +amax tensors are bit-identical. Also reports a wall-clock speedup number for the +weight-MSE search step on a representative LLM-sized weight. +""" + +import time + +import pytest +import torch +from conftest import requires_triton + +from modelopt.torch.quantization.calib import NVFP4MSECalibrator, TritonNVFP4MSECalibrator +from modelopt.torch.quantization.tensor_quant import static_blockwise_fp4_fake_quant + +BLOCK_SIZE = 16 + + +def _reference_quant_func(global_amax): + """Reference NVFP4 fake-quant matching what ``mse_calibrate`` plumbs in.""" + + def quant_func(x, amax): + return static_blockwise_fp4_fake_quant(x, amax, global_amax) + + return quant_func + + +def _run_reference(x, per_block_amax, global_amax): + cal = NVFP4MSECalibrator( + amax=per_block_amax, + axis=0, + global_amax=global_amax, + quant_func=_reference_quant_func(global_amax), + ) + cal.collect(x) + return cal.compute_amax() + + +def _run_triton(x, per_block_amax, global_amax): + cal = TritonNVFP4MSECalibrator( + amax=per_block_amax, + axis=0, + global_amax=global_amax, + quant_func=_reference_quant_func(global_amax), + ) + cal.collect(x) + return cal.compute_amax() + + +@requires_triton +@pytest.mark.parametrize("seed", [0, 1, 2]) +@pytest.mark.parametrize("num_blocks", [4, 64, 1024]) +def test_parity_random_weights(seed, num_blocks): + """Triton sweep must produce the exact same per-block amax as the reference.""" + torch.manual_seed(seed) + device = "cuda" + x = torch.randn(num_blocks, BLOCK_SIZE, device=device, dtype=torch.float32) + per_block_amax = x.abs().amax(dim=-1) + global_amax = per_block_amax.max() + + ref = _run_reference(x, per_block_amax, global_amax) + tri = _run_triton(x, per_block_amax, global_amax) + + assert ref.shape == tri.shape + # Both pick from the same 126-element discrete candidate set, so any disagreement + # would show up as a non-zero diff (not a small float epsilon). Demand exact match. + assert torch.equal(ref, tri), ( + f"Triton sweep diverged from reference: max |diff| = " + f"{(ref - tri).abs().max().item():.3e}, " + f"differing blocks = {(ref != tri).sum().item()} / {num_blocks}" + ) + + +@requires_triton +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +def test_parity_dtypes(dtype): + """Sweep must agree across the dtypes supported by the NVFP4 quantizer.""" + torch.manual_seed(42) + device = "cuda" + num_blocks = 256 + x = torch.randn(num_blocks, BLOCK_SIZE, device=device, dtype=dtype) + # Promote to fp32 for the per-block amax (matches what max_calibrate produces). + per_block_amax = x.float().abs().amax(dim=-1) + global_amax = per_block_amax.max() + + ref = _run_reference(x, per_block_amax, global_amax) + tri = _run_triton(x, per_block_amax, global_amax) + assert torch.equal(ref, tri) + + +@requires_triton +def test_quantized_output_matches(): + """Round-tripping x through the chosen amax should give the same fake-quant result.""" + torch.manual_seed(7) + device = "cuda" + num_blocks = 128 + x = torch.randn(num_blocks, BLOCK_SIZE, device=device, dtype=torch.float32) + per_block_amax = x.abs().amax(dim=-1) + global_amax = per_block_amax.max() + + ref_amax = _run_reference(x, per_block_amax, global_amax) + tri_amax = _run_triton(x, per_block_amax, global_amax) + + ref_xq = static_blockwise_fp4_fake_quant(x, ref_amax, global_amax) + tri_xq = static_blockwise_fp4_fake_quant(x, tri_amax, global_amax) + assert torch.equal(ref_xq, tri_xq) + + +@requires_triton +def test_reset_allows_recollect(): + torch.manual_seed(0) + device = "cuda" + num_blocks = 32 + x = torch.randn(num_blocks, BLOCK_SIZE, device=device, dtype=torch.float32) + per_block_amax = x.abs().amax(dim=-1) + global_amax = per_block_amax.max() + + cal = TritonNVFP4MSECalibrator( + amax=per_block_amax, + axis=0, + global_amax=global_amax, + ) + cal.collect(x) + first = cal.compute_amax().clone() + + with pytest.raises(RuntimeError, match="single collect"): + cal.collect(x) + + cal.reset() + # After reset the calibrator's _initial_amax has been freed; reconstruct. + cal2 = TritonNVFP4MSECalibrator( + amax=per_block_amax, + axis=0, + global_amax=global_amax, + ) + cal2.collect(x) + assert torch.equal(first, cal2.compute_amax()) + + +def _bench(fn, warmup=2, iters=5): + for _ in range(warmup): + fn() + torch.cuda.synchronize() + t0 = time.perf_counter() + for _ in range(iters): + fn() + torch.cuda.synchronize() + return (time.perf_counter() - t0) / iters + + +@requires_triton +def test_speedup_report(capsys): + """Sanity-check that the Triton path is meaningfully faster on a realistic weight. + + Uses an 8192 x 4096 weight (~33M elements, ~2M NVFP4 blocks) — roughly the size + of an LLM attention/MLP projection. Reports the speedup; does not gate on a + minimum factor (kernel timing is noisy on shared CI), but does require parity + on the chosen amax. + """ + torch.manual_seed(123) + device = "cuda" + cout, cin = 8192, 4096 + x = torch.randn(cout, cin // BLOCK_SIZE, BLOCK_SIZE, device=device, dtype=torch.float32) + x = x.reshape(-1, BLOCK_SIZE) + per_block_amax = x.abs().amax(dim=-1) + global_amax = per_block_amax.max() + + ref_amax = _run_reference(x, per_block_amax, global_amax) + tri_amax = _run_triton(x, per_block_amax, global_amax) + # Bit-equality across millions of blocks isn't guaranteed: when two adjacent FP8 + # candidates yield near-identical per-block MSE (within fp32 noise), the reference's + # CUDA fake_e4m3fy path and our Triton inline math can break ties differently. Demand + # instead that the Triton choice produces a per-block MSE within fp32 epsilon of the + # reference's choice. + n_blocks = ref_amax.numel() + n_diff = int((ref_amax != tri_amax).sum()) + if n_diff: + ref_xq = static_blockwise_fp4_fake_quant(x, ref_amax, global_amax) + tri_xq = static_blockwise_fp4_fake_quant(x, tri_amax, global_amax) + per_block_mse_ref = (x - ref_xq).pow(2).sum(dim=-1) + per_block_mse_tri = (x - tri_xq).pow(2).sum(dim=-1) + # Reference is the formal argmin, so triton's loss should be ≥ reference's. + # Allow at most 1e-5 relative gap on differing blocks (observed ~1e-7 in practice). + rel_gap = (per_block_mse_tri - per_block_mse_ref).abs() / per_block_mse_ref.clamp_min(1e-12) + worst = rel_gap.max().item() + assert worst < 1e-5, ( + f"{n_diff}/{n_blocks} blocks disagree with worst relative MSE gap {worst:.3e} " + "— exceeds tie-break tolerance" + ) + + ref_t = _bench(lambda: _run_reference(x, per_block_amax, global_amax)) + tri_t = _bench(lambda: _run_triton(x, per_block_amax, global_amax)) + speedup = ref_t / tri_t + + # Force-print regardless of pytest capture mode. + with capsys.disabled(): + n_blocks = x.numel() // BLOCK_SIZE + print( + f"\n[NVFP4 FP8 sweep] weight=({cout},{cin}) " + f"n_blocks={n_blocks} block_size={BLOCK_SIZE}\n" + f" reference NVFP4MSECalibrator: {ref_t * 1e3:8.2f} ms\n" + f" triton TritonNVFP4MSECalibrator: {tri_t * 1e3:8.2f} ms\n" + f" speedup: {speedup:.1f}x" + ) From 7a132d78f9be7de5606c9106ae5fd8cc7f57e518 Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Mon, 4 May 2026 21:11:09 +0000 Subject: [PATCH 2/8] [Quantization] Autotune NVFP4 FP8 sweep kernel; drop sign-where in inner loop MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two follow-on optimizations to the fused FP8 scale sweep kernel: 1. @triton.autotune over (BLOCKS_PER_PROGRAM, num_warps): a hand-sweep on B300 showed the previous default (BPP=4, num_warps=4) at 23.7 ms left ~4x on the table — best config (BPP=64, num_warps=8) lands at ~5 ms. Three configs are included to cover small/medium/large N_BLOCKS without flooding compile time. 2. Drop the sign-handling tl.where: since FP4 quantization preserves sign, (w - w_q)^2 == (|w| - |w_q|)^2, so the kernel works on |w| throughout and skips one tl.where + negation per element per candidate. Result on the same 8192x4096 weight (~2M blocks) on B300: reference NVFP4MSECalibrator: 176.68 ms triton TritonNVFP4MSECalibrator: 4.23 ms speedup: 41.8x (was 7.4x) This is ~1.2x above the rough pure-compute floor (~240 GF / 67 TF/s ~= 3.6 ms), so the kernel is now near saturation and further wins would need an algorithmic change (candidate pruning, etc.). Signed-off-by: Chenjie Luo --- .../quantization/gemm/nvfp4_fp8_sweep.py | 41 +++++++++++++------ modelopt/torch/quantization/calib/mse.py | 6 +-- 2 files changed, 31 insertions(+), 16 deletions(-) diff --git a/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py b/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py index 4fdeaf7c104..8492a9c93a5 100644 --- a/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py +++ b/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py @@ -24,13 +24,15 @@ the per-block scale is the identity, so the kernel can use ``scale = candidate * global_amax / 6.0`` without an explicit FP8 cast — making it runnable on any CUDA GPU with Triton (no ``tl.float8e4nv`` requirement). + +Tile shape (``BLOCKS_PER_PROGRAM``) and ``num_warps`` are autotuned per ``N_BLOCKS``. """ import torch import triton import triton.language as tl -from .nvfp4_quant import nvfp4_scalar_quant +from .nvfp4_quant import fp4_round_magnitude __all__ = ["fp8_scale_candidates", "nvfp4_fp8_scale_sweep"] @@ -43,6 +45,18 @@ def fp8_scale_candidates(device: torch.device | str = "cpu") -> torch.Tensor: return fp8_values[valid_mask] / 448.0 +# Selected from a (BLOCKS_PER_PROGRAM, num_warps) sweep on B300: +# BPP=16,nw=2: 6.06 ms BPP=32,nw=4: 6.06 ms BPP=64,nw=8: 5.08 ms +# The smaller-tile entries cover cases where N_BLOCKS is small enough that BPP=64 +# would underfill the SMs. +_FP8_SWEEP_AUTOTUNE_CONFIGS = [ + triton.Config({"BLOCKS_PER_PROGRAM": 16}, num_warps=2), + triton.Config({"BLOCKS_PER_PROGRAM": 32}, num_warps=4), + triton.Config({"BLOCKS_PER_PROGRAM": 64}, num_warps=8), +] + + +@triton.autotune(configs=_FP8_SWEEP_AUTOTUNE_CONFIGS, key=["N_BLOCKS"]) @triton.jit def _fp8_scale_sweep_kernel( x_ptr, # [N_BLOCKS * BLOCK_SIZE], any float dtype (loaded as fp32) @@ -59,10 +73,13 @@ def _fp8_scale_sweep_kernel( block_idx = block_start + tl.arange(0, BLOCKS_PER_PROGRAM) block_mask = block_idx < N_BLOCKS - # Load weights for this tile: [BLOCKS_PER_PROGRAM, BLOCK_SIZE] + # Load weights for this tile and pre-compute their absolute values once. + # The squared error is sign-invariant since FP4 quant preserves sign: + # (w - w_q)^2 = (|w| - |w_q|)^2 = (|w| - q_mag * scale)^2 + # so we never need ``w`` itself again, dropping a tl.where + negation per element. elem_offs = block_idx[:, None] * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[None, :] elem_mask = block_mask[:, None] - w = tl.load(x_ptr + elem_offs, mask=elem_mask, other=0.0).to(tl.float32) + w_abs = tl.abs(tl.load(x_ptr + elem_offs, mask=elem_mask, other=0.0).to(tl.float32)) global_amax = tl.load(global_amax_ptr).to(tl.float32) @@ -70,13 +87,17 @@ def _fp8_scale_sweep_kernel( best_idx = tl.zeros([BLOCKS_PER_PROGRAM], dtype=tl.int32) # Loop over the 126 FP8 candidates (compile-time unrolled). + # Scales are guaranteed positive and finite (constructed from a positive candidate + # times nonneg global_amax), so the degenerate-scale guard from nvfp4_scalar_quant is + # unnecessary apart from the global_amax == 0 case handled below. for k in tl.static_range(NUM_CANDIDATES): c = tl.load(candidates_ptr + k).to(tl.float32) - # block_amax = global_amax * c by construction; the FP8 round on the resulting - # scale is the identity for our candidate set, so we can skip the FP8 cast. scale = c * global_amax / 6.0 - w_q = nvfp4_scalar_quant(w, scale, BLOCK_SIZE) - diff = w - w_q + # Avoid divide-by-zero when global_amax == 0; the resulting err == w_abs² is + # the same for every candidate, so any best_idx is fine. + scale_safe = tl.where(scale == 0.0, 1.0, scale) + q_mag = fp4_round_magnitude(w_abs / scale_safe) + diff = w_abs - q_mag * scale loss = tl.sum(diff * diff, axis=1) # [BLOCKS_PER_PROGRAM] is_better = loss < best_loss best_loss = tl.where(is_better, loss, best_loss) @@ -93,7 +114,6 @@ def nvfp4_fp8_scale_sweep( global_amax: torch.Tensor, block_size: int = 16, candidates: torch.Tensor | None = None, - blocks_per_program: int = 4, ) -> torch.Tensor: """Find the per-block FP8 scale that minimizes NVFP4 quantization MSE. @@ -108,8 +128,6 @@ def nvfp4_fp8_scale_sweep( block_size: NVFP4 block size (typically 16). candidates: Optional precomputed candidate tensor of shape ``[126]`` (must be the FP8 E4M3 valid values divided by 448). Built lazily if omitted. - blocks_per_program: Number of blocks each Triton program handles. Trades - launch overhead for register pressure; 4 is a reasonable default. Returns: ``best_amax`` of shape ``[N_BLOCKS]``, fp32, on the same device as ``x``. @@ -127,7 +145,7 @@ def nvfp4_fp8_scale_sweep( global_amax_f32 = global_amax.detach().to(device=x.device, dtype=torch.float32).reshape(1) best_amax = torch.empty(n_blocks, dtype=torch.float32, device=x.device) - grid = (triton.cdiv(n_blocks, blocks_per_program),) + grid = lambda meta: (triton.cdiv(n_blocks, meta["BLOCKS_PER_PROGRAM"]),) with torch.cuda.device(x.device): _fp8_scale_sweep_kernel[grid]( x_flat, @@ -137,6 +155,5 @@ def nvfp4_fp8_scale_sweep( n_blocks, BLOCK_SIZE=block_size, NUM_CANDIDATES=int(candidates.numel()), - BLOCKS_PER_PROGRAM=blocks_per_program, ) return best_amax diff --git a/modelopt/torch/quantization/calib/mse.py b/modelopt/torch/quantization/calib/mse.py index a879790cc42..7471ec23bb2 100644 --- a/modelopt/torch/quantization/calib/mse.py +++ b/modelopt/torch/quantization/calib/mse.py @@ -219,12 +219,12 @@ def __init__( axis: int | tuple | list | None = None, quant_func: Callable | None = None, error_func: Callable | None = None, - blocks_per_program: int = 4, ): """Initialize the Triton-fused NVFP4 MSE calibrator. See :class:`NVFP4MSECalibrator`. ``quant_func``/``error_func`` are unused by - the kernel path but accepted for API parity. + the kernel path but accepted for API parity. Tile shape and ``num_warps`` are + autotuned by the kernel per ``N_BLOCKS``. """ super().__init__( amax=amax, @@ -233,7 +233,6 @@ def __init__( quant_func=quant_func, error_func=error_func, ) - self._blocks_per_program = blocks_per_program self._best_amax: torch.Tensor | None = None @torch.no_grad() @@ -260,7 +259,6 @@ def collect(self, x: torch.Tensor): x, self._global_amax, block_size=block_size, - blocks_per_program=self._blocks_per_program, ) # Match the original shape/dtype of _initial_amax so downstream load_calib_amax # behaves identically to the reference path. From 4644ed11cadaa76d0a2e26cb8ee88e683d7feed8 Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Mon, 4 May 2026 22:05:02 +0000 Subject: [PATCH 3/8] [Quantization] Address PR review feedback on FP8 sweep kernel Addresses review comments on PR #1387: - TritonNVFP4MSECalibrator.reset() now leaves the calibrator reusable: shape / dtype / n_blocks of the initial amax are stashed in __init__, so collect() no longer depends on _initial_amax surviving reset(). Adds an x.ndim==2 assertion in collect() since the weight quantizer always reshapes upstream. - nvfp4_fp8_scale_sweep validates inputs cleanly instead of using assert (which is stripped by python -O): rejects non-CUDA tensors, non-positive block_size, and empty / non-1D candidates with ValueError. Skips the per-element finite/positive check on candidates since it would scan a 126- entry tensor on every kernel call. - mse_calibrate hoists the MODELOPT_NVFP4_TRITON_SWEEP env-var lookup out of the per-quantizer loop and resolves to the calibrator class once. - Updates test_reset_allows_recollect to verify the new reuse contract; adds test_input_validation covering the new ValueErrors. The duplicate fp8_scale_candidates implementation in the kernel file and NVFP4MSECalibrator._generate_candidates() is left in place: deduplicating would force the reference path to import from the kernel module, which is gated behind Triton availability. The FP8 E4M3 spec is fixed and the parity test exercises both paths against each other. Signed-off-by: Chenjie Luo --- .../quantization/gemm/nvfp4_fp8_sweep.py | 9 +++- modelopt/torch/quantization/calib/mse.py | 36 ++++++++++----- modelopt/torch/quantization/model_calib.py | 13 +++--- .../test_nvfp4_fp8_sweep_kernel.py | 44 +++++++++++++++---- 4 files changed, 74 insertions(+), 28 deletions(-) diff --git a/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py b/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py index 8492a9c93a5..4b9f19837f2 100644 --- a/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py +++ b/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py @@ -132,13 +132,20 @@ def nvfp4_fp8_scale_sweep( Returns: ``best_amax`` of shape ``[N_BLOCKS]``, fp32, on the same device as ``x``. """ - assert x.is_cuda, "nvfp4_fp8_scale_sweep requires a CUDA tensor" + if not x.is_cuda: + raise ValueError("nvfp4_fp8_scale_sweep requires a CUDA tensor.") + if not isinstance(block_size, int) or block_size <= 0: + raise ValueError(f"block_size must be a positive int, got {block_size!r}.") if x.numel() % block_size != 0: raise ValueError(f"x.numel() ({x.numel()}) is not divisible by block_size ({block_size}).") if candidates is None: candidates = fp8_scale_candidates(x.device) candidates = candidates.contiguous().to(device=x.device, dtype=torch.float32) + if candidates.ndim != 1 or candidates.numel() == 0: + raise ValueError( + f"candidates must be a non-empty 1-D tensor; got shape {tuple(candidates.shape)}." + ) n_blocks = x.numel() // block_size x_flat = x.contiguous().view(-1) diff --git a/modelopt/torch/quantization/calib/mse.py b/modelopt/torch/quantization/calib/mse.py index 7471ec23bb2..fff0b8af1b2 100644 --- a/modelopt/torch/quantization/calib/mse.py +++ b/modelopt/torch/quantization/calib/mse.py @@ -192,7 +192,12 @@ def _compute_candidate_amax(self, candidates: torch.Tensor) -> torch.Tensor: return torch.ones_like(self._initial_amax) * self._global_amax * candidates def _generate_candidates(self, device: torch.device) -> torch.Tensor: - """Generate 126 valid FP8 E4M3 scale candidates.""" + """Generate 126 valid FP8 E4M3 scale candidates. + + Kept in sync with ``fp8_scale_candidates`` in + ``modelopt.torch.kernels.quantization.gemm.nvfp4_fp8_sweep`` — the FP8 E4M3 + spec is fixed, and the parity test exercises both paths against each other. + """ uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device) fp8_values = uint8_values.view(torch.float8_e4m3fn).float() valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0) @@ -210,6 +215,7 @@ class TritonNVFP4MSECalibrator(NVFP4MSECalibrator): This matches the static weight-MSE flow (``mse_calibrate``'s weight loop), where the calibrator is collected once per weight and immediately consumed. For activation calibration (multiple ``collect`` calls), use :class:`NVFP4MSECalibrator`. + Call :meth:`reset` to free internal state and re-enable :meth:`collect`. """ def __init__( @@ -233,6 +239,11 @@ def __init__( quant_func=quant_func, error_func=error_func, ) + # Stash shape metadata so collect() can keep working after reset() releases + # the (potentially large) _initial_amax buffer. + self._initial_amax_shape = tuple(amax.shape) + self._initial_amax_dtype = amax.dtype + self._n_blocks = int(amax.numel()) self._best_amax: torch.Tensor | None = None @torch.no_grad() @@ -242,17 +253,20 @@ def collect(self, x: torch.Tensor): if self._best_amax is not None: raise RuntimeError( - "TritonNVFP4MSECalibrator only supports a single collect() per cycle; " - "call reset() before collecting again." + "TritonNVFP4MSECalibrator.collect() is one-shot; call reset() to " + "discard the previous result before collecting again." ) x = x.detach() + # The weight quantizer reshapes its input to [n_blocks, block_size] before + # calling collect (see TensorQuantizer._process_for_blockquant). + assert x.ndim == 2, f"Expected x to be [n_blocks, block_size]; got shape {tuple(x.shape)}." block_size = x.shape[-1] n_blocks = x.numel() // block_size - if self._initial_amax.numel() != n_blocks: + if n_blocks != self._n_blocks: raise ValueError( - f"initial_amax.numel() ({self._initial_amax.numel()}) does not match " - f"the number of NVFP4 blocks ({n_blocks})." + f"initial amax.numel() ({self._n_blocks}) does not match the number " + f"of NVFP4 blocks in x ({n_blocks})." ) best_amax_flat = nvfp4_fp8_scale_sweep( @@ -260,10 +274,10 @@ def collect(self, x: torch.Tensor): self._global_amax, block_size=block_size, ) - # Match the original shape/dtype of _initial_amax so downstream load_calib_amax - # behaves identically to the reference path. - self._best_amax = best_amax_flat.reshape(self._initial_amax.shape).to( - self._initial_amax.dtype + # Match the original shape/dtype of the initial amax so downstream + # load_calib_amax behaves identically to the reference path. + self._best_amax = best_amax_flat.reshape(self._initial_amax_shape).to( + self._initial_amax_dtype ) @torch.no_grad() @@ -272,6 +286,6 @@ def compute_amax(self, verbose: bool = False): return self._best_amax def reset(self): - """Reset the stored best amax.""" + """Reset the stored best amax. Subsequent ``collect`` calls are allowed.""" self._best_amax = None super().reset() diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 9920802567c..d4b5d78f016 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -355,6 +355,11 @@ def mse_calibrate( weight_quantizers = [] seen_modules = set() + # Triton-fused FP8 sweep is on by default for NVFP4 static quant; set + # MODELOPT_NVFP4_TRITON_SWEEP=0 to fall back to the reference for debugging. + use_triton_fp8_sweep = os.environ.get("MODELOPT_NVFP4_TRITON_SWEEP", "1") != "0" + nvfp4_calibrator_cls = TritonNVFP4MSECalibrator if use_triton_fp8_sweep else NVFP4MSECalibrator + for name, module in list(model.named_modules()): if isinstance(module, TensorQuantizer) and not module._disabled: if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): @@ -392,13 +397,7 @@ def mse_calibrate( continue if fp8_scale_sweep and is_nvfp4_static: - # Replace calibrator with the fused Triton sweep kernel by default - # (single-shot collect, ~7-20x faster for the weight-MSE phase). - # Set MODELOPT_NVFP4_TRITON_SWEEP=0 to fall back to the reference - # NVFP4MSECalibrator for debugging or numerics comparison. - use_triton = os.environ.get("MODELOPT_NVFP4_TRITON_SWEEP", "1") != "0" - cls = TritonNVFP4MSECalibrator if use_triton else NVFP4MSECalibrator - module._calibrator = cls( + module._calibrator = nvfp4_calibrator_cls( amax=initial_amax, axis=module._calibrator._axis, global_amax=module.global_amax, diff --git a/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py b/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py index f1ac5b7f24d..c25867d8327 100644 --- a/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py +++ b/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py @@ -140,18 +140,44 @@ def test_reset_allows_recollect(): cal.collect(x) first = cal.compute_amax().clone() - with pytest.raises(RuntimeError, match="single collect"): + # collect() is one-shot per cycle until reset() is called. + with pytest.raises(RuntimeError, match="one-shot"): cal.collect(x) cal.reset() - # After reset the calibrator's _initial_amax has been freed; reconstruct. - cal2 = TritonNVFP4MSECalibrator( - amax=per_block_amax, - axis=0, - global_amax=global_amax, - ) - cal2.collect(x) - assert torch.equal(first, cal2.compute_amax()) + # After reset, the same calibrator instance can be re-used. + cal.collect(x) + assert torch.equal(first, cal.compute_amax()) + + +@requires_triton +def test_input_validation(): + """``nvfp4_fp8_scale_sweep`` should reject malformed inputs cleanly.""" + from modelopt.torch.kernels.quantization.gemm import fp8_scale_candidates, nvfp4_fp8_scale_sweep + + device = "cuda" + x = torch.randn(64, BLOCK_SIZE, device=device) + g = x.abs().amax() + + # CPU tensor → ValueError (not bare AssertionError). + with pytest.raises(ValueError, match="CUDA"): + nvfp4_fp8_scale_sweep(x.cpu(), g.cpu()) + + # block_size <= 0. + with pytest.raises(ValueError, match="block_size"): + nvfp4_fp8_scale_sweep(x, g, block_size=0) + with pytest.raises(ValueError, match="block_size"): + nvfp4_fp8_scale_sweep(x, g, block_size=-1) + + # Non-divisible numel. + with pytest.raises(ValueError, match="not divisible"): + nvfp4_fp8_scale_sweep(x, g, block_size=15) + + # Empty / wrong-rank candidates. + with pytest.raises(ValueError, match="non-empty 1-D"): + nvfp4_fp8_scale_sweep(x, g, candidates=torch.empty(0, device=device)) + with pytest.raises(ValueError, match="non-empty 1-D"): + nvfp4_fp8_scale_sweep(x, g, candidates=fp8_scale_candidates(device).reshape(2, -1)) def _bench(fn, warmup=2, iters=5): From c4be1bb11fa0142f7f8edf1255cec333de196c66 Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Wed, 6 May 2026 17:43:03 +0000 Subject: [PATCH 4/8] [Quantization] Tighten FP8 sweep input contracts and add dispatch test Address realAsma's review feedback on the NVFP4 FP8 sweep kernel: - TritonNVFP4MSECalibrator.collect: replace `assert x.ndim == 2` with ValueError so the contract still holds under `python -O`, validate block_size > 0 before use, and derive n_blocks from x.shape[0] so a zero last-dim cannot trigger division before the shape check. - nvfp4_fp8_scale_sweep: drop the public `candidates` parameter. The candidate set is fixed (FP8 E4M3 valid values / 448) and a wrong length would silently inflate `tl.static_range` codegen, while nonpositive/nonfinite entries violate the kernel's scale assumptions. No internal caller used the override. - Add test_mse_calibrate_dispatch covering the public default + opt-out wiring: confirms `mse_calibrate(fp8_scale_sweep=True)` installs TritonNVFP4MSECalibrator by default and falls back to NVFP4MSECalibrator when MODELOPT_NVFP4_TRITON_SWEEP=0. Signed-off-by: Chenjie Luo --- .../quantization/gemm/nvfp4_fp8_sweep.py | 11 +-- modelopt/torch/quantization/calib/mse.py | 12 +++- .../test_nvfp4_fp8_sweep_kernel.py | 67 +++++++++++++++++-- 3 files changed, 71 insertions(+), 19 deletions(-) diff --git a/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py b/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py index 4b9f19837f2..49e4839a3c1 100644 --- a/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py +++ b/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py @@ -113,7 +113,6 @@ def nvfp4_fp8_scale_sweep( x: torch.Tensor, global_amax: torch.Tensor, block_size: int = 16, - candidates: torch.Tensor | None = None, ) -> torch.Tensor: """Find the per-block FP8 scale that minimizes NVFP4 quantization MSE. @@ -126,8 +125,6 @@ def nvfp4_fp8_scale_sweep( ``block_size``; layout is treated as a flat ``[N_BLOCKS, BLOCK_SIZE]``. global_amax: Scalar FP32 global amax (``= reduce_amax(per_block_amax)``). block_size: NVFP4 block size (typically 16). - candidates: Optional precomputed candidate tensor of shape ``[126]`` (must - be the FP8 E4M3 valid values divided by 448). Built lazily if omitted. Returns: ``best_amax`` of shape ``[N_BLOCKS]``, fp32, on the same device as ``x``. @@ -139,13 +136,7 @@ def nvfp4_fp8_scale_sweep( if x.numel() % block_size != 0: raise ValueError(f"x.numel() ({x.numel()}) is not divisible by block_size ({block_size}).") - if candidates is None: - candidates = fp8_scale_candidates(x.device) - candidates = candidates.contiguous().to(device=x.device, dtype=torch.float32) - if candidates.ndim != 1 or candidates.numel() == 0: - raise ValueError( - f"candidates must be a non-empty 1-D tensor; got shape {tuple(candidates.shape)}." - ) + candidates = fp8_scale_candidates(x.device).to(dtype=torch.float32) n_blocks = x.numel() // block_size x_flat = x.contiguous().view(-1) diff --git a/modelopt/torch/quantization/calib/mse.py b/modelopt/torch/quantization/calib/mse.py index fff0b8af1b2..79961a0b677 100644 --- a/modelopt/torch/quantization/calib/mse.py +++ b/modelopt/torch/quantization/calib/mse.py @@ -259,10 +259,16 @@ def collect(self, x: torch.Tensor): x = x.detach() # The weight quantizer reshapes its input to [n_blocks, block_size] before - # calling collect (see TensorQuantizer._process_for_blockquant). - assert x.ndim == 2, f"Expected x to be [n_blocks, block_size]; got shape {tuple(x.shape)}." + # calling collect (see TensorQuantizer._process_for_blockquant). Validate + # via ValueError so the contract still holds under ``python -O``. + if x.ndim != 2: + raise ValueError( + f"Expected x to be [n_blocks, block_size]; got shape {tuple(x.shape)}." + ) block_size = x.shape[-1] - n_blocks = x.numel() // block_size + if block_size <= 0: + raise ValueError(f"x.shape[-1] must be positive; got {block_size}.") + n_blocks = x.shape[0] if n_blocks != self._n_blocks: raise ValueError( f"initial amax.numel() ({self._n_blocks}) does not match the number " diff --git a/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py b/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py index c25867d8327..14d70d007c3 100644 --- a/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py +++ b/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py @@ -153,7 +153,7 @@ def test_reset_allows_recollect(): @requires_triton def test_input_validation(): """``nvfp4_fp8_scale_sweep`` should reject malformed inputs cleanly.""" - from modelopt.torch.kernels.quantization.gemm import fp8_scale_candidates, nvfp4_fp8_scale_sweep + from modelopt.torch.kernels.quantization.gemm import nvfp4_fp8_scale_sweep device = "cuda" x = torch.randn(64, BLOCK_SIZE, device=device) @@ -173,11 +173,66 @@ def test_input_validation(): with pytest.raises(ValueError, match="not divisible"): nvfp4_fp8_scale_sweep(x, g, block_size=15) - # Empty / wrong-rank candidates. - with pytest.raises(ValueError, match="non-empty 1-D"): - nvfp4_fp8_scale_sweep(x, g, candidates=torch.empty(0, device=device)) - with pytest.raises(ValueError, match="non-empty 1-D"): - nvfp4_fp8_scale_sweep(x, g, candidates=fp8_scale_candidates(device).reshape(2, -1)) + +@requires_triton +def test_mse_calibrate_dispatch(monkeypatch): + """``mse_calibrate(fp8_scale_sweep=True)`` must install the right calibrator class. + + Default path: ``TritonNVFP4MSECalibrator``. + With ``MODELOPT_NVFP4_TRITON_SWEEP=0``: ``NVFP4MSECalibrator`` (and not its subclass). + """ + from _test_utils.torch.quantization.models import SimpleLinear + + import modelopt.torch.quantization as mtq + from modelopt.torch.quantization.extensions import get_cuda_ext_mx + from modelopt.torch.quantization.nn import TensorQuantizer + + if get_cuda_ext_mx() is None: + pytest.skip("cuda_ext_mx is not available") + + cfg = { + "quant_cfg": [ + { + "quantizer_name": "*weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + }, + "enable": True, + }, + {"quantizer_name": "*input_quantizer", "enable": False}, + ], + "algorithm": {"method": "mse", "fp8_scale_sweep": True}, + } + + def _quantize_and_get_weight_calibrators(model): + calib_data = [model.get_input().cuda() for _ in range(2)] + + def forward_loop(m): + for batch in calib_data: + m(batch) + + mtq.quantize(model, cfg, forward_loop=forward_loop) + return [ + type(m._calibrator) + for name, m in model.named_modules() + if isinstance(m, TensorQuantizer) + and name.endswith("weight_quantizer") + and getattr(m, "_calibrator", None) is not None + ] + + # Default: triton path. + monkeypatch.delenv("MODELOPT_NVFP4_TRITON_SWEEP", raising=False) + types_default = _quantize_and_get_weight_calibrators(SimpleLinear().cuda()) + assert types_default, "expected at least one weight quantizer with a calibrator" + assert all(t is TritonNVFP4MSECalibrator for t in types_default), types_default + + # Opt-out: reference path, exact class match (TritonNVFP4MSECalibrator is a subclass). + monkeypatch.setenv("MODELOPT_NVFP4_TRITON_SWEEP", "0") + types_optout = _quantize_and_get_weight_calibrators(SimpleLinear().cuda()) + assert types_optout, "expected at least one weight quantizer with a calibrator" + assert all(t is NVFP4MSECalibrator for t in types_optout), types_optout def _bench(fn, warmup=2, iters=5): From c10f7846aebd8b7ea453ff2cf1c5b4583db1f0e3 Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Wed, 6 May 2026 23:54:19 +0000 Subject: [PATCH 5/8] [Quantization] Fold Triton FP8 sweep into NVFP4MSECalibrator Per realAsma's review, collapse TritonNVFP4MSECalibrator into NVFP4MSECalibrator as an internal fast path rather than a separately-exported subclass: - mse.py: NVFP4MSECalibrator.collect() picks the fused Triton kernel via a predicate _can_use_triton_fast_path(x) that requires error_func is None, CUDA input, blocked layout matching the per-block amax, the kernel package importable, and MODELOPT_NVFP4_TRITON_SWEEP \!= "0". Otherwise falls back to the parent's reference 126-step sweep. Override reset() to clear only per-cycle state and keep _initial_amax (shape [num_blocks], small) so the calibrator is reusable; the multi-collect-after-fast-path case raises a RuntimeError with a clear message. TritonNVFP4MSECalibrator class deleted. - model_calib.py: always instantiate NVFP4MSECalibrator; drop the TritonNVFP4MSECalibrator import and the env-var dispatch (now internal). - tests: drop the TritonNVFP4MSECalibrator references. Force the requested path via a _force_sweep_path() context manager around the env var. New dispatch tests assert the predicate's behavior for the env opt-out, custom error_func, and CPU input cases. test_mse_calibrate_end_to_end exercises the full mtq.quantize wiring with default and MODELOPT_NVFP4_TRITON_SWEEP=0 and asserts bitwise-identical model outputs. This fixes a latent correctness issue: the previous TritonNVFP4MSECalibrator silently ignored a custom error_func, so a caller passing a Hessian-weighted loss (e.g. local-Hessian calibration) would have gotten plain squared-error results from the kernel. The new predicate routes any non-None error_func to the reference path so the user's metric is honored. Signed-off-by: Chenjie Luo --- modelopt/torch/quantization/calib/mse.py | 147 ++++++------ modelopt/torch/quantization/model_calib.py | 13 +- .../test_nvfp4_fp8_sweep_kernel.py | 219 +++++++++++++----- 3 files changed, 232 insertions(+), 147 deletions(-) diff --git a/modelopt/torch/quantization/calib/mse.py b/modelopt/torch/quantization/calib/mse.py index 79961a0b677..417d0173853 100644 --- a/modelopt/torch/quantization/calib/mse.py +++ b/modelopt/torch/quantization/calib/mse.py @@ -16,6 +16,7 @@ """Calibrator that returns the MSE amax of all collected tensors.""" import math +import os from collections.abc import Callable import torch @@ -24,7 +25,7 @@ from .. import utils as quant_utils from .calibrator import _Calibrator -__all__ = ["MseCalibrator", "NVFP4MSECalibrator", "TritonNVFP4MSECalibrator"] +__all__ = ["MseCalibrator", "NVFP4MSECalibrator"] class MseCalibrator(_Calibrator): @@ -172,7 +173,15 @@ def compute_amax(self, verbose: bool = False): class NVFP4MSECalibrator(MseCalibrator): - """Per-block FP8 scale sweep calibrator for NVFP4 static quantization.""" + """Per-block FP8 scale sweep calibrator for NVFP4 static quantization. + + Uses a fused Triton kernel as an internal fast path on the first ``collect`` call + when (a) ``error_func is None``, (b) the input tensor is on CUDA in the standard + blocked ``[n_blocks, block_size]`` layout, and (c) Triton + the kernel package are + importable. Falls back to the reference 126-step Python sweep otherwise (custom + ``error_func`` users, multi-``collect`` activation flows, CPU inputs, or when the + fast path is disabled via ``MODELOPT_NVFP4_TRITON_SWEEP=0``). + """ def __init__( self, @@ -185,6 +194,8 @@ def __init__( """Initialize NVFP4 MSE calibrator with per-block and global amax.""" super().__init__(amax=amax, axis=axis, quant_func=quant_func, error_func=error_func) self._global_amax = global_amax + # Set by the Triton fast path on its (one-shot) collect; consumed by compute_amax. + self._best_amax_fast: torch.Tensor | None = None def _compute_candidate_amax(self, candidates: torch.Tensor) -> torch.Tensor: if candidates.ndim != 0: # Called during final compute amax @@ -204,94 +215,70 @@ def _generate_candidates(self, device: torch.device) -> torch.Tensor: fp8_values = fp8_values[valid_mask] return fp8_values / 448.0 + def _can_use_triton_fast_path(self, x: torch.Tensor) -> bool: + """Whether the Triton fast path is usable for this ``collect`` input. -class TritonNVFP4MSECalibrator(NVFP4MSECalibrator): - """Triton-fused FP8 scale sweep calibrator for NVFP4 weight MSE. - - Numerically equivalent to :class:`NVFP4MSECalibrator` but evaluates all 126 - candidates in a single fused Triton kernel — one weight read instead of 126. - - Limitation: a single ``collect()`` call is supported per ``compute_amax`` cycle. - This matches the static weight-MSE flow (``mse_calibrate``'s weight loop), where - the calibrator is collected once per weight and immediately consumed. For - activation calibration (multiple ``collect`` calls), use :class:`NVFP4MSECalibrator`. - Call :meth:`reset` to free internal state and re-enable :meth:`collect`. - """ - - def __init__( - self, - amax: torch.Tensor, - global_amax: torch.Tensor, - axis: int | tuple | list | None = None, - quant_func: Callable | None = None, - error_func: Callable | None = None, - ): - """Initialize the Triton-fused NVFP4 MSE calibrator. - - See :class:`NVFP4MSECalibrator`. ``quant_func``/``error_func`` are unused by - the kernel path but accepted for API parity. Tile shape and ``num_warps`` are - autotuned by the kernel per ``N_BLOCKS``. + The kernel produces the final per-block amax in one shot, so it's only usable + when the caller wants the standard squared-error sweep on a single CUDA tensor + whose layout already matches the per-block amax. """ - super().__init__( - amax=amax, - global_amax=global_amax, - axis=axis, - quant_func=quant_func, - error_func=error_func, - ) - # Stash shape metadata so collect() can keep working after reset() releases - # the (potentially large) _initial_amax buffer. - self._initial_amax_shape = tuple(amax.shape) - self._initial_amax_dtype = amax.dtype - self._n_blocks = int(amax.numel()) - self._best_amax: torch.Tensor | None = None + if self._error_func is not None: + return False + if not x.is_cuda: + return False + if os.environ.get("MODELOPT_NVFP4_TRITON_SWEEP", "1") == "0": + return False + if self._initial_amax is None: + return False + if x.ndim != 2 or x.shape[0] != int(self._initial_amax.numel()): + return False + try: + from modelopt.torch.kernels.quantization.gemm import nvfp4_fp8_scale_sweep # noqa: F401 + except ImportError: + return False + return True @torch.no_grad() def collect(self, x: torch.Tensor): - """Run the fused FP8 sweep kernel and store the resulting per-block amax.""" - from modelopt.torch.kernels.quantization.gemm import nvfp4_fp8_scale_sweep - - if self._best_amax is not None: + """Collect input statistics. Uses the Triton fast path when eligible.""" + if self._best_amax_fast is not None: raise RuntimeError( - "TritonNVFP4MSECalibrator.collect() is one-shot; call reset() to " - "discard the previous result before collecting again." - ) - - x = x.detach() - # The weight quantizer reshapes its input to [n_blocks, block_size] before - # calling collect (see TensorQuantizer._process_for_blockquant). Validate - # via ValueError so the contract still holds under ``python -O``. - if x.ndim != 2: - raise ValueError( - f"Expected x to be [n_blocks, block_size]; got shape {tuple(x.shape)}." + "NVFP4MSECalibrator: the Triton fast path produced a final amax on a " + "previous collect() call; multi-collect after the fast path is not " + "supported. Call reset() to start a fresh cycle, set " + "MODELOPT_NVFP4_TRITON_SWEEP=0, or pass a non-None error_func to force " + "the reference path for activation-style accumulation." ) - block_size = x.shape[-1] - if block_size <= 0: - raise ValueError(f"x.shape[-1] must be positive; got {block_size}.") - n_blocks = x.shape[0] - if n_blocks != self._n_blocks: - raise ValueError( - f"initial amax.numel() ({self._n_blocks}) does not match the number " - f"of NVFP4 blocks in x ({n_blocks})." + # Fast path is eligible only on the first call, before the reference accumulator + # has produced any state. + if self._losses_sum is None and self._can_use_triton_fast_path(x): + from modelopt.torch.kernels.quantization.gemm import nvfp4_fp8_scale_sweep + + best_flat = nvfp4_fp8_scale_sweep(x.detach(), self._global_amax, block_size=x.shape[-1]) + # Match the original shape/dtype of the initial amax so downstream + # load_calib_amax behaves identically to the reference path. + self._best_amax_fast = best_flat.reshape(self._initial_amax.shape).to( + self._initial_amax.dtype ) - - best_amax_flat = nvfp4_fp8_scale_sweep( - x, - self._global_amax, - block_size=block_size, - ) - # Match the original shape/dtype of the initial amax so downstream - # load_calib_amax behaves identically to the reference path. - self._best_amax = best_amax_flat.reshape(self._initial_amax_shape).to( - self._initial_amax_dtype - ) + return + super().collect(x) @torch.no_grad() def compute_amax(self, verbose: bool = False): - """Return the per-block amax computed during ``collect``.""" - return self._best_amax + """Return the per-block amax — from the fast path if it ran, else from the reference sweep.""" + if self._best_amax_fast is not None: + return self._best_amax_fast + return super().compute_amax(verbose=verbose) def reset(self): - """Reset the stored best amax. Subsequent ``collect`` calls are allowed.""" - self._best_amax = None - super().reset() + """Reset per-cycle state. Keep ``_initial_amax`` so the calibrator stays reusable. + + ``MseCalibrator.reset()`` intentionally drops ``_initial_amax`` to free memory in + the multi-step search, but the NVFP4 per-block amax is shape ``[num_blocks]`` — + small enough to keep so a follow-up ``collect()`` can run again on the same + calibrator instance. + """ + self._best_amax_fast = None + self._losses_sum = None + self._candidates = None + self._amax = None diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index d4b5d78f016..fe4c3f77ce6 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -16,7 +16,6 @@ """Calibration utilities.""" import math -import os import time import warnings from collections.abc import Callable @@ -38,7 +37,7 @@ from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method -from .calib import MseCalibrator, NVFP4MSECalibrator, TritonNVFP4MSECalibrator, _Calibrator +from .calib import MseCalibrator, NVFP4MSECalibrator, _Calibrator from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context from .nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer from .utils import ( @@ -355,11 +354,6 @@ def mse_calibrate( weight_quantizers = [] seen_modules = set() - # Triton-fused FP8 sweep is on by default for NVFP4 static quant; set - # MODELOPT_NVFP4_TRITON_SWEEP=0 to fall back to the reference for debugging. - use_triton_fp8_sweep = os.environ.get("MODELOPT_NVFP4_TRITON_SWEEP", "1") != "0" - nvfp4_calibrator_cls = TritonNVFP4MSECalibrator if use_triton_fp8_sweep else NVFP4MSECalibrator - for name, module in list(model.named_modules()): if isinstance(module, TensorQuantizer) and not module._disabled: if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): @@ -397,7 +391,10 @@ def mse_calibrate( continue if fp8_scale_sweep and is_nvfp4_static: - module._calibrator = nvfp4_calibrator_cls( + # NVFP4MSECalibrator internally selects a fused Triton kernel for + # the standard squared-error sweep; set MODELOPT_NVFP4_TRITON_SWEEP=0 + # to force the reference Python sweep for debugging. + module._calibrator = NVFP4MSECalibrator( amax=initial_amax, axis=module._calibrator._axis, global_amax=module.global_amax, diff --git a/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py b/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py index 14d70d007c3..dc8d245ee4d 100644 --- a/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py +++ b/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py @@ -13,26 +13,45 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Parity + speedup tests for the fused NVFP4 FP8 scale sweep Triton kernel. +"""Parity + speedup tests for the NVFP4 FP8 scale sweep Triton fast path. -Compares :class:`TritonNVFP4MSECalibrator` against the reference -:class:`NVFP4MSECalibrator` on the same inputs and asserts the resulting per-block -amax tensors are bit-identical. Also reports a wall-clock speedup number for the -weight-MSE search step on a representative LLM-sized weight. +Compares the Triton fast path inside :class:`NVFP4MSECalibrator` against its +reference 126-step Python sweep on the same inputs and asserts the resulting +per-block amax tensors are bit-identical. Also reports a wall-clock speedup +number for the weight-MSE search step on a representative LLM-sized weight, +plus dispatch coverage for the conditions that gate the fast path. """ +import os import time +from contextlib import contextmanager import pytest import torch from conftest import requires_triton -from modelopt.torch.quantization.calib import NVFP4MSECalibrator, TritonNVFP4MSECalibrator +from modelopt.torch.quantization.calib import NVFP4MSECalibrator from modelopt.torch.quantization.tensor_quant import static_blockwise_fp4_fake_quant BLOCK_SIZE = 16 +@contextmanager +def _force_sweep_path(triton_enabled: bool): + """Pin the NVFP4 sweep dispatch to the requested path for the duration of the + block, restoring the prior environment afterwards.""" + key = "MODELOPT_NVFP4_TRITON_SWEEP" + prev = os.environ.get(key) + os.environ[key] = "1" if triton_enabled else "0" + try: + yield + finally: + if prev is None: + os.environ.pop(key, None) + else: + os.environ[key] = prev + + def _reference_quant_func(global_amax): """Reference NVFP4 fake-quant matching what ``mse_calibrate`` plumbs in.""" @@ -42,26 +61,27 @@ def quant_func(x, amax): return quant_func -def _run_reference(x, per_block_amax, global_amax): - cal = NVFP4MSECalibrator( +def _make_calibrator(per_block_amax, global_amax): + return NVFP4MSECalibrator( amax=per_block_amax, axis=0, global_amax=global_amax, quant_func=_reference_quant_func(global_amax), ) - cal.collect(x) - return cal.compute_amax() + + +def _run_reference(x, per_block_amax, global_amax): + with _force_sweep_path(triton_enabled=False): + cal = _make_calibrator(per_block_amax, global_amax) + cal.collect(x) + return cal.compute_amax() def _run_triton(x, per_block_amax, global_amax): - cal = TritonNVFP4MSECalibrator( - amax=per_block_amax, - axis=0, - global_amax=global_amax, - quant_func=_reference_quant_func(global_amax), - ) - cal.collect(x) - return cal.compute_amax() + with _force_sweep_path(triton_enabled=True): + cal = _make_calibrator(per_block_amax, global_amax) + cal.collect(x) + return cal.compute_amax() @requires_triton @@ -125,6 +145,7 @@ def test_quantized_output_matches(): @requires_triton def test_reset_allows_recollect(): + """After the fast path runs, a second collect() requires reset() in between.""" torch.manual_seed(0) device = "cuda" num_blocks = 32 @@ -132,22 +153,20 @@ def test_reset_allows_recollect(): per_block_amax = x.abs().amax(dim=-1) global_amax = per_block_amax.max() - cal = TritonNVFP4MSECalibrator( - amax=per_block_amax, - axis=0, - global_amax=global_amax, - ) - cal.collect(x) - first = cal.compute_amax().clone() - - # collect() is one-shot per cycle until reset() is called. - with pytest.raises(RuntimeError, match="one-shot"): + with _force_sweep_path(triton_enabled=True): + cal = _make_calibrator(per_block_amax, global_amax) cal.collect(x) + first = cal.compute_amax().clone() + assert cal._best_amax_fast is not None # fast path was taken - cal.reset() - # After reset, the same calibrator instance can be re-used. - cal.collect(x) - assert torch.equal(first, cal.compute_amax()) + # Second collect after the fast path is not allowed without a reset. + with pytest.raises(RuntimeError, match="multi-collect after the fast path"): + cal.collect(x) + + cal.reset() + # After reset, the same calibrator instance can be re-used; fast path runs again. + cal.collect(x) + assert torch.equal(first, cal.compute_amax()) @requires_triton @@ -175,17 +194,96 @@ def test_input_validation(): @requires_triton -def test_mse_calibrate_dispatch(monkeypatch): - """``mse_calibrate(fp8_scale_sweep=True)`` must install the right calibrator class. +def test_dispatch_fast_path_default(): + """Default config on CUDA with no error_func takes the Triton fast path.""" + torch.manual_seed(0) + num_blocks = 32 + x = torch.randn(num_blocks, BLOCK_SIZE, device="cuda", dtype=torch.float32) + per_block_amax = x.abs().amax(dim=-1) + global_amax = per_block_amax.max() + + with _force_sweep_path(triton_enabled=True): + cal = _make_calibrator(per_block_amax, global_amax) + cal.collect(x) + # Fast path stashes the final amax directly; reference accumulator stays empty. + assert cal._best_amax_fast is not None + assert cal._losses_sum is None + + +@requires_triton +def test_dispatch_env_optout_falls_back(): + """``MODELOPT_NVFP4_TRITON_SWEEP=0`` forces the reference 126-step sweep.""" + torch.manual_seed(0) + num_blocks = 32 + x = torch.randn(num_blocks, BLOCK_SIZE, device="cuda", dtype=torch.float32) + per_block_amax = x.abs().amax(dim=-1) + global_amax = per_block_amax.max() - Default path: ``TritonNVFP4MSECalibrator``. - With ``MODELOPT_NVFP4_TRITON_SWEEP=0``: ``NVFP4MSECalibrator`` (and not its subclass). + with _force_sweep_path(triton_enabled=False): + cal = _make_calibrator(per_block_amax, global_amax) + cal.collect(x) + assert cal._best_amax_fast is None + assert cal._losses_sum is not None + + +@requires_triton +def test_dispatch_custom_error_func_falls_back(): + """A non-None ``error_func`` keeps the reference path so the user's metric is honored. + + This protects custom error-function callers (e.g. local-Hessian calibration's + Hessian-weighted error) from silently being routed through a kernel that only + knows squared-error. """ + torch.manual_seed(0) + num_blocks = 32 + x = torch.randn(num_blocks, BLOCK_SIZE, device="cuda", dtype=torch.float32) + per_block_amax = x.abs().amax(dim=-1) + global_amax = per_block_amax.max() + + def hessian_like_error(a, b): + return (a - b).pow(2) # placeholder; the point is "non-None" + + with _force_sweep_path(triton_enabled=True): + cal = NVFP4MSECalibrator( + amax=per_block_amax, + axis=0, + global_amax=global_amax, + quant_func=_reference_quant_func(global_amax), + error_func=hessian_like_error, + ) + cal.collect(x) + assert cal._best_amax_fast is None + assert cal._losses_sum is not None + + +@requires_triton +def test_dispatch_cpu_path_excluded(): + """The fast-path predicate must reject CPU inputs (kernel is CUDA-only). + + Tests the dispatch decision directly via the predicate rather than running + ``collect()``, since the reference NVFP4 fake-quant kernel is itself CUDA-only — + NVFP4 calibration as a whole isn't a CPU code path. + """ + torch.manual_seed(0) + num_blocks = 32 + x_cpu = torch.randn(num_blocks, BLOCK_SIZE, dtype=torch.float32) + # Build the calibrator on CUDA so other predicate guards aren't the rejection cause. + per_block_amax = x_cpu.abs().amax(dim=-1).cuda() + global_amax = per_block_amax.max() + + with _force_sweep_path(triton_enabled=True): + cal = _make_calibrator(per_block_amax, global_amax) + assert cal._can_use_triton_fast_path(x_cpu) is False + + +@requires_triton +def test_mse_calibrate_end_to_end(monkeypatch): + """End-to-end: the ``mse``/``fp8_scale_sweep=True`` path produces the same quantized + weights with the fast path on (default) and off (``MODELOPT_NVFP4_TRITON_SWEEP=0``).""" from _test_utils.torch.quantization.models import SimpleLinear import modelopt.torch.quantization as mtq from modelopt.torch.quantization.extensions import get_cuda_ext_mx - from modelopt.torch.quantization.nn import TensorQuantizer if get_cuda_ext_mx() is None: pytest.skip("cuda_ext_mx is not available") @@ -206,7 +304,15 @@ def test_mse_calibrate_dispatch(monkeypatch): "algorithm": {"method": "mse", "fp8_scale_sweep": True}, } - def _quantize_and_get_weight_calibrators(model): + def _run_calibrated(env_value): + torch.manual_seed(0) + model = SimpleLinear().cuda() + # Snapshot the pre-calibration weights so both runs start from identical state. + weight_snapshots = {n: p.detach().clone() for n, p in model.named_parameters()} + if env_value is None: + monkeypatch.delenv("MODELOPT_NVFP4_TRITON_SWEEP", raising=False) + else: + monkeypatch.setenv("MODELOPT_NVFP4_TRITON_SWEEP", env_value) calib_data = [model.get_input().cuda() for _ in range(2)] def forward_loop(m): @@ -214,25 +320,20 @@ def forward_loop(m): m(batch) mtq.quantize(model, cfg, forward_loop=forward_loop) - return [ - type(m._calibrator) - for name, m in model.named_modules() - if isinstance(m, TensorQuantizer) - and name.endswith("weight_quantizer") - and getattr(m, "_calibrator", None) is not None - ] - - # Default: triton path. - monkeypatch.delenv("MODELOPT_NVFP4_TRITON_SWEEP", raising=False) - types_default = _quantize_and_get_weight_calibrators(SimpleLinear().cuda()) - assert types_default, "expected at least one weight quantizer with a calibrator" - assert all(t is TritonNVFP4MSECalibrator for t in types_default), types_default - - # Opt-out: reference path, exact class match (TritonNVFP4MSECalibrator is a subclass). - monkeypatch.setenv("MODELOPT_NVFP4_TRITON_SWEEP", "0") - types_optout = _quantize_and_get_weight_calibrators(SimpleLinear().cuda()) - assert types_optout, "expected at least one weight quantizer with a calibrator" - assert all(t is NVFP4MSECalibrator for t in types_optout), types_optout + # Run a deterministic input through and snapshot the output. + torch.manual_seed(1) + x = torch.randn(4, 16, device="cuda") + with torch.no_grad(): + y = model(x).detach().clone() + return y, weight_snapshots + + y_default, w0 = _run_calibrated(env_value=None) + y_optout, w1 = _run_calibrated(env_value="0") + # Both runs must start from the same weights (sanity: SimpleLinear is deterministic + # under the same seed) before we compare post-calibration outputs. + for name in w0: + assert torch.equal(w0[name], w1[name]), name + assert torch.equal(y_default, y_optout) def _bench(fn, warmup=2, iters=5): @@ -296,7 +397,7 @@ def test_speedup_report(capsys): print( f"\n[NVFP4 FP8 sweep] weight=({cout},{cin}) " f"n_blocks={n_blocks} block_size={BLOCK_SIZE}\n" - f" reference NVFP4MSECalibrator: {ref_t * 1e3:8.2f} ms\n" - f" triton TritonNVFP4MSECalibrator: {tri_t * 1e3:8.2f} ms\n" + f" reference path: {ref_t * 1e3:8.2f} ms\n" + f" triton fast path: {tri_t * 1e3:8.2f} ms\n" f" speedup: {speedup:.1f}x" ) From 8f04a9a8aea9fb8d989e629b2b0d387413a07b21 Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Fri, 8 May 2026 05:04:33 +0000 Subject: [PATCH 6/8] [Quantization] Address review feedback round 3 on FP8 sweep MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three changes from realAsma's latest review: - nvfp4_fp8_sweep kernel: use ``scale_safe`` rather than ``scale`` in the per-candidate diff so the divisor and multiplier match. Numerically equivalent on real inputs (the only case where ``scale_safe`` differs from ``scale`` is ``global_amax == 0``, in which case ``w_abs`` is also zero so the loss is zero either way), but more consistent. - Extract ``fp8_scale_candidates`` to a triton-free module ``_fp8_scale_candidates.py`` so the calibrator's reference sweep and the Triton kernel wrapper share one definition. Removes the duplicate copy in ``NVFP4MSECalibrator._generate_candidates``. - Parity test: extend ``test_parity_random_weights`` to cover bf16 and fp16 in addition to fp32 by parametrizing on dtype, so the canonical parity grid (3 seeds × 3 num_blocks) is now exercised on every supported dtype. Folded the smaller ``test_parity_dtypes`` into this since it was a strict subset. Signed-off-by: Chenjie Luo --- .../gemm/_fp8_scale_candidates.py | 31 +++++++++++++++++++ .../quantization/gemm/nvfp4_fp8_sweep.py | 15 +++------ modelopt/torch/quantization/calib/mse.py | 15 +++------ .../test_nvfp4_fp8_sweep_kernel.py | 30 +++++------------- 4 files changed, 48 insertions(+), 43 deletions(-) create mode 100644 modelopt/torch/kernels/quantization/gemm/_fp8_scale_candidates.py diff --git a/modelopt/torch/kernels/quantization/gemm/_fp8_scale_candidates.py b/modelopt/torch/kernels/quantization/gemm/_fp8_scale_candidates.py new file mode 100644 index 00000000000..9c48c180896 --- /dev/null +++ b/modelopt/torch/kernels/quantization/gemm/_fp8_scale_candidates.py @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Single source of truth for the NVFP4 FP8 scale-candidate set. + +Pure PyTorch, no Triton dependency, so it can be imported from both the kernel +wrapper (which is triton-gated) and the reference Python sweep in the +:class:`NVFP4MSECalibrator` (which must work without triton too). +""" + +import torch + + +def fp8_scale_candidates(device: torch.device | str = "cpu") -> torch.Tensor: + """Return the 126 valid finite positive FP8 E4M3 scale candidates / 448.""" + uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device) + fp8_values = uint8_values.view(torch.float8_e4m3fn).float() + valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0) + return fp8_values[valid_mask] / 448.0 diff --git a/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py b/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py index 49e4839a3c1..e15eab328ab 100644 --- a/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py +++ b/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py @@ -32,19 +32,12 @@ import triton import triton.language as tl +from ._fp8_scale_candidates import fp8_scale_candidates from .nvfp4_quant import fp4_round_magnitude __all__ = ["fp8_scale_candidates", "nvfp4_fp8_scale_sweep"] -def fp8_scale_candidates(device: torch.device | str = "cpu") -> torch.Tensor: - """Return the 126 valid finite positive FP8 E4M3 scale candidates / 448.""" - uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device) - fp8_values = uint8_values.view(torch.float8_e4m3fn).float() - valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0) - return fp8_values[valid_mask] / 448.0 - - # Selected from a (BLOCKS_PER_PROGRAM, num_warps) sweep on B300: # BPP=16,nw=2: 6.06 ms BPP=32,nw=4: 6.06 ms BPP=64,nw=8: 5.08 ms # The smaller-tile entries cover cases where N_BLOCKS is small enough that BPP=64 @@ -93,11 +86,11 @@ def _fp8_scale_sweep_kernel( for k in tl.static_range(NUM_CANDIDATES): c = tl.load(candidates_ptr + k).to(tl.float32) scale = c * global_amax / 6.0 - # Avoid divide-by-zero when global_amax == 0; the resulting err == w_abs² is - # the same for every candidate, so any best_idx is fine. + # Avoid divide-by-zero when global_amax == 0; in that case w_abs is also zero + # (global_amax = max|w|), so the loss is zero for every candidate either way. scale_safe = tl.where(scale == 0.0, 1.0, scale) q_mag = fp4_round_magnitude(w_abs / scale_safe) - diff = w_abs - q_mag * scale + diff = w_abs - q_mag * scale_safe loss = tl.sum(diff * diff, axis=1) # [BLOCKS_PER_PROGRAM] is_better = loss < best_loss best_loss = tl.where(is_better, loss, best_loss) diff --git a/modelopt/torch/quantization/calib/mse.py b/modelopt/torch/quantization/calib/mse.py index 417d0173853..c3cacd9f993 100644 --- a/modelopt/torch/quantization/calib/mse.py +++ b/modelopt/torch/quantization/calib/mse.py @@ -203,17 +203,12 @@ def _compute_candidate_amax(self, candidates: torch.Tensor) -> torch.Tensor: return torch.ones_like(self._initial_amax) * self._global_amax * candidates def _generate_candidates(self, device: torch.device) -> torch.Tensor: - """Generate 126 valid FP8 E4M3 scale candidates. + """Generate the 126 valid FP8 E4M3 scale candidates.""" + from modelopt.torch.kernels.quantization.gemm._fp8_scale_candidates import ( + fp8_scale_candidates, + ) - Kept in sync with ``fp8_scale_candidates`` in - ``modelopt.torch.kernels.quantization.gemm.nvfp4_fp8_sweep`` — the FP8 E4M3 - spec is fixed, and the parity test exercises both paths against each other. - """ - uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device) - fp8_values = uint8_values.view(torch.float8_e4m3fn).float() - valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0) - fp8_values = fp8_values[valid_mask] - return fp8_values / 448.0 + return fp8_scale_candidates(device) def _can_use_triton_fast_path(self, x: torch.Tensor) -> bool: """Whether the Triton fast path is usable for this ``collect`` input. diff --git a/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py b/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py index dc8d245ee4d..73c61a45100 100644 --- a/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py +++ b/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py @@ -85,14 +85,17 @@ def _run_triton(x, per_block_amax, global_amax): @requires_triton +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("seed", [0, 1, 2]) @pytest.mark.parametrize("num_blocks", [4, 64, 1024]) -def test_parity_random_weights(seed, num_blocks): - """Triton sweep must produce the exact same per-block amax as the reference.""" +def test_parity_random_weights(seed, num_blocks, dtype): + """Triton sweep must produce the exact same per-block amax as the reference, + across every dtype supported by the NVFP4 quantizer (fp32, fp16, bf16).""" torch.manual_seed(seed) device = "cuda" - x = torch.randn(num_blocks, BLOCK_SIZE, device=device, dtype=torch.float32) - per_block_amax = x.abs().amax(dim=-1) + x = torch.randn(num_blocks, BLOCK_SIZE, device=device, dtype=dtype) + # Promote to fp32 for the per-block amax (matches what max_calibrate produces). + per_block_amax = x.float().abs().amax(dim=-1) global_amax = per_block_amax.max() ref = _run_reference(x, per_block_amax, global_amax) @@ -102,29 +105,12 @@ def test_parity_random_weights(seed, num_blocks): # Both pick from the same 126-element discrete candidate set, so any disagreement # would show up as a non-zero diff (not a small float epsilon). Demand exact match. assert torch.equal(ref, tri), ( - f"Triton sweep diverged from reference: max |diff| = " + f"Triton sweep diverged from reference (dtype={dtype}): max |diff| = " f"{(ref - tri).abs().max().item():.3e}, " f"differing blocks = {(ref != tri).sum().item()} / {num_blocks}" ) -@requires_triton -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) -def test_parity_dtypes(dtype): - """Sweep must agree across the dtypes supported by the NVFP4 quantizer.""" - torch.manual_seed(42) - device = "cuda" - num_blocks = 256 - x = torch.randn(num_blocks, BLOCK_SIZE, device=device, dtype=dtype) - # Promote to fp32 for the per-block amax (matches what max_calibrate produces). - per_block_amax = x.float().abs().amax(dim=-1) - global_amax = per_block_amax.max() - - ref = _run_reference(x, per_block_amax, global_amax) - tri = _run_triton(x, per_block_amax, global_amax) - assert torch.equal(ref, tri) - - @requires_triton def test_quantized_output_matches(): """Round-tripping x through the chosen amax should give the same fake-quant result.""" From 2bc8a54ed85eea66f00d6dd5a4a1c74f98d72462 Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Fri, 8 May 2026 05:15:11 +0000 Subject: [PATCH 7/8] [Quantization] Shrink FP8 sweep parity matrix from 27 to 12 cases MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Trim the parity grid to keep all three axes but with smaller per-axis ranges: 2 seeds × 2 num_blocks × 3 dtypes = 12 parametrized cases (down from 3×3×3 = 27). Still exercises every supported dtype and the small/ large num_blocks extremes that drive different autotune choices, while roughly halving the cold-compile cost on hosts where Triton compilation is expensive. Signed-off-by: Chenjie Luo --- tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py b/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py index 73c61a45100..17d1f1fea55 100644 --- a/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py +++ b/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py @@ -86,8 +86,8 @@ def _run_triton(x, per_block_amax, global_amax): @requires_triton @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("seed", [0, 1, 2]) -@pytest.mark.parametrize("num_blocks", [4, 64, 1024]) +@pytest.mark.parametrize("num_blocks", [4, 1024]) +@pytest.mark.parametrize("seed", [0, 1]) def test_parity_random_weights(seed, num_blocks, dtype): """Triton sweep must produce the exact same per-block amax as the reference, across every dtype supported by the NVFP4 quantizer (fp32, fp16, bf16).""" From c42de7f7de629b228938f18371f83774f80719fd Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Fri, 8 May 2026 16:55:49 +0000 Subject: [PATCH 8/8] [Quantization] Pin NVFP4 calibrator collect tests to reference path After folding the Triton fast path into NVFP4MSECalibrator.collect(), two tests in tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py broke because they inspect path-specific state: - test_collect_and_compute_amax asserts ``cal._losses_sum is not None`` with ``len == 126``. Only the reference 126-step sweep populates that list; the Triton fast path produces ``_best_amax_fast`` directly and leaves ``_losses_sum = None``. - test_multiple_collections asserts that two ``collect()`` calls accumulate. The Triton fast path is one-shot by design and refuses a second collect until ``reset()``, so multi-collect is fundamentally reference-path semantics. Fix: take the ``monkeypatch`` fixture in both tests and force ``MODELOPT_NVFP4_TRITON_SWEEP=0`` so they exercise the reference accumulator. Triton-path coverage stays in test_nvfp4_fp8_sweep_kernel.py (parity, dispatch predicate, end-to-end mtq.quantize). The other tests in the same class (initialization, candidate generation, per-block independence) are path-agnostic and unchanged. Signed-off-by: Chenjie Luo --- .../test_nvfp4_static_quantizer_cuda.py | 21 +++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py b/tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py index 430b7ee4113..2b5caea16f0 100644 --- a/tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py +++ b/tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py @@ -164,8 +164,15 @@ def test_fp8_candidates_generation(self, device): assert torch.all(torch.isfinite(candidates)) assert torch.all(candidates > 0) - def test_collect_and_compute_amax(self, device): - """Test collect and compute_amax workflow.""" + def test_collect_and_compute_amax(self, device, monkeypatch): + """Test reference-path collect and compute_amax workflow. + + Pinned to the reference 126-step sweep (``MODELOPT_NVFP4_TRITON_SWEEP=0``) + because this test inspects ``_losses_sum``, which only the reference path + populates; the Triton fast path produces ``_best_amax_fast`` directly and + is covered separately in ``test_nvfp4_fp8_sweep_kernel.py``. + """ + monkeypatch.setenv("MODELOPT_NVFP4_TRITON_SWEEP", "0") num_blocks = 8 block_size = 16 per_block_amax = torch.ones(num_blocks, device=device) @@ -193,8 +200,14 @@ def quant_func(x, amax): assert torch.all(torch.isfinite(amax)) assert torch.all(amax > 0) - def test_multiple_collections(self, device): - """Test that multiple collections accumulate correctly.""" + def test_multiple_collections(self, device, monkeypatch): + """Test that multiple collections accumulate correctly. + + Multi-collect is reference-path-only — the Triton fast path is one-shot + and refuses a second ``collect()`` until ``reset()``. Forcing the env var + keeps this exercising the accumulator. + """ + monkeypatch.setenv("MODELOPT_NVFP4_TRITON_SWEEP", "0") num_blocks = 4 block_size = 16 per_block_amax = torch.ones(num_blocks, device=device)