[Quantization] Fused Triton kernel for NVFP4 FP8 scale sweep search#1387
[Quantization] Fused Triton kernel for NVFP4 FP8 scale sweep search#1387
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds a Triton-fused NVFP4 FP8 per-block scale-sweep kernel and wrapper, integrates a Triton fast-path into the NVFP4 MSE calibrator with one-shot collect/reset semantics, conditionally re-exports the kernel, documents env-var opt-out, and adds extensive GPU tests for parity, validation, semantics, and performance. ChangesTriton NVFP4 FP8 Scale Sweep + Calibrator Fast Path
Sequence DiagramsequenceDiagram
participant User
participant Calibrator
participant Wrapper
participant Kernel
User->>Calibrator: collect(x)
Calibrator->>Calibrator: check fast-path eligibility
alt triton fast-path eligible
Calibrator->>Wrapper: nvfp4_fp8_scale_sweep(x, global_amax, block_size)
Wrapper->>Wrapper: validate inputs and materialize candidates on device
Wrapper->>Wrapper: flatten to NVFP4 blocks and cast global_amax
Wrapper->>Kernel: launch _fp8_scale_sweep_kernel (autotuned)
Kernel->>Kernel: load block magnitudes
Kernel->>Kernel: for each candidate compute quantize and MSE, track best
Kernel-->>Wrapper: write per-block best_amax
Wrapper-->>Calibrator: best_amax
Calibrator->>Calibrator: store _best_amax_fast
else reference path
Calibrator->>Calibrator: run reference sweep / accumulate losses
end
User->>Calibrator: compute_amax()
Calibrator-->>User: return per-block amax
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 5 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Comment |
meenchen
left a comment
There was a problem hiding this comment.
Bot review — DM the bot to share feedback.
Clean, well-structured Triton kernel for speeding up the NVFP4 FP8 scale sweep. The implementation correctly reuses the existing nvfp4_scalar_quant JIT function from nvfp4_quant.py, the math insight about the FP8 round-trip identity for the candidate set is sound, and test coverage is solid (15 GPU tests covering parity, dtypes, round-trip, reset, and speedup).
A few points for the owner:
-
Unchecked test plan items: The PR body has two unchecked items — H100/A100 run and end-to-end PTQ on a 70B model. Per project norms, these should be completed before merge.
-
Minor code duplication:
fp8_scale_candidates()innvfp4_fp8_sweep.pyduplicatesNVFP4MSECalibrator._generate_candidates(). Consider having one call the other (or extracting a shared utility) to keep the candidate generation logic in one place. -
local_hessian_calibratenot using the Triton path: This function still usesNVFP4MSECalibratordirectly (notTritonNVFP4MSECalibrator), which is correct since it needs a customerror_func. Worth adding a comment there noting that the Triton path doesn't support custom error functions, so someone doesn't "helpfully" switch it later. -
collectassumesx.shape[-1]is block_size: This works for the current MSE weight calibration flow where the tensor is pre-reshaped, but could be fragile if the calibrator is used in a different context. A brief assert or docstring note would help.
| block_size = x.shape[-1] | ||
| n_blocks = x.numel() // block_size | ||
| if self._initial_amax.numel() != n_blocks: | ||
| raise ValueError( |
There was a problem hiding this comment.
Bot comment.
Nit: block_size = x.shape[-1] assumes the input tensor has already been reshaped to [n_blocks, block_size]. This is true for the current mse_calibrate weight flow, but could silently produce wrong results if someone uses this calibrator with a differently-shaped tensor. Consider adding a brief assertion or docstring note, e.g.:
assert x.ndim == 2, "Expected x to be [n_blocks, block_size] from the weight quantizer reshape"|
|
||
|
|
||
| def fp8_scale_candidates(device: torch.device | str = "cpu") -> torch.Tensor: | ||
| """Return the 126 valid finite positive FP8 E4M3 scale candidates / 448.""" |
There was a problem hiding this comment.
Bot comment.
Minor duplication: this function reproduces the same logic as NVFP4MSECalibrator._generate_candidates() in calib/mse.py. Consider having one call the other (or extracting a shared utility) so the candidate generation stays in sync if the candidate set ever changes.
There was a problem hiding this comment.
Done in a9c8ccf. Extracted fp8_scale_candidates to a triton-free module modelopt/torch/kernels/quantization/gemm/_fp8_scale_candidates.py. Both the kernel wrapper and NVFP4MSECalibrator._generate_candidates now import from it, so there's a single source of truth.
| # Replace calibrator with NVFP4MSECalibrator | ||
| module._calibrator = NVFP4MSECalibrator( | ||
| # Replace calibrator with the fused Triton sweep kernel by default | ||
| # (single-shot collect, ~7-20x faster for the weight-MSE phase). |
There was a problem hiding this comment.
Bot comment.
The env var check os.environ.get("MODELOPT_NVFP4_TRITON_SWEEP", "1") != "0" is evaluated on every weight quantizer in the loop. Since it won't change mid-loop, consider hoisting it above the loop for clarity and minor efficiency.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py (1)
72-74: 💤 Low valueConsider hoisting candidate loads outside the loop.
The candidate value is loaded inside
tl.static_range, which means 126 separate scalar loads per program invocation. Sincecandidates_ptrpoints to shared read-only data, you could load all candidates into a register vector once before the loop for better memory efficiency.That said,
tl.static_rangeunrolls at compile time and Triton's compiler may already optimize repeated scalar loads. This is a minor optimization suggestion.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py` around lines 72 - 74, The loop uses tl.static_range over NUM_CANDIDATES and calls tl.load(candidates_ptr + k) inside each iteration (producing many scalar loads); hoist these loads by reading the entire candidates_ptr into a local vector/array (e.g., candidates_arr) once before the tl.static_range loop and then use candidates_arr[k] (or the equivalent register lookup) inside the loop; update references to the temporary c to load from the prefilled candidates_arr instead of calling tl.load each iteration (keep names NUM_CANDIDATES, candidates_ptr, tl.static_range, and c to locate the code).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py`:
- Around line 117-119: Replace the assert-based CUDA check with consistent
exception raising: in the nvfp4_fp8_scale_sweep (or the function containing the
lines checking x.is_cuda and block_size), change the assert x.is_cuda line to
raise a ValueError (or a module-specific custom exception) with a clear message
like "nvfp4_fp8_scale_sweep requires a CUDA tensor" so both the CUDA check and
the block_size divisibility check use the same error style and won't be removed
by python -O; keep the existing block_size check and message for x.numel()
as-is.
---
Nitpick comments:
In `@modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py`:
- Around line 72-74: The loop uses tl.static_range over NUM_CANDIDATES and calls
tl.load(candidates_ptr + k) inside each iteration (producing many scalar loads);
hoist these loads by reading the entire candidates_ptr into a local vector/array
(e.g., candidates_arr) once before the tl.static_range loop and then use
candidates_arr[k] (or the equivalent register lookup) inside the loop; update
references to the temporary c to load from the prefilled candidates_arr instead
of calling tl.load each iteration (keep names NUM_CANDIDATES, candidates_ptr,
tl.static_range, and c to locate the code).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: e28188c9-6b33-47d0-ac6b-3029bbe39550
📒 Files selected for processing (5)
modelopt/torch/kernels/quantization/gemm/__init__.pymodelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.pymodelopt/torch/quantization/calib/mse.pymodelopt/torch/quantization/model_calib.pytests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py
|
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py`:
- Around line 139-158: The code does not validate a custom candidates tensor
before launching _fp8_scale_sweep_kernel, allowing empty or malformed inputs
which make NUM_CANDIDATES zero and cause out-of-bounds reads (e.g.,
candidates_ptr[0]); fix by validating candidates returned or passed into the
function: if candidates is not None ensure it's a 1-D tensor, non-empty, finite,
positive, and has the expected length (ideally 126) or otherwise raise a clear
error; if candidates is None keep using fp8_scale_candidates(x.device) as
before; perform this validation before calling candidates.contiguous().to(...)
and before computing NUM_CANDIDATES/launching _fp8_scale_sweep_kernel so the
kernel never receives an invalid candidate tensor.
- Around line 112-137: In nvfp4_fp8_scale_sweep validate the public parameter
block_size before using it: add an explicit check that block_size is a positive
integer (e.g. if not isinstance(block_size, int) or block_size <= 0: raise
ValueError(...)) at the start of the function, and only afterwards perform
operations that use it (such as x.numel() % block_size) so you avoid
ZeroDivisionError for 0 and invalid negative values that would otherwise produce
bad kernel launches.
In `@modelopt/torch/quantization/calib/mse.py`:
- Around line 243-247: The reset() implementation in TritonNVFP4MSECalibrator
currently calls MseCalibrator.reset(), which clears _initial_amax and makes the
next collect() dereference None; update TritonNVFP4MSECalibrator.reset() to
either preserve _initial_amax (do not delete or set _initial_amax to its prior
tensor) or reinitialize it to a valid tensor/zero-sized tensor so collect() can
safely call self._initial_amax.numel(), and ensure _best_amax is also reset
consistently; apply the same fix to the other reset override referenced around
the second occurrence (similar block at lines ~274-277) so both reset overrides
in TritonNVFP4MSECalibrator maintain the contract expected by collect().
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 61d118f9-f103-4290-945e-bbc50478d48c
📒 Files selected for processing (2)
modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.pymodelopt/torch/quantization/calib/mse.py
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1387 +/- ##
==========================================
- Coverage 77.40% 76.82% -0.58%
==========================================
Files 476 478 +2
Lines 51319 51404 +85
==========================================
- Hits 39721 39492 -229
- Misses 11598 11912 +314
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Addresses review comments on PR #1387: - TritonNVFP4MSECalibrator.reset() now leaves the calibrator reusable: shape / dtype / n_blocks of the initial amax are stashed in __init__, so collect() no longer depends on _initial_amax surviving reset(). Adds an x.ndim==2 assertion in collect() since the weight quantizer always reshapes upstream. - nvfp4_fp8_scale_sweep validates inputs cleanly instead of using assert (which is stripped by python -O): rejects non-CUDA tensors, non-positive block_size, and empty / non-1D candidates with ValueError. Skips the per-element finite/positive check on candidates since it would scan a 126- entry tensor on every kernel call. - mse_calibrate hoists the MODELOPT_NVFP4_TRITON_SWEEP env-var lookup out of the per-quantizer loop and resolves to the calibrator class once. - Updates test_reset_allows_recollect to verify the new reuse contract; adds test_input_validation covering the new ValueErrors. The duplicate fp8_scale_candidates implementation in the kernel file and NVFP4MSECalibrator._generate_candidates() is left in place: deduplicating would force the reference path to import from the kernel module, which is gated behind Triton availability. The FP8 E4M3 spec is fixed and the parity test exercises both paths against each other. Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
meenchen
left a comment
There was a problem hiding this comment.
Bot review — DM the bot to share feedback.
All critical issues from the previous review have been addressed:
- ✅
assert x.is_cuda→raise ValueError - ✅
block_size <= 0validation added - ✅ Custom
candidatestensor validation added - ✅
reset()reusability fixed (metadata stashed in__init__, tested intest_reset_allows_recollect) - ✅ Env var hoisted above the loop
- ✅
x.ndim == 2assertion added incollect() - ✅ Candidate duplication addressed with sync docstring in
_generate_candidates
The kernel implementation is clean, the math insight about the FP8 round-trip identity is sound, and test coverage is solid (15 GPU tests covering parity, dtypes, round-trip, reset, input validation, and speedup).
Remaining concern: The PR body still has two unchecked test plan items — H100/A100 validation and end-to-end PTQ on a 70B model. Per project norms, these should be completed before merge. Nudging for human sign-off on that.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@modelopt/torch/quantization/calib/mse.py`:
- Around line 260-266: In collect(), replace the runtime assertion with explicit
input validation: check that the input tensor x has ndim == 2 and raise a
ValueError with a clear message if not; check block_size = x.shape[-1] is > 0
and raise ValueError if it is zero to avoid ZeroDivisionError; compute n_blocks
= x.numel() // block_size only after these checks and if n_blocks !=
self._n_blocks raise a ValueError describing the mismatch (referencing the
collect method and TensorQuantizer._process_for_blockquant behavior to explain
expected [n_blocks, block_size] shape).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: b2ef23a9-7bcf-4284-b38a-2cf368077ba4
📒 Files selected for processing (4)
modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.pymodelopt/torch/quantization/calib/mse.pymodelopt/torch/quantization/model_calib.pytests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py
✅ Files skipped from review due to trivial changes (2)
- modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py
- modelopt/torch/quantization/model_calib.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py
meenchen
left a comment
There was a problem hiding this comment.
Bot review — DM the bot to share feedback.
All critical issues from previous reviews have been addressed:
assert x.is_cudareplaced withraise ValueErrorblock_size <= 0validation added before use- Custom
candidatestensor validation (non-empty, 1-D) added reset()reusability fixed — shape/dtype/n_blocks metadata stashed in__init__, verified bytest_reset_allows_recollect- Env var read hoisted above the weight quantizer loop
x.ndim == 2shape check added incollect()- Candidate duplication addressed with sync docstring
The kernel implementation is correct — the FP8 round-trip identity insight is sound, input validation is thorough, and test coverage is solid (16 GPU tests covering parity across seeds/dtypes/block counts, round-trip, reset, input validation, and speedup reporting). License headers match the canonical LICENSE_HEADER.
meenchen
left a comment
There was a problem hiding this comment.
LGTM. @Fridah-nv could you also help review this PR?
| @torch.no_grad() | ||
| def compute_amax(self, verbose: bool = False): | ||
| """Return the per-block amax computed during ``collect``.""" | ||
| return self._best_amax | ||
|
|
||
| def reset(self): | ||
| """Reset the stored best amax. Subsequent ``collect`` calls are allowed.""" | ||
| self._best_amax = None | ||
| super().reset() |
There was a problem hiding this comment.
| def __init__( | ||
| self, | ||
| amax: torch.Tensor, | ||
| global_amax: torch.Tensor, | ||
| axis: int | tuple | list | None = None, | ||
| quant_func: Callable | None = None, | ||
| error_func: Callable | None = None, | ||
| ): | ||
| """Initialize the Triton-fused NVFP4 MSE calibrator. | ||
|
|
||
| See :class:`NVFP4MSECalibrator`. ``quant_func``/``error_func`` are unused by | ||
| the kernel path but accepted for API parity. Tile shape and ``num_warps`` are | ||
| autotuned by the kernel per ``N_BLOCKS``. | ||
| """ | ||
| super().__init__( | ||
| amax=amax, | ||
| global_amax=global_amax, | ||
| axis=axis, | ||
| quant_func=quant_func, | ||
| error_func=error_func, | ||
| ) | ||
| # Stash shape metadata so collect() can keep working after reset() releases | ||
| # the (potentially large) _initial_amax buffer. | ||
| self._initial_amax_shape = tuple(amax.shape) | ||
| self._initial_amax_dtype = amax.dtype | ||
| self._n_blocks = int(amax.numel()) | ||
| self._best_amax: torch.Tensor | None = None |
There was a problem hiding this comment.
why do we need this? why not inherit from parent class?
There was a problem hiding this comment.
Agreed on the design call-out. The shape/dtype/n_blocks stash in __init__ is necessary today only because MseCalibrator.reset() deletes _initial_amax, and the Triton path still needs that metadata to reshape the kernel's flat output and validate later collect() calls — so I'd rather not strip it without the larger refactor.
The cleaner shape — a small NVFP4 sweep base/helper with the reference and Triton calibrators as siblings off it — I'd like to do as a follow-up so this PR stays scoped to "port the sweep to Triton." Will file a follow-up issue.
Per-thread items addressed in the latest commit:
collect()now raisesValueErrorinstead ofassert, and validatesx.shape[-1]before using it.- Dropped the public
candidatesoverride innvfp4_fp8_scale_sweep(no internal callers used it; wrong-length input would have silently inflated the unrolledtl.static_rangecodegen). - Added
test_mse_calibrate_dispatchcovering the default-on +MODELOPT_NVFP4_TRITON_SWEEP=0fallback wiring.
Reviewed current head Remaining items I would resolve before merge:
|
Notes on the active review threads:
|
Additional review notes from the current head
|
|
@realAsma — thanks for the thorough review. Punch-list against your three top-level comments, addressed in commit e57c7a8: Fixed in this PR
Deferring to a follow-up
CI
|
Thanks for adding the Triton sweep path here. One alternate design that I think is worth considering is to keep That would keep the higher-level calibration wiring focused on calibration intent instead of kernel selection, while still allowing the fast path when it is valid. Concretely, I like this boundary because custom |
|
@realAsma — adopted this design in 95b8a95. Outcome: Refactor
Tests
Validation on B300
Pre-existing items that are unaffected by this commit:
|
There was a problem hiding this comment.
♻️ Duplicate comments (1)
modelopt/torch/quantization/calib/mse.py (1)
233-234:⚠️ Potential issue | 🟠 Major | ⚡ Quick winAdd explicit positive
block_sizevalidation before Triton launch.Line 257 passes
block_size=x.shape[-1]directly. With shape[n_blocks, 0], the fast-path predicate can still pass and launch the kernel withblock_size=0, which can fail at runtime with a less actionable error.Suggested patch
if self._losses_sum is None and self._can_use_triton_fast_path(x): from modelopt.torch.kernels.quantization.gemm import nvfp4_fp8_scale_sweep - best_flat = nvfp4_fp8_scale_sweep(x.detach(), self._global_amax, block_size=x.shape[-1]) + block_size = x.shape[-1] + if block_size <= 0: + raise ValueError(f"Expected positive block_size in collect(), got {block_size}.") + best_flat = nvfp4_fp8_scale_sweep(x.detach(), self._global_amax, block_size=block_size)As per coding guidelines, "
**/*.py: All code must follow the security guidelines inSECURITY.md— violations are blocked as pre-merge errors."Also applies to: 257-257
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/quantization/calib/mse.py` around lines 233 - 234, The fast-path currently computes block_size as block_size = x.shape[-1] and can be zero for shape [n_blocks, 0], allowing a Triton kernel launch with block_size==0; add an explicit positive-size guard before the Triton launch/fallback decision (the predicate that checks x.ndim and x.shape[0] against int(self._initial_amax.numel())) by verifying block_size > 0 (using the same local x and block_size variables) and return False / take the safe fallback if block_size <= 0 to avoid launching with block_size==0.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Duplicate comments:
In `@modelopt/torch/quantization/calib/mse.py`:
- Around line 233-234: The fast-path currently computes block_size as block_size
= x.shape[-1] and can be zero for shape [n_blocks, 0], allowing a Triton kernel
launch with block_size==0; add an explicit positive-size guard before the Triton
launch/fallback decision (the predicate that checks x.ndim and x.shape[0]
against int(self._initial_amax.numel())) by verifying block_size > 0 (using the
same local x and block_size variables) and return False / take the safe fallback
if block_size <= 0 to avoid launching with block_size==0.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 2826b6cc-25f7-4dc6-91ef-b2c9f8621b76
📒 Files selected for processing (3)
modelopt/torch/quantization/calib/mse.pymodelopt/torch/quantization/model_calib.pytests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py
✅ Files skipped from review due to trivial changes (1)
- modelopt/torch/quantization/model_calib.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py
|
/claude review |
| scale_safe = tl.where(scale == 0.0, 1.0, scale) | ||
| q_mag = fp4_round_magnitude(w_abs / scale_safe) | ||
| diff = w_abs - q_mag * scale |
There was a problem hiding this comment.
[SUGGESTION]
When global_amax == 0, scale is 0 and scale_safe becomes 1. The loss computation then uses:
diff = w_abs - q_mag * scale # scale is 0 → diff = w_abs
This is correct — loss equals sum(w_abs²) for every candidate so any best_idx is fine, and the final best_amax = global_amax * best_c = 0 regardless of which candidate wins.
However, there's a subtlety: after the scale_safe substitution, q_mag = fp4_round_magnitude(w_abs / 1.0) which rounds |w| itself — this is dead work since q_mag * 0 == 0 anyway. You could short-circuit (diff = w_abs directly when scale == 0), but given it's a degenerate edge case that shouldn't appear in practice, this is purely cosmetic. Noting it for documentation purposes only.
| if self._losses_sum is None and self._can_use_triton_fast_path(x): | ||
| from modelopt.torch.kernels.quantization.gemm import nvfp4_fp8_scale_sweep | ||
|
|
||
| best_flat = nvfp4_fp8_scale_sweep(x.detach(), self._global_amax, block_size=x.shape[-1]) |
There was a problem hiding this comment.
[SUGGESTION]
Minor: The reference path in MseCalibrator.collect (line 93) does x = x.detach().to(dtype=torch.float32) before computing losses, promoting the input to fp32 on the host side. Here the Triton path passes x.detach() without explicit fp32 promotion. This is fine because the kernel loads with .to(tl.float32) internally — but it means the kernel receives, e.g., a bf16 tensor and Triton must handle the load-conversion. Just noting for future readers that the fp32 semantics are maintained inside the kernel, not at the call site.
| if x.numel() % block_size != 0: | ||
| raise ValueError(f"x.numel() ({x.numel()}) is not divisible by block_size ({block_size}).") | ||
|
|
||
| candidates = fp8_scale_candidates(x.device).to(dtype=torch.float32) |
There was a problem hiding this comment.
[IMPORTANT Performance]
fp8_scale_candidates(x.device) is called on every invocation of nvfp4_fp8_scale_sweep, allocating a new 126-element tensor each time. In the weight calibration loop (mse_calibrate step 3), this function is called once per weight quantizer — potentially hundreds of times for large models.
The allocation is small (504 bytes) and the cost is negligible compared to the kernel itself, so this is not a blocking issue. However, if you want to eliminate it, a module-level cache (e.g., functools.lru_cache keyed on device, or a _candidates_cache: dict[torch.device, torch.Tensor]) would ensure the tensor is computed once per device per process.
Not urgent given the 42x speedup context, but would be a clean follow-up.
There was a problem hiding this comment.
Claude Deep Review Summary
CRITICAL: 0 | IMPORTANT: 1 | SUGGESTION: 2
Overall Assessment: Low Risk — Approve
This PR is well-crafted. The core algorithmic insight (FP8 round-trip identity for candidates constructed as valid_fp8_e4m3_value / 448) is mathematically sound — I traced it through the two-level NVFP4 scaling chain (fp8_quantize_scale → nvfp4_scalar_quant) and confirmed the identity holds. The kernel correctly:
- Loads weights once per tile, promotes to fp32 in-register
- Exploits sign-invariance of MSE to work on
|w|throughout - Uses
fp4_round_magnitude(shared with the reference kernel) for exact FP4 rounding - Handles the degenerate
global_amax == 0edge case correctly (any candidate is equally good, and the outputbest_amax = 0regardless)
Dispatch/fallback design is robust:
- Guards on
error_func is None(protects Hessian-weighted custom callers) - Guards on CUDA, 2D layout matching
_initial_amax, Triton importability - Env-var opt-out for debugging/bisection
- One-shot semantics with clear
RuntimeErroron misuse +reset()for reuse
Mode/state composability: No changes to modelopt_state schema, mode registration, or export paths — the change is purely internal to the calibrator's implementation.
Findings
| # | Severity | File | Summary |
|---|---|---|---|
| 1 | IMPORTANT | nvfp4_fp8_sweep.py:139 |
fp8_scale_candidates() allocates a new tensor per call; cacheable for marginal perf gain in large-model loops |
| 2 | SUGGESTION | nvfp4_fp8_sweep.py:98-100 |
Documenting the scale==0 dead-work path for clarity |
| 3 | SUGGESTION | mse.py:257 |
Noting that fp32 promotion happens inside the kernel rather than at the call site (correct, just non-obvious) |
None of these are blocking. The test coverage is thorough (parity across dtypes/seeds/block-counts, dispatch gate coverage, end-to-end through mtq.quantize, and a relaxed speedup benchmark). LGTM.
|
|
||
| def _generate_candidates(self, device: torch.device) -> torch.Tensor: | ||
| """Generate 126 valid FP8 E4M3 scale candidates.""" | ||
| """Generate 126 valid FP8 E4M3 scale candidates. |
There was a problem hiding this comment.
can the triton kernel be an optional arg? We are seeing bugs in the MSE FP8 sweep path that @Fridah-nv is fixing, and completely replacing the current sweep is risky. We can keep triton as an option (via quant config arg) while we test its confidence on more model types (you've only tested 1 qwen 8b model)
There was a problem hiding this comment.
@jenchen13 I will make the triton as the default. The two path has unittest covering and make sure they produces the same scales.
| return False | ||
| if not x.is_cuda: | ||
| return False | ||
| if os.environ.get("MODELOPT_NVFP4_TRITON_SWEEP", "1") == "0": |
| # the same for every candidate, so any best_idx is fine. | ||
| scale_safe = tl.where(scale == 0.0, 1.0, scale) | ||
| q_mag = fp4_round_magnitude(w_abs / scale_safe) | ||
| diff = w_abs - q_mag * scale |
There was a problem hiding this comment.
Can we use the same scale here please?
| diff = w_abs - q_mag * scale | |
| diff = w_abs - q_mag * scale_safe |
There was a problem hiding this comment.
Done in a9c8ccf. Equivalent on real inputs (the only case where scale_safe differs from scale is global_amax == 0, but in that case w_abs is also zero by construction — global_amax = max|w| — so the loss is zero for every candidate either way), but more consistent with the divisor used to compute q_mag.
| @requires_triton | ||
| @pytest.mark.parametrize("seed", [0, 1, 2]) | ||
| @pytest.mark.parametrize("num_blocks", [4, 64, 1024]) | ||
| def test_parity_random_weights(seed, num_blocks): |
There was a problem hiding this comment.
This test currently tests only fp32 type. can we extend this test for bf16 and fp16 as well? This is an important test.
There was a problem hiding this comment.
Done in a9c8ccf. test_parity_random_weights is now parametrized on dtype ∈ {fp32, fp16, bf16} in addition to seed and num_blocks, so the canonical 3 seeds × 3 num_blocks grid is exercised on every supported dtype. Folded the smaller test_parity_dtypes into this since it was a strict subset.
Note on validation: the host I had available for this round (single-GPU B200) has unusually slow Triton compile (~9.5s cold per kernel signature vs ~50ms on the B300 used previously), so the full test file timed out at 600s on this host. The kernel itself runs correctly (verified via direct call: 9.55s cold → 0.7ms cached, correct shape/dtype out). Pre-commit + ruff + mypy clean. Will rely on CI for the full sweep validation.
realAsma
left a comment
There was a problem hiding this comment.
Looks great!
I have left a few comments. Could you please address them as well?
Replaces the 126-iteration Python sweep in NVFP4MSECalibrator with a single fused Triton kernel that loads each NVFP4 block once, evaluates all 126 valid FP8 E4M3 scale candidates in registers, and emits the per-block best amax directly. For our specific candidate set (FP8 representable values / 448) the FP8 round-trip on the per-block scale is the identity, so the kernel uses `scale = candidate * global_amax / 6.0` and runs on any CUDA + Triton. Triton-backed calibrator is on by default for `mse_calibrate(... fp8_scale_sweep=True)`; set `MODELOPT_NVFP4_TRITON_SWEEP=0` to fall back to the reference for debugging. Measured ~7.4x speedup on a B300 over the reference NVFP4MSECalibrator (8192x4096 weight, ~2M NVFP4 blocks: 176.67 ms -> 23.81 ms). Bit-identical to the reference for typical block counts; on multi-million-block weights an occasional adjacent-candidate tie-break can differ at the fp32-noise level (observed 2 / 2,097,152 blocks; per-block MSE within 1e-7 relative). Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
…ner loop Two follow-on optimizations to the fused FP8 scale sweep kernel: 1. @triton.autotune over (BLOCKS_PER_PROGRAM, num_warps): a hand-sweep on B300 showed the previous default (BPP=4, num_warps=4) at 23.7 ms left ~4x on the table — best config (BPP=64, num_warps=8) lands at ~5 ms. Three configs are included to cover small/medium/large N_BLOCKS without flooding compile time. 2. Drop the sign-handling tl.where: since FP4 quantization preserves sign, (w - w_q)^2 == (|w| - |w_q|)^2, so the kernel works on |w| throughout and skips one tl.where + negation per element per candidate. Result on the same 8192x4096 weight (~2M blocks) on B300: reference NVFP4MSECalibrator: 176.68 ms triton TritonNVFP4MSECalibrator: 4.23 ms speedup: 41.8x (was 7.4x) This is ~1.2x above the rough pure-compute floor (~240 GF / 67 TF/s ~= 3.6 ms), so the kernel is now near saturation and further wins would need an algorithmic change (candidate pruning, etc.). Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
Addresses review comments on PR #1387: - TritonNVFP4MSECalibrator.reset() now leaves the calibrator reusable: shape / dtype / n_blocks of the initial amax are stashed in __init__, so collect() no longer depends on _initial_amax surviving reset(). Adds an x.ndim==2 assertion in collect() since the weight quantizer always reshapes upstream. - nvfp4_fp8_scale_sweep validates inputs cleanly instead of using assert (which is stripped by python -O): rejects non-CUDA tensors, non-positive block_size, and empty / non-1D candidates with ValueError. Skips the per-element finite/positive check on candidates since it would scan a 126- entry tensor on every kernel call. - mse_calibrate hoists the MODELOPT_NVFP4_TRITON_SWEEP env-var lookup out of the per-quantizer loop and resolves to the calibrator class once. - Updates test_reset_allows_recollect to verify the new reuse contract; adds test_input_validation covering the new ValueErrors. The duplicate fp8_scale_candidates implementation in the kernel file and NVFP4MSECalibrator._generate_candidates() is left in place: deduplicating would force the reference path to import from the kernel module, which is gated behind Triton availability. The FP8 E4M3 spec is fixed and the parity test exercises both paths against each other. Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
Address realAsma's review feedback on the NVFP4 FP8 sweep kernel: - TritonNVFP4MSECalibrator.collect: replace `assert x.ndim == 2` with ValueError so the contract still holds under `python -O`, validate block_size > 0 before use, and derive n_blocks from x.shape[0] so a zero last-dim cannot trigger division before the shape check. - nvfp4_fp8_scale_sweep: drop the public `candidates` parameter. The candidate set is fixed (FP8 E4M3 valid values / 448) and a wrong length would silently inflate `tl.static_range` codegen, while nonpositive/nonfinite entries violate the kernel's scale assumptions. No internal caller used the override. - Add test_mse_calibrate_dispatch covering the public default + opt-out wiring: confirms `mse_calibrate(fp8_scale_sweep=True)` installs TritonNVFP4MSECalibrator by default and falls back to NVFP4MSECalibrator when MODELOPT_NVFP4_TRITON_SWEEP=0. Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
Per realAsma's review, collapse TritonNVFP4MSECalibrator into NVFP4MSECalibrator as an internal fast path rather than a separately-exported subclass: - mse.py: NVFP4MSECalibrator.collect() picks the fused Triton kernel via a predicate _can_use_triton_fast_path(x) that requires error_func is None, CUDA input, blocked layout matching the per-block amax, the kernel package importable, and MODELOPT_NVFP4_TRITON_SWEEP \!= "0". Otherwise falls back to the parent's reference 126-step sweep. Override reset() to clear only per-cycle state and keep _initial_amax (shape [num_blocks], small) so the calibrator is reusable; the multi-collect-after-fast-path case raises a RuntimeError with a clear message. TritonNVFP4MSECalibrator class deleted. - model_calib.py: always instantiate NVFP4MSECalibrator; drop the TritonNVFP4MSECalibrator import and the env-var dispatch (now internal). - tests: drop the TritonNVFP4MSECalibrator references. Force the requested path via a _force_sweep_path() context manager around the env var. New dispatch tests assert the predicate's behavior for the env opt-out, custom error_func, and CPU input cases. test_mse_calibrate_end_to_end exercises the full mtq.quantize wiring with default and MODELOPT_NVFP4_TRITON_SWEEP=0 and asserts bitwise-identical model outputs. This fixes a latent correctness issue: the previous TritonNVFP4MSECalibrator silently ignored a custom error_func, so a caller passing a Hessian-weighted loss (e.g. local-Hessian calibration) would have gotten plain squared-error results from the kernel. The new predicate routes any non-None error_func to the reference path so the user's metric is honored. Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
Three changes from realAsma's latest review: - nvfp4_fp8_sweep kernel: use ``scale_safe`` rather than ``scale`` in the per-candidate diff so the divisor and multiplier match. Numerically equivalent on real inputs (the only case where ``scale_safe`` differs from ``scale`` is ``global_amax == 0``, in which case ``w_abs`` is also zero so the loss is zero either way), but more consistent. - Extract ``fp8_scale_candidates`` to a triton-free module ``_fp8_scale_candidates.py`` so the calibrator's reference sweep and the Triton kernel wrapper share one definition. Removes the duplicate copy in ``NVFP4MSECalibrator._generate_candidates``. - Parity test: extend ``test_parity_random_weights`` to cover bf16 and fp16 in addition to fp32 by parametrizing on dtype, so the canonical parity grid (3 seeds × 3 num_blocks) is now exercised on every supported dtype. Folded the smaller ``test_parity_dtypes`` into this since it was a strict subset. Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
a9c8ccf to
8f04a9a
Compare
Trim the parity grid to keep all three axes but with smaller per-axis ranges: 2 seeds × 2 num_blocks × 3 dtypes = 12 parametrized cases (down from 3×3×3 = 27). Still exercises every supported dtype and the small/ large num_blocks extremes that drive different autotune choices, while roughly halving the cold-compile cost on hosts where Triton compilation is expensive. Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
After folding the Triton fast path into NVFP4MSECalibrator.collect(), two tests in tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py broke because they inspect path-specific state: - test_collect_and_compute_amax asserts ``cal._losses_sum is not None`` with ``len == 126``. Only the reference 126-step sweep populates that list; the Triton fast path produces ``_best_amax_fast`` directly and leaves ``_losses_sum = None``. - test_multiple_collections asserts that two ``collect()`` calls accumulate. The Triton fast path is one-shot by design and refuses a second collect until ``reset()``, so multi-collect is fundamentally reference-path semantics. Fix: take the ``monkeypatch`` fixture in both tests and force ``MODELOPT_NVFP4_TRITON_SWEEP=0`` so they exercise the reference accumulator. Triton-path coverage stays in test_nvfp4_fp8_sweep_kernel.py (parity, dispatch predicate, end-to-end mtq.quantize). The other tests in the same class (initialization, candidate generation, per-block independence) are path-agnostic and unchanged. Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
Summary
NVFP4MSECalibratorwith a fused Triton kernel that loads each NVFP4 block once, evaluates all 126 valid FP8 E4M3 scale candidates in registers, and emits the per-blockbest_amaxdirectly.TritonNVFP4MSECalibratoris the default formse_calibrate(..., fp8_scale_sweep=True). SetMODELOPT_NVFP4_TRITON_SWEEP=0to fall back to the reference for debugging or numerics comparison.NVFP4MSECalibrator) on a representative LLM weight (8192x4096, ~2M NVFP4 blocks):176.68 ms -> 4.23 ms.hf_ptq.py --qformat nvfp4_mse, calib=128): ~6.7x faster mtq.quantize with identical global weight-MSE as the reference.End-to-end Qwen3-8B PTQ (B300)
Settings:
--calib_size 128 --calib_seq 512, defaultnvfp4_mseconfig (NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG). The firstnvfp4run was discarded as a warm-up for HF weight loading.mtq.quantizetimenvfp4(no MSE search)nvfp4_mse(reference, slow)MODELOPT_NVFP4_TRITON_SWEEP=0nvfp4_mse(Triton, fast — this PR)nvfp4_mseis 6.67x faster than the referencenvfp4_mse, and adds only ~3 s over the no-MSE baseline (vs ~40 s for the reference) — making the MSE-search option nearly free in practice.Per-layer MSE CSVs are produced by
tools/debugger/compare_mse_qwen.py(one row per Linear weight) for closer inspection if needed.Microbenchmark (B300, 8192x4096 weight, ~2M NVFP4 blocks)
Why this works
Each candidate is constructed as
valid_fp8_e4m3_value / 448. Withblock_amax = global_amax * candidate, the FP8 round-trip on the per-block scaleblock_amax / 6(usingglobal_amax / 6as the FP8 amax) is the identity — so the kernel can computescale = candidate * global_amax / 6.0inline and skip the FP8 cast. This keeps the kernel runnable on any CUDA + Triton (notl.float8e4nvrequirement).Because every candidate's per-block scale is just a rescaling of the same input block, all 126 candidates can be evaluated against a single
[BLOCKS_PER_PROGRAM, BLOCK_SIZE]tile held in registers — replacing 126 weight-bandwidth passes with 1.Two follow-on optimizations close the gap to the compute ceiling:
@triton.autotuneover(BLOCKS_PER_PROGRAM, num_warps)— the original hand-picked default (BPP=4, num_warps=4) left ~4x on the table; the best B300 config isBPP=64, num_warps=8.tl.where: FP4 quant preserves sign, so(w - w_q)^2 == (|w| - |w_q|)^2and the kernel works on|w|throughout (one fewer where + negation per element per candidate).Files
modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py— new kernel +nvfp4_fp8_scale_sweepwrapper, autotuned.modelopt/torch/kernels/quantization/gemm/__init__.py— wire-in.modelopt/torch/quantization/calib/mse.py— newTritonNVFP4MSECalibrator(NVFP4MSECalibrator).modelopt/torch/quantization/model_calib.py— opt-out env var (MODELOPT_NVFP4_TRITON_SWEEP=0); Triton path is default.tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py— 16 GPU tests covering parity (across seeds, block counts, dtypes), input validation, output round-trip, reset, and a wall-clock speedup report.Numerics
Bit-identical to the reference for typical block counts (
{4, 64, 1024}blocks, 3 seeds, fp32/fp16/bf16 — 14/15 microbenchmark tests bit-exact).On multi-million-block weights an occasional adjacent-candidate tie-break can differ at the fp32-noise level (observed 2 / 2,097,152 blocks in the speedup test, per-block MSE within ~1e-7 relative). The reference's CUDA
fake_e4m3fyand the Triton inline math have slightly different op ordering, which lets nearly-tied candidates flip. The speedup test asserts the worst per-block MSE gap is< 1e-5relative on differing blocks — both choices are valid argmins; the resulting quantized weights are equally good. The Qwen3-8B end-to-end run confirms this: aggregate weight MSE matches the reference exactly at the displayed precision.Test plan
pytest tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py -v— 16/16 pass on B300pytest tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py -v— existing NVFP4 tests still pass (9/9)--qformat nvfp4_mse: 6.67x speedup, identical weight MSE to reference🤖 Generated with Claude Code
Summary by CodeRabbit
New Features
Bug Fixes / Behavior
Tests