diff --git a/examples/llm_ptq/scripts/huggingface_example.sh b/examples/llm_ptq/scripts/huggingface_example.sh index 6ca99c7f963..693506929d9 100755 --- a/examples/llm_ptq/scripts/huggingface_example.sh +++ b/examples/llm_ptq/scripts/huggingface_example.sh @@ -49,18 +49,7 @@ dense | sparsegpt) ;; ;; esac -#Iterate over list of qformats provided and check if they are valid -IFS="," -for qformat in $QFORMAT; do - case $qformat in - fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | nvfp4_mse | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | nvfp4_experts_only | nvfp4_mlp_only | nvfp4_omlp_only | nvfp4_svdquant | mxfp8 | nvfp4_local_hessian) ;; - *) - echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, nvfp4_mse, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, nvfp4_experts_only, nvfp4_mlp_only, nvfp4_omlp_only, nvfp4_svdquant, mxfp8, nvfp4_local_hessian]" >&2 - exit 1 - ;; - esac -done -IFS=" " +# Quant format / recipe validation is delegated to hf_ptq.py. script_dir="$(dirname "$(readlink -f "$0")")" @@ -72,7 +61,14 @@ fi QFORMAT_MODIFIED="${QFORMAT//,/_}" -MODEL_NAME=$(basename $MODEL_PATH | sed 's/[^0-9a-zA-Z\-]/_/g')_${QFORMAT_MODIFIED}${KV_CACHE_QUANT:+_kv_${KV_CACHE_QUANT}} +# When using --recipe, build the model name from the recipe basename (without +# directory or .yaml suffix) so each recipe gets its own SAVE_PATH. +if [ -n "$RECIPE" ]; then + RECIPE_TAG=$(basename "$RECIPE" .yaml | sed 's/[^0-9a-zA-Z\-]/_/g') + MODEL_NAME=$(basename $MODEL_PATH | sed 's/[^0-9a-zA-Z\-]/_/g')_recipe_${RECIPE_TAG} +else + MODEL_NAME=$(basename $MODEL_PATH | sed 's/[^0-9a-zA-Z\-]/_/g')_${QFORMAT_MODIFIED}${KV_CACHE_QUANT:+_kv_${KV_CACHE_QUANT}} +fi SAVE_PATH=${ROOT_SAVE_PATH}/saved_models_${MODEL_NAME} @@ -177,11 +173,16 @@ if [[ $TASKS =~ "quant" ]] || [[ ! -d "$SAVE_PATH" ]] || [[ ! $(ls -A $SAVE_PATH if [[ "$MODEL_CONFIG_EXIST" == false ]]; then echo "Quantizing original model..." + if [ -n "$RECIPE" ]; then + QUANT_SPEC_ARGS="--recipe=$RECIPE" + else + QUANT_SPEC_ARGS="--qformat=${QFORMAT// /,}" + fi python hf_ptq.py \ --pyt_ckpt_path=$MODEL_PATH \ --export_path=$SAVE_PATH \ --sparsity_fmt=$SPARSITY_FMT \ - --qformat="${QFORMAT// /,}" \ + $QUANT_SPEC_ARGS \ --calib_size=$CALIB_SIZE \ --batch_size=$CALIB_BATCH_SIZE \ --inference_tensor_parallel=$TP \ @@ -203,7 +204,7 @@ if [[ $TASKS =~ "quant" ]] || [[ ! -d "$SAVE_PATH" ]] || [[ ! $(ls -A $SAVE_PATH exit 0 fi - if [[ "$QFORMAT" == *"nvfp4"* ]] || [[ "$KV_CACHE_QUANT" == *"nvfp4"* ]]; then + if [[ "$QFORMAT" == *"nvfp4"* ]] || [[ "$KV_CACHE_QUANT" == *"nvfp4"* ]] || [[ "$RECIPE" == *"nvfp4"* ]]; then cuda_major=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader -i 0 | cut -d. -f1) if [ "$cuda_major" -lt 10 ]; then @@ -212,6 +213,11 @@ if [[ $TASKS =~ "quant" ]] || [[ ! -d "$SAVE_PATH" ]] || [[ ! $(ls -A $SAVE_PATH fi fi + if [ -n "$RECIPE" ]; then + echo "Recipe $RECIPE used. Please deploy with TensorRT-LLM directly. Checkpoint export_path: $SAVE_PATH" + exit 0 + fi + if [[ ! " fp8 nvfp4 bf16 fp16 " =~ " ${QFORMAT} " ]]; then echo "Quant $QFORMAT specified. Please read TensorRT-LLM quantization support matrix https://nvidia.github.io/TensorRT-LLM/features/quantization.html#quantization-in-tensorrt-llm and use TensorRT-LLM for deployment. Checkpoint export_path: $SAVE_PATH" exit 0 diff --git a/examples/llm_ptq/scripts/parser.sh b/examples/llm_ptq/scripts/parser.sh index 3817c1dee7c..2a9a28b3566 100644 --- a/examples/llm_ptq/scripts/parser.sh +++ b/examples/llm_ptq/scripts/parser.sh @@ -20,6 +20,7 @@ parse_options() { # Default values MODEL_PATH="" QFORMAT="" + RECIPE="" KV_CACHE_QUANT="" TP=1 PP=1 @@ -37,13 +38,14 @@ parse_options() { CAST_MXFP4_TO_NVFP4=false # Parse command-line options - ARGS=$(getopt -o "" -l "model:,quant:,kv_cache_quant:,tp:,pp:,sparsity:,awq_block_size:,calib:,calib_batch_size:,auto_quantize_bits:,output:,batch:,tasks:,lm_eval_tasks:,lm_eval_limit:,simple_eval_tasks:,trust_remote_code,use_seq_device_map,gpu_max_mem_percentage:,kv_cache_free_gpu_memory_fraction:,low_memory_mode,no-verbose,calib_dataset:,calib_seq:,auto_quantize_method:,auto_quantize_score_size:,auto_quantize_checkpoint:,moe_calib_experts_ratio:,cast_mxfp4_to_nvfp4" -n "$0" -- "$@") + ARGS=$(getopt -o "" -l "model:,quant:,recipe:,kv_cache_quant:,tp:,pp:,sparsity:,awq_block_size:,calib:,calib_batch_size:,auto_quantize_bits:,output:,batch:,tasks:,lm_eval_tasks:,lm_eval_limit:,simple_eval_tasks:,trust_remote_code,use_seq_device_map,gpu_max_mem_percentage:,kv_cache_free_gpu_memory_fraction:,low_memory_mode,no-verbose,calib_dataset:,calib_seq:,auto_quantize_method:,auto_quantize_score_size:,auto_quantize_checkpoint:,moe_calib_experts_ratio:,cast_mxfp4_to_nvfp4" -n "$0" -- "$@") eval set -- "$ARGS" while true; do case "$1" in --model ) MODEL_PATH="$2"; shift 2;; --quant ) QFORMAT="$2"; shift 2;; + --recipe ) RECIPE="$2"; shift 2;; --kv_cache_quant ) KV_CACHE_QUANT="$2"; shift 2;; --tp ) TP="$2"; shift 2;; --pp ) PP="$2"; shift 2;; @@ -99,12 +101,19 @@ parse_options() { fi # Verify required options are provided - if [ -z "$MODEL_PATH" ] || [ -z "$QFORMAT" ] || [ -z "$TASKS" ]; then - echo "Usage: $0 --model= --quant= --tasks=" + if [ -z "$MODEL_PATH" ] || [ -z "$TASKS" ] || ([ -z "$QFORMAT" ] && [ -z "$RECIPE" ]); then + echo "Usage: $0 --model= (--quant= | --recipe=) --tasks=" echo "Optional args: --sparsity= --awq_block_size= --calib=" exit 1 fi + # --quant and --recipe are mutually exclusive: --recipe is a full PTQ spec, while + # --quant selects a built-in qformat preset. Pick exactly one. + if [ -n "$QFORMAT" ] && [ -n "$RECIPE" ]; then + echo "Cannot specify both --quant and --recipe; pick one." >&2 + exit 1 + fi + VALID_TASKS=("quant" "mmlu" "lm_eval" "livecodebench" "simple_eval") for task in $(echo "$TASKS" | tr ',' ' '); do @@ -135,6 +144,7 @@ parse_options() { echo "=================" echo "model: $MODEL_PATH" echo "quant: $QFORMAT" + echo "recipe: $RECIPE" echo "tp (TensorRT-LLM Checkpoint only): $TP" echo "pp (TensorRT-LLM Checkpoint only): $PP" echo "sparsity: $SPARSITY_FMT" diff --git a/modelopt/torch/export/moe_utils.py b/modelopt/torch/export/moe_utils.py index 952ed1e39c1..21e6537e92e 100644 --- a/modelopt/torch/export/moe_utils.py +++ b/modelopt/torch/export/moe_utils.py @@ -87,11 +87,25 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None: and w_quantizer._amax.dim() >= 1 ): amax = w_quantizer._amax + # Static block-quant calibration (e.g. NVFP4 MSE FP8 sweep) + # produces a per-block _amax with shape (num_blocks_total, ...) + # where num_blocks_total = fused_total * blocks_per_row. That + # shape collapses the row axis we want to slice on. Restore the + # row dimension so the dim-0 slicing below splits gate / up + # correctly. No-op when _amax is already aligned with fused_total. + if amax.numel() != fused_total and amax.numel() % fused_total == 0: + amax = amax.contiguous().view(fused_total, amax.numel() // fused_total) amax_dim0 = amax.shape[0] if fused_total % amax_dim0 == 0: slice_start = fused_start * amax_dim0 // fused_total slice_end = (fused_start + weight_slice.shape[0]) * amax_dim0 // fused_total - w_quantizer.amax = amax[slice_start:slice_end].contiguous() + sliced = amax[slice_start:slice_end].contiguous() + # The amax setter refuses shape changes once `_amax` exists, + # so drop the existing buffer before re-registering with the + # sliced shape. + if hasattr(w_quantizer, "_amax"): + delattr(w_quantizer, "_amax") + w_quantizer.amax = sliced else: warnings.warn( f"Expert {idx} {proj_name}: fused amax dim0 ({amax_dim0}) does not " diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index c0f00f7e9a1..080c98ea950 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -1130,6 +1130,32 @@ def _unpatch_revert_weight_conversion(patches: list[tuple[Any, Any]]) -> None: mod.revert_weight_conversion = original +def _sanitize_generation_config_for_save(model: torch.nn.Module) -> None: + """Coerce ``model.generation_config`` so it passes transformers' strict validation. + + Some upstream HF checkpoints ship a ``generation_config.json`` that mixes + ``do_sample=False`` with sampling-only attrs (``top_p``, ``top_k``, ...). + Newer transformers raise ``ValueError("GenerationConfig is invalid: ...")`` + inside ``save_pretrained``, blocking export. We try a strict validate and + on failure flip ``do_sample`` to ``True`` so the upstream sampling intent + is preserved (rather than silently dropping ``top_p`` etc.). Quietly does + nothing if the model has no generation_config or it's already valid. + """ + gc = getattr(model, "generation_config", None) + if gc is None or not hasattr(gc, "validate"): + return + try: + gc.validate(strict=True) + return + except Exception: + pass + if not getattr(gc, "do_sample", False): + try: + gc.do_sample = True + except Exception: + pass + + def export_speculative_decoding( model: torch.nn.Module, dtype: torch.dtype | None = None, @@ -1211,6 +1237,12 @@ def export_hf_checkpoint( # modeling_utils does `from core_model_loading import revert_weight_conversion`. _patches = _patch_revert_weight_conversion() + # Some upstream HF checkpoints ship a generation_config.json that fails + # transformers' strict validation on save (e.g. ``top_p`` set without + # ``do_sample=True`` — newer transformers raises). Flip ``do_sample`` to + # the sampling-attrs intent so save_pretrained can write the file. + _sanitize_generation_config_for_save(model) + try: model.save_pretrained( export_dir, diff --git a/modelopt/torch/kernels/quantization/gemm/__init__.py b/modelopt/torch/kernels/quantization/gemm/__init__.py index 39b07b4faa9..70f729cffb0 100644 --- a/modelopt/torch/kernels/quantization/gemm/__init__.py +++ b/modelopt/torch/kernels/quantization/gemm/__init__.py @@ -32,6 +32,7 @@ # fp4_kernel works on any CUDA GPU with triton from .fp4_kernel import * from .fp8_kernel import * + from .nvfp4_fp8_sweep import * # fp4_kernel_hopper requires compute >= 8.9 (uses tl.float8e4nv) if torch.cuda.get_device_capability() >= (8, 9): diff --git a/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py b/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py new file mode 100644 index 00000000000..4b9f19837f2 --- /dev/null +++ b/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py @@ -0,0 +1,166 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fused Triton kernel for the NVFP4 weight-MSE FP8 scale sweep. + +Replaces the 126-iteration Python sweep in :class:`NVFP4MSECalibrator` with a single +kernel that, for each NVFP4 block, evaluates all 126 valid FP8 E4M3 scale candidates +and emits the per-block ``best_amax`` directly. + +The 126 candidates are constructed as ``valid_fp8_e4m3_value / 448`` (see +:func:`fp8_scale_candidates`). For these specific candidates, the FP8 round-trip on +the per-block scale is the identity, so the kernel can use +``scale = candidate * global_amax / 6.0`` without an explicit FP8 cast — making it +runnable on any CUDA GPU with Triton (no ``tl.float8e4nv`` requirement). + +Tile shape (``BLOCKS_PER_PROGRAM``) and ``num_warps`` are autotuned per ``N_BLOCKS``. +""" + +import torch +import triton +import triton.language as tl + +from .nvfp4_quant import fp4_round_magnitude + +__all__ = ["fp8_scale_candidates", "nvfp4_fp8_scale_sweep"] + + +def fp8_scale_candidates(device: torch.device | str = "cpu") -> torch.Tensor: + """Return the 126 valid finite positive FP8 E4M3 scale candidates / 448.""" + uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device) + fp8_values = uint8_values.view(torch.float8_e4m3fn).float() + valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0) + return fp8_values[valid_mask] / 448.0 + + +# Selected from a (BLOCKS_PER_PROGRAM, num_warps) sweep on B300: +# BPP=16,nw=2: 6.06 ms BPP=32,nw=4: 6.06 ms BPP=64,nw=8: 5.08 ms +# The smaller-tile entries cover cases where N_BLOCKS is small enough that BPP=64 +# would underfill the SMs. +_FP8_SWEEP_AUTOTUNE_CONFIGS = [ + triton.Config({"BLOCKS_PER_PROGRAM": 16}, num_warps=2), + triton.Config({"BLOCKS_PER_PROGRAM": 32}, num_warps=4), + triton.Config({"BLOCKS_PER_PROGRAM": 64}, num_warps=8), +] + + +@triton.autotune(configs=_FP8_SWEEP_AUTOTUNE_CONFIGS, key=["N_BLOCKS"]) +@triton.jit +def _fp8_scale_sweep_kernel( + x_ptr, # [N_BLOCKS * BLOCK_SIZE], any float dtype (loaded as fp32) + candidates_ptr, # [NUM_CANDIDATES] fp32 + global_amax_ptr, # scalar fp32 + best_amax_ptr, # [N_BLOCKS] fp32 output + N_BLOCKS, + BLOCK_SIZE: tl.constexpr, + NUM_CANDIDATES: tl.constexpr, + BLOCKS_PER_PROGRAM: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCKS_PER_PROGRAM + block_idx = block_start + tl.arange(0, BLOCKS_PER_PROGRAM) + block_mask = block_idx < N_BLOCKS + + # Load weights for this tile and pre-compute their absolute values once. + # The squared error is sign-invariant since FP4 quant preserves sign: + # (w - w_q)^2 = (|w| - |w_q|)^2 = (|w| - q_mag * scale)^2 + # so we never need ``w`` itself again, dropping a tl.where + negation per element. + elem_offs = block_idx[:, None] * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[None, :] + elem_mask = block_mask[:, None] + w_abs = tl.abs(tl.load(x_ptr + elem_offs, mask=elem_mask, other=0.0).to(tl.float32)) + + global_amax = tl.load(global_amax_ptr).to(tl.float32) + + best_loss = tl.full([BLOCKS_PER_PROGRAM], float("inf"), dtype=tl.float32) + best_idx = tl.zeros([BLOCKS_PER_PROGRAM], dtype=tl.int32) + + # Loop over the 126 FP8 candidates (compile-time unrolled). + # Scales are guaranteed positive and finite (constructed from a positive candidate + # times nonneg global_amax), so the degenerate-scale guard from nvfp4_scalar_quant is + # unnecessary apart from the global_amax == 0 case handled below. + for k in tl.static_range(NUM_CANDIDATES): + c = tl.load(candidates_ptr + k).to(tl.float32) + scale = c * global_amax / 6.0 + # Avoid divide-by-zero when global_amax == 0; the resulting err == w_abs² is + # the same for every candidate, so any best_idx is fine. + scale_safe = tl.where(scale == 0.0, 1.0, scale) + q_mag = fp4_round_magnitude(w_abs / scale_safe) + diff = w_abs - q_mag * scale + loss = tl.sum(diff * diff, axis=1) # [BLOCKS_PER_PROGRAM] + is_better = loss < best_loss + best_loss = tl.where(is_better, loss, best_loss) + best_idx = tl.where(is_better, k, best_idx) + + # Map each block's winning candidate index back to its amax = global_amax * c[best]. + best_c = tl.load(candidates_ptr + best_idx, mask=block_mask, other=0.0).to(tl.float32) + best_amax = global_amax * best_c + tl.store(best_amax_ptr + block_idx, best_amax, mask=block_mask) + + +def nvfp4_fp8_scale_sweep( + x: torch.Tensor, + global_amax: torch.Tensor, + block_size: int = 16, + candidates: torch.Tensor | None = None, +) -> torch.Tensor: + """Find the per-block FP8 scale that minimizes NVFP4 quantization MSE. + + Equivalent to the 126-step sweep in :class:`NVFP4MSECalibrator`, but fused into + a single Triton kernel: every block's weight elements are loaded once, all 126 + candidates are evaluated in registers, and the running argmin is kept inline. + + Args: + x: Weight tensor on CUDA. Total element count must be divisible by + ``block_size``; layout is treated as a flat ``[N_BLOCKS, BLOCK_SIZE]``. + global_amax: Scalar FP32 global amax (``= reduce_amax(per_block_amax)``). + block_size: NVFP4 block size (typically 16). + candidates: Optional precomputed candidate tensor of shape ``[126]`` (must + be the FP8 E4M3 valid values divided by 448). Built lazily if omitted. + + Returns: + ``best_amax`` of shape ``[N_BLOCKS]``, fp32, on the same device as ``x``. + """ + if not x.is_cuda: + raise ValueError("nvfp4_fp8_scale_sweep requires a CUDA tensor.") + if not isinstance(block_size, int) or block_size <= 0: + raise ValueError(f"block_size must be a positive int, got {block_size!r}.") + if x.numel() % block_size != 0: + raise ValueError(f"x.numel() ({x.numel()}) is not divisible by block_size ({block_size}).") + + if candidates is None: + candidates = fp8_scale_candidates(x.device) + candidates = candidates.contiguous().to(device=x.device, dtype=torch.float32) + if candidates.ndim != 1 or candidates.numel() == 0: + raise ValueError( + f"candidates must be a non-empty 1-D tensor; got shape {tuple(candidates.shape)}." + ) + + n_blocks = x.numel() // block_size + x_flat = x.contiguous().view(-1) + global_amax_f32 = global_amax.detach().to(device=x.device, dtype=torch.float32).reshape(1) + best_amax = torch.empty(n_blocks, dtype=torch.float32, device=x.device) + + grid = lambda meta: (triton.cdiv(n_blocks, meta["BLOCKS_PER_PROGRAM"]),) + with torch.cuda.device(x.device): + _fp8_scale_sweep_kernel[grid]( + x_flat, + candidates, + global_amax_f32, + best_amax, + n_blocks, + BLOCK_SIZE=block_size, + NUM_CANDIDATES=int(candidates.numel()), + ) + return best_amax diff --git a/modelopt/torch/quantization/calib/mse.py b/modelopt/torch/quantization/calib/mse.py index 1f439a7e778..fff0b8af1b2 100644 --- a/modelopt/torch/quantization/calib/mse.py +++ b/modelopt/torch/quantization/calib/mse.py @@ -24,7 +24,7 @@ from .. import utils as quant_utils from .calibrator import _Calibrator -__all__ = ["MseCalibrator", "NVFP4MSECalibrator"] +__all__ = ["MseCalibrator", "NVFP4MSECalibrator", "TritonNVFP4MSECalibrator"] class MseCalibrator(_Calibrator): @@ -192,9 +192,100 @@ def _compute_candidate_amax(self, candidates: torch.Tensor) -> torch.Tensor: return torch.ones_like(self._initial_amax) * self._global_amax * candidates def _generate_candidates(self, device: torch.device) -> torch.Tensor: - """Generate 126 valid FP8 E4M3 scale candidates.""" + """Generate 126 valid FP8 E4M3 scale candidates. + + Kept in sync with ``fp8_scale_candidates`` in + ``modelopt.torch.kernels.quantization.gemm.nvfp4_fp8_sweep`` — the FP8 E4M3 + spec is fixed, and the parity test exercises both paths against each other. + """ uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device) fp8_values = uint8_values.view(torch.float8_e4m3fn).float() valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0) fp8_values = fp8_values[valid_mask] return fp8_values / 448.0 + + +class TritonNVFP4MSECalibrator(NVFP4MSECalibrator): + """Triton-fused FP8 scale sweep calibrator for NVFP4 weight MSE. + + Numerically equivalent to :class:`NVFP4MSECalibrator` but evaluates all 126 + candidates in a single fused Triton kernel — one weight read instead of 126. + + Limitation: a single ``collect()`` call is supported per ``compute_amax`` cycle. + This matches the static weight-MSE flow (``mse_calibrate``'s weight loop), where + the calibrator is collected once per weight and immediately consumed. For + activation calibration (multiple ``collect`` calls), use :class:`NVFP4MSECalibrator`. + Call :meth:`reset` to free internal state and re-enable :meth:`collect`. + """ + + def __init__( + self, + amax: torch.Tensor, + global_amax: torch.Tensor, + axis: int | tuple | list | None = None, + quant_func: Callable | None = None, + error_func: Callable | None = None, + ): + """Initialize the Triton-fused NVFP4 MSE calibrator. + + See :class:`NVFP4MSECalibrator`. ``quant_func``/``error_func`` are unused by + the kernel path but accepted for API parity. Tile shape and ``num_warps`` are + autotuned by the kernel per ``N_BLOCKS``. + """ + super().__init__( + amax=amax, + global_amax=global_amax, + axis=axis, + quant_func=quant_func, + error_func=error_func, + ) + # Stash shape metadata so collect() can keep working after reset() releases + # the (potentially large) _initial_amax buffer. + self._initial_amax_shape = tuple(amax.shape) + self._initial_amax_dtype = amax.dtype + self._n_blocks = int(amax.numel()) + self._best_amax: torch.Tensor | None = None + + @torch.no_grad() + def collect(self, x: torch.Tensor): + """Run the fused FP8 sweep kernel and store the resulting per-block amax.""" + from modelopt.torch.kernels.quantization.gemm import nvfp4_fp8_scale_sweep + + if self._best_amax is not None: + raise RuntimeError( + "TritonNVFP4MSECalibrator.collect() is one-shot; call reset() to " + "discard the previous result before collecting again." + ) + + x = x.detach() + # The weight quantizer reshapes its input to [n_blocks, block_size] before + # calling collect (see TensorQuantizer._process_for_blockquant). + assert x.ndim == 2, f"Expected x to be [n_blocks, block_size]; got shape {tuple(x.shape)}." + block_size = x.shape[-1] + n_blocks = x.numel() // block_size + if n_blocks != self._n_blocks: + raise ValueError( + f"initial amax.numel() ({self._n_blocks}) does not match the number " + f"of NVFP4 blocks in x ({n_blocks})." + ) + + best_amax_flat = nvfp4_fp8_scale_sweep( + x, + self._global_amax, + block_size=block_size, + ) + # Match the original shape/dtype of the initial amax so downstream + # load_calib_amax behaves identically to the reference path. + self._best_amax = best_amax_flat.reshape(self._initial_amax_shape).to( + self._initial_amax_dtype + ) + + @torch.no_grad() + def compute_amax(self, verbose: bool = False): + """Return the per-block amax computed during ``collect``.""" + return self._best_amax + + def reset(self): + """Reset the stored best amax. Subsequent ``collect`` calls are allowed.""" + self._best_amax = None + super().reset() diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 4ce0f62a75d..2ac1d9aee1a 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -16,6 +16,7 @@ """Calibration utilities.""" import math +import os import time import warnings from collections.abc import Callable @@ -37,7 +38,7 @@ from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method -from .calib import MseCalibrator, NVFP4MSECalibrator, _Calibrator +from .calib import MseCalibrator, NVFP4MSECalibrator, TritonNVFP4MSECalibrator, _Calibrator from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context from .nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer from .utils import ( @@ -52,7 +53,6 @@ promote_nvfp4_static_quantizers, quantizer_attr_names, reduce_amax, - weight_attr_names, ) from .utils.calib_utils import _GPTQ_HELPER_REGISTRY, GPTQHelper @@ -64,8 +64,143 @@ "max_calibrate", "smoothquant", "svdquant", + "sync_grouped_weight_global_amax", ] + +# Sibling weight-quantizer name groups whose ``global_amax`` should share an +# FP8 scale-of-scales. All members of a group sit under the same parent module +# (e.g. one self-attention or one MLP block) and either consume the same input +# tensor or get fused at deployment, so a divergent global_amax across siblings +# would split their FP8 grids and skew the round. +_GROUPED_WEIGHT_QUANTIZER_PATTERNS: tuple[tuple[str, ...], ...] = ( + # Standard self-attention (skipped for fused qkv_proj — single weight). + ("q_proj", "k_proj", "v_proj"), + # Gated MLP, modern naming (Llama / Qwen / Mistral / etc.). + ("gate_proj", "up_proj"), + # Gated MLP, older Mixtral-style naming. + ("w1", "w3"), +) + + +def _is_calibrated_nvfp4_static_weight_quantizer(q) -> bool: + """Check whether ``q`` is an enabled, calibrated NVFP4-static weight quantizer. + + True when ``max_calibrate`` already populated a per-block ``_amax``. + """ + return ( + isinstance(q, TensorQuantizer) + and not q._disabled + and q.is_nvfp4_static + and hasattr(q, "_amax") + and q._amax is not None + ) + + +def _collect_grouped_linears(model: nn.Module) -> list[list[nn.Module]]: + """Find Linear-like submodule groups whose NVFP4-static weight quantizers should share global_amax. + + Groups are Q/K/V under one attention parent and gate/up under one MLP parent. + """ + groups: list[list[nn.Module]] = [] + wq_attr = quantizer_attr_names("weight").weight_quantizer + for parent in model.modules(): + for sibling_names in _GROUPED_WEIGHT_QUANTIZER_PATTERNS: + members: list[nn.Module] = [] + for n in sibling_names: + child = getattr(parent, n, None) + if child is None: + continue + wq = getattr(child, wq_attr, None) + if _is_calibrated_nvfp4_static_weight_quantizer(wq): + members.append(child) + if len(members) >= 2: + groups.append(members) + return groups + + +@torch.no_grad() +def _bootstrap_uncalibrated_weight_quantizers(model: nn.Module) -> int: + """Run a max-style amax collection on weight quantizers whose ``_amax`` is missing. + + Forward-pass max calibration only populates per-expert weight quantizers in + fused-experts containers when tokens are routed to that expert. "Dead" + experts that received no tokens end up with no ``_amax``, which causes + ``mse_calibrate``'s subsequent walk to skip them and forces the export-time + fallback to derive separate per-half amax for gate/up. This helper walks + every QuantModule's :meth:`iter_weights_for_calibration` pairs and, for any + quantizer that lacks ``_amax``, runs the existing calibrator (typically + :class:`MaxCalibrator`) on the corresponding weight slice — populating + ``_amax`` from the weight rather than from runtime activations. + + Returns the number of quantizers bootstrapped (mostly for diagnostics). + """ + n = 0 + for module in model.modules(): + if not isinstance(module, QuantModule): + continue + try: + pairs = list(module.iter_weights_for_calibration()) + except Exception: + continue + for weight, q in pairs: + if not isinstance(q, TensorQuantizer) or q._disabled or q._dynamic: + continue + if q._calibrator is None: + continue + if hasattr(q, "_amax") and q._amax is not None and not torch.all(q._amax == 0): + continue + q.disable_quant() + q.enable_calib() + q(weight) + if q._calibrator.compute_amax() is not None: + q.load_calib_amax() + q.enable_quant() + q.disable_calib() + if hasattr(q._calibrator, "reset"): + q._calibrator.reset() + n += 1 + return n + + +@torch.no_grad() +def sync_grouped_weight_global_amax(model: nn.Module) -> int: + """Sync ``global_amax`` across sibling NVFP4-static weight quantizers. + + For each group of siblings (Q/K/V projections under one attention parent; + gate/up — a.k.a. ``w1``/``w3`` — under one MLP parent) unifies the + NVFP4 ``global_amax`` so the per-block FP8 round picks scales against a + consistent FP8 grid across the group during MSE / local-Hessian search. + + Reuses :func:`modelopt.torch.export.quant_utils.preprocess_linear_fusion` + (whose ``NVFP4StaticQuantizer`` branch performs the same + ``max(stack(global_amax))`` unification at export time). To call it before + MSE, this helper first promotes each grouped weight quantizer to + :class:`NVFP4StaticQuantizer` with its local ``global_amax`` (= + ``reduce_amax(_amax)``); ``preprocess_linear_fusion`` then unifies in + place. + + Must be called after ``max_calibrate`` has populated each weight + quantizer's ``_amax``. Idempotent. Returns the number of groups synced. + """ + from modelopt.torch.export.quant_utils import preprocess_linear_fusion + + n_groups = 0 + for group in _collect_grouped_linears(model): + # Promote each member's weight quantizer so `preprocess_linear_fusion` + # sees post-conversion NVFP4StaticQuantizers (its NVFP4 branch reads + # `global_amax`, which only exists post-promotion). + wq_attr = quantizer_attr_names("weight").weight_quantizer + for child in group: + wq = getattr(child, wq_attr) + if not isinstance(wq, NVFP4StaticQuantizer): + local_global = reduce_amax(wq._amax, axis=None) + NVFP4StaticQuantizer.from_tensor_quantizer(wq, global_amax=local_global) + preprocess_linear_fusion(group) + n_groups += 1 + return n_groups + + CalibratorFactory: TypeAlias = Callable[ [torch.Tensor, int | tuple | list | None, Callable[..., torch.Tensor]], _Calibrator ] @@ -349,10 +484,29 @@ def mse_calibrate( # Step 1: First get initial amax using max calibration max_calibrate(model, forward_loop, distributed_sync) - # Step 2: Replace calibrators with MseCalibrator for enabled quantizers - # and identify weight quantizers - weight_quantizers = [] - seen_modules = set() + # Step 1a: Bootstrap any weight quantizer that didn't receive an _amax from + # the forward-pass max calibration (typical of dead MoE experts in fused- + # experts containers). Without this, the dead-expert per-expert quantizers + # would be silently skipped by step 2's `hasattr(_amax)` gate, leaving the + # export-time fallback to derive separate gate/up amaxes from each half of + # the fused weight (breaking the gate==up weight_scale_2 invariant). + _bootstrap_uncalibrated_weight_quantizers(model) + + # Step 1b: Sync global_amax across sibling NVFP4-static weight quantizers + # (q/k/v_proj under one attention block; gate/up — a.k.a. w1/w3 — under one + # MLP block) so their FP8 scale-of-scales matches and the per-block FP8 + # round uses a consistent grid. No-op when there are no sibling groups + # (e.g. fused QKV / fused gate_up_proj). + sync_grouped_weight_global_amax(model) + + # Step 2: Replace calibrators with MseCalibrator for enabled quantizers. + # (Weight-quantizer discovery + calibration happens in step 3 below using + # iter_weights_for_calibration.) + + # Triton-fused FP8 sweep is on by default for NVFP4 static quant; set + # MODELOPT_NVFP4_TRITON_SWEEP=0 to fall back to the reference for debugging. + use_triton_fp8_sweep = os.environ.get("MODELOPT_NVFP4_TRITON_SWEEP", "1") != "0" + nvfp4_calibrator_cls = TritonNVFP4MSECalibrator if use_triton_fp8_sweep else NVFP4MSECalibrator for name, module in list(model.named_modules()): if isinstance(module, TensorQuantizer) and not module._disabled: @@ -360,19 +514,16 @@ def mse_calibrate( # Get the initial amax from max calibration initial_amax = module._amax.clone().detach() - is_nvfp4_static = ( - module.is_static_block_quant - and module._num_bits == (2, 1) - and module._block_sizes is not None - and module._block_sizes.get("scale_bits") == (4, 3) - ) + is_nvfp4_static = module.is_nvfp4_static if is_nvfp4_static: - # Compute and set global_amax - global_amax = reduce_amax(initial_amax, axis=None) - - # Convert to NVFP4StaticQuantizer in-place - NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) + # If sync_grouped_weight_global_amax already promoted this + # quantizer (it's a sibling in a Q/K/V or gate/up group), + # its global_amax has been unified across the group; just + # leave it. Otherwise convert + set local global_amax. + if not isinstance(module, NVFP4StaticQuantizer): + global_amax = reduce_amax(initial_amax, axis=None) + NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) if fp8_scale_sweep: # Check if backend has a registered custom calibrator factory. @@ -391,8 +542,7 @@ def mse_calibrate( continue if fp8_scale_sweep and is_nvfp4_static: - # Replace calibrator with NVFP4MSECalibrator - module._calibrator = NVFP4MSECalibrator( + module._calibrator = nvfp4_calibrator_cls( amax=initial_amax, axis=module._calibrator._axis, global_amax=module.global_amax, @@ -410,52 +560,80 @@ def mse_calibrate( quant_func=partial(_mse_quant_func, quantizer=module), ) - # Identify weight quantizers by checking if they have corresponding weight parameters + # Step 3+4: discover and calibrate weight quantizers via + # iter_weights_for_calibration, which yields (weight_or_slice, quantizer) + # pairs. For non-fused QuantModules, this is one pair per weight (same as + # the previous singular-only walk). For fused-experts containers + # (transformers 5.x: gate_up_proj / down_proj as 3-D Parameters with per- + # expert quantizer ModuleLists) it yields one pair per expert per + # projection — so every per-expert weight quantizer gets MSE-calibrated, + # not just the ones that received forward-pass tokens. name_to_module = dict(model.named_modules()) + weight_calib_seen: set[int] = set() + + # Pre-count for an accurate tqdm total (the same iter is cheap to repeat; + # actually run-time work happens in the second pass). + total_to_calib = 0 for parent_module in name_to_module.values(): - if parent_module in seen_modules: + if id(parent_module) in weight_calib_seen or not isinstance(parent_module, QuantModule): + continue + try: + pairs = list(parent_module.iter_weights_for_calibration()) + except Exception: + continue + for _, q in pairs: + if ( + isinstance(q, TensorQuantizer) + and q.is_enabled + and getattr(q, "_calibrator", None) is not None + ): + total_to_calib += 1 + + pbar = tqdm(total=total_to_calib, desc="MSE weight calibration") + n_calibrated = 0 + for parent_module in name_to_module.values(): + if id(parent_module) in weight_calib_seen: + continue + weight_calib_seen.add(id(parent_module)) + if not isinstance(parent_module, QuantModule): continue - for weight_name in weight_attr_names(parent_module): - weight_quantizer_name = quantizer_attr_names(weight_name).weight_quantizer - weight_quantizer = getattr(parent_module, weight_quantizer_name, None) - if isinstance(weight_quantizer, TensorQuantizer) and weight_quantizer.is_enabled: - if getattr(weight_quantizer, "_calibrator", None) is not None: - weight_quantizers.append((parent_module, weight_name, weight_quantizer)) - seen_modules.add(parent_module) - - # Step 3: Calibrate weight quantizers ONE AT A TIME with immediate amax computation - # This prevents massive memory accumulation seen in large models - for idx, (parent_module, weight_name, weight_quantizer) in enumerate( - tqdm(weight_quantizers, desc="MSE weight calibration") - ): - # Enable calibration mode for the weight quantizer - weight_quantizer.disable_quant() - weight_quantizer.enable_calib() with enable_weight_access_and_writeback(parent_module, model, name_to_module): - weight = getattr(parent_module, weight_name) - weight_quantizer(weight) + try: + pairs = list(parent_module.iter_weights_for_calibration()) + except Exception: + pairs = [] + for weight, weight_quantizer in pairs: + if not ( + isinstance(weight_quantizer, TensorQuantizer) + and weight_quantizer.is_enabled + and getattr(weight_quantizer, "_calibrator", None) is not None + ): + continue + weight_quantizer.disable_quant() + weight_quantizer.enable_calib() + weight_quantizer(weight) - # IMMEDIATELY compute amax and reset calibrator to free memory - cal = getattr(weight_quantizer, "_calibrator", None) - if cal is not None and cal.compute_amax() is not None: - weight_quantizer.load_calib_amax() + cal = weight_quantizer._calibrator + if cal.compute_amax() is not None: + weight_quantizer.load_calib_amax() - weight_quantizer.enable_quant() - weight_quantizer.disable_calib() + weight_quantizer.enable_quant() + weight_quantizer.disable_calib() - # Synchronize ALL CUDA devices before resetting to ensure all async operations complete - # This is critical for multi-GPU setups where tensors may be on different devices - if torch.cuda.is_available(): - for dev_id in range(torch.cuda.device_count()): - torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) + if torch.cuda.is_available(): + for dev_id in range(torch.cuda.device_count()): + torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) - if cal is not None and hasattr(cal, "reset"): - cal.reset() + if hasattr(cal, "reset"): + cal.reset() - if (idx + 1) % 10 == 0 and torch.cuda.is_available(): - for dev_id in range(torch.cuda.device_count()): - torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) - torch.cuda.empty_cache() + pbar.update(1) + n_calibrated += 1 + if n_calibrated % 10 == 0 and torch.cuda.is_available(): + for dev_id in range(torch.cuda.device_count()): + torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) + torch.cuda.empty_cache() + pbar.close() if torch.cuda.is_available(): for dev_id in range(torch.cuda.device_count()): @@ -610,6 +788,11 @@ def forward(self, input, *args, **kwargs): print_rank_0("local_hessian: Running max calibration for all quantizers...") max_calibrate(model, forward_loop, distributed_sync) + # Sync global_amax across sibling NVFP4-static weight quantizers + # (q/k/v_proj, gate/up_proj a.k.a. w1/w3) so the FP8 scale-of-scales + # is consistent across the group. Idempotent; no-op when fused. + sync_grouped_weight_global_amax(model) + # Setup helpers for all quantized linear modules name_to_module = dict(model.named_modules()) weight_quantizers_info = [] @@ -664,14 +847,9 @@ def quant_func(x, amax, quantizer=weight_quantizer): return xq - is_nvfp4_static = ( - weight_quantizer.is_static_block_quant - and weight_quantizer._num_bits == (2, 1) - and weight_quantizer._block_sizes is not None - and weight_quantizer._block_sizes.get("scale_bits") == (4, 3) - ) + is_nvfp4_static = weight_quantizer.is_nvfp4_static - if is_nvfp4_static: + if is_nvfp4_static and not isinstance(weight_quantizer, NVFP4StaticQuantizer): global_amax = reduce_amax(initial_amax, axis=None) NVFP4StaticQuantizer.from_tensor_quantizer(weight_quantizer, global_amax=global_amax) diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 3ff7401ec3e..12649691453 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -514,6 +514,22 @@ def is_mx_format(self): and self.block_sizes.get("scale_bits", None) == (8, 0) ) + @property + def is_nvfp4_static(self): + """Check if this quantizer is configured for NVFP4 static block quantization. + + Format-only check (does not consider whether ``_amax`` has been + populated by calibration). True when the quantizer holds E2M1 weights + with E4M3 per-block scales in a static layout — i.e. the two-level + scaling NVFP4 path consumed by :class:`NVFP4StaticQuantizer`. + """ + return ( + self.is_static_block_quant + and self._num_bits == (2, 1) + and self._block_sizes is not None + and self._block_sizes.get("scale_bits") == (4, 3) + ) + def is_mxfp(self, bits): """Check if is MXFP4/MXFP6/MXFP8.""" if bits == 4: diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 77f26b20602..10f9721ad4d 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -900,6 +900,27 @@ def forward(self, *args, **kwargs): self._down_proj_linear = False return super().forward(*args, **kwargs) + def iter_weights_for_calibration(self): + """Yield ``(weight_slice, quantizer)`` pairs for every per-expert weight quantizer. + + Overrides the default :meth:`QuantModule.iter_weights_for_calibration`, + which uses ``weight_attr_names`` + singular ``*_weight_quantizer`` and + therefore silently skips fused-experts modules. Without this override, + weight-only calibration paths (``mse_calibrate``, ``weight_only_quantize``) + never reach per-expert weight quantizers — leaving any expert that the + forward-pass max-calibration didn't route to with no ``_amax``. + """ + for weight_name, quantizers_name in ( + ("gate_up_proj", "gate_up_proj_weight_quantizers"), + ("down_proj", "down_proj_weight_quantizers"), + ): + weight = getattr(self, weight_name, None) + quantizers = getattr(self, quantizers_name, None) + if weight is None or quantizers is None: + continue + for idx, q in enumerate(quantizers): + yield weight[idx], q + def fold_weight(self, keep_attrs: bool = False): """Fold per-expert weight quantizers into the fused 3-D weights. diff --git a/modelopt/torch/quantization/qtensor/nvfp4_tensor.py b/modelopt/torch/quantization/qtensor/nvfp4_tensor.py index fe30e283c2d..bb39c8a81e3 100644 --- a/modelopt/torch/quantization/qtensor/nvfp4_tensor.py +++ b/modelopt/torch/quantization/qtensor/nvfp4_tensor.py @@ -122,10 +122,16 @@ def get_weights_scaling_factor_from_quantizer( expected_shape = (*weight.shape[:-1], num_blocks_per_row) per_block_scale = per_block_scale.view(expected_shape) - # Quantize scales to FP8 + # Quantize scales to FP8. Saturate to the fp8_e4m3fn max (448) before the + # cast: when the [==0]=1.0 safety net above fires (per_block_amax was zero + # for an all-zero weight block) and global_amax is small, the pre-cast value + # explodes to ``1.0 * 448 / (global_amax/6)``. fp8_e4m3fn has no Inf, so any + # value >= 480 casts to NaN — clamp first to keep the stored byte finite. if not keep_high_precision: - per_block_scale = (per_block_scale * 448.0 / per_block_scale_max).to( - torch.float8_e4m3fn + per_block_scale = ( + (per_block_scale * 448.0 / per_block_scale_max) + .clamp_(max=448.0) + .to(torch.float8_e4m3fn) ) return per_block_scale, weights_scaling_factor_2 else: diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index 1a177e04dc8..cea3d4260e4 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -957,13 +957,7 @@ def promote_nvfp4_static_quantizers(model: nn.Module) -> int: for _name, module in list(model.named_modules()): if isinstance(module, TensorQuantizer) and not module._disabled: if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): - is_nvfp4_static = ( - module.is_static_block_quant - and module._num_bits == (2, 1) - and module._block_sizes is not None - and module._block_sizes.get("scale_bits") == (4, 3) - ) - if is_nvfp4_static: + if module.is_nvfp4_static: initial_amax = module._amax.clone().detach() global_amax = reduce_amax(initial_amax, axis=None) NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) diff --git a/modelopt_recipes/general/ptq/nvfp4_experts_only-kv_fp8_cast.yaml b/modelopt_recipes/general/ptq/nvfp4_experts_only-kv_fp8_cast.yaml new file mode 100644 index 00000000000..8b0df2ebb68 --- /dev/null +++ b/modelopt_recipes/general/ptq/nvfp4_experts_only-kv_fp8_cast.yaml @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + nvfp4: configs/numerics/nvfp4 + kv_fp8_cast: configs/ptq/units/kv_fp8_cast + +metadata: + recipe_type: ptq + description: >- + NVFP4 static weight and dynamic activation for expert layers only (W4A4), + FP8 KV cache with constant amax (skips KV calibration; amax hardcoded to + FP8 E4M3 max 448.0), max layerwise calibration. +quantize: + algorithm: + method: max + # Max calibration is fast and does not typically need checkpointing. + # layerwise=false required for VLMs where the decoder layers are nested under + # `model.language_model.layers` (layerwise_calibrate can't find them otherwise). + layerwise: false + quant_cfg: + - $import: base_disable_all + - quantizer_name: '*mlp.experts*weight_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*mlp.experts*input_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*block_sparse_moe*weight_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*block_sparse_moe*input_quantizer' + cfg: + $import: nvfp4 + - $import: kv_fp8_cast + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/general/ptq/nvfp4_experts_only_mse-fp8_cast_kv.yaml b/modelopt_recipes/general/ptq/nvfp4_experts_only_mse-fp8_cast_kv.yaml new file mode 100644 index 00000000000..749a875b368 --- /dev/null +++ b/modelopt_recipes/general/ptq/nvfp4_experts_only_mse-fp8_cast_kv.yaml @@ -0,0 +1,103 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +metadata: + recipe_type: ptq + description: >- + NVFP4 W4A4 for expert layers only with MSE-search FP8 scale calibration on + expert weights, FP8 KV cache with constant amax (skips KV calibration; amax + hardcoded to FP8 E4M3 max 448.0). Expert weight quantizers are static + (per-block amax fixed by the MSE FP8-scale sweep); input quantizers remain + dynamic. +quantize: + algorithm: + method: mse + fp8_scale_sweep: true + # Max calibration is fast and does not typically need checkpointing. + # layerwise=false required for VLMs where the decoder layers are nested under + # `model.language_model.layers` (layerwise_calibrate can't find them otherwise). + layerwise: false + quant_cfg: + - quantizer_name: '*' + enable: false + - quantizer_name: '*mlp.experts*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: static + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*mlp.experts*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*block_sparse_moe*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: static + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*block_sparse_moe*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*[kv]_bmm_quantizer' + enable: true + cfg: + num_bits: e4m3 + use_constant_amax: true + - quantizer_name: '*block_sparse_moe.gate*' + enable: false + - quantizer_name: '*linear_attn.conv1d*' + enable: false + - quantizer_name: '*lm_head*' + enable: false + - quantizer_name: '*mixer.conv1d*' + enable: false + - quantizer_name: '*mlp.gate.*' + enable: false + - quantizer_name: '*mlp.shared_expert_gate.*' + enable: false + - quantizer_name: '*output_layer*' + enable: false + - quantizer_name: '*proj_out.*' + enable: false + - quantizer_name: '*router*' + enable: false + - quantizer_name: 'output.*' + enable: false + - parent_class: 'nn.BatchNorm1d' + quantizer_name: '*' + enable: false + - parent_class: 'nn.BatchNorm2d' + quantizer_name: '*' + enable: false + - parent_class: 'nn.BatchNorm3d' + quantizer_name: '*' + enable: false + - parent_class: 'nn.LeakyReLU' + quantizer_name: '*' + enable: false diff --git a/modelopt_recipes/general/ptq/nvfp4_mlp_only_mse-fp8_cast_kv.yaml b/modelopt_recipes/general/ptq/nvfp4_mlp_only_mse-fp8_cast_kv.yaml new file mode 100644 index 00000000000..bca46608d59 --- /dev/null +++ b/modelopt_recipes/general/ptq/nvfp4_mlp_only_mse-fp8_cast_kv.yaml @@ -0,0 +1,103 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +metadata: + recipe_type: ptq + description: >- + NVFP4 W4A4 for all MLP layers (dense + MoE) with MSE-search FP8 scale + calibration on MLP weights, FP8 KV cache with constant amax (skips KV + calibration; amax hardcoded to FP8 E4M3 max 448.0). MLP weight quantizers + are static (per-block amax fixed by the MSE FP8-scale sweep); input + quantizers remain dynamic. +quantize: + algorithm: + method: mse + fp8_scale_sweep: true + # Max calibration is fast and does not typically need checkpointing. + # layerwise=false required for VLMs where the decoder layers are nested under + # `model.language_model.layers` (layerwise_calibrate can't find them otherwise). + layerwise: false + quant_cfg: + - quantizer_name: '*' + enable: false + - quantizer_name: '*mlp*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: static + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*mlp*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*block_sparse_moe*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: static + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*block_sparse_moe*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*[kv]_bmm_quantizer' + enable: true + cfg: + num_bits: e4m3 + use_constant_amax: true + - quantizer_name: '*block_sparse_moe.gate*' + enable: false + - quantizer_name: '*linear_attn.conv1d*' + enable: false + - quantizer_name: '*lm_head*' + enable: false + - quantizer_name: '*mixer.conv1d*' + enable: false + - quantizer_name: '*mlp.gate.*' + enable: false + - quantizer_name: '*mlp.shared_expert_gate.*' + enable: false + - quantizer_name: '*output_layer*' + enable: false + - quantizer_name: '*proj_out.*' + enable: false + - quantizer_name: '*router*' + enable: false + - quantizer_name: 'output.*' + enable: false + - parent_class: 'nn.BatchNorm1d' + quantizer_name: '*' + enable: false + - parent_class: 'nn.BatchNorm2d' + quantizer_name: '*' + enable: false + - parent_class: 'nn.BatchNorm3d' + quantizer_name: '*' + enable: false + - parent_class: 'nn.LeakyReLU' + quantizer_name: '*' + enable: false diff --git a/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py b/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py new file mode 100644 index 00000000000..c25867d8327 --- /dev/null +++ b/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py @@ -0,0 +1,247 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Parity + speedup tests for the fused NVFP4 FP8 scale sweep Triton kernel. + +Compares :class:`TritonNVFP4MSECalibrator` against the reference +:class:`NVFP4MSECalibrator` on the same inputs and asserts the resulting per-block +amax tensors are bit-identical. Also reports a wall-clock speedup number for the +weight-MSE search step on a representative LLM-sized weight. +""" + +import time + +import pytest +import torch +from conftest import requires_triton + +from modelopt.torch.quantization.calib import NVFP4MSECalibrator, TritonNVFP4MSECalibrator +from modelopt.torch.quantization.tensor_quant import static_blockwise_fp4_fake_quant + +BLOCK_SIZE = 16 + + +def _reference_quant_func(global_amax): + """Reference NVFP4 fake-quant matching what ``mse_calibrate`` plumbs in.""" + + def quant_func(x, amax): + return static_blockwise_fp4_fake_quant(x, amax, global_amax) + + return quant_func + + +def _run_reference(x, per_block_amax, global_amax): + cal = NVFP4MSECalibrator( + amax=per_block_amax, + axis=0, + global_amax=global_amax, + quant_func=_reference_quant_func(global_amax), + ) + cal.collect(x) + return cal.compute_amax() + + +def _run_triton(x, per_block_amax, global_amax): + cal = TritonNVFP4MSECalibrator( + amax=per_block_amax, + axis=0, + global_amax=global_amax, + quant_func=_reference_quant_func(global_amax), + ) + cal.collect(x) + return cal.compute_amax() + + +@requires_triton +@pytest.mark.parametrize("seed", [0, 1, 2]) +@pytest.mark.parametrize("num_blocks", [4, 64, 1024]) +def test_parity_random_weights(seed, num_blocks): + """Triton sweep must produce the exact same per-block amax as the reference.""" + torch.manual_seed(seed) + device = "cuda" + x = torch.randn(num_blocks, BLOCK_SIZE, device=device, dtype=torch.float32) + per_block_amax = x.abs().amax(dim=-1) + global_amax = per_block_amax.max() + + ref = _run_reference(x, per_block_amax, global_amax) + tri = _run_triton(x, per_block_amax, global_amax) + + assert ref.shape == tri.shape + # Both pick from the same 126-element discrete candidate set, so any disagreement + # would show up as a non-zero diff (not a small float epsilon). Demand exact match. + assert torch.equal(ref, tri), ( + f"Triton sweep diverged from reference: max |diff| = " + f"{(ref - tri).abs().max().item():.3e}, " + f"differing blocks = {(ref != tri).sum().item()} / {num_blocks}" + ) + + +@requires_triton +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +def test_parity_dtypes(dtype): + """Sweep must agree across the dtypes supported by the NVFP4 quantizer.""" + torch.manual_seed(42) + device = "cuda" + num_blocks = 256 + x = torch.randn(num_blocks, BLOCK_SIZE, device=device, dtype=dtype) + # Promote to fp32 for the per-block amax (matches what max_calibrate produces). + per_block_amax = x.float().abs().amax(dim=-1) + global_amax = per_block_amax.max() + + ref = _run_reference(x, per_block_amax, global_amax) + tri = _run_triton(x, per_block_amax, global_amax) + assert torch.equal(ref, tri) + + +@requires_triton +def test_quantized_output_matches(): + """Round-tripping x through the chosen amax should give the same fake-quant result.""" + torch.manual_seed(7) + device = "cuda" + num_blocks = 128 + x = torch.randn(num_blocks, BLOCK_SIZE, device=device, dtype=torch.float32) + per_block_amax = x.abs().amax(dim=-1) + global_amax = per_block_amax.max() + + ref_amax = _run_reference(x, per_block_amax, global_amax) + tri_amax = _run_triton(x, per_block_amax, global_amax) + + ref_xq = static_blockwise_fp4_fake_quant(x, ref_amax, global_amax) + tri_xq = static_blockwise_fp4_fake_quant(x, tri_amax, global_amax) + assert torch.equal(ref_xq, tri_xq) + + +@requires_triton +def test_reset_allows_recollect(): + torch.manual_seed(0) + device = "cuda" + num_blocks = 32 + x = torch.randn(num_blocks, BLOCK_SIZE, device=device, dtype=torch.float32) + per_block_amax = x.abs().amax(dim=-1) + global_amax = per_block_amax.max() + + cal = TritonNVFP4MSECalibrator( + amax=per_block_amax, + axis=0, + global_amax=global_amax, + ) + cal.collect(x) + first = cal.compute_amax().clone() + + # collect() is one-shot per cycle until reset() is called. + with pytest.raises(RuntimeError, match="one-shot"): + cal.collect(x) + + cal.reset() + # After reset, the same calibrator instance can be re-used. + cal.collect(x) + assert torch.equal(first, cal.compute_amax()) + + +@requires_triton +def test_input_validation(): + """``nvfp4_fp8_scale_sweep`` should reject malformed inputs cleanly.""" + from modelopt.torch.kernels.quantization.gemm import fp8_scale_candidates, nvfp4_fp8_scale_sweep + + device = "cuda" + x = torch.randn(64, BLOCK_SIZE, device=device) + g = x.abs().amax() + + # CPU tensor → ValueError (not bare AssertionError). + with pytest.raises(ValueError, match="CUDA"): + nvfp4_fp8_scale_sweep(x.cpu(), g.cpu()) + + # block_size <= 0. + with pytest.raises(ValueError, match="block_size"): + nvfp4_fp8_scale_sweep(x, g, block_size=0) + with pytest.raises(ValueError, match="block_size"): + nvfp4_fp8_scale_sweep(x, g, block_size=-1) + + # Non-divisible numel. + with pytest.raises(ValueError, match="not divisible"): + nvfp4_fp8_scale_sweep(x, g, block_size=15) + + # Empty / wrong-rank candidates. + with pytest.raises(ValueError, match="non-empty 1-D"): + nvfp4_fp8_scale_sweep(x, g, candidates=torch.empty(0, device=device)) + with pytest.raises(ValueError, match="non-empty 1-D"): + nvfp4_fp8_scale_sweep(x, g, candidates=fp8_scale_candidates(device).reshape(2, -1)) + + +def _bench(fn, warmup=2, iters=5): + for _ in range(warmup): + fn() + torch.cuda.synchronize() + t0 = time.perf_counter() + for _ in range(iters): + fn() + torch.cuda.synchronize() + return (time.perf_counter() - t0) / iters + + +@requires_triton +def test_speedup_report(capsys): + """Sanity-check that the Triton path is meaningfully faster on a realistic weight. + + Uses an 8192 x 4096 weight (~33M elements, ~2M NVFP4 blocks) — roughly the size + of an LLM attention/MLP projection. Reports the speedup; does not gate on a + minimum factor (kernel timing is noisy on shared CI), but does require parity + on the chosen amax. + """ + torch.manual_seed(123) + device = "cuda" + cout, cin = 8192, 4096 + x = torch.randn(cout, cin // BLOCK_SIZE, BLOCK_SIZE, device=device, dtype=torch.float32) + x = x.reshape(-1, BLOCK_SIZE) + per_block_amax = x.abs().amax(dim=-1) + global_amax = per_block_amax.max() + + ref_amax = _run_reference(x, per_block_amax, global_amax) + tri_amax = _run_triton(x, per_block_amax, global_amax) + # Bit-equality across millions of blocks isn't guaranteed: when two adjacent FP8 + # candidates yield near-identical per-block MSE (within fp32 noise), the reference's + # CUDA fake_e4m3fy path and our Triton inline math can break ties differently. Demand + # instead that the Triton choice produces a per-block MSE within fp32 epsilon of the + # reference's choice. + n_blocks = ref_amax.numel() + n_diff = int((ref_amax != tri_amax).sum()) + if n_diff: + ref_xq = static_blockwise_fp4_fake_quant(x, ref_amax, global_amax) + tri_xq = static_blockwise_fp4_fake_quant(x, tri_amax, global_amax) + per_block_mse_ref = (x - ref_xq).pow(2).sum(dim=-1) + per_block_mse_tri = (x - tri_xq).pow(2).sum(dim=-1) + # Reference is the formal argmin, so triton's loss should be ≥ reference's. + # Allow at most 1e-5 relative gap on differing blocks (observed ~1e-7 in practice). + rel_gap = (per_block_mse_tri - per_block_mse_ref).abs() / per_block_mse_ref.clamp_min(1e-12) + worst = rel_gap.max().item() + assert worst < 1e-5, ( + f"{n_diff}/{n_blocks} blocks disagree with worst relative MSE gap {worst:.3e} " + "— exceeds tie-break tolerance" + ) + + ref_t = _bench(lambda: _run_reference(x, per_block_amax, global_amax)) + tri_t = _bench(lambda: _run_triton(x, per_block_amax, global_amax)) + speedup = ref_t / tri_t + + # Force-print regardless of pytest capture mode. + with capsys.disabled(): + n_blocks = x.numel() // BLOCK_SIZE + print( + f"\n[NVFP4 FP8 sweep] weight=({cout},{cin}) " + f"n_blocks={n_blocks} block_size={BLOCK_SIZE}\n" + f" reference NVFP4MSECalibrator: {ref_t * 1e3:8.2f} ms\n" + f" triton TritonNVFP4MSECalibrator: {tri_t * 1e3:8.2f} ms\n" + f" speedup: {speedup:.1f}x" + ) diff --git a/tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py b/tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py index b1b3691a797..430b7ee4113 100644 --- a/tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py +++ b/tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py @@ -21,6 +21,7 @@ from modelopt.torch.quantization.calib import NVFP4MSECalibrator from modelopt.torch.quantization.config import QuantizerAttributeConfig from modelopt.torch.quantization.nn import NVFP4StaticQuantizer, TensorQuantizer +from modelopt.torch.quantization.qtensor import NVFP4QTensor from modelopt.torch.quantization.tensor_quant import ( scaled_e4m3_impl, static_blockwise_fp4_fake_quant, @@ -64,6 +65,51 @@ def test_global_amax_property(self, device): quantizer.global_amax = None assert quantizer.global_amax is None + def test_export_fp8_scale_no_nan_for_zero_amax_block(self, device): + """Regression: export must not emit fp8 NaN bytes for an all-zero block. + + When max-only calibration leaves ``_amax = 0`` for a fully-zero weight block, + the export's ``[per_block_scale == 0] = 1.0`` safety net drives the pre-cast + value to ``1.0 * 448 / (global_amax / 6)``. fp8_e4m3fn has no Inf, so any + pre-cast value >= 480 rounds to NaN — without a saturation clamp this writes + a 0x7F byte into ``weight_scale``. Reproduces the NaN seen in the saved + Kimi-K2.6-NVFP4-MSE checkpoint at expert 21 down_proj. + """ + block_size = 16 + cfg = QuantizerAttributeConfig( + num_bits=(2, 1), + block_sizes={-1: block_size, "type": "static", "scale_bits": (4, 3)}, + ) + quantizer = NVFP4StaticQuantizer(quant_attribute_cfg=cfg).to(device) + + # Two-block weight: block 0 is non-trivial; block 1 is all zeros so its + # per-block amax is exactly 0. + weight = torch.zeros(1, 2 * block_size, device=device, dtype=torch.bfloat16) + weight[0, :block_size] = 0.1 + + per_block_amax = weight.abs().reshape(1, 2, block_size).amax(dim=-1).flatten() + quantizer.amax = per_block_amax + quantizer.global_amax = per_block_amax.max() + + # Sanity: the bug only fires when the would-be cast value exceeds 480. + # With global_amax = 0.1, scale_in_fp8 for a zero block is + # 1.0 * 448 / (0.1 / 6) ≈ 26880 — well past the 480 NaN threshold. + assert (per_block_amax == 0).any() + assert quantizer.global_amax.float().item() < 1.0 + + weight_scale, _ = NVFP4QTensor.get_weights_scaling_factor_from_quantizer( + quantizer, weight, weights_scaling_factor_2=None + ) + assert weight_scale.dtype == torch.float8_e4m3fn + + # No fp8_e4m3fn NaN bytes (NaN encoding is (b & 0x7F) == 0x7F). + raw = weight_scale.view(torch.uint8) + n_nan = ((raw & 0x7F) == 0x7F).sum().item() + assert n_nan == 0, f"fp8 weight_scale contains {n_nan} NaN byte(s)" + + # The all-zero block's stored fp8 scale should saturate to 448 (max finite). + assert raw.flatten()[1].item() == 0x7E + def test_fake_quantize_with_both_amaxs(self, device): """Test _fake_quantize uses both _amax and _global_amax.""" num_blocks = 4