diff --git a/modelopt/torch/kernels/quantization/gemm/__init__.py b/modelopt/torch/kernels/quantization/gemm/__init__.py index 39b07b4faa9..70f729cffb0 100644 --- a/modelopt/torch/kernels/quantization/gemm/__init__.py +++ b/modelopt/torch/kernels/quantization/gemm/__init__.py @@ -32,6 +32,7 @@ # fp4_kernel works on any CUDA GPU with triton from .fp4_kernel import * from .fp8_kernel import * + from .nvfp4_fp8_sweep import * # fp4_kernel_hopper requires compute >= 8.9 (uses tl.float8e4nv) if torch.cuda.get_device_capability() >= (8, 9): diff --git a/modelopt/torch/kernels/quantization/gemm/_fp8_scale_candidates.py b/modelopt/torch/kernels/quantization/gemm/_fp8_scale_candidates.py new file mode 100644 index 00000000000..9c48c180896 --- /dev/null +++ b/modelopt/torch/kernels/quantization/gemm/_fp8_scale_candidates.py @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Single source of truth for the NVFP4 FP8 scale-candidate set. + +Pure PyTorch, no Triton dependency, so it can be imported from both the kernel +wrapper (which is triton-gated) and the reference Python sweep in the +:class:`NVFP4MSECalibrator` (which must work without triton too). +""" + +import torch + + +def fp8_scale_candidates(device: torch.device | str = "cpu") -> torch.Tensor: + """Return the 126 valid finite positive FP8 E4M3 scale candidates / 448.""" + uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device) + fp8_values = uint8_values.view(torch.float8_e4m3fn).float() + valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0) + return fp8_values[valid_mask] / 448.0 diff --git a/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py b/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py new file mode 100644 index 00000000000..e15eab328ab --- /dev/null +++ b/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py @@ -0,0 +1,150 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fused Triton kernel for the NVFP4 weight-MSE FP8 scale sweep. + +Replaces the 126-iteration Python sweep in :class:`NVFP4MSECalibrator` with a single +kernel that, for each NVFP4 block, evaluates all 126 valid FP8 E4M3 scale candidates +and emits the per-block ``best_amax`` directly. + +The 126 candidates are constructed as ``valid_fp8_e4m3_value / 448`` (see +:func:`fp8_scale_candidates`). For these specific candidates, the FP8 round-trip on +the per-block scale is the identity, so the kernel can use +``scale = candidate * global_amax / 6.0`` without an explicit FP8 cast — making it +runnable on any CUDA GPU with Triton (no ``tl.float8e4nv`` requirement). + +Tile shape (``BLOCKS_PER_PROGRAM``) and ``num_warps`` are autotuned per ``N_BLOCKS``. +""" + +import torch +import triton +import triton.language as tl + +from ._fp8_scale_candidates import fp8_scale_candidates +from .nvfp4_quant import fp4_round_magnitude + +__all__ = ["fp8_scale_candidates", "nvfp4_fp8_scale_sweep"] + + +# Selected from a (BLOCKS_PER_PROGRAM, num_warps) sweep on B300: +# BPP=16,nw=2: 6.06 ms BPP=32,nw=4: 6.06 ms BPP=64,nw=8: 5.08 ms +# The smaller-tile entries cover cases where N_BLOCKS is small enough that BPP=64 +# would underfill the SMs. +_FP8_SWEEP_AUTOTUNE_CONFIGS = [ + triton.Config({"BLOCKS_PER_PROGRAM": 16}, num_warps=2), + triton.Config({"BLOCKS_PER_PROGRAM": 32}, num_warps=4), + triton.Config({"BLOCKS_PER_PROGRAM": 64}, num_warps=8), +] + + +@triton.autotune(configs=_FP8_SWEEP_AUTOTUNE_CONFIGS, key=["N_BLOCKS"]) +@triton.jit +def _fp8_scale_sweep_kernel( + x_ptr, # [N_BLOCKS * BLOCK_SIZE], any float dtype (loaded as fp32) + candidates_ptr, # [NUM_CANDIDATES] fp32 + global_amax_ptr, # scalar fp32 + best_amax_ptr, # [N_BLOCKS] fp32 output + N_BLOCKS, + BLOCK_SIZE: tl.constexpr, + NUM_CANDIDATES: tl.constexpr, + BLOCKS_PER_PROGRAM: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCKS_PER_PROGRAM + block_idx = block_start + tl.arange(0, BLOCKS_PER_PROGRAM) + block_mask = block_idx < N_BLOCKS + + # Load weights for this tile and pre-compute their absolute values once. + # The squared error is sign-invariant since FP4 quant preserves sign: + # (w - w_q)^2 = (|w| - |w_q|)^2 = (|w| - q_mag * scale)^2 + # so we never need ``w`` itself again, dropping a tl.where + negation per element. + elem_offs = block_idx[:, None] * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[None, :] + elem_mask = block_mask[:, None] + w_abs = tl.abs(tl.load(x_ptr + elem_offs, mask=elem_mask, other=0.0).to(tl.float32)) + + global_amax = tl.load(global_amax_ptr).to(tl.float32) + + best_loss = tl.full([BLOCKS_PER_PROGRAM], float("inf"), dtype=tl.float32) + best_idx = tl.zeros([BLOCKS_PER_PROGRAM], dtype=tl.int32) + + # Loop over the 126 FP8 candidates (compile-time unrolled). + # Scales are guaranteed positive and finite (constructed from a positive candidate + # times nonneg global_amax), so the degenerate-scale guard from nvfp4_scalar_quant is + # unnecessary apart from the global_amax == 0 case handled below. + for k in tl.static_range(NUM_CANDIDATES): + c = tl.load(candidates_ptr + k).to(tl.float32) + scale = c * global_amax / 6.0 + # Avoid divide-by-zero when global_amax == 0; in that case w_abs is also zero + # (global_amax = max|w|), so the loss is zero for every candidate either way. + scale_safe = tl.where(scale == 0.0, 1.0, scale) + q_mag = fp4_round_magnitude(w_abs / scale_safe) + diff = w_abs - q_mag * scale_safe + loss = tl.sum(diff * diff, axis=1) # [BLOCKS_PER_PROGRAM] + is_better = loss < best_loss + best_loss = tl.where(is_better, loss, best_loss) + best_idx = tl.where(is_better, k, best_idx) + + # Map each block's winning candidate index back to its amax = global_amax * c[best]. + best_c = tl.load(candidates_ptr + best_idx, mask=block_mask, other=0.0).to(tl.float32) + best_amax = global_amax * best_c + tl.store(best_amax_ptr + block_idx, best_amax, mask=block_mask) + + +def nvfp4_fp8_scale_sweep( + x: torch.Tensor, + global_amax: torch.Tensor, + block_size: int = 16, +) -> torch.Tensor: + """Find the per-block FP8 scale that minimizes NVFP4 quantization MSE. + + Equivalent to the 126-step sweep in :class:`NVFP4MSECalibrator`, but fused into + a single Triton kernel: every block's weight elements are loaded once, all 126 + candidates are evaluated in registers, and the running argmin is kept inline. + + Args: + x: Weight tensor on CUDA. Total element count must be divisible by + ``block_size``; layout is treated as a flat ``[N_BLOCKS, BLOCK_SIZE]``. + global_amax: Scalar FP32 global amax (``= reduce_amax(per_block_amax)``). + block_size: NVFP4 block size (typically 16). + + Returns: + ``best_amax`` of shape ``[N_BLOCKS]``, fp32, on the same device as ``x``. + """ + if not x.is_cuda: + raise ValueError("nvfp4_fp8_scale_sweep requires a CUDA tensor.") + if not isinstance(block_size, int) or block_size <= 0: + raise ValueError(f"block_size must be a positive int, got {block_size!r}.") + if x.numel() % block_size != 0: + raise ValueError(f"x.numel() ({x.numel()}) is not divisible by block_size ({block_size}).") + + candidates = fp8_scale_candidates(x.device).to(dtype=torch.float32) + + n_blocks = x.numel() // block_size + x_flat = x.contiguous().view(-1) + global_amax_f32 = global_amax.detach().to(device=x.device, dtype=torch.float32).reshape(1) + best_amax = torch.empty(n_blocks, dtype=torch.float32, device=x.device) + + grid = lambda meta: (triton.cdiv(n_blocks, meta["BLOCKS_PER_PROGRAM"]),) + with torch.cuda.device(x.device): + _fp8_scale_sweep_kernel[grid]( + x_flat, + candidates, + global_amax_f32, + best_amax, + n_blocks, + BLOCK_SIZE=block_size, + NUM_CANDIDATES=int(candidates.numel()), + ) + return best_amax diff --git a/modelopt/torch/quantization/calib/mse.py b/modelopt/torch/quantization/calib/mse.py index 1f439a7e778..c3cacd9f993 100644 --- a/modelopt/torch/quantization/calib/mse.py +++ b/modelopt/torch/quantization/calib/mse.py @@ -16,6 +16,7 @@ """Calibrator that returns the MSE amax of all collected tensors.""" import math +import os from collections.abc import Callable import torch @@ -172,7 +173,15 @@ def compute_amax(self, verbose: bool = False): class NVFP4MSECalibrator(MseCalibrator): - """Per-block FP8 scale sweep calibrator for NVFP4 static quantization.""" + """Per-block FP8 scale sweep calibrator for NVFP4 static quantization. + + Uses a fused Triton kernel as an internal fast path on the first ``collect`` call + when (a) ``error_func is None``, (b) the input tensor is on CUDA in the standard + blocked ``[n_blocks, block_size]`` layout, and (c) Triton + the kernel package are + importable. Falls back to the reference 126-step Python sweep otherwise (custom + ``error_func`` users, multi-``collect`` activation flows, CPU inputs, or when the + fast path is disabled via ``MODELOPT_NVFP4_TRITON_SWEEP=0``). + """ def __init__( self, @@ -185,6 +194,8 @@ def __init__( """Initialize NVFP4 MSE calibrator with per-block and global amax.""" super().__init__(amax=amax, axis=axis, quant_func=quant_func, error_func=error_func) self._global_amax = global_amax + # Set by the Triton fast path on its (one-shot) collect; consumed by compute_amax. + self._best_amax_fast: torch.Tensor | None = None def _compute_candidate_amax(self, candidates: torch.Tensor) -> torch.Tensor: if candidates.ndim != 0: # Called during final compute amax @@ -192,9 +203,77 @@ def _compute_candidate_amax(self, candidates: torch.Tensor) -> torch.Tensor: return torch.ones_like(self._initial_amax) * self._global_amax * candidates def _generate_candidates(self, device: torch.device) -> torch.Tensor: - """Generate 126 valid FP8 E4M3 scale candidates.""" - uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device) - fp8_values = uint8_values.view(torch.float8_e4m3fn).float() - valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0) - fp8_values = fp8_values[valid_mask] - return fp8_values / 448.0 + """Generate the 126 valid FP8 E4M3 scale candidates.""" + from modelopt.torch.kernels.quantization.gemm._fp8_scale_candidates import ( + fp8_scale_candidates, + ) + + return fp8_scale_candidates(device) + + def _can_use_triton_fast_path(self, x: torch.Tensor) -> bool: + """Whether the Triton fast path is usable for this ``collect`` input. + + The kernel produces the final per-block amax in one shot, so it's only usable + when the caller wants the standard squared-error sweep on a single CUDA tensor + whose layout already matches the per-block amax. + """ + if self._error_func is not None: + return False + if not x.is_cuda: + return False + if os.environ.get("MODELOPT_NVFP4_TRITON_SWEEP", "1") == "0": + return False + if self._initial_amax is None: + return False + if x.ndim != 2 or x.shape[0] != int(self._initial_amax.numel()): + return False + try: + from modelopt.torch.kernels.quantization.gemm import nvfp4_fp8_scale_sweep # noqa: F401 + except ImportError: + return False + return True + + @torch.no_grad() + def collect(self, x: torch.Tensor): + """Collect input statistics. Uses the Triton fast path when eligible.""" + if self._best_amax_fast is not None: + raise RuntimeError( + "NVFP4MSECalibrator: the Triton fast path produced a final amax on a " + "previous collect() call; multi-collect after the fast path is not " + "supported. Call reset() to start a fresh cycle, set " + "MODELOPT_NVFP4_TRITON_SWEEP=0, or pass a non-None error_func to force " + "the reference path for activation-style accumulation." + ) + # Fast path is eligible only on the first call, before the reference accumulator + # has produced any state. + if self._losses_sum is None and self._can_use_triton_fast_path(x): + from modelopt.torch.kernels.quantization.gemm import nvfp4_fp8_scale_sweep + + best_flat = nvfp4_fp8_scale_sweep(x.detach(), self._global_amax, block_size=x.shape[-1]) + # Match the original shape/dtype of the initial amax so downstream + # load_calib_amax behaves identically to the reference path. + self._best_amax_fast = best_flat.reshape(self._initial_amax.shape).to( + self._initial_amax.dtype + ) + return + super().collect(x) + + @torch.no_grad() + def compute_amax(self, verbose: bool = False): + """Return the per-block amax — from the fast path if it ran, else from the reference sweep.""" + if self._best_amax_fast is not None: + return self._best_amax_fast + return super().compute_amax(verbose=verbose) + + def reset(self): + """Reset per-cycle state. Keep ``_initial_amax`` so the calibrator stays reusable. + + ``MseCalibrator.reset()`` intentionally drops ``_initial_amax`` to free memory in + the multi-step search, but the NVFP4 per-block amax is shape ``[num_blocks]`` — + small enough to keep so a follow-up ``collect()`` can run again on the same + calibrator instance. + """ + self._best_amax_fast = None + self._losses_sum = None + self._candidates = None + self._amax = None diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index aeae3dd4321..fe4c3f77ce6 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -391,7 +391,9 @@ def mse_calibrate( continue if fp8_scale_sweep and is_nvfp4_static: - # Replace calibrator with NVFP4MSECalibrator + # NVFP4MSECalibrator internally selects a fused Triton kernel for + # the standard squared-error sweep; set MODELOPT_NVFP4_TRITON_SWEEP=0 + # to force the reference Python sweep for debugging. module._calibrator = NVFP4MSECalibrator( amax=initial_amax, axis=module._calibrator._axis, diff --git a/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py b/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py new file mode 100644 index 00000000000..17d1f1fea55 --- /dev/null +++ b/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py @@ -0,0 +1,389 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Parity + speedup tests for the NVFP4 FP8 scale sweep Triton fast path. + +Compares the Triton fast path inside :class:`NVFP4MSECalibrator` against its +reference 126-step Python sweep on the same inputs and asserts the resulting +per-block amax tensors are bit-identical. Also reports a wall-clock speedup +number for the weight-MSE search step on a representative LLM-sized weight, +plus dispatch coverage for the conditions that gate the fast path. +""" + +import os +import time +from contextlib import contextmanager + +import pytest +import torch +from conftest import requires_triton + +from modelopt.torch.quantization.calib import NVFP4MSECalibrator +from modelopt.torch.quantization.tensor_quant import static_blockwise_fp4_fake_quant + +BLOCK_SIZE = 16 + + +@contextmanager +def _force_sweep_path(triton_enabled: bool): + """Pin the NVFP4 sweep dispatch to the requested path for the duration of the + block, restoring the prior environment afterwards.""" + key = "MODELOPT_NVFP4_TRITON_SWEEP" + prev = os.environ.get(key) + os.environ[key] = "1" if triton_enabled else "0" + try: + yield + finally: + if prev is None: + os.environ.pop(key, None) + else: + os.environ[key] = prev + + +def _reference_quant_func(global_amax): + """Reference NVFP4 fake-quant matching what ``mse_calibrate`` plumbs in.""" + + def quant_func(x, amax): + return static_blockwise_fp4_fake_quant(x, amax, global_amax) + + return quant_func + + +def _make_calibrator(per_block_amax, global_amax): + return NVFP4MSECalibrator( + amax=per_block_amax, + axis=0, + global_amax=global_amax, + quant_func=_reference_quant_func(global_amax), + ) + + +def _run_reference(x, per_block_amax, global_amax): + with _force_sweep_path(triton_enabled=False): + cal = _make_calibrator(per_block_amax, global_amax) + cal.collect(x) + return cal.compute_amax() + + +def _run_triton(x, per_block_amax, global_amax): + with _force_sweep_path(triton_enabled=True): + cal = _make_calibrator(per_block_amax, global_amax) + cal.collect(x) + return cal.compute_amax() + + +@requires_triton +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("num_blocks", [4, 1024]) +@pytest.mark.parametrize("seed", [0, 1]) +def test_parity_random_weights(seed, num_blocks, dtype): + """Triton sweep must produce the exact same per-block amax as the reference, + across every dtype supported by the NVFP4 quantizer (fp32, fp16, bf16).""" + torch.manual_seed(seed) + device = "cuda" + x = torch.randn(num_blocks, BLOCK_SIZE, device=device, dtype=dtype) + # Promote to fp32 for the per-block amax (matches what max_calibrate produces). + per_block_amax = x.float().abs().amax(dim=-1) + global_amax = per_block_amax.max() + + ref = _run_reference(x, per_block_amax, global_amax) + tri = _run_triton(x, per_block_amax, global_amax) + + assert ref.shape == tri.shape + # Both pick from the same 126-element discrete candidate set, so any disagreement + # would show up as a non-zero diff (not a small float epsilon). Demand exact match. + assert torch.equal(ref, tri), ( + f"Triton sweep diverged from reference (dtype={dtype}): max |diff| = " + f"{(ref - tri).abs().max().item():.3e}, " + f"differing blocks = {(ref != tri).sum().item()} / {num_blocks}" + ) + + +@requires_triton +def test_quantized_output_matches(): + """Round-tripping x through the chosen amax should give the same fake-quant result.""" + torch.manual_seed(7) + device = "cuda" + num_blocks = 128 + x = torch.randn(num_blocks, BLOCK_SIZE, device=device, dtype=torch.float32) + per_block_amax = x.abs().amax(dim=-1) + global_amax = per_block_amax.max() + + ref_amax = _run_reference(x, per_block_amax, global_amax) + tri_amax = _run_triton(x, per_block_amax, global_amax) + + ref_xq = static_blockwise_fp4_fake_quant(x, ref_amax, global_amax) + tri_xq = static_blockwise_fp4_fake_quant(x, tri_amax, global_amax) + assert torch.equal(ref_xq, tri_xq) + + +@requires_triton +def test_reset_allows_recollect(): + """After the fast path runs, a second collect() requires reset() in between.""" + torch.manual_seed(0) + device = "cuda" + num_blocks = 32 + x = torch.randn(num_blocks, BLOCK_SIZE, device=device, dtype=torch.float32) + per_block_amax = x.abs().amax(dim=-1) + global_amax = per_block_amax.max() + + with _force_sweep_path(triton_enabled=True): + cal = _make_calibrator(per_block_amax, global_amax) + cal.collect(x) + first = cal.compute_amax().clone() + assert cal._best_amax_fast is not None # fast path was taken + + # Second collect after the fast path is not allowed without a reset. + with pytest.raises(RuntimeError, match="multi-collect after the fast path"): + cal.collect(x) + + cal.reset() + # After reset, the same calibrator instance can be re-used; fast path runs again. + cal.collect(x) + assert torch.equal(first, cal.compute_amax()) + + +@requires_triton +def test_input_validation(): + """``nvfp4_fp8_scale_sweep`` should reject malformed inputs cleanly.""" + from modelopt.torch.kernels.quantization.gemm import nvfp4_fp8_scale_sweep + + device = "cuda" + x = torch.randn(64, BLOCK_SIZE, device=device) + g = x.abs().amax() + + # CPU tensor → ValueError (not bare AssertionError). + with pytest.raises(ValueError, match="CUDA"): + nvfp4_fp8_scale_sweep(x.cpu(), g.cpu()) + + # block_size <= 0. + with pytest.raises(ValueError, match="block_size"): + nvfp4_fp8_scale_sweep(x, g, block_size=0) + with pytest.raises(ValueError, match="block_size"): + nvfp4_fp8_scale_sweep(x, g, block_size=-1) + + # Non-divisible numel. + with pytest.raises(ValueError, match="not divisible"): + nvfp4_fp8_scale_sweep(x, g, block_size=15) + + +@requires_triton +def test_dispatch_fast_path_default(): + """Default config on CUDA with no error_func takes the Triton fast path.""" + torch.manual_seed(0) + num_blocks = 32 + x = torch.randn(num_blocks, BLOCK_SIZE, device="cuda", dtype=torch.float32) + per_block_amax = x.abs().amax(dim=-1) + global_amax = per_block_amax.max() + + with _force_sweep_path(triton_enabled=True): + cal = _make_calibrator(per_block_amax, global_amax) + cal.collect(x) + # Fast path stashes the final amax directly; reference accumulator stays empty. + assert cal._best_amax_fast is not None + assert cal._losses_sum is None + + +@requires_triton +def test_dispatch_env_optout_falls_back(): + """``MODELOPT_NVFP4_TRITON_SWEEP=0`` forces the reference 126-step sweep.""" + torch.manual_seed(0) + num_blocks = 32 + x = torch.randn(num_blocks, BLOCK_SIZE, device="cuda", dtype=torch.float32) + per_block_amax = x.abs().amax(dim=-1) + global_amax = per_block_amax.max() + + with _force_sweep_path(triton_enabled=False): + cal = _make_calibrator(per_block_amax, global_amax) + cal.collect(x) + assert cal._best_amax_fast is None + assert cal._losses_sum is not None + + +@requires_triton +def test_dispatch_custom_error_func_falls_back(): + """A non-None ``error_func`` keeps the reference path so the user's metric is honored. + + This protects custom error-function callers (e.g. local-Hessian calibration's + Hessian-weighted error) from silently being routed through a kernel that only + knows squared-error. + """ + torch.manual_seed(0) + num_blocks = 32 + x = torch.randn(num_blocks, BLOCK_SIZE, device="cuda", dtype=torch.float32) + per_block_amax = x.abs().amax(dim=-1) + global_amax = per_block_amax.max() + + def hessian_like_error(a, b): + return (a - b).pow(2) # placeholder; the point is "non-None" + + with _force_sweep_path(triton_enabled=True): + cal = NVFP4MSECalibrator( + amax=per_block_amax, + axis=0, + global_amax=global_amax, + quant_func=_reference_quant_func(global_amax), + error_func=hessian_like_error, + ) + cal.collect(x) + assert cal._best_amax_fast is None + assert cal._losses_sum is not None + + +@requires_triton +def test_dispatch_cpu_path_excluded(): + """The fast-path predicate must reject CPU inputs (kernel is CUDA-only). + + Tests the dispatch decision directly via the predicate rather than running + ``collect()``, since the reference NVFP4 fake-quant kernel is itself CUDA-only — + NVFP4 calibration as a whole isn't a CPU code path. + """ + torch.manual_seed(0) + num_blocks = 32 + x_cpu = torch.randn(num_blocks, BLOCK_SIZE, dtype=torch.float32) + # Build the calibrator on CUDA so other predicate guards aren't the rejection cause. + per_block_amax = x_cpu.abs().amax(dim=-1).cuda() + global_amax = per_block_amax.max() + + with _force_sweep_path(triton_enabled=True): + cal = _make_calibrator(per_block_amax, global_amax) + assert cal._can_use_triton_fast_path(x_cpu) is False + + +@requires_triton +def test_mse_calibrate_end_to_end(monkeypatch): + """End-to-end: the ``mse``/``fp8_scale_sweep=True`` path produces the same quantized + weights with the fast path on (default) and off (``MODELOPT_NVFP4_TRITON_SWEEP=0``).""" + from _test_utils.torch.quantization.models import SimpleLinear + + import modelopt.torch.quantization as mtq + from modelopt.torch.quantization.extensions import get_cuda_ext_mx + + if get_cuda_ext_mx() is None: + pytest.skip("cuda_ext_mx is not available") + + cfg = { + "quant_cfg": [ + { + "quantizer_name": "*weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + }, + "enable": True, + }, + {"quantizer_name": "*input_quantizer", "enable": False}, + ], + "algorithm": {"method": "mse", "fp8_scale_sweep": True}, + } + + def _run_calibrated(env_value): + torch.manual_seed(0) + model = SimpleLinear().cuda() + # Snapshot the pre-calibration weights so both runs start from identical state. + weight_snapshots = {n: p.detach().clone() for n, p in model.named_parameters()} + if env_value is None: + monkeypatch.delenv("MODELOPT_NVFP4_TRITON_SWEEP", raising=False) + else: + monkeypatch.setenv("MODELOPT_NVFP4_TRITON_SWEEP", env_value) + calib_data = [model.get_input().cuda() for _ in range(2)] + + def forward_loop(m): + for batch in calib_data: + m(batch) + + mtq.quantize(model, cfg, forward_loop=forward_loop) + # Run a deterministic input through and snapshot the output. + torch.manual_seed(1) + x = torch.randn(4, 16, device="cuda") + with torch.no_grad(): + y = model(x).detach().clone() + return y, weight_snapshots + + y_default, w0 = _run_calibrated(env_value=None) + y_optout, w1 = _run_calibrated(env_value="0") + # Both runs must start from the same weights (sanity: SimpleLinear is deterministic + # under the same seed) before we compare post-calibration outputs. + for name in w0: + assert torch.equal(w0[name], w1[name]), name + assert torch.equal(y_default, y_optout) + + +def _bench(fn, warmup=2, iters=5): + for _ in range(warmup): + fn() + torch.cuda.synchronize() + t0 = time.perf_counter() + for _ in range(iters): + fn() + torch.cuda.synchronize() + return (time.perf_counter() - t0) / iters + + +@requires_triton +def test_speedup_report(capsys): + """Sanity-check that the Triton path is meaningfully faster on a realistic weight. + + Uses an 8192 x 4096 weight (~33M elements, ~2M NVFP4 blocks) — roughly the size + of an LLM attention/MLP projection. Reports the speedup; does not gate on a + minimum factor (kernel timing is noisy on shared CI), but does require parity + on the chosen amax. + """ + torch.manual_seed(123) + device = "cuda" + cout, cin = 8192, 4096 + x = torch.randn(cout, cin // BLOCK_SIZE, BLOCK_SIZE, device=device, dtype=torch.float32) + x = x.reshape(-1, BLOCK_SIZE) + per_block_amax = x.abs().amax(dim=-1) + global_amax = per_block_amax.max() + + ref_amax = _run_reference(x, per_block_amax, global_amax) + tri_amax = _run_triton(x, per_block_amax, global_amax) + # Bit-equality across millions of blocks isn't guaranteed: when two adjacent FP8 + # candidates yield near-identical per-block MSE (within fp32 noise), the reference's + # CUDA fake_e4m3fy path and our Triton inline math can break ties differently. Demand + # instead that the Triton choice produces a per-block MSE within fp32 epsilon of the + # reference's choice. + n_blocks = ref_amax.numel() + n_diff = int((ref_amax != tri_amax).sum()) + if n_diff: + ref_xq = static_blockwise_fp4_fake_quant(x, ref_amax, global_amax) + tri_xq = static_blockwise_fp4_fake_quant(x, tri_amax, global_amax) + per_block_mse_ref = (x - ref_xq).pow(2).sum(dim=-1) + per_block_mse_tri = (x - tri_xq).pow(2).sum(dim=-1) + # Reference is the formal argmin, so triton's loss should be ≥ reference's. + # Allow at most 1e-5 relative gap on differing blocks (observed ~1e-7 in practice). + rel_gap = (per_block_mse_tri - per_block_mse_ref).abs() / per_block_mse_ref.clamp_min(1e-12) + worst = rel_gap.max().item() + assert worst < 1e-5, ( + f"{n_diff}/{n_blocks} blocks disagree with worst relative MSE gap {worst:.3e} " + "— exceeds tie-break tolerance" + ) + + ref_t = _bench(lambda: _run_reference(x, per_block_amax, global_amax)) + tri_t = _bench(lambda: _run_triton(x, per_block_amax, global_amax)) + speedup = ref_t / tri_t + + # Force-print regardless of pytest capture mode. + with capsys.disabled(): + n_blocks = x.numel() // BLOCK_SIZE + print( + f"\n[NVFP4 FP8 sweep] weight=({cout},{cin}) " + f"n_blocks={n_blocks} block_size={BLOCK_SIZE}\n" + f" reference path: {ref_t * 1e3:8.2f} ms\n" + f" triton fast path: {tri_t * 1e3:8.2f} ms\n" + f" speedup: {speedup:.1f}x" + ) diff --git a/tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py b/tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py index 430b7ee4113..2b5caea16f0 100644 --- a/tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py +++ b/tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py @@ -164,8 +164,15 @@ def test_fp8_candidates_generation(self, device): assert torch.all(torch.isfinite(candidates)) assert torch.all(candidates > 0) - def test_collect_and_compute_amax(self, device): - """Test collect and compute_amax workflow.""" + def test_collect_and_compute_amax(self, device, monkeypatch): + """Test reference-path collect and compute_amax workflow. + + Pinned to the reference 126-step sweep (``MODELOPT_NVFP4_TRITON_SWEEP=0``) + because this test inspects ``_losses_sum``, which only the reference path + populates; the Triton fast path produces ``_best_amax_fast`` directly and + is covered separately in ``test_nvfp4_fp8_sweep_kernel.py``. + """ + monkeypatch.setenv("MODELOPT_NVFP4_TRITON_SWEEP", "0") num_blocks = 8 block_size = 16 per_block_amax = torch.ones(num_blocks, device=device) @@ -193,8 +200,14 @@ def quant_func(x, amax): assert torch.all(torch.isfinite(amax)) assert torch.all(amax > 0) - def test_multiple_collections(self, device): - """Test that multiple collections accumulate correctly.""" + def test_multiple_collections(self, device, monkeypatch): + """Test that multiple collections accumulate correctly. + + Multi-collect is reference-path-only — the Triton fast path is one-shot + and refuses a second ``collect()`` until ``reset()``. Forcing the env var + keeps this exercising the accumulator. + """ + monkeypatch.setenv("MODELOPT_NVFP4_TRITON_SWEEP", "0") num_blocks = 4 block_size = 16 per_block_amax = torch.ones(num_blocks, device=device)