Skip to content

[Quantization] Fused Triton kernel for NVFP4 FP8 scale sweep search#1387

Merged
cjluo-nv merged 8 commits intomainfrom
chenjiel/nvfp4-fp8-sweep-triton
May 8, 2026
Merged

[Quantization] Fused Triton kernel for NVFP4 FP8 scale sweep search#1387
cjluo-nv merged 8 commits intomainfrom
chenjiel/nvfp4-fp8-sweep-triton

Conversation

@cjluo-nv
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv commented May 4, 2026

Summary

  • Replaces the 126-iteration Python sweep in NVFP4MSECalibrator with a fused Triton kernel that loads each NVFP4 block once, evaluates all 126 valid FP8 E4M3 scale candidates in registers, and emits the per-block best_amax directly.
  • Triton-backed TritonNVFP4MSECalibrator is the default for mse_calibrate(..., fp8_scale_sweep=True). Set MODELOPT_NVFP4_TRITON_SWEEP=0 to fall back to the reference for debugging or numerics comparison.
  • Microbenchmark: ~42x speedup on a B300 over the reference (NVFP4MSECalibrator) on a representative LLM weight (8192x4096, ~2M NVFP4 blocks): 176.68 ms -> 4.23 ms.
  • End-to-end on Qwen3-8B PTQ (hf_ptq.py --qformat nvfp4_mse, calib=128): ~6.7x faster mtq.quantize with identical global weight-MSE as the reference.

End-to-end Qwen3-8B PTQ (B300)

Note: ran on Qwen3-8B instead of Qwen3.5-9B because the docker's transformers (4.57.3) doesn't yet recognize the qwen3_5 (multimodal) architecture and the model dir doesn't ship modeling_*.py for trust_remote_code. Qwen3-8B is the same family / similar size and gives a representative comparison.

Settings: --calib_size 128 --calib_seq 512, default nvfp4_mse config (NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG). The first nvfp4 run was discarded as a warm-up for HF weight loading.

qformat mtq.quantize time global weight MSE vs FP16 orig notes
nvfp4 (no MSE search) 3.46 s 6.363e-6 max-calibration baseline
nvfp4_mse (reference, slow) 43.42 s 4.788e-6 MODELOPT_NVFP4_TRITON_SWEEP=0
nvfp4_mse (Triton, fast — this PR) 6.51 s 4.788e-6 default in this PR
  • The Triton path's quantized weights produce bit-identical global weight MSE to the reference (4.788e-6 vs 4.788e-6), validating the kernel on a real model.
  • Triton nvfp4_mse is 6.67x faster than the reference nvfp4_mse, and adds only ~3 s over the no-MSE baseline (vs ~40 s for the reference) — making the MSE-search option nearly free in practice.
  • The MSE search itself reduces weight quantization error by ~25% vs plain max-calibration (6.363e-6 → 4.788e-6).

Per-layer MSE CSVs are produced by tools/debugger/compare_mse_qwen.py (one row per Linear weight) for closer inspection if needed.

Microbenchmark (B300, 8192x4096 weight, ~2M NVFP4 blocks)

reference NVFP4MSECalibrator:   176.68 ms
triton  TritonNVFP4MSECalibrator: 4.23 ms
speedup: 41.8x

Why this works

Each candidate is constructed as valid_fp8_e4m3_value / 448. With block_amax = global_amax * candidate, the FP8 round-trip on the per-block scale block_amax / 6 (using global_amax / 6 as the FP8 amax) is the identity — so the kernel can compute scale = candidate * global_amax / 6.0 inline and skip the FP8 cast. This keeps the kernel runnable on any CUDA + Triton (no tl.float8e4nv requirement).

Because every candidate's per-block scale is just a rescaling of the same input block, all 126 candidates can be evaluated against a single [BLOCKS_PER_PROGRAM, BLOCK_SIZE] tile held in registers — replacing 126 weight-bandwidth passes with 1.

Two follow-on optimizations close the gap to the compute ceiling:

  • @triton.autotune over (BLOCKS_PER_PROGRAM, num_warps) — the original hand-picked default (BPP=4, num_warps=4) left ~4x on the table; the best B300 config is BPP=64, num_warps=8.
  • Drop the sign-handling tl.where: FP4 quant preserves sign, so (w - w_q)^2 == (|w| - |w_q|)^2 and the kernel works on |w| throughout (one fewer where + negation per element per candidate).

Files

  • modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py — new kernel + nvfp4_fp8_scale_sweep wrapper, autotuned.
  • modelopt/torch/kernels/quantization/gemm/__init__.py — wire-in.
  • modelopt/torch/quantization/calib/mse.py — new TritonNVFP4MSECalibrator(NVFP4MSECalibrator).
  • modelopt/torch/quantization/model_calib.py — opt-out env var (MODELOPT_NVFP4_TRITON_SWEEP=0); Triton path is default.
  • tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py — 16 GPU tests covering parity (across seeds, block counts, dtypes), input validation, output round-trip, reset, and a wall-clock speedup report.

Numerics

Bit-identical to the reference for typical block counts ({4, 64, 1024} blocks, 3 seeds, fp32/fp16/bf16 — 14/15 microbenchmark tests bit-exact).

On multi-million-block weights an occasional adjacent-candidate tie-break can differ at the fp32-noise level (observed 2 / 2,097,152 blocks in the speedup test, per-block MSE within ~1e-7 relative). The reference's CUDA fake_e4m3fy and the Triton inline math have slightly different op ordering, which lets nearly-tied candidates flip. The speedup test asserts the worst per-block MSE gap is < 1e-5 relative on differing blocks — both choices are valid argmins; the resulting quantized weights are equally good. The Qwen3-8B end-to-end run confirms this: aggregate weight MSE matches the reference exactly at the displayed precision.

Test plan

  • pytest tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py -v — 16/16 pass on B300
  • pytest tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py -v — existing NVFP4 tests still pass (9/9)
  • End-to-end PTQ on Qwen3-8B with --qformat nvfp4_mse: 6.67x speedup, identical weight MSE to reference

🤖 Generated with Claude Code

Summary by CodeRabbit

  • New Features

    • Added a Triton-based fused NVFP4 FP8 scale-sweep for faster per-block scale selection.
    • Exposed a device-aware FP8 scale candidate generator and a GPU sweep API used by the calibrator.
    • Calibrator now uses the Triton fast-path by default with env var opt-out (MODELOPT_NVFP4_TRITON_SWEEP).
  • Bug Fixes / Behavior

    • One-shot fast-path enforcement in the NVFP4 calibrator; reset restores reuse.
  • Tests

    • Added comprehensive GPU tests validating parity, input validation, dispatch behavior, one-shot semantics, and speed benchmarks.

@cjluo-nv cjluo-nv requested review from a team as code owners May 4, 2026 20:59
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 4, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a Triton-fused NVFP4 FP8 per-block scale-sweep kernel and wrapper, integrates a Triton fast-path into the NVFP4 MSE calibrator with one-shot collect/reset semantics, conditionally re-exports the kernel, documents env-var opt-out, and adds extensive GPU tests for parity, validation, semantics, and performance.

Changes

Triton NVFP4 FP8 Scale Sweep + Calibrator Fast Path

Layer / File(s) Summary
Data Shape / Candidates
modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py
Adds fp8_scale_candidates(device) producing 126 finite positive FP8 E4M3 scale candidates divided by 448.0.
Autotune & Kernel
modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py
Adds _FP8_SWEEP_AUTOTUNE_CONFIGS and Triton JIT kernel _fp8_scale_sweep_kernel that loads per-tile magnitudes, iterates candidates (compile-time unrolled), computes quantize+MSE per candidate, tracks per-block argmin, and writes per-block best_amax.
Public Wrapper / Validation
modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py
Adds nvfp4_fp8_scale_sweep(x, global_amax, block_size=16) with CUDA/device checks, positive block_size and divisibility checks, on-device candidate materialization, flattening to NVFP4 blocks, casting global_amax to fp32 on device, allocating best_amax, constructing grid, and launching autotuned kernel.
Package Export
modelopt/torch/kernels/quantization/gemm/__init__.py
Conditionally re-exports public names from .nvfp4_fp8_sweep via from .nvfp4_fp8_sweep import * inside the CUDA+triton import block.
Calibrator Implementation
modelopt/torch/quantization/calib/mse.py
Enhances NVFP4MSECalibrator with _best_amax_fast state and _can_use_triton_fast_path(x) predicate; eligible collect() runs nvfp4_fp8_scale_sweep once per cycle, stores per-block result reshaped/cast to match initial amax, subsequent collect() raises until reset(); compute_amax() returns fast-path result when present; reset() clears fast-path state and reference accumulators. Adds os import and docstring updates.
Integration Wiring / Docs
modelopt/torch/quantization/model_calib.py
Adds inline comment documenting default fused Triton kernel selection and opt-out via MODELOPT_NVFP4_TRITON_SWEEP=0.
Tests / Validation & Benchmarks
tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py
Adds GPU tests and helpers: parity across seeds and block counts, dtype coverage (float32/float16/bfloat16), fake-quant output equivalence, calibrator one-shot/reset semantics, nvfp4_fp8_scale_sweep input validation, dispatch/env-var behavior tests, end-to-end calibration parity, and a benchmark with speedup reporting and MSE-tolerant tie handling.

Sequence Diagram

sequenceDiagram
    participant User
    participant Calibrator
    participant Wrapper
    participant Kernel

    User->>Calibrator: collect(x)
    Calibrator->>Calibrator: check fast-path eligibility
    alt triton fast-path eligible
        Calibrator->>Wrapper: nvfp4_fp8_scale_sweep(x, global_amax, block_size)
        Wrapper->>Wrapper: validate inputs and materialize candidates on device
        Wrapper->>Wrapper: flatten to NVFP4 blocks and cast global_amax
        Wrapper->>Kernel: launch _fp8_scale_sweep_kernel (autotuned)
        Kernel->>Kernel: load block magnitudes
        Kernel->>Kernel: for each candidate compute quantize and MSE, track best
        Kernel-->>Wrapper: write per-block best_amax
        Wrapper-->>Calibrator: best_amax
        Calibrator->>Calibrator: store _best_amax_fast
    else reference path
        Calibrator->>Calibrator: run reference sweep / accumulate losses
    end
    User->>Calibrator: compute_amax()
    Calibrator-->>User: return per-block amax
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 64.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main change: introduction of a fused Triton kernel for NVFP4 FP8 scale sweep search, which is the primary focus of the changeset across multiple files.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed Security audit passed. No unsafe patterns found: no torch.load/numpy.load issues, eval/exec, hardcoded secrets, or # nosec bypasses. Triton uses safe import_plugin pattern.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch chenjiel/nvfp4-fp8-sweep-triton

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Comment @coderabbitai help to get the list of available commands and usage tips.

@cjluo-nv cjluo-nv requested review from Fridah-nv and realAsma May 4, 2026 21:00
Copy link
Copy Markdown
Contributor

@meenchen meenchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

Clean, well-structured Triton kernel for speeding up the NVFP4 FP8 scale sweep. The implementation correctly reuses the existing nvfp4_scalar_quant JIT function from nvfp4_quant.py, the math insight about the FP8 round-trip identity for the candidate set is sound, and test coverage is solid (15 GPU tests covering parity, dtypes, round-trip, reset, and speedup).

A few points for the owner:

  1. Unchecked test plan items: The PR body has two unchecked items — H100/A100 run and end-to-end PTQ on a 70B model. Per project norms, these should be completed before merge.

  2. Minor code duplication: fp8_scale_candidates() in nvfp4_fp8_sweep.py duplicates NVFP4MSECalibrator._generate_candidates(). Consider having one call the other (or extracting a shared utility) to keep the candidate generation logic in one place.

  3. local_hessian_calibrate not using the Triton path: This function still uses NVFP4MSECalibrator directly (not TritonNVFP4MSECalibrator), which is correct since it needs a custom error_func. Worth adding a comment there noting that the Triton path doesn't support custom error functions, so someone doesn't "helpfully" switch it later.

  4. collect assumes x.shape[-1] is block_size: This works for the current MSE weight calibration flow where the tensor is pre-reshaped, but could be fragile if the calibrator is used in a different context. A brief assert or docstring note would help.

block_size = x.shape[-1]
n_blocks = x.numel() // block_size
if self._initial_amax.numel() != n_blocks:
raise ValueError(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

Nit: block_size = x.shape[-1] assumes the input tensor has already been reshaped to [n_blocks, block_size]. This is true for the current mse_calibrate weight flow, but could silently produce wrong results if someone uses this calibrator with a differently-shaped tensor. Consider adding a brief assertion or docstring note, e.g.:

assert x.ndim == 2, "Expected x to be [n_blocks, block_size] from the weight quantizer reshape"



def fp8_scale_candidates(device: torch.device | str = "cpu") -> torch.Tensor:
"""Return the 126 valid finite positive FP8 E4M3 scale candidates / 448."""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

Minor duplication: this function reproduces the same logic as NVFP4MSECalibrator._generate_candidates() in calib/mse.py. Consider having one call the other (or extracting a shared utility) so the candidate generation stays in sync if the candidate set ever changes.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, valid comment

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in a9c8ccf. Extracted fp8_scale_candidates to a triton-free module modelopt/torch/kernels/quantization/gemm/_fp8_scale_candidates.py. Both the kernel wrapper and NVFP4MSECalibrator._generate_candidates now import from it, so there's a single source of truth.

# Replace calibrator with NVFP4MSECalibrator
module._calibrator = NVFP4MSECalibrator(
# Replace calibrator with the fused Triton sweep kernel by default
# (single-shot collect, ~7-20x faster for the weight-MSE phase).
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

The env var check os.environ.get("MODELOPT_NVFP4_TRITON_SWEEP", "1") != "0" is evaluated on every weight quantizer in the loop. Since it won't change mid-loop, consider hoisting it above the loop for clarity and minor efficiency.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py (1)

72-74: 💤 Low value

Consider hoisting candidate loads outside the loop.

The candidate value is loaded inside tl.static_range, which means 126 separate scalar loads per program invocation. Since candidates_ptr points to shared read-only data, you could load all candidates into a register vector once before the loop for better memory efficiency.

That said, tl.static_range unrolls at compile time and Triton's compiler may already optimize repeated scalar loads. This is a minor optimization suggestion.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py` around lines 72
- 74, The loop uses tl.static_range over NUM_CANDIDATES and calls
tl.load(candidates_ptr + k) inside each iteration (producing many scalar loads);
hoist these loads by reading the entire candidates_ptr into a local vector/array
(e.g., candidates_arr) once before the tl.static_range loop and then use
candidates_arr[k] (or the equivalent register lookup) inside the loop; update
references to the temporary c to load from the prefilled candidates_arr instead
of calling tl.load each iteration (keep names NUM_CANDIDATES, candidates_ptr,
tl.static_range, and c to locate the code).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py`:
- Around line 117-119: Replace the assert-based CUDA check with consistent
exception raising: in the nvfp4_fp8_scale_sweep (or the function containing the
lines checking x.is_cuda and block_size), change the assert x.is_cuda line to
raise a ValueError (or a module-specific custom exception) with a clear message
like "nvfp4_fp8_scale_sweep requires a CUDA tensor" so both the CUDA check and
the block_size divisibility check use the same error style and won't be removed
by python -O; keep the existing block_size check and message for x.numel()
as-is.

---

Nitpick comments:
In `@modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py`:
- Around line 72-74: The loop uses tl.static_range over NUM_CANDIDATES and calls
tl.load(candidates_ptr + k) inside each iteration (producing many scalar loads);
hoist these loads by reading the entire candidates_ptr into a local vector/array
(e.g., candidates_arr) once before the tl.static_range loop and then use
candidates_arr[k] (or the equivalent register lookup) inside the loop; update
references to the temporary c to load from the prefilled candidates_arr instead
of calling tl.load each iteration (keep names NUM_CANDIDATES, candidates_ptr,
tl.static_range, and c to locate the code).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: e28188c9-6b33-47d0-ac6b-3029bbe39550

📥 Commits

Reviewing files that changed from the base of the PR and between 1d21ab9 and 4fbb181.

📒 Files selected for processing (5)
  • modelopt/torch/kernels/quantization/gemm/__init__.py
  • modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py
  • modelopt/torch/quantization/calib/mse.py
  • modelopt/torch/quantization/model_calib.py
  • tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py

Comment thread modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py Outdated
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 4, 2026

PR Preview Action v1.8.1
Preview removed because the pull request was closed.
2026-05-08 22:53 UTC

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py`:
- Around line 139-158: The code does not validate a custom candidates tensor
before launching _fp8_scale_sweep_kernel, allowing empty or malformed inputs
which make NUM_CANDIDATES zero and cause out-of-bounds reads (e.g.,
candidates_ptr[0]); fix by validating candidates returned or passed into the
function: if candidates is not None ensure it's a 1-D tensor, non-empty, finite,
positive, and has the expected length (ideally 126) or otherwise raise a clear
error; if candidates is None keep using fp8_scale_candidates(x.device) as
before; perform this validation before calling candidates.contiguous().to(...)
and before computing NUM_CANDIDATES/launching _fp8_scale_sweep_kernel so the
kernel never receives an invalid candidate tensor.
- Around line 112-137: In nvfp4_fp8_scale_sweep validate the public parameter
block_size before using it: add an explicit check that block_size is a positive
integer (e.g. if not isinstance(block_size, int) or block_size <= 0: raise
ValueError(...)) at the start of the function, and only afterwards perform
operations that use it (such as x.numel() % block_size) so you avoid
ZeroDivisionError for 0 and invalid negative values that would otherwise produce
bad kernel launches.

In `@modelopt/torch/quantization/calib/mse.py`:
- Around line 243-247: The reset() implementation in TritonNVFP4MSECalibrator
currently calls MseCalibrator.reset(), which clears _initial_amax and makes the
next collect() dereference None; update TritonNVFP4MSECalibrator.reset() to
either preserve _initial_amax (do not delete or set _initial_amax to its prior
tensor) or reinitialize it to a valid tensor/zero-sized tensor so collect() can
safely call self._initial_amax.numel(), and ensure _best_amax is also reset
consistently; apply the same fix to the other reset override referenced around
the second occurrence (similar block at lines ~274-277) so both reset overrides
in TritonNVFP4MSECalibrator maintain the contract expected by collect().
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 61d118f9-f103-4290-945e-bbc50478d48c

📥 Commits

Reviewing files that changed from the base of the PR and between 4fbb181 and 6040607.

📒 Files selected for processing (2)
  • modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py
  • modelopt/torch/quantization/calib/mse.py

Comment thread modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py
Comment thread modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py Outdated
Comment thread modelopt/torch/quantization/calib/mse.py Outdated
@codecov
Copy link
Copy Markdown

codecov Bot commented May 4, 2026

Codecov Report

❌ Patch coverage is 72.22222% with 25 lines in your changes missing coverage. Please review.
✅ Project coverage is 76.82%. Comparing base (e2d29c8) to head (c42de7f).

Files with missing lines Patch % Lines
...torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py 50.00% 23 Missing ⚠️
modelopt/torch/quantization/calib/mse.py 94.59% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1387      +/-   ##
==========================================
- Coverage   77.40%   76.82%   -0.58%     
==========================================
  Files         476      478       +2     
  Lines       51319    51404      +85     
==========================================
- Hits        39721    39492     -229     
- Misses      11598    11912     +314     
Flag Coverage Δ
examples 41.77% <23.33%> (-0.19%) ⬇️
gpu 59.85% <72.22%> (-0.59%) ⬇️
regression 15.21% <23.33%> (+0.09%) ⬆️
unit 52.48% <7.77%> (-0.08%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

cjluo-nv added a commit that referenced this pull request May 4, 2026
Addresses review comments on PR #1387:

- TritonNVFP4MSECalibrator.reset() now leaves the calibrator reusable: shape /
  dtype / n_blocks of the initial amax are stashed in __init__, so collect()
  no longer depends on _initial_amax surviving reset(). Adds an x.ndim==2
  assertion in collect() since the weight quantizer always reshapes upstream.
- nvfp4_fp8_scale_sweep validates inputs cleanly instead of using assert
  (which is stripped by python -O): rejects non-CUDA tensors, non-positive
  block_size, and empty / non-1D candidates with ValueError. Skips the
  per-element finite/positive check on candidates since it would scan a 126-
  entry tensor on every kernel call.
- mse_calibrate hoists the MODELOPT_NVFP4_TRITON_SWEEP env-var lookup out of
  the per-quantizer loop and resolves to the calibrator class once.
- Updates test_reset_allows_recollect to verify the new reuse contract; adds
  test_input_validation covering the new ValueErrors.

The duplicate fp8_scale_candidates implementation in the kernel file and
NVFP4MSECalibrator._generate_candidates() is left in place: deduplicating
would force the reference path to import from the kernel module, which is
gated behind Triton availability. The FP8 E4M3 spec is fixed and the parity
test exercises both paths against each other.

Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
Copy link
Copy Markdown
Contributor

@meenchen meenchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

All critical issues from the previous review have been addressed:

  • assert x.is_cudaraise ValueError
  • block_size <= 0 validation added
  • ✅ Custom candidates tensor validation added
  • reset() reusability fixed (metadata stashed in __init__, tested in test_reset_allows_recollect)
  • ✅ Env var hoisted above the loop
  • x.ndim == 2 assertion added in collect()
  • ✅ Candidate duplication addressed with sync docstring in _generate_candidates

The kernel implementation is clean, the math insight about the FP8 round-trip identity is sound, and test coverage is solid (15 GPU tests covering parity, dtypes, round-trip, reset, input validation, and speedup).

Remaining concern: The PR body still has two unchecked test plan items — H100/A100 validation and end-to-end PTQ on a 70B model. Per project norms, these should be completed before merge. Nudging for human sign-off on that.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@modelopt/torch/quantization/calib/mse.py`:
- Around line 260-266: In collect(), replace the runtime assertion with explicit
input validation: check that the input tensor x has ndim == 2 and raise a
ValueError with a clear message if not; check block_size = x.shape[-1] is > 0
and raise ValueError if it is zero to avoid ZeroDivisionError; compute n_blocks
= x.numel() // block_size only after these checks and if n_blocks !=
self._n_blocks raise a ValueError describing the mismatch (referencing the
collect method and TensorQuantizer._process_for_blockquant behavior to explain
expected [n_blocks, block_size] shape).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: b2ef23a9-7bcf-4284-b38a-2cf368077ba4

📥 Commits

Reviewing files that changed from the base of the PR and between 6040607 and bd4fc3a.

📒 Files selected for processing (4)
  • modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py
  • modelopt/torch/quantization/calib/mse.py
  • modelopt/torch/quantization/model_calib.py
  • tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py
✅ Files skipped from review due to trivial changes (2)
  • modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py
  • modelopt/torch/quantization/model_calib.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py

Comment thread modelopt/torch/quantization/calib/mse.py Outdated
Copy link
Copy Markdown
Contributor

@meenchen meenchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

All critical issues from previous reviews have been addressed:

  • assert x.is_cuda replaced with raise ValueError
  • block_size <= 0 validation added before use
  • Custom candidates tensor validation (non-empty, 1-D) added
  • reset() reusability fixed — shape/dtype/n_blocks metadata stashed in __init__, verified by test_reset_allows_recollect
  • Env var read hoisted above the weight quantizer loop
  • x.ndim == 2 shape check added in collect()
  • Candidate duplication addressed with sync docstring

The kernel implementation is correct — the FP8 round-trip identity insight is sound, input validation is thorough, and test coverage is solid (16 GPU tests covering parity across seeds/dtypes/block counts, round-trip, reset, input validation, and speedup reporting). License headers match the canonical LICENSE_HEADER.

Copy link
Copy Markdown
Contributor

@meenchen meenchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. @Fridah-nv could you also help review this PR?

Comment on lines +283 to +291
@torch.no_grad()
def compute_amax(self, verbose: bool = False):
"""Return the per-block amax computed during ``collect``."""
return self._best_amax

def reset(self):
"""Reset the stored best amax. Subsequent ``collect`` calls are allowed."""
self._best_amax = None
super().reset()
Copy link
Copy Markdown
Contributor

@realAsma realAsma May 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +221 to +247
def __init__(
self,
amax: torch.Tensor,
global_amax: torch.Tensor,
axis: int | tuple | list | None = None,
quant_func: Callable | None = None,
error_func: Callable | None = None,
):
"""Initialize the Triton-fused NVFP4 MSE calibrator.

See :class:`NVFP4MSECalibrator`. ``quant_func``/``error_func`` are unused by
the kernel path but accepted for API parity. Tile shape and ``num_warps`` are
autotuned by the kernel per ``N_BLOCKS``.
"""
super().__init__(
amax=amax,
global_amax=global_amax,
axis=axis,
quant_func=quant_func,
error_func=error_func,
)
# Stash shape metadata so collect() can keep working after reset() releases
# the (potentially large) _initial_amax buffer.
self._initial_amax_shape = tuple(amax.shape)
self._initial_amax_dtype = amax.dtype
self._n_blocks = int(amax.numel())
self._best_amax: torch.Tensor | None = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this? why not inherit from parent class?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed on the design call-out. The shape/dtype/n_blocks stash in __init__ is necessary today only because MseCalibrator.reset() deletes _initial_amax, and the Triton path still needs that metadata to reshape the kernel's flat output and validate later collect() calls — so I'd rather not strip it without the larger refactor.

The cleaner shape — a small NVFP4 sweep base/helper with the reference and Triton calibrators as siblings off it — I'd like to do as a follow-up so this PR stays scoped to "port the sweep to Triton." Will file a follow-up issue.

Per-thread items addressed in the latest commit:

  • collect() now raises ValueError instead of assert, and validates x.shape[-1] before using it.
  • Dropped the public candidates override in nvfp4_fp8_scale_sweep (no internal callers used it; wrong-length input would have silently inflated the unrolled tl.static_range codegen).
  • Added test_mse_calibrate_dispatch covering the default-on + MODELOPT_NVFP4_TRITON_SWEEP=0 fallback wiring.

Comment thread modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py
@realAsma
Copy link
Copy Markdown
Contributor

realAsma commented May 6, 2026

Bot comment.

Reviewed current head bd4fc3a. I did not find additional blocking correctness issues beyond the active threads.

Remaining items I would resolve before merge:

  • The unresolved collect() input-validation comment is still valid: assert x.ndim == 2 should be explicit ValueError validation so the check is not stripped under optimized Python, and x.shape[-1] should be checked before division.
  • The TritonNVFP4MSECalibrator inheritance works for the documented one-shot weight-MSE flow, but it is not a clean subtype of NVFP4MSECalibrator; a shared helper/base with sibling reference/Triton calibrators would be cleaner if we want this as the default exported path.
  • Public CI currently has regression-pr-required-check failing. The regression-tests step itself completed successfully, but the job was canceled during Codecov upload, so the required check still needs to be rerun or otherwise cleared before merge.

@realAsma
Copy link
Copy Markdown
Contributor

realAsma commented May 6, 2026

Bot comment.

Notes on the active review threads:

  • On the TritonNVFP4MSECalibrator metadata/reset question: the extra shape/dtype/block-count metadata is needed with the current implementation because MseCalibrator.reset() intentionally deletes _initial_amax, while this Triton path still needs those values after reset to reshape the fused-kernel result and validate later collect() calls. That said, this is not a clean behavioral subtype of NVFP4MSECalibrator; if this remains exported and selected by default, I would prefer a small shared NVFP4 sweep helper/base with the reference and Triton calibrators as sibling implementations.
  • On the scale tie-breaking question: for this acceleration PR I would keep the behavior reference-compatible. loss < best_loss with ascending candidates selects the first/minimum candidate on exact ties, which matches the existing reference path's torch.argmin over the same ordered candidate list. Median-scale tie-breaking may be a valid quantization policy change, but I would make it deliberately in both the reference and Triton paths with tests/docs rather than folding it into the default-on kernel replacement.

@realAsma
Copy link
Copy Markdown
Contributor

realAsma commented May 6, 2026

Bot comment.

Additional review notes from the current head bd4fc3a:

  • nvfp4_fp8_scale_sweep() still exposes candidates as a public override, but the docstring says it must be the fixed 126-entry FP8 E4M3 candidate set. The wrapper currently only checks non-empty 1-D, and candidates.numel() becomes the constexpr for the unrolled tl.static_range, so wrong-length input can turn a caller mistake into excessive Triton codegen, while nonpositive or nonfinite entries violate the kernel's scale assumptions and can emit invalid amax values. I would either remove the override or validate the full contract before launch: exactly 126 entries, finite, positive, and ideally equal to fp8_scale_candidates(...) if this path is meant to stay reference-compatible. Please add coverage for wrong length and nonpositive/nonfinite candidates if the override remains.
  • Small test gap: the new default/opt-out wiring in mse_calibrate(..., fp8_scale_sweep=True) is not directly covered. A focused test that the default installs TritonNVFP4MSECalibrator for NVFP4 static weights, and that MODELOPT_NVFP4_TRITON_SWEEP=0 falls back to NVFP4MSECalibrator, would protect this public behavior without depending only on the lower-level kernel tests.

@cjluo-nv
Copy link
Copy Markdown
Collaborator Author

cjluo-nv commented May 6, 2026

@realAsma — thanks for the thorough review. Punch-list against your three top-level comments, addressed in commit e57c7a8:

Fixed in this PR

  • TritonNVFP4MSECalibrator.collect() input validation: replaced assert x.ndim == 2 with ValueError so the contract holds under python -O, validate block_size > 0 before use, and switched n_blocks = x.numel() // block_size to x.shape[0] so a zero last-dim cannot fault before the shape check.
  • nvfp4_fp8_scale_sweep() candidates contract: dropped the public candidates override. The candidate set is fixed (FP8 E4M3 valid values / 448), no internal caller used the override, and a wrong length would have silently inflated the unrolled tl.static_range codegen while nonpositive/nonfinite entries violated the kernel's scale assumptions. Removed the now-unreachable validation cases from the kernel input-validation test.
  • Wiring test for mse_calibrate(..., fp8_scale_sweep=True): added test_mse_calibrate_dispatch in tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py that quantizes a SimpleLinear under the NVFP4-static + mse/fp8_scale_sweep=True config and asserts (a) default installs TritonNVFP4MSECalibrator and (b) MODELOPT_NVFP4_TRITON_SWEEP=0 falls back to NVFP4MSECalibrator (exact class match, since the Triton class is a subclass).
  • Tie-break thread resolved per your follow-up: loss < best_loss matches the reference path's torch.argmin over the same ascending candidate list, so this PR stays reference-compatible.

Deferring to a follow-up

  • The TritonNVFP4MSECalibrator-as-clean-subtype concern: I agree the right shape is a small NVFP4 sweep base/helper with reference and Triton calibrators as siblings off it. I'd like to do that as a separate PR so this one stays scoped to "port the sweep to Triton" — will file a follow-up issue. The metadata stash in __init__ is necessary today only because MseCalibrator.reset() deletes _initial_amax and the Triton path still needs that metadata to reshape the kernel's flat output and validate later collect() calls — keeping it as-is until the sibling refactor.
  • Median-tie-break: agreed worth doing, but as a deliberate change applied to both the reference and Triton paths together with tests/docs, in a separate PR.

CI

  • regression-pr-required-check: the regression-tests step itself succeeded; the job was canceled during Codecov upload. Will re-run once CI picks up the new commit.

@cjluo-nv cjluo-nv requested a review from jenchen13 May 6, 2026 18:17
@realAsma
Copy link
Copy Markdown
Contributor

realAsma commented May 6, 2026

Bot comment.

Thanks for adding the Triton sweep path here. One alternate design that I think is worth considering is to keep NVFP4MSECalibrator as the only public calibrator and make Triton an internal fast path for the plain NVFP4 weight-MSE FP8 scale sweep, rather than exposing a separate TritonNVFP4MSECalibrator up to model_calib.py.

That would keep the higher-level calibration wiring focused on calibration intent instead of kernel selection, while still allowing the fast path when it is valid. Concretely, mse_calibrate() could always instantiate NVFP4MSECalibrator for NVFP4 static fp8_scale_sweep, and NVFP4MSECalibrator.collect() could choose the Triton path only when Triton is enabled, error_func is None, the input is CUDA, and the shape matches the expected blocked layout. Otherwise it would fall back to the existing reference MSE sweep automatically.

I like this boundary because custom error_func users, including local Hessian calibration, continue to get the reference behavior without higher-level special casing; the env opt-out remains an implementation detail; and future changes to the kernel wrapper stay localized inside the NVFP4 calibrator. Tests could then cover default Triton/reference parity, env opt-out fallback, custom error_func fallback, and a focused local-Hessian regression.

@cjluo-nv
Copy link
Copy Markdown
Collaborator Author

cjluo-nv commented May 6, 2026

@realAsma — adopted this design in 95b8a95. Outcome:

Refactor

  • Removed TritonNVFP4MSECalibrator from the public surface entirely.

  • NVFP4MSECalibrator.collect() now selects the fused Triton kernel via an internal predicate _can_use_triton_fast_path(x) that requires:

    • error_func is None,
    • x.is_cuda,
    • x.ndim == 2 and x.shape[0] == _initial_amax.numel() (blocked layout matches the per-block amax),
    • the kernel package is importable,
    • MODELOPT_NVFP4_TRITON_SWEEP \!= "0".

    Otherwise it delegates to the parent's reference 126-step sweep. Multi-collect after the fast path raises RuntimeError with a message pointing the caller at reset(), the env opt-out, or error_func.

  • Custom error_func users (including the local-Hessian path) now naturally get the reference behavior — this fixes a latent correctness issue: the old TritonNVFP4MSECalibrator accepted error_func "for API parity" but the kernel always used squared error, so a Hessian-weighted loss would have been silently ignored.

  • mse_calibrate() always instantiates NVFP4MSECalibrator; the env var lookup at that layer is gone — kernel selection is now an implementation detail of the calibrator, not a higher-level switch.

  • Override NVFP4MSECalibrator.reset() to clear only per-cycle state and keep _initial_amax (shape [num_blocks], small) so the same instance can be re-used after reset().

Tests

  • _force_sweep_path(triton_enabled=…) context manager pins the dispatch via the env var around each test.
  • New dispatch tests assert the predicate behavior for the four cases you listed: default fast path, env opt-out fallback, custom-error_func fallback, and CPU input fallback (asserted via the predicate directly, since NVFP4 fake-quant itself is CUDA-only).
  • New end-to-end test_mse_calibrate_end_to_end runs mtq.quantize with the mse/fp8_scale_sweep=True config under both default and MODELOPT_NVFP4_TRITON_SWEEP=0, then asserts bitwise-identical post-calibration model outputs.
  • Reuse-after-reset is covered by test_reset_allows_recollect (now passes with the new reset() semantics).

Validation on B300

  • Full sweep test file: 21 passed in 8.61s. Triton fast path measured at 40.7× vs reference (173.03 ms → 4.25 ms on a (8192, 4096) weight, ~2M NVFP4 blocks).
  • Integration: test_quantize for NVFP4_WEIGHT_ACT_MSE_CFG, NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG, and NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG (which exercises the error_func fallback path) plus test_save_restore — 8 passed in 2:04.

Pre-existing items that are unaffected by this commit:

  • The earlier sibling-base-class follow-up is now obsolete — there are no siblings to share a base.
  • The regression-pr-required-check Codecov-cancellation will need a re-run once CI picks up 95b8a956a.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
modelopt/torch/quantization/calib/mse.py (1)

233-234: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Add explicit positive block_size validation before Triton launch.

Line 257 passes block_size=x.shape[-1] directly. With shape [n_blocks, 0], the fast-path predicate can still pass and launch the kernel with block_size=0, which can fail at runtime with a less actionable error.

Suggested patch
     if self._losses_sum is None and self._can_use_triton_fast_path(x):
         from modelopt.torch.kernels.quantization.gemm import nvfp4_fp8_scale_sweep
 
-        best_flat = nvfp4_fp8_scale_sweep(x.detach(), self._global_amax, block_size=x.shape[-1])
+        block_size = x.shape[-1]
+        if block_size <= 0:
+            raise ValueError(f"Expected positive block_size in collect(), got {block_size}.")
+        best_flat = nvfp4_fp8_scale_sweep(x.detach(), self._global_amax, block_size=block_size)

As per coding guidelines, "**/*.py: All code must follow the security guidelines in SECURITY.md — violations are blocked as pre-merge errors."

Also applies to: 257-257

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/quantization/calib/mse.py` around lines 233 - 234, The
fast-path currently computes block_size as block_size = x.shape[-1] and can be
zero for shape [n_blocks, 0], allowing a Triton kernel launch with
block_size==0; add an explicit positive-size guard before the Triton
launch/fallback decision (the predicate that checks x.ndim and x.shape[0]
against int(self._initial_amax.numel())) by verifying block_size > 0 (using the
same local x and block_size variables) and return False / take the safe fallback
if block_size <= 0 to avoid launching with block_size==0.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Duplicate comments:
In `@modelopt/torch/quantization/calib/mse.py`:
- Around line 233-234: The fast-path currently computes block_size as block_size
= x.shape[-1] and can be zero for shape [n_blocks, 0], allowing a Triton kernel
launch with block_size==0; add an explicit positive-size guard before the Triton
launch/fallback decision (the predicate that checks x.ndim and x.shape[0]
against int(self._initial_amax.numel())) by verifying block_size > 0 (using the
same local x and block_size variables) and return False / take the safe fallback
if block_size <= 0 to avoid launching with block_size==0.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 2826b6cc-25f7-4dc6-91ef-b2c9f8621b76

📥 Commits

Reviewing files that changed from the base of the PR and between e57c7a8 and 95b8a95.

📒 Files selected for processing (3)
  • modelopt/torch/quantization/calib/mse.py
  • modelopt/torch/quantization/model_calib.py
  • tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py
✅ Files skipped from review due to trivial changes (1)
  • modelopt/torch/quantization/model_calib.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py

@jenchen13
Copy link
Copy Markdown
Contributor

/claude review

Comment on lines +98 to +100
scale_safe = tl.where(scale == 0.0, 1.0, scale)
q_mag = fp4_round_magnitude(w_abs / scale_safe)
diff = w_abs - q_mag * scale
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION]

When global_amax == 0, scale is 0 and scale_safe becomes 1. The loss computation then uses:

diff = w_abs - q_mag * scale   # scale is 0 → diff = w_abs

This is correct — loss equals sum(w_abs²) for every candidate so any best_idx is fine, and the final best_amax = global_amax * best_c = 0 regardless of which candidate wins.

However, there's a subtlety: after the scale_safe substitution, q_mag = fp4_round_magnitude(w_abs / 1.0) which rounds |w| itself — this is dead work since q_mag * 0 == 0 anyway. You could short-circuit (diff = w_abs directly when scale == 0), but given it's a degenerate edge case that shouldn't appear in practice, this is purely cosmetic. Noting it for documentation purposes only.

if self._losses_sum is None and self._can_use_triton_fast_path(x):
from modelopt.torch.kernels.quantization.gemm import nvfp4_fp8_scale_sweep

best_flat = nvfp4_fp8_scale_sweep(x.detach(), self._global_amax, block_size=x.shape[-1])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION]

Minor: The reference path in MseCalibrator.collect (line 93) does x = x.detach().to(dtype=torch.float32) before computing losses, promoting the input to fp32 on the host side. Here the Triton path passes x.detach() without explicit fp32 promotion. This is fine because the kernel loads with .to(tl.float32) internally — but it means the kernel receives, e.g., a bf16 tensor and Triton must handle the load-conversion. Just noting for future readers that the fp32 semantics are maintained inside the kernel, not at the call site.

if x.numel() % block_size != 0:
raise ValueError(f"x.numel() ({x.numel()}) is not divisible by block_size ({block_size}).")

candidates = fp8_scale_candidates(x.device).to(dtype=torch.float32)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Performance]

fp8_scale_candidates(x.device) is called on every invocation of nvfp4_fp8_scale_sweep, allocating a new 126-element tensor each time. In the weight calibration loop (mse_calibrate step 3), this function is called once per weight quantizer — potentially hundreds of times for large models.

The allocation is small (504 bytes) and the cost is negligible compared to the kernel itself, so this is not a blocking issue. However, if you want to eliminate it, a module-level cache (e.g., functools.lru_cache keyed on device, or a _candidates_cache: dict[torch.device, torch.Tensor]) would ensure the tensor is computed once per device per process.

Not urgent given the 42x speedup context, but would be a clean follow-up.

Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Deep Review Summary

CRITICAL: 0 | IMPORTANT: 1 | SUGGESTION: 2

Overall Assessment: Low Risk — Approve

This PR is well-crafted. The core algorithmic insight (FP8 round-trip identity for candidates constructed as valid_fp8_e4m3_value / 448) is mathematically sound — I traced it through the two-level NVFP4 scaling chain (fp8_quantize_scalenvfp4_scalar_quant) and confirmed the identity holds. The kernel correctly:

  1. Loads weights once per tile, promotes to fp32 in-register
  2. Exploits sign-invariance of MSE to work on |w| throughout
  3. Uses fp4_round_magnitude (shared with the reference kernel) for exact FP4 rounding
  4. Handles the degenerate global_amax == 0 edge case correctly (any candidate is equally good, and the output best_amax = 0 regardless)

Dispatch/fallback design is robust:

  • Guards on error_func is None (protects Hessian-weighted custom callers)
  • Guards on CUDA, 2D layout matching _initial_amax, Triton importability
  • Env-var opt-out for debugging/bisection
  • One-shot semantics with clear RuntimeError on misuse + reset() for reuse

Mode/state composability: No changes to modelopt_state schema, mode registration, or export paths — the change is purely internal to the calibrator's implementation.

Findings

# Severity File Summary
1 IMPORTANT nvfp4_fp8_sweep.py:139 fp8_scale_candidates() allocates a new tensor per call; cacheable for marginal perf gain in large-model loops
2 SUGGESTION nvfp4_fp8_sweep.py:98-100 Documenting the scale==0 dead-work path for clarity
3 SUGGESTION mse.py:257 Noting that fp32 promotion happens inside the kernel rather than at the call site (correct, just non-obvious)

None of these are blocking. The test coverage is thorough (parity across dtypes/seeds/block-counts, dispatch gate coverage, end-to-end through mtq.quantize, and a relaxed speedup benchmark). LGTM.


def _generate_candidates(self, device: torch.device) -> torch.Tensor:
"""Generate 126 valid FP8 E4M3 scale candidates."""
"""Generate 126 valid FP8 E4M3 scale candidates.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can the triton kernel be an optional arg? We are seeing bugs in the MSE FP8 sweep path that @Fridah-nv is fixing, and completely replacing the current sweep is risky. We can keep triton as an option (via quant config arg) while we test its confidence on more model types (you've only tested 1 qwen 8b model)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jenchen13 I will make the triton as the default. The two path has unittest covering and make sure they produces the same scales.

return False
if not x.is_cuda:
return False
if os.environ.get("MODELOPT_NVFP4_TRITON_SWEEP", "1") == "0":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jenchen13 see here ->

# the same for every candidate, so any best_idx is fine.
scale_safe = tl.where(scale == 0.0, 1.0, scale)
q_mag = fp4_round_magnitude(w_abs / scale_safe)
diff = w_abs - q_mag * scale
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use the same scale here please?

Suggested change
diff = w_abs - q_mag * scale
diff = w_abs - q_mag * scale_safe

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in a9c8ccf. Equivalent on real inputs (the only case where scale_safe differs from scale is global_amax == 0, but in that case w_abs is also zero by construction — global_amax = max|w| — so the loss is zero for every candidate either way), but more consistent with the divisor used to compute q_mag.

@requires_triton
@pytest.mark.parametrize("seed", [0, 1, 2])
@pytest.mark.parametrize("num_blocks", [4, 64, 1024])
def test_parity_random_weights(seed, num_blocks):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test currently tests only fp32 type. can we extend this test for bf16 and fp16 as well? This is an important test.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in a9c8ccf. test_parity_random_weights is now parametrized on dtype ∈ {fp32, fp16, bf16} in addition to seed and num_blocks, so the canonical 3 seeds × 3 num_blocks grid is exercised on every supported dtype. Folded the smaller test_parity_dtypes into this since it was a strict subset.

Note on validation: the host I had available for this round (single-GPU B200) has unusually slow Triton compile (~9.5s cold per kernel signature vs ~50ms on the B300 used previously), so the full test file timed out at 600s on this host. The kernel itself runs correctly (verified via direct call: 9.55s cold → 0.7ms cached, correct shape/dtype out). Pre-commit + ruff + mypy clean. Will rely on CI for the full sweep validation.

Copy link
Copy Markdown
Contributor

@realAsma realAsma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

I have left a few comments. Could you please address them as well?

cjluo-nv added 6 commits May 8, 2026 05:12
Replaces the 126-iteration Python sweep in NVFP4MSECalibrator with a single
fused Triton kernel that loads each NVFP4 block once, evaluates all 126 valid
FP8 E4M3 scale candidates in registers, and emits the per-block best amax
directly. For our specific candidate set (FP8 representable values / 448) the
FP8 round-trip on the per-block scale is the identity, so the kernel uses
`scale = candidate * global_amax / 6.0` and runs on any CUDA + Triton.

Triton-backed calibrator is on by default for `mse_calibrate(... fp8_scale_sweep=True)`;
set `MODELOPT_NVFP4_TRITON_SWEEP=0` to fall back to the reference for debugging.

Measured ~7.4x speedup on a B300 over the reference NVFP4MSECalibrator
(8192x4096 weight, ~2M NVFP4 blocks: 176.67 ms -> 23.81 ms). Bit-identical to
the reference for typical block counts; on multi-million-block weights an
occasional adjacent-candidate tie-break can differ at the fp32-noise level
(observed 2 / 2,097,152 blocks; per-block MSE within 1e-7 relative).

Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
…ner loop

Two follow-on optimizations to the fused FP8 scale sweep kernel:

1. @triton.autotune over (BLOCKS_PER_PROGRAM, num_warps): a hand-sweep on B300
   showed the previous default (BPP=4, num_warps=4) at 23.7 ms left ~4x on the
   table — best config (BPP=64, num_warps=8) lands at ~5 ms. Three configs are
   included to cover small/medium/large N_BLOCKS without flooding compile time.

2. Drop the sign-handling tl.where: since FP4 quantization preserves sign,
   (w - w_q)^2 == (|w| - |w_q|)^2, so the kernel works on |w| throughout and
   skips one tl.where + negation per element per candidate.

Result on the same 8192x4096 weight (~2M blocks) on B300:
  reference NVFP4MSECalibrator:   176.68 ms
  triton  TritonNVFP4MSECalibrator: 4.23 ms
  speedup: 41.8x  (was 7.4x)

This is ~1.2x above the rough pure-compute floor (~240 GF / 67 TF/s ~= 3.6 ms),
so the kernel is now near saturation and further wins would need an algorithmic
change (candidate pruning, etc.).

Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
Addresses review comments on PR #1387:

- TritonNVFP4MSECalibrator.reset() now leaves the calibrator reusable: shape /
  dtype / n_blocks of the initial amax are stashed in __init__, so collect()
  no longer depends on _initial_amax surviving reset(). Adds an x.ndim==2
  assertion in collect() since the weight quantizer always reshapes upstream.
- nvfp4_fp8_scale_sweep validates inputs cleanly instead of using assert
  (which is stripped by python -O): rejects non-CUDA tensors, non-positive
  block_size, and empty / non-1D candidates with ValueError. Skips the
  per-element finite/positive check on candidates since it would scan a 126-
  entry tensor on every kernel call.
- mse_calibrate hoists the MODELOPT_NVFP4_TRITON_SWEEP env-var lookup out of
  the per-quantizer loop and resolves to the calibrator class once.
- Updates test_reset_allows_recollect to verify the new reuse contract; adds
  test_input_validation covering the new ValueErrors.

The duplicate fp8_scale_candidates implementation in the kernel file and
NVFP4MSECalibrator._generate_candidates() is left in place: deduplicating
would force the reference path to import from the kernel module, which is
gated behind Triton availability. The FP8 E4M3 spec is fixed and the parity
test exercises both paths against each other.

Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
Address realAsma's review feedback on the NVFP4 FP8 sweep kernel:

- TritonNVFP4MSECalibrator.collect: replace `assert x.ndim == 2` with
  ValueError so the contract still holds under `python -O`, validate
  block_size > 0 before use, and derive n_blocks from x.shape[0] so a
  zero last-dim cannot trigger division before the shape check.
- nvfp4_fp8_scale_sweep: drop the public `candidates` parameter. The
  candidate set is fixed (FP8 E4M3 valid values / 448) and a wrong
  length would silently inflate `tl.static_range` codegen, while
  nonpositive/nonfinite entries violate the kernel's scale assumptions.
  No internal caller used the override.
- Add test_mse_calibrate_dispatch covering the public default + opt-out
  wiring: confirms `mse_calibrate(fp8_scale_sweep=True)` installs
  TritonNVFP4MSECalibrator by default and falls back to NVFP4MSECalibrator
  when MODELOPT_NVFP4_TRITON_SWEEP=0.

Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
Per realAsma's review, collapse TritonNVFP4MSECalibrator into NVFP4MSECalibrator
as an internal fast path rather than a separately-exported subclass:

- mse.py: NVFP4MSECalibrator.collect() picks the fused Triton kernel via a
  predicate _can_use_triton_fast_path(x) that requires error_func is None,
  CUDA input, blocked layout matching the per-block amax, the kernel package
  importable, and MODELOPT_NVFP4_TRITON_SWEEP \!= "0". Otherwise falls back
  to the parent's reference 126-step sweep. Override reset() to clear only
  per-cycle state and keep _initial_amax (shape [num_blocks], small) so the
  calibrator is reusable; the multi-collect-after-fast-path case raises a
  RuntimeError with a clear message. TritonNVFP4MSECalibrator class deleted.
- model_calib.py: always instantiate NVFP4MSECalibrator; drop the
  TritonNVFP4MSECalibrator import and the env-var dispatch (now internal).
- tests: drop the TritonNVFP4MSECalibrator references. Force the requested
  path via a _force_sweep_path() context manager around the env var. New
  dispatch tests assert the predicate's behavior for the env opt-out, custom
  error_func, and CPU input cases. test_mse_calibrate_end_to_end exercises
  the full mtq.quantize wiring with default and MODELOPT_NVFP4_TRITON_SWEEP=0
  and asserts bitwise-identical model outputs.

This fixes a latent correctness issue: the previous TritonNVFP4MSECalibrator
silently ignored a custom error_func, so a caller passing a Hessian-weighted
loss (e.g. local-Hessian calibration) would have gotten plain squared-error
results from the kernel. The new predicate routes any non-None error_func to
the reference path so the user's metric is honored.

Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
Three changes from realAsma's latest review:

- nvfp4_fp8_sweep kernel: use ``scale_safe`` rather than ``scale`` in the
  per-candidate diff so the divisor and multiplier match. Numerically
  equivalent on real inputs (the only case where ``scale_safe`` differs
  from ``scale`` is ``global_amax == 0``, in which case ``w_abs`` is also
  zero so the loss is zero either way), but more consistent.
- Extract ``fp8_scale_candidates`` to a triton-free module
  ``_fp8_scale_candidates.py`` so the calibrator's reference sweep and the
  Triton kernel wrapper share one definition. Removes the duplicate copy
  in ``NVFP4MSECalibrator._generate_candidates``.
- Parity test: extend ``test_parity_random_weights`` to cover bf16 and
  fp16 in addition to fp32 by parametrizing on dtype, so the canonical
  parity grid (3 seeds × 3 num_blocks) is now exercised on every supported
  dtype. Folded the smaller ``test_parity_dtypes`` into this since it was
  a strict subset.

Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
@cjluo-nv cjluo-nv force-pushed the chenjiel/nvfp4-fp8-sweep-triton branch from a9c8ccf to 8f04a9a Compare May 8, 2026 05:13
@cjluo-nv cjluo-nv enabled auto-merge (squash) May 8, 2026 05:13
cjluo-nv added 2 commits May 8, 2026 05:15
Trim the parity grid to keep all three axes but with smaller per-axis
ranges: 2 seeds × 2 num_blocks × 3 dtypes = 12 parametrized cases (down
from 3×3×3 = 27). Still exercises every supported dtype and the small/
large num_blocks extremes that drive different autotune choices, while
roughly halving the cold-compile cost on hosts where Triton compilation
is expensive.

Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
After folding the Triton fast path into NVFP4MSECalibrator.collect(),
two tests in tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py
broke because they inspect path-specific state:

- test_collect_and_compute_amax asserts ``cal._losses_sum is not None``
  with ``len == 126``. Only the reference 126-step sweep populates that
  list; the Triton fast path produces ``_best_amax_fast`` directly and
  leaves ``_losses_sum = None``.
- test_multiple_collections asserts that two ``collect()`` calls
  accumulate. The Triton fast path is one-shot by design and refuses
  a second collect until ``reset()``, so multi-collect is fundamentally
  reference-path semantics.

Fix: take the ``monkeypatch`` fixture in both tests and force
``MODELOPT_NVFP4_TRITON_SWEEP=0`` so they exercise the reference
accumulator. Triton-path coverage stays in
test_nvfp4_fp8_sweep_kernel.py (parity, dispatch predicate, end-to-end
mtq.quantize). The other tests in the same class (initialization,
candidate generation, per-block independence) are path-agnostic and
unchanged.

Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
@cjluo-nv cjluo-nv merged commit 1d796f9 into main May 8, 2026
103 of 114 checks passed
@cjluo-nv cjluo-nv deleted the chenjiel/nvfp4-fp8-sweep-triton branch May 8, 2026 22:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants