Implement per-token NVFP4 fprop recipe#2931
Conversation
Greptile SummaryThis PR adds a per-token (row-scaled) NVFP4 forward quantization recipe for TransformerEngine. When enabled via
Confidence Score: 4/5Safe to merge for the fprop-only use case that was tested; the unsupported backward path is blocked by a The row-scaled quantization logic — two-step kernel flow, per-row amax bookkeeping, neutralising global amax before cuBLAS, and post-GEMM FP32 rescaling — is mathematically correct and verified by bitwise-exact tests on B200. The main unresolved structural issue is that several
Important Files Changed
Sequence DiagramsequenceDiagram
participant Mod as Module Forward
participant Quant as NVFP4Quantizer (row_scaled=True)
participant RowAmax as compute_rowwise_amax kernel
participant QKernel as quantize_transpose kernel (ROW_SCALED_NVFP4)
participant PyGEMM as general_gemm (Python)
participant cuBLAS as cuBLAS NVFP4 GEMM
Mod->>Quant: quantize(activation)
Quant->>RowAmax: compute per-row abs-max to amax[0..M-1]
RowAmax-->>Quant: amax array written to tensor
Quant->>QKernel: quantize using amax[row_idx] per block
QKernel-->>Quant: NVFP4 tensor (row_scaled_nvfp4=True)
Quant-->>Mod: NVFP4Tensor with per-row amax metadata
Mod->>PyGEMM: general_gemm(A=weight, B=activation_fp4)
Note over PyGEMM: Detect _is_nvfp4_row_scaled_tensor(B)
PyGEMM->>PyGEMM: Replace amax to 1.0 in A and B metadata, capture rowwise_global_scales
PyGEMM->>cuBLAS: GEMM with amax=1.0 to FP32 output
cuBLAS-->>PyGEMM: out_fp32
PyGEMM->>PyGEMM: out_fp32 *= rowwise_global_scales per row, add bias in FP32
PyGEMM->>PyGEMM: cast to requested_out_dtype
PyGEMM-->>Mod: scaled output
Reviews (9): Last reviewed commit: "Minor" | Re-trigger Greptile |
| // Compute "correct" per-block encoding scaling factor | ||
| const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : S_enc / S_dec_b_fp32; | ||
| const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : | ||
| fminf(1.0f / (S_dec_b_fp32 * (1.0f / S_enc)), Numeric_Traits<float>::maxNorm); |
There was a problem hiding this comment.
We have to change here to stay aligned with pytorch reference.
Signed-off-by: Ziang Li <ziangli@umich.edu> Co-authored-by: Yigong Qin <qqqyyy1233@outlook.com>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
6998f64 to
5b2f606
Compare
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
|
The following extended tests all passed: |
|
Hi @zianglih, could you clarify why you needed the new quantization kernels? The existing NVFP4 quantization kernels should already work if you only use the rowwise mode there, no? |
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
|
Hi @ptrendx , I have removed the standalone implementation and extended existing kernel to support this per-token nvfp4 recipe. |
| with_post_rht_amax=qparams.random_hadamard_transform, | ||
| with_2d_quantization=qparams.fp4_2d_quantization, | ||
| stochastic_rounding=qparams.stochastic_rounding, | ||
| per_token_activation=self.recipe.per_token_activation and idx % 3 != 1, |
There was a problem hiding this comment.
Hardcoded
idx % 3 != 1 pattern silently misassigns per-token to wrong quantizers
The expression self.recipe.per_token_activation and idx % 3 != 1 assumes that every third quantizer (index 1, 4, 7, …) is always a weight quantizer and should be skipped for per-token scaling. This works for standard Linear / LayerNormLinear layers but is not documented and will silently produce wrong results if the quantizer ordering changes (e.g., MoE layers, attention layers with additional quantizers, or future refactors). The intent should either be codified as a named constant or enforced by tagging each quantizer by its semantic role (activation vs. weight) rather than position.
There was a problem hiding this comment.
We should be able to get rid of this hack once #2620 is merged.
Signed-off-by: Ziang Li <ziangli@umich.edu>
|
The functionality has been verified by nvfp4 rl experiment. |
timmoon10
left a comment
There was a problem hiding this comment.
Overall this is a nice feature, but we should make some changes to the core design. My biggest suggestions:
- Row-scaling should be part of the tensor data and not just hidden in the quantizer. You need to be aware of it for both quantization and dequantization, so it can't be hidden in the quantizer.
- We should enable row-scaling based on a bool flag rather than the amax tensor shape. We should also make sure it is clearly documented.
- We should consider a better name like "1D scaling" or "row scaling" since I don't see any reason this is specific to tokens or activations.
| /*! Whether to enable per-token (per-row) NVFP4 quantization */ | ||
| kNVTEQuantizationConfigNVFP4PerTokenActivation = 8, |
There was a problem hiding this comment.
We should configure this in NVTETensor rather than in NVTEQuantizationConfig. The quantization config is used for quantization, and is not available to any downstream consumers (dequant or GEMM). However, consumers need to be aware of tensor-scaling vs row-scaling. The buffer sizes are different, and getting it wrong means incorrect values or segfaults.
There was a problem hiding this comment.
Dequantization handles row-scaled NVFP4, but we need to make sure that other quantized tensor consumers also handle it. Erroring out is fine for now. Currently the only other consumers we need to handle are GEMM and attention, although we should keep this mind if we add more features in the future.
|
|
||
|
|
||
| @pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) | ||
| def test_nvfp4_per_token_quantizer_roles(): |
There was a problem hiding this comment.
This test will need to be updated once #2620 merges.
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Description
@HumansAnd
Implement per-token NVFP4 recipe with fprop only.
Currently, the per-token scaling is handled by separate pytorch code.
Quantization kernels are bitwise exact with existing TE reference implementation.
The following tests passed on B200:
Type of change
Changes
Please list the changes introduced in this PR:
per_token_activationfield in nvfp4 recipe, can be turned on byNVTE_NVFP4_PER_TOKEN_ACTIVATIONtransformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh, bitwise exact with existing TE pytorch reference implementation and per-tesor nvfp4 emulated implmentation.transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuhto correctly handle this per-token nvfp4TransformerEngine/transformer_engine/pytorch/cpp_extensions/gemm.py, if per-token nvfp4 is detected, it conducts separate per-token scaling using pytorch code, after cublas gemmtests/cpp/operator/test_cast_nvfp4_transpose.cuto align with pytorch reference numericsChecklist: