From 14f77bee4f7231a0d2b9b343c16582b214a9c498 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 26 Apr 2026 00:37:14 -0700 Subject: [PATCH 01/45] Adapt initial implementation and make quantization bitwise exact Signed-off-by: Ziang Li Co-authored-by: Yigong Qin --- docs/envvars.rst | 6 + .../nvfp4/test_nvfp4_quantize_exact.py | 55 +++++ tests/pytorch/test_backward_override.py | 27 ++- tests/pytorch/utils.py | 10 +- .../common/cast/dispatch/quantize.cuh | 19 ++ .../common/cast/nvfp4/dequantize_nvfp4.cuh | 14 +- .../cast/nvfp4/quantize_pertoken_nvfp4.cuh | 220 ++++++++++++++++++ transformer_engine/common/common.h | 4 +- .../transformer_engine/transformer_engine.h | 9 + transformer_engine/common/recipe/__init__.py | 6 + .../common/transformer_engine.cpp | 6 + .../pytorch/cpp_extensions/gemm.py | 74 +++++- transformer_engine/pytorch/csrc/common.h | 1 + transformer_engine/pytorch/csrc/extensions.h | 2 + .../pytorch/csrc/extensions/cast.cpp | 131 ++++++++++- .../pytorch/csrc/extensions/pybind.cpp | 2 + transformer_engine/pytorch/csrc/quantizer.cpp | 49 +++- .../custom_recipes/quantization_nvfp4.py | 60 ++++- transformer_engine/pytorch/quantization.py | 2 + .../pytorch/tensor/nvfp4_tensor.py | 11 +- 20 files changed, 663 insertions(+), 45 deletions(-) create mode 100644 transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh diff --git a/docs/envvars.rst b/docs/envvars.rst index 1e040b4c3e..58988b5473 100644 --- a/docs/envvars.rst +++ b/docs/envvars.rst @@ -281,6 +281,12 @@ Kernel Configuration :Default: ``0`` :Description: Emit a warning when falling back from CUTLASS to cuBLAS for grouped GEMM operations. +.. envvar:: NVTE_NVFP4_PER_TOKEN_ACTIVATION + + :Type: ``int`` (0 or 1) + :Default: ``0`` + :Description: Enable per-token activation quantization for the ``NVFP4BlockScaling`` recipe in GroupedLinear split-quantize paths. When set to ``1`` (or when ``NVFP4BlockScaling(per_token_activation=True)`` is used), NVFP4 rowwise ``amax`` metadata stores one FP32 value per token (row) instead of a single scalar. + Torch Compilation and Fusion ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index bf3f545b8b..7e94911ddd 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -23,6 +23,20 @@ def unpack_fp4(x: torch.Tensor) -> torch.Tensor: return repeated +def maybe_skip_pertoken_nvfp4( + x_dtype: torch.dtype = torch.bfloat16, + *, + return_transpose: bool = False, + with_2d_quantization: bool = False, +) -> None: + if x_dtype == torch.float32: + pytest.skip("Per-token NVFP4 kernel supports BF16/FP16 inputs only") + if return_transpose: + pytest.skip("Per-token NVFP4 currently supports rowwise-only quantization") + if with_2d_quantization: + pytest.skip("Per-token NVFP4 does not support 2D quantization") + + def check_quantization_nvfp4_versus_reference( x_dtype: torch.dtype, M: int, @@ -31,6 +45,7 @@ def check_quantization_nvfp4_versus_reference( swizzled_scale: bool, use_cpp_allocator: bool, with_2d_quantization: bool, + per_token_activation: bool = False, ) -> None: te_dtype = tex.DType.kFloat4E2M1 @@ -52,6 +67,7 @@ def check_quantization_nvfp4_versus_reference( with_rht=False, with_post_rht_amax=False, with_2d_quantization=with_2d_quantization, + per_token_activation=per_token_activation, ) if use_cpp_allocator: x_nvfp4_sut = nvfp4_quantizer(x) @@ -83,6 +99,7 @@ def check_quantization_nvfp4_versus_reference( pow_2_scales=False, eps=0.0, quant_tile_shape=quant_tile_shape, + per_token_activation=per_token_activation, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -155,6 +172,9 @@ def check_quantization_nvfp4_versus_reference( @pytest.mark.parametrize( "with_2d_quantization", [True, False], ids=["2d_quantization", "1d_quantization"] ) +@pytest.mark.parametrize( + "per_token_activation", [False, True], ids=["nvfp4_per_tensor", "nvfp4_pertoken"] +) def test_quantization_block_tiling_versus_reference( x_dtype: torch.dtype, M: int, @@ -163,7 +183,14 @@ def test_quantization_block_tiling_versus_reference( swizzled_scale: bool, use_cpp_allocator: bool, with_2d_quantization: bool, + per_token_activation: bool, ) -> None: + if per_token_activation: + maybe_skip_pertoken_nvfp4( + x_dtype=x_dtype, + return_transpose=return_transpose, + with_2d_quantization=with_2d_quantization, + ) check_quantization_nvfp4_versus_reference( x_dtype=x_dtype, M=M, @@ -172,6 +199,7 @@ def test_quantization_block_tiling_versus_reference( swizzled_scale=swizzled_scale, use_cpp_allocator=use_cpp_allocator, with_2d_quantization=with_2d_quantization, + per_token_activation=per_token_activation, ) @@ -188,6 +216,9 @@ def test_quantization_block_tiling_versus_reference( @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) +@pytest.mark.parametrize( + "per_token_activation", [False, True], ids=["nvfp4_per_tensor", "nvfp4_pertoken"] +) def test_nvfp4_quantization_extrema_versus_reference( x_dtype: torch.dtype, M: int, @@ -195,6 +226,7 @@ def test_nvfp4_quantization_extrema_versus_reference( extrema_high: bool, return_transpose: bool, use_cpp_allocator: bool, + per_token_activation: bool, ): te_dtype = tex.DType.kFloat4E2M1 @@ -208,6 +240,9 @@ def test_nvfp4_quantization_extrema_versus_reference( else: x = torch.zeros((M, N), dtype=x_dtype, device=device) + if per_token_activation: + maybe_skip_pertoken_nvfp4(x_dtype=x_dtype, return_transpose=return_transpose) + nvfp4_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, rowwise=True, @@ -216,6 +251,7 @@ def test_nvfp4_quantization_extrema_versus_reference( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, + per_token_activation=per_token_activation, ) if use_cpp_allocator: @@ -245,6 +281,7 @@ def test_nvfp4_quantization_extrema_versus_reference( pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), + per_token_activation=per_token_activation, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -286,12 +323,16 @@ def test_nvfp4_quantization_extrema_versus_reference( @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) +@pytest.mark.parametrize( + "per_token_activation", [False, True], ids=["nvfp4_per_tensor", "nvfp4_pertoken"] +) def test_nvfp4_quantization_boundary_values( x_dtype: torch.dtype, M: int, N: int, return_transpose: bool, use_cpp_allocator: bool, + per_token_activation: bool, ): """ Stress rounding/threshold behavior by placing values just below/above @@ -319,6 +360,9 @@ def test_nvfp4_quantization_boundary_values( row[1::2] = upper x = row.unsqueeze(0).repeat(M, 1).to(dtype=x_dtype) + if per_token_activation: + maybe_skip_pertoken_nvfp4(x_dtype=x_dtype, return_transpose=return_transpose) + nvfp4_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, rowwise=True, @@ -327,6 +371,7 @@ def test_nvfp4_quantization_boundary_values( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, + per_token_activation=per_token_activation, ) if use_cpp_allocator: @@ -356,6 +401,7 @@ def test_nvfp4_quantization_boundary_values( pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), + per_token_activation=per_token_activation, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -397,12 +443,16 @@ def test_nvfp4_quantization_boundary_values( @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) +@pytest.mark.parametrize( + "per_token_activation", [False, True], ids=["nvfp4_per_tensor", "nvfp4_pertoken"] +) def test_nvfp4_quantization_noncontiguous_inputs( x_dtype: torch.dtype, M: int, N: int, return_transpose: bool, use_cpp_allocator: bool, + per_token_activation: bool, ): te_dtype = tex.DType.kFloat4E2M1 @@ -416,6 +466,9 @@ def test_nvfp4_quantization_noncontiguous_inputs( x_nc = x_base.t() # shape (N, M), non-contiguous assert not x_nc.is_contiguous() + if per_token_activation: + maybe_skip_pertoken_nvfp4(x_dtype=x_dtype, return_transpose=return_transpose) + nvfp4_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, rowwise=True, @@ -424,6 +477,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, + per_token_activation=per_token_activation, ) if use_cpp_allocator: @@ -453,6 +507,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), + per_token_activation=per_token_activation, ) x_nvfp4_ref = ref_quantizer.quantize(x_nc) diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index ed4f73adbc..5da55b14b6 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -78,6 +78,11 @@ marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), id="NVFP4BlockScaling", ), + pytest.param( + "nvfp4_pertoken", + marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), + id="NVFP4PerTokenBlockScaling", + ), ] @@ -165,7 +170,7 @@ def _maybe_skip_recipe_dtype( ) -> None: if dtype == torch.bfloat16 and not bf16_available: pytest.skip(reason_for_no_bf16) - if recipe_name == "nvfp4": + if recipe_name in ("nvfp4", "nvfp4_pertoken"): if module_type in ("linear", "layernorm_linear") and dtype not in ( torch.bfloat16, torch.float32, @@ -178,6 +183,16 @@ def _maybe_skip_recipe_dtype( def _maybe_skip_unsupported_recipe_module_combo(recipe_name: str, module_type: str) -> None: if module_type == "ops_linear" and recipe_name == "fp8_block_scaling": pytest.skip("Fusible ops (te_ops.Linear) do not support Float8BlockScaling recipe") + if recipe_name == "nvfp4_pertoken" and module_type in ( + "linear", + "layernorm_linear", + "ops_linear", + "grouped_linear", + ): + pytest.skip( + "Per-token NVFP4 currently supports rowwise-only quantization paths " + "(columnwise usage is unsupported for these modules)." + ) def _maybe_skip_unsupported_recipe_shape( @@ -195,7 +210,9 @@ def _maybe_skip_unsupported_recipe_shape( " by 32." ) return - if recipe_name == "nvfp4" and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): + if recipe_name in ("nvfp4", "nvfp4_pertoken") and ( + flat_first_dim % 16 != 0 or last_dim % 16 != 0 + ): pytest.skip( "Linear/LayerNormLinear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible" " by 16." @@ -220,7 +237,9 @@ def _maybe_skip_unsupported_recipe_shape( pytest.skip( "te_ops.Linear + MXFP8 requires prod(shape[:-1]) and shape[-1] divisible by 32." ) - if recipe_name == "nvfp4" and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): + if recipe_name in ("nvfp4", "nvfp4_pertoken") and ( + flat_first_dim % 16 != 0 or last_dim % 16 != 0 + ): pytest.skip( "te_ops.Linear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible by 16." ) @@ -239,7 +258,7 @@ def _maybe_skip_unsupported_grouped_splits(recipe_name: str, m_splits: list[int] ) if recipe_name == "mxfp8" and any(m % 32 != 0 for m in non_empty_splits): pytest.skip("GroupedLinear + MXFP8 requires each non-empty m_split divisible by 32.") - if recipe_name == "nvfp4" and any(m % 16 != 0 for m in non_empty_splits): + if recipe_name in ("nvfp4", "nvfp4_pertoken") and any(m % 16 != 0 for m in non_empty_splits): pytest.skip("GroupedLinear + NVFP4 requires each non-empty m_split divisible by 16.") if recipe_name == "nvfp4" and any(m % 64 != 0 for m in non_empty_splits): pytest.skip( diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 8f8852edc2..04ac2becbc 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -115,7 +115,7 @@ def quantization_tols(name: str) -> dict[str, float]: "mxfp8_block_scaling", ): return dtype_tols(tex.DType.kFloat8E4M3) - if name == "nvfp4": + if name in ("nvfp4", "nvfp4_pertoken"): return dtype_tols(tex.DType.kFloat4E2M1) raise ValueError(f"Unsupported quantization scheme ({name})") @@ -149,6 +149,14 @@ def make_recipe(name: Optional[str], **recipe_kwargs: Any) -> Optional[Recipe]: disable_2d_quantization=True, **recipe_kwargs, ) + if name == "nvfp4_pertoken": + return transformer_engine.common.recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + per_token_activation=True, + **recipe_kwargs, + ) raise ValueError(f"Unsupported quantization scheme ({name})") diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 5d0d3c28e8..0d86022cc1 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -21,6 +21,7 @@ #include "../mxfp8/group_quantize_mxfp8.cuh" #include "../mxfp8/quantize_mxfp8.cuh" #include "../nvfp4/group_quantize_transpose_nvfp4.cuh" +#include "../nvfp4/quantize_pertoken_nvfp4.cuh" #include "../nvfp4/quantize_transpose_nvfp4.cuh" namespace transformer_engine { @@ -100,6 +101,15 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, int32_t rows = input_tensor->flat_first_dim(); int32_t cols = input_tensor->flat_last_dim(); auto dtype = input_tensor->dtype(); + const bool per_token_activation = quant_config_cpp.nvfp4_per_token_activation; + if (per_token_activation) { + NVTE_CHECK(!output_tensor->has_columnwise_data(), + "Per-token NVFP4 quantization supports rowwise-only output."); + NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, + "Per-token NVFP4 quantization does not support 2D quantization."); + nvfp4::quantize_pertoken(*input_tensor, noop_tensor, output_tensor, stream); + break; + } bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && (cols % 32 == 0) && output_tensor->has_data(); @@ -239,6 +249,15 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens int32_t rows = grad_tensor->flat_first_dim(); int32_t cols = grad_tensor->flat_last_dim(); auto dtype = grad_tensor->dtype(); + const bool per_token_activation = quant_config_cpp.nvfp4_per_token_activation; + if (per_token_activation) { + NVTE_CHECK(!output_tensor->has_columnwise_data(), + "Per-token NVFP4 quantization supports rowwise-only output."); + NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, + "Per-token NVFP4 quantization does not support 2D quantization."); + nvfp4::quantize_pertoken(*grad_tensor, noop_tensor, output_tensor, stream); + break; + } bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && (cols % 32 == 0) && output_tensor->has_data(); diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh index 4143208153..9436b94939 100644 --- a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -34,7 +34,7 @@ namespace dequantize_kernel { template __global__ void __launch_bounds__(512) dequantize_fp4_kernel(const void *const input, OType *output, const fp8e4m3 *const scales, - const float *const tensor_amax, const size_t N, const size_t M, + const float *const tensor_amax, const size_t amax_numel, const size_t N, const size_t M, const size_t scale_stride, const size_t num_scale_tiles_X) { const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; const size_t x = thread_idx % M; @@ -63,7 +63,7 @@ __global__ void __launch_bounds__(512) fp4vec value; value.vec = input_vectorized[my_index]; fp8e4m3 scale = scales[my_scale_index]; - float amax = *tensor_amax; + float amax = (amax_numel == 1) ? tensor_amax[0] : tensor_amax[y]; constexpr float factor_inv = 1.0 / (6.0 * 448.0); float final_scale = static_cast(scale) * amax * factor_inv; #pragma unroll @@ -110,11 +110,11 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES, dequantize_fp4_kernel<<>>( - input.data.dptr, reinterpret_cast(output->data.dptr), - reinterpret_cast(input.scale_inv.dptr), - reinterpret_cast(input.amax.dptr), N, Mread, input.scale_inv.shape.back(), - num_scale_tiles_X);); // NOLINT(*) - ); // NOLINT(*) + input.data.dptr, reinterpret_cast(output->data.dptr), + reinterpret_cast(input.scale_inv.dptr), + reinterpret_cast(input.amax.dptr), input.amax.numel(), N, Mread, input.scale_inv.shape.back(), + num_scale_tiles_X);); // NOLINT(*) +); // NOLINT(*) NVTE_CHECK_CUDA(cudaGetLastError()); #else NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); diff --git a/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh new file mode 100644 index 0000000000..5e1e23f5d5 --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh @@ -0,0 +1,220 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_pertoken_nvfp4.cuh + * \brief CUDA kernels to cast to NVFP4 with per-token (per-row) global scaling. + */ + +#ifndef TRANSFORMER_ENGINE_QUANTIZE_PERTOKEN_NVFP4_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_PERTOKEN_NVFP4_CUH_ + +#include +#include + +#include +#include + +#include "../../common.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" +#include "core_nvfp4.cuh" + +#if FP4_TYPE_SUPPORTED +#include +#endif + +namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { +namespace quantize_pertoken_kernel { + +using namespace core; +using namespace ptx; + +constexpr int PERTOKEN_BLOCK_SIZE = 256; +constexpr int PERTOKEN_SF_VEC_SIZE = 16; + +template +__global__ void +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__launch_bounds__(BLOCK_SIZE) +#endif + quantize_pertoken_nvfp4_kernel( + const int num_rows, const int num_cols, const IType *__restrict__ input, + const int *__restrict__ row_offsets, uint8_t *__restrict__ output_data, + fp8e4m3 *__restrict__ output_scales, float *__restrict__ output_per_token_amax, + const int scale_stride, const float *__restrict__ noop) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + using namespace detail; + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + + using IType2 = typename ptx::FPx2; + + const int row_idx = blockIdx.x; + if (row_idx >= num_rows) return; + + const int actual_row = (row_offsets != nullptr) ? row_offsets[row_idx] : row_idx; + if (actual_row < 0) return; + + const int num_vec2 = num_cols / 2; + const IType2 *input_row = reinterpret_cast(input + actual_row * num_cols); + + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; + for (int i = threadIdx.x; i < num_vec2; i += BLOCK_SIZE) { + const IType2 val = input_row[i]; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, val); + } + const float thread_max = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + float row_amax = + BlockReduce(temp_storage).Reduce(thread_max, [](float a, float b) { return fmaxf(a, b); }); + + __shared__ float shared_s_enc; + if (threadIdx.x == 0) { + const float s_enc = compute_global_encode_scaling_factor_FP4(row_amax); + output_per_token_amax[row_idx] = row_amax; + shared_s_enc = s_enc; + } + __syncthreads(); + const float S_enc = shared_s_enc; + const float S_dec_rowwise = 1.0 / S_enc; + constexpr float fp4_max_inv = 1.0f / detail::TypeExtrema::max; + const float global_encode_scale_multiplier = S_enc * fp4_max_inv; + + const int num_sf_blocks = num_cols / PERTOKEN_SF_VEC_SIZE; + for (int sf_idx = threadIdx.x; sf_idx < num_sf_blocks; sf_idx += BLOCK_SIZE) { + const int col_start = sf_idx * PERTOKEN_SF_VEC_SIZE; + + IType2 block_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; + alignas(8) IType2 vals[PERTOKEN_SF_VEC_SIZE / 2]; + const IType2 *input_block = + reinterpret_cast(input + actual_row * num_cols + col_start); + for (int j = 0; j < PERTOKEN_SF_VEC_SIZE / 2; ++j) { + vals[j] = input_block[j]; + ptx::abs_max_2x(block_amax_2x, block_amax_2x, vals[j]); + } + const float block_max = + static_cast(__hmax(__habs(block_amax_2x.x), __habs(block_amax_2x.y))); + + const float S_dec_b_f32 = + fminf(block_max * global_encode_scale_multiplier, detail::TypeExtrema::max); + const nvfp4_scale_t S_dec_b_fp8 = static_cast(S_dec_b_f32); + output_scales[row_idx * scale_stride + sf_idx] = S_dec_b_fp8; + + constexpr float float_max = detail::TypeExtrema::max; + const float block_scale_inverse = + fminf(1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise), float_max); + const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + + uint8_t *out_ptr = output_data + actual_row * (num_cols / 2) + col_start / 2; + if constexpr (std::is_same_v) { + auto *out_fp4_8x = reinterpret_cast(out_ptr); + for (int j = 0; j < PERTOKEN_SF_VEC_SIZE / 2; j += 4) { + const uint64_t elts03 = *reinterpret_cast(&vals[j]); + const uint64_t elts47 = *reinterpret_cast(&vals[j + 2]); + out_fp4_8x[j / 4] = ptx::mul_cvt_bf16_to_fp4_8x_round_to_nearest( + elts03, elts47, block_scale_inverse); + } + } else { + auto *out_fp4 = reinterpret_cast(out_ptr); + for (int j = 0; j < PERTOKEN_SF_VEC_SIZE / 2; j += 2) { + const float2 in01 = + make_float2(static_cast(vals[j].x), static_cast(vals[j].y)); + const float2 in23 = + make_float2(static_cast(vals[j + 1].x), static_cast(vals[j + 1].y)); + out_fp4[j / 2] = ptx::mul_cvt_fp32_to_fp4_4x( + in01, in23, block_scale_inverse_2x, /*rbits=*/0u); + } + } + } +#endif +} + +template +void launch_quantize_pertoken_nvfp4(const int num_rows, const int num_cols, const IType *input, + const int *row_offsets, uint8_t *output_data, + fp8e4m3 *output_scales, float *output_per_token_amax, + const int scale_stride, cudaStream_t stream, + const float *noop = nullptr) { +#if FP4_TYPE_SUPPORTED + if (num_rows == 0 || num_cols == 0) return; + + NVTE_CHECK(num_cols % PERTOKEN_SF_VEC_SIZE == 0, "num_cols must be a multiple of ", + PERTOKEN_SF_VEC_SIZE, " for per-token NVFP4 quantization, got ", num_cols); + dim3 grid(num_rows); + dim3 block(PERTOKEN_BLOCK_SIZE); + + quantize_pertoken_nvfp4_kernel + <<>>(num_rows, num_cols, input, row_offsets, output_data, + output_scales, output_per_token_amax, scale_stride, noop); + NVTE_CHECK_CUDA(cudaGetLastError()); +#else + NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); +#endif +} + +} // namespace quantize_pertoken_kernel + +inline void quantize_pertoken(const Tensor &input, const Tensor *noop, Tensor *output, + cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + checkCuDriverContext(stream); + CheckNoopTensor(*noop, "cast_noop"); + CheckInputTensor(input, "input"); + CheckOutputTensor(*output, "output", false); + + NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); + NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated."); + NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); + NVTE_CHECK(output->amax.dptr != nullptr, "Per-token amax tensor must be allocated."); + NVTE_CHECK(!output->has_columnwise_data(), + "Per-token NVFP4 quantization supports rowwise-only output."); + NVTE_CHECK(!output->with_gemm_swizzled_scales, "Output must have scales in compact format."); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + NVTE_CHECK(cols % quantize_pertoken_kernel::PERTOKEN_SF_VEC_SIZE == 0, + "Per-token NVFP4 quantization requires last dim divisible by ", + quantize_pertoken_kernel::PERTOKEN_SF_VEC_SIZE, "."); + + const auto *noop_ptr = reinterpret_cast(noop->data.dptr); + auto *data_ptr = reinterpret_cast(output->data.dptr); + auto *scale_ptr = reinterpret_cast(output->scale_inv.dptr); + auto *amax_ptr = reinterpret_cast(output->amax.dptr); + const int *row_offsets = nullptr; + const int scale_stride = static_cast(output->scale_inv.shape.back()); + + if (input.dtype() == DType::kBFloat16) { + quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4<__nv_bfloat16>( + static_cast(rows), static_cast(cols), + reinterpret_cast(input.data.dptr), row_offsets, data_ptr, scale_ptr, + amax_ptr, scale_stride, stream, noop_ptr); + } else if (input.dtype() == DType::kFloat16) { + quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4( + static_cast(rows), static_cast(cols), + reinterpret_cast(input.data.dptr), row_offsets, data_ptr, scale_ptr, amax_ptr, + scale_stride, stream, noop_ptr); + } else { + NVTE_ERROR( + "Unsupported input dtype for per-token NVFP4 quantization. " + "Expected BFloat16 or Float16."); + } +#else + NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); +#endif +} + +} // namespace nvfp4 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_QUANTIZE_PERTOKEN_NVFP4_CUH_ diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index c1b3f8f427..c5b4254e8b 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -470,6 +470,7 @@ struct QuantizationConfig { bool nvfp4_2d_quantization = false; bool stochastic_rounding = false; bool use_fast_math = false; + bool nvfp4_per_token_activation = false; static constexpr size_t attr_sizes[] = { sizeof(uint8_t), // force_pow_2_scales @@ -479,7 +480,8 @@ struct QuantizationConfig { sizeof(NVTETensor), // rng_seed and offset sizeof(uint8_t), // nvfp4_2d_quantization sizeof(uint8_t), // stochastic_rounding - sizeof(uint8_t) // use_fast_math + sizeof(uint8_t), // use_fast_math + sizeof(uint8_t) // nvfp4_per_token_activation }; }; diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index b7461a85d1..0463d51d1c 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -370,6 +370,8 @@ enum NVTEQuantizationConfigAttribute { * inconsistently between kernels. */ kNVTEQuantizationConfigUseFastMath = 7, + /*! Whether to enable per-token (per-row) NVFP4 quantization */ + kNVTEQuantizationConfigNVFP4PerTokenActivation = 8, kNVTEQuantizationConfigNumAttributes }; @@ -1296,6 +1298,13 @@ class QuantizationConfigWrapper { sizeof(val)); } + /*! \brief Set whether to enable per-token NVFP4 quantization */ + void set_nvfp4_per_token_activation(bool nvfp4_per_token_activation) { + const auto val = static_cast(nvfp4_per_token_activation); + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigNVFP4PerTokenActivation, + &val, sizeof(val)); + } + private: /*! \brief Wrapped NVTEQuantizationConfig. */ NVTEQuantizationConfig config_ = nullptr; diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 67b6f87067..e59d01d82a 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -478,6 +478,10 @@ class NVFP4BlockScaling(Recipe): If set to `True`, stochastic rounding is disabled during quantization for all tensors. disable_2d_quantization : bool, default = False If set to `True`, 1D block scaling with block size 16 is used for all tensors. + per_token_activation : bool, default = False + If set to `True`, GroupedLinear activation split quantization uses per-token + (per-row) NVFP4 global amax values. In this mode, rowwise ``amax`` metadata + is stored as a vector with one FP32 value per token. backward_override : {None, 'high_precision', 'dequantized'}, default = None Backward precision mode. None does not modify backward behavior, `high_precision` keeps original high-precision operands for backward, @@ -491,6 +495,7 @@ class NVFP4BlockScaling(Recipe): os.getenv("NVTE_NVFP4_DISABLE_STOCHASTIC_ROUNDING", "0") == "1" ) disable_2d_quantization: bool = os.getenv("NVTE_NVFP4_DISABLE_2D_QUANTIZATION", "0") == "1" + per_token_activation: bool = os.getenv("NVTE_NVFP4_PER_TOKEN_ACTIVATION", "0") == "1" fp4_format: Format = Format.E2M1 fp8_format: Format = Format.E4M3 @@ -534,6 +539,7 @@ def __repr__(self) -> str: f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}, " f"backward_override={self.backward_override}, " + f"per_token_activation={self.per_token_activation}, " f"fp4_quant_fwd_inp={self.fp4_quant_fwd_inp}, " f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, " f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, " diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 1261879a8b..a0a0ffa45f 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -1043,6 +1043,9 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigUseFastMath: bool_to_uint8(config_.use_fast_math, buf); break; + case kNVTEQuantizationConfigNVFP4PerTokenActivation: + bool_to_uint8(config_.nvfp4_per_token_activation, buf); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } @@ -1098,6 +1101,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigUseFastMath: uint8_to_bool(buf, config_.use_fast_math); break; + case kNVTEQuantizationConfigNVFP4PerTokenActivation: + uint8_to_bool(buf, config_.nvfp4_per_token_activation); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 6f3553bf94..9cf58f9dce 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -15,6 +15,7 @@ from ..quantized_tensor import Quantizer from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage +from ..tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage from ..tensor.utils import is_custom from ..custom_recipes.gemm import custom_gemm from ...debug.pytorch.debug_quantization import DebugQuantizer @@ -69,6 +70,50 @@ def validate_gemm_scale(scale: Optional[float], required: bool) -> float: return 0.0 +def _maybe_apply_nvfp4_pertoken_output_rescale( + out: torch.Tensor, + B: torch.Tensor, + *, + layout: str, + bias: Optional[torch.Tensor], + grad: bool, + gelu: bool, + accumulate: bool, +) -> None: + """Apply per-token NVFP4 global-scale correction for TN forward GEMM outputs. + + Current NVFP4 GEMM alpha path consumes one scalar amax. Per-token NVFP4 stores + rowwise amax vector in B._amax_rowwise, so we correct by row using ratio + (amax[row] / amax[0]). If bias was fused in epilogue, remove/reapply it around + the row rescale to avoid bias distortion. + """ + + if grad or gelu or accumulate or layout != "TN": + return + if not isinstance(B, NVFP4TensorStorage): + return + if not isinstance(out, torch.Tensor) or is_custom(out): + return + if out.numel() == 0: + return + amax = B._amax_rowwise + if amax is None or amax.numel() <= 1: + return + + out_2d = out.reshape(-1, out.shape[-1]) + if amax.numel() != out_2d.shape[0]: + return + + ratios = (amax / amax[0]).to(dtype=out.dtype).view(-1, 1) + if bias is not None: + bias_cast = bias.to(dtype=out.dtype) + out_2d.sub_(bias_cast) + out_2d.mul_(ratios) + out_2d.add_(bias_cast) + else: + out_2d.mul_(ratios) + + def general_gemm( A: torch.Tensor, B: torch.Tensor, @@ -147,6 +192,22 @@ def general_gemm( # FP8 block-scaling requires split accumulator use_split_accumulator = True + requested_out_dtype = out_dtype + needs_fp32_rescale_path = ( + layout == "TN" + and not grad + and not gelu + and not accumulate + and isinstance(B, NVFP4TensorStorage) + and B._amax_rowwise is not None + and B._amax_rowwise.numel() > 1 + and quantization_params is None + and out is None + and requested_out_dtype is not None + and requested_out_dtype != torch.float32 + ) + effective_out_dtype = torch.float32 if needs_fp32_rescale_path else requested_out_dtype + args = ( A, transa, # transa @@ -154,7 +215,7 @@ def general_gemm( transb, # transb out, quantization_params, - TE_DType[out_dtype] if out_dtype is not None else None, + TE_DType[effective_out_dtype] if effective_out_dtype is not None else None, bias, bias_dtype, gelu, @@ -175,6 +236,17 @@ def general_gemm( } out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) + _maybe_apply_nvfp4_pertoken_output_rescale( + out, + B, + layout=layout, + bias=bias, + grad=grad, + gelu=gelu, + accumulate=accumulate, + ) + if needs_fp32_rescale_path: + out = out.to(dtype=requested_out_dtype) if debug_quantizer is not None: out = debug_quantizer.process_gemm_output(out) diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 8e3bcdd5b3..b9f852c07d 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -320,6 +320,7 @@ class NVFP4Quantizer : public Quantizer { // 2D block scaling bool with_2d_quantization; bool stochastic_rounding; + bool per_token_activation; int rht_matrix_random_sign_mask_t; at::Tensor rht_matrix; diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 4a2ea7412b..06478b54e0 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -326,6 +326,8 @@ py::object group_dequantize(const py::handle &input, DType otype); py::object bgrad_group_quantize(const at::Tensor &tensor, py::handle quantizer, const size_t num_tensors, std::optional first_dims); +std::tuple quantize_nvfp4_pertoken(at::Tensor input); + std::vector multi_tensor_quantize(const std::vector &tensor_list, std::vector quantizer_list); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 50fe4c109e..f1654f5525 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -801,6 +801,7 @@ std::tuple, std::vector, bool> bulk_alloc const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage; const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode(); const auto fp4_dtype = quantizer_cpp_list[0]->dtype; + const bool per_token_activation = quantizer_cpp_list[0]->per_token_activation; const bool with_gemm_swizzled_scales = false; /// TODO (tmoon) Enable based on optimize_for_gemm; constexpr size_t scale_elem_size = 1; @@ -828,6 +829,16 @@ std::tuple, std::vector, bool> bulk_alloc } return fp4_shape; }; + auto flat_first_dim = [](const std::vector &shape) -> size_t { + if (shape.empty()) { + return 1; + } + size_t rows = 1; + for (size_t i = 0; i + 1 < shape.size(); ++i) { + rows *= shape[i]; + } + return rows; + }; // Allocate row-wise data std::vector rowwise_data_list, rowwise_scale_list, amax_rowwise_list; @@ -866,7 +877,9 @@ std::tuple, std::vector, bool> bulk_alloc // Note: Multi-quantize kernel does not require contiguous amaxes. const auto offset = roundup(buffer_size, 16); amax_offsets.push_back(offset); - buffer_size = offset + 4; + const size_t amax_size = + per_token_activation ? 4 * flat_first_dim(rowwise_data_shapes[i]) : 4; + buffer_size = offset + amax_size; } // Allocate full buffer @@ -879,8 +892,11 @@ std::tuple, std::vector, bool> bulk_alloc data_offsets[i], torch::kUInt8)); rowwise_scale_list.emplace_back( make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); + const std::vector amax_shape = + per_token_activation ? std::vector{flat_first_dim(rowwise_data_shapes[i])} + : std::vector{1}; amax_rowwise_list.emplace_back( - make_torch_view(buffer, std::vector{1}, amax_offsets[i], torch::kFloat32)); + make_torch_view(buffer, amax_shape, amax_offsets[i], torch::kFloat32)); } } @@ -983,7 +999,7 @@ std::tuple, std::vector, bool> bulk_alloc // Set the amax rowwise and amax columnwise if available if (rowwise_usage) { tensor_wrapper.set_amax(amax_rowwise_list[i].data_ptr(), DType::kFloat32, - std::vector{1}); + getTensorShape(amax_rowwise_list[i])); } if (columnwise_usage) { tensor_wrapper.set_columnwise_amax(amax_columnwise_list[i].data_ptr(), DType::kFloat32, @@ -1263,6 +1279,35 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, nvte_tensor_output_list.push_back(output_list[i].data()); } + if (quantizer.per_token_activation) { + NVTE_CHECK(!quantizer.with_rht, "Per-token NVFP4 split quantize does not support RHT."); + NVTE_CHECK(!quantizer.columnwise_usage, + "Per-token NVFP4 split quantize currently supports rowwise-only quantization."); + NVTE_CHECK(!quantizer.with_2d_quantization, + "Per-token NVFP4 split quantize does not support 2D quantization."); + NVTE_CHECK(!quantizer.stochastic_rounding, + "Per-token NVFP4 split quantize does not support stochastic rounding."); + + std::vector quant_config_list; + quant_config_list.reserve(num_tensors); + for (size_t i = 0; i < num_tensors; ++i) { + quant_config_list.emplace_back(QuantizationConfigWrapper()); + quant_config_list.back().set_nvfp4_per_token_activation(true); + } + + for (size_t i = 0; i < num_tensors; i++) { + if (input_list[i].numel() == 0) { + continue; + } + const size_t input_ndim = input_list[i].ndim(); + const size_t cols = input_ndim > 0 ? input_list[i].size(input_ndim - 1) : 1; + NVTE_CHECK(cols % 16 == 0, + "Per-token NVFP4 split quantize requires split inner dim divisible by 16."); + nvte_quantize_v2(input_list[i].data(), output_list[i].data(), quant_config_list[i], stream); + } + return; + } + // In this case without RHT, the rowwise and colwise quantization are fused // we don't need separate rng states for rowwise and colwise bool need_separate_rng_states = false; @@ -1360,8 +1405,13 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, // Check input tensor shape const size_t input_last_dim = input.ndim() > 0 ? input.size(input.ndim() - 1) : 1; - NVTE_CHECK(input_last_dim % 128 == 0, - "NVFP4 multi-quantize requires inner dim to be multiple of 128."); + if (quantizer.per_token_activation) { + NVTE_CHECK(input_last_dim % 16 == 0, + "Per-token NVFP4 split-quantize requires inner dim to be multiple of 16."); + } else { + NVTE_CHECK(input_last_dim % 128 == 0, + "NVFP4 multi-quantize requires inner dim to be multiple of 128."); + } // CUDA stream auto stream = at::cuda::getCurrentCUDAStream(); @@ -1433,12 +1483,25 @@ std::vector split_quantize(const at::Tensor &tensor, for (size_t i = 0; i < num_splits; i++) { quantizer_cpp_list.push_back(convert_quantizer(quantizer_list[i])); } + const bool all_nvfp4_quantizers = std::all_of(quantizer_list.begin(), quantizer_list.end(), + [](const py::handle &quantizer) -> bool { + return detail::IsNVFP4Quantizers(quantizer.ptr()); + }); + const bool all_nvfp4_per_token_activation = + all_nvfp4_quantizers && + std::all_of(quantizer_cpp_list.begin(), quantizer_cpp_list.end(), + [](const std::unique_ptr &quantizer) -> bool { + return static_cast(quantizer.get())->per_token_activation; + }); // Choose implementation for allocating and populating tensors enum class AllocationMethod { UNFUSED, BULK_FP8_BLOCKWISE, BULK_MXFP8, BULK_NVFP4 }; enum class QuantizationMethod { UNFUSED, FUSED_NVFP4 }; AllocationMethod allocation_method = AllocationMethod::UNFUSED; QuantizationMethod quantization_method = QuantizationMethod::UNFUSED; + if (all_nvfp4_per_token_activation) { + quantization_method = QuantizationMethod::FUSED_NVFP4; + } if (!disable_bulk_allocation) { if (std::all_of(quantizer_list.begin(), quantizer_list.end(), [](const py::handle &quantizer) -> bool { @@ -1450,10 +1513,7 @@ std::vector split_quantize(const at::Tensor &tensor, return detail::IsMXFP8Quantizers(quantizer.ptr()); })) { allocation_method = AllocationMethod::BULK_MXFP8; - } else if (std::all_of(quantizer_list.begin(), quantizer_list.end(), - [](const py::handle &quantizer) -> bool { - return detail::IsNVFP4Quantizers(quantizer.ptr()); - })) { + } else if (all_nvfp4_quantizers) { allocation_method = AllocationMethod::BULK_NVFP4; quantization_method = QuantizationMethod::FUSED_NVFP4; } @@ -1492,7 +1552,8 @@ std::vector split_quantize(const at::Tensor &tensor, bool contiguous_data_and_scale = false; std::tie(output_py_list, output_cpp_list, contiguous_data_and_scale) = bulk_allocate_nvfp4_tensors(split_shapes, quantizer_list, nvfp4_quantizers); - if (!input_shape.empty() && input_shape.back() % 128 != 0) { + if (!all_nvfp4_per_token_activation && !input_shape.empty() && + input_shape.back() % 128 != 0) { static std::once_flag once_unfused_nvfp4_fallback_warning; std::call_once(once_unfused_nvfp4_fallback_warning, []() { NVTE_WARN( @@ -1502,7 +1563,7 @@ std::vector split_quantize(const at::Tensor &tensor, }); quantization_method = QuantizationMethod::UNFUSED; } - if (!contiguous_data_and_scale) { + if (!all_nvfp4_per_token_activation && !contiguous_data_and_scale) { // Avoid fused quantize kernel if data is not contiguous quantization_method = QuantizationMethod::UNFUSED; } @@ -1540,5 +1601,53 @@ std::vector split_quantize(const at::Tensor &tensor, return output_py_list; } +std::tuple quantize_nvfp4_pertoken(at::Tensor input) { + init_extension(); + + NVTE_CHECK(input.dim() == 2, "Input must be 2D (num_rows, num_cols)"); + NVTE_CHECK(input.is_cuda(), "Input must be on CUDA device"); + NVTE_CHECK(input.scalar_type() == at::ScalarType::BFloat16 || + input.scalar_type() == at::ScalarType::Half, + "Input must be BFloat16 or Half"); + + const int num_rows = input.size(0); + const int num_cols = input.size(1); + NVTE_CHECK(num_cols % 16 == 0, + "num_cols must be a multiple of 16 for per-token NVFP4 quantization"); + + if (num_rows == 0) { + auto options = input.options(); + return {at::empty({0, num_cols / 2}, options.dtype(at::kByte)), + at::empty({0, num_cols / 16}, options.dtype(at::kByte)), + at::empty({0}, options.dtype(at::kFloat))}; + } + + auto input_contig = input.contiguous(); + auto options = input_contig.options(); + + auto output_data = at::empty({num_rows, num_cols / 2}, options.dtype(at::kByte)); + auto output_scales = at::empty({num_rows, num_cols / 16}, options.dtype(at::kByte)); + auto output_per_token_amax = at::empty({num_rows}, options.dtype(at::kFloat)); + + auto te_input = makeTransformerEngineTensor(input_contig); + TensorWrapper te_output(NVTE_NVFP4_1D_SCALING); + te_output.set_rowwise_data( + output_data.data_ptr(), DType::kFloat4E2M1, + std::vector{static_cast(num_rows), static_cast(num_cols)}); + te_output.set_rowwise_scale_inv( + output_scales.data_ptr(), DType::kFloat8E4M3, + std::vector{static_cast(num_rows), static_cast(num_cols / 16)}); + te_output.set_amax(output_per_token_amax.data_ptr(), DType::kFloat32, + std::vector{static_cast(num_rows)}); + QuantizationConfigWrapper quant_config; + quant_config.set_nvfp4_per_token_activation(true); + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + NVTE_SCOPED_GIL_RELEASE( + { nvte_quantize_v2(te_input.data(), te_output.data(), quant_config, stream); }); + + return {output_data, output_scales, output_per_token_amax}; +} + } // namespace pytorch } // namespace transformer_engine diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index eb7576d905..4021792f86 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -145,6 +145,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Dequantize group tensor", py::arg("input"), py::arg("otype")); m.def("bgrad_group_quantize", transformer_engine::pytorch::bgrad_group_quantize, py::arg("tensor"), py::arg("quantizer"), py::arg("num_tensors"), py::arg("first_dims")); + m.def("quantize_nvfp4_pertoken", transformer_engine::pytorch::quantize_nvfp4_pertoken, + "Per-token NVFP4 quantization", py::arg("input")); m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize, "Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer")); m.def("generic_gemm", transformer_engine::pytorch::gemm, "Compute GEMM (matrix-matrix multiply)", diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index da91e5c170..d6fedc707b 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1696,6 +1696,7 @@ NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantize this->with_post_rht_amax = quantizer.attr("with_post_rht_amax").cast(); this->with_2d_quantization = quantizer.attr("with_2d_quantization").cast(); this->stochastic_rounding = quantizer.attr("stochastic_rounding").cast(); + this->per_token_activation = quantizer.attr("per_token_activation").cast(); // Get amax reduction group if needed for NVFP4 AG const bool with_amax_reduction = quantizer.attr("with_amax_reduction").cast(); @@ -1760,9 +1761,10 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve rowwise_scale_inv_shape.end()); rowwise_data_tensor = at::empty(convert_shape_for_fp4(shape_int64), bit8_tensor_opts); rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); + const int64_t amax_rows = this->per_token_activation ? static_cast(flat_first_dim) : 1; // hadamard amax kernel will zero out pointer with ZeroAmaxKernel // nvte_compute_amax_with_config will zero out the pointer if needed - amax_rowwise = at::empty({1}, bit32_tensor_opts); + amax_rowwise = at::empty({amax_rows}, bit32_tensor_opts); } if (columnwise_usage) { const std::vector scale_inv_shape_int64(columnwise_scale_inv_shape.begin(), @@ -1850,7 +1852,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve out_cpp.set_rowwise_data(rowwise_data_tensor.data_ptr(), DType::kFloat4E2M1, shape); out_cpp.set_rowwise_scale_inv(rowwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3, rowwise_scale_inv_shape); - out_cpp.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, std::vector{1}); + out_cpp.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, getTensorShape(amax_rowwise)); } if (columnwise_usage) { // enforce 2D shape to avoid [S, B, H] shape and B and be 1 @@ -1862,7 +1864,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve out_cpp.set_columnwise_scale_inv(columnwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3, columnwise_scale_inv_shape); out_cpp.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32, - std::vector{1}); + getTensorShape(amax_columnwise)); } out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); this->set_quantization_params(&out_cpp); @@ -1975,15 +1977,22 @@ std::pair NVFP4Quantizer::create_unquantized_tensor_w auto [out_cpp, out_py] = NoneQuantizer(py::none()).create_tensor(shape, dtype); // Register amax pointer from quantized tensor - void* amax_ptr = quantized_tensor.amax(); + auto rowwise_amax = quantized_tensor.get_amax(); + auto columnwise_amax = quantized_tensor.get_columnwise_amax(); + + void* amax_ptr = rowwise_amax.data_ptr; + std::vector amax_shape = convertShape(rowwise_amax.shape); if (amax_ptr == nullptr) { - amax_ptr = quantized_tensor.get_columnwise_amax().data_ptr; + amax_ptr = columnwise_amax.data_ptr; + amax_shape = convertShape(columnwise_amax.shape); } NVTE_CHECK(amax_ptr != nullptr, "Could not extract amax pointer from NVFP4 tensor."); - out_cpp.set_amax(amax_ptr, DType::kFloat32, std::vector{1}); + out_cpp.set_amax(amax_ptr, DType::kFloat32, amax_shape); // Zero out amax - NVTE_CHECK_CUDA(cudaMemsetAsync(amax_ptr, 0, sizeof(float), at::cuda::getCurrentCUDAStream())); + const size_t amax_numel = product(amax_shape); + NVTE_CHECK_CUDA( + cudaMemsetAsync(amax_ptr, 0, amax_numel * sizeof(float), at::cuda::getCurrentCUDAStream())); return {std::move(out_cpp), std::move(out_py)}; } @@ -2050,9 +2059,11 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( } if (!amax_rowwise) { const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + const int64_t amax_rows = + this->per_token_activation ? static_cast(flat_first_dim) : 1; // hadamard amax kernel will zero out pointer with ZeroAmaxKernel // nvte_compute_amax_with_config will zero out the pointer if needed - amax_rowwise = at::empty({1}, opts); + amax_rowwise = at::empty({amax_rows}, opts); tensor.attr("_amax_rowwise") = *amax_rowwise; } } else { // rowwise_usage == false @@ -2118,7 +2129,7 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( out_cpp.set_rowwise_data(rowwise_data->data_ptr(), DType::kFloat4E2M1, shape); out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat8E4M3, getTensorShape(*rowwise_scale_inv)); - out_cpp.set_amax(amax_rowwise->data_ptr(), DType::kFloat32, std::vector{1}); + out_cpp.set_amax(amax_rowwise->data_ptr(), DType::kFloat32, getTensorShape(*amax_rowwise)); } if (columnwise_usage) { // enforce 2D shape to avoid [S, B, H] shape and B and be 1 @@ -2241,6 +2252,22 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou } size_t cols = input.size(input.ndim() - 1); + if (this->per_token_activation) { + NVTE_CHECK(!this->with_rht, "Per-token NVFP4 activation does not support RHT."); + NVTE_CHECK(!this->with_2d_quantization, + "Per-token NVFP4 activation does not support 2D quantization."); + NVTE_CHECK(!this->stochastic_rounding, + "Per-token NVFP4 activation does not support stochastic rounding."); + NVTE_CHECK(!this->columnwise_usage, + "Per-token NVFP4 activation currently supports rowwise-only quantization."); + NVTE_CHECK(!this->with_amax_reduction, + "Per-token NVFP4 activation does not support amax reduction."); + NVTE_CHECK(input.dtype() == DType::kBFloat16 || input.dtype() == DType::kFloat16, + "Per-token NVFP4 activation supports BF16/FP16 inputs only."); + NVTE_CHECK(cols % 16 == 0, "Per-token NVFP4 activation requires last dim divisible by 16."); + quant_config.set_nvfp4_per_token_activation(true); + } + // Restriction for the RHT cast fusion kernel because we are using MMA hardware for computing RHT bool eligible_for_rht_cast_fusion = input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0; @@ -2307,7 +2334,7 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou "Use with_post_rht_amax=true instead."); } } else { // Without RHT - if (compute_amax) { + if (compute_amax && !this->per_token_activation) { // Amax pointers auto rowwise_amax_ptr = out.get_amax().data_ptr; auto columnwise_amax_ptr = out.get_columnwise_amax().data_ptr; @@ -2408,6 +2435,8 @@ void NVFP4Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out, } void NVFP4Quantizer::quantize_with_amax(TensorWrapper& input, TensorWrapper& out) { + NVTE_CHECK(!this->per_token_activation, + "quantize_with_amax is not supported for per-token NVFP4 activation."); // Update output tensor amaxes with input tensor amax auto input_amax_ptr = input.amax(); auto output_rowwise_amax_ptr = out.get_amax().data_ptr; diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py index dd01ae05d3..6a5e400592 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py @@ -350,6 +350,7 @@ def __init__( pow_2_scales: bool = False, eps: float = 0.0, quant_tile_shape: Tuple[int, int] = (1, 16), + per_token_activation: bool = False, with_rht: bool = False, with_random_sign_mask: bool = True, ): @@ -360,6 +361,7 @@ def __init__( self.pow_2_scales = pow_2_scales self.eps = eps self.quant_tile_shape = quant_tile_shape + self.per_token_activation = per_token_activation self.with_rht = with_rht self.with_random_sign_mask = with_random_sign_mask @@ -447,6 +449,7 @@ def _quantize_blockwise_reference( tile_len_y: int, *, pow_2_scales: bool, + per_token_activation: bool, eps: float, # pylint: disable=unused-argument ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -488,6 +491,11 @@ def _quantize_blockwise_reference( decode_scale.to(torch.float32), ) else: + if per_token_activation: + global_amax = global_amax.to(torch.float32).view(m, 1, 1) + else: + global_amax = global_amax.to(torch.float32) + global_encode_scale = torch.div(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX, global_amax) global_encode_scale = torch.min( global_encode_scale, @@ -497,8 +505,15 @@ def _quantize_blockwise_reference( dtype=torch.float32, ), ) - if global_encode_scale == torch.tensor(0.0, device=x.device, dtype=torch.float32): - global_encode_scale = torch.tensor(1.0, device=x.device, dtype=torch.float32) + if global_encode_scale.numel() == 1: + if global_encode_scale == torch.tensor(0.0, device=x.device, dtype=torch.float32): + global_encode_scale = torch.tensor(1.0, device=x.device, dtype=torch.float32) + else: + global_encode_scale = torch.where( + global_encode_scale == 0.0, + torch.ones_like(global_encode_scale), + global_encode_scale, + ) global_decode_scale = torch.div(1.0, global_encode_scale) global_encode_scale_multiplier = global_encode_scale * torch.reciprocal(FLOAT4_E2M1_MAX) @@ -609,6 +624,10 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ raise ValueError( f"MXFP4 only supports 1x32 tile shape, got {self.quant_tile_shape}" ) + if self.per_token_activation: + raise ValueError( + "Per-token activation is only supported for NVFP4 (non-pow2) mode." + ) # TODO(etsykunov): Fix bug where global_amax_row and # global_amax_col are not defined # global_amax = torch.empty(0, device=tensor.device, dtype=torch.float32) @@ -625,13 +644,24 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ if self.with_rht else tensor.t().contiguous() ) - # Compute amax for rowwise and columnwise paths separately - global_amax_row = torch.max(torch.abs(row_input)).to(torch.float32).view(1) - global_amax_col = ( - torch.max(torch.abs(col_input)).to(torch.float32).view(1) - if self.columnwise_usage - else global_amax_row - ) + if self.per_token_activation: + if self.quant_tile_shape != (1, 16): + raise ValueError( + "Per-token activation only supports NVFP4 1x16 tile shape, " + f"got {self.quant_tile_shape}" + ) + if self.columnwise_usage: + raise ValueError("Per-token activation reference supports rowwise-only usage.") + global_amax_row = torch.max(torch.abs(row_input), dim=1).values.to(torch.float32) + global_amax_col = global_amax_row + else: + # Compute amax for rowwise and columnwise paths separately + global_amax_row = torch.max(torch.abs(row_input)).to(torch.float32).view(1) + global_amax_col = ( + torch.max(torch.abs(col_input)).to(torch.float32).view(1) + if self.columnwise_usage + else global_amax_row + ) transpose_scales = False @@ -648,6 +678,7 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ self.quant_tile_shape[1], self.quant_tile_shape[0], pow_2_scales=self.pow_2_scales, + per_token_activation=self.per_token_activation, eps=self.eps, ) if transpose_scales: @@ -671,6 +702,7 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ self.quant_tile_shape[1], self.quant_tile_shape[0], pow_2_scales=self.pow_2_scales, + per_token_activation=False, eps=self.eps, ) @@ -863,6 +895,16 @@ def qgemm( sw = sw.to(torch.float32) factor = 6.0 * 6.0 * 448.0 * 448.0 + if ( + qresult_x.global_amax_row.numel() != 1 + or qresult_w.global_amax_row.numel() != 1 + or qresult_w.global_amax_col.numel() != 1 + or qresult_x.global_amax_col.numel() != 1 + ): + raise ValueError( + "NVFP4QuantizerRef.qgemm expects scalar global amax values; " + "per-token amax vectors are not supported in reference GEMM." + ) if gemm_type == quantization.GEMMType.WGRAD: partial_alpha = qresult_x.global_amax_col * qresult_w.global_amax_col diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 9956fb77ec..6ffca84a7d 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -1375,6 +1375,7 @@ def _make_quantizer(idx: int) -> NVFP4Quantizer: with_post_rht_amax=qparams.random_hadamard_transform, with_2d_quantization=qparams.fp4_2d_quantization, stochastic_rounding=qparams.stochastic_rounding, + per_token_activation=self.recipe.per_token_activation, ) return [_make_quantizer(idx) for idx in range(self.num_quantizers)] @@ -1389,6 +1390,7 @@ def _make_quantizer(idx: int) -> NVFP4Quantizer: with_post_rht_amax=self.recipe.fp4_quant_bwd_grad.random_hadamard_transform, with_2d_quantization=self.recipe.fp4_quant_bwd_grad.fp4_2d_quantization, stochastic_rounding=self.recipe.fp4_quant_bwd_grad.stochastic_rounding, + per_token_activation=self.recipe.per_token_activation, ) for _ in range(self.num_quantizers) ] diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 65678aa347..cd63fb5221 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -128,6 +128,9 @@ class NVFP4Quantizer(Quantizer): """Stochastic rounding, only applicable for gradients.""" stochastic_rounding: bool + """Per-token activation quantization path (grouped split quantize).""" + per_token_activation: bool + """RHT matrix random sign mask""" rht_matrix_random_sign_mask_t: int rht_matrix: torch.Tensor @@ -143,6 +146,7 @@ def __init__( with_post_rht_amax: bool = False, with_2d_quantization: bool = False, stochastic_rounding: bool = False, + per_token_activation: bool = False, with_random_sign_mask: bool = True, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) @@ -153,6 +157,7 @@ def __init__( self.amax_reduction_group = amax_reduction_group self.with_2d_quantization = with_2d_quantization self.stochastic_rounding = stochastic_rounding + self.per_token_activation = per_token_activation self.rht_matrix_random_sign_mask_t = get_random_sign_mask_for_rht( with_random_sign_mask, torch.cuda.current_device() ) @@ -198,6 +203,7 @@ def copy(self) -> NVFP4Quantizer: with_post_rht_amax=self.with_post_rht_amax, with_2d_quantization=self.with_2d_quantization, stochastic_rounding=self.stochastic_rounding, + per_token_activation=self.per_token_activation, ) quantizer.internal = self.internal quantizer.optimize_for_gemm = self.optimize_for_gemm @@ -330,7 +336,10 @@ def make_empty( scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory ) # Allocate per tensor scale inverse. FP32 format. - amax_rowwise = torch.zeros(1, dtype=torch.float32, device=device, pin_memory=pin_memory) + amax_rows = flat_first_dim if self.per_token_activation else 1 + amax_rowwise = torch.zeros( + amax_rows, dtype=torch.float32, device=device, pin_memory=pin_memory + ) # Allocate FP8 data transpose if needed columnwise_data = None From 700cbce0f382a585a09e122bb4ec8dda5913bf5a Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 26 Apr 2026 01:45:09 -0700 Subject: [PATCH 02/45] Add col Signed-off-by: Ziang Li --- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 110 +++++--- .../nvfp4/test_nvfp4_quantize_exact.py | 4 +- tests/pytorch/test_backward_override.py | 13 +- .../common/cast/dispatch/quantize.cuh | 4 - .../cast/nvfp4/quantize_pertoken_nvfp4.cuh | 245 ++++++++++++++++-- .../pytorch/cpp_extensions/gemm.py | 18 +- .../pytorch/csrc/extensions/activation.cpp | 10 +- .../pytorch/csrc/extensions/bias.cpp | 5 +- .../pytorch/csrc/extensions/cast.cpp | 12 +- .../pytorch/csrc/extensions/normalization.cpp | 10 +- transformer_engine/pytorch/csrc/quantizer.cpp | 11 +- .../custom_recipes/quantization_nvfp4.py | 109 ++++++-- .../pytorch/tensor/nvfp4_tensor.py | 3 +- 13 files changed, 438 insertions(+), 116 deletions(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index 911b7660dc..d22442cd64 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -8,6 +8,7 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch import NVFP4Quantizer +from transformer_engine.pytorch.cpp_extensions import general_gemm from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.custom_recipes import utils @@ -15,6 +16,20 @@ recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) +def maybe_skip_pertoken_nvfp4_gemm( + x_dtype: torch.dtype, + *, + accumulate: bool, + x_columnwise: bool, +) -> None: + if x_dtype == torch.float32: + pytest.skip("Per-token NVFP4 kernel supports BF16/FP16 inputs only") + if accumulate: + pytest.skip("Per-token NVFP4 GEMM output rescale does not support accumulation") + if x_columnwise: + pytest.skip("Per-token NVFP4 GEMM output rescale requires rowwise activation usage") + + def check_nvfp4_gemm_versus_reference( x_dtype: torch.dtype, w_dtype: torch.dtype, @@ -26,6 +41,7 @@ def check_nvfp4_gemm_versus_reference( *, x_columnwise: bool = False, w_columnwise: bool = False, + per_token_activation: bool = False, ): te_dtype = tex.DType.kFloat4E2M1 @@ -56,6 +72,7 @@ def check_nvfp4_gemm_versus_reference( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, + per_token_activation=per_token_activation, ) w_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -112,7 +129,16 @@ def check_nvfp4_gemm_versus_reference( sw_trimmed = sw_trimmed.view(torch.float8_e4m3fn) # Create reference quantizer for reference GEMM - ref_quantizer = NVFP4QuantizerRef( + x_ref_quantizer = NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + rowwise=True, + columnwise=not per_token_activation, + pow_2_scales=False, + eps=0.0, + quant_tile_shape=(1, 16), + per_token_activation=per_token_activation, + ) + w_ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, rowwise=True, columnwise=True, @@ -124,16 +150,16 @@ def check_nvfp4_gemm_versus_reference( # Create reference quantized tensors needed by reference GEMM # Reference GEMM is only rowwise. if x_columnwise: - x_nvfp4_ref = ref_quantizer.quantize(x.t().contiguous()) + x_nvfp4_ref = x_ref_quantizer.quantize(x.t().contiguous()) else: - x_nvfp4_ref = ref_quantizer.quantize(x) + x_nvfp4_ref = x_ref_quantizer.quantize(x) if w_columnwise: - w_nvfp4_ref = ref_quantizer.quantize(w.t().contiguous()) + w_nvfp4_ref = w_ref_quantizer.quantize(w.t().contiguous()) else: - w_nvfp4_ref = ref_quantizer.quantize(w) + w_nvfp4_ref = w_ref_quantizer.quantize(w) # Reference GEMM using quantizer's qgemm method - y_ref = ref_quantizer.qgemm( + y_ref = x_ref_quantizer.qgemm( qx=qx_data, qw=qw_data, m_params=None, # MMParams not used in reference @@ -148,7 +174,7 @@ def check_nvfp4_gemm_versus_reference( qresult_w=w_nvfp4_ref, ) - # Native TE GEMM using tex.generic_gemm (cuBLAS GEMM) + # Native TE GEMM path # Allocate cuBLAS workspace workspace = torch.empty(4, dtype=torch.uint8, device=device) @@ -166,27 +192,38 @@ def check_nvfp4_gemm_versus_reference( x_nvfp4_native.update_usage(rowwise_usage=False) if w_columnwise: w_nvfp4_native.update_usage(rowwise_usage=False) - # Native cuBLAS GEMM - # return type is out, bias_grad, gelu_input, extra_output - # We are just capturing out. - y_native = tex.generic_gemm( - w_nvfp4_native, - transa, - x_nvfp4_native, - transb, - out.clone() if accumulate else None, - out_quantizer, - TE_DType[out_dtype], - bias, - bias_dtype, - use_gelu, - gelu_input, - use_grad, - workspace, - workspace.shape[0], - accumulate, - use_split_accumulator, - )[0] + if per_token_activation: + layout = ("T" if transa else "N") + ("T" if transb else "N") + y_native = general_gemm( + w_nvfp4_native, + x_nvfp4_native, + out_dtype=out_dtype, + accumulate=accumulate, + layout=layout, + out=out.clone() if accumulate else None, + )[0] + else: + # Native cuBLAS GEMM + # return type is out, bias_grad, gelu_input, extra_output + # We are just capturing out. + y_native = tex.generic_gemm( + w_nvfp4_native, + transa, + x_nvfp4_native, + transb, + out.clone() if accumulate else None, + out_quantizer, + TE_DType[out_dtype], + bias, + bias_dtype, + use_gelu, + gelu_input, + use_grad, + workspace, + workspace.shape[0], + accumulate, + use_split_accumulator, + )[0] # just in case of accumulation, make sure y_ref and y_native are not the same tensor assert y_ref is not y_native, "y_ref and y_native should not be the same tensor" @@ -224,10 +261,14 @@ def check_nvfp4_gemm_versus_reference( "is_x_columnwise, is_w_columnwise", [ (False, False), # TN - (True, False), # NN + (False, True), # NN + (True, False), # TT (True, True), # NT ], - ids=["rowxrow", "colxrow", "colxcol"], + ids=["rowxrow", "rowxcol", "colxrow", "colxcol"], +) +@pytest.mark.parametrize( + "per_token_activation", [False, True], ids=["nvfp4_per_tensor", "nvfp4_pertoken"] ) def test_nvfp4_gemm_versus_reference( M: int, @@ -239,7 +280,15 @@ def test_nvfp4_gemm_versus_reference( accumulate: bool, is_x_columnwise: bool, is_w_columnwise: bool, + per_token_activation: bool, ): + if per_token_activation: + maybe_skip_pertoken_nvfp4_gemm( + x_dtype=x_dtype, + accumulate=accumulate, + x_columnwise=is_x_columnwise, + ) + check_nvfp4_gemm_versus_reference( x_dtype=x_dtype, w_dtype=w_dtype, @@ -250,4 +299,5 @@ def test_nvfp4_gemm_versus_reference( accumulate=accumulate, x_columnwise=is_x_columnwise, w_columnwise=is_w_columnwise, + per_token_activation=per_token_activation, ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index 7e94911ddd..7e2a587223 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -26,13 +26,11 @@ def unpack_fp4(x: torch.Tensor) -> torch.Tensor: def maybe_skip_pertoken_nvfp4( x_dtype: torch.dtype = torch.bfloat16, *, - return_transpose: bool = False, + return_transpose: bool = False, # pylint: disable=unused-argument with_2d_quantization: bool = False, ) -> None: if x_dtype == torch.float32: pytest.skip("Per-token NVFP4 kernel supports BF16/FP16 inputs only") - if return_transpose: - pytest.skip("Per-token NVFP4 currently supports rowwise-only quantization") if with_2d_quantization: pytest.skip("Per-token NVFP4 does not support 2D quantization") diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index 5da55b14b6..06de1a06f7 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -170,6 +170,9 @@ def _maybe_skip_recipe_dtype( ) -> None: if dtype == torch.bfloat16 and not bf16_available: pytest.skip(reason_for_no_bf16) + if recipe_name == "nvfp4_pertoken" and module_type in ("linear", "layernorm_linear"): + if dtype != torch.bfloat16: + pytest.skip("Per-token NVFP4 activation supports BF16 inputs only in this test") if recipe_name in ("nvfp4", "nvfp4_pertoken"): if module_type in ("linear", "layernorm_linear") and dtype not in ( torch.bfloat16, @@ -183,16 +186,6 @@ def _maybe_skip_recipe_dtype( def _maybe_skip_unsupported_recipe_module_combo(recipe_name: str, module_type: str) -> None: if module_type == "ops_linear" and recipe_name == "fp8_block_scaling": pytest.skip("Fusible ops (te_ops.Linear) do not support Float8BlockScaling recipe") - if recipe_name == "nvfp4_pertoken" and module_type in ( - "linear", - "layernorm_linear", - "ops_linear", - "grouped_linear", - ): - pytest.skip( - "Per-token NVFP4 currently supports rowwise-only quantization paths " - "(columnwise usage is unsupported for these modules)." - ) def _maybe_skip_unsupported_recipe_shape( diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 0d86022cc1..eab27a6e7e 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -103,8 +103,6 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, auto dtype = input_tensor->dtype(); const bool per_token_activation = quant_config_cpp.nvfp4_per_token_activation; if (per_token_activation) { - NVTE_CHECK(!output_tensor->has_columnwise_data(), - "Per-token NVFP4 quantization supports rowwise-only output."); NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, "Per-token NVFP4 quantization does not support 2D quantization."); nvfp4::quantize_pertoken(*input_tensor, noop_tensor, output_tensor, stream); @@ -251,8 +249,6 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens auto dtype = grad_tensor->dtype(); const bool per_token_activation = quant_config_cpp.nvfp4_per_token_activation; if (per_token_activation) { - NVTE_CHECK(!output_tensor->has_columnwise_data(), - "Per-token NVFP4 quantization supports rowwise-only output."); NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, "Per-token NVFP4 quantization does not support 2D quantization."); nvfp4::quantize_pertoken(*grad_tensor, noop_tensor, output_tensor, stream); diff --git a/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh index 5e1e23f5d5..3f6f809d32 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh @@ -161,6 +161,157 @@ void launch_quantize_pertoken_nvfp4(const int num_rows, const int num_cols, cons #endif } +template +__global__ void +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__launch_bounds__(BLOCK_SIZE) +#endif + compute_pertoken_amax_kernel(const int num_rows, const int num_cols, + const IType *__restrict__ input, + float *__restrict__ output_per_token_amax, + const float *__restrict__ noop) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + + using IType2 = typename ptx::FPx2; + + const int row_idx = blockIdx.x; + if (row_idx >= num_rows) return; + + const int num_vec2 = num_cols / 2; + const IType2 *input_row = reinterpret_cast(input + row_idx * num_cols); + + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; + for (int i = threadIdx.x; i < num_vec2; i += BLOCK_SIZE) { + const IType2 val = input_row[i]; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, val); + } + const float thread_max = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + float row_amax = + BlockReduce(temp_storage).Reduce(thread_max, [](float a, float b) { return fmaxf(a, b); }); + + if (threadIdx.x == 0) { + output_per_token_amax[row_idx] = row_amax; + } +#endif +} + +template +void launch_compute_pertoken_amax(const int num_rows, const int num_cols, const IType *input, + float *output_per_token_amax, cudaStream_t stream, + const float *noop = nullptr) { +#if FP4_TYPE_SUPPORTED + if (num_rows == 0 || num_cols == 0) return; + + NVTE_CHECK(num_cols % 2 == 0, "num_cols must be even for per-token amax computation, got ", + num_cols); + dim3 grid(num_rows); + dim3 block(PERTOKEN_BLOCK_SIZE); + + compute_pertoken_amax_kernel + <<>>(num_rows, num_cols, input, output_per_token_amax, noop); + NVTE_CHECK_CUDA(cudaGetLastError()); +#else + NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); +#endif +} + +template +__global__ void +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__launch_bounds__(BLOCK_SIZE) +#endif + quantize_pertoken_nvfp4_columnwise_kernel( + const int num_rows, const int num_cols, const IType *__restrict__ input, + uint8_t *__restrict__ output_data_t, fp8e4m3 *__restrict__ output_scales_t, + const float *__restrict__ per_token_amax, const int scale_stride_t, + const float *__restrict__ noop) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + using namespace detail; + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + + const int col_idx = blockIdx.x; + if (col_idx >= num_cols) return; + + constexpr float fp4_max_inv = 1.0f / TypeExtrema::max; + constexpr float float_max = TypeExtrema::max; + constexpr float one = 1.0f; + const float2 one_2x{one, one}; + const int num_row_blocks = num_rows / PERTOKEN_SF_VEC_SIZE; + + for (int row_block = threadIdx.x; row_block < num_row_blocks; row_block += BLOCK_SIZE) { + const int row_start = row_block * PERTOKEN_SF_VEC_SIZE; + + float vals[PERTOKEN_SF_VEC_SIZE]; + float s_enc[PERTOKEN_SF_VEC_SIZE]; + float scaled_block_amax = 0.0f; +#pragma unroll + for (int i = 0; i < PERTOKEN_SF_VEC_SIZE; ++i) { + const int row_idx = row_start + i; + const float val = static_cast(input[row_idx * num_cols + col_idx]); + const float S_enc = compute_global_encode_scaling_factor_FP4(per_token_amax[row_idx]); + vals[i] = val; + s_enc[i] = S_enc; + scaled_block_amax = fmaxf(scaled_block_amax, fabsf(val) * (S_enc * fp4_max_inv)); + } + + const float S_dec_b_f32 = fminf(scaled_block_amax, float_max); + const nvfp4_scale_t S_dec_b_fp8 = static_cast(S_dec_b_f32); + output_scales_t[col_idx * scale_stride_t + row_block] = S_dec_b_fp8; + + float scaled_vals[PERTOKEN_SF_VEC_SIZE]; +#pragma unroll + for (int i = 0; i < PERTOKEN_SF_VEC_SIZE; ++i) { + const float S_dec_rowwise = 1.0f / s_enc[i]; + const float block_scale_inverse = + fminf(1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise), float_max); + scaled_vals[i] = vals[i] * block_scale_inverse; + } + + uint8_t *out_ptr = output_data_t + col_idx * (num_rows / 2) + row_start / 2; + auto *out_fp4 = reinterpret_cast(out_ptr); +#pragma unroll + for (int j = 0; j < PERTOKEN_SF_VEC_SIZE; j += 4) { + const float2 in01 = make_float2(scaled_vals[j], scaled_vals[j + 1]); + const float2 in23 = make_float2(scaled_vals[j + 2], scaled_vals[j + 3]); + out_fp4[j / 4] = ptx::mul_cvt_fp32_to_fp4_4x( + in01, in23, one_2x, /*rbits=*/0u); + } + } +#endif +} + +template +void launch_quantize_pertoken_nvfp4_columnwise( + const int num_rows, const int num_cols, const IType *input, uint8_t *output_data_t, + fp8e4m3 *output_scales_t, const float *per_token_amax, const int scale_stride_t, + cudaStream_t stream, const float *noop = nullptr) { +#if FP4_TYPE_SUPPORTED + if (num_rows == 0 || num_cols == 0) return; + + NVTE_CHECK(num_rows % PERTOKEN_SF_VEC_SIZE == 0, "num_rows must be a multiple of ", + PERTOKEN_SF_VEC_SIZE, " for per-token NVFP4 columnwise quantization, got ", + num_rows); + dim3 grid(num_cols); + dim3 block(PERTOKEN_BLOCK_SIZE); + + quantize_pertoken_nvfp4_columnwise_kernel + <<>>(num_rows, num_cols, input, output_data_t, output_scales_t, + per_token_amax, scale_stride_t, noop); + NVTE_CHECK_CUDA(cudaGetLastError()); +#else + NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); +#endif +} + } // namespace quantize_pertoken_kernel inline void quantize_pertoken(const Tensor &input, const Tensor *noop, Tensor *output, @@ -172,12 +323,8 @@ inline void quantize_pertoken(const Tensor &input, const Tensor *noop, Tensor *o CheckOutputTensor(*output, "output", false); NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); - NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated."); - NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); - NVTE_CHECK(output->amax.dptr != nullptr, "Per-token amax tensor must be allocated."); - NVTE_CHECK(!output->has_columnwise_data(), - "Per-token NVFP4 quantization supports rowwise-only output."); + NVTE_CHECK(output->has_data() || output->has_columnwise_data(), + "NVFP4 output tensor must be allocated."); NVTE_CHECK(!output->with_gemm_swizzled_scales, "Output must have scales in compact format."); const size_t rows = input.flat_first_dim(); @@ -187,22 +334,86 @@ inline void quantize_pertoken(const Tensor &input, const Tensor *noop, Tensor *o quantize_pertoken_kernel::PERTOKEN_SF_VEC_SIZE, "."); const auto *noop_ptr = reinterpret_cast(noop->data.dptr); - auto *data_ptr = reinterpret_cast(output->data.dptr); - auto *scale_ptr = reinterpret_cast(output->scale_inv.dptr); auto *amax_ptr = reinterpret_cast(output->amax.dptr); + auto *amax_colwise_ptr = reinterpret_cast(output->columnwise_amax.dptr); + auto *per_token_amax_ptr = (amax_ptr != nullptr) ? amax_ptr : amax_colwise_ptr; + NVTE_CHECK(per_token_amax_ptr != nullptr, "Per-token amax tensor must be allocated."); + if (amax_ptr != nullptr) { + NVTE_CHECK(output->amax.numel() == rows, "Per-token rowwise amax must have ", rows, + " entries, got ", output->amax.shape, "."); + } + if (amax_colwise_ptr != nullptr) { + NVTE_CHECK(output->columnwise_amax.numel() == rows, "Per-token columnwise amax must have ", + rows, " entries, got ", output->columnwise_amax.shape, "."); + } const int *row_offsets = nullptr; - const int scale_stride = static_cast(output->scale_inv.shape.back()); if (input.dtype() == DType::kBFloat16) { - quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4<__nv_bfloat16>( - static_cast(rows), static_cast(cols), - reinterpret_cast(input.data.dptr), row_offsets, data_ptr, scale_ptr, - amax_ptr, scale_stride, stream, noop_ptr); + const auto *input_ptr = reinterpret_cast(input.data.dptr); + if (output->has_data()) { + NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Rowwise output must have FP4 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Rowwise scaling tensor must be allocated."); + NVTE_CHECK(output->amax.dptr != nullptr, "Rowwise per-token amax tensor must be allocated."); + auto *data_ptr = reinterpret_cast(output->data.dptr); + auto *scale_ptr = reinterpret_cast(output->scale_inv.dptr); + const int scale_stride = static_cast(output->scale_inv.shape.back()); + quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4<__nv_bfloat16>( + static_cast(rows), static_cast(cols), input_ptr, row_offsets, data_ptr, + scale_ptr, amax_ptr, scale_stride, stream, noop_ptr); + } else { + quantize_pertoken_kernel::launch_compute_pertoken_amax<__nv_bfloat16>( + static_cast(rows), static_cast(cols), input_ptr, per_token_amax_ptr, stream, + noop_ptr); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), + "Columnwise output must have FP4 type."); + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Columnwise scaling tensor must be allocated."); + if (amax_ptr != nullptr && amax_colwise_ptr != nullptr && amax_ptr != amax_colwise_ptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(amax_colwise_ptr, amax_ptr, rows * sizeof(float), + cudaMemcpyDeviceToDevice, stream)); + } + auto *data_t_ptr = reinterpret_cast(output->columnwise_data.dptr); + auto *scale_t_ptr = reinterpret_cast(output->columnwise_scale_inv.dptr); + const int scale_stride_t = static_cast(output->columnwise_scale_inv.shape.back()); + quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4_columnwise<__nv_bfloat16>( + static_cast(rows), static_cast(cols), input_ptr, data_t_ptr, scale_t_ptr, + per_token_amax_ptr, scale_stride_t, stream, noop_ptr); + } } else if (input.dtype() == DType::kFloat16) { - quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4( - static_cast(rows), static_cast(cols), - reinterpret_cast(input.data.dptr), row_offsets, data_ptr, scale_ptr, amax_ptr, - scale_stride, stream, noop_ptr); + const auto *input_ptr = reinterpret_cast(input.data.dptr); + if (output->has_data()) { + NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Rowwise output must have FP4 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Rowwise scaling tensor must be allocated."); + NVTE_CHECK(output->amax.dptr != nullptr, "Rowwise per-token amax tensor must be allocated."); + auto *data_ptr = reinterpret_cast(output->data.dptr); + auto *scale_ptr = reinterpret_cast(output->scale_inv.dptr); + const int scale_stride = static_cast(output->scale_inv.shape.back()); + quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4( + static_cast(rows), static_cast(cols), input_ptr, row_offsets, data_ptr, + scale_ptr, amax_ptr, scale_stride, stream, noop_ptr); + } else { + quantize_pertoken_kernel::launch_compute_pertoken_amax( + static_cast(rows), static_cast(cols), input_ptr, per_token_amax_ptr, stream, + noop_ptr); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), + "Columnwise output must have FP4 type."); + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Columnwise scaling tensor must be allocated."); + if (amax_ptr != nullptr && amax_colwise_ptr != nullptr && amax_ptr != amax_colwise_ptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(amax_colwise_ptr, amax_ptr, rows * sizeof(float), + cudaMemcpyDeviceToDevice, stream)); + } + auto *data_t_ptr = reinterpret_cast(output->columnwise_data.dptr); + auto *scale_t_ptr = reinterpret_cast(output->columnwise_scale_inv.dptr); + const int scale_stride_t = static_cast(output->columnwise_scale_inv.shape.back()); + quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4_columnwise( + static_cast(rows), static_cast(cols), input_ptr, data_t_ptr, scale_t_ptr, + per_token_amax_ptr, scale_stride_t, stream, noop_ptr); + } } else { NVTE_ERROR( "Unsupported input dtype for per-token NVFP4 quantization. " diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 9cf58f9dce..4895054758 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -80,15 +80,15 @@ def _maybe_apply_nvfp4_pertoken_output_rescale( gelu: bool, accumulate: bool, ) -> None: - """Apply per-token NVFP4 global-scale correction for TN forward GEMM outputs. + """Apply per-token NVFP4 global-scale correction for forward GEMM outputs. Current NVFP4 GEMM alpha path consumes one scalar amax. Per-token NVFP4 stores - rowwise amax vector in B._amax_rowwise, so we correct by row using ratio - (amax[row] / amax[0]). If bias was fused in epilogue, remove/reapply it around + rowwise amax vector in B, so we correct by row using ratio (amax[row] / amax[0]) + when B is not transposed. If bias was fused in epilogue, remove/reapply it around the row rescale to avoid bias distortion. """ - if grad or gelu or accumulate or layout != "TN": + if grad or gelu or accumulate or layout[1] != "N": return if not isinstance(B, NVFP4TensorStorage): return @@ -96,7 +96,7 @@ def _maybe_apply_nvfp4_pertoken_output_rescale( return if out.numel() == 0: return - amax = B._amax_rowwise + amax = B._amax_rowwise if B._amax_rowwise is not None else B._amax_columnwise if amax is None or amax.numel() <= 1: return @@ -194,13 +194,15 @@ def general_gemm( requested_out_dtype = out_dtype needs_fp32_rescale_path = ( - layout == "TN" + layout[1] == "N" and not grad and not gelu and not accumulate and isinstance(B, NVFP4TensorStorage) - and B._amax_rowwise is not None - and B._amax_rowwise.numel() > 1 + and ( + (B._amax_rowwise is not None and B._amax_rowwise.numel() > 1) + or (B._amax_columnwise is not None and B._amax_columnwise.numel() > 1) + ) and quantization_params is None and out is None and requested_out_dtype is not None diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 2df3b66553..17f86d63d6 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -42,8 +42,9 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer + if (nvfp4_quantizer_cpp->per_token_activation || + (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { + // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; } else { impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4; @@ -154,8 +155,9 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer + if (nvfp4_quantizer_cpp->per_token_activation || + (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { + // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; } else { impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4; diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index 0cf2025f1b..e2dba46370 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -152,8 +152,9 @@ std::vector dact_dbias( } else if (detail::IsNVFP4Quantizers(quantizer_py.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer + if (nvfp4_quantizer_cpp->per_token_activation || + (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { + // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; } else { impl = Impl::FUSED_DACT_AMAX_NVFP4; diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index f1654f5525..b05d399414 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -944,7 +944,8 @@ std::tuple, std::vector, bool> bulk_alloc // Note: Multi-quantize kernel does not require contiguous amaxes. const auto offset = roundup(buffer_size, 16); amax_offsets.push_back(offset); - buffer_size = offset + 4; + const size_t amax_size = per_token_activation ? 4 * flat_first_dim(shape_list[i]) : 4; + buffer_size = offset + amax_size; } // Allocate full buffer @@ -957,8 +958,11 @@ std::tuple, std::vector, bool> bulk_alloc buffer, to_fp4_shape(columnwise_data_shapes[i]), data_offsets[i], torch::kUInt8)); columnwise_scale_list.emplace_back( make_torch_view(buffer, columnwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); + const std::vector amax_shape = + per_token_activation ? std::vector{flat_first_dim(shape_list[i])} + : std::vector{1}; amax_columnwise_list.emplace_back( - make_torch_view(buffer, std::vector{1}, amax_offsets[i], torch::kFloat32)); + make_torch_view(buffer, amax_shape, amax_offsets[i], torch::kFloat32)); } } @@ -1003,7 +1007,7 @@ std::tuple, std::vector, bool> bulk_alloc } if (columnwise_usage) { tensor_wrapper.set_columnwise_amax(amax_columnwise_list[i].data_ptr(), DType::kFloat32, - std::vector{1}); + getTensorShape(amax_columnwise_list[i])); } tensor_cpp_list.emplace_back(std::move(tensor_wrapper)); @@ -1281,8 +1285,6 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, if (quantizer.per_token_activation) { NVTE_CHECK(!quantizer.with_rht, "Per-token NVFP4 split quantize does not support RHT."); - NVTE_CHECK(!quantizer.columnwise_usage, - "Per-token NVFP4 split quantize currently supports rowwise-only quantization."); NVTE_CHECK(!quantizer.with_2d_quantization, "Per-token NVFP4 split quantize does not support 2D quantization."); NVTE_CHECK(!quantizer.stochastic_rounding, diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index fb4c7aa1c9..3975c01fa5 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -120,8 +120,9 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer + if (nvfp4_quantizer_cpp->per_token_activation || + (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { + // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; } else if (!transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { // TE kernel supports amax output @@ -357,8 +358,9 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer + if (nvfp4_quantizer_cpp->per_token_activation || + (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { + // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; } else if (!transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { // TE kernel supports amax output diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index d6fedc707b..6cc6560d8b 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1779,7 +1779,8 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); // hadamard amax kernel will zero out pointer with ZeroAmaxKernel // nvte_compute_amax_with_config will zero out the pointer if needed - amax_columnwise = at::empty({1}, bit32_tensor_opts); + const int64_t amax_rows = this->per_token_activation ? static_cast(flat_first_dim) : 1; + amax_columnwise = at::empty({amax_rows}, bit32_tensor_opts); } // Convert tensors to Python @@ -2105,7 +2106,9 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); // hadamard amax kernel will zero out pointer with ZeroAmaxKernel // nvte_compute_amax_with_config will zero out the pointer if needed - amax_columnwise = at::empty({1}, opts); + const int64_t amax_rows = + this->per_token_activation ? static_cast(flat_first_dim) : 1; + amax_columnwise = at::empty({amax_rows}, opts); tensor.attr("_amax_columnwise") = *amax_columnwise; } } else { // columnwise_usage == false @@ -2141,7 +2144,7 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E4M3, getTensorShape(*columnwise_scale_inv)); out_cpp.set_columnwise_amax(amax_columnwise->data_ptr(), DType::kFloat32, - std::vector{1}); + getTensorShape(*amax_columnwise)); } out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); this->set_quantization_params(&out_cpp); @@ -2258,8 +2261,6 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou "Per-token NVFP4 activation does not support 2D quantization."); NVTE_CHECK(!this->stochastic_rounding, "Per-token NVFP4 activation does not support stochastic rounding."); - NVTE_CHECK(!this->columnwise_usage, - "Per-token NVFP4 activation currently supports rowwise-only quantization."); NVTE_CHECK(!this->with_amax_reduction, "Per-token NVFP4 activation does not support amax reduction."); NVTE_CHECK(input.dtype() == DType::kBFloat16 || input.dtype() == DType::kFloat16, diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py index 6a5e400592..430af6c581 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py @@ -546,6 +546,71 @@ def _quantize_blockwise_reference( return cast_to_fp4x2(clipped_x), decode_scale.squeeze(-1) + @classmethod + def _quantize_blockwise_pertoken_columnwise_reference( + cls, + x: torch.Tensor, + global_amax: torch.Tensor, + tile_len_x: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if x.ndim != 2: + raise ValueError( + "_quantize_blockwise_pertoken_columnwise_reference expects a 2D tensor, got" + f" {x.ndim}D with shape {x.shape}" + ) + + m, n = x.shape + x = x.view(m, n // tile_len_x, tile_len_x) + FLOAT4_E2M1_MAX = torch.tensor(6.0, device=x.device, dtype=torch.float32) + FLOAT8_E4M3_MAX = torch.tensor(448.0, device=x.device, dtype=torch.float32) + + global_amax = global_amax.to(torch.float32).view(1, n // tile_len_x, tile_len_x) + global_encode_scale = torch.div(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX, global_amax) + global_encode_scale = torch.min( + global_encode_scale, + torch.tensor( + torch.finfo(torch.float32).max, + device=global_encode_scale.device, + dtype=torch.float32, + ), + ) + global_encode_scale = torch.where( + global_encode_scale == 0.0, + torch.ones_like(global_encode_scale), + global_encode_scale, + ) + global_decode_scale = torch.div(1.0, global_encode_scale) + global_encode_scale_multiplier = global_encode_scale * torch.reciprocal(FLOAT4_E2M1_MAX) + + decode_scale = torch.amax( + torch.abs(x.to(torch.float32)) * global_encode_scale_multiplier, + dim=-1, + keepdim=True, + ) + decode_scale = torch.min( + decode_scale, + torch.tensor( + torch.finfo(torch.float32).max, + device=decode_scale.device, + dtype=torch.float32, + ), + ) + decode_scale = torch.clamp(decode_scale, min=-FLOAT8_E4M3_MAX, max=FLOAT8_E4M3_MAX) + decode_scale = decode_scale.to(torch.float8_e4m3fn) + + encode_scale = torch.min( + torch.div(1.0, decode_scale.to(torch.float32) * global_decode_scale), + torch.tensor( + torch.finfo(torch.float32).max, + device=decode_scale.device, + dtype=torch.float32, + ), + ) + scaled_x = x.to(torch.float32) * encode_scale + clipped_x = torch.clamp(scaled_x, -FLOAT4_E2M1_MAX, FLOAT4_E2M1_MAX).reshape(m, n) + + return cast_to_fp4x2(clipped_x), decode_scale.squeeze(-1) + @staticmethod def _pad_tensor( tensor: torch.Tensor, row_divisor: Optional[int], col_divisor: Optional[int] @@ -650,8 +715,6 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ "Per-token activation only supports NVFP4 1x16 tile shape, " f"got {self.quant_tile_shape}" ) - if self.columnwise_usage: - raise ValueError("Per-token activation reference supports rowwise-only usage.") global_amax_row = torch.max(torch.abs(row_input), dim=1).values.to(torch.float32) global_amax_col = global_amax_row else: @@ -696,15 +759,22 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ x_t, row_divisor=self.quant_tile_shape[0], col_divisor=self.quant_tile_shape[1] ) - qx_t, sx_t = self._quantize_blockwise_reference( - x_t_padded, - global_amax_col, - self.quant_tile_shape[1], - self.quant_tile_shape[0], - pow_2_scales=self.pow_2_scales, - per_token_activation=False, - eps=self.eps, - ) + if self.per_token_activation: + qx_t, sx_t = self._quantize_blockwise_pertoken_columnwise_reference( + x_t_padded, + global_amax_col, + self.quant_tile_shape[1], + ) + else: + qx_t, sx_t = self._quantize_blockwise_reference( + x_t_padded, + global_amax_col, + self.quant_tile_shape[1], + self.quant_tile_shape[0], + pow_2_scales=self.pow_2_scales, + per_token_activation=False, + eps=self.eps, + ) qx_t = self._rm_pad_tensor(qx_t, (N, M // 2)) @@ -895,22 +965,15 @@ def qgemm( sw = sw.to(torch.float32) factor = 6.0 * 6.0 * 448.0 * 448.0 - if ( - qresult_x.global_amax_row.numel() != 1 - or qresult_w.global_amax_row.numel() != 1 - or qresult_w.global_amax_col.numel() != 1 - or qresult_x.global_amax_col.numel() != 1 - ): - raise ValueError( - "NVFP4QuantizerRef.qgemm expects scalar global amax values; " - "per-token amax vectors are not supported in reference GEMM." - ) - if gemm_type == quantization.GEMMType.WGRAD: partial_alpha = qresult_x.global_amax_col * qresult_w.global_amax_col else: partial_alpha = qresult_x.global_amax_row * qresult_w.global_amax_row - alpha = torch.div(partial_alpha, factor).squeeze(-1) + if partial_alpha.numel() > 1 and partial_alpha.numel() == high_precision_x.shape[0]: + partial_alpha = partial_alpha.view(-1, 1) + else: + partial_alpha = partial_alpha.squeeze(-1) + alpha = torch.div(partial_alpha, factor) M, K = high_precision_x.shape N, K_w = high_precision_w.shape diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index cd63fb5221..53f77da9e4 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -362,8 +362,9 @@ def make_empty( device=device, pin_memory=pin_memory, ) + amax_rows = flat_first_dim if self.per_token_activation else 1 amax_columnwise = torch.zeros( - 1, dtype=torch.float32, device=device, pin_memory=pin_memory + amax_rows, dtype=torch.float32, device=device, pin_memory=pin_memory ) # Construct FP8 tensor From cfd13bb97392fa971a0d1b7adbb3904514cdbfec Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 26 Apr 2026 02:03:16 -0700 Subject: [PATCH 03/45] Add fp32 Signed-off-by: Ziang Li --- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 4 -- .../nvfp4/test_nvfp4_quantize_exact.py | 10 +-- tests/pytorch/test_backward_override.py | 3 - .../cast/nvfp4/quantize_pertoken_nvfp4.cuh | 70 ++++++++++++++++--- .../pytorch/csrc/extensions/cast.cpp | 3 - transformer_engine/pytorch/csrc/quantizer.cpp | 2 - 6 files changed, 63 insertions(+), 29 deletions(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index d22442cd64..231fb62468 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -17,13 +17,10 @@ def maybe_skip_pertoken_nvfp4_gemm( - x_dtype: torch.dtype, *, accumulate: bool, x_columnwise: bool, ) -> None: - if x_dtype == torch.float32: - pytest.skip("Per-token NVFP4 kernel supports BF16/FP16 inputs only") if accumulate: pytest.skip("Per-token NVFP4 GEMM output rescale does not support accumulation") if x_columnwise: @@ -284,7 +281,6 @@ def test_nvfp4_gemm_versus_reference( ): if per_token_activation: maybe_skip_pertoken_nvfp4_gemm( - x_dtype=x_dtype, accumulate=accumulate, x_columnwise=is_x_columnwise, ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index 7e2a587223..93359b6179 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -24,13 +24,10 @@ def unpack_fp4(x: torch.Tensor) -> torch.Tensor: def maybe_skip_pertoken_nvfp4( - x_dtype: torch.dtype = torch.bfloat16, *, return_transpose: bool = False, # pylint: disable=unused-argument with_2d_quantization: bool = False, ) -> None: - if x_dtype == torch.float32: - pytest.skip("Per-token NVFP4 kernel supports BF16/FP16 inputs only") if with_2d_quantization: pytest.skip("Per-token NVFP4 does not support 2D quantization") @@ -185,7 +182,6 @@ def test_quantization_block_tiling_versus_reference( ) -> None: if per_token_activation: maybe_skip_pertoken_nvfp4( - x_dtype=x_dtype, return_transpose=return_transpose, with_2d_quantization=with_2d_quantization, ) @@ -239,7 +235,7 @@ def test_nvfp4_quantization_extrema_versus_reference( x = torch.zeros((M, N), dtype=x_dtype, device=device) if per_token_activation: - maybe_skip_pertoken_nvfp4(x_dtype=x_dtype, return_transpose=return_transpose) + maybe_skip_pertoken_nvfp4(return_transpose=return_transpose) nvfp4_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -359,7 +355,7 @@ def test_nvfp4_quantization_boundary_values( x = row.unsqueeze(0).repeat(M, 1).to(dtype=x_dtype) if per_token_activation: - maybe_skip_pertoken_nvfp4(x_dtype=x_dtype, return_transpose=return_transpose) + maybe_skip_pertoken_nvfp4(return_transpose=return_transpose) nvfp4_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -465,7 +461,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( assert not x_nc.is_contiguous() if per_token_activation: - maybe_skip_pertoken_nvfp4(x_dtype=x_dtype, return_transpose=return_transpose) + maybe_skip_pertoken_nvfp4(return_transpose=return_transpose) nvfp4_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index 06de1a06f7..6f11069e28 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -170,9 +170,6 @@ def _maybe_skip_recipe_dtype( ) -> None: if dtype == torch.bfloat16 and not bf16_available: pytest.skip(reason_for_no_bf16) - if recipe_name == "nvfp4_pertoken" and module_type in ("linear", "layernorm_linear"): - if dtype != torch.bfloat16: - pytest.skip("Per-token NVFP4 activation supports BF16 inputs only in this test") if recipe_name in ("nvfp4", "nvfp4_pertoken"): if module_type in ("linear", "layernorm_linear") and dtype not in ( torch.bfloat16, diff --git a/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh index 3f6f809d32..feacc2ff6a 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh @@ -37,6 +37,26 @@ using namespace ptx; constexpr int PERTOKEN_BLOCK_SIZE = 256; constexpr int PERTOKEN_SF_VEC_SIZE = 16; +template +__device__ __forceinline__ void abs_max_2x_update(ptx::FPx2 &dst, + const ptx::FPx2 &val) { + if constexpr (std::is_same_v) { + dst.x = fmaxf(fabsf(dst.x), fabsf(val.x)); + dst.y = fmaxf(fabsf(dst.y), fabsf(val.y)); + } else { + ptx::abs_max_2x(dst, dst, val); + } +} + +template +__device__ __forceinline__ float abs_max_2x_to_float(const ptx::FPx2 &val) { + if constexpr (std::is_same_v) { + return fmaxf(fabsf(val.x), fabsf(val.y)); + } else { + return static_cast(__hmax(__habs(val.x), __habs(val.y))); + } +} + template __global__ void #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) @@ -67,10 +87,9 @@ __launch_bounds__(BLOCK_SIZE) IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; for (int i = threadIdx.x; i < num_vec2; i += BLOCK_SIZE) { const IType2 val = input_row[i]; - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, val); + abs_max_2x_update(thread_amax_2x, val); } - const float thread_max = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + const float thread_max = abs_max_2x_to_float(thread_amax_2x); using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; @@ -99,10 +118,9 @@ __launch_bounds__(BLOCK_SIZE) reinterpret_cast(input + actual_row * num_cols + col_start); for (int j = 0; j < PERTOKEN_SF_VEC_SIZE / 2; ++j) { vals[j] = input_block[j]; - ptx::abs_max_2x(block_amax_2x, block_amax_2x, vals[j]); + abs_max_2x_update(block_amax_2x, vals[j]); } - const float block_max = - static_cast(__hmax(__habs(block_amax_2x.x), __habs(block_amax_2x.y))); + const float block_max = abs_max_2x_to_float(block_amax_2x); const float S_dec_b_f32 = fminf(block_max * global_encode_scale_multiplier, detail::TypeExtrema::max); @@ -186,10 +204,9 @@ __launch_bounds__(BLOCK_SIZE) IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; for (int i = threadIdx.x; i < num_vec2; i += BLOCK_SIZE) { const IType2 val = input_row[i]; - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, val); + abs_max_2x_update(thread_amax_2x, val); } - const float thread_max = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + const float thread_max = abs_max_2x_to_float(thread_amax_2x); using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; @@ -414,10 +431,43 @@ inline void quantize_pertoken(const Tensor &input, const Tensor *noop, Tensor *o static_cast(rows), static_cast(cols), input_ptr, data_t_ptr, scale_t_ptr, per_token_amax_ptr, scale_stride_t, stream, noop_ptr); } + } else if (input.dtype() == DType::kFloat32) { + const auto *input_ptr = reinterpret_cast(input.data.dptr); + if (output->has_data()) { + NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Rowwise output must have FP4 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Rowwise scaling tensor must be allocated."); + NVTE_CHECK(output->amax.dptr != nullptr, "Rowwise per-token amax tensor must be allocated."); + auto *data_ptr = reinterpret_cast(output->data.dptr); + auto *scale_ptr = reinterpret_cast(output->scale_inv.dptr); + const int scale_stride = static_cast(output->scale_inv.shape.back()); + quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4( + static_cast(rows), static_cast(cols), input_ptr, row_offsets, data_ptr, + scale_ptr, amax_ptr, scale_stride, stream, noop_ptr); + } else { + quantize_pertoken_kernel::launch_compute_pertoken_amax( + static_cast(rows), static_cast(cols), input_ptr, per_token_amax_ptr, stream, + noop_ptr); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), + "Columnwise output must have FP4 type."); + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Columnwise scaling tensor must be allocated."); + if (amax_ptr != nullptr && amax_colwise_ptr != nullptr && amax_ptr != amax_colwise_ptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(amax_colwise_ptr, amax_ptr, rows * sizeof(float), + cudaMemcpyDeviceToDevice, stream)); + } + auto *data_t_ptr = reinterpret_cast(output->columnwise_data.dptr); + auto *scale_t_ptr = reinterpret_cast(output->columnwise_scale_inv.dptr); + const int scale_stride_t = static_cast(output->columnwise_scale_inv.shape.back()); + quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4_columnwise( + static_cast(rows), static_cast(cols), input_ptr, data_t_ptr, scale_t_ptr, + per_token_amax_ptr, scale_stride_t, stream, noop_ptr); + } } else { NVTE_ERROR( "Unsupported input dtype for per-token NVFP4 quantization. " - "Expected BFloat16 or Float16."); + "Expected BFloat16, Float16, or Float32."); } #else NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index b05d399414..9423aa7296 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -1608,9 +1608,6 @@ std::tuple quantize_nvfp4_pertoken(at::Tenso NVTE_CHECK(input.dim() == 2, "Input must be 2D (num_rows, num_cols)"); NVTE_CHECK(input.is_cuda(), "Input must be on CUDA device"); - NVTE_CHECK(input.scalar_type() == at::ScalarType::BFloat16 || - input.scalar_type() == at::ScalarType::Half, - "Input must be BFloat16 or Half"); const int num_rows = input.size(0); const int num_cols = input.size(1); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 6cc6560d8b..6e6e38a1dd 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -2263,8 +2263,6 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou "Per-token NVFP4 activation does not support stochastic rounding."); NVTE_CHECK(!this->with_amax_reduction, "Per-token NVFP4 activation does not support amax reduction."); - NVTE_CHECK(input.dtype() == DType::kBFloat16 || input.dtype() == DType::kFloat16, - "Per-token NVFP4 activation supports BF16/FP16 inputs only."); NVTE_CHECK(cols % 16 == 0, "Per-token NVFP4 activation requires last dim divisible by 16."); quant_config.set_nvfp4_per_token_activation(true); } From 866d337ff0754682b9b674bbfae3fddc4500c71f Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 26 Apr 2026 15:32:57 -0700 Subject: [PATCH 04/45] Clean up tests Signed-off-by: Ziang Li --- tests/pytorch/test_backward_override.py | 8 +++- tests/pytorch/test_cuda_graphs.py | 20 +++++++-- tests/pytorch/test_sanity.py | 31 ++++++++++--- tests/pytorch/utils.py | 15 +++++++ transformer_engine/common/recipe/__init__.py | 1 - .../pytorch/cpp_extensions/gemm.py | 45 +++++++++++++++++++ 6 files changed, 109 insertions(+), 11 deletions(-) diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index 6f11069e28..c91442562f 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -1042,6 +1042,7 @@ def test_grouped_linear_backward_override_matches_reference( quantized_ref_recipe = make_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) + skip_unsupported_backward_override("grouped_linear", mode_recipe, backward_override) module_quantized_ref = te.GroupedLinear( num_gemms, @@ -1280,6 +1281,7 @@ def test_grouped_linear_runtime_backward_override_switch_updates_ctx( default_recipe = make_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) + skip_unsupported_backward_override("grouped_linear", mode_recipe, backward_override) *_, default_ctx = _run_grouped_linear_single_step_with_ctx_state( module, @@ -1724,7 +1726,11 @@ def test_backward_override_memory_peak_report( x = torch.randn(*input_shape, dtype=dtype, device="cuda") dy = torch.randn(*input_shape[:-1], out_features, dtype=dtype, device="cuda") - modes = (None, "high_precision", "dequantized") + modes = ( + ("high_precision", "dequantized") + if recipe_name == "nvfp4_pertoken" + else (None, "high_precision", "dequantized") + ) mode_results: dict[str, dict[str, float] | str] = {} for mode in modes: diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index a782dadc60..8a01acf0eb 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -20,17 +20,19 @@ is_fp8_available, is_fp8_block_scaling_available, is_mxfp8_available, + is_nvfp4_available, is_bf16_available, ) from transformer_engine.pytorch.quantization import FP8GlobalStateManager import transformer_engine.pytorch.ops as te_ops from transformer_engine.common import recipe -from utils import ModelConfig, reset_rng_states, skip_unsupported_backward_override +from utils import ModelConfig, recipe_id, reset_rng_states, skip_unsupported_backward_override # Check if FP8 is supported. fp8_available = is_fp8_available() fp8_block_scaling_available = is_fp8_block_scaling_available() mxfp8_available = is_mxfp8_available() +nvfp4_available = is_nvfp4_available() # Reset RNG states. reset_rng_states() @@ -62,6 +64,14 @@ def nvfp4_rht_and_2d_quantization(): return nvfp4_recipe +def nvfp4_per_token(): + nvfp4_recipe = recipe.NVFP4BlockScaling(per_token_activation=True) + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + def check_rht_usage(recipe: recipe.Recipe) -> bool: # if using RHT, we can only support bf16 # check fp4_quant_fwd_inp, fp4_quant_fwd_weight, fp4_quant_bwd_grad @@ -88,7 +98,9 @@ def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> fp8_recipes = [] if mxfp8_available: fp8_recipes.append(recipe.MXFP8BlockScaling()) +if nvfp4_available: fp8_recipes.append(nvfp4_rht_and_2d_quantization()) + fp8_recipes.append(nvfp4_per_token()) if fp8_block_scaling_available: fp8_recipes.append(recipe.Float8BlockScaling()) if fp8_available: @@ -360,7 +372,7 @@ def _test_cuda_graphs( @pytest.mark.parametrize("module", _test_cuda_graphs_modules) @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("fp8_params", (False, True)) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None], ids=lambda r: type(r).__name__) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None], ids=recipe_id) @pytest.mark.parametrize("backward_override", (None, "high_precision", "dequantized")) def test_make_graphed_callables( *, @@ -390,6 +402,8 @@ def test_make_graphed_callables( f"Module not yet supported for {fp8_recipe.__class__.__name__} with CUDA graphs" ) if fp8 and fp8_recipe.nvfp4(): + if getattr(fp8_recipe, "per_token_activation", False) and module == "mha": + pytest.skip("Per-token NVFP4 CUDA graph coverage applies to GEMM modules.") if dtype not in get_nvfp4_inp_supported_dtypes(fp8_recipe, dtype): pytest.skip( f"Input dtype {dtype} not supported for NVFP4 Recipe" @@ -448,7 +462,7 @@ def test_make_graphed_callables( ) @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("fp8_params", (False, True)) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=lambda r: type(r).__name__) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("backward_override", (None, "high_precision", "dequantized")) def test_make_graphed_callables_with_fp8_weight_caching( *, diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 7f2f24fd69..c3951c28ed 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -33,17 +33,19 @@ checkpoint, QuantizedTensor, is_bf16_available, + is_nvfp4_available, ) from transformer_engine.common import recipe import transformer_engine_torch as tex from transformer_engine.pytorch.cpp_extensions import general_gemm from transformer_engine.pytorch.tensor.utils import replace_raw_data -from utils import ModelConfig, skip_unsupported_backward_override +from utils import ModelConfig, recipe_id, skip_unsupported_backward_override # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) fp8_block_scaling_available, _ = te.is_fp8_block_scaling_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) +nvfp4_available, _ = is_nvfp4_available(return_reason=True) # Record initial RNG state from script run. seed = 1234 @@ -93,9 +95,18 @@ def nvfp4_vanilla(): return nvfp4_recipe +def nvfp4_per_token(): + nvfp4_recipe = recipe.NVFP4BlockScaling(per_token_activation=True) + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + fp8_recipes = [] if mxfp8_available: fp8_recipes.append(recipe.MXFP8BlockScaling()) +if nvfp4_available: fp8_recipes.append(nvfp4_vanilla()) # TODO: fix check for this if fp8_block_scaling_available: fp8_recipes.append(recipe.Float8BlockScaling()) @@ -103,6 +114,9 @@ def nvfp4_vanilla(): fp8_recipes.append(recipe.Float8CurrentScaling()) fp8_recipes.append(recipe.DelayedScaling()) fp8_recipes.append(None) +fp8_recipes_with_per_token = fp8_recipes.copy() +if nvfp4_available: + fp8_recipes_with_per_token.insert(-1, nvfp4_per_token()) param_types = [torch.float32, torch.float16] if is_bf16_available(): # bf16 requires sm_80 or higher @@ -402,7 +416,7 @@ def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normaliz @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_per_token, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("model", ["small", "weird"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @@ -450,7 +464,7 @@ def test_sanity_layernorm_linear( @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_per_token, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("model", ["small", "weird"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @@ -488,7 +502,7 @@ def test_sanity_linear( @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes_with_zero) @pytest.mark.parametrize("model", ["small", "weird"]) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_per_token, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean) @@ -529,7 +543,7 @@ def test_sanity_linear_with_zero_tokens( @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes_with_zero) @pytest.mark.parametrize("model", ["small", "weird"]) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_per_token, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean) @@ -563,7 +577,12 @@ def test_sanity_grouped_linear( if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") if fp8_recipe.nvfp4(): - pytest.skip("NVFP4 not supported for grouped linear") + if dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") + if not getattr(fp8_recipe, "per_token_activation", False): + pytest.skip("Only per-token NVFP4 is supported for grouped linear") + if fp8_model_params: + pytest.skip("Per-token NVFP4 grouped linear does not support FP8 model params") use_fp8 = fp8_recipe is not None with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 04ac2becbc..6e58538a4a 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -160,12 +160,27 @@ def make_recipe(name: Optional[str], **recipe_kwargs: Any) -> Optional[Recipe]: raise ValueError(f"Unsupported quantization scheme ({name})") +def recipe_id(fp8_recipe: Optional[Recipe]) -> str: + """Readable pytest id for FP8/FP4 recipes.""" + if fp8_recipe is None: + return "None" + if fp8_recipe.nvfp4() and getattr(fp8_recipe, "per_token_activation", False): + return "NVFP4PerTokenBlockScaling" + return type(fp8_recipe).__name__ + + def skip_unsupported_backward_override( layer_type: str, quant_recipe: Optional[Recipe], backward_override: Optional[str], ) -> None: """Skip known unsupported layer/recipe/backward-override combinations used in tests.""" + if ( + quant_recipe is not None + and getattr(quant_recipe, "per_token_activation", False) + and backward_override is None + ): + pytest.skip("Per-token NVFP4 requires an explicit backward override.") if backward_override is None: return if quant_recipe is None and backward_override is not None: diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index e59d01d82a..8f549c8979 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -511,7 +511,6 @@ def __post_init__(self) -> None: assert ( self.backward_override in _BACKWARD_OVERRIDES ), "NVTE_BACKWARD_OVERRIDE must be unset or one of: 'high_precision', 'dequantized'." - # Quantization params # Note: RHT is currently only applied to column-wise usage so that # it can be used for wgrad GEMM. diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 4895054758..de693d823a 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -114,6 +114,14 @@ def _maybe_apply_nvfp4_pertoken_output_rescale( out_2d.mul_(ratios) +def _is_nvfp4_pertoken_tensor(tensor: torch.Tensor) -> bool: + """Whether tensor carries per-token NVFP4 global amax metadata.""" + if not isinstance(tensor, NVFP4TensorStorage): + return False + amax = tensor._amax_rowwise if tensor._amax_rowwise is not None else tensor._amax_columnwise + return amax is not None and amax.numel() > 1 + + def general_gemm( A: torch.Tensor, B: torch.Tensor, @@ -303,6 +311,43 @@ def general_grouped_gemm( else: bias_dtype = TE_DType[torch.bfloat16] + use_pertoken_unfused_fprop = ( + not grad + and not gelu + and not accumulate + and layout[1] == "N" + and D_dtype is None + and all(q is None for q in quantization_params) + and any(_is_nvfp4_pertoken_tensor(tensor) for tensor in B) + ) + if use_pertoken_unfused_fprop: + out_init = out[0] if single_output else None + if single_output: + start_idx = 0 + out_views = [] + for i in range(num_gemms): + size = m_splits[i] + out_views.append(out_init[start_idx : start_idx + size]) + start_idx += size + else: + out_views = out + for i in range(num_gemms): + if out_views[i].numel() == 0: + continue + gemm_out, _, _, _ = general_gemm( + A[i], + B[i], + quantization_params=None, + out_dtype=out_views[i].dtype, + layout=layout, + bias=bias[i] if use_bias else None, + use_split_accumulator=use_split_accumulator, + ) + out_views[i].copy_(gemm_out) + if single_output: + out = out_init + return out, bias, gelu_input + if isinstance(quantization_params[0], DebugQuantizer): assert not gelu, "GELU not supported in debug mode" if single_output: From 5a6ea130cc57a7ed0eabcb21ea9213e818dc093d Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 26 Apr 2026 15:43:04 -0700 Subject: [PATCH 05/45] Clean up ref Signed-off-by: Ziang Li --- .../custom_recipes/quantization_nvfp4.py | 118 +++++------------- 1 file changed, 29 insertions(+), 89 deletions(-) diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py index 430af6c581..d57ea792dd 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py @@ -449,10 +449,14 @@ def _quantize_blockwise_reference( tile_len_y: int, *, pow_2_scales: bool, - per_token_activation: bool, + per_token_rowwise: bool = False, + per_token_columnwise: bool = False, eps: float, # pylint: disable=unused-argument ) -> Tuple[torch.Tensor, torch.Tensor]: + assert not ( + per_token_rowwise and per_token_columnwise + ), "Per-token rowwise and columnwise reference modes are mutually exclusive." if x.ndim != 2: raise ValueError( f"_quantize_blockwise_reference expects a 2D tensor, got {x.ndim}D with shape" @@ -491,10 +495,10 @@ def _quantize_blockwise_reference( decode_scale.to(torch.float32), ) else: - if per_token_activation: + if per_token_rowwise: global_amax = global_amax.to(torch.float32).view(m, 1, 1) - else: - global_amax = global_amax.to(torch.float32) + if per_token_columnwise: + global_amax = global_amax.to(torch.float32).view(1, n // tile_len_x, tile_len_x) global_encode_scale = torch.div(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX, global_amax) global_encode_scale = torch.min( @@ -517,9 +521,16 @@ def _quantize_blockwise_reference( global_decode_scale = torch.div(1.0, global_encode_scale) global_encode_scale_multiplier = global_encode_scale * torch.reciprocal(FLOAT4_E2M1_MAX) - # Match the kernel's default path: fold the FP4 reciprocal into the - # global scale multiplier, but keep the final reciprocal exact. - decode_scale = vec_max * global_encode_scale_multiplier + if per_token_columnwise: + decode_scale = torch.amax( + torch.abs(x.to(torch.float32)) * global_encode_scale_multiplier, + dim=-1, + keepdim=True, + ) + else: + # Match the kernel's default path: fold the FP4 reciprocal into the + # global scale multiplier, but keep the final reciprocal exact. + decode_scale = vec_max * global_encode_scale_multiplier decode_scale = torch.min( decode_scale, torch.tensor( @@ -546,71 +557,6 @@ def _quantize_blockwise_reference( return cast_to_fp4x2(clipped_x), decode_scale.squeeze(-1) - @classmethod - def _quantize_blockwise_pertoken_columnwise_reference( - cls, - x: torch.Tensor, - global_amax: torch.Tensor, - tile_len_x: int, - ) -> Tuple[torch.Tensor, torch.Tensor]: - if x.ndim != 2: - raise ValueError( - "_quantize_blockwise_pertoken_columnwise_reference expects a 2D tensor, got" - f" {x.ndim}D with shape {x.shape}" - ) - - m, n = x.shape - x = x.view(m, n // tile_len_x, tile_len_x) - FLOAT4_E2M1_MAX = torch.tensor(6.0, device=x.device, dtype=torch.float32) - FLOAT8_E4M3_MAX = torch.tensor(448.0, device=x.device, dtype=torch.float32) - - global_amax = global_amax.to(torch.float32).view(1, n // tile_len_x, tile_len_x) - global_encode_scale = torch.div(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX, global_amax) - global_encode_scale = torch.min( - global_encode_scale, - torch.tensor( - torch.finfo(torch.float32).max, - device=global_encode_scale.device, - dtype=torch.float32, - ), - ) - global_encode_scale = torch.where( - global_encode_scale == 0.0, - torch.ones_like(global_encode_scale), - global_encode_scale, - ) - global_decode_scale = torch.div(1.0, global_encode_scale) - global_encode_scale_multiplier = global_encode_scale * torch.reciprocal(FLOAT4_E2M1_MAX) - - decode_scale = torch.amax( - torch.abs(x.to(torch.float32)) * global_encode_scale_multiplier, - dim=-1, - keepdim=True, - ) - decode_scale = torch.min( - decode_scale, - torch.tensor( - torch.finfo(torch.float32).max, - device=decode_scale.device, - dtype=torch.float32, - ), - ) - decode_scale = torch.clamp(decode_scale, min=-FLOAT8_E4M3_MAX, max=FLOAT8_E4M3_MAX) - decode_scale = decode_scale.to(torch.float8_e4m3fn) - - encode_scale = torch.min( - torch.div(1.0, decode_scale.to(torch.float32) * global_decode_scale), - torch.tensor( - torch.finfo(torch.float32).max, - device=decode_scale.device, - dtype=torch.float32, - ), - ) - scaled_x = x.to(torch.float32) * encode_scale - clipped_x = torch.clamp(scaled_x, -FLOAT4_E2M1_MAX, FLOAT4_E2M1_MAX).reshape(m, n) - - return cast_to_fp4x2(clipped_x), decode_scale.squeeze(-1) - @staticmethod def _pad_tensor( tensor: torch.Tensor, row_divisor: Optional[int], col_divisor: Optional[int] @@ -741,7 +687,7 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ self.quant_tile_shape[1], self.quant_tile_shape[0], pow_2_scales=self.pow_2_scales, - per_token_activation=self.per_token_activation, + per_token_rowwise=self.per_token_activation, eps=self.eps, ) if transpose_scales: @@ -759,22 +705,15 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ x_t, row_divisor=self.quant_tile_shape[0], col_divisor=self.quant_tile_shape[1] ) - if self.per_token_activation: - qx_t, sx_t = self._quantize_blockwise_pertoken_columnwise_reference( - x_t_padded, - global_amax_col, - self.quant_tile_shape[1], - ) - else: - qx_t, sx_t = self._quantize_blockwise_reference( - x_t_padded, - global_amax_col, - self.quant_tile_shape[1], - self.quant_tile_shape[0], - pow_2_scales=self.pow_2_scales, - per_token_activation=False, - eps=self.eps, - ) + qx_t, sx_t = self._quantize_blockwise_reference( + x_t_padded, + global_amax_col, + self.quant_tile_shape[1], + self.quant_tile_shape[0], + pow_2_scales=self.pow_2_scales, + per_token_columnwise=self.per_token_activation, + eps=self.eps, + ) qx_t = self._rm_pad_tensor(qx_t, (N, M // 2)) @@ -965,6 +904,7 @@ def qgemm( sw = sw.to(torch.float32) factor = 6.0 * 6.0 * 448.0 * 448.0 + if gemm_type == quantization.GEMMType.WGRAD: partial_alpha = qresult_x.global_amax_col * qresult_w.global_amax_col else: From ee0aafb3ee360b51d8b151177d93f3b712909b00 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 26 Apr 2026 17:58:25 -0700 Subject: [PATCH 06/45] Clean up gemm wrapper Signed-off-by: Ziang Li --- tests/pytorch/test_backward_override.py | 57 ++++--- transformer_engine/common/recipe/__init__.py | 1 + .../pytorch/cpp_extensions/gemm.py | 152 ++++++++---------- 3 files changed, 105 insertions(+), 105 deletions(-) diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index c91442562f..0921035d1e 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -99,6 +99,12 @@ def backward_override(request: pytest.FixtureRequest) -> str: return request.param +def _make_backward_test_recipe(recipe_name: str, **recipe_kwargs) -> Optional[recipe.Recipe]: + if recipe_name == "nvfp4_pertoken" and "backward_override" not in recipe_kwargs: + recipe_kwargs["backward_override"] = "dequantized" + return make_recipe(recipe_name, **recipe_kwargs) + + # -------------------------- # Test cases # -------------------------- @@ -185,6 +191,11 @@ def _maybe_skip_unsupported_recipe_module_combo(recipe_name: str, module_type: s pytest.skip("Fusible ops (te_ops.Linear) do not support Float8BlockScaling recipe") +def _maybe_skip_unsupported_fused_ops(recipe_name: str) -> None: + if recipe_name == "nvfp4_pertoken": + pytest.skip("Per-token NVFP4 currently does not support fused te_ops paths.") + + def _maybe_skip_unsupported_recipe_shape( recipe_name: str, input_shape: tuple[int, ...], @@ -856,7 +867,7 @@ def test_linear_like_backward_override_matches_reference( _maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, module_type) in_features = input_shape[-1] - quantized_ref_recipe = make_recipe(recipe_name) + quantized_ref_recipe = _make_backward_test_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override(module_type, mode_recipe, backward_override) @@ -1040,7 +1051,7 @@ def test_grouped_linear_backward_override_matches_reference( num_gemms = len(m_splits) num_tokens = sum(m_splits) - quantized_ref_recipe = make_recipe(recipe_name) + quantized_ref_recipe = _make_backward_test_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override("grouped_linear", mode_recipe, backward_override) @@ -1209,9 +1220,11 @@ def test_linear_like_runtime_backward_override_switch_updates_ctx( x = torch.randn(*input_shape, dtype=dtype, device="cuda") dy = torch.randn(*input_shape[:-1], out_features, dtype=dtype, device="cuda") - default_recipe = make_recipe(recipe_name) + default_recipe = _make_backward_test_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override(module_type, mode_recipe, backward_override) + expected_default_mode = default_recipe.backward_override + expected_default_fp8 = expected_default_mode is None *_, default_ctx = _run_single_step_with_ctx_state(module, x, dy, default_recipe) ( @@ -1220,10 +1233,10 @@ def test_linear_like_runtime_backward_override_switch_updates_ctx( default_grad_output_quantizer, default_reduce_and_update, ) = default_ctx - assert default_mode is None - assert default_fp8 - assert default_grad_output_quantizer is not None - assert default_reduce_and_update + assert default_mode == expected_default_mode + assert default_fp8 == expected_default_fp8 + assert (default_grad_output_quantizer is not None) == expected_default_fp8 + assert default_reduce_and_update == expected_default_fp8 *_, switched_ctx = _run_single_step_with_ctx_state(module, x, dy, mode_recipe) switched_mode, switched_fp8, switched_grad_output_quantizer, switched_reduce_and_update = ( @@ -1241,10 +1254,10 @@ def test_linear_like_runtime_backward_override_switch_updates_ctx( default_grad_output_quantizer_after, default_reduce_and_update_after, ) = default_ctx_after - assert default_mode_after is None - assert default_fp8_after - assert default_grad_output_quantizer_after is not None - assert default_reduce_and_update_after + assert default_mode_after == expected_default_mode + assert default_fp8_after == expected_default_fp8 + assert (default_grad_output_quantizer_after is not None) == expected_default_fp8 + assert default_reduce_and_update_after == expected_default_fp8 @pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) @@ -1279,9 +1292,11 @@ def test_grouped_linear_runtime_backward_override_switch_updates_ctx( x = torch.randn(num_tokens, in_features, dtype=dtype, device="cuda") dy = torch.randn(num_tokens, out_features, dtype=dtype, device="cuda") - default_recipe = make_recipe(recipe_name) + default_recipe = _make_backward_test_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override("grouped_linear", mode_recipe, backward_override) + expected_default_mode = default_recipe.backward_override + expected_default_fp8 = expected_default_mode is None *_, default_ctx = _run_grouped_linear_single_step_with_ctx_state( module, @@ -1291,9 +1306,9 @@ def test_grouped_linear_runtime_backward_override_switch_updates_ctx( default_recipe, ) default_mode, default_fp8, default_reduce_and_update = default_ctx - assert default_mode is None - assert default_fp8 - assert default_reduce_and_update + assert default_mode == expected_default_mode + assert default_fp8 == expected_default_fp8 + assert default_reduce_and_update == expected_default_fp8 *_, switched_ctx = _run_grouped_linear_single_step_with_ctx_state( module, @@ -1315,9 +1330,9 @@ def test_grouped_linear_runtime_backward_override_switch_updates_ctx( default_recipe, ) default_mode_after, default_fp8_after, default_reduce_and_update_after = default_ctx_after - assert default_mode_after is None - assert default_fp8_after - assert default_reduce_and_update_after + assert default_mode_after == expected_default_mode + assert default_fp8_after == expected_default_fp8 + assert default_reduce_and_update_after == expected_default_fp8 @pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) @@ -1344,10 +1359,11 @@ def test_fused_linear_paths_match_backward_override_reference( _maybe_skip_recipe_dtype(recipe_name, dtype, "ops_linear") _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") _maybe_skip_unsupported_recipe_shape(recipe_name, (m, in_features), "ops_linear") + _maybe_skip_unsupported_fused_ops(recipe_name) reset_rng_states() - quantized_ref_recipe = make_recipe(recipe_name) + quantized_ref_recipe = _make_backward_test_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override("ops_linear", mode_recipe, backward_override) @@ -1483,11 +1499,12 @@ def test_fused_bias_activation_matches_masked_linear_backward( _maybe_skip_recipe_dtype(recipe_name, dtype, "ops_linear") _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") _maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, "ops_linear") + _maybe_skip_unsupported_fused_ops(recipe_name) reset_rng_states() in_features = input_shape[-1] - quantized_ref_recipe = make_recipe(recipe_name) + quantized_ref_recipe = _make_backward_test_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override("ops_linear", mode_recipe, backward_override) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 8f549c8979..e59d01d82a 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -511,6 +511,7 @@ def __post_init__(self) -> None: assert ( self.backward_override in _BACKWARD_OVERRIDES ), "NVTE_BACKWARD_OVERRIDE must be unset or one of: 'high_precision', 'dequantized'." + # Quantization params # Note: RHT is currently only applied to column-wise usage so that # it can be used for wgrad GEMM. diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index de693d823a..f19b175969 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -70,50 +70,6 @@ def validate_gemm_scale(scale: Optional[float], required: bool) -> float: return 0.0 -def _maybe_apply_nvfp4_pertoken_output_rescale( - out: torch.Tensor, - B: torch.Tensor, - *, - layout: str, - bias: Optional[torch.Tensor], - grad: bool, - gelu: bool, - accumulate: bool, -) -> None: - """Apply per-token NVFP4 global-scale correction for forward GEMM outputs. - - Current NVFP4 GEMM alpha path consumes one scalar amax. Per-token NVFP4 stores - rowwise amax vector in B, so we correct by row using ratio (amax[row] / amax[0]) - when B is not transposed. If bias was fused in epilogue, remove/reapply it around - the row rescale to avoid bias distortion. - """ - - if grad or gelu or accumulate or layout[1] != "N": - return - if not isinstance(B, NVFP4TensorStorage): - return - if not isinstance(out, torch.Tensor) or is_custom(out): - return - if out.numel() == 0: - return - amax = B._amax_rowwise if B._amax_rowwise is not None else B._amax_columnwise - if amax is None or amax.numel() <= 1: - return - - out_2d = out.reshape(-1, out.shape[-1]) - if amax.numel() != out_2d.shape[0]: - return - - ratios = (amax / amax[0]).to(dtype=out.dtype).view(-1, 1) - if bias is not None: - bias_cast = bias.to(dtype=out.dtype) - out_2d.sub_(bias_cast) - out_2d.mul_(ratios) - out_2d.add_(bias_cast) - else: - out_2d.mul_(ratios) - - def _is_nvfp4_pertoken_tensor(tensor: torch.Tensor) -> bool: """Whether tensor carries per-token NVFP4 global amax metadata.""" if not isinstance(tensor, NVFP4TensorStorage): @@ -200,24 +156,6 @@ def general_gemm( # FP8 block-scaling requires split accumulator use_split_accumulator = True - requested_out_dtype = out_dtype - needs_fp32_rescale_path = ( - layout[1] == "N" - and not grad - and not gelu - and not accumulate - and isinstance(B, NVFP4TensorStorage) - and ( - (B._amax_rowwise is not None and B._amax_rowwise.numel() > 1) - or (B._amax_columnwise is not None and B._amax_columnwise.numel() > 1) - ) - and quantization_params is None - and out is None - and requested_out_dtype is not None - and requested_out_dtype != torch.float32 - ) - effective_out_dtype = torch.float32 if needs_fp32_rescale_path else requested_out_dtype - args = ( A, transa, # transa @@ -225,7 +163,7 @@ def general_gemm( transb, # transb out, quantization_params, - TE_DType[effective_out_dtype] if effective_out_dtype is not None else None, + TE_DType[out_dtype] if out_dtype is not None else None, bias, bias_dtype, gelu, @@ -245,18 +183,57 @@ def general_gemm( "beta": beta, } - out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) - _maybe_apply_nvfp4_pertoken_output_rescale( - out, - B, - layout=layout, - bias=bias, - grad=grad, - gelu=gelu, - accumulate=accumulate, - ) - if needs_fp32_rescale_path: - out = out.to(dtype=requested_out_dtype) + if not _is_nvfp4_pertoken_tensor(B): + out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) + else: + assert layout[1] == "N", "Per-token NVFP4 GEMM currently supports N-layout B only." + assert not grad, "Per-token NVFP4 GEMM currently supports fprop only." + assert not gelu, "Per-token NVFP4 GEMM currently does not support fused GELU." + assert not accumulate, "Per-token NVFP4 GEMM currently does not support accumulation." + assert ( + quantization_params is None + ), "Per-token NVFP4 GEMM currently does not support output quantization." + assert out is None or ( + isinstance(out, torch.Tensor) and not is_custom(out) + ), "Per-token NVFP4 GEMM currently supports only plain torch.Tensor outputs." + requested_out = out + requested_out_dtype = out_dtype + fp32_out = ( + torch.empty_like(requested_out, dtype=torch.float32) + if requested_out is not None + else None + ) + # Override only output, output quantizer, and output dtype for the FP32 correction path. + args = ( + *args[:4], + fp32_out, + None, + TE_DType[torch.float32], + *args[7:], + ) + out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) + + assert isinstance(out, torch.Tensor) and not is_custom(out) + assert out.numel() > 0 + amax = B._amax_rowwise if B._amax_rowwise is not None else B._amax_columnwise + assert amax is not None and amax.numel() > 1 + + out_2d = out.reshape(-1, out.shape[-1]) + assert amax.numel() == out_2d.shape[0] + ratios = (amax / amax[0]).to(dtype=out.dtype).view(-1, 1) + if bias is not None: + bias_cast = bias.to(dtype=out.dtype) + out_2d.sub_(bias_cast) + out_2d.mul_(ratios) + out_2d.add_(bias_cast) + else: + out_2d.mul_(ratios) + + if requested_out is not None: + requested_out.copy_(out.to(dtype=requested_out.dtype)) + out = requested_out + elif requested_out_dtype is not None and requested_out_dtype != torch.float32: + out = out.to(dtype=requested_out_dtype) if debug_quantizer is not None: out = debug_quantizer.process_gemm_output(out) @@ -311,16 +288,21 @@ def general_grouped_gemm( else: bias_dtype = TE_DType[torch.bfloat16] - use_pertoken_unfused_fprop = ( - not grad - and not gelu - and not accumulate - and layout[1] == "N" - and D_dtype is None - and all(q is None for q in quantization_params) - and any(_is_nvfp4_pertoken_tensor(tensor) for tensor in B) - ) - if use_pertoken_unfused_fprop: + if any(_is_nvfp4_pertoken_tensor(tensor) for tensor in B): + assert layout[1] == "N", "Per-token NVFP4 grouped GEMM currently supports N-layout B only." + assert not grad, "Per-token NVFP4 grouped GEMM currently supports fprop only." + assert not gelu, "Per-token NVFP4 grouped GEMM currently does not support fused GELU." + assert ( + not accumulate + ), "Per-token NVFP4 grouped GEMM currently does not support accumulation." + assert D_dtype is None, "Per-token NVFP4 grouped GEMM currently does not support D_dtype." + assert all( + q is None for q in quantization_params + ), "Per-token NVFP4 grouped GEMM currently does not support output quantization." + if single_output: + assert ( + m_splits is not None + ), "Per-token NVFP4 grouped GEMM requires m_splits with single output." out_init = out[0] if single_output else None if single_output: start_idx = 0 From e852804fddbc17c3b610a6dad7fc3d4533e0802a Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 26 Apr 2026 18:18:30 -0700 Subject: [PATCH 07/45] Clean up test Signed-off-by: Ziang Li --- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 28 +++++---------- .../nvfp4/test_nvfp4_quantize_exact.py | 34 ++++--------------- 2 files changed, 15 insertions(+), 47 deletions(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index 231fb62468..27b5d0626f 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -16,17 +16,6 @@ recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) -def maybe_skip_pertoken_nvfp4_gemm( - *, - accumulate: bool, - x_columnwise: bool, -) -> None: - if accumulate: - pytest.skip("Per-token NVFP4 GEMM output rescale does not support accumulation") - if x_columnwise: - pytest.skip("Per-token NVFP4 GEMM output rescale requires rowwise activation usage") - - def check_nvfp4_gemm_versus_reference( x_dtype: torch.dtype, w_dtype: torch.dtype, @@ -171,7 +160,7 @@ def check_nvfp4_gemm_versus_reference( qresult_w=w_nvfp4_ref, ) - # Native TE GEMM path + # Native TE GEMM using tex.generic_gemm (cuBLAS GEMM) # Allocate cuBLAS workspace workspace = torch.empty(4, dtype=torch.uint8, device=device) @@ -258,14 +247,13 @@ def check_nvfp4_gemm_versus_reference( "is_x_columnwise, is_w_columnwise", [ (False, False), # TN - (False, True), # NN - (True, False), # TT + (True, False), # NN (True, True), # NT ], - ids=["rowxrow", "rowxcol", "colxrow", "colxcol"], + ids=["rowxrow", "colxrow", "colxcol"], ) @pytest.mark.parametrize( - "per_token_activation", [False, True], ids=["nvfp4_per_tensor", "nvfp4_pertoken"] + "per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"] ) def test_nvfp4_gemm_versus_reference( M: int, @@ -280,10 +268,10 @@ def test_nvfp4_gemm_versus_reference( per_token_activation: bool, ): if per_token_activation: - maybe_skip_pertoken_nvfp4_gemm( - accumulate=accumulate, - x_columnwise=is_x_columnwise, - ) + if accumulate: + pytest.skip("Per-token NVFP4 GEMM output rescale does not support accumulation") + if is_x_columnwise: + pytest.skip("Per-token NVFP4 GEMM output rescale requires rowwise activation usage") check_nvfp4_gemm_versus_reference( x_dtype=x_dtype, diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index 93359b6179..a804e8f1ba 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -23,15 +23,6 @@ def unpack_fp4(x: torch.Tensor) -> torch.Tensor: return repeated -def maybe_skip_pertoken_nvfp4( - *, - return_transpose: bool = False, # pylint: disable=unused-argument - with_2d_quantization: bool = False, -) -> None: - if with_2d_quantization: - pytest.skip("Per-token NVFP4 does not support 2D quantization") - - def check_quantization_nvfp4_versus_reference( x_dtype: torch.dtype, M: int, @@ -168,7 +159,7 @@ def check_quantization_nvfp4_versus_reference( "with_2d_quantization", [True, False], ids=["2d_quantization", "1d_quantization"] ) @pytest.mark.parametrize( - "per_token_activation", [False, True], ids=["nvfp4_per_tensor", "nvfp4_pertoken"] + "per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"] ) def test_quantization_block_tiling_versus_reference( x_dtype: torch.dtype, @@ -180,11 +171,9 @@ def test_quantization_block_tiling_versus_reference( with_2d_quantization: bool, per_token_activation: bool, ) -> None: - if per_token_activation: - maybe_skip_pertoken_nvfp4( - return_transpose=return_transpose, - with_2d_quantization=with_2d_quantization, - ) + if per_token_activation and with_2d_quantization: + pytest.skip("Per-token NVFP4 does not support 2D quantization") + check_quantization_nvfp4_versus_reference( x_dtype=x_dtype, M=M, @@ -211,7 +200,7 @@ def test_quantization_block_tiling_versus_reference( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) @pytest.mark.parametrize( - "per_token_activation", [False, True], ids=["nvfp4_per_tensor", "nvfp4_pertoken"] + "per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"] ) def test_nvfp4_quantization_extrema_versus_reference( x_dtype: torch.dtype, @@ -234,9 +223,6 @@ def test_nvfp4_quantization_extrema_versus_reference( else: x = torch.zeros((M, N), dtype=x_dtype, device=device) - if per_token_activation: - maybe_skip_pertoken_nvfp4(return_transpose=return_transpose) - nvfp4_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, rowwise=True, @@ -318,7 +304,7 @@ def test_nvfp4_quantization_extrema_versus_reference( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) @pytest.mark.parametrize( - "per_token_activation", [False, True], ids=["nvfp4_per_tensor", "nvfp4_pertoken"] + "per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"] ) def test_nvfp4_quantization_boundary_values( x_dtype: torch.dtype, @@ -354,9 +340,6 @@ def test_nvfp4_quantization_boundary_values( row[1::2] = upper x = row.unsqueeze(0).repeat(M, 1).to(dtype=x_dtype) - if per_token_activation: - maybe_skip_pertoken_nvfp4(return_transpose=return_transpose) - nvfp4_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, rowwise=True, @@ -438,7 +421,7 @@ def test_nvfp4_quantization_boundary_values( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) @pytest.mark.parametrize( - "per_token_activation", [False, True], ids=["nvfp4_per_tensor", "nvfp4_pertoken"] + "per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"] ) def test_nvfp4_quantization_noncontiguous_inputs( x_dtype: torch.dtype, @@ -460,9 +443,6 @@ def test_nvfp4_quantization_noncontiguous_inputs( x_nc = x_base.t() # shape (N, M), non-contiguous assert not x_nc.is_contiguous() - if per_token_activation: - maybe_skip_pertoken_nvfp4(return_transpose=return_transpose) - nvfp4_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, rowwise=True, From 9dbb3ad02f1550764805363f2cf2e8c7ff8084b4 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 26 Apr 2026 18:45:02 -0700 Subject: [PATCH 08/45] Clean up Signed-off-by: Ziang Li --- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 4 +--- tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py | 16 ++++------------ tests/pytorch/test_sanity.py | 12 +++--------- 3 files changed, 8 insertions(+), 24 deletions(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index 27b5d0626f..b2862cc63d 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -252,9 +252,7 @@ def check_nvfp4_gemm_versus_reference( ], ids=["rowxrow", "colxrow", "colxcol"], ) -@pytest.mark.parametrize( - "per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"] -) +@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"]) def test_nvfp4_gemm_versus_reference( M: int, K: int, diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index a804e8f1ba..d21e6a6e37 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -158,9 +158,7 @@ def check_quantization_nvfp4_versus_reference( @pytest.mark.parametrize( "with_2d_quantization", [True, False], ids=["2d_quantization", "1d_quantization"] ) -@pytest.mark.parametrize( - "per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"] -) +@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"]) def test_quantization_block_tiling_versus_reference( x_dtype: torch.dtype, M: int, @@ -199,9 +197,7 @@ def test_quantization_block_tiling_versus_reference( @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) -@pytest.mark.parametrize( - "per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"] -) +@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"]) def test_nvfp4_quantization_extrema_versus_reference( x_dtype: torch.dtype, M: int, @@ -303,9 +299,7 @@ def test_nvfp4_quantization_extrema_versus_reference( @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) -@pytest.mark.parametrize( - "per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"] -) +@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"]) def test_nvfp4_quantization_boundary_values( x_dtype: torch.dtype, M: int, @@ -420,9 +414,7 @@ def test_nvfp4_quantization_boundary_values( @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) -@pytest.mark.parametrize( - "per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"] -) +@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"]) def test_nvfp4_quantization_noncontiguous_inputs( x_dtype: torch.dtype, M: int, diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index c3951c28ed..73c291cc15 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -33,7 +33,6 @@ checkpoint, QuantizedTensor, is_bf16_available, - is_nvfp4_available, ) from transformer_engine.common import recipe import transformer_engine_torch as tex @@ -45,7 +44,7 @@ fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) fp8_block_scaling_available, _ = te.is_fp8_block_scaling_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) -nvfp4_available, _ = is_nvfp4_available(return_reason=True) +nvfp4_available, _ = te.is_nvfp4_available(return_reason=True) # Record initial RNG state from script run. seed = 1234 @@ -543,7 +542,7 @@ def test_sanity_linear_with_zero_tokens( @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes_with_zero) @pytest.mark.parametrize("model", ["small", "weird"]) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_per_token, ids=recipe_id) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean) @@ -577,12 +576,7 @@ def test_sanity_grouped_linear( if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") if fp8_recipe.nvfp4(): - if dtype == torch.float16: - pytest.skip("FP16 output for NVFP4 not supported") - if not getattr(fp8_recipe, "per_token_activation", False): - pytest.skip("Only per-token NVFP4 is supported for grouped linear") - if fp8_model_params: - pytest.skip("Per-token NVFP4 grouped linear does not support FP8 model params") + pytest.skip("NVFP4 not supported for grouped linear") use_fp8 = fp8_recipe is not None with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): From 475de8a604c43775129db2993709bf052fd07ac7 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 26 Apr 2026 18:48:50 -0700 Subject: [PATCH 09/45] Rename and reformat Signed-off-by: Ziang Li --- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 2 +- .../nvfp4/test_nvfp4_quantize_exact.py | 8 ++-- tests/pytorch/test_backward_override.py | 16 ++++---- tests/pytorch/utils.py | 4 +- .../cast/nvfp4/quantize_pertoken_nvfp4.cuh | 37 +++++++++++-------- 5 files changed, 36 insertions(+), 31 deletions(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index b2862cc63d..3708205aef 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -252,7 +252,7 @@ def check_nvfp4_gemm_versus_reference( ], ids=["rowxrow", "colxrow", "colxcol"], ) -@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"]) +@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_per_token"]) def test_nvfp4_gemm_versus_reference( M: int, K: int, diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index d21e6a6e37..cf801639b7 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -158,7 +158,7 @@ def check_quantization_nvfp4_versus_reference( @pytest.mark.parametrize( "with_2d_quantization", [True, False], ids=["2d_quantization", "1d_quantization"] ) -@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"]) +@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_per_token"]) def test_quantization_block_tiling_versus_reference( x_dtype: torch.dtype, M: int, @@ -197,7 +197,7 @@ def test_quantization_block_tiling_versus_reference( @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) -@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"]) +@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_per_token"]) def test_nvfp4_quantization_extrema_versus_reference( x_dtype: torch.dtype, M: int, @@ -299,7 +299,7 @@ def test_nvfp4_quantization_extrema_versus_reference( @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) -@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"]) +@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_per_token"]) def test_nvfp4_quantization_boundary_values( x_dtype: torch.dtype, M: int, @@ -414,7 +414,7 @@ def test_nvfp4_quantization_boundary_values( @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) -@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"]) +@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_per_token"]) def test_nvfp4_quantization_noncontiguous_inputs( x_dtype: torch.dtype, M: int, diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index 0921035d1e..2156d6cef0 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -79,7 +79,7 @@ id="NVFP4BlockScaling", ), pytest.param( - "nvfp4_pertoken", + "nvfp4_per_token", marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), id="NVFP4PerTokenBlockScaling", ), @@ -100,7 +100,7 @@ def backward_override(request: pytest.FixtureRequest) -> str: def _make_backward_test_recipe(recipe_name: str, **recipe_kwargs) -> Optional[recipe.Recipe]: - if recipe_name == "nvfp4_pertoken" and "backward_override" not in recipe_kwargs: + if recipe_name == "nvfp4_per_token" and "backward_override" not in recipe_kwargs: recipe_kwargs["backward_override"] = "dequantized" return make_recipe(recipe_name, **recipe_kwargs) @@ -176,7 +176,7 @@ def _maybe_skip_recipe_dtype( ) -> None: if dtype == torch.bfloat16 and not bf16_available: pytest.skip(reason_for_no_bf16) - if recipe_name in ("nvfp4", "nvfp4_pertoken"): + if recipe_name in ("nvfp4", "nvfp4_per_token"): if module_type in ("linear", "layernorm_linear") and dtype not in ( torch.bfloat16, torch.float32, @@ -192,7 +192,7 @@ def _maybe_skip_unsupported_recipe_module_combo(recipe_name: str, module_type: s def _maybe_skip_unsupported_fused_ops(recipe_name: str) -> None: - if recipe_name == "nvfp4_pertoken": + if recipe_name == "nvfp4_per_token": pytest.skip("Per-token NVFP4 currently does not support fused te_ops paths.") @@ -211,7 +211,7 @@ def _maybe_skip_unsupported_recipe_shape( " by 32." ) return - if recipe_name in ("nvfp4", "nvfp4_pertoken") and ( + if recipe_name in ("nvfp4", "nvfp4_per_token") and ( flat_first_dim % 16 != 0 or last_dim % 16 != 0 ): pytest.skip( @@ -238,7 +238,7 @@ def _maybe_skip_unsupported_recipe_shape( pytest.skip( "te_ops.Linear + MXFP8 requires prod(shape[:-1]) and shape[-1] divisible by 32." ) - if recipe_name in ("nvfp4", "nvfp4_pertoken") and ( + if recipe_name in ("nvfp4", "nvfp4_per_token") and ( flat_first_dim % 16 != 0 or last_dim % 16 != 0 ): pytest.skip( @@ -259,7 +259,7 @@ def _maybe_skip_unsupported_grouped_splits(recipe_name: str, m_splits: list[int] ) if recipe_name == "mxfp8" and any(m % 32 != 0 for m in non_empty_splits): pytest.skip("GroupedLinear + MXFP8 requires each non-empty m_split divisible by 32.") - if recipe_name in ("nvfp4", "nvfp4_pertoken") and any(m % 16 != 0 for m in non_empty_splits): + if recipe_name in ("nvfp4", "nvfp4_per_token") and any(m % 16 != 0 for m in non_empty_splits): pytest.skip("GroupedLinear + NVFP4 requires each non-empty m_split divisible by 16.") if recipe_name == "nvfp4" and any(m % 64 != 0 for m in non_empty_splits): pytest.skip( @@ -1745,7 +1745,7 @@ def test_backward_override_memory_peak_report( modes = ( ("high_precision", "dequantized") - if recipe_name == "nvfp4_pertoken" + if recipe_name == "nvfp4_per_token" else (None, "high_precision", "dequantized") ) mode_results: dict[str, dict[str, float] | str] = {} diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 6e58538a4a..b88bcd31b5 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -115,7 +115,7 @@ def quantization_tols(name: str) -> dict[str, float]: "mxfp8_block_scaling", ): return dtype_tols(tex.DType.kFloat8E4M3) - if name in ("nvfp4", "nvfp4_pertoken"): + if name in ("nvfp4", "nvfp4_per_token"): return dtype_tols(tex.DType.kFloat4E2M1) raise ValueError(f"Unsupported quantization scheme ({name})") @@ -149,7 +149,7 @@ def make_recipe(name: Optional[str], **recipe_kwargs: Any) -> Optional[Recipe]: disable_2d_quantization=True, **recipe_kwargs, ) - if name == "nvfp4_pertoken": + if name == "nvfp4_per_token": return transformer_engine.common.recipe.NVFP4BlockScaling( disable_rht=True, disable_stochastic_rounding=True, diff --git a/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh index feacc2ff6a..36eb05115d 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh @@ -62,11 +62,13 @@ __global__ void #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) __launch_bounds__(BLOCK_SIZE) #endif - quantize_pertoken_nvfp4_kernel( - const int num_rows, const int num_cols, const IType *__restrict__ input, - const int *__restrict__ row_offsets, uint8_t *__restrict__ output_data, - fp8e4m3 *__restrict__ output_scales, float *__restrict__ output_per_token_amax, - const int scale_stride, const float *__restrict__ noop) { + quantize_pertoken_nvfp4_kernel(const int num_rows, const int num_cols, + const IType *__restrict__ input, + const int *__restrict__ row_offsets, + uint8_t *__restrict__ output_data, + fp8e4m3 *__restrict__ output_scales, + float *__restrict__ output_per_token_amax, + const int scale_stride, const float *__restrict__ noop) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) using namespace detail; if (noop != nullptr && noop[0] == 1.0f) { @@ -244,11 +246,13 @@ __global__ void #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) __launch_bounds__(BLOCK_SIZE) #endif - quantize_pertoken_nvfp4_columnwise_kernel( - const int num_rows, const int num_cols, const IType *__restrict__ input, - uint8_t *__restrict__ output_data_t, fp8e4m3 *__restrict__ output_scales_t, - const float *__restrict__ per_token_amax, const int scale_stride_t, - const float *__restrict__ noop) { + quantize_pertoken_nvfp4_columnwise_kernel(const int num_rows, const int num_cols, + const IType *__restrict__ input, + uint8_t *__restrict__ output_data_t, + fp8e4m3 *__restrict__ output_scales_t, + const float *__restrict__ per_token_amax, + const int scale_stride_t, + const float *__restrict__ noop) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) using namespace detail; if (noop != nullptr && noop[0] == 1.0f) { @@ -307,16 +311,17 @@ __launch_bounds__(BLOCK_SIZE) } template -void launch_quantize_pertoken_nvfp4_columnwise( - const int num_rows, const int num_cols, const IType *input, uint8_t *output_data_t, - fp8e4m3 *output_scales_t, const float *per_token_amax, const int scale_stride_t, - cudaStream_t stream, const float *noop = nullptr) { +void launch_quantize_pertoken_nvfp4_columnwise(const int num_rows, const int num_cols, + const IType *input, uint8_t *output_data_t, + fp8e4m3 *output_scales_t, + const float *per_token_amax, + const int scale_stride_t, cudaStream_t stream, + const float *noop = nullptr) { #if FP4_TYPE_SUPPORTED if (num_rows == 0 || num_cols == 0) return; NVTE_CHECK(num_rows % PERTOKEN_SF_VEC_SIZE == 0, "num_rows must be a multiple of ", - PERTOKEN_SF_VEC_SIZE, " for per-token NVFP4 columnwise quantization, got ", - num_rows); + PERTOKEN_SF_VEC_SIZE, " for per-token NVFP4 columnwise quantization, got ", num_rows); dim3 grid(num_cols); dim3 block(PERTOKEN_BLOCK_SIZE); From 62a1c1ed95c51423f198020c40d9404688788d01 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 26 Apr 2026 22:27:25 -0700 Subject: [PATCH 10/45] Avoid partial amax folding in gemm Signed-off-by: Ziang Li --- .../pytorch/cpp_extensions/gemm.py | 55 ++++++++++++------- 1 file changed, 34 insertions(+), 21 deletions(-) diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index f19b175969..fec82b8a02 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -78,6 +78,22 @@ def _is_nvfp4_pertoken_tensor(tensor: torch.Tensor) -> bool: return amax is not None and amax.numel() > 1 +def _nvfp4_pertoken_gemm_input( + tensor: NVFP4TensorStorage, +) -> Tuple[NVFP4TensorStorage, torch.Tensor]: + """Return a GEMM alias with identity activation amax and the original per-token amax.""" + metadata = tensor.get_metadata() + if tensor._amax_rowwise is not None: + amax = tensor._amax_rowwise + assert amax is not None and amax.numel() > 1 + metadata["amax_rowwise"] = amax.new_ones(1) + else: + amax = tensor._amax_columnwise + assert amax is not None and amax.numel() > 1 + metadata["amax_columnwise"] = amax.new_ones(1) + return NVFP4TensorStorage(**metadata), amax + + def general_gemm( A: torch.Tensor, B: torch.Tensor, @@ -196,38 +212,35 @@ def general_gemm( assert out is None or ( isinstance(out, torch.Tensor) and not is_custom(out) ), "Per-token NVFP4 GEMM currently supports only plain torch.Tensor outputs." - requested_out = out - requested_out_dtype = out_dtype + # cuBLAS folds the first activation amax into GEMM alpha. Keep per-token amax out of + # alpha by using identity here, then apply the true per-token scale in FP32 below. + gemm_B, amax = _nvfp4_pertoken_gemm_input(B) + per_token_scales = amax.view(-1, 1) + + requested_out, requested_out_dtype = out, out_dtype fp32_out = ( torch.empty_like(requested_out, dtype=torch.float32) if requested_out is not None else None ) - # Override only output, output quantizer, and output dtype for the FP32 correction path. - args = ( - *args[:4], - fp32_out, - None, - TE_DType[torch.float32], - *args[7:], - ) - out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) - - assert isinstance(out, torch.Tensor) and not is_custom(out) - assert out.numel() > 0 - amax = B._amax_rowwise if B._amax_rowwise is not None else B._amax_columnwise - assert amax is not None and amax.numel() > 1 - + gemm_args = list(args) + gemm_args[2] = gemm_B # B + gemm_args[4] = fp32_out # out + gemm_args[5] = None # quantization_params + gemm_args[6] = TE_DType[torch.float32] # out_dtype + out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*gemm_args, **kwargs) out_2d = out.reshape(-1, out.shape[-1]) + + assert amax.dtype == torch.float32 and out.dtype == torch.float32 assert amax.numel() == out_2d.shape[0] - ratios = (amax / amax[0]).to(dtype=out.dtype).view(-1, 1) + if bias is not None: - bias_cast = bias.to(dtype=out.dtype) + bias_cast = bias.to(dtype=torch.float32) out_2d.sub_(bias_cast) - out_2d.mul_(ratios) + out_2d.mul_(per_token_scales) out_2d.add_(bias_cast) else: - out_2d.mul_(ratios) + out_2d.mul_(per_token_scales) if requested_out is not None: requested_out.copy_(out.to(dtype=requested_out.dtype)) From 44e4e0fd5c28b2c1c7899be2e8b44f8ffc404cdb Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 26 Apr 2026 22:27:41 -0700 Subject: [PATCH 11/45] Expand test coverage Signed-off-by: Ziang Li --- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 138 +++++++++++++++++- .../nvfp4/test_nvfp4_quantize_exact.py | 12 ++ 2 files changed, 149 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index 3708205aef..1a6784ed24 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -8,7 +8,7 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch import NVFP4Quantizer -from transformer_engine.pytorch.cpp_extensions import general_gemm +from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.custom_recipes import utils @@ -222,6 +222,98 @@ def check_nvfp4_gemm_versus_reference( torch.testing.assert_close(y_native, y_ref, atol=8e-3, rtol=8e-3) +def check_nvfp4_pertoken_grouped_gemm_matches_per_gemm( + x_dtype: torch.dtype, + w_dtype: torch.dtype, + out_dtype: torch.dtype, + m_splits: list[int], + k: int, + n: int, + *, + use_bias: bool, + single_output: bool, +): + te_dtype = tex.DType.kFloat4E2M1 + device = "cuda" + torch.manual_seed(23) + torch.cuda.manual_seed(23) + + num_gemms = len(m_splits) + + x_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + per_token_activation=True, + ) + w_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + ) + + x_nvfp4 = [] + w_nvfp4 = [] + bias = [] + expected = [] + for m in m_splits: + x = torch.randn((m, k), dtype=x_dtype, device=device) + w = torch.randn((n, k), dtype=w_dtype, device=device) + x_nvfp4.append( + x_quantizer.update_quantized( + x, x_quantizer.make_empty(x.shape, dtype=x_dtype, device=device) + ) + ) + w_nvfp4.append( + w_quantizer.update_quantized( + w, w_quantizer.make_empty(w.shape, dtype=w_dtype, device=device) + ) + ) + bias.append(torch.randn(n, dtype=torch.bfloat16, device=device) if use_bias else None) + expected.append( + general_gemm( + w_nvfp4[-1], + x_nvfp4[-1], + out_dtype=out_dtype, + layout="TN", + bias=bias[-1], + )[0] + ) + + if single_output: + out = [torch.empty((sum(m_splits), n), dtype=out_dtype, device=device)] + else: + out = [torch.empty((m, n), dtype=out_dtype, device=device) for m in m_splits] + + grouped_out, _, _ = general_grouped_gemm( + w_nvfp4, + x_nvfp4, + out, + quantization_params=[None] * num_gemms, + out_dtype=out_dtype, + layout="TN", + m_splits=m_splits, + bias=bias, + use_bias=use_bias, + single_output=single_output, + ) + + if single_output: + grouped_slices = torch.split(grouped_out, m_splits, dim=0) + else: + grouped_slices = grouped_out + for grouped, ref in zip(grouped_slices, expected): + torch.testing.assert_close(grouped, ref, atol=0.0, rtol=0.0) + + @pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) @pytest.mark.parametrize( "M, K, N", @@ -283,3 +375,47 @@ def test_nvfp4_gemm_versus_reference( w_columnwise=is_w_columnwise, per_token_activation=per_token_activation, ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "m_splits, k, n", + [ + ([32, 48, 48], 128, 128), + ([64, 80, 112], 128, 256), + ([64, 80, 112], 256, 256), + ([64, 80, 112], 1024, 256), + ([256, 256, 512], 1024, 1024), + ([1024, 1536, 1536], 512, 3072), + ([16, 32, 64], 128, 96), + ([80, 96, 128], 640, 304), + ([320, 336, 352], 3072, 992), + ([64, 80, 112], 64, 256), + ([32, 48, 48], 128, 112), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("use_bias", [False, True], ids=["no_bias", "bias"]) +@pytest.mark.parametrize("single_output", [False, True], ids=["list_output", "single_output"]) +def test_nvfp4_pertoken_grouped_gemm_matches_per_gemm( + m_splits: list[int], + k: int, + n: int, + x_dtype: torch.dtype, + w_dtype: torch.dtype, + out_dtype: torch.dtype, + use_bias: bool, + single_output: bool, +): + check_nvfp4_pertoken_grouped_gemm_matches_per_gemm( + x_dtype=x_dtype, + w_dtype=w_dtype, + out_dtype=out_dtype, + m_splits=m_splits, + k=k, + n=n, + use_bias=use_bias, + single_output=single_output, + ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index cf801639b7..098807b685 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -75,6 +75,7 @@ def check_quantization_nvfp4_versus_reference( ) sx_t = x_nvfp4_sut._columnwise_scale_inv qx_amax = x_nvfp4_sut._amax_rowwise + qx_amax_t = x_nvfp4_sut._amax_columnwise # Reference quantization quant_tile_shape = (1, 16) if not with_2d_quantization else (16, 16) @@ -105,6 +106,7 @@ def check_quantization_nvfp4_versus_reference( x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None ) ref_amax = x_nvfp4_ref.global_amax_row + ref_amax_t = x_nvfp4_ref.global_amax_col qx = unpack_fp4(qx) qx_t = unpack_fp4(qx_t) if qx_t is not None else None @@ -124,6 +126,7 @@ def check_quantization_nvfp4_versus_reference( ref_sx_t_shape = sx_t_ref.shape sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + torch.testing.assert_close(qx_amax_t, ref_amax_t, atol=0.0, rtol=0.0) torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) @@ -249,6 +252,7 @@ def test_nvfp4_quantization_extrema_versus_reference( ) sx_t = x_nvfp4_sut._columnwise_scale_inv qx_amax = x_nvfp4_sut._amax_rowwise + qx_amax_t = x_nvfp4_sut._amax_columnwise ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, @@ -270,6 +274,7 @@ def test_nvfp4_quantization_extrema_versus_reference( x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None ) ref_amax = x_nvfp4_ref.global_amax_row + ref_amax_t = x_nvfp4_ref.global_amax_col torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) @@ -282,6 +287,7 @@ def test_nvfp4_quantization_extrema_versus_reference( ref_sx_t_shape = sx_t_ref.shape sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + torch.testing.assert_close(qx_amax_t, ref_amax_t, atol=0.0, rtol=0.0) torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) @@ -364,6 +370,7 @@ def test_nvfp4_quantization_boundary_values( ) sx_t = x_nvfp4_sut._columnwise_scale_inv qx_amax = x_nvfp4_sut._amax_rowwise + qx_amax_t = x_nvfp4_sut._amax_columnwise ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, @@ -385,6 +392,7 @@ def test_nvfp4_quantization_boundary_values( x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None ) ref_amax = x_nvfp4_ref.global_amax_row + ref_amax_t = x_nvfp4_ref.global_amax_col torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) @@ -398,6 +406,7 @@ def test_nvfp4_quantization_boundary_values( ref_sx_t_shape = sx_t_ref.shape sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + torch.testing.assert_close(qx_amax_t, ref_amax_t, atol=0.0, rtol=0.0) torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) @@ -465,6 +474,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( ) sx_t = x_nvfp4_sut._columnwise_scale_inv qx_amax = x_nvfp4_sut._amax_rowwise + qx_amax_t = x_nvfp4_sut._amax_columnwise ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, @@ -486,6 +496,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None ) ref_amax = x_nvfp4_ref.global_amax_row + ref_amax_t = x_nvfp4_ref.global_amax_col # Quantized must match torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) @@ -500,5 +511,6 @@ def test_nvfp4_quantization_noncontiguous_inputs( ref_sx_t_shape = sx_t_ref.shape sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + torch.testing.assert_close(qx_amax_t, ref_amax_t, atol=0.0, rtol=0.0) torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) From 4755f09941fee8cf2981fa3715277cf93b9e1b5f Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 26 Apr 2026 22:39:35 -0700 Subject: [PATCH 12/45] Expand more tests Signed-off-by: Ziang Li --- tests/pytorch/test_recipe.py | 5 +++-- tests/pytorch/test_torch_compile.py | 4 +++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 91d4b89013..b44f27765a 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -509,6 +509,7 @@ def test_quantizer_update(self, module_class): @pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_per_token"]) @pytest.mark.parametrize( "M, N", [ @@ -524,8 +525,8 @@ def test_quantizer_update(self, module_class): (8192, 8192), ], ) -def test_fp4_dequantize(dtype, M, N): - q = NVFP4Quantizer() +def test_fp4_dequantize(dtype, per_token_activation, M, N): + q = NVFP4Quantizer(per_token_activation=per_token_activation) a = torch.rand((M, N)).cuda().to(dtype=dtype) starting_tensor = q(a) dequantized_tensor = starting_tensor.dequantize() diff --git a/tests/pytorch/test_torch_compile.py b/tests/pytorch/test_torch_compile.py index 9d0ed79888..d67c5e77b7 100644 --- a/tests/pytorch/test_torch_compile.py +++ b/tests/pytorch/test_torch_compile.py @@ -32,6 +32,7 @@ is_fp8_block_scaling_available, is_nvfp4_available, ) +from utils import recipe_id fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True) @@ -47,6 +48,7 @@ _all_recipes.append(recipe.MXFP8BlockScaling()) if nvfp4_available: _all_recipes.append(recipe.NVFP4BlockScaling()) + _all_recipes.append(recipe.NVFP4BlockScaling(per_token_activation=True)) # --------------------------------------------------------------------------- @@ -303,7 +305,7 @@ def fn(inp): @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -@pytest.mark.parametrize("fp8_recipe", _all_recipes, ids=lambda r: type(r).__name__) +@pytest.mark.parametrize("fp8_recipe", _all_recipes, ids=recipe_id) def test_autocast_sanity(fp8_recipe): """Smoke test: torch.nn.Linear inside a single te.autocast with each built-in recipe. Forward + backward under torch.compile(fullgraph=True).""" From 55286ed5eb75e0147f361d09f1be11e95c71219a Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 26 Apr 2026 23:06:17 -0700 Subject: [PATCH 13/45] Turn on test for grouped linear sanity Signed-off-by: Ziang Li --- tests/pytorch/test_sanity.py | 7 ++-- .../tensor/storage/grouped_tensor_storage.py | 34 +++++++++++++++---- 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 73c291cc15..c7527ecfe4 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -542,7 +542,7 @@ def test_sanity_linear_with_zero_tokens( @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes_with_zero) @pytest.mark.parametrize("model", ["small", "weird"]) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_per_token, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean) @@ -576,7 +576,10 @@ def test_sanity_grouped_linear( if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") if fp8_recipe.nvfp4(): - pytest.skip("NVFP4 not supported for grouped linear") + if not getattr(fp8_recipe, "per_token_activation", False): + pytest.skip("NVFP4 not supported for grouped linear") + if dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") use_fp8 = fp8_recipe is not None with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index 5f12c3ed8c..1732abf57c 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -662,6 +662,10 @@ def make_grouped_tensor( # Amax buffer for delayed scaling - one per tensor amax = torch.empty(num_tensors, dtype=torch.float32, device=device) elif quantizer._get_compatible_recipe().nvfp4(): + per_token_activation = getattr(quantizer, "per_token_activation", False) + total_amax_elements = ( + sum(math.prod(s[:-1]) for s in shape) if per_token_activation else num_tensors + ) if rowwise_usage: # Allocate rowwise data buffer (1D flattened, uint8, but FP4 packs 2 values per byte) @@ -675,8 +679,7 @@ def make_grouped_tensor( total_scale_elements += math.prod(scale_inv_shape) scale_inv_offsets.append(total_scale_elements) scale_inv = torch.empty(total_scale_elements, dtype=torch.uint8, device=device) - # Amax buffer - one per tensor - amax = torch.empty(num_tensors, dtype=torch.float32, device=device) + amax = torch.empty(total_amax_elements, dtype=torch.float32, device=device) if columnwise_usage: # Allocate columnwise data buffer (1D flattened, uint8, FP4 packed) @@ -693,8 +696,9 @@ def make_grouped_tensor( columnwise_scale_inv = torch.empty( total_columnwise_scale_elements, dtype=torch.uint8, device=device ) - # Columnwise amax buffer - one per tensor - columnwise_amax = torch.empty(num_tensors, dtype=torch.float32, device=device) + columnwise_amax = torch.empty( + total_amax_elements, dtype=torch.float32, device=device + ) elif quantizer._get_compatible_recipe().float8_block_scaling(): if rowwise_usage: # Allocate rowwise data buffer (1D flattened, uint8) @@ -891,6 +895,13 @@ def split_into_quantized_tensors( cum += math.prod(scale_shape) columnwise_scale_inv_offsets.append(cum) self.columnwise_scale_inv_offsets = columnwise_scale_inv_offsets + nvfp4_per_token_amax_offsets = None + if recipe.nvfp4() and getattr(self.quantizer, "per_token_activation", False): + cum = 0 + nvfp4_per_token_amax_offsets = [0] + for i in range(self.num_tensors): + cum += math.prod(self.tensor_shapes[i][:-1]) + nvfp4_per_token_amax_offsets.append(cum) for i in range(self.num_tensors): quantizer = self.quantizer @@ -1083,12 +1094,21 @@ def split_into_quantized_tensors( cscale_shape ) - # Extract amax - one per tensor if self.amax is not None: - amax_rowwise = self.amax[i : i + 1] + if nvfp4_per_token_amax_offsets is not None: + amax_start = nvfp4_per_token_amax_offsets[i] + amax_end = nvfp4_per_token_amax_offsets[i + 1] + amax_rowwise = self.amax[amax_start:amax_end] + else: + amax_rowwise = self.amax[i : i + 1] if self.columnwise_amax is not None: - amax_columnwise = self.columnwise_amax[i : i + 1] + if nvfp4_per_token_amax_offsets is not None: + amax_start = nvfp4_per_token_amax_offsets[i] + amax_end = nvfp4_per_token_amax_offsets[i + 1] + amax_columnwise = self.columnwise_amax[amax_start:amax_end] + else: + amax_columnwise = self.columnwise_amax[i : i + 1] if quantizer.internal: nvfp4_tensor_class = NVFP4TensorStorage From e4829b8038bce6ede33c22eea395d63defd4283a Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 26 Apr 2026 23:39:05 -0700 Subject: [PATCH 14/45] Rename pertoken to per_token Signed-off-by: Ziang Li --- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 6 +- .../common/cast/dispatch/quantize.cuh | 6 +- ...nvfp4.cuh => quantize_per_token_nvfp4.cuh} | 102 +++++++++--------- .../pytorch/cpp_extensions/gemm.py | 10 +- transformer_engine/pytorch/csrc/extensions.h | 2 +- .../pytorch/csrc/extensions/cast.cpp | 2 +- .../pytorch/csrc/extensions/pybind.cpp | 2 +- 7 files changed, 65 insertions(+), 65 deletions(-) rename transformer_engine/common/cast/nvfp4/{quantize_pertoken_nvfp4.cuh => quantize_per_token_nvfp4.cuh} (83%) diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index 1a6784ed24..5fdb0c7d26 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -222,7 +222,7 @@ def check_nvfp4_gemm_versus_reference( torch.testing.assert_close(y_native, y_ref, atol=8e-3, rtol=8e-3) -def check_nvfp4_pertoken_grouped_gemm_matches_per_gemm( +def check_nvfp4_per_token_grouped_gemm_matches_per_gemm( x_dtype: torch.dtype, w_dtype: torch.dtype, out_dtype: torch.dtype, @@ -399,7 +399,7 @@ def test_nvfp4_gemm_versus_reference( @pytest.mark.parametrize("out_dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("use_bias", [False, True], ids=["no_bias", "bias"]) @pytest.mark.parametrize("single_output", [False, True], ids=["list_output", "single_output"]) -def test_nvfp4_pertoken_grouped_gemm_matches_per_gemm( +def test_nvfp4_per_token_grouped_gemm_matches_per_gemm( m_splits: list[int], k: int, n: int, @@ -409,7 +409,7 @@ def test_nvfp4_pertoken_grouped_gemm_matches_per_gemm( use_bias: bool, single_output: bool, ): - check_nvfp4_pertoken_grouped_gemm_matches_per_gemm( + check_nvfp4_per_token_grouped_gemm_matches_per_gemm( x_dtype=x_dtype, w_dtype=w_dtype, out_dtype=out_dtype, diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index eab27a6e7e..1200979f6b 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -21,7 +21,7 @@ #include "../mxfp8/group_quantize_mxfp8.cuh" #include "../mxfp8/quantize_mxfp8.cuh" #include "../nvfp4/group_quantize_transpose_nvfp4.cuh" -#include "../nvfp4/quantize_pertoken_nvfp4.cuh" +#include "../nvfp4/quantize_per_token_nvfp4.cuh" #include "../nvfp4/quantize_transpose_nvfp4.cuh" namespace transformer_engine { @@ -105,7 +105,7 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, if (per_token_activation) { NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, "Per-token NVFP4 quantization does not support 2D quantization."); - nvfp4::quantize_pertoken(*input_tensor, noop_tensor, output_tensor, stream); + nvfp4::quantize_per_token(*input_tensor, noop_tensor, output_tensor, stream); break; } bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && @@ -251,7 +251,7 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens if (per_token_activation) { NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, "Per-token NVFP4 quantization does not support 2D quantization."); - nvfp4::quantize_pertoken(*grad_tensor, noop_tensor, output_tensor, stream); + nvfp4::quantize_per_token(*grad_tensor, noop_tensor, output_tensor, stream); break; } bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && diff --git a/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_per_token_nvfp4.cuh similarity index 83% rename from transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh rename to transformer_engine/common/cast/nvfp4/quantize_per_token_nvfp4.cuh index 36eb05115d..c4b16c557e 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_per_token_nvfp4.cuh @@ -4,7 +4,7 @@ * See LICENSE for license information. ************************************************************************/ -/*! \file quantize_pertoken_nvfp4.cuh +/*! \file quantize_per_token_nvfp4.cuh * \brief CUDA kernels to cast to NVFP4 with per-token (per-row) global scaling. */ @@ -29,7 +29,7 @@ namespace transformer_engine { namespace dispatch { namespace nvfp4 { -namespace quantize_pertoken_kernel { +namespace quantize_per_token_kernel { using namespace core; using namespace ptx; @@ -62,13 +62,13 @@ __global__ void #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) __launch_bounds__(BLOCK_SIZE) #endif - quantize_pertoken_nvfp4_kernel(const int num_rows, const int num_cols, - const IType *__restrict__ input, - const int *__restrict__ row_offsets, - uint8_t *__restrict__ output_data, - fp8e4m3 *__restrict__ output_scales, - float *__restrict__ output_per_token_amax, - const int scale_stride, const float *__restrict__ noop) { + quantize_per_token_nvfp4_kernel(const int num_rows, const int num_cols, + const IType *__restrict__ input, + const int *__restrict__ row_offsets, + uint8_t *__restrict__ output_data, + fp8e4m3 *__restrict__ output_scales, + float *__restrict__ output_per_token_amax, + const int scale_stride, const float *__restrict__ noop) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) using namespace detail; if (noop != nullptr && noop[0] == 1.0f) { @@ -159,11 +159,11 @@ __launch_bounds__(BLOCK_SIZE) } template -void launch_quantize_pertoken_nvfp4(const int num_rows, const int num_cols, const IType *input, - const int *row_offsets, uint8_t *output_data, - fp8e4m3 *output_scales, float *output_per_token_amax, - const int scale_stride, cudaStream_t stream, - const float *noop = nullptr) { +void launch_quantize_per_token_nvfp4(const int num_rows, const int num_cols, const IType *input, + const int *row_offsets, uint8_t *output_data, + fp8e4m3 *output_scales, float *output_per_token_amax, + const int scale_stride, cudaStream_t stream, + const float *noop = nullptr) { #if FP4_TYPE_SUPPORTED if (num_rows == 0 || num_cols == 0) return; @@ -172,7 +172,7 @@ void launch_quantize_pertoken_nvfp4(const int num_rows, const int num_cols, cons dim3 grid(num_rows); dim3 block(PERTOKEN_BLOCK_SIZE); - quantize_pertoken_nvfp4_kernel + quantize_per_token_nvfp4_kernel <<>>(num_rows, num_cols, input, row_offsets, output_data, output_scales, output_per_token_amax, scale_stride, noop); NVTE_CHECK_CUDA(cudaGetLastError()); @@ -186,10 +186,10 @@ __global__ void #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) __launch_bounds__(BLOCK_SIZE) #endif - compute_pertoken_amax_kernel(const int num_rows, const int num_cols, - const IType *__restrict__ input, - float *__restrict__ output_per_token_amax, - const float *__restrict__ noop) { + compute_per_token_amax_kernel(const int num_rows, const int num_cols, + const IType *__restrict__ input, + float *__restrict__ output_per_token_amax, + const float *__restrict__ noop) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) if (noop != nullptr && noop[0] == 1.0f) { return; @@ -222,9 +222,9 @@ __launch_bounds__(BLOCK_SIZE) } template -void launch_compute_pertoken_amax(const int num_rows, const int num_cols, const IType *input, - float *output_per_token_amax, cudaStream_t stream, - const float *noop = nullptr) { +void launch_compute_per_token_amax(const int num_rows, const int num_cols, const IType *input, + float *output_per_token_amax, cudaStream_t stream, + const float *noop = nullptr) { #if FP4_TYPE_SUPPORTED if (num_rows == 0 || num_cols == 0) return; @@ -233,7 +233,7 @@ void launch_compute_pertoken_amax(const int num_rows, const int num_cols, const dim3 grid(num_rows); dim3 block(PERTOKEN_BLOCK_SIZE); - compute_pertoken_amax_kernel + compute_per_token_amax_kernel <<>>(num_rows, num_cols, input, output_per_token_amax, noop); NVTE_CHECK_CUDA(cudaGetLastError()); #else @@ -246,13 +246,13 @@ __global__ void #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) __launch_bounds__(BLOCK_SIZE) #endif - quantize_pertoken_nvfp4_columnwise_kernel(const int num_rows, const int num_cols, - const IType *__restrict__ input, - uint8_t *__restrict__ output_data_t, - fp8e4m3 *__restrict__ output_scales_t, - const float *__restrict__ per_token_amax, - const int scale_stride_t, - const float *__restrict__ noop) { + quantize_per_token_nvfp4_columnwise_kernel(const int num_rows, const int num_cols, + const IType *__restrict__ input, + uint8_t *__restrict__ output_data_t, + fp8e4m3 *__restrict__ output_scales_t, + const float *__restrict__ per_token_amax, + const int scale_stride_t, + const float *__restrict__ noop) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) using namespace detail; if (noop != nullptr && noop[0] == 1.0f) { @@ -311,12 +311,12 @@ __launch_bounds__(BLOCK_SIZE) } template -void launch_quantize_pertoken_nvfp4_columnwise(const int num_rows, const int num_cols, - const IType *input, uint8_t *output_data_t, - fp8e4m3 *output_scales_t, - const float *per_token_amax, - const int scale_stride_t, cudaStream_t stream, - const float *noop = nullptr) { +void launch_quantize_per_token_nvfp4_columnwise(const int num_rows, const int num_cols, + const IType *input, uint8_t *output_data_t, + fp8e4m3 *output_scales_t, + const float *per_token_amax, + const int scale_stride_t, cudaStream_t stream, + const float *noop = nullptr) { #if FP4_TYPE_SUPPORTED if (num_rows == 0 || num_cols == 0) return; @@ -325,7 +325,7 @@ void launch_quantize_pertoken_nvfp4_columnwise(const int num_rows, const int num dim3 grid(num_cols); dim3 block(PERTOKEN_BLOCK_SIZE); - quantize_pertoken_nvfp4_columnwise_kernel + quantize_per_token_nvfp4_columnwise_kernel <<>>(num_rows, num_cols, input, output_data_t, output_scales_t, per_token_amax, scale_stride_t, noop); NVTE_CHECK_CUDA(cudaGetLastError()); @@ -334,10 +334,10 @@ void launch_quantize_pertoken_nvfp4_columnwise(const int num_rows, const int num #endif } -} // namespace quantize_pertoken_kernel +} // namespace quantize_per_token_kernel -inline void quantize_pertoken(const Tensor &input, const Tensor *noop, Tensor *output, - cudaStream_t stream) { +inline void quantize_per_token(const Tensor &input, const Tensor *noop, Tensor *output, + cudaStream_t stream) { #if FP4_TYPE_SUPPORTED checkCuDriverContext(stream); CheckNoopTensor(*noop, "cast_noop"); @@ -351,9 +351,9 @@ inline void quantize_pertoken(const Tensor &input, const Tensor *noop, Tensor *o const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); - NVTE_CHECK(cols % quantize_pertoken_kernel::PERTOKEN_SF_VEC_SIZE == 0, + NVTE_CHECK(cols % quantize_per_token_kernel::PERTOKEN_SF_VEC_SIZE == 0, "Per-token NVFP4 quantization requires last dim divisible by ", - quantize_pertoken_kernel::PERTOKEN_SF_VEC_SIZE, "."); + quantize_per_token_kernel::PERTOKEN_SF_VEC_SIZE, "."); const auto *noop_ptr = reinterpret_cast(noop->data.dptr); auto *amax_ptr = reinterpret_cast(output->amax.dptr); @@ -379,11 +379,11 @@ inline void quantize_pertoken(const Tensor &input, const Tensor *noop, Tensor *o auto *data_ptr = reinterpret_cast(output->data.dptr); auto *scale_ptr = reinterpret_cast(output->scale_inv.dptr); const int scale_stride = static_cast(output->scale_inv.shape.back()); - quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4<__nv_bfloat16>( + quantize_per_token_kernel::launch_quantize_per_token_nvfp4<__nv_bfloat16>( static_cast(rows), static_cast(cols), input_ptr, row_offsets, data_ptr, scale_ptr, amax_ptr, scale_stride, stream, noop_ptr); } else { - quantize_pertoken_kernel::launch_compute_pertoken_amax<__nv_bfloat16>( + quantize_per_token_kernel::launch_compute_per_token_amax<__nv_bfloat16>( static_cast(rows), static_cast(cols), input_ptr, per_token_amax_ptr, stream, noop_ptr); } @@ -399,7 +399,7 @@ inline void quantize_pertoken(const Tensor &input, const Tensor *noop, Tensor *o auto *data_t_ptr = reinterpret_cast(output->columnwise_data.dptr); auto *scale_t_ptr = reinterpret_cast(output->columnwise_scale_inv.dptr); const int scale_stride_t = static_cast(output->columnwise_scale_inv.shape.back()); - quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4_columnwise<__nv_bfloat16>( + quantize_per_token_kernel::launch_quantize_per_token_nvfp4_columnwise<__nv_bfloat16>( static_cast(rows), static_cast(cols), input_ptr, data_t_ptr, scale_t_ptr, per_token_amax_ptr, scale_stride_t, stream, noop_ptr); } @@ -412,11 +412,11 @@ inline void quantize_pertoken(const Tensor &input, const Tensor *noop, Tensor *o auto *data_ptr = reinterpret_cast(output->data.dptr); auto *scale_ptr = reinterpret_cast(output->scale_inv.dptr); const int scale_stride = static_cast(output->scale_inv.shape.back()); - quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4( + quantize_per_token_kernel::launch_quantize_per_token_nvfp4( static_cast(rows), static_cast(cols), input_ptr, row_offsets, data_ptr, scale_ptr, amax_ptr, scale_stride, stream, noop_ptr); } else { - quantize_pertoken_kernel::launch_compute_pertoken_amax( + quantize_per_token_kernel::launch_compute_per_token_amax( static_cast(rows), static_cast(cols), input_ptr, per_token_amax_ptr, stream, noop_ptr); } @@ -432,7 +432,7 @@ inline void quantize_pertoken(const Tensor &input, const Tensor *noop, Tensor *o auto *data_t_ptr = reinterpret_cast(output->columnwise_data.dptr); auto *scale_t_ptr = reinterpret_cast(output->columnwise_scale_inv.dptr); const int scale_stride_t = static_cast(output->columnwise_scale_inv.shape.back()); - quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4_columnwise( + quantize_per_token_kernel::launch_quantize_per_token_nvfp4_columnwise( static_cast(rows), static_cast(cols), input_ptr, data_t_ptr, scale_t_ptr, per_token_amax_ptr, scale_stride_t, stream, noop_ptr); } @@ -445,11 +445,11 @@ inline void quantize_pertoken(const Tensor &input, const Tensor *noop, Tensor *o auto *data_ptr = reinterpret_cast(output->data.dptr); auto *scale_ptr = reinterpret_cast(output->scale_inv.dptr); const int scale_stride = static_cast(output->scale_inv.shape.back()); - quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4( + quantize_per_token_kernel::launch_quantize_per_token_nvfp4( static_cast(rows), static_cast(cols), input_ptr, row_offsets, data_ptr, scale_ptr, amax_ptr, scale_stride, stream, noop_ptr); } else { - quantize_pertoken_kernel::launch_compute_pertoken_amax( + quantize_per_token_kernel::launch_compute_per_token_amax( static_cast(rows), static_cast(cols), input_ptr, per_token_amax_ptr, stream, noop_ptr); } @@ -465,7 +465,7 @@ inline void quantize_pertoken(const Tensor &input, const Tensor *noop, Tensor *o auto *data_t_ptr = reinterpret_cast(output->columnwise_data.dptr); auto *scale_t_ptr = reinterpret_cast(output->columnwise_scale_inv.dptr); const int scale_stride_t = static_cast(output->columnwise_scale_inv.shape.back()); - quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4_columnwise( + quantize_per_token_kernel::launch_quantize_per_token_nvfp4_columnwise( static_cast(rows), static_cast(cols), input_ptr, data_t_ptr, scale_t_ptr, per_token_amax_ptr, scale_stride_t, stream, noop_ptr); } diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index fec82b8a02..d23fdf1b59 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -70,7 +70,7 @@ def validate_gemm_scale(scale: Optional[float], required: bool) -> float: return 0.0 -def _is_nvfp4_pertoken_tensor(tensor: torch.Tensor) -> bool: +def _is_nvfp4_per_token_tensor(tensor: torch.Tensor) -> bool: """Whether tensor carries per-token NVFP4 global amax metadata.""" if not isinstance(tensor, NVFP4TensorStorage): return False @@ -78,7 +78,7 @@ def _is_nvfp4_pertoken_tensor(tensor: torch.Tensor) -> bool: return amax is not None and amax.numel() > 1 -def _nvfp4_pertoken_gemm_input( +def _nvfp4_per_token_gemm_input( tensor: NVFP4TensorStorage, ) -> Tuple[NVFP4TensorStorage, torch.Tensor]: """Return a GEMM alias with identity activation amax and the original per-token amax.""" @@ -199,7 +199,7 @@ def general_gemm( "beta": beta, } - if not _is_nvfp4_pertoken_tensor(B): + if not _is_nvfp4_per_token_tensor(B): out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) else: assert layout[1] == "N", "Per-token NVFP4 GEMM currently supports N-layout B only." @@ -214,7 +214,7 @@ def general_gemm( ), "Per-token NVFP4 GEMM currently supports only plain torch.Tensor outputs." # cuBLAS folds the first activation amax into GEMM alpha. Keep per-token amax out of # alpha by using identity here, then apply the true per-token scale in FP32 below. - gemm_B, amax = _nvfp4_pertoken_gemm_input(B) + gemm_B, amax = _nvfp4_per_token_gemm_input(B) per_token_scales = amax.view(-1, 1) requested_out, requested_out_dtype = out, out_dtype @@ -301,7 +301,7 @@ def general_grouped_gemm( else: bias_dtype = TE_DType[torch.bfloat16] - if any(_is_nvfp4_pertoken_tensor(tensor) for tensor in B): + if any(_is_nvfp4_per_token_tensor(tensor) for tensor in B): assert layout[1] == "N", "Per-token NVFP4 grouped GEMM currently supports N-layout B only." assert not grad, "Per-token NVFP4 grouped GEMM currently supports fprop only." assert not gelu, "Per-token NVFP4 grouped GEMM currently does not support fused GELU." diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 06478b54e0..f62853bb2b 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -326,7 +326,7 @@ py::object group_dequantize(const py::handle &input, DType otype); py::object bgrad_group_quantize(const at::Tensor &tensor, py::handle quantizer, const size_t num_tensors, std::optional first_dims); -std::tuple quantize_nvfp4_pertoken(at::Tensor input); +std::tuple quantize_nvfp4_per_token(at::Tensor input); std::vector multi_tensor_quantize(const std::vector &tensor_list, std::vector quantizer_list); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 9423aa7296..ba75867a15 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -1603,7 +1603,7 @@ std::vector split_quantize(const at::Tensor &tensor, return output_py_list; } -std::tuple quantize_nvfp4_pertoken(at::Tensor input) { +std::tuple quantize_nvfp4_per_token(at::Tensor input) { init_extension(); NVTE_CHECK(input.dim() == 2, "Input must be 2D (num_rows, num_cols)"); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 4021792f86..b2d74205cc 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -145,7 +145,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Dequantize group tensor", py::arg("input"), py::arg("otype")); m.def("bgrad_group_quantize", transformer_engine::pytorch::bgrad_group_quantize, py::arg("tensor"), py::arg("quantizer"), py::arg("num_tensors"), py::arg("first_dims")); - m.def("quantize_nvfp4_pertoken", transformer_engine::pytorch::quantize_nvfp4_pertoken, + m.def("quantize_nvfp4_per_token", transformer_engine::pytorch::quantize_nvfp4_per_token, "Per-token NVFP4 quantization", py::arg("input")); m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize, "Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer")); From dbbdecbf195a1aa68cce37987d4e2f489865a938 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 27 Apr 2026 02:09:11 -0700 Subject: [PATCH 15/45] Expand .cu test Signed-off-by: Ziang Li --- .../cpp/operator/test_cast_nvfp4_transpose.cu | 189 +++++++++++++++--- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 2 +- 2 files changed, 158 insertions(+), 33 deletions(-) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 15d7c695c9..c59c895965 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -114,16 +114,14 @@ void quantize_nvfp4_1d(float (*OP)(const float), block_amax = std::max(block_amax, std::abs(elt)); } - // 2. Compute E4M3 scaling factor - // Compute per-block encoding/decoding scaling factor - const float S_dec_b = block_amax / 6.0f; - - // Scale & Store per-block decoding scaling factor - const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b * S_enc); + // Compute and store the per-block FP8 decode scale + const float S_dec_b = block_amax * (S_enc * (1.0f / 6.0f)); + const fp8e4m3 S_dec_b_fp8 = static_cast(fminf(S_dec_b, Numeric_Traits::maxNorm)); const float S_dec_b_fp32 = static_cast(S_dec_b_fp8); // Compute "correct" per-block encoding scaling factor - const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : S_enc / S_dec_b_fp32; + const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : + fminf(1.0f / (S_dec_b_fp32 * (1.0f / S_enc)), Numeric_Traits::maxNorm); const size_t scale_idx = i * scales_stride + block_X; scales[scale_idx] = S_dec_b_fp8; @@ -317,11 +315,69 @@ void compute_ref(float (*OP)(const float), const size_t scales_stride, const size_t scales_stride_t, const bool use_fast_math, - const bool use_2d_quantization = false) + const bool use_2d_quantization = false, + std::vector *per_token_amax = nullptr) { std::vector input_t = create_transpose(input, rows, cols); - if (use_2d_quantization) { + if (per_token_amax != nullptr) { + constexpr size_t kBlockSize = 16; + constexpr float fp4_max_inv = 1.0f / 6.0f; + constexpr float float_max = Numeric_Traits::maxNorm; + + per_token_amax->resize(rows, 0.0f); + for (size_t row = 0; row < rows; ++row) { + float row_amax = 0.0f; + for (size_t col = 0; col < cols; ++col) { + row_amax = fmaxf(row_amax, fabsf(static_cast(input[row * cols + col]))); + } + (*per_token_amax)[row] = row_amax; + quantize_nvfp4(OP, + input + row * cols, + output + row * (cols / 2), + scales + row * scales_stride, + 1, + cols, + scales_stride, + row_amax, + use_fast_math, + use_2d_quantization); + } + + for (size_t col = 0; col < cols; ++col) { + for (size_t row_start = 0; row_start < rows; row_start += kBlockSize) { + float vals[kBlockSize]; + float s_enc[kBlockSize]; + float scaled_block_amax = 0.0f; + for (size_t i = 0; i < kBlockSize; ++i) { + const size_t row = row_start + i; + const float val = static_cast(input[row * cols + col]); + const float S_enc = + compute_global_encode_scaling_factor_FP4((*per_token_amax)[row], false); + vals[i] = val; + s_enc[i] = S_enc; + scaled_block_amax = fmaxf(scaled_block_amax, fabsf(val) * (S_enc * fp4_max_inv)); + } + + const float S_dec_b_f32 = fminf(scaled_block_amax, float_max); + const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b_f32); + scales_t[col * scales_stride_t + row_start / kBlockSize] = S_dec_b_fp8; + + for (size_t i = 0; i < kBlockSize; i += 2) { + const float S_dec_rowwise_x = 1.0f / s_enc[i]; + const float S_dec_rowwise_y = 1.0f / s_enc[i + 1]; + const float S_dec_b_fp32 = static_cast(S_dec_b_fp8); + const float S_enc_b_fp8_x = + fminf(1.0f / (S_dec_b_fp32 * S_dec_rowwise_x), float_max); + const float S_enc_b_fp8_y = + fminf(1.0f / (S_dec_b_fp32 * S_dec_rowwise_y), float_max); + const float2 scaled_elt_pair = {vals[i] * S_enc_b_fp8_x, + vals[i + 1] * S_enc_b_fp8_y}; + output_t[(col * rows + row_start + i) / 2] = fp4e2m1x2(scaled_elt_pair); + } + } + } + } else if (use_2d_quantization) { // Step 1: Compute mathematical 8×8 scaling factors std::vector> math_scales; compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales, use_fast_math); @@ -526,10 +582,20 @@ void compareResults_nvfp4(const Tensor &test, compare_nvfp4_tensors("output_t", test_data_t, ref_data_t, cols, rows, atol, rtol); } +void compare_per_token_amax(const Tensor &test_amax, const std::vector &ref_amax) { + test_amax.to_cpu(); + const float *test_amax_data = test_amax.rowwise_cpu_dptr(); + for (size_t row = 0; row < ref_amax.size(); ++row) { + ASSERT_EQ(test_amax_data[row], ref_amax[row]) + << "Per-token amax mismatch at row " << row; + } +} + template void performTest(float (*OP)(const float), const std::vector& shape, - const bool use_fast_math) { + const bool use_fast_math, + const bool per_token_activation = false) { using namespace test; DType itype = TypeInfo::dtype; @@ -557,6 +623,7 @@ void performTest(float (*OP)(const float), Tensor input("input", shape, itype); Tensor output("output", shape, otype, true, true, NVTE_NVFP4_1D_SCALING); + Tensor per_token_amax; std::unique_ptr ref_output = std::make_unique(rows * (cols / 2)); std::unique_ptr ref_output_t = std::make_unique(cols * (rows / 2)); @@ -565,28 +632,53 @@ void performTest(float (*OP)(const float), fillCase(&input, InputsFillCase::uniform); - // Golden value of amax chosen to make the 2nd-stage scaling mantissa zero and avoid rounding issues - const float amax = 448.0f * 6.0f * 8.0f; - - // Set 2nd stage NVFP4 scaling factor - output.set_tensor_amax(amax); - output.set_tensor_amax_columnwise(amax); - bool use_2d_quantization = false; - - compute_ref(OP, - input.rowwise_cpu_dptr(), - ref_output.get(), - ref_output_t.get(), - ref_scales.get(), - ref_scales_t.get(), - amax, - rows, - cols, - scales_stride, - scales_stride_t, - use_fast_math, - use_2d_quantization); + std::vector ref_per_token_amax; + if (per_token_activation) { + per_token_amax = Tensor("per_token_amax", std::vector{rows}, DType::kFloat32); + compute_ref(OP, + input.rowwise_cpu_dptr(), + ref_output.get(), + ref_output_t.get(), + ref_scales.get(), + ref_scales_t.get(), + 0.0f, + rows, + cols, + scales_stride, + scales_stride_t, + use_fast_math, + use_2d_quantization, + &ref_per_token_amax); + + std::vector per_token_amax_shape = {rows}; + NVTEBasicTensor amax_tensor = {per_token_amax.rowwise_dptr(), + static_cast(DType::kFloat32), + nvte_make_shape(per_token_amax_shape.data(), + per_token_amax_shape.size())}; + NVTETensor output_tensor = output.data(); + nvte_set_tensor_param_v2(output_tensor, kNVTEAmax, &amax_tensor, sizeof(amax_tensor)); + } else { + // Golden value of amax chosen to make the 2nd-stage scaling mantissa zero and avoid rounding issues + const float amax = 448.0f * 6.0f * 8.0f; + // Set 2nd stage NVFP4 scaling factor + output.set_tensor_amax(amax); + output.set_tensor_amax_columnwise(amax); + + compute_ref(OP, + input.rowwise_cpu_dptr(), + ref_output.get(), + ref_output_t.get(), + ref_scales.get(), + ref_scales_t.get(), + amax, + rows, + cols, + scales_stride, + scales_stride_t, + use_fast_math, + use_2d_quantization); + } // Initialize stochastic rounding Tensor rng_state("rng_state", std::vector{2}, DType::kInt64); rng_state.rowwise_cpu_dptr()[0] = 123; // rng_seed @@ -600,6 +692,7 @@ void performTest(float (*OP)(const float), // Set 2D quantization based on compile-time flag quant_config.set_nvfp4_2d_quantization(use_2d_quantization); + quant_config.set_nvfp4_per_token_activation(per_token_activation); // Call appropriate function based on operation type // Activation functions take 3 parameters (input, output, stream) @@ -646,6 +739,10 @@ void performTest(float (*OP)(const float), ref_scales_t.get(), unpadded_blocks_Y_t, unpadded_blocks_X_t, scales_stride_t, scale_mismatches_num); + + if (per_token_activation) { + compare_per_token_amax(per_token_amax, ref_per_token_amax); + } } std::vector> tensor_dims = { @@ -678,6 +775,7 @@ class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam , transformer_engine::DType, + bool, bool>> {}; TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { @@ -693,6 +791,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { const auto tensor_dims = std::get<1>(GetParam()); const DType input_type = std::get<2>(GetParam()); const bool use_fast_math = std::get<3>(GetParam()); + const bool per_token_activation = std::get<4>(GetParam()); // Skip tests if the input tensor is 1D if (tensor_dims.size() < 2) { @@ -710,7 +809,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { } TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, - performTest(OP, tensor_dims, use_fast_math); + performTest(OP, tensor_dims, use_fast_math, per_token_activation); ); } @@ -733,6 +832,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(Activation_types), ::testing::ValuesIn(tensor_dims), ::testing::Values(DType::kBFloat16), + ::testing::Values(false), ::testing::Values(false)), [](const testing::TestParamInfo& info) { std::string name = to_string(std::get<0>(info.param)); @@ -746,3 +846,28 @@ INSTANTIATE_TEST_SUITE_P( } return name; }); + +INSTANTIATE_TEST_SUITE_P( + OperatorTestPerToken, + FusedCastTransposeNVFP4TestSuite, + ::testing::Combine( + ::testing::Values(ActivationType::Identity), + ::testing::Values(tensor_dims[4], tensor_dims[9], tensor_dims[12]), + ::testing::Values(DType::kBFloat16, DType::kFloat32), + ::testing::Values(false), + ::testing::Values(true)), + [](const testing::TestParamInfo& info) { + std::string name = to_string(std::get<0>(info.param)); + const auto& shape = std::get<1>(info.param); + for (const auto& s: shape) { + name += "X" + std::to_string(s); + } + name += "X" + test::typeName(std::get<2>(info.param)); + if (std::get<3>(info.param)) { + name += "X_FAST_SCALING"; + } + if (std::get<4>(info.param)) { + name += "XPER_TOKEN"; + } + return name; + }); diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index 5fdb0c7d26..ef6eda8dcd 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -118,7 +118,7 @@ def check_nvfp4_gemm_versus_reference( x_ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, rowwise=True, - columnwise=not per_token_activation, + columnwise=True, pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), From 2374a6e0757dff7e09bfee3f2ef2a07b026aa20b Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sat, 2 May 2026 12:08:19 -0700 Subject: [PATCH 16/45] Format after rebase Signed-off-by: Ziang Li --- .../common/cast/nvfp4/dequantize_nvfp4.cuh | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh index 9436b94939..85e858e146 100644 --- a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -34,8 +34,9 @@ namespace dequantize_kernel { template __global__ void __launch_bounds__(512) dequantize_fp4_kernel(const void *const input, OType *output, const fp8e4m3 *const scales, - const float *const tensor_amax, const size_t amax_numel, const size_t N, const size_t M, - const size_t scale_stride, const size_t num_scale_tiles_X) { + const float *const tensor_amax, const size_t amax_numel, const size_t N, + const size_t M, const size_t scale_stride, + const size_t num_scale_tiles_X) { const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; const size_t x = thread_idx % M; const size_t y = thread_idx / M; @@ -110,11 +111,12 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES, dequantize_fp4_kernel<<>>( - input.data.dptr, reinterpret_cast(output->data.dptr), - reinterpret_cast(input.scale_inv.dptr), - reinterpret_cast(input.amax.dptr), input.amax.numel(), N, Mread, input.scale_inv.shape.back(), - num_scale_tiles_X);); // NOLINT(*) -); // NOLINT(*) + input.data.dptr, reinterpret_cast(output->data.dptr), + reinterpret_cast(input.scale_inv.dptr), + reinterpret_cast(input.amax.dptr), input.amax.numel(), N, Mread, + input.scale_inv.shape.back(), + num_scale_tiles_X);); // NOLINT(*) + ); // NOLINT(*) NVTE_CHECK_CUDA(cudaGetLastError()); #else NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); From 57982850fb4189d4fafbac39e1afe0afa4d5ba22 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sat, 2 May 2026 12:09:53 -0700 Subject: [PATCH 17/45] Fix test after rebase Signed-off-by: Ziang Li --- .../cpp/operator/test_cast_nvfp4_transpose.cu | 42 +++++++++++++++---- 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index c59c895965..cc45b2fce5 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -582,9 +582,13 @@ void compareResults_nvfp4(const Tensor &test, compare_nvfp4_tensors("output_t", test_data_t, ref_data_t, cols, rows, atol, rtol); } -void compare_per_token_amax(const Tensor &test_amax, const std::vector &ref_amax) { - test_amax.to_cpu(); - const float *test_amax_data = test_amax.rowwise_cpu_dptr(); +void compare_per_token_amax(const float *test_amax, const std::vector &ref_amax) { + std::vector test_amax_data(ref_amax.size()); + ASSERT_EQ(cudaMemcpy(test_amax_data.data(), + test_amax, + ref_amax.size() * sizeof(float), + cudaMemcpyDeviceToHost), + cudaSuccess); for (size_t row = 0; row < ref_amax.size(); ++row) { ASSERT_EQ(test_amax_data[row], ref_amax[row]) << "Per-token amax mismatch at row " << row; @@ -623,7 +627,8 @@ void performTest(float (*OP)(const float), Tensor input("input", shape, itype); Tensor output("output", shape, otype, true, true, NVTE_NVFP4_1D_SCALING); - Tensor per_token_amax; + float *per_token_amax = nullptr; + float *per_token_columnwise_amax = nullptr; std::unique_ptr ref_output = std::make_unique(rows * (cols / 2)); std::unique_ptr ref_output_t = std::make_unique(cols * (rows / 2)); @@ -635,7 +640,6 @@ void performTest(float (*OP)(const float), bool use_2d_quantization = false; std::vector ref_per_token_amax; if (per_token_activation) { - per_token_amax = Tensor("per_token_amax", std::vector{rows}, DType::kFloat32); compute_ref(OP, input.rowwise_cpu_dptr(), ref_output.get(), @@ -651,20 +655,44 @@ void performTest(float (*OP)(const float), use_2d_quantization, &ref_per_token_amax); + NVTETensor output_tensor = output.data(); + NVTEBasicTensor old_amax; + NVTEBasicTensor old_columnwise_amax; + nvte_get_tensor_param_v2(output_tensor, kNVTEAmax, &old_amax, sizeof(old_amax), nullptr); + nvte_get_tensor_param_v2(output_tensor, kNVTEColumnwiseAmax, &old_columnwise_amax, + sizeof(old_columnwise_amax), nullptr); + if (old_amax.data_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaFree(old_amax.data_ptr)); + } + if (old_columnwise_amax.data_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaFree(old_columnwise_amax.data_ptr)); + } + NVTE_CHECK_CUDA(cudaMalloc(&per_token_amax, rows * sizeof(float))); + NVTE_CHECK_CUDA(cudaMalloc(&per_token_columnwise_amax, rows * sizeof(float))); + NVTE_CHECK_CUDA(cudaMemset(per_token_amax, 0, rows * sizeof(float))); + NVTE_CHECK_CUDA(cudaMemset(per_token_columnwise_amax, 0, rows * sizeof(float))); std::vector per_token_amax_shape = {rows}; - NVTEBasicTensor amax_tensor = {per_token_amax.rowwise_dptr(), + NVTEBasicTensor amax_tensor = {per_token_amax, static_cast(DType::kFloat32), nvte_make_shape(per_token_amax_shape.data(), per_token_amax_shape.size())}; - NVTETensor output_tensor = output.data(); + NVTEBasicTensor columnwise_amax_tensor = {per_token_columnwise_amax, + static_cast(DType::kFloat32), + nvte_make_shape(per_token_amax_shape.data(), + per_token_amax_shape.size())}; nvte_set_tensor_param_v2(output_tensor, kNVTEAmax, &amax_tensor, sizeof(amax_tensor)); + nvte_set_tensor_param_v2(output_tensor, kNVTEColumnwiseAmax, &columnwise_amax_tensor, + sizeof(columnwise_amax_tensor)); } else { // Golden value of amax chosen to make the 2nd-stage scaling mantissa zero and avoid rounding issues const float amax = 448.0f * 6.0f * 8.0f; + // Set 2nd stage NVFP4 scaling factor output.set_tensor_amax(amax); output.set_tensor_amax_columnwise(amax); + bool use_2d_quantization = false; + compute_ref(OP, input.rowwise_cpu_dptr(), ref_output.get(), From 233bb4456cee29c767d6a632dd00cc53479a8558 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sat, 2 May 2026 16:24:16 -0700 Subject: [PATCH 18/45] Clean up cpp test Signed-off-by: Ziang Li --- .../cpp/operator/test_cast_nvfp4_transpose.cu | 81 +++++++++---------- 1 file changed, 40 insertions(+), 41 deletions(-) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index cc45b2fce5..f7a16539cc 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -582,10 +582,16 @@ void compareResults_nvfp4(const Tensor &test, compare_nvfp4_tensors("output_t", test_data_t, ref_data_t, cols, rows, atol, rtol); } -void compare_per_token_amax(const float *test_amax, const std::vector &ref_amax) { +void compare_per_token_amax(const Tensor &output, const std::vector &ref_amax) { + NVTEBasicTensor amax; + nvte_get_tensor_param_v2(output.data(), kNVTEAmax, &amax, sizeof(amax), nullptr); + ASSERT_NE(amax.data_ptr, nullptr); + ASSERT_EQ(amax.shape.ndim, 1); + ASSERT_EQ(amax.shape.data[0], ref_amax.size()); + std::vector test_amax_data(ref_amax.size()); ASSERT_EQ(cudaMemcpy(test_amax_data.data(), - test_amax, + amax.data_ptr, ref_amax.size() * sizeof(float), cudaMemcpyDeviceToHost), cudaSuccess); @@ -595,6 +601,32 @@ void compare_per_token_amax(const float *test_amax, const std::vector &re } } +void set_per_token_amax_metadata(Tensor &output, const size_t rows) { + const std::vector shape = {rows}; + NVTETensor output_tensor = output.data(); + + auto replace_amax = [&](const NVTETensorParam param) { + NVTEBasicTensor old_amax; + nvte_get_tensor_param_v2(output_tensor, param, &old_amax, sizeof(old_amax), nullptr); + if (old_amax.data_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaFree(old_amax.data_ptr)); + } + + float *amax = nullptr; + NVTE_CHECK_CUDA(cudaMalloc(&amax, rows * sizeof(float))); + NVTE_CHECK_CUDA(cudaMemset(amax, 0, rows * sizeof(float))); + + NVTEBasicTensor amax_tensor = {amax, + static_cast(DType::kFloat32), + nvte_make_shape(shape.data(), shape.size())}; + nvte_set_tensor_param_v2(output_tensor, param, &amax_tensor, sizeof(amax_tensor)); + return amax; + }; + + replace_amax(kNVTEAmax); + replace_amax(kNVTEColumnwiseAmax); +} + template void performTest(float (*OP)(const float), const std::vector& shape, @@ -627,8 +659,6 @@ void performTest(float (*OP)(const float), Tensor input("input", shape, itype); Tensor output("output", shape, otype, true, true, NVTE_NVFP4_1D_SCALING); - float *per_token_amax = nullptr; - float *per_token_columnwise_amax = nullptr; std::unique_ptr ref_output = std::make_unique(rows * (cols / 2)); std::unique_ptr ref_output_t = std::make_unique(cols * (rows / 2)); @@ -637,9 +667,12 @@ void performTest(float (*OP)(const float), fillCase(&input, InputsFillCase::uniform); - bool use_2d_quantization = false; + // Golden value of amax chosen to make the 2nd-stage scaling mantissa zero and avoid rounding issues + const float amax = 448.0f * 6.0f * 8.0f; std::vector ref_per_token_amax; + bool use_2d_quantization = false; if (per_token_activation) { + set_per_token_amax_metadata(output, rows); compute_ref(OP, input.rowwise_cpu_dptr(), ref_output.get(), @@ -654,45 +687,10 @@ void performTest(float (*OP)(const float), use_fast_math, use_2d_quantization, &ref_per_token_amax); - - NVTETensor output_tensor = output.data(); - NVTEBasicTensor old_amax; - NVTEBasicTensor old_columnwise_amax; - nvte_get_tensor_param_v2(output_tensor, kNVTEAmax, &old_amax, sizeof(old_amax), nullptr); - nvte_get_tensor_param_v2(output_tensor, kNVTEColumnwiseAmax, &old_columnwise_amax, - sizeof(old_columnwise_amax), nullptr); - if (old_amax.data_ptr != nullptr) { - NVTE_CHECK_CUDA(cudaFree(old_amax.data_ptr)); - } - if (old_columnwise_amax.data_ptr != nullptr) { - NVTE_CHECK_CUDA(cudaFree(old_columnwise_amax.data_ptr)); - } - NVTE_CHECK_CUDA(cudaMalloc(&per_token_amax, rows * sizeof(float))); - NVTE_CHECK_CUDA(cudaMalloc(&per_token_columnwise_amax, rows * sizeof(float))); - NVTE_CHECK_CUDA(cudaMemset(per_token_amax, 0, rows * sizeof(float))); - NVTE_CHECK_CUDA(cudaMemset(per_token_columnwise_amax, 0, rows * sizeof(float))); - std::vector per_token_amax_shape = {rows}; - NVTEBasicTensor amax_tensor = {per_token_amax, - static_cast(DType::kFloat32), - nvte_make_shape(per_token_amax_shape.data(), - per_token_amax_shape.size())}; - NVTEBasicTensor columnwise_amax_tensor = {per_token_columnwise_amax, - static_cast(DType::kFloat32), - nvte_make_shape(per_token_amax_shape.data(), - per_token_amax_shape.size())}; - nvte_set_tensor_param_v2(output_tensor, kNVTEAmax, &amax_tensor, sizeof(amax_tensor)); - nvte_set_tensor_param_v2(output_tensor, kNVTEColumnwiseAmax, &columnwise_amax_tensor, - sizeof(columnwise_amax_tensor)); } else { - // Golden value of amax chosen to make the 2nd-stage scaling mantissa zero and avoid rounding issues - const float amax = 448.0f * 6.0f * 8.0f; - // Set 2nd stage NVFP4 scaling factor output.set_tensor_amax(amax); output.set_tensor_amax_columnwise(amax); - - bool use_2d_quantization = false; - compute_ref(OP, input.rowwise_cpu_dptr(), ref_output.get(), @@ -707,6 +705,7 @@ void performTest(float (*OP)(const float), use_fast_math, use_2d_quantization); } + // Initialize stochastic rounding Tensor rng_state("rng_state", std::vector{2}, DType::kInt64); rng_state.rowwise_cpu_dptr()[0] = 123; // rng_seed @@ -769,7 +768,7 @@ void performTest(float (*OP)(const float), scale_mismatches_num); if (per_token_activation) { - compare_per_token_amax(per_token_amax, ref_per_token_amax); + compare_per_token_amax(output, ref_per_token_amax); } } From 47c9cde6f07ad46ca902c24e1337172fe65d786d Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sat, 2 May 2026 16:31:27 -0700 Subject: [PATCH 19/45] Extend cpp dequantize test Signed-off-by: Ziang Li --- tests/cpp/operator/test_dequantize_nvfp4.cu | 117 +++++++++++++++++--- 1 file changed, 99 insertions(+), 18 deletions(-) diff --git a/tests/cpp/operator/test_dequantize_nvfp4.cu b/tests/cpp/operator/test_dequantize_nvfp4.cu index 96e85cb5ed..f932c7dd7a 100644 --- a/tests/cpp/operator/test_dequantize_nvfp4.cu +++ b/tests/cpp/operator/test_dequantize_nvfp4.cu @@ -42,7 +42,7 @@ float2 cvt_fp4x2_to_float2(fp4e2m1x2 fp4_pair) { template void compute_ref_dequantize_nvfp4(const uint8_t *packed_data, const fp8e4m3 *scales, - float amax, + const std::vector &amax, OType *output, size_t rows, size_t cols, @@ -55,7 +55,8 @@ void compute_ref_dequantize_nvfp4(const uint8_t *packed_data, for (size_t row = 0; row < rows; ++row) { for (size_t block = 0; block < Mread; ++block) { const fp8e4m3 scale = scales[row * scale_stride + block]; - const float final_scale = static_cast(scale) * amax * factor_inv; + const float final_scale = + static_cast(scale) * (amax.size() == 1 ? amax[0] : amax[row]) * factor_inv; for (size_t pair_idx = 0; pair_idx < bytes_per_block; ++pair_idx) { const size_t byte_idx = @@ -74,6 +75,43 @@ void compute_ref_dequantize_nvfp4(const uint8_t *packed_data, } } +void set_per_token_amax_metadata(Tensor &output, const size_t rows) { + const std::vector shape = {rows}; + NVTETensor output_tensor = output.data(); + + auto replace_amax = [&](const NVTETensorParam param) { + NVTEBasicTensor old_amax; + nvte_get_tensor_param_v2(output_tensor, param, &old_amax, sizeof(old_amax), nullptr); + if (old_amax.data_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaFree(old_amax.data_ptr)); + } + + float *amax = nullptr; + NVTE_CHECK_CUDA(cudaMalloc(&amax, rows * sizeof(float))); + NVTE_CHECK_CUDA(cudaMemset(amax, 0, rows * sizeof(float))); + + NVTEBasicTensor amax_tensor = {amax, + static_cast(DType::kFloat32), + nvte_make_shape(shape.data(), shape.size())}; + nvte_set_tensor_param_v2(output_tensor, param, &amax_tensor, sizeof(amax_tensor)); + }; + + replace_amax(kNVTEAmax); + replace_amax(kNVTEColumnwiseAmax); +} + +std::vector get_amax_values(const Tensor &tensor) { + NVTEBasicTensor amax; + nvte_get_tensor_param_v2(tensor.data(), kNVTEAmax, &amax, sizeof(amax), nullptr); + const size_t numel = amax.shape.ndim == 0 ? 1 : amax.shape.data[0]; + std::vector amax_values(numel); + if (numel > 0) { + NVTE_CHECK_CUDA(cudaMemcpy(amax_values.data(), amax.data_ptr, numel * sizeof(float), + cudaMemcpyDeviceToHost)); + } + return amax_values; +} + template float compute_amax(const test::Tensor &t, size_t rows, size_t cols) { t.to_cpu(); @@ -88,7 +126,8 @@ float compute_amax(const test::Tensor &t, size_t rows, size_t cols) { // Quantize a high-precision input to NVFP4, then dequantize and compare // against a CPU reference computed from the quantized data. template -void performTest_dequantize_nvfp4(const size_t rows, const size_t cols) { +void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, + const bool per_token_activation) { using namespace test; DType otype = TypeInfo::dtype; @@ -97,14 +136,22 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols) { Tensor quantized("quantized", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); - if (rows > 0 && cols > 0) { + if (per_token_activation) { + set_per_token_amax_metadata(quantized, rows); + } else if (rows > 0 && cols > 0) { quantized.set_tensor_amax(compute_amax(input, rows, cols)); } else { quantized.set_tensor_amax(0.0f); } if (rows > 0 && cols > 0) { - nvte_quantize(input.data(), quantized.data(), 0); + if (per_token_activation) { + QuantizationConfigWrapper quant_config; + quant_config.set_nvfp4_per_token_activation(true); + nvte_quantize_v2(input.data(), quantized.data(), quant_config, 0); + } else { + nvte_quantize(input.data(), quantized.data(), 0); + } cudaDeviceSynchronize(); } @@ -120,7 +167,7 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols) { const uint8_t *fp4_data = reinterpret_cast(quantized.rowwise_cpu_dptr()); const fp8e4m3 *scales = quantized.rowwise_cpu_scale_inv_ptr(); - const float amax_val = quantized.amax(); + const std::vector amax_val = get_amax_values(quantized); const NVTEShape scale_shape = quantized.rowwise_scale_inv_shape(); const size_t scale_stride = scale_shape.data[scale_shape.ndim - 1]; @@ -137,7 +184,8 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols) { // Dequantize NVFP4 with GEMM-swizzled scales and compare against compact path. template -void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols) { +void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, + const bool per_token_activation) { using namespace test; DType otype = TypeInfo::dtype; @@ -146,14 +194,22 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols) Tensor quantized_compact("quantized_compact", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); - if (rows > 0 && cols > 0) { + if (per_token_activation) { + set_per_token_amax_metadata(quantized_compact, rows); + } else if (rows > 0 && cols > 0) { quantized_compact.set_tensor_amax(compute_amax(input, rows, cols)); } else { quantized_compact.set_tensor_amax(0.0f); } if (rows > 0 && cols > 0) { - nvte_quantize(input.data(), quantized_compact.data(), 0); + if (per_token_activation) { + QuantizationConfigWrapper quant_config; + quant_config.set_nvfp4_per_token_activation(true); + nvte_quantize_v2(input.data(), quantized_compact.data(), quant_config, 0); + } else { + nvte_quantize(input.data(), quantized_compact.data(), 0); + } cudaDeviceSynchronize(); } @@ -165,13 +221,30 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols) // Create tensor with same FP4 data but swizzled scales Tensor quantized_swizzled("quantized_swizzled", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); - quantized_swizzled.set_tensor_amax(0.0f); + if (per_token_activation) { + set_per_token_amax_metadata(quantized_swizzled, rows); + } else { + quantized_swizzled.set_tensor_amax(0.0f); + } quantized_swizzled.set_with_gemm_swizzled_scales(true); // Copy amax and scale from compact to swizzled before FP4 data, // since from_cpu() uploads all CPU buffers (including zero-init data). quantized_compact.to_cpu(); - quantized_swizzled.set_tensor_amax(quantized_compact.amax()); + if (per_token_activation) { + NVTEBasicTensor compact_amax; + NVTEBasicTensor swizzled_amax; + nvte_get_tensor_param_v2(quantized_compact.data(), kNVTEAmax, &compact_amax, + sizeof(compact_amax), nullptr); + nvte_get_tensor_param_v2(quantized_swizzled.data(), kNVTEAmax, &swizzled_amax, + sizeof(swizzled_amax), nullptr); + if (rows > 0) { + NVTE_CHECK_CUDA(cudaMemcpy(swizzled_amax.data_ptr, compact_amax.data_ptr, + rows * sizeof(float), cudaMemcpyDeviceToDevice)); + } + } else { + quantized_swizzled.set_tensor_amax(quantized_compact.amax()); + } // Copy FP4 data after from_cpu() to avoid being overwritten const size_t data_bytes = rows * cols / 2; @@ -227,7 +300,8 @@ std::vector> nvfp4_tensor_dims = { class DequantizeNVFP4TestSuite : public ::testing::TestWithParam , - transformer_engine::DType>> {}; + transformer_engine::DType, + bool>> {}; TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4) { @@ -237,10 +311,11 @@ TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4) const auto tensor_size = std::get<0>(GetParam()); const DType output_type = std::get<1>(GetParam()); + const bool per_token_activation = std::get<2>(GetParam()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType, performTest_dequantize_nvfp4( - tensor_size.first, tensor_size.second); + tensor_size.first, tensor_size.second, per_token_activation); ); } @@ -249,19 +324,22 @@ INSTANTIATE_TEST_SUITE_P( DequantizeNVFP4TestSuite, ::testing::Combine( ::testing::ValuesIn(nvfp4_tensor_dims), - ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16)), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Bool()), [](const testing::TestParamInfo& info) { std::string name = std::to_string(std::get<0>(info.param).first) + "X" + std::to_string(std::get<0>(info.param).second) + "X" + - test::typeName(std::get<1>(info.param)); + test::typeName(std::get<1>(info.param)) + "X" + + (std::get<2>(info.param) ? "PerToken" : "PerTensor"); return name; } ); class DequantizeNVFP4SwizzledTestSuite : public ::testing::TestWithParam , - transformer_engine::DType>> {}; + transformer_engine::DType, + bool>> {}; TEST_P(DequantizeNVFP4SwizzledTestSuite, TestDequantizeNVFP4Swizzled) { @@ -271,10 +349,11 @@ TEST_P(DequantizeNVFP4SwizzledTestSuite, TestDequantizeNVFP4Swizzled) const auto tensor_size = std::get<0>(GetParam()); const DType output_type = std::get<1>(GetParam()); + const bool per_token_activation = std::get<2>(GetParam()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType, performTest_dequantize_nvfp4_swizzled( - tensor_size.first, tensor_size.second); + tensor_size.first, tensor_size.second, per_token_activation); ); } @@ -283,12 +362,14 @@ INSTANTIATE_TEST_SUITE_P( DequantizeNVFP4SwizzledTestSuite, ::testing::Combine( ::testing::ValuesIn(nvfp4_tensor_dims), - ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16)), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Bool()), [](const testing::TestParamInfo& info) { std::string name = std::to_string(std::get<0>(info.param).first) + "X" + std::to_string(std::get<0>(info.param).second) + "X" + test::typeName(std::get<1>(info.param)) + "X" + + (std::get<2>(info.param) ? "PerToken" : "PerTensor") + "X" + "Swizzled"; return name; } From 21a19f5ecf882b0c5faa463dcdf2721b6a9692dd Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sat, 2 May 2026 17:23:27 -0700 Subject: [PATCH 20/45] Only pass `per_token_activation` to forward activation quantizer and clean up Signed-off-by: Ziang Li --- tests/pytorch/test_backward_override.py | 56 +++++++------------ tests/pytorch/test_recipe.py | 27 ++++++++- tests/pytorch/utils.py | 6 -- .../pytorch/cpp_extensions/gemm.py | 12 +++- transformer_engine/pytorch/quantization.py | 4 +- 5 files changed, 59 insertions(+), 46 deletions(-) diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index 2156d6cef0..ed099314f8 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -99,12 +99,6 @@ def backward_override(request: pytest.FixtureRequest) -> str: return request.param -def _make_backward_test_recipe(recipe_name: str, **recipe_kwargs) -> Optional[recipe.Recipe]: - if recipe_name == "nvfp4_per_token" and "backward_override" not in recipe_kwargs: - recipe_kwargs["backward_override"] = "dequantized" - return make_recipe(recipe_name, **recipe_kwargs) - - # -------------------------- # Test cases # -------------------------- @@ -867,7 +861,7 @@ def test_linear_like_backward_override_matches_reference( _maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, module_type) in_features = input_shape[-1] - quantized_ref_recipe = _make_backward_test_recipe(recipe_name) + quantized_ref_recipe = make_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override(module_type, mode_recipe, backward_override) @@ -1051,7 +1045,7 @@ def test_grouped_linear_backward_override_matches_reference( num_gemms = len(m_splits) num_tokens = sum(m_splits) - quantized_ref_recipe = _make_backward_test_recipe(recipe_name) + quantized_ref_recipe = make_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override("grouped_linear", mode_recipe, backward_override) @@ -1220,11 +1214,9 @@ def test_linear_like_runtime_backward_override_switch_updates_ctx( x = torch.randn(*input_shape, dtype=dtype, device="cuda") dy = torch.randn(*input_shape[:-1], out_features, dtype=dtype, device="cuda") - default_recipe = _make_backward_test_recipe(recipe_name) + default_recipe = make_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override(module_type, mode_recipe, backward_override) - expected_default_mode = default_recipe.backward_override - expected_default_fp8 = expected_default_mode is None *_, default_ctx = _run_single_step_with_ctx_state(module, x, dy, default_recipe) ( @@ -1233,10 +1225,10 @@ def test_linear_like_runtime_backward_override_switch_updates_ctx( default_grad_output_quantizer, default_reduce_and_update, ) = default_ctx - assert default_mode == expected_default_mode - assert default_fp8 == expected_default_fp8 - assert (default_grad_output_quantizer is not None) == expected_default_fp8 - assert default_reduce_and_update == expected_default_fp8 + assert default_mode is None + assert default_fp8 + assert default_grad_output_quantizer is not None + assert default_reduce_and_update *_, switched_ctx = _run_single_step_with_ctx_state(module, x, dy, mode_recipe) switched_mode, switched_fp8, switched_grad_output_quantizer, switched_reduce_and_update = ( @@ -1254,10 +1246,10 @@ def test_linear_like_runtime_backward_override_switch_updates_ctx( default_grad_output_quantizer_after, default_reduce_and_update_after, ) = default_ctx_after - assert default_mode_after == expected_default_mode - assert default_fp8_after == expected_default_fp8 - assert (default_grad_output_quantizer_after is not None) == expected_default_fp8 - assert default_reduce_and_update_after == expected_default_fp8 + assert default_mode_after is None + assert default_fp8_after + assert default_grad_output_quantizer_after is not None + assert default_reduce_and_update_after @pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) @@ -1292,11 +1284,9 @@ def test_grouped_linear_runtime_backward_override_switch_updates_ctx( x = torch.randn(num_tokens, in_features, dtype=dtype, device="cuda") dy = torch.randn(num_tokens, out_features, dtype=dtype, device="cuda") - default_recipe = _make_backward_test_recipe(recipe_name) + default_recipe = make_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override("grouped_linear", mode_recipe, backward_override) - expected_default_mode = default_recipe.backward_override - expected_default_fp8 = expected_default_mode is None *_, default_ctx = _run_grouped_linear_single_step_with_ctx_state( module, @@ -1306,9 +1296,9 @@ def test_grouped_linear_runtime_backward_override_switch_updates_ctx( default_recipe, ) default_mode, default_fp8, default_reduce_and_update = default_ctx - assert default_mode == expected_default_mode - assert default_fp8 == expected_default_fp8 - assert default_reduce_and_update == expected_default_fp8 + assert default_mode is None + assert default_fp8 + assert default_reduce_and_update *_, switched_ctx = _run_grouped_linear_single_step_with_ctx_state( module, @@ -1330,9 +1320,9 @@ def test_grouped_linear_runtime_backward_override_switch_updates_ctx( default_recipe, ) default_mode_after, default_fp8_after, default_reduce_and_update_after = default_ctx_after - assert default_mode_after == expected_default_mode - assert default_fp8_after == expected_default_fp8 - assert default_reduce_and_update_after == expected_default_fp8 + assert default_mode_after is None + assert default_fp8_after + assert default_reduce_and_update_after @pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) @@ -1363,7 +1353,7 @@ def test_fused_linear_paths_match_backward_override_reference( reset_rng_states() - quantized_ref_recipe = _make_backward_test_recipe(recipe_name) + quantized_ref_recipe = make_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override("ops_linear", mode_recipe, backward_override) @@ -1504,7 +1494,7 @@ def test_fused_bias_activation_matches_masked_linear_backward( reset_rng_states() in_features = input_shape[-1] - quantized_ref_recipe = _make_backward_test_recipe(recipe_name) + quantized_ref_recipe = make_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override("ops_linear", mode_recipe, backward_override) @@ -1743,11 +1733,7 @@ def test_backward_override_memory_peak_report( x = torch.randn(*input_shape, dtype=dtype, device="cuda") dy = torch.randn(*input_shape[:-1], out_features, dtype=dtype, device="cuda") - modes = ( - ("high_precision", "dequantized") - if recipe_name == "nvfp4_per_token" - else (None, "high_precision", "dequantized") - ) + modes = (None, "high_precision", "dequantized") mode_results: dict[str, dict[str, float] | str] = {} for mode in modes: diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index b44f27765a..f12148232c 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -25,10 +25,16 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.quantization import ( FP8GlobalStateManager, + NVFP4BlockScalingRecipeState, _amax_and_scale_update, ) import transformer_engine.pytorch.ops as te_ops -from transformer_engine.common.recipe import DelayedScaling, Float8BlockScaling, MXFP8BlockScaling +from transformer_engine.common.recipe import ( + DelayedScaling, + Float8BlockScaling, + MXFP8BlockScaling, + NVFP4BlockScaling, +) # Check if FP8 is supported fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) @@ -507,6 +513,25 @@ def test_quantizer_update(self, module_class): y = module(x) +@pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) +def test_nvfp4_per_token_quantizer_roles(): + recipe = NVFP4BlockScaling(per_token_activation=True) + + forward_quantizers = NVFP4BlockScalingRecipeState( + recipe, + mode="forward", + num_quantizers=3, + ).make_quantizers() + assert [q.per_token_activation for q in forward_quantizers] == [True, False, True] + + backward_quantizers = NVFP4BlockScalingRecipeState( + recipe, + mode="backward", + num_quantizers=2, + ).make_quantizers() + assert [q.per_token_activation for q in backward_quantizers] == [False, False] + + @pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_per_token"]) diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index b88bcd31b5..3dc4cdffe8 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -175,12 +175,6 @@ def skip_unsupported_backward_override( backward_override: Optional[str], ) -> None: """Skip known unsupported layer/recipe/backward-override combinations used in tests.""" - if ( - quant_recipe is not None - and getattr(quant_recipe, "per_token_activation", False) - and backward_override is None - ): - pytest.skip("Per-token NVFP4 requires an explicit backward override.") if backward_override is None: return if quant_recipe is None and backward_override is not None: diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index d23fdf1b59..79a7d28df5 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -203,7 +203,11 @@ def general_gemm( out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) else: assert layout[1] == "N", "Per-token NVFP4 GEMM currently supports N-layout B only." - assert not grad, "Per-token NVFP4 GEMM currently supports fprop only." + if grad: + raise RuntimeError( + "Per-token NVFP4 GEMM currently supports fprop only. " + "Backward NVFP4 gradient quantizers should use scalar global amax." + ) assert not gelu, "Per-token NVFP4 GEMM currently does not support fused GELU." assert not accumulate, "Per-token NVFP4 GEMM currently does not support accumulation." assert ( @@ -303,7 +307,11 @@ def general_grouped_gemm( if any(_is_nvfp4_per_token_tensor(tensor) for tensor in B): assert layout[1] == "N", "Per-token NVFP4 grouped GEMM currently supports N-layout B only." - assert not grad, "Per-token NVFP4 grouped GEMM currently supports fprop only." + if grad: + raise RuntimeError( + "Per-token NVFP4 grouped GEMM currently supports fprop only. " + "Backward NVFP4 gradient quantizers should use scalar global amax." + ) assert not gelu, "Per-token NVFP4 grouped GEMM currently does not support fused GELU." assert ( not accumulate diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 6ffca84a7d..2cb6c21946 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -1375,7 +1375,7 @@ def _make_quantizer(idx: int) -> NVFP4Quantizer: with_post_rht_amax=qparams.random_hadamard_transform, with_2d_quantization=qparams.fp4_2d_quantization, stochastic_rounding=qparams.stochastic_rounding, - per_token_activation=self.recipe.per_token_activation, + per_token_activation=self.recipe.per_token_activation and idx % 3 != 1, ) return [_make_quantizer(idx) for idx in range(self.num_quantizers)] @@ -1390,7 +1390,7 @@ def _make_quantizer(idx: int) -> NVFP4Quantizer: with_post_rht_amax=self.recipe.fp4_quant_bwd_grad.random_hadamard_transform, with_2d_quantization=self.recipe.fp4_quant_bwd_grad.fp4_2d_quantization, stochastic_rounding=self.recipe.fp4_quant_bwd_grad.stochastic_rounding, - per_token_activation=self.recipe.per_token_activation, + per_token_activation=False, ) for _ in range(self.num_quantizers) ] From 75c19d0e172974a620fc38248851895f23a6c583 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sat, 2 May 2026 17:46:18 -0700 Subject: [PATCH 21/45] Minor fix test Signed-off-by: Ziang Li --- tests/pytorch/test_backward_override.py | 2 +- tests/pytorch/test_sanity.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index ed099314f8..15f08975e2 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -255,7 +255,7 @@ def _maybe_skip_unsupported_grouped_splits(recipe_name: str, m_splits: list[int] pytest.skip("GroupedLinear + MXFP8 requires each non-empty m_split divisible by 32.") if recipe_name in ("nvfp4", "nvfp4_per_token") and any(m % 16 != 0 for m in non_empty_splits): pytest.skip("GroupedLinear + NVFP4 requires each non-empty m_split divisible by 16.") - if recipe_name == "nvfp4" and any(m % 64 != 0 for m in non_empty_splits): + if recipe_name in ("nvfp4", "nvfp4_per_token") and any(m % 64 != 0 for m in non_empty_splits): pytest.skip( "GroupedLinear + NVFP4 grouped split_quantize currently requires each non-empty " "m_split divisible by 64 due to grouped amax kernel constraints." diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index c7527ecfe4..bb1c952163 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -580,6 +580,8 @@ def test_sanity_grouped_linear( pytest.skip("NVFP4 not supported for grouped linear") if dtype == torch.float16: pytest.skip("FP16 output for NVFP4 not supported") + if backward_override is None and dtype != torch.bfloat16: + pytest.skip("NVFP4 grouped default backward requires BF16 grad output") use_fp8 = fp8_recipe is not None with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): From a3e8305867ca7bb02b3f93076951808df0642427 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 4 May 2026 01:15:34 -0700 Subject: [PATCH 22/45] Improve accuracy by unfolding weight per-tensor fp32 Signed-off-by: Ziang Li --- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 110 ++++++++++++++++++ .../pytorch/cpp_extensions/gemm.py | 53 ++++++--- 2 files changed, 144 insertions(+), 19 deletions(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index ef6eda8dcd..7c9f2c7eb6 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -314,6 +314,78 @@ def check_nvfp4_per_token_grouped_gemm_matches_per_gemm( torch.testing.assert_close(grouped, ref, atol=0.0, rtol=0.0) +def check_nvfp4_per_token_gemm_matches_emulated( + x_dtype: torch.dtype, + w_dtype: torch.dtype, + out_dtype: torch.dtype, + M: int, + K: int, + N: int, +): + te_dtype = tex.DType.kFloat4E2M1 + device = "cuda" + torch.manual_seed(37) + torch.cuda.manual_seed(37) + + x = torch.randn((M, K), dtype=x_dtype, device=device) + w = torch.randn((N, K), dtype=w_dtype, device=device) + + x_per_token_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + per_token_activation=True, + ) + x_tensorwise_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + ) + w_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + ) + + x_per_token = x_per_token_quantizer.update_quantized( + x, x_per_token_quantizer.make_empty(x.shape, dtype=x_dtype, device=device) + ) + w_nvfp4 = w_quantizer.update_quantized( + w, w_quantizer.make_empty(w.shape, dtype=w_dtype, device=device) + ) + y_per_token = general_gemm(w_nvfp4, x_per_token, out_dtype=out_dtype, layout="TN")[0] + + emulated_rows = [] + for i in range(M): + x_padded = torch.zeros((16, K), dtype=x_dtype, device=device) + x_padded[0].copy_(x[i]) + x_tensorwise = x_tensorwise_quantizer.update_quantized( + x_padded, + x_tensorwise_quantizer.make_empty(x_padded.shape, dtype=x_dtype, device=device), + ) + emulated_rows.append( + general_gemm(w_nvfp4, x_tensorwise, out_dtype=out_dtype, layout="TN")[0][:1] + ) + + y_emulated = torch.cat(emulated_rows, dim=0) + if out_dtype == torch.bfloat16: + torch.testing.assert_close(y_per_token, y_emulated, atol=0.0, rtol=7.8e-3) + else: + torch.testing.assert_close(y_per_token, y_emulated, atol=3.0517578125e-5, rtol=0.0) + + @pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) @pytest.mark.parametrize( "M, K, N", @@ -419,3 +491,41 @@ def test_nvfp4_per_token_grouped_gemm_matches_per_gemm( use_bias=use_bias, single_output=single_output, ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, K, N", + [ + (128, 128, 128), + (256, 128, 256), + (256, 256, 256), + (256, 1024, 256), + (1024, 1024, 1024), + (4096, 512, 3072), + (112, 128, 96), + (304, 640, 304), + (1008, 3072, 992), + (256, 64, 256), + (128, 128, 112), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +def test_nvfp4_per_token_gemm_matches_emulated( + M: int, + K: int, + N: int, + x_dtype: torch.dtype, + w_dtype: torch.dtype, + out_dtype: torch.dtype, +): + check_nvfp4_per_token_gemm_matches_emulated( + x_dtype=x_dtype, + w_dtype=w_dtype, + out_dtype=out_dtype, + M=M, + K=K, + N=N, + ) diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 79a7d28df5..394778f96b 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -78,20 +78,34 @@ def _is_nvfp4_per_token_tensor(tensor: torch.Tensor) -> bool: return amax is not None and amax.numel() > 1 -def _nvfp4_per_token_gemm_input( - tensor: NVFP4TensorStorage, -) -> Tuple[NVFP4TensorStorage, torch.Tensor]: - """Return a GEMM alias with identity activation amax and the original per-token amax.""" - metadata = tensor.get_metadata() - if tensor._amax_rowwise is not None: - amax = tensor._amax_rowwise - assert amax is not None and amax.numel() > 1 - metadata["amax_rowwise"] = amax.new_ones(1) +def _nvfp4_per_token_gemm_inputs( + A: NVFP4TensorStorage, + B: NVFP4TensorStorage, + *, + transa: bool, +) -> Tuple[NVFP4TensorStorage, NVFP4TensorStorage, torch.Tensor]: + """Return GEMM aliases and FP32 output scales for per-token NVFP4.""" + A_metadata = A.get_metadata() + weight_amax = A._amax_rowwise if transa else A._amax_columnwise + assert weight_amax is not None and weight_amax.numel() == 1 + A_metadata["amax_rowwise" if transa else "amax_columnwise"] = weight_amax.new_ones(1) + + B_metadata = B.get_metadata() + if B._amax_rowwise is not None: + activation_amax = B._amax_rowwise + assert activation_amax.numel() > 1 + B_metadata["amax_rowwise"] = activation_amax.new_ones(1) else: - amax = tensor._amax_columnwise - assert amax is not None and amax.numel() > 1 - metadata["amax_columnwise"] = amax.new_ones(1) - return NVFP4TensorStorage(**metadata), amax + activation_amax = B._amax_columnwise + assert activation_amax is not None and activation_amax.numel() > 1 + B_metadata["amax_columnwise"] = activation_amax.new_ones(1) + + assert activation_amax.dtype == torch.float32 and weight_amax.dtype == torch.float32 + return ( + NVFP4TensorStorage(**A_metadata), + NVFP4TensorStorage(**B_metadata), + (activation_amax * weight_amax).view(-1, 1), + ) def general_gemm( @@ -216,10 +230,10 @@ def general_gemm( assert out is None or ( isinstance(out, torch.Tensor) and not is_custom(out) ), "Per-token NVFP4 GEMM currently supports only plain torch.Tensor outputs." - # cuBLAS folds the first activation amax into GEMM alpha. Keep per-token amax out of - # alpha by using identity here, then apply the true per-token scale in FP32 below. - gemm_B, amax = _nvfp4_per_token_gemm_input(B) - per_token_scales = amax.view(-1, 1) + assert isinstance(A, NVFP4TensorStorage), "Per-token NVFP4 GEMM currently requires NVFP4 A." + # cuBLAS folds NVFP4 global amax values into GEMM alpha. Keep the per-token + # recipe's global scales out of alpha and apply them in FP32 below. + gemm_A, gemm_B, per_token_scales = _nvfp4_per_token_gemm_inputs(A, B, transa=transa) requested_out, requested_out_dtype = out, out_dtype fp32_out = ( @@ -228,6 +242,7 @@ def general_gemm( else None ) gemm_args = list(args) + gemm_args[0] = gemm_A # A gemm_args[2] = gemm_B # B gemm_args[4] = fp32_out # out gemm_args[5] = None # quantization_params @@ -235,8 +250,8 @@ def general_gemm( out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*gemm_args, **kwargs) out_2d = out.reshape(-1, out.shape[-1]) - assert amax.dtype == torch.float32 and out.dtype == torch.float32 - assert amax.numel() == out_2d.shape[0] + assert per_token_scales.dtype == torch.float32 and out.dtype == torch.float32 + assert per_token_scales.numel() == out_2d.shape[0] if bias is not None: bias_cast = bias.to(dtype=torch.float32) From 027cb79120648462920c1bd70688a00d152fe250 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 4 May 2026 23:02:03 -0700 Subject: [PATCH 23/45] Fold row-wise quantization Signed-off-by: Ziang Li --- .../common/cast/dispatch/quantize.cuh | 12 +- .../cast/nvfp4/quantize_per_token_nvfp4.cuh | 285 +++++------------- .../cast/nvfp4/quantize_transpose_nvfp4.cuh | 58 ++-- .../quantize_transpose_nvfp4_tuned_1D.cuh | 62 ++-- .../common/transpose/cast_transpose.h | 2 +- ...quantize_transpose_vector_blockwise_fp4.cu | 88 ++++-- 6 files changed, 211 insertions(+), 296 deletions(-) diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 1200979f6b..c204790861 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -105,7 +105,8 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, if (per_token_activation) { NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, "Per-token NVFP4 quantization does not support 2D quantization."); - nvfp4::quantize_per_token(*input_tensor, noop_tensor, output_tensor, stream); + nvfp4::quantize_per_token(*input_tensor, noop_tensor, output_tensor, &quant_config_cpp, + stream); break; } bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && @@ -134,7 +135,8 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, /*rng_state=*/quant_config_cpp.rng_state, /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, - /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); + /*per_token_rowwise_scaling=*/false, /*noop_tensor=*/noop_tensor->data, + /*stream=*/stream); } break; } @@ -251,7 +253,8 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens if (per_token_activation) { NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, "Per-token NVFP4 quantization does not support 2D quantization."); - nvfp4::quantize_per_token(*grad_tensor, noop_tensor, output_tensor, stream); + nvfp4::quantize_per_token(*grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, + stream); break; } bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && @@ -280,7 +283,8 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, /*rng_state=*/quant_config_cpp.rng_state, /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, - /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); + /*per_token_rowwise_scaling=*/false, /*noop_tensor=*/noop_tensor->data, + /*stream=*/stream); } break; } diff --git a/transformer_engine/common/cast/nvfp4/quantize_per_token_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_per_token_nvfp4.cuh index c4b16c557e..53c361c22d 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_per_token_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_per_token_nvfp4.cuh @@ -18,9 +18,11 @@ #include #include "../../common.h" +#include "../../transpose/cast_transpose.h" #include "../../util/ptx.cuh" #include "../../utils.cuh" #include "core_nvfp4.cuh" +#include "quantize_transpose_nvfp4.cuh" #if FP4_TYPE_SUPPORTED #include @@ -57,130 +59,6 @@ __device__ __forceinline__ float abs_max_2x_to_float(const ptx::FPx2 &val } } -template -__global__ void -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -__launch_bounds__(BLOCK_SIZE) -#endif - quantize_per_token_nvfp4_kernel(const int num_rows, const int num_cols, - const IType *__restrict__ input, - const int *__restrict__ row_offsets, - uint8_t *__restrict__ output_data, - fp8e4m3 *__restrict__ output_scales, - float *__restrict__ output_per_token_amax, - const int scale_stride, const float *__restrict__ noop) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - using namespace detail; - if (noop != nullptr && noop[0] == 1.0f) { - return; - } - - using IType2 = typename ptx::FPx2; - - const int row_idx = blockIdx.x; - if (row_idx >= num_rows) return; - - const int actual_row = (row_offsets != nullptr) ? row_offsets[row_idx] : row_idx; - if (actual_row < 0) return; - - const int num_vec2 = num_cols / 2; - const IType2 *input_row = reinterpret_cast(input + actual_row * num_cols); - - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; - for (int i = threadIdx.x; i < num_vec2; i += BLOCK_SIZE) { - const IType2 val = input_row[i]; - abs_max_2x_update(thread_amax_2x, val); - } - const float thread_max = abs_max_2x_to_float(thread_amax_2x); - - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - float row_amax = - BlockReduce(temp_storage).Reduce(thread_max, [](float a, float b) { return fmaxf(a, b); }); - - __shared__ float shared_s_enc; - if (threadIdx.x == 0) { - const float s_enc = compute_global_encode_scaling_factor_FP4(row_amax); - output_per_token_amax[row_idx] = row_amax; - shared_s_enc = s_enc; - } - __syncthreads(); - const float S_enc = shared_s_enc; - const float S_dec_rowwise = 1.0 / S_enc; - constexpr float fp4_max_inv = 1.0f / detail::TypeExtrema::max; - const float global_encode_scale_multiplier = S_enc * fp4_max_inv; - - const int num_sf_blocks = num_cols / PERTOKEN_SF_VEC_SIZE; - for (int sf_idx = threadIdx.x; sf_idx < num_sf_blocks; sf_idx += BLOCK_SIZE) { - const int col_start = sf_idx * PERTOKEN_SF_VEC_SIZE; - - IType2 block_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; - alignas(8) IType2 vals[PERTOKEN_SF_VEC_SIZE / 2]; - const IType2 *input_block = - reinterpret_cast(input + actual_row * num_cols + col_start); - for (int j = 0; j < PERTOKEN_SF_VEC_SIZE / 2; ++j) { - vals[j] = input_block[j]; - abs_max_2x_update(block_amax_2x, vals[j]); - } - const float block_max = abs_max_2x_to_float(block_amax_2x); - - const float S_dec_b_f32 = - fminf(block_max * global_encode_scale_multiplier, detail::TypeExtrema::max); - const nvfp4_scale_t S_dec_b_fp8 = static_cast(S_dec_b_f32); - output_scales[row_idx * scale_stride + sf_idx] = S_dec_b_fp8; - - constexpr float float_max = detail::TypeExtrema::max; - const float block_scale_inverse = - fminf(1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise), float_max); - const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; - - uint8_t *out_ptr = output_data + actual_row * (num_cols / 2) + col_start / 2; - if constexpr (std::is_same_v) { - auto *out_fp4_8x = reinterpret_cast(out_ptr); - for (int j = 0; j < PERTOKEN_SF_VEC_SIZE / 2; j += 4) { - const uint64_t elts03 = *reinterpret_cast(&vals[j]); - const uint64_t elts47 = *reinterpret_cast(&vals[j + 2]); - out_fp4_8x[j / 4] = ptx::mul_cvt_bf16_to_fp4_8x_round_to_nearest( - elts03, elts47, block_scale_inverse); - } - } else { - auto *out_fp4 = reinterpret_cast(out_ptr); - for (int j = 0; j < PERTOKEN_SF_VEC_SIZE / 2; j += 2) { - const float2 in01 = - make_float2(static_cast(vals[j].x), static_cast(vals[j].y)); - const float2 in23 = - make_float2(static_cast(vals[j + 1].x), static_cast(vals[j + 1].y)); - out_fp4[j / 2] = ptx::mul_cvt_fp32_to_fp4_4x( - in01, in23, block_scale_inverse_2x, /*rbits=*/0u); - } - } - } -#endif -} - -template -void launch_quantize_per_token_nvfp4(const int num_rows, const int num_cols, const IType *input, - const int *row_offsets, uint8_t *output_data, - fp8e4m3 *output_scales, float *output_per_token_amax, - const int scale_stride, cudaStream_t stream, - const float *noop = nullptr) { -#if FP4_TYPE_SUPPORTED - if (num_rows == 0 || num_cols == 0) return; - - NVTE_CHECK(num_cols % PERTOKEN_SF_VEC_SIZE == 0, "num_cols must be a multiple of ", - PERTOKEN_SF_VEC_SIZE, " for per-token NVFP4 quantization, got ", num_cols); - dim3 grid(num_rows); - dim3 block(PERTOKEN_BLOCK_SIZE); - - quantize_per_token_nvfp4_kernel - <<>>(num_rows, num_cols, input, row_offsets, output_data, - output_scales, output_per_token_amax, scale_stride, noop); - NVTE_CHECK_CUDA(cudaGetLastError()); -#else - NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); -#endif -} - template __global__ void #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) @@ -337,8 +215,10 @@ void launch_quantize_per_token_nvfp4_columnwise(const int num_rows, const int nu } // namespace quantize_per_token_kernel inline void quantize_per_token(const Tensor &input, const Tensor *noop, Tensor *output, - cudaStream_t stream) { + const QuantizationConfig *quant_config, cudaStream_t stream) { #if FP4_TYPE_SUPPORTED + using namespace detail; + checkCuDriverContext(stream); CheckNoopTensor(*noop, "cast_noop"); CheckInputTensor(input, "input"); @@ -368,111 +248,86 @@ inline void quantize_per_token(const Tensor &input, const Tensor *noop, Tensor * NVTE_CHECK(output->columnwise_amax.numel() == rows, "Per-token columnwise amax must have ", rows, " entries, got ", output->columnwise_amax.shape, "."); } - const int *row_offsets = nullptr; if (input.dtype() == DType::kBFloat16) { const auto *input_ptr = reinterpret_cast(input.data.dptr); - if (output->has_data()) { - NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Rowwise output must have FP4 type."); - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Rowwise scaling tensor must be allocated."); - NVTE_CHECK(output->amax.dptr != nullptr, "Rowwise per-token amax tensor must be allocated."); - auto *data_ptr = reinterpret_cast(output->data.dptr); - auto *scale_ptr = reinterpret_cast(output->scale_inv.dptr); - const int scale_stride = static_cast(output->scale_inv.shape.back()); - quantize_per_token_kernel::launch_quantize_per_token_nvfp4<__nv_bfloat16>( - static_cast(rows), static_cast(cols), input_ptr, row_offsets, data_ptr, - scale_ptr, amax_ptr, scale_stride, stream, noop_ptr); + quantize_per_token_kernel::launch_compute_per_token_amax<__nv_bfloat16>( + static_cast(rows), static_cast(cols), input_ptr, per_token_amax_ptr, stream, + noop_ptr); + } else if (input.dtype() == DType::kFloat16) { + const auto *input_ptr = reinterpret_cast(input.data.dptr); + quantize_per_token_kernel::launch_compute_per_token_amax( + static_cast(rows), static_cast(cols), input_ptr, per_token_amax_ptr, stream, + noop_ptr); + } else if (input.dtype() == DType::kFloat32) { + const auto *input_ptr = reinterpret_cast(input.data.dptr); + quantize_per_token_kernel::launch_compute_per_token_amax( + static_cast(rows), static_cast(cols), input_ptr, per_token_amax_ptr, stream, + noop_ptr); + } else { + NVTE_ERROR( + "Unsupported input dtype for per-token NVFP4 quantization. " + "Expected BFloat16, Float16, or Float32."); + } + + if (output->has_data()) { + NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Rowwise output must have FP4 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Rowwise scaling tensor must be allocated."); + NVTE_CHECK(output->amax.dptr != nullptr, "Rowwise per-token amax tensor must be allocated."); + + QuantizationConfig per_token_quant_config; + if (quant_config != nullptr) { + per_token_quant_config = *quant_config; + } + per_token_quant_config.nvfp4_per_token_activation = true; + per_token_quant_config.nvfp4_2d_quantization = false; + + const bool use_optimized_kernel = + (input.dtype() == DType::kBFloat16) && (rows % 32 == 0) && (cols % 32 == 0); + if (use_optimized_kernel) { + quantize_transpose(input, noop, output, + &per_token_quant_config, stream); } else { - quantize_per_token_kernel::launch_compute_per_token_amax<__nv_bfloat16>( - static_cast(rows), static_cast(cols), input_ptr, per_token_amax_ptr, stream, - noop_ptr); + quantize_transpose_vector_blockwise_fp4( + /*input=*/input.data, /*global_amax=*/output->amax, + /*scale_inv=*/output->scale_inv, /*scale_inv_t=*/output->columnwise_scale_inv, + /*output=*/output->data, /*output_t=*/output->columnwise_data, + /*epsilon=*/0.0f, /*return_identity=*/true, /*return_transpose=*/false, + /*pow2_scale=*/false, /*swizzled_scale=*/false, + /*use_stochastic_rounding=*/per_token_quant_config.stochastic_rounding, + /*rng_state=*/per_token_quant_config.rng_state, /*use_2d_quantization=*/false, + /*per_token_rowwise_scaling=*/true, /*noop_tensor=*/noop->data, /*stream=*/stream); } - if (output->has_columnwise_data()) { - NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), - "Columnwise output must have FP4 type."); - NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, - "Columnwise scaling tensor must be allocated."); - if (amax_ptr != nullptr && amax_colwise_ptr != nullptr && amax_ptr != amax_colwise_ptr) { - NVTE_CHECK_CUDA(cudaMemcpyAsync(amax_colwise_ptr, amax_ptr, rows * sizeof(float), - cudaMemcpyDeviceToDevice, stream)); - } - auto *data_t_ptr = reinterpret_cast(output->columnwise_data.dptr); - auto *scale_t_ptr = reinterpret_cast(output->columnwise_scale_inv.dptr); - const int scale_stride_t = static_cast(output->columnwise_scale_inv.shape.back()); + } + + if (output->has_columnwise_data()) { + NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), + "Columnwise output must have FP4 type."); + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Columnwise scaling tensor must be allocated."); + if (amax_ptr != nullptr && amax_colwise_ptr != nullptr && amax_ptr != amax_colwise_ptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(amax_colwise_ptr, amax_ptr, rows * sizeof(float), + cudaMemcpyDeviceToDevice, stream)); + } + auto *data_t_ptr = reinterpret_cast(output->columnwise_data.dptr); + auto *scale_t_ptr = reinterpret_cast(output->columnwise_scale_inv.dptr); + const int scale_stride_t = static_cast(output->columnwise_scale_inv.shape.back()); + if (input.dtype() == DType::kBFloat16) { + const auto *input_ptr = reinterpret_cast(input.data.dptr); quantize_per_token_kernel::launch_quantize_per_token_nvfp4_columnwise<__nv_bfloat16>( static_cast(rows), static_cast(cols), input_ptr, data_t_ptr, scale_t_ptr, per_token_amax_ptr, scale_stride_t, stream, noop_ptr); - } - } else if (input.dtype() == DType::kFloat16) { - const auto *input_ptr = reinterpret_cast(input.data.dptr); - if (output->has_data()) { - NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Rowwise output must have FP4 type."); - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Rowwise scaling tensor must be allocated."); - NVTE_CHECK(output->amax.dptr != nullptr, "Rowwise per-token amax tensor must be allocated."); - auto *data_ptr = reinterpret_cast(output->data.dptr); - auto *scale_ptr = reinterpret_cast(output->scale_inv.dptr); - const int scale_stride = static_cast(output->scale_inv.shape.back()); - quantize_per_token_kernel::launch_quantize_per_token_nvfp4( - static_cast(rows), static_cast(cols), input_ptr, row_offsets, data_ptr, - scale_ptr, amax_ptr, scale_stride, stream, noop_ptr); - } else { - quantize_per_token_kernel::launch_compute_per_token_amax( - static_cast(rows), static_cast(cols), input_ptr, per_token_amax_ptr, stream, - noop_ptr); - } - if (output->has_columnwise_data()) { - NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), - "Columnwise output must have FP4 type."); - NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, - "Columnwise scaling tensor must be allocated."); - if (amax_ptr != nullptr && amax_colwise_ptr != nullptr && amax_ptr != amax_colwise_ptr) { - NVTE_CHECK_CUDA(cudaMemcpyAsync(amax_colwise_ptr, amax_ptr, rows * sizeof(float), - cudaMemcpyDeviceToDevice, stream)); - } - auto *data_t_ptr = reinterpret_cast(output->columnwise_data.dptr); - auto *scale_t_ptr = reinterpret_cast(output->columnwise_scale_inv.dptr); - const int scale_stride_t = static_cast(output->columnwise_scale_inv.shape.back()); + } else if (input.dtype() == DType::kFloat16) { + const auto *input_ptr = reinterpret_cast(input.data.dptr); quantize_per_token_kernel::launch_quantize_per_token_nvfp4_columnwise( static_cast(rows), static_cast(cols), input_ptr, data_t_ptr, scale_t_ptr, per_token_amax_ptr, scale_stride_t, stream, noop_ptr); - } - } else if (input.dtype() == DType::kFloat32) { - const auto *input_ptr = reinterpret_cast(input.data.dptr); - if (output->has_data()) { - NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Rowwise output must have FP4 type."); - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Rowwise scaling tensor must be allocated."); - NVTE_CHECK(output->amax.dptr != nullptr, "Rowwise per-token amax tensor must be allocated."); - auto *data_ptr = reinterpret_cast(output->data.dptr); - auto *scale_ptr = reinterpret_cast(output->scale_inv.dptr); - const int scale_stride = static_cast(output->scale_inv.shape.back()); - quantize_per_token_kernel::launch_quantize_per_token_nvfp4( - static_cast(rows), static_cast(cols), input_ptr, row_offsets, data_ptr, - scale_ptr, amax_ptr, scale_stride, stream, noop_ptr); - } else { - quantize_per_token_kernel::launch_compute_per_token_amax( - static_cast(rows), static_cast(cols), input_ptr, per_token_amax_ptr, stream, - noop_ptr); - } - if (output->has_columnwise_data()) { - NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), - "Columnwise output must have FP4 type."); - NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, - "Columnwise scaling tensor must be allocated."); - if (amax_ptr != nullptr && amax_colwise_ptr != nullptr && amax_ptr != amax_colwise_ptr) { - NVTE_CHECK_CUDA(cudaMemcpyAsync(amax_colwise_ptr, amax_ptr, rows * sizeof(float), - cudaMemcpyDeviceToDevice, stream)); - } - auto *data_t_ptr = reinterpret_cast(output->columnwise_data.dptr); - auto *scale_t_ptr = reinterpret_cast(output->columnwise_scale_inv.dptr); - const int scale_stride_t = static_cast(output->columnwise_scale_inv.shape.back()); + } else if (input.dtype() == DType::kFloat32) { + const auto *input_ptr = reinterpret_cast(input.data.dptr); quantize_per_token_kernel::launch_quantize_per_token_nvfp4_columnwise( static_cast(rows), static_cast(cols), input_ptr, data_t_ptr, scale_t_ptr, per_token_amax_ptr, scale_stride_t, stream, noop_ptr); } - } else { - NVTE_ERROR( - "Unsupported input dtype for per-token NVFP4 quantization. " - "Expected BFloat16, Float16, or Float32."); } #else NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index f164636e38..331c78df51 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -108,7 +108,8 @@ constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM; // 8 = 128 / 16 template + typename IType, bool USE_STOCHASTIC_ROUNDING, bool RETURN_TRANSPOSE, + bool PER_TOKEN_ROWWISE> __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_kernel(const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ CUtensorMap tensor_map_output, @@ -509,12 +510,22 @@ __global__ void __launch_bounds__(THREADS_NUM) } // 2. Compute E4M3 scaling factor + const size_t row_idx_global = + scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; + float S_enc_rowwise_block = S_enc_rowwise; + if constexpr (PER_TOKEN_ROWWISE) { + S_enc_rowwise_block = + row_idx_global < rows + ? compute_global_encode_scaling_factor_FP4(amax_rowwise_ptr[row_idx_global]) + : 1.0f; + } + const float S_dec_rowwise_block = + PER_TOKEN_ROWWISE ? 1.0 / S_enc_rowwise_block : S_dec_rowwise; const nvfp4_scale_t S_dec_b_fp8 = - compute_decoding_scaling_factor(block_amax, S_enc_rowwise); + compute_decoding_scaling_factor(block_amax, S_enc_rowwise_block); // Check boundaries - const size_t scales_offset_Y = - scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; + const size_t scales_offset_Y = row_idx_global; const size_t scales_offset_X = scales_offset_X_rowwise; const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; @@ -527,8 +538,9 @@ __global__ void __launch_bounds__(THREADS_NUM) // Compute "correct" per-block encoding scaling factor constexpr float float_max = detail::TypeExtrema::max; - const float block_scale_inverse = fminf( - 1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise), float_max); // S_enc_b_fp8 + const float block_scale_inverse = + fminf(1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise_block), + float_max); // S_enc_b_fp8 const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; // 3. Scale elements @@ -1162,11 +1174,14 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, using namespace ptx; bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; + const bool per_token_rowwise = quant_config ? quant_config->nvfp4_per_token_activation : false; + NVTE_CHECK(!per_token_rowwise || !use_2d_quantization, + "Per-token NVFP4 quantization does not support 2D quantization."); // If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to // return the transposed data. // TODO(Frank): Is there a better way to do this? - bool return_transpose = output->has_columnwise_data(); + bool return_transpose = output->has_columnwise_data() && !per_token_rowwise; if (!use_2d_quantization && (input.dtype() == DType::kBFloat16)) { quantize_transpose_tuned_1D(input, noop, output, quant_config, stream); @@ -1186,6 +1201,8 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated."); NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + NVTE_CHECK(!per_token_rowwise || output->amax.dptr != nullptr, + "Per-token NVFP4 rowwise quantization requires rowwise amax."); NVTE_CHECK(!output->with_gemm_swizzled_scales, "Output must have scales in compact format."); if (return_transpose) { NVTE_CHECK(output->has_columnwise_data(), "NVFP4 transposed output tensor must be allocated."); @@ -1268,20 +1285,23 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, TRANSFORMER_ENGINE_SWITCH_CONDITION( use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, - TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { - auto kernel = quantize_transpose_nvfp4_kernel; + TRANSFORMER_ENGINE_SWITCH_CONDITION(per_token_rowwise, PER_TOKEN_ROWWISE, { + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { + auto kernel = quantize_transpose_nvfp4_kernel; - if constexpr (use_2d_quantization) { - kernel = quantize_transpose_nvfp4_2D_kernel; - } + if constexpr (use_2d_quantization) { + kernel = quantize_transpose_nvfp4_2D_kernel; + } - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); - kernel<<>>( - tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, - scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, - scale_stride, scale_stride_transpose, rng_state); + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + kernel<<>>( + tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, + scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, + scale_stride, scale_stride_transpose, rng_state); + }); });); #else NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); diff --git a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh index fc337f6078..172b35b245 100644 --- a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh +++ b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh @@ -261,14 +261,12 @@ __device__ __forceinline__ void colwise_scaling(const IType *__restrict__ sIn_pt } } -template -__device__ __forceinline__ void rowwise_scaling(const IType *__restrict__ sIn_ptr, - fp4e2m1x2 *__restrict__ sOut_ptr, - nvfp4_scale_t *__restrict__ sSFrowwise_ptr, - const float S_enc_rowwise, const int stage_Y, - const int stage_X, const int buff_in, - const int buff_out, RNG_t &rng, uint4 &random_uint4, - int &rnd_idx) { +template +__device__ __forceinline__ void rowwise_scaling( + const IType *__restrict__ sIn_ptr, fp4e2m1x2 *__restrict__ sOut_ptr, + nvfp4_scale_t *__restrict__ sSFrowwise_ptr, const float S_enc_rowwise, const int stage_Y, + const int stage_X, const int buff_in, const int buff_out, const float *amax_rowwise_ptr, + const size_t row_offset, const size_t rows, RNG_t &rng, uint4 &random_uint4, int &rnd_idx) { using scaling_coeff_type = typename SCALING_COEFFICIENT_TYPE::type; const auto &sIn = *reinterpret_cast(sIn_ptr); @@ -315,9 +313,17 @@ __device__ __forceinline__ void rowwise_scaling(const IType *__restrict__ sIn_pt } const float block_amax = get_amax_of_pair(thread_amax_2x); - const nvfp4_scale_t S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc_rowwise); + const size_t row_idx = row_offset + stage_Y * TILE_DIM_Y + it_offset_Y_rowwise; + float S_enc_rowwise_block = S_enc_rowwise; + if constexpr (PER_TOKEN_ROWWISE) { + S_enc_rowwise_block = + row_idx < rows ? core::compute_global_encode_scaling_factor_FP4(amax_rowwise_ptr[row_idx]) + : 1.0f; + } + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_rowwise_block); const scaling_coeff_type SFcoefficient = - compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_enc_rowwise); + compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_enc_rowwise_block); // Store scaling factors to SMEM buffer (R2S) if (SF_storing_thread) { @@ -350,7 +356,8 @@ __device__ __forceinline__ void rowwise_scaling(const IType *__restrict__ sIn_pt } } -template +template __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_tuned_1D_kernel( const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ CUtensorMap tensor_map_output, @@ -571,9 +578,9 @@ __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_tuned_1D ptx::cp_async_bulk_wait_group_read(); // NVFP4 Quantization - rowwise_scaling( + rowwise_scaling( sIn_ptr, sOut_ptr, sSFrowwise_ptr, S_enc_rowwise, stage_Y, stage_X, buff_in, buff_out, - rng, random_uint4, rnd_idx); + amax_rowwise_ptr, block_offset_Y, rows, rng, random_uint4, rnd_idx); if constexpr (RETURN_TRANSPOSE) { colwise_scaling( @@ -680,10 +687,11 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, const bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; const bool use_fast_math = quant_config ? quant_config->use_fast_math : false; + const bool per_token_rowwise = quant_config ? quant_config->nvfp4_per_token_activation : false; // If transposed output is allocated, return the transposed data // Otherwise, it's not necesary to return the transposed data. - const bool return_transpose = output->has_columnwise_data(); + const bool return_transpose = output->has_columnwise_data() && !per_token_rowwise; checkCuDriverContext(stream); CheckNoopTensor(*noop, "cast_noop"); @@ -694,6 +702,8 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated."); NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + NVTE_CHECK(!per_token_rowwise || output->amax.dptr != nullptr, + "Per-token NVFP4 rowwise quantization requires rowwise amax."); if (return_transpose) { NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), @@ -783,16 +793,20 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, TRANSFORMER_ENGINE_SWITCH_CONDITION( use_fast_math, USE_FAST_MATH, - TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { - auto kernel = quantize_transpose_nvfp4_tuned_1D_kernel; - - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); - kernel<<>>( - tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, - scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, - scale_stride, scale_stride_transpose, rng_state); - }););); + TRANSFORMER_ENGINE_SWITCH_CONDITION( + per_token_rowwise, PER_TOKEN_ROWWISE, + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { + auto kernel = + quantize_transpose_nvfp4_tuned_1D_kernel; + + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + dshmem_size); + kernel<<>>( + tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, + scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, + scale_stride, scale_stride_transpose, rng_state); + });););); #else NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); #endif // FP4_TYPE_SUPPORTED diff --git a/transformer_engine/common/transpose/cast_transpose.h b/transformer_engine/common/transpose/cast_transpose.h index a5ec2306b1..d2f8ba384a 100644 --- a/transformer_engine/common/transpose/cast_transpose.h +++ b/transformer_engine/common/transpose/cast_transpose.h @@ -68,7 +68,7 @@ void quantize_transpose_vector_blockwise_fp4( const bool return_identity, const bool return_transpose, const bool pow2_scale, const bool swizzled_scale, const bool use_stochastic_rounding, const NVTETensor rng_state_tensor, const bool use_2d_quantization, - const SimpleTensor &noop_tensor, cudaStream_t stream); + const bool per_token_rowwise_scaling, const SimpleTensor &noop_tensor, cudaStream_t stream); } // namespace transformer_engine::detail diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index d3d3dceca9..64e4e09f89 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -316,7 +316,7 @@ __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x(const float2 in01, template + bool kApplyStochasticRounding, bool kIs2DBlockScaling, bool kPerTokenRowwiseScaling> __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel( const IType* const input, const float* global_amax, OType* const output_c, OType* const output_t, ScaleType* const tile_scales_inv_c, ScaleType* const tile_scales_inv_t, @@ -509,8 +509,20 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo amax = amax_smem[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x]; } // Step 2.4: Compute scale - ScaleType scale_inv = ComputeDecodeScaleFP4(amax, global_encode_scale_multiplier); - float encode_scale = ComputeEncodeScaleFP4(scale_inv, global_decode_scale); + const size_t row_idx = block_idx_y * kTileDim + r_s; + float row_global_encode_scale = global_encode_scale; + if constexpr (kPerTokenRowwiseScaling) { + row_global_encode_scale = + row_idx < num_rows ? ComputeGlobalEncodeScaleFP4(global_amax[row_idx]) : 1.0f; + } + const float row_global_encode_scale_multiplier = kPerTokenRowwiseScaling + ? row_global_encode_scale * fp4_max_inv + : global_encode_scale_multiplier; + const float row_global_decode_scale = + kPerTokenRowwiseScaling ? 1.0f / row_global_encode_scale : global_decode_scale; + ScaleType scale_inv = + ComputeDecodeScaleFP4(amax, row_global_encode_scale_multiplier); + float encode_scale = ComputeEncodeScaleFP4(scale_inv, row_global_decode_scale); // Step 2.5: Write scale_inv bool write_scale_inv = is_src_lane; if constexpr (!kAligned) { @@ -709,7 +721,7 @@ void quantize_transpose_vector_blockwise_fp4( const bool return_identity, const bool return_transpose, const bool pow2_scale, const bool swizzled_scale, const bool use_stochastic_rounding, const NVTETensor rng_state_tensor, const bool use_2d_quantization, - const SimpleTensor& noop_tensor, cudaStream_t stream) { + const bool per_token_rowwise_scaling, const SimpleTensor& noop_tensor, cudaStream_t stream) { NVTE_API_CALL(quantize_transpose_vector_blockwise_fp4); #if CUDA_VERSION >= 12080 @@ -722,6 +734,10 @@ void quantize_transpose_vector_blockwise_fp4( NVTE_CHECK(return_identity || !use_2d_quantization, "2D block quantization is only supported when return_identity is true."); + NVTE_CHECK(!per_token_rowwise_scaling || (return_identity && !return_transpose), + "Per-token NVFP4 rowwise scaling only supports rowwise quantization."); + NVTE_CHECK(!per_token_rowwise_scaling || !use_2d_quantization, + "Per-token NVFP4 rowwise scaling does not support 2D quantization."); const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; size_t num_elements = row_length; @@ -801,35 +817,41 @@ void quantize_transpose_vector_blockwise_fp4( TRANSFORMER_ENGINE_SWITCH_CONDITION( use_2d_quantization, kIs2DBlockScaling, - size_t smem_bytes = kSMemSize * sizeof(InputType); - auto kernel = block_scaled_1d_cast_transpose_kernel< - kReturnIdentity, kReturnTranspose, kPow2Scale, kAligned, - float, InputType, OutputType, ScaleType, kSwizzledScale, - kApplyStochasticRounding, kIs2DBlockScaling>; - if (smem_bytes >= 48 * 1024) { - cudaError_t err = cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_bytes); - NVTE_CHECK(err == cudaSuccess, - "Failed to set dynamic shared memory size."); - } kernel<<>>( - reinterpret_cast(input.dptr), - reinterpret_cast(global_amax.dptr), - reinterpret_cast(output.dptr), - reinterpret_cast(output_t.dptr), - reinterpret_cast(scale_inv.dptr), - reinterpret_cast(scale_inv_t.dptr), row_length, - num_rows, scale_stride_x, scale_stride_y, scale_t_stride_x, - scale_t_stride_y, kScaleBlockDim, epsilon, rng_state, - noop_ptr);) // kIs2DBlockScaling - ) // kApplyStochasticRounding - ) // kSwizzledScale - ) // kAligned - ) // kReturnTranspose - ) // kReturnIdentity - ) // OutputType - ) // InputType + TRANSFORMER_ENGINE_SWITCH_CONDITION( + per_token_rowwise_scaling, kPerTokenRowwiseScaling, + + size_t smem_bytes = kSMemSize * sizeof(InputType); + auto kernel = block_scaled_1d_cast_transpose_kernel< + kReturnIdentity, kReturnTranspose, kPow2Scale, kAligned, + float, InputType, OutputType, ScaleType, kSwizzledScale, + kApplyStochasticRounding, kIs2DBlockScaling, + kPerTokenRowwiseScaling>; + if (smem_bytes >= 48 * 1024) { + cudaError_t err = cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_bytes); + NVTE_CHECK(err == cudaSuccess, + "Failed to set dynamic shared memory size."); + } kernel<<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(global_amax.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), + row_length, num_rows, scale_stride_x, scale_stride_y, + scale_t_stride_x, scale_t_stride_y, kScaleBlockDim, + epsilon, rng_state, + noop_ptr);) // kPerTokenRowwiseScaling + ) // kIs2DBlockScaling + ) // kApplyStochasticRounding + ) // kSwizzledScale + ) // kAligned + ) // kReturnTranspose + ) // kReturnIdentity + ) // OutputType + ) // InputType NVTE_CHECK_CUDA(cudaGetLastError()); #else From 93a06ad6351e3905d44300179d259b79bff73b39 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 4 May 2026 23:42:29 -0700 Subject: [PATCH 24/45] Drop column wise Signed-off-by: Ziang Li --- .../cpp/operator/test_cast_nvfp4_transpose.cu | 36 +++-- .../nvfp4/test_nvfp4_quantize_exact.py | 32 +++- tests/pytorch/test_backward_override.py | 23 ++- tests/pytorch/utils.py | 7 + .../cast/nvfp4/quantize_per_token_nvfp4.cuh | 151 ++---------------- .../pytorch/csrc/extensions/cast.cpp | 16 +- transformer_engine/pytorch/csrc/quantizer.cpp | 20 ++- .../pytorch/tensor/nvfp4_tensor.py | 6 +- 8 files changed, 100 insertions(+), 191 deletions(-) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index f7a16539cc..a6d4105702 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -560,13 +560,12 @@ void print_detailed_tensor_comparison(const std::string& name, void compareResults_nvfp4(const Tensor &test, const void *ref, const void *ref_t, const int rows, const int cols, - double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true, bool dump_data = false) { + double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true, + bool dump_data = false, bool compare_columnwise = true) { if (if_on_gpus) test.to_cpu(); const fp4e2m1 *test_data = test.rowwise_cpu_dptr(); - const fp4e2m1 *test_data_t = test.columnwise_cpu_dptr(); const fp4e2m1 *ref_data = reinterpret_cast(ref); - const fp4e2m1 *ref_data_t = reinterpret_cast(ref_t); // Print detailed element-by-element comparison // print_detailed_tensor_comparison("output", test_data, ref_data, rows, cols); @@ -575,11 +574,17 @@ void compareResults_nvfp4(const Tensor &test, // Optionally dump tensor data to files for detailed analysis if (dump_data) { dump_nvfp4_tensor_data("output", test_data, ref_data, rows, cols); - dump_nvfp4_tensor_data("output_t", test_data_t, ref_data_t, cols, rows); } compare_nvfp4_tensors("output", test_data, ref_data, rows, cols, atol, rtol); - compare_nvfp4_tensors("output_t", test_data_t, ref_data_t, cols, rows, atol, rtol); + if (compare_columnwise) { + const fp4e2m1 *test_data_t = test.columnwise_cpu_dptr(); + const fp4e2m1 *ref_data_t = reinterpret_cast(ref_t); + if (dump_data) { + dump_nvfp4_tensor_data("output_t", test_data_t, ref_data_t, cols, rows); + } + compare_nvfp4_tensors("output_t", test_data_t, ref_data_t, cols, rows, atol, rtol); + } } void compare_per_token_amax(const Tensor &output, const std::vector &ref_amax) { @@ -624,7 +629,6 @@ void set_per_token_amax_metadata(Tensor &output, const size_t rows) { }; replace_amax(kNVTEAmax); - replace_amax(kNVTEColumnwiseAmax); } template @@ -658,7 +662,7 @@ void performTest(float (*OP)(const float), const size_t scales_stride_t = blocks_X_t; Tensor input("input", shape, itype); - Tensor output("output", shape, otype, true, true, NVTE_NVFP4_1D_SCALING); + Tensor output("output", shape, otype, true, !per_token_activation, NVTE_NVFP4_1D_SCALING); std::unique_ptr ref_output = std::make_unique(rows * (cols / 2)); std::unique_ptr ref_output_t = std::make_unique(cols * (rows / 2)); @@ -749,12 +753,8 @@ void performTest(float (*OP)(const float), const double rtol = 1.0E-6; // Set dump_data=true to enable dumping tensor data to files for analysis - compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, true, false); - - const fp8e4m3* kernel_scales = output.rowwise_cpu_scale_inv_ptr(); - const fp8e4m3* ref_scales_ptr = ref_scales.get(); - const fp8e4m3* kernel_scales_t = output.columnwise_cpu_scale_inv_ptr(); - const fp8e4m3* ref_scales_t_ptr = ref_scales_t.get(); + compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, true, + false, !per_token_activation); size_t scale_mismatches_num = 0; compare_scaling_factors("scales", output.rowwise_cpu_scale_inv_ptr(), @@ -762,10 +762,12 @@ void performTest(float (*OP)(const float), unpadded_blocks_Y, unpadded_blocks_X, scales_stride, scale_mismatches_num); - compare_scaling_factors("scales_t", output.columnwise_cpu_scale_inv_ptr(), - ref_scales_t.get(), - unpadded_blocks_Y_t, unpadded_blocks_X_t, scales_stride_t, - scale_mismatches_num); + if (!per_token_activation) { + compare_scaling_factors("scales_t", output.columnwise_cpu_scale_inv_ptr(), + ref_scales_t.get(), + unpadded_blocks_Y_t, unpadded_blocks_X_t, scales_stride_t, + scale_mismatches_num); + } if (per_token_activation) { compare_per_token_amax(output, ref_per_token_amax); diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index 098807b685..4458e408ba 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -82,7 +82,7 @@ def check_quantization_nvfp4_versus_reference( ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, rowwise=True, - columnwise=return_transpose, + columnwise=(return_transpose and not per_token_activation), pow_2_scales=False, eps=0.0, quant_tile_shape=quant_tile_shape, @@ -119,7 +119,7 @@ def check_quantization_nvfp4_versus_reference( torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) - if return_transpose: + if return_transpose and not per_token_activation: torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) # Compare only the valid portion of transpose scale tensors @@ -127,6 +127,10 @@ def check_quantization_nvfp4_versus_reference( sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) torch.testing.assert_close(qx_amax_t, ref_amax_t, atol=0.0, rtol=0.0) + elif return_transpose: + assert qx_t is None + assert sx_t is None + assert qx_amax_t is None torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) @@ -257,7 +261,7 @@ def test_nvfp4_quantization_extrema_versus_reference( ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, rowwise=True, - columnwise=return_transpose, + columnwise=(return_transpose and not per_token_activation), pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), @@ -282,12 +286,16 @@ def test_nvfp4_quantization_extrema_versus_reference( sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]] torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) - if return_transpose: + if return_transpose and not per_token_activation: torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) ref_sx_t_shape = sx_t_ref.shape sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) torch.testing.assert_close(qx_amax_t, ref_amax_t, atol=0.0, rtol=0.0) + elif return_transpose: + assert qx_t is None + assert sx_t is None + assert qx_amax_t is None torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) @@ -375,7 +383,7 @@ def test_nvfp4_quantization_boundary_values( ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, rowwise=True, - columnwise=return_transpose, + columnwise=(return_transpose and not per_token_activation), pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), @@ -401,12 +409,16 @@ def test_nvfp4_quantization_boundary_values( sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]] torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) - if return_transpose: + if return_transpose and not per_token_activation: torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) ref_sx_t_shape = sx_t_ref.shape sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) torch.testing.assert_close(qx_amax_t, ref_amax_t, atol=0.0, rtol=0.0) + elif return_transpose: + assert qx_t is None + assert sx_t is None + assert qx_amax_t is None torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) @@ -479,7 +491,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, rowwise=True, - columnwise=return_transpose, + columnwise=(return_transpose and not per_token_activation), pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), @@ -506,11 +518,15 @@ def test_nvfp4_quantization_noncontiguous_inputs( sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]] torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) - if return_transpose: + if return_transpose and not per_token_activation: torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) ref_sx_t_shape = sx_t_ref.shape sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) torch.testing.assert_close(qx_amax_t, ref_amax_t, atol=0.0, rtol=0.0) + elif return_transpose: + assert qx_t is None + assert sx_t is None + assert qx_amax_t is None torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index 15f08975e2..b96d75cfff 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -190,6 +190,12 @@ def _maybe_skip_unsupported_fused_ops(recipe_name: str) -> None: pytest.skip("Per-token NVFP4 currently does not support fused te_ops paths.") +def _make_quantized_forward_reference_recipe(recipe_name: str) -> recipe.Recipe: + if recipe_name == "nvfp4_per_token": + return make_recipe(recipe_name, backward_override="dequantized") + return make_recipe(recipe_name) + + def _maybe_skip_unsupported_recipe_shape( recipe_name: str, input_shape: tuple[int, ...], @@ -861,7 +867,7 @@ def test_linear_like_backward_override_matches_reference( _maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, module_type) in_features = input_shape[-1] - quantized_ref_recipe = make_recipe(recipe_name) + quantized_ref_recipe = _make_quantized_forward_reference_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override(module_type, mode_recipe, backward_override) @@ -1045,7 +1051,7 @@ def test_grouped_linear_backward_override_matches_reference( num_gemms = len(m_splits) num_tokens = sum(m_splits) - quantized_ref_recipe = make_recipe(recipe_name) + quantized_ref_recipe = _make_quantized_forward_reference_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override("grouped_linear", mode_recipe, backward_override) @@ -1215,6 +1221,7 @@ def test_linear_like_runtime_backward_override_switch_updates_ctx( dy = torch.randn(*input_shape[:-1], out_features, dtype=dtype, device="cuda") default_recipe = make_recipe(recipe_name) + skip_unsupported_backward_override(module_type, default_recipe, None) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override(module_type, mode_recipe, backward_override) @@ -1285,6 +1292,7 @@ def test_grouped_linear_runtime_backward_override_switch_updates_ctx( dy = torch.randn(num_tokens, out_features, dtype=dtype, device="cuda") default_recipe = make_recipe(recipe_name) + skip_unsupported_backward_override("grouped_linear", default_recipe, None) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override("grouped_linear", mode_recipe, backward_override) @@ -1353,7 +1361,7 @@ def test_fused_linear_paths_match_backward_override_reference( reset_rng_states() - quantized_ref_recipe = make_recipe(recipe_name) + quantized_ref_recipe = _make_quantized_forward_reference_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override("ops_linear", mode_recipe, backward_override) @@ -1494,7 +1502,7 @@ def test_fused_bias_activation_matches_masked_linear_backward( reset_rng_states() in_features = input_shape[-1] - quantized_ref_recipe = make_recipe(recipe_name) + quantized_ref_recipe = _make_quantized_forward_reference_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override("ops_linear", mode_recipe, backward_override) @@ -1636,6 +1644,7 @@ def test_operation_fuser_rebuilds_userbuffers_fusion_on_backward_override_switch reset_rng_states() _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") + _maybe_skip_unsupported_fused_ops(recipe_name) # Build a Userbuffers-eligible fuser and representative inputs. linear = te_ops.BasicLinear( @@ -1733,7 +1742,11 @@ def test_backward_override_memory_peak_report( x = torch.randn(*input_shape, dtype=dtype, device="cuda") dy = torch.randn(*input_shape[:-1], out_features, dtype=dtype, device="cuda") - modes = (None, "high_precision", "dequantized") + modes = ( + ("high_precision", "dequantized") + if recipe_name == "nvfp4_per_token" + else (None, "high_precision", "dequantized") + ) mode_results: dict[str, dict[str, float] | str] = {} for mode in modes: diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 3dc4cdffe8..1b2f9fe987 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -175,6 +175,13 @@ def skip_unsupported_backward_override( backward_override: Optional[str], ) -> None: """Skip known unsupported layer/recipe/backward-override combinations used in tests.""" + if ( + quant_recipe is not None + and quant_recipe.nvfp4() + and getattr(quant_recipe, "per_token_activation", False) + and backward_override is None + ): + pytest.skip("Per-token NVFP4 does not support default quantized backward.") if backward_override is None: return if quant_recipe is None and backward_override is not None: diff --git a/transformer_engine/common/cast/nvfp4/quantize_per_token_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_per_token_nvfp4.cuh index 53c361c22d..dec7248803 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_per_token_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_per_token_nvfp4.cuh @@ -119,99 +119,6 @@ void launch_compute_per_token_amax(const int num_rows, const int num_cols, const #endif } -template -__global__ void -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -__launch_bounds__(BLOCK_SIZE) -#endif - quantize_per_token_nvfp4_columnwise_kernel(const int num_rows, const int num_cols, - const IType *__restrict__ input, - uint8_t *__restrict__ output_data_t, - fp8e4m3 *__restrict__ output_scales_t, - const float *__restrict__ per_token_amax, - const int scale_stride_t, - const float *__restrict__ noop) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - using namespace detail; - if (noop != nullptr && noop[0] == 1.0f) { - return; - } - - const int col_idx = blockIdx.x; - if (col_idx >= num_cols) return; - - constexpr float fp4_max_inv = 1.0f / TypeExtrema::max; - constexpr float float_max = TypeExtrema::max; - constexpr float one = 1.0f; - const float2 one_2x{one, one}; - const int num_row_blocks = num_rows / PERTOKEN_SF_VEC_SIZE; - - for (int row_block = threadIdx.x; row_block < num_row_blocks; row_block += BLOCK_SIZE) { - const int row_start = row_block * PERTOKEN_SF_VEC_SIZE; - - float vals[PERTOKEN_SF_VEC_SIZE]; - float s_enc[PERTOKEN_SF_VEC_SIZE]; - float scaled_block_amax = 0.0f; -#pragma unroll - for (int i = 0; i < PERTOKEN_SF_VEC_SIZE; ++i) { - const int row_idx = row_start + i; - const float val = static_cast(input[row_idx * num_cols + col_idx]); - const float S_enc = compute_global_encode_scaling_factor_FP4(per_token_amax[row_idx]); - vals[i] = val; - s_enc[i] = S_enc; - scaled_block_amax = fmaxf(scaled_block_amax, fabsf(val) * (S_enc * fp4_max_inv)); - } - - const float S_dec_b_f32 = fminf(scaled_block_amax, float_max); - const nvfp4_scale_t S_dec_b_fp8 = static_cast(S_dec_b_f32); - output_scales_t[col_idx * scale_stride_t + row_block] = S_dec_b_fp8; - - float scaled_vals[PERTOKEN_SF_VEC_SIZE]; -#pragma unroll - for (int i = 0; i < PERTOKEN_SF_VEC_SIZE; ++i) { - const float S_dec_rowwise = 1.0f / s_enc[i]; - const float block_scale_inverse = - fminf(1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise), float_max); - scaled_vals[i] = vals[i] * block_scale_inverse; - } - - uint8_t *out_ptr = output_data_t + col_idx * (num_rows / 2) + row_start / 2; - auto *out_fp4 = reinterpret_cast(out_ptr); -#pragma unroll - for (int j = 0; j < PERTOKEN_SF_VEC_SIZE; j += 4) { - const float2 in01 = make_float2(scaled_vals[j], scaled_vals[j + 1]); - const float2 in23 = make_float2(scaled_vals[j + 2], scaled_vals[j + 3]); - out_fp4[j / 4] = ptx::mul_cvt_fp32_to_fp4_4x( - in01, in23, one_2x, /*rbits=*/0u); - } - } -#endif -} - -template -void launch_quantize_per_token_nvfp4_columnwise(const int num_rows, const int num_cols, - const IType *input, uint8_t *output_data_t, - fp8e4m3 *output_scales_t, - const float *per_token_amax, - const int scale_stride_t, cudaStream_t stream, - const float *noop = nullptr) { -#if FP4_TYPE_SUPPORTED - if (num_rows == 0 || num_cols == 0) return; - - NVTE_CHECK(num_rows % PERTOKEN_SF_VEC_SIZE == 0, "num_rows must be a multiple of ", - PERTOKEN_SF_VEC_SIZE, " for per-token NVFP4 columnwise quantization, got ", num_rows); - dim3 grid(num_cols); - dim3 block(PERTOKEN_BLOCK_SIZE); - - quantize_per_token_nvfp4_columnwise_kernel - <<>>(num_rows, num_cols, input, output_data_t, output_scales_t, - per_token_amax, scale_stride_t, noop); - NVTE_CHECK_CUDA(cudaGetLastError()); -#else - NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); -#endif -} - } // namespace quantize_per_token_kernel inline void quantize_per_token(const Tensor &input, const Tensor *noop, Tensor *output, @@ -225,8 +132,9 @@ inline void quantize_per_token(const Tensor &input, const Tensor *noop, Tensor * CheckOutputTensor(*output, "output", false); NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); - NVTE_CHECK(output->has_data() || output->has_columnwise_data(), - "NVFP4 output tensor must be allocated."); + NVTE_CHECK(output->has_data(), "Per-token NVFP4 quantization requires rowwise output."); + NVTE_CHECK(!output->has_columnwise_data(), + "Per-token NVFP4 quantization does not produce columnwise output."); NVTE_CHECK(!output->with_gemm_swizzled_scales, "Output must have scales in compact format."); const size_t rows = input.flat_first_dim(); @@ -237,33 +145,22 @@ inline void quantize_per_token(const Tensor &input, const Tensor *noop, Tensor * const auto *noop_ptr = reinterpret_cast(noop->data.dptr); auto *amax_ptr = reinterpret_cast(output->amax.dptr); - auto *amax_colwise_ptr = reinterpret_cast(output->columnwise_amax.dptr); - auto *per_token_amax_ptr = (amax_ptr != nullptr) ? amax_ptr : amax_colwise_ptr; - NVTE_CHECK(per_token_amax_ptr != nullptr, "Per-token amax tensor must be allocated."); - if (amax_ptr != nullptr) { - NVTE_CHECK(output->amax.numel() == rows, "Per-token rowwise amax must have ", rows, - " entries, got ", output->amax.shape, "."); - } - if (amax_colwise_ptr != nullptr) { - NVTE_CHECK(output->columnwise_amax.numel() == rows, "Per-token columnwise amax must have ", - rows, " entries, got ", output->columnwise_amax.shape, "."); - } + NVTE_CHECK(amax_ptr != nullptr, "Per-token rowwise amax tensor must be allocated."); + NVTE_CHECK(output->amax.numel() == rows, "Per-token rowwise amax must have ", rows, + " entries, got ", output->amax.shape, "."); if (input.dtype() == DType::kBFloat16) { const auto *input_ptr = reinterpret_cast(input.data.dptr); quantize_per_token_kernel::launch_compute_per_token_amax<__nv_bfloat16>( - static_cast(rows), static_cast(cols), input_ptr, per_token_amax_ptr, stream, - noop_ptr); + static_cast(rows), static_cast(cols), input_ptr, amax_ptr, stream, noop_ptr); } else if (input.dtype() == DType::kFloat16) { const auto *input_ptr = reinterpret_cast(input.data.dptr); quantize_per_token_kernel::launch_compute_per_token_amax( - static_cast(rows), static_cast(cols), input_ptr, per_token_amax_ptr, stream, - noop_ptr); + static_cast(rows), static_cast(cols), input_ptr, amax_ptr, stream, noop_ptr); } else if (input.dtype() == DType::kFloat32) { const auto *input_ptr = reinterpret_cast(input.data.dptr); quantize_per_token_kernel::launch_compute_per_token_amax( - static_cast(rows), static_cast(cols), input_ptr, per_token_amax_ptr, stream, - noop_ptr); + static_cast(rows), static_cast(cols), input_ptr, amax_ptr, stream, noop_ptr); } else { NVTE_ERROR( "Unsupported input dtype for per-token NVFP4 quantization. " @@ -299,36 +196,6 @@ inline void quantize_per_token(const Tensor &input, const Tensor *noop, Tensor * /*per_token_rowwise_scaling=*/true, /*noop_tensor=*/noop->data, /*stream=*/stream); } } - - if (output->has_columnwise_data()) { - NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), - "Columnwise output must have FP4 type."); - NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, - "Columnwise scaling tensor must be allocated."); - if (amax_ptr != nullptr && amax_colwise_ptr != nullptr && amax_ptr != amax_colwise_ptr) { - NVTE_CHECK_CUDA(cudaMemcpyAsync(amax_colwise_ptr, amax_ptr, rows * sizeof(float), - cudaMemcpyDeviceToDevice, stream)); - } - auto *data_t_ptr = reinterpret_cast(output->columnwise_data.dptr); - auto *scale_t_ptr = reinterpret_cast(output->columnwise_scale_inv.dptr); - const int scale_stride_t = static_cast(output->columnwise_scale_inv.shape.back()); - if (input.dtype() == DType::kBFloat16) { - const auto *input_ptr = reinterpret_cast(input.data.dptr); - quantize_per_token_kernel::launch_quantize_per_token_nvfp4_columnwise<__nv_bfloat16>( - static_cast(rows), static_cast(cols), input_ptr, data_t_ptr, scale_t_ptr, - per_token_amax_ptr, scale_stride_t, stream, noop_ptr); - } else if (input.dtype() == DType::kFloat16) { - const auto *input_ptr = reinterpret_cast(input.data.dptr); - quantize_per_token_kernel::launch_quantize_per_token_nvfp4_columnwise( - static_cast(rows), static_cast(cols), input_ptr, data_t_ptr, scale_t_ptr, - per_token_amax_ptr, scale_stride_t, stream, noop_ptr); - } else if (input.dtype() == DType::kFloat32) { - const auto *input_ptr = reinterpret_cast(input.data.dptr); - quantize_per_token_kernel::launch_quantize_per_token_nvfp4_columnwise( - static_cast(rows), static_cast(cols), input_ptr, data_t_ptr, scale_t_ptr, - per_token_amax_ptr, scale_stride_t, stream, noop_ptr); - } - } #else NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); #endif diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index ba75867a15..9e9964b99b 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -798,10 +798,12 @@ std::tuple, std::vector, bool> bulk_alloc // Quantization parameters const auto rowwise_usage = quantizer_cpp_list[0]->rowwise_usage; - const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage; + const bool per_token_activation = quantizer_cpp_list[0]->per_token_activation; + NVTE_CHECK(!per_token_activation || rowwise_usage, + "Per-token NVFP4 quantization requires rowwise usage."); + const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage && !per_token_activation; const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode(); const auto fp4_dtype = quantizer_cpp_list[0]->dtype; - const bool per_token_activation = quantizer_cpp_list[0]->per_token_activation; const bool with_gemm_swizzled_scales = false; /// TODO (tmoon) Enable based on optimize_for_gemm; constexpr size_t scale_elem_size = 1; @@ -944,8 +946,7 @@ std::tuple, std::vector, bool> bulk_alloc // Note: Multi-quantize kernel does not require contiguous amaxes. const auto offset = roundup(buffer_size, 16); amax_offsets.push_back(offset); - const size_t amax_size = per_token_activation ? 4 * flat_first_dim(shape_list[i]) : 4; - buffer_size = offset + amax_size; + buffer_size = offset + 4; } // Allocate full buffer @@ -958,11 +959,8 @@ std::tuple, std::vector, bool> bulk_alloc buffer, to_fp4_shape(columnwise_data_shapes[i]), data_offsets[i], torch::kUInt8)); columnwise_scale_list.emplace_back( make_torch_view(buffer, columnwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); - const std::vector amax_shape = - per_token_activation ? std::vector{flat_first_dim(shape_list[i])} - : std::vector{1}; amax_columnwise_list.emplace_back( - make_torch_view(buffer, amax_shape, amax_offsets[i], torch::kFloat32)); + make_torch_view(buffer, std::vector{1}, amax_offsets[i], torch::kFloat32)); } } @@ -1007,7 +1005,7 @@ std::tuple, std::vector, bool> bulk_alloc } if (columnwise_usage) { tensor_wrapper.set_columnwise_amax(amax_columnwise_list[i].data_ptr(), DType::kFloat32, - getTensorShape(amax_columnwise_list[i])); + std::vector{1}); } tensor_cpp_list.emplace_back(std::move(tensor_wrapper)); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 6e6e38a1dd..0b9119924e 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1748,6 +1748,9 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve NVTE_CHECK(flat_last_dim % NVFP4_BLOCK_SIZE == 0, "NVFP4 requires tensor dims that are divisible by ", NVFP4_BLOCK_SIZE, " (got shape=", shape, ")"); + NVTE_CHECK(!this->per_token_activation || rowwise_usage, + "Per-token NVFP4 quantization requires rowwise usage."); + const bool columnwise_usage = this->columnwise_usage && !this->per_token_activation; const auto rowwise_scale_inv_shape = get_scale_shape(shape, false); const auto columnwise_scale_inv_shape = get_scale_shape(shape, true); @@ -1779,8 +1782,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); // hadamard amax kernel will zero out pointer with ZeroAmaxKernel // nvte_compute_amax_with_config will zero out the pointer if needed - const int64_t amax_rows = this->per_token_activation ? static_cast(flat_first_dim) : 1; - amax_columnwise = at::empty({amax_rows}, bit32_tensor_opts); + amax_columnwise = at::empty({1}, bit32_tensor_opts); } // Convert tensors to Python @@ -1865,7 +1867,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve out_cpp.set_columnwise_scale_inv(columnwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3, columnwise_scale_inv_shape); out_cpp.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32, - getTensorShape(amax_columnwise)); + std::vector{1}); } out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); this->set_quantization_params(&out_cpp); @@ -1895,6 +1897,9 @@ std::pair NVFP4Quantizer::create_grouped_tenso std::optional rowwise_amax; std::optional columnwise_amax; const std::vector logical_shape_vec = {logical_first_dim, logical_last_dim}; + NVTE_CHECK(!this->per_token_activation || rowwise_usage, + "Per-token NVFP4 grouped quantization requires rowwise usage."); + const bool columnwise_usage = this->columnwise_usage && !this->per_token_activation; const int64_t total_data_elements = total_elements / 2; @@ -2041,6 +2046,9 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( } } const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; + NVTE_CHECK(!this->per_token_activation || rowwise_usage, + "Per-token NVFP4 quantization requires rowwise usage."); + const bool columnwise_usage = this->columnwise_usage && !this->per_token_activation; // Coerce row-wise data if (rowwise_usage) { @@ -2106,9 +2114,7 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); // hadamard amax kernel will zero out pointer with ZeroAmaxKernel // nvte_compute_amax_with_config will zero out the pointer if needed - const int64_t amax_rows = - this->per_token_activation ? static_cast(flat_first_dim) : 1; - amax_columnwise = at::empty({amax_rows}, opts); + amax_columnwise = at::empty({1}, opts); tensor.attr("_amax_columnwise") = *amax_columnwise; } } else { // columnwise_usage == false @@ -2144,7 +2150,7 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E4M3, getTensorShape(*columnwise_scale_inv)); out_cpp.set_columnwise_amax(amax_columnwise->data_ptr(), DType::kFloat32, - getTensorShape(*amax_columnwise)); + std::vector{1}); } out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); this->set_quantization_params(&out_cpp); diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 53f77da9e4..97b81933db 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -345,7 +345,8 @@ def make_empty( columnwise_data = None columnwise_scale_inv = None amax_columnwise = None - if self.columnwise_usage: + columnwise_usage = self.columnwise_usage and not self.per_token_activation + if columnwise_usage: # enforce 2D shape to avoid [S, B, H] shape and B and be 1 # and the transposed shape is [H, S, B], so divide last dim by 2 gives zero shape_2d = tuple([flat_first_dim, shape[-1]]) @@ -362,9 +363,8 @@ def make_empty( device=device, pin_memory=pin_memory, ) - amax_rows = flat_first_dim if self.per_token_activation else 1 amax_columnwise = torch.zeros( - amax_rows, dtype=torch.float32, device=device, pin_memory=pin_memory + 1, dtype=torch.float32, device=device, pin_memory=pin_memory ) # Construct FP8 tensor From db1c2a63820ddf1450bf0ad3579ba1a1b94607ab Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 5 May 2026 00:05:21 -0700 Subject: [PATCH 25/45] Clean up Signed-off-by: Ziang Li --- docs/envvars.rst | 2 +- tests/cpp/operator/test_dequantize_nvfp4.cu | 2 +- .../cast/nvfp4/quantize_per_token_nvfp4.cuh | 59 +++++++++---------- transformer_engine/common/recipe/__init__.py | 6 +- .../pytorch/cpp_extensions/gemm.py | 2 +- transformer_engine/pytorch/csrc/extensions.h | 2 - .../pytorch/csrc/extensions/cast.cpp | 45 -------------- .../pytorch/csrc/extensions/pybind.cpp | 2 - .../pytorch/tensor/nvfp4_tensor.py | 2 +- .../tensor/storage/grouped_tensor_storage.py | 1 + 10 files changed, 36 insertions(+), 87 deletions(-) diff --git a/docs/envvars.rst b/docs/envvars.rst index 58988b5473..8f90814d10 100644 --- a/docs/envvars.rst +++ b/docs/envvars.rst @@ -285,7 +285,7 @@ Kernel Configuration :Type: ``int`` (0 or 1) :Default: ``0`` - :Description: Enable per-token activation quantization for the ``NVFP4BlockScaling`` recipe in GroupedLinear split-quantize paths. When set to ``1`` (or when ``NVFP4BlockScaling(per_token_activation=True)`` is used), NVFP4 rowwise ``amax`` metadata stores one FP32 value per token (row) instead of a single scalar. + :Description: Enable per-token activation quantization for the ``NVFP4BlockScaling`` recipe. When set to ``1`` (or when ``NVFP4BlockScaling(per_token_activation=True)`` is used), forward activation quantizers store NVFP4 rowwise ``amax`` metadata as one FP32 value per token (row) instead of a single scalar. Torch Compilation and Fusion ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/cpp/operator/test_dequantize_nvfp4.cu b/tests/cpp/operator/test_dequantize_nvfp4.cu index f932c7dd7a..e61ab894d6 100644 --- a/tests/cpp/operator/test_dequantize_nvfp4.cu +++ b/tests/cpp/operator/test_dequantize_nvfp4.cu @@ -213,7 +213,7 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, cudaDeviceSynchronize(); } - // Dequantize with compact scales → reference output + // Dequantize with compact scales → reference output. Tensor output_compact("output_compact", std::vector{rows, cols}, otype, true, false); nvte_dequantize(quantized_compact.data(), output_compact.data(), 0); cudaDeviceSynchronize(); diff --git a/transformer_engine/common/cast/nvfp4/quantize_per_token_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_per_token_nvfp4.cuh index dec7248803..5176d01d10 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_per_token_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_per_token_nvfp4.cuh @@ -8,8 +8,8 @@ * \brief CUDA kernels to cast to NVFP4 with per-token (per-row) global scaling. */ -#ifndef TRANSFORMER_ENGINE_QUANTIZE_PERTOKEN_NVFP4_CUH_ -#define TRANSFORMER_ENGINE_QUANTIZE_PERTOKEN_NVFP4_CUH_ +#ifndef TRANSFORMER_ENGINE_QUANTIZE_PER_TOKEN_NVFP4_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_PER_TOKEN_NVFP4_CUH_ #include #include @@ -167,34 +167,31 @@ inline void quantize_per_token(const Tensor &input, const Tensor *noop, Tensor * "Expected BFloat16, Float16, or Float32."); } - if (output->has_data()) { - NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Rowwise output must have FP4 type."); - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Rowwise scaling tensor must be allocated."); - NVTE_CHECK(output->amax.dptr != nullptr, "Rowwise per-token amax tensor must be allocated."); - - QuantizationConfig per_token_quant_config; - if (quant_config != nullptr) { - per_token_quant_config = *quant_config; - } - per_token_quant_config.nvfp4_per_token_activation = true; - per_token_quant_config.nvfp4_2d_quantization = false; - - const bool use_optimized_kernel = - (input.dtype() == DType::kBFloat16) && (rows % 32 == 0) && (cols % 32 == 0); - if (use_optimized_kernel) { - quantize_transpose(input, noop, output, - &per_token_quant_config, stream); - } else { - quantize_transpose_vector_blockwise_fp4( - /*input=*/input.data, /*global_amax=*/output->amax, - /*scale_inv=*/output->scale_inv, /*scale_inv_t=*/output->columnwise_scale_inv, - /*output=*/output->data, /*output_t=*/output->columnwise_data, - /*epsilon=*/0.0f, /*return_identity=*/true, /*return_transpose=*/false, - /*pow2_scale=*/false, /*swizzled_scale=*/false, - /*use_stochastic_rounding=*/per_token_quant_config.stochastic_rounding, - /*rng_state=*/per_token_quant_config.rng_state, /*use_2d_quantization=*/false, - /*per_token_rowwise_scaling=*/true, /*noop_tensor=*/noop->data, /*stream=*/stream); - } + NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Rowwise output must have FP4 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Rowwise scaling tensor must be allocated."); + + QuantizationConfig per_token_quant_config; + if (quant_config != nullptr) { + per_token_quant_config = *quant_config; + } + per_token_quant_config.nvfp4_per_token_activation = true; + per_token_quant_config.nvfp4_2d_quantization = false; + + const bool use_optimized_kernel = + (input.dtype() == DType::kBFloat16) && (rows % 32 == 0) && (cols % 32 == 0); + if (use_optimized_kernel) { + quantize_transpose(input, noop, output, &per_token_quant_config, + stream); + } else { + quantize_transpose_vector_blockwise_fp4( + /*input=*/input.data, /*global_amax=*/output->amax, + /*scale_inv=*/output->scale_inv, /*scale_inv_t=*/output->columnwise_scale_inv, + /*output=*/output->data, /*output_t=*/output->columnwise_data, + /*epsilon=*/0.0f, /*return_identity=*/true, /*return_transpose=*/false, + /*pow2_scale=*/false, /*swizzled_scale=*/false, + /*use_stochastic_rounding=*/per_token_quant_config.stochastic_rounding, + /*rng_state=*/per_token_quant_config.rng_state, /*use_2d_quantization=*/false, + /*per_token_rowwise_scaling=*/true, /*noop_tensor=*/noop->data, /*stream=*/stream); } #else NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); @@ -205,4 +202,4 @@ inline void quantize_per_token(const Tensor &input, const Tensor *noop, Tensor * } // namespace dispatch } // namespace transformer_engine -#endif // TRANSFORMER_ENGINE_QUANTIZE_PERTOKEN_NVFP4_CUH_ +#endif // TRANSFORMER_ENGINE_QUANTIZE_PER_TOKEN_NVFP4_CUH_ diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index e59d01d82a..c2e3ac334e 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -479,9 +479,9 @@ class NVFP4BlockScaling(Recipe): disable_2d_quantization : bool, default = False If set to `True`, 1D block scaling with block size 16 is used for all tensors. per_token_activation : bool, default = False - If set to `True`, GroupedLinear activation split quantization uses per-token - (per-row) NVFP4 global amax values. In this mode, rowwise ``amax`` metadata - is stored as a vector with one FP32 value per token. + If set to `True`, forward activation quantizers use per-token (per-row) + NVFP4 global amax values. In this mode, rowwise ``amax`` metadata is + stored as a vector with one FP32 value per token. backward_override : {None, 'high_precision', 'dequantized'}, default = None Backward precision mode. None does not modify backward behavior, `high_precision` keeps original high-precision operands for backward, diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 394778f96b..49beef0778 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -364,7 +364,7 @@ def general_grouped_gemm( out_views[i].copy_(gemm_out) if single_output: out = out_init - return out, bias, gelu_input + return out, grad_bias, gelu_input if isinstance(quantization_params[0], DebugQuantizer): assert not gelu, "GELU not supported in debug mode" diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index f62853bb2b..4a2ea7412b 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -326,8 +326,6 @@ py::object group_dequantize(const py::handle &input, DType otype); py::object bgrad_group_quantize(const at::Tensor &tensor, py::handle quantizer, const size_t num_tensors, std::optional first_dims); -std::tuple quantize_nvfp4_per_token(at::Tensor input); - std::vector multi_tensor_quantize(const std::vector &tensor_list, std::vector quantizer_list); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 9e9964b99b..3497f1aa59 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -1601,50 +1601,5 @@ std::vector split_quantize(const at::Tensor &tensor, return output_py_list; } -std::tuple quantize_nvfp4_per_token(at::Tensor input) { - init_extension(); - - NVTE_CHECK(input.dim() == 2, "Input must be 2D (num_rows, num_cols)"); - NVTE_CHECK(input.is_cuda(), "Input must be on CUDA device"); - - const int num_rows = input.size(0); - const int num_cols = input.size(1); - NVTE_CHECK(num_cols % 16 == 0, - "num_cols must be a multiple of 16 for per-token NVFP4 quantization"); - - if (num_rows == 0) { - auto options = input.options(); - return {at::empty({0, num_cols / 2}, options.dtype(at::kByte)), - at::empty({0, num_cols / 16}, options.dtype(at::kByte)), - at::empty({0}, options.dtype(at::kFloat))}; - } - - auto input_contig = input.contiguous(); - auto options = input_contig.options(); - - auto output_data = at::empty({num_rows, num_cols / 2}, options.dtype(at::kByte)); - auto output_scales = at::empty({num_rows, num_cols / 16}, options.dtype(at::kByte)); - auto output_per_token_amax = at::empty({num_rows}, options.dtype(at::kFloat)); - - auto te_input = makeTransformerEngineTensor(input_contig); - TensorWrapper te_output(NVTE_NVFP4_1D_SCALING); - te_output.set_rowwise_data( - output_data.data_ptr(), DType::kFloat4E2M1, - std::vector{static_cast(num_rows), static_cast(num_cols)}); - te_output.set_rowwise_scale_inv( - output_scales.data_ptr(), DType::kFloat8E4M3, - std::vector{static_cast(num_rows), static_cast(num_cols / 16)}); - te_output.set_amax(output_per_token_amax.data_ptr(), DType::kFloat32, - std::vector{static_cast(num_rows)}); - QuantizationConfigWrapper quant_config; - quant_config.set_nvfp4_per_token_activation(true); - auto stream = at::cuda::getCurrentCUDAStream().stream(); - - NVTE_SCOPED_GIL_RELEASE( - { nvte_quantize_v2(te_input.data(), te_output.data(), quant_config, stream); }); - - return {output_data, output_scales, output_per_token_amax}; -} - } // namespace pytorch } // namespace transformer_engine diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index b2d74205cc..eb7576d905 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -145,8 +145,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Dequantize group tensor", py::arg("input"), py::arg("otype")); m.def("bgrad_group_quantize", transformer_engine::pytorch::bgrad_group_quantize, py::arg("tensor"), py::arg("quantizer"), py::arg("num_tensors"), py::arg("first_dims")); - m.def("quantize_nvfp4_per_token", transformer_engine::pytorch::quantize_nvfp4_per_token, - "Per-token NVFP4 quantization", py::arg("input")); m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize, "Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer")); m.def("generic_gemm", transformer_engine::pytorch::gemm, "Compute GEMM (matrix-matrix multiply)", diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 97b81933db..d6674f752e 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -128,7 +128,7 @@ class NVFP4Quantizer(Quantizer): """Stochastic rounding, only applicable for gradients.""" stochastic_rounding: bool - """Per-token activation quantization path (grouped split quantize).""" + """Per-token activation quantization path.""" per_token_activation: bool """RHT matrix random sign mask""" diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index 1732abf57c..2ea6ef958a 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -663,6 +663,7 @@ def make_grouped_tensor( amax = torch.empty(num_tensors, dtype=torch.float32, device=device) elif quantizer._get_compatible_recipe().nvfp4(): per_token_activation = getattr(quantizer, "per_token_activation", False) + columnwise_usage = columnwise_usage and not per_token_activation total_amax_elements = ( sum(math.prod(s[:-1]) for s in shape) if per_token_activation else num_tensors ) From 9eb06c79ef065e06f686a22b5431863eb94acb36 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 5 May 2026 00:29:27 -0700 Subject: [PATCH 26/45] Clean up Signed-off-by: Ziang Li --- .../cpp/operator/test_cast_nvfp4_transpose.cu | 39 ------------------- tests/cpp/operator/test_dequantize_nvfp4.cu | 3 +- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 2 +- .../cast/nvfp4/quantize_per_token_nvfp4.cuh | 12 +++--- .../cast/nvfp4/quantize_transpose_nvfp4.cuh | 2 +- 5 files changed, 9 insertions(+), 49 deletions(-) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index a6d4105702..1472a01c32 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -321,10 +321,6 @@ void compute_ref(float (*OP)(const float), std::vector input_t = create_transpose(input, rows, cols); if (per_token_amax != nullptr) { - constexpr size_t kBlockSize = 16; - constexpr float fp4_max_inv = 1.0f / 6.0f; - constexpr float float_max = Numeric_Traits::maxNorm; - per_token_amax->resize(rows, 0.0f); for (size_t row = 0; row < rows; ++row) { float row_amax = 0.0f; @@ -343,40 +339,6 @@ void compute_ref(float (*OP)(const float), use_fast_math, use_2d_quantization); } - - for (size_t col = 0; col < cols; ++col) { - for (size_t row_start = 0; row_start < rows; row_start += kBlockSize) { - float vals[kBlockSize]; - float s_enc[kBlockSize]; - float scaled_block_amax = 0.0f; - for (size_t i = 0; i < kBlockSize; ++i) { - const size_t row = row_start + i; - const float val = static_cast(input[row * cols + col]); - const float S_enc = - compute_global_encode_scaling_factor_FP4((*per_token_amax)[row], false); - vals[i] = val; - s_enc[i] = S_enc; - scaled_block_amax = fmaxf(scaled_block_amax, fabsf(val) * (S_enc * fp4_max_inv)); - } - - const float S_dec_b_f32 = fminf(scaled_block_amax, float_max); - const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b_f32); - scales_t[col * scales_stride_t + row_start / kBlockSize] = S_dec_b_fp8; - - for (size_t i = 0; i < kBlockSize; i += 2) { - const float S_dec_rowwise_x = 1.0f / s_enc[i]; - const float S_dec_rowwise_y = 1.0f / s_enc[i + 1]; - const float S_dec_b_fp32 = static_cast(S_dec_b_fp8); - const float S_enc_b_fp8_x = - fminf(1.0f / (S_dec_b_fp32 * S_dec_rowwise_x), float_max); - const float S_enc_b_fp8_y = - fminf(1.0f / (S_dec_b_fp32 * S_dec_rowwise_y), float_max); - const float2 scaled_elt_pair = {vals[i] * S_enc_b_fp8_x, - vals[i + 1] * S_enc_b_fp8_y}; - output_t[(col * rows + row_start + i) / 2] = fp4e2m1x2(scaled_elt_pair); - } - } - } } else if (use_2d_quantization) { // Step 1: Compute mathematical 8×8 scaling factors std::vector> math_scales; @@ -625,7 +587,6 @@ void set_per_token_amax_metadata(Tensor &output, const size_t rows) { static_cast(DType::kFloat32), nvte_make_shape(shape.data(), shape.size())}; nvte_set_tensor_param_v2(output_tensor, param, &amax_tensor, sizeof(amax_tensor)); - return amax; }; replace_amax(kNVTEAmax); diff --git a/tests/cpp/operator/test_dequantize_nvfp4.cu b/tests/cpp/operator/test_dequantize_nvfp4.cu index e61ab894d6..e9d6b5b525 100644 --- a/tests/cpp/operator/test_dequantize_nvfp4.cu +++ b/tests/cpp/operator/test_dequantize_nvfp4.cu @@ -97,7 +97,6 @@ void set_per_token_amax_metadata(Tensor &output, const size_t rows) { }; replace_amax(kNVTEAmax); - replace_amax(kNVTEColumnwiseAmax); } std::vector get_amax_values(const Tensor &tensor) { @@ -213,7 +212,7 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, cudaDeviceSynchronize(); } - // Dequantize with compact scales → reference output. + // Dequantize with compact scales to get the reference output. Tensor output_compact("output_compact", std::vector{rows, cols}, otype, true, false); nvte_dequantize(quantized_compact.data(), output_compact.data(), 0); cudaDeviceSynchronize(); diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index 7c9f2c7eb6..45aa34f414 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -118,7 +118,7 @@ def check_nvfp4_gemm_versus_reference( x_ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, rowwise=True, - columnwise=True, + columnwise=not per_token_activation, pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), diff --git a/transformer_engine/common/cast/nvfp4/quantize_per_token_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_per_token_nvfp4.cuh index 5176d01d10..824509f299 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_per_token_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_per_token_nvfp4.cuh @@ -36,8 +36,8 @@ namespace quantize_per_token_kernel { using namespace core; using namespace ptx; -constexpr int PERTOKEN_BLOCK_SIZE = 256; -constexpr int PERTOKEN_SF_VEC_SIZE = 16; +constexpr int PER_TOKEN_BLOCK_SIZE = 256; +constexpr int PER_TOKEN_SF_VEC_SIZE = 16; template __device__ __forceinline__ void abs_max_2x_update(ptx::FPx2 &dst, @@ -109,9 +109,9 @@ void launch_compute_per_token_amax(const int num_rows, const int num_cols, const NVTE_CHECK(num_cols % 2 == 0, "num_cols must be even for per-token amax computation, got ", num_cols); dim3 grid(num_rows); - dim3 block(PERTOKEN_BLOCK_SIZE); + dim3 block(PER_TOKEN_BLOCK_SIZE); - compute_per_token_amax_kernel + compute_per_token_amax_kernel <<>>(num_rows, num_cols, input, output_per_token_amax, noop); NVTE_CHECK_CUDA(cudaGetLastError()); #else @@ -139,9 +139,9 @@ inline void quantize_per_token(const Tensor &input, const Tensor *noop, Tensor * const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); - NVTE_CHECK(cols % quantize_per_token_kernel::PERTOKEN_SF_VEC_SIZE == 0, + NVTE_CHECK(cols % quantize_per_token_kernel::PER_TOKEN_SF_VEC_SIZE == 0, "Per-token NVFP4 quantization requires last dim divisible by ", - quantize_per_token_kernel::PERTOKEN_SF_VEC_SIZE, "."); + quantize_per_token_kernel::PER_TOKEN_SF_VEC_SIZE, "."); const auto *noop_ptr = reinterpret_cast(noop->data.dptr); auto *amax_ptr = reinterpret_cast(output->amax.dptr); diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index 331c78df51..08e4855a1b 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -520,7 +520,7 @@ __global__ void __launch_bounds__(THREADS_NUM) : 1.0f; } const float S_dec_rowwise_block = - PER_TOKEN_ROWWISE ? 1.0 / S_enc_rowwise_block : S_dec_rowwise; + PER_TOKEN_ROWWISE ? 1.0f / S_enc_rowwise_block : S_dec_rowwise; const nvfp4_scale_t S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc_rowwise_block); From 21274d81d86b822c5ce1abc70deded14ba23a70a Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 5 May 2026 00:48:04 -0700 Subject: [PATCH 27/45] Clean up column wise Signed-off-by: Ziang Li --- .../common/cast/dispatch/quantize.cuh | 21 +- .../cast/nvfp4/quantize_per_token_nvfp4.cuh | 205 ------------------ .../cast/nvfp4/quantize_transpose_nvfp4.cuh | 131 +++++++++++ .../custom_recipes/quantization_nvfp4.py | 22 +- 4 files changed, 143 insertions(+), 236 deletions(-) delete mode 100644 transformer_engine/common/cast/nvfp4/quantize_per_token_nvfp4.cuh diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index c204790861..7b063dbf6a 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -21,7 +21,6 @@ #include "../mxfp8/group_quantize_mxfp8.cuh" #include "../mxfp8/quantize_mxfp8.cuh" #include "../nvfp4/group_quantize_transpose_nvfp4.cuh" -#include "../nvfp4/quantize_per_token_nvfp4.cuh" #include "../nvfp4/quantize_transpose_nvfp4.cuh" namespace transformer_engine { @@ -105,9 +104,11 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, if (per_token_activation) { NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, "Per-token NVFP4 quantization does not support 2D quantization."); - nvfp4::quantize_per_token(*input_tensor, noop_tensor, output_tensor, &quant_config_cpp, - stream); - break; + NVTE_CHECK(output_tensor->has_data(), + "Per-token NVFP4 quantization requires rowwise output."); + NVTE_CHECK(!output_tensor->has_columnwise_data(), + "Per-token NVFP4 quantization does not produce columnwise output."); + nvfp4::compute_per_token_amax(*input_tensor, noop_tensor, output_tensor, stream); } bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && (cols % 32 == 0) && output_tensor->has_data(); @@ -135,7 +136,7 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, /*rng_state=*/quant_config_cpp.rng_state, /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, - /*per_token_rowwise_scaling=*/false, /*noop_tensor=*/noop_tensor->data, + /*per_token_rowwise_scaling=*/per_token_activation, /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); } break; @@ -249,14 +250,8 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens int32_t rows = grad_tensor->flat_first_dim(); int32_t cols = grad_tensor->flat_last_dim(); auto dtype = grad_tensor->dtype(); - const bool per_token_activation = quant_config_cpp.nvfp4_per_token_activation; - if (per_token_activation) { - NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, - "Per-token NVFP4 quantization does not support 2D quantization."); - nvfp4::quantize_per_token(*grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, - stream); - break; - } + NVTE_CHECK(!quant_config_cpp.nvfp4_per_token_activation, + "Per-token NVFP4 quantization is only supported for forward activation tensors."); bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && (cols % 32 == 0) && output_tensor->has_data(); diff --git a/transformer_engine/common/cast/nvfp4/quantize_per_token_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_per_token_nvfp4.cuh deleted file mode 100644 index 824509f299..0000000000 --- a/transformer_engine/common/cast/nvfp4/quantize_per_token_nvfp4.cuh +++ /dev/null @@ -1,205 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -/*! \file quantize_per_token_nvfp4.cuh - * \brief CUDA kernels to cast to NVFP4 with per-token (per-row) global scaling. - */ - -#ifndef TRANSFORMER_ENGINE_QUANTIZE_PER_TOKEN_NVFP4_CUH_ -#define TRANSFORMER_ENGINE_QUANTIZE_PER_TOKEN_NVFP4_CUH_ - -#include -#include - -#include -#include - -#include "../../common.h" -#include "../../transpose/cast_transpose.h" -#include "../../util/ptx.cuh" -#include "../../utils.cuh" -#include "core_nvfp4.cuh" -#include "quantize_transpose_nvfp4.cuh" - -#if FP4_TYPE_SUPPORTED -#include -#endif - -namespace transformer_engine { -namespace dispatch { -namespace nvfp4 { -namespace quantize_per_token_kernel { - -using namespace core; -using namespace ptx; - -constexpr int PER_TOKEN_BLOCK_SIZE = 256; -constexpr int PER_TOKEN_SF_VEC_SIZE = 16; - -template -__device__ __forceinline__ void abs_max_2x_update(ptx::FPx2 &dst, - const ptx::FPx2 &val) { - if constexpr (std::is_same_v) { - dst.x = fmaxf(fabsf(dst.x), fabsf(val.x)); - dst.y = fmaxf(fabsf(dst.y), fabsf(val.y)); - } else { - ptx::abs_max_2x(dst, dst, val); - } -} - -template -__device__ __forceinline__ float abs_max_2x_to_float(const ptx::FPx2 &val) { - if constexpr (std::is_same_v) { - return fmaxf(fabsf(val.x), fabsf(val.y)); - } else { - return static_cast(__hmax(__habs(val.x), __habs(val.y))); - } -} - -template -__global__ void -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -__launch_bounds__(BLOCK_SIZE) -#endif - compute_per_token_amax_kernel(const int num_rows, const int num_cols, - const IType *__restrict__ input, - float *__restrict__ output_per_token_amax, - const float *__restrict__ noop) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - if (noop != nullptr && noop[0] == 1.0f) { - return; - } - - using IType2 = typename ptx::FPx2; - - const int row_idx = blockIdx.x; - if (row_idx >= num_rows) return; - - const int num_vec2 = num_cols / 2; - const IType2 *input_row = reinterpret_cast(input + row_idx * num_cols); - - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; - for (int i = threadIdx.x; i < num_vec2; i += BLOCK_SIZE) { - const IType2 val = input_row[i]; - abs_max_2x_update(thread_amax_2x, val); - } - const float thread_max = abs_max_2x_to_float(thread_amax_2x); - - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - float row_amax = - BlockReduce(temp_storage).Reduce(thread_max, [](float a, float b) { return fmaxf(a, b); }); - - if (threadIdx.x == 0) { - output_per_token_amax[row_idx] = row_amax; - } -#endif -} - -template -void launch_compute_per_token_amax(const int num_rows, const int num_cols, const IType *input, - float *output_per_token_amax, cudaStream_t stream, - const float *noop = nullptr) { -#if FP4_TYPE_SUPPORTED - if (num_rows == 0 || num_cols == 0) return; - - NVTE_CHECK(num_cols % 2 == 0, "num_cols must be even for per-token amax computation, got ", - num_cols); - dim3 grid(num_rows); - dim3 block(PER_TOKEN_BLOCK_SIZE); - - compute_per_token_amax_kernel - <<>>(num_rows, num_cols, input, output_per_token_amax, noop); - NVTE_CHECK_CUDA(cudaGetLastError()); -#else - NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); -#endif -} - -} // namespace quantize_per_token_kernel - -inline void quantize_per_token(const Tensor &input, const Tensor *noop, Tensor *output, - const QuantizationConfig *quant_config, cudaStream_t stream) { -#if FP4_TYPE_SUPPORTED - using namespace detail; - - checkCuDriverContext(stream); - CheckNoopTensor(*noop, "cast_noop"); - CheckInputTensor(input, "input"); - CheckOutputTensor(*output, "output", false); - - NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); - NVTE_CHECK(output->has_data(), "Per-token NVFP4 quantization requires rowwise output."); - NVTE_CHECK(!output->has_columnwise_data(), - "Per-token NVFP4 quantization does not produce columnwise output."); - NVTE_CHECK(!output->with_gemm_swizzled_scales, "Output must have scales in compact format."); - - const size_t rows = input.flat_first_dim(); - const size_t cols = input.flat_last_dim(); - NVTE_CHECK(cols % quantize_per_token_kernel::PER_TOKEN_SF_VEC_SIZE == 0, - "Per-token NVFP4 quantization requires last dim divisible by ", - quantize_per_token_kernel::PER_TOKEN_SF_VEC_SIZE, "."); - - const auto *noop_ptr = reinterpret_cast(noop->data.dptr); - auto *amax_ptr = reinterpret_cast(output->amax.dptr); - NVTE_CHECK(amax_ptr != nullptr, "Per-token rowwise amax tensor must be allocated."); - NVTE_CHECK(output->amax.numel() == rows, "Per-token rowwise amax must have ", rows, - " entries, got ", output->amax.shape, "."); - - if (input.dtype() == DType::kBFloat16) { - const auto *input_ptr = reinterpret_cast(input.data.dptr); - quantize_per_token_kernel::launch_compute_per_token_amax<__nv_bfloat16>( - static_cast(rows), static_cast(cols), input_ptr, amax_ptr, stream, noop_ptr); - } else if (input.dtype() == DType::kFloat16) { - const auto *input_ptr = reinterpret_cast(input.data.dptr); - quantize_per_token_kernel::launch_compute_per_token_amax( - static_cast(rows), static_cast(cols), input_ptr, amax_ptr, stream, noop_ptr); - } else if (input.dtype() == DType::kFloat32) { - const auto *input_ptr = reinterpret_cast(input.data.dptr); - quantize_per_token_kernel::launch_compute_per_token_amax( - static_cast(rows), static_cast(cols), input_ptr, amax_ptr, stream, noop_ptr); - } else { - NVTE_ERROR( - "Unsupported input dtype for per-token NVFP4 quantization. " - "Expected BFloat16, Float16, or Float32."); - } - - NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Rowwise output must have FP4 type."); - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Rowwise scaling tensor must be allocated."); - - QuantizationConfig per_token_quant_config; - if (quant_config != nullptr) { - per_token_quant_config = *quant_config; - } - per_token_quant_config.nvfp4_per_token_activation = true; - per_token_quant_config.nvfp4_2d_quantization = false; - - const bool use_optimized_kernel = - (input.dtype() == DType::kBFloat16) && (rows % 32 == 0) && (cols % 32 == 0); - if (use_optimized_kernel) { - quantize_transpose(input, noop, output, &per_token_quant_config, - stream); - } else { - quantize_transpose_vector_blockwise_fp4( - /*input=*/input.data, /*global_amax=*/output->amax, - /*scale_inv=*/output->scale_inv, /*scale_inv_t=*/output->columnwise_scale_inv, - /*output=*/output->data, /*output_t=*/output->columnwise_data, - /*epsilon=*/0.0f, /*return_identity=*/true, /*return_transpose=*/false, - /*pow2_scale=*/false, /*swizzled_scale=*/false, - /*use_stochastic_rounding=*/per_token_quant_config.stochastic_rounding, - /*rng_state=*/per_token_quant_config.rng_state, /*use_2d_quantization=*/false, - /*per_token_rowwise_scaling=*/true, /*noop_tensor=*/noop->data, /*stream=*/stream); - } -#else - NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); -#endif -} - -} // namespace nvfp4 -} // namespace dispatch -} // namespace transformer_engine - -#endif // TRANSFORMER_ENGINE_QUANTIZE_PER_TOKEN_NVFP4_CUH_ diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index 08e4855a1b..14909435ac 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -16,6 +16,9 @@ #include #include +#include +#include + #include "../../common.h" #include "../../util/math.h" #include "../../util/ptx.cuh" @@ -27,6 +30,134 @@ namespace transformer_engine { namespace dispatch { namespace nvfp4 { +namespace per_token_amax_kernel { + +using namespace ptx; + +#if FP4_TYPE_SUPPORTED + +constexpr int PER_TOKEN_BLOCK_SIZE = 256; +constexpr int PER_TOKEN_SF_VEC_SIZE = 16; + +template +__device__ __forceinline__ void abs_max_2x_update(ptx::FPx2 &dst, + const ptx::FPx2 &val) { + if constexpr (std::is_same_v) { + dst.x = fmaxf(fabsf(dst.x), fabsf(val.x)); + dst.y = fmaxf(fabsf(dst.y), fabsf(val.y)); + } else { + ptx::abs_max_2x(dst, dst, val); + } +} + +template +__device__ __forceinline__ float abs_max_2x_to_float(const ptx::FPx2 &val) { + if constexpr (std::is_same_v) { + return fmaxf(fabsf(val.x), fabsf(val.y)); + } else { + return static_cast(__hmax(__habs(val.x), __habs(val.y))); + } +} + +template +__global__ void +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__launch_bounds__(BLOCK_SIZE) +#endif + compute_per_token_amax_kernel(const int num_rows, const int num_cols, + const IType *__restrict__ input, + float *__restrict__ output_per_token_amax, + const float *__restrict__ noop) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + + using IType2 = typename ptx::FPx2; + + const int row_idx = blockIdx.x; + if (row_idx >= num_rows) return; + + const int num_vec2 = num_cols / 2; + const IType2 *input_row = reinterpret_cast(input + row_idx * num_cols); + + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; + for (int i = threadIdx.x; i < num_vec2; i += BLOCK_SIZE) { + const IType2 val = input_row[i]; + abs_max_2x_update(thread_amax_2x, val); + } + const float thread_max = abs_max_2x_to_float(thread_amax_2x); + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + const float row_amax = + BlockReduce(temp_storage).Reduce(thread_max, [](float a, float b) { return fmaxf(a, b); }); + + if (threadIdx.x == 0) { + output_per_token_amax[row_idx] = row_amax; + } +#endif +} + +template +void launch_compute_per_token_amax(const int num_rows, const int num_cols, const IType *input, + float *output_per_token_amax, cudaStream_t stream, + const float *noop = nullptr) { + if (num_rows == 0 || num_cols == 0) return; + + NVTE_CHECK(num_cols % 2 == 0, "num_cols must be even for per-token amax computation, got ", + num_cols); + dim3 grid(num_rows); + dim3 block(PER_TOKEN_BLOCK_SIZE); + + compute_per_token_amax_kernel + <<>>(num_rows, num_cols, input, output_per_token_amax, noop); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +#endif // FP4_TYPE_SUPPORTED + +} // namespace per_token_amax_kernel + +inline void compute_per_token_amax(const Tensor &input, const Tensor *noop, Tensor *output, + cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + using namespace per_token_amax_kernel; + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + NVTE_CHECK(cols % PER_TOKEN_SF_VEC_SIZE == 0, + "Per-token NVFP4 quantization requires last dim divisible by ", PER_TOKEN_SF_VEC_SIZE, + "."); + + auto *amax_ptr = reinterpret_cast(output->amax.dptr); + NVTE_CHECK(amax_ptr != nullptr, "Per-token rowwise amax tensor must be allocated."); + NVTE_CHECK(output->amax.numel() == rows, "Per-token rowwise amax must have ", rows, + " entries, got ", output->amax.shape, "."); + + const auto *noop_ptr = reinterpret_cast(noop->data.dptr); + if (input.dtype() == DType::kBFloat16) { + const auto *input_ptr = reinterpret_cast(input.data.dptr); + launch_compute_per_token_amax<__nv_bfloat16>(static_cast(rows), static_cast(cols), + input_ptr, amax_ptr, stream, noop_ptr); + } else if (input.dtype() == DType::kFloat16) { + const auto *input_ptr = reinterpret_cast(input.data.dptr); + launch_compute_per_token_amax(static_cast(rows), static_cast(cols), input_ptr, + amax_ptr, stream, noop_ptr); + } else if (input.dtype() == DType::kFloat32) { + const auto *input_ptr = reinterpret_cast(input.data.dptr); + launch_compute_per_token_amax(static_cast(rows), static_cast(cols), input_ptr, + amax_ptr, stream, noop_ptr); + } else { + NVTE_ERROR( + "Unsupported input dtype for per-token NVFP4 quantization. " + "Expected BFloat16, Float16, or Float32."); + } +#else + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif // FP4_TYPE_SUPPORTED +} + namespace quantize_transpose_kernel { using namespace quantization_and_transposition_SF; diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py index d57ea792dd..36c50309f5 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py @@ -354,7 +354,7 @@ def __init__( with_rht: bool = False, with_random_sign_mask: bool = True, ): - super().__init__(rowwise=rowwise, columnwise=columnwise) + super().__init__(rowwise=rowwise, columnwise=columnwise and not per_token_activation) self.internal = True self.dtype = dtype @@ -450,13 +450,9 @@ def _quantize_blockwise_reference( *, pow_2_scales: bool, per_token_rowwise: bool = False, - per_token_columnwise: bool = False, eps: float, # pylint: disable=unused-argument ) -> Tuple[torch.Tensor, torch.Tensor]: - assert not ( - per_token_rowwise and per_token_columnwise - ), "Per-token rowwise and columnwise reference modes are mutually exclusive." if x.ndim != 2: raise ValueError( f"_quantize_blockwise_reference expects a 2D tensor, got {x.ndim}D with shape" @@ -497,8 +493,6 @@ def _quantize_blockwise_reference( else: if per_token_rowwise: global_amax = global_amax.to(torch.float32).view(m, 1, 1) - if per_token_columnwise: - global_amax = global_amax.to(torch.float32).view(1, n // tile_len_x, tile_len_x) global_encode_scale = torch.div(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX, global_amax) global_encode_scale = torch.min( @@ -521,16 +515,9 @@ def _quantize_blockwise_reference( global_decode_scale = torch.div(1.0, global_encode_scale) global_encode_scale_multiplier = global_encode_scale * torch.reciprocal(FLOAT4_E2M1_MAX) - if per_token_columnwise: - decode_scale = torch.amax( - torch.abs(x.to(torch.float32)) * global_encode_scale_multiplier, - dim=-1, - keepdim=True, - ) - else: - # Match the kernel's default path: fold the FP4 reciprocal into the - # global scale multiplier, but keep the final reciprocal exact. - decode_scale = vec_max * global_encode_scale_multiplier + # Match the kernel's default path: fold the FP4 reciprocal into the + # global scale multiplier, but keep the final reciprocal exact. + decode_scale = vec_max * global_encode_scale_multiplier decode_scale = torch.min( decode_scale, torch.tensor( @@ -711,7 +698,6 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ self.quant_tile_shape[1], self.quant_tile_shape[0], pow_2_scales=self.pow_2_scales, - per_token_columnwise=self.per_token_activation, eps=self.eps, ) From 4cbb43a78fe9f1ebde2003e410dc897f20946401 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 5 May 2026 01:02:19 -0700 Subject: [PATCH 28/45] Move shared test helpers Signed-off-by: Ziang Li --- .../cpp/operator/test_cast_nvfp4_transpose.cu | 40 +------------- tests/cpp/operator/test_dequantize_nvfp4.cu | 55 ++----------------- tests/cpp/test_common.cu | 53 ++++++++++++++++++ tests/cpp/test_common.h | 4 ++ 4 files changed, 65 insertions(+), 87 deletions(-) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 1472a01c32..ea9c145d10 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -550,48 +550,14 @@ void compareResults_nvfp4(const Tensor &test, } void compare_per_token_amax(const Tensor &output, const std::vector &ref_amax) { - NVTEBasicTensor amax; - nvte_get_tensor_param_v2(output.data(), kNVTEAmax, &amax, sizeof(amax), nullptr); - ASSERT_NE(amax.data_ptr, nullptr); - ASSERT_EQ(amax.shape.ndim, 1); - ASSERT_EQ(amax.shape.data[0], ref_amax.size()); - - std::vector test_amax_data(ref_amax.size()); - ASSERT_EQ(cudaMemcpy(test_amax_data.data(), - amax.data_ptr, - ref_amax.size() * sizeof(float), - cudaMemcpyDeviceToHost), - cudaSuccess); + const std::vector test_amax_data = output.tensor_amax_values(); + ASSERT_EQ(test_amax_data.size(), ref_amax.size()); for (size_t row = 0; row < ref_amax.size(); ++row) { ASSERT_EQ(test_amax_data[row], ref_amax[row]) << "Per-token amax mismatch at row " << row; } } -void set_per_token_amax_metadata(Tensor &output, const size_t rows) { - const std::vector shape = {rows}; - NVTETensor output_tensor = output.data(); - - auto replace_amax = [&](const NVTETensorParam param) { - NVTEBasicTensor old_amax; - nvte_get_tensor_param_v2(output_tensor, param, &old_amax, sizeof(old_amax), nullptr); - if (old_amax.data_ptr != nullptr) { - NVTE_CHECK_CUDA(cudaFree(old_amax.data_ptr)); - } - - float *amax = nullptr; - NVTE_CHECK_CUDA(cudaMalloc(&amax, rows * sizeof(float))); - NVTE_CHECK_CUDA(cudaMemset(amax, 0, rows * sizeof(float))); - - NVTEBasicTensor amax_tensor = {amax, - static_cast(DType::kFloat32), - nvte_make_shape(shape.data(), shape.size())}; - nvte_set_tensor_param_v2(output_tensor, param, &amax_tensor, sizeof(amax_tensor)); - }; - - replace_amax(kNVTEAmax); -} - template void performTest(float (*OP)(const float), const std::vector& shape, @@ -637,7 +603,7 @@ void performTest(float (*OP)(const float), std::vector ref_per_token_amax; bool use_2d_quantization = false; if (per_token_activation) { - set_per_token_amax_metadata(output, rows); + output.set_tensor_amax_shape({rows}); compute_ref(OP, input.rowwise_cpu_dptr(), ref_output.get(), diff --git a/tests/cpp/operator/test_dequantize_nvfp4.cu b/tests/cpp/operator/test_dequantize_nvfp4.cu index e9d6b5b525..efce28e0d1 100644 --- a/tests/cpp/operator/test_dequantize_nvfp4.cu +++ b/tests/cpp/operator/test_dequantize_nvfp4.cu @@ -75,42 +75,6 @@ void compute_ref_dequantize_nvfp4(const uint8_t *packed_data, } } -void set_per_token_amax_metadata(Tensor &output, const size_t rows) { - const std::vector shape = {rows}; - NVTETensor output_tensor = output.data(); - - auto replace_amax = [&](const NVTETensorParam param) { - NVTEBasicTensor old_amax; - nvte_get_tensor_param_v2(output_tensor, param, &old_amax, sizeof(old_amax), nullptr); - if (old_amax.data_ptr != nullptr) { - NVTE_CHECK_CUDA(cudaFree(old_amax.data_ptr)); - } - - float *amax = nullptr; - NVTE_CHECK_CUDA(cudaMalloc(&amax, rows * sizeof(float))); - NVTE_CHECK_CUDA(cudaMemset(amax, 0, rows * sizeof(float))); - - NVTEBasicTensor amax_tensor = {amax, - static_cast(DType::kFloat32), - nvte_make_shape(shape.data(), shape.size())}; - nvte_set_tensor_param_v2(output_tensor, param, &amax_tensor, sizeof(amax_tensor)); - }; - - replace_amax(kNVTEAmax); -} - -std::vector get_amax_values(const Tensor &tensor) { - NVTEBasicTensor amax; - nvte_get_tensor_param_v2(tensor.data(), kNVTEAmax, &amax, sizeof(amax), nullptr); - const size_t numel = amax.shape.ndim == 0 ? 1 : amax.shape.data[0]; - std::vector amax_values(numel); - if (numel > 0) { - NVTE_CHECK_CUDA(cudaMemcpy(amax_values.data(), amax.data_ptr, numel * sizeof(float), - cudaMemcpyDeviceToHost)); - } - return amax_values; -} - template float compute_amax(const test::Tensor &t, size_t rows, size_t cols) { t.to_cpu(); @@ -136,7 +100,7 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, Tensor quantized("quantized", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); if (per_token_activation) { - set_per_token_amax_metadata(quantized, rows); + quantized.set_tensor_amax_shape({rows}); } else if (rows > 0 && cols > 0) { quantized.set_tensor_amax(compute_amax(input, rows, cols)); } else { @@ -166,7 +130,7 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, const uint8_t *fp4_data = reinterpret_cast(quantized.rowwise_cpu_dptr()); const fp8e4m3 *scales = quantized.rowwise_cpu_scale_inv_ptr(); - const std::vector amax_val = get_amax_values(quantized); + const std::vector amax_val = quantized.tensor_amax_values(); const NVTEShape scale_shape = quantized.rowwise_scale_inv_shape(); const size_t scale_stride = scale_shape.data[scale_shape.ndim - 1]; @@ -194,7 +158,7 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, Tensor quantized_compact("quantized_compact", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); if (per_token_activation) { - set_per_token_amax_metadata(quantized_compact, rows); + quantized_compact.set_tensor_amax_shape({rows}); } else if (rows > 0 && cols > 0) { quantized_compact.set_tensor_amax(compute_amax(input, rows, cols)); } else { @@ -221,7 +185,7 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, Tensor quantized_swizzled("quantized_swizzled", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); if (per_token_activation) { - set_per_token_amax_metadata(quantized_swizzled, rows); + quantized_swizzled.set_tensor_amax_shape({rows}); } else { quantized_swizzled.set_tensor_amax(0.0f); } @@ -231,16 +195,7 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, // since from_cpu() uploads all CPU buffers (including zero-init data). quantized_compact.to_cpu(); if (per_token_activation) { - NVTEBasicTensor compact_amax; - NVTEBasicTensor swizzled_amax; - nvte_get_tensor_param_v2(quantized_compact.data(), kNVTEAmax, &compact_amax, - sizeof(compact_amax), nullptr); - nvte_get_tensor_param_v2(quantized_swizzled.data(), kNVTEAmax, &swizzled_amax, - sizeof(swizzled_amax), nullptr); - if (rows > 0) { - NVTE_CHECK_CUDA(cudaMemcpy(swizzled_amax.data_ptr, compact_amax.data_ptr, - rows * sizeof(float), cudaMemcpyDeviceToDevice)); - } + quantized_swizzled.copy_tensor_amax_from(quantized_compact); } else { quantized_swizzled.set_tensor_amax(quantized_compact.amax()); } diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index c756b83810..96e71f9513 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -543,6 +543,59 @@ void Tensor::set_scale(float scale) { } } +void Tensor::set_tensor_amax_shape(const std::vector &shape) { + const size_t numel = product(shape); + NVTE_CHECK(tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING, + "Amax shape override is only supported for NVFP4 test tensors."); + + auto old_amax = tensor_.get_amax(); + if (old_amax.data_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaFree(old_amax.data_ptr)); + } + + float *amax = nullptr; + NVTE_CHECK_CUDA(cudaMalloc(&amax, numel * sizeof(float))); + NVTE_CHECK_CUDA(cudaMemset(amax, 0, numel * sizeof(float))); + tensor_.set_amax(amax, DType::kFloat32, shape); +} + +std::vector Tensor::tensor_amax_values() const { + const auto amax = tensor_.get_amax(); + NVTE_CHECK(static_cast(amax.dtype) == DType::kFloat32, "Tensor amax must be FP32."); + + const size_t numel = product(amax.shape); + if (numel == 0) { + return {}; + } + NVTE_CHECK(amax.data_ptr != nullptr, "Tensor amax is not allocated."); + + std::vector values(numel); + NVTE_CHECK_CUDA( + cudaMemcpy(values.data(), amax.data_ptr, numel * sizeof(float), cudaMemcpyDeviceToHost)); + return values; +} + +void Tensor::copy_tensor_amax_from(const Tensor &other) { + const auto other_amax = other.tensor_.get_amax(); + NVTE_CHECK(static_cast(other_amax.dtype) == DType::kFloat32, + "Source tensor amax must be FP32."); + + auto my_amax = tensor_.get_amax(); + NVTE_CHECK(static_cast(my_amax.dtype) == DType::kFloat32, + "Destination tensor amax must be FP32."); + NVTE_CHECK(areShapesEqual(my_amax.shape, other_amax.shape), "Amax shape mismatch."); + + const size_t numel = product(other_amax.shape); + if (numel == 0) { + return; + } + + NVTE_CHECK(other_amax.data_ptr != nullptr, "Source tensor amax is not allocated."); + NVTE_CHECK(my_amax.data_ptr != nullptr, "Destination tensor amax is not allocated."); + NVTE_CHECK_CUDA(cudaMemcpy(my_amax.data_ptr, other_amax.data_ptr, numel * sizeof(float), + cudaMemcpyDeviceToDevice)); +} + void Tensor::set_scale_inv(float scale_inv) { if (isFp8Type(dtype()) || isFp4Type(dtype())) { if (rowwise_) { diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index b8389d5833..fa46995991 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -319,6 +319,10 @@ class Tensor { tensor_.set_amax(nullptr, DType::kFloat32, tensor_.defaultShape); } + void set_tensor_amax_shape(const std::vector &shape); + std::vector tensor_amax_values() const; + void copy_tensor_amax_from(const Tensor &other); + void set_with_gemm_swizzled_scales(bool with_gemm_swizzled_scales){ tensor_.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); } From d4ab1e701b4bb4577ea4375c697bb049d8f054dc Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 5 May 2026 01:12:31 -0700 Subject: [PATCH 29/45] Minor clean up test Signed-off-by: Ziang Li --- tests/pytorch/test_backward_override.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index b96d75cfff..de916ff562 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -183,10 +183,7 @@ def _maybe_skip_recipe_dtype( def _maybe_skip_unsupported_recipe_module_combo(recipe_name: str, module_type: str) -> None: if module_type == "ops_linear" and recipe_name == "fp8_block_scaling": pytest.skip("Fusible ops (te_ops.Linear) do not support Float8BlockScaling recipe") - - -def _maybe_skip_unsupported_fused_ops(recipe_name: str) -> None: - if recipe_name == "nvfp4_per_token": + if module_type == "ops_linear" and recipe_name == "nvfp4_per_token": pytest.skip("Per-token NVFP4 currently does not support fused te_ops paths.") @@ -1357,7 +1354,6 @@ def test_fused_linear_paths_match_backward_override_reference( _maybe_skip_recipe_dtype(recipe_name, dtype, "ops_linear") _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") _maybe_skip_unsupported_recipe_shape(recipe_name, (m, in_features), "ops_linear") - _maybe_skip_unsupported_fused_ops(recipe_name) reset_rng_states() @@ -1497,7 +1493,6 @@ def test_fused_bias_activation_matches_masked_linear_backward( _maybe_skip_recipe_dtype(recipe_name, dtype, "ops_linear") _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") _maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, "ops_linear") - _maybe_skip_unsupported_fused_ops(recipe_name) reset_rng_states() in_features = input_shape[-1] @@ -1644,7 +1639,6 @@ def test_operation_fuser_rebuilds_userbuffers_fusion_on_backward_override_switch reset_rng_states() _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") - _maybe_skip_unsupported_fused_ops(recipe_name) # Build a Userbuffers-eligible fuser and representative inputs. linear = te_ops.BasicLinear( From 363335b582fceb64a1e05d3ac52501b1e7d68f03 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 5 May 2026 01:42:36 -0700 Subject: [PATCH 30/45] Readability Signed-off-by: Ziang Li --- .../cast/nvfp4/quantize_transpose_nvfp4.cuh | 74 ++++++++++++------- .../quantize_transpose_nvfp4_tuned_1D.cuh | 18 +++-- 2 files changed, 58 insertions(+), 34 deletions(-) diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index 14909435ac..1888bc74f7 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -640,38 +640,58 @@ __global__ void __launch_bounds__(THREADS_NUM) } } - // 2. Compute E4M3 scaling factor - const size_t row_idx_global = - scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; - float S_enc_rowwise_block = S_enc_rowwise; + float block_scale_inverse; if constexpr (PER_TOKEN_ROWWISE) { - S_enc_rowwise_block = - row_idx_global < rows - ? compute_global_encode_scaling_factor_FP4(amax_rowwise_ptr[row_idx_global]) + // 2. Compute E4M3 scaling factor + const size_t scales_offset_Y = + scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; + const float S_enc_rowwise_block = + scales_offset_Y < rows + ? compute_global_encode_scaling_factor_FP4(amax_rowwise_ptr[scales_offset_Y]) : 1.0f; - } - const float S_dec_rowwise_block = - PER_TOKEN_ROWWISE ? 1.0f / S_enc_rowwise_block : S_dec_rowwise; - const nvfp4_scale_t S_dec_b_fp8 = - compute_decoding_scaling_factor(block_amax, S_enc_rowwise_block); + const float S_dec_rowwise_block = 1.0f / S_enc_rowwise_block; + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_rowwise_block); + + // Check boundaries + const size_t scales_offset_X = scales_offset_X_rowwise; + const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; + + // const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows; + const bool rowwise_scale_is_within_bounds_Y = + (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; + if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { + scales_ptr[scale_idx_global] = S_dec_b_fp8; + } - // Check boundaries - const size_t scales_offset_Y = row_idx_global; - const size_t scales_offset_X = scales_offset_X_rowwise; - const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + block_scale_inverse = + fminf(1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise_block), + float_max); // S_enc_b_fp8 + } else { + // 2. Compute E4M3 scaling factor + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_rowwise); + + // Check boundaries + const size_t scales_offset_Y = + scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; + const size_t scales_offset_X = scales_offset_X_rowwise; + const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; + + // const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows; + const bool rowwise_scale_is_within_bounds_Y = + (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; + if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { + scales_ptr[scale_idx_global] = S_dec_b_fp8; + } - // const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows; - const bool rowwise_scale_is_within_bounds_Y = - (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; - if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { - scales_ptr[scale_idx_global] = S_dec_b_fp8; + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + block_scale_inverse = fminf(1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise), + float_max); // S_enc_b_fp8 } - - // Compute "correct" per-block encoding scaling factor - constexpr float float_max = detail::TypeExtrema::max; - const float block_scale_inverse = - fminf(1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise_block), - float_max); // S_enc_b_fp8 const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; // 3. Scale elements diff --git a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh index 172b35b245..378023b54e 100644 --- a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh +++ b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh @@ -313,17 +313,21 @@ __device__ __forceinline__ void rowwise_scaling( } const float block_amax = get_amax_of_pair(thread_amax_2x); - const size_t row_idx = row_offset + stage_Y * TILE_DIM_Y + it_offset_Y_rowwise; - float S_enc_rowwise_block = S_enc_rowwise; + nvfp4_scale_t S_dec_b_fp8; + scaling_coeff_type SFcoefficient; if constexpr (PER_TOKEN_ROWWISE) { - S_enc_rowwise_block = + const size_t row_idx = row_offset + stage_Y * TILE_DIM_Y + it_offset_Y_rowwise; + const float S_enc_rowwise_block = row_idx < rows ? core::compute_global_encode_scaling_factor_FP4(amax_rowwise_ptr[row_idx]) : 1.0f; + S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc_rowwise_block); + SFcoefficient = + compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_enc_rowwise_block); + } else { + S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc_rowwise); + SFcoefficient = + compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_enc_rowwise); } - const nvfp4_scale_t S_dec_b_fp8 = - compute_decoding_scaling_factor(block_amax, S_enc_rowwise_block); - const scaling_coeff_type SFcoefficient = - compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_enc_rowwise_block); // Store scaling factors to SMEM buffer (R2S) if (SF_storing_thread) { From 1a4d3b0da24039197519537286a3edc6d316c954 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 5 May 2026 17:10:15 -0700 Subject: [PATCH 31/45] Rename Signed-off-by: Ziang Li --- docs/envvars.rst | 4 +- .../cpp/operator/test_cast_nvfp4_transpose.cu | 40 ++++----- tests/cpp/operator/test_dequantize_nvfp4.cu | 43 ++++------ tests/cpp/test_common.h | 4 + tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 50 +++++------ .../nvfp4/test_nvfp4_quantize_exact.py | 56 ++++++------- tests/pytorch/test_backward_override.py | 22 ++--- tests/pytorch/test_cuda_graphs.py | 10 +-- tests/pytorch/test_recipe.py | 14 ++-- tests/pytorch/test_sanity.py | 18 ++-- tests/pytorch/test_torch_compile.py | 2 +- tests/pytorch/utils.py | 14 ++-- .../common/cast/dispatch/quantize.cuh | 21 ++--- .../common/cast/nvfp4/dequantize_nvfp4.cuh | 11 ++- .../cast/nvfp4/quantize_transpose_nvfp4.cuh | 82 +++++++++---------- .../quantize_transpose_nvfp4_tuned_1D.cuh | 21 ++--- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 8 ++ transformer_engine/common/common.h | 13 ++- .../transformer_engine/transformer_engine.h | 21 +++-- transformer_engine/common/recipe/__init__.py | 12 +-- .../common/transformer_engine.cpp | 12 +-- .../common/transpose/cast_transpose.h | 2 +- ...quantize_transpose_vector_blockwise_fp4.cu | 24 +++--- .../pytorch/cpp_extensions/gemm.py | 67 +++++++-------- transformer_engine/pytorch/csrc/common.h | 2 +- .../pytorch/csrc/extensions/activation.cpp | 4 +- .../pytorch/csrc/extensions/bias.cpp | 2 +- .../pytorch/csrc/extensions/cast.cpp | 58 +++++++------ .../pytorch/csrc/extensions/normalization.cpp | 4 +- transformer_engine/pytorch/csrc/quantizer.cpp | 56 +++++++------ .../pytorch/csrc/type_converters.cpp | 2 + .../custom_recipes/quantization_nvfp4.py | 20 ++--- transformer_engine/pytorch/quantization.py | 4 +- .../pytorch/tensor/nvfp4_tensor.py | 17 ++-- .../tensor/storage/grouped_tensor_storage.py | 27 +++--- .../tensor/storage/nvfp4_tensor_storage.py | 8 ++ 36 files changed, 407 insertions(+), 368 deletions(-) diff --git a/docs/envvars.rst b/docs/envvars.rst index 8f90814d10..665f7912ab 100644 --- a/docs/envvars.rst +++ b/docs/envvars.rst @@ -281,11 +281,11 @@ Kernel Configuration :Default: ``0`` :Description: Emit a warning when falling back from CUTLASS to cuBLAS for grouped GEMM operations. -.. envvar:: NVTE_NVFP4_PER_TOKEN_ACTIVATION +.. envvar:: NVTE_NVFP4_ROW_SCALED_ACTIVATION :Type: ``int`` (0 or 1) :Default: ``0`` - :Description: Enable per-token activation quantization for the ``NVFP4BlockScaling`` recipe. When set to ``1`` (or when ``NVFP4BlockScaling(per_token_activation=True)`` is used), forward activation quantizers store NVFP4 rowwise ``amax`` metadata as one FP32 value per token (row) instead of a single scalar. + :Description: Enable row-scaled NVFP4 tensors for forward activation quantizers in the ``NVFP4BlockScaling`` recipe. When set to ``1`` (or when ``NVFP4BlockScaling(row_scaled_activation=True)`` is used), rowwise ``amax`` metadata is stored as one FP32 value per tensor row instead of a single scalar. Torch Compilation and Fusion ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index ea9c145d10..34101e8572 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -316,18 +316,18 @@ void compute_ref(float (*OP)(const float), const size_t scales_stride_t, const bool use_fast_math, const bool use_2d_quantization = false, - std::vector *per_token_amax = nullptr) + std::vector *rowwise_amax = nullptr) { std::vector input_t = create_transpose(input, rows, cols); - if (per_token_amax != nullptr) { - per_token_amax->resize(rows, 0.0f); + if (rowwise_amax != nullptr) { + rowwise_amax->resize(rows, 0.0f); for (size_t row = 0; row < rows; ++row) { float row_amax = 0.0f; for (size_t col = 0; col < cols; ++col) { row_amax = fmaxf(row_amax, fabsf(static_cast(input[row * cols + col]))); } - (*per_token_amax)[row] = row_amax; + (*rowwise_amax)[row] = row_amax; quantize_nvfp4(OP, input + row * cols, output + row * (cols / 2), @@ -549,12 +549,12 @@ void compareResults_nvfp4(const Tensor &test, } } -void compare_per_token_amax(const Tensor &output, const std::vector &ref_amax) { +void compare_rowwise_amax(const Tensor &output, const std::vector &ref_amax) { const std::vector test_amax_data = output.tensor_amax_values(); ASSERT_EQ(test_amax_data.size(), ref_amax.size()); for (size_t row = 0; row < ref_amax.size(); ++row) { ASSERT_EQ(test_amax_data[row], ref_amax[row]) - << "Per-token amax mismatch at row " << row; + << "Row-scaled amax mismatch at row " << row; } } @@ -562,7 +562,7 @@ template void performTest(float (*OP)(const float), const std::vector& shape, const bool use_fast_math, - const bool per_token_activation = false) { + const bool row_scaled_activation = false) { using namespace test; DType itype = TypeInfo::dtype; @@ -589,7 +589,7 @@ void performTest(float (*OP)(const float), const size_t scales_stride_t = blocks_X_t; Tensor input("input", shape, itype); - Tensor output("output", shape, otype, true, !per_token_activation, NVTE_NVFP4_1D_SCALING); + Tensor output("output", shape, otype, true, !row_scaled_activation, NVTE_NVFP4_1D_SCALING); std::unique_ptr ref_output = std::make_unique(rows * (cols / 2)); std::unique_ptr ref_output_t = std::make_unique(cols * (rows / 2)); @@ -600,10 +600,11 @@ void performTest(float (*OP)(const float), // Golden value of amax chosen to make the 2nd-stage scaling mantissa zero and avoid rounding issues const float amax = 448.0f * 6.0f * 8.0f; - std::vector ref_per_token_amax; + std::vector ref_rowwise_amax; bool use_2d_quantization = false; - if (per_token_activation) { + if (row_scaled_activation) { output.set_tensor_amax_shape({rows}); + output.set_rowwise_amax_is_row_scaled(true); compute_ref(OP, input.rowwise_cpu_dptr(), ref_output.get(), @@ -617,7 +618,7 @@ void performTest(float (*OP)(const float), scales_stride_t, use_fast_math, use_2d_quantization, - &ref_per_token_amax); + &ref_rowwise_amax); } else { // Set 2nd stage NVFP4 scaling factor output.set_tensor_amax(amax); @@ -650,7 +651,6 @@ void performTest(float (*OP)(const float), // Set 2D quantization based on compile-time flag quant_config.set_nvfp4_2d_quantization(use_2d_quantization); - quant_config.set_nvfp4_per_token_activation(per_token_activation); // Call appropriate function based on operation type // Activation functions take 3 parameters (input, output, stream) @@ -681,7 +681,7 @@ void performTest(float (*OP)(const float), // Set dump_data=true to enable dumping tensor data to files for analysis compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, true, - false, !per_token_activation); + false, !row_scaled_activation); size_t scale_mismatches_num = 0; compare_scaling_factors("scales", output.rowwise_cpu_scale_inv_ptr(), @@ -689,15 +689,15 @@ void performTest(float (*OP)(const float), unpadded_blocks_Y, unpadded_blocks_X, scales_stride, scale_mismatches_num); - if (!per_token_activation) { + if (!row_scaled_activation) { compare_scaling_factors("scales_t", output.columnwise_cpu_scale_inv_ptr(), ref_scales_t.get(), unpadded_blocks_Y_t, unpadded_blocks_X_t, scales_stride_t, scale_mismatches_num); } - if (per_token_activation) { - compare_per_token_amax(output, ref_per_token_amax); + if (row_scaled_activation) { + compare_rowwise_amax(output, ref_rowwise_amax); } } @@ -747,7 +747,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { const auto tensor_dims = std::get<1>(GetParam()); const DType input_type = std::get<2>(GetParam()); const bool use_fast_math = std::get<3>(GetParam()); - const bool per_token_activation = std::get<4>(GetParam()); + const bool row_scaled_activation = std::get<4>(GetParam()); // Skip tests if the input tensor is 1D if (tensor_dims.size() < 2) { @@ -765,7 +765,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { } TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, - performTest(OP, tensor_dims, use_fast_math, per_token_activation); + performTest(OP, tensor_dims, use_fast_math, row_scaled_activation); ); } @@ -804,7 +804,7 @@ INSTANTIATE_TEST_SUITE_P( }); INSTANTIATE_TEST_SUITE_P( - OperatorTestPerToken, + OperatorTestRowScaled, FusedCastTransposeNVFP4TestSuite, ::testing::Combine( ::testing::Values(ActivationType::Identity), @@ -823,7 +823,7 @@ INSTANTIATE_TEST_SUITE_P( name += "X_FAST_SCALING"; } if (std::get<4>(info.param)) { - name += "XPER_TOKEN"; + name += "XROW_SCALED"; } return name; }); diff --git a/tests/cpp/operator/test_dequantize_nvfp4.cu b/tests/cpp/operator/test_dequantize_nvfp4.cu index efce28e0d1..87dd0d3508 100644 --- a/tests/cpp/operator/test_dequantize_nvfp4.cu +++ b/tests/cpp/operator/test_dequantize_nvfp4.cu @@ -90,7 +90,7 @@ float compute_amax(const test::Tensor &t, size_t rows, size_t cols) { // against a CPU reference computed from the quantized data. template void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, - const bool per_token_activation) { + const bool row_scaled_activation) { using namespace test; DType otype = TypeInfo::dtype; @@ -99,8 +99,9 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, Tensor quantized("quantized", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); - if (per_token_activation) { + if (row_scaled_activation) { quantized.set_tensor_amax_shape({rows}); + quantized.set_rowwise_amax_is_row_scaled(true); } else if (rows > 0 && cols > 0) { quantized.set_tensor_amax(compute_amax(input, rows, cols)); } else { @@ -108,13 +109,7 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, } if (rows > 0 && cols > 0) { - if (per_token_activation) { - QuantizationConfigWrapper quant_config; - quant_config.set_nvfp4_per_token_activation(true); - nvte_quantize_v2(input.data(), quantized.data(), quant_config, 0); - } else { - nvte_quantize(input.data(), quantized.data(), 0); - } + nvte_quantize(input.data(), quantized.data(), 0); cudaDeviceSynchronize(); } @@ -148,7 +143,7 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, // Dequantize NVFP4 with GEMM-swizzled scales and compare against compact path. template void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, - const bool per_token_activation) { + const bool row_scaled_activation) { using namespace test; DType otype = TypeInfo::dtype; @@ -157,8 +152,9 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, Tensor quantized_compact("quantized_compact", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); - if (per_token_activation) { + if (row_scaled_activation) { quantized_compact.set_tensor_amax_shape({rows}); + quantized_compact.set_rowwise_amax_is_row_scaled(true); } else if (rows > 0 && cols > 0) { quantized_compact.set_tensor_amax(compute_amax(input, rows, cols)); } else { @@ -166,13 +162,7 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, } if (rows > 0 && cols > 0) { - if (per_token_activation) { - QuantizationConfigWrapper quant_config; - quant_config.set_nvfp4_per_token_activation(true); - nvte_quantize_v2(input.data(), quantized_compact.data(), quant_config, 0); - } else { - nvte_quantize(input.data(), quantized_compact.data(), 0); - } + nvte_quantize(input.data(), quantized_compact.data(), 0); cudaDeviceSynchronize(); } @@ -184,8 +174,9 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, // Create tensor with same FP4 data but swizzled scales Tensor quantized_swizzled("quantized_swizzled", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); - if (per_token_activation) { + if (row_scaled_activation) { quantized_swizzled.set_tensor_amax_shape({rows}); + quantized_swizzled.set_rowwise_amax_is_row_scaled(true); } else { quantized_swizzled.set_tensor_amax(0.0f); } @@ -194,7 +185,7 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, // Copy amax and scale from compact to swizzled before FP4 data, // since from_cpu() uploads all CPU buffers (including zero-init data). quantized_compact.to_cpu(); - if (per_token_activation) { + if (row_scaled_activation) { quantized_swizzled.copy_tensor_amax_from(quantized_compact); } else { quantized_swizzled.set_tensor_amax(quantized_compact.amax()); @@ -265,11 +256,11 @@ TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4) const auto tensor_size = std::get<0>(GetParam()); const DType output_type = std::get<1>(GetParam()); - const bool per_token_activation = std::get<2>(GetParam()); + const bool row_scaled_activation = std::get<2>(GetParam()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType, performTest_dequantize_nvfp4( - tensor_size.first, tensor_size.second, per_token_activation); + tensor_size.first, tensor_size.second, row_scaled_activation); ); } @@ -285,7 +276,7 @@ INSTANTIATE_TEST_SUITE_P( std::string name = std::to_string(std::get<0>(info.param).first) + "X" + std::to_string(std::get<0>(info.param).second) + "X" + test::typeName(std::get<1>(info.param)) + "X" + - (std::get<2>(info.param) ? "PerToken" : "PerTensor"); + (std::get<2>(info.param) ? "RowScaled" : "PerTensor"); return name; } ); @@ -303,11 +294,11 @@ TEST_P(DequantizeNVFP4SwizzledTestSuite, TestDequantizeNVFP4Swizzled) const auto tensor_size = std::get<0>(GetParam()); const DType output_type = std::get<1>(GetParam()); - const bool per_token_activation = std::get<2>(GetParam()); + const bool row_scaled_activation = std::get<2>(GetParam()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType, performTest_dequantize_nvfp4_swizzled( - tensor_size.first, tensor_size.second, per_token_activation); + tensor_size.first, tensor_size.second, row_scaled_activation); ); } @@ -323,7 +314,7 @@ INSTANTIATE_TEST_SUITE_P( std::string name = std::to_string(std::get<0>(info.param).first) + "X" + std::to_string(std::get<0>(info.param).second) + "X" + test::typeName(std::get<1>(info.param)) + "X" + - (std::get<2>(info.param) ? "PerToken" : "PerTensor") + "X" + + (std::get<2>(info.param) ? "RowScaled" : "PerTensor") + "X" + "Swizzled"; return name; } diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index fa46995991..61684e8e40 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -327,6 +327,10 @@ class Tensor { tensor_.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); } + void set_rowwise_amax_is_row_scaled(bool rowwise_amax_is_row_scaled) { + tensor_.set_rowwise_amax_is_row_scaled(rowwise_amax_is_row_scaled); + } + void to_cpu() const; void from_cpu() const; void set_scale(float scale); diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index 45aa34f414..a5052e9726 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -27,7 +27,7 @@ def check_nvfp4_gemm_versus_reference( *, x_columnwise: bool = False, w_columnwise: bool = False, - per_token_activation: bool = False, + row_scaled_activation: bool = False, ): te_dtype = tex.DType.kFloat4E2M1 @@ -58,7 +58,7 @@ def check_nvfp4_gemm_versus_reference( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, - per_token_activation=per_token_activation, + row_scaled_activation=row_scaled_activation, ) w_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -118,11 +118,11 @@ def check_nvfp4_gemm_versus_reference( x_ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, rowwise=True, - columnwise=not per_token_activation, + columnwise=not row_scaled_activation, pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), - per_token_activation=per_token_activation, + row_scaled_activation=row_scaled_activation, ) w_ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, @@ -178,7 +178,7 @@ def check_nvfp4_gemm_versus_reference( x_nvfp4_native.update_usage(rowwise_usage=False) if w_columnwise: w_nvfp4_native.update_usage(rowwise_usage=False) - if per_token_activation: + if row_scaled_activation: layout = ("T" if transa else "N") + ("T" if transb else "N") y_native = general_gemm( w_nvfp4_native, @@ -222,7 +222,7 @@ def check_nvfp4_gemm_versus_reference( torch.testing.assert_close(y_native, y_ref, atol=8e-3, rtol=8e-3) -def check_nvfp4_per_token_grouped_gemm_matches_per_gemm( +def check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( x_dtype: torch.dtype, w_dtype: torch.dtype, out_dtype: torch.dtype, @@ -248,7 +248,7 @@ def check_nvfp4_per_token_grouped_gemm_matches_per_gemm( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, - per_token_activation=True, + row_scaled_activation=True, ) w_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -314,7 +314,7 @@ def check_nvfp4_per_token_grouped_gemm_matches_per_gemm( torch.testing.assert_close(grouped, ref, atol=0.0, rtol=0.0) -def check_nvfp4_per_token_gemm_matches_emulated( +def check_nvfp4_row_scaled_gemm_matches_emulated( x_dtype: torch.dtype, w_dtype: torch.dtype, out_dtype: torch.dtype, @@ -330,7 +330,7 @@ def check_nvfp4_per_token_gemm_matches_emulated( x = torch.randn((M, K), dtype=x_dtype, device=device) w = torch.randn((N, K), dtype=w_dtype, device=device) - x_per_token_quantizer = NVFP4Quantizer( + x_row_scaled_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, rowwise=True, columnwise=True, @@ -338,7 +338,7 @@ def check_nvfp4_per_token_gemm_matches_emulated( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, - per_token_activation=True, + row_scaled_activation=True, ) x_tensorwise_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -359,13 +359,13 @@ def check_nvfp4_per_token_gemm_matches_emulated( with_post_rht_amax=False, ) - x_per_token = x_per_token_quantizer.update_quantized( - x, x_per_token_quantizer.make_empty(x.shape, dtype=x_dtype, device=device) + x_row_scaled = x_row_scaled_quantizer.update_quantized( + x, x_row_scaled_quantizer.make_empty(x.shape, dtype=x_dtype, device=device) ) w_nvfp4 = w_quantizer.update_quantized( w, w_quantizer.make_empty(w.shape, dtype=w_dtype, device=device) ) - y_per_token = general_gemm(w_nvfp4, x_per_token, out_dtype=out_dtype, layout="TN")[0] + y_row_scaled = general_gemm(w_nvfp4, x_row_scaled, out_dtype=out_dtype, layout="TN")[0] emulated_rows = [] for i in range(M): @@ -381,9 +381,9 @@ def check_nvfp4_per_token_gemm_matches_emulated( y_emulated = torch.cat(emulated_rows, dim=0) if out_dtype == torch.bfloat16: - torch.testing.assert_close(y_per_token, y_emulated, atol=0.0, rtol=7.8e-3) + torch.testing.assert_close(y_row_scaled, y_emulated, atol=0.0, rtol=7.8e-3) else: - torch.testing.assert_close(y_per_token, y_emulated, atol=3.0517578125e-5, rtol=0.0) + torch.testing.assert_close(y_row_scaled, y_emulated, atol=3.0517578125e-5, rtol=0.0) @pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) @@ -416,7 +416,7 @@ def check_nvfp4_per_token_gemm_matches_emulated( ], ids=["rowxrow", "colxrow", "colxcol"], ) -@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_per_token"]) +@pytest.mark.parametrize("row_scaled_activation", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) def test_nvfp4_gemm_versus_reference( M: int, K: int, @@ -427,13 +427,13 @@ def test_nvfp4_gemm_versus_reference( accumulate: bool, is_x_columnwise: bool, is_w_columnwise: bool, - per_token_activation: bool, + row_scaled_activation: bool, ): - if per_token_activation: + if row_scaled_activation: if accumulate: - pytest.skip("Per-token NVFP4 GEMM output rescale does not support accumulation") + pytest.skip("Row-scaled NVFP4 GEMM output rescale does not support accumulation") if is_x_columnwise: - pytest.skip("Per-token NVFP4 GEMM output rescale requires rowwise activation usage") + pytest.skip("Row-scaled NVFP4 GEMM output rescale requires rowwise activation usage") check_nvfp4_gemm_versus_reference( x_dtype=x_dtype, @@ -445,7 +445,7 @@ def test_nvfp4_gemm_versus_reference( accumulate=accumulate, x_columnwise=is_x_columnwise, w_columnwise=is_w_columnwise, - per_token_activation=per_token_activation, + row_scaled_activation=row_scaled_activation, ) @@ -471,7 +471,7 @@ def test_nvfp4_gemm_versus_reference( @pytest.mark.parametrize("out_dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("use_bias", [False, True], ids=["no_bias", "bias"]) @pytest.mark.parametrize("single_output", [False, True], ids=["list_output", "single_output"]) -def test_nvfp4_per_token_grouped_gemm_matches_per_gemm( +def test_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( m_splits: list[int], k: int, n: int, @@ -481,7 +481,7 @@ def test_nvfp4_per_token_grouped_gemm_matches_per_gemm( use_bias: bool, single_output: bool, ): - check_nvfp4_per_token_grouped_gemm_matches_per_gemm( + check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( x_dtype=x_dtype, w_dtype=w_dtype, out_dtype=out_dtype, @@ -513,7 +513,7 @@ def test_nvfp4_per_token_grouped_gemm_matches_per_gemm( @pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("w_dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) -def test_nvfp4_per_token_gemm_matches_emulated( +def test_nvfp4_row_scaled_gemm_matches_emulated( M: int, K: int, N: int, @@ -521,7 +521,7 @@ def test_nvfp4_per_token_gemm_matches_emulated( w_dtype: torch.dtype, out_dtype: torch.dtype, ): - check_nvfp4_per_token_gemm_matches_emulated( + check_nvfp4_row_scaled_gemm_matches_emulated( x_dtype=x_dtype, w_dtype=w_dtype, out_dtype=out_dtype, diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index 4458e408ba..4be06bba42 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -31,7 +31,7 @@ def check_quantization_nvfp4_versus_reference( swizzled_scale: bool, use_cpp_allocator: bool, with_2d_quantization: bool, - per_token_activation: bool = False, + row_scaled_activation: bool = False, ) -> None: te_dtype = tex.DType.kFloat4E2M1 @@ -53,7 +53,7 @@ def check_quantization_nvfp4_versus_reference( with_rht=False, with_post_rht_amax=False, with_2d_quantization=with_2d_quantization, - per_token_activation=per_token_activation, + row_scaled_activation=row_scaled_activation, ) if use_cpp_allocator: x_nvfp4_sut = nvfp4_quantizer(x) @@ -82,11 +82,11 @@ def check_quantization_nvfp4_versus_reference( ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, rowwise=True, - columnwise=(return_transpose and not per_token_activation), + columnwise=(return_transpose and not row_scaled_activation), pow_2_scales=False, eps=0.0, quant_tile_shape=quant_tile_shape, - per_token_activation=per_token_activation, + row_scaled_activation=row_scaled_activation, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -119,7 +119,7 @@ def check_quantization_nvfp4_versus_reference( torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) - if return_transpose and not per_token_activation: + if return_transpose and not row_scaled_activation: torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) # Compare only the valid portion of transpose scale tensors @@ -165,7 +165,7 @@ def check_quantization_nvfp4_versus_reference( @pytest.mark.parametrize( "with_2d_quantization", [True, False], ids=["2d_quantization", "1d_quantization"] ) -@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_per_token"]) +@pytest.mark.parametrize("row_scaled_activation", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) def test_quantization_block_tiling_versus_reference( x_dtype: torch.dtype, M: int, @@ -174,10 +174,10 @@ def test_quantization_block_tiling_versus_reference( swizzled_scale: bool, use_cpp_allocator: bool, with_2d_quantization: bool, - per_token_activation: bool, + row_scaled_activation: bool, ) -> None: - if per_token_activation and with_2d_quantization: - pytest.skip("Per-token NVFP4 does not support 2D quantization") + if row_scaled_activation and with_2d_quantization: + pytest.skip("Row-scaled NVFP4 does not support 2D quantization") check_quantization_nvfp4_versus_reference( x_dtype=x_dtype, @@ -187,7 +187,7 @@ def test_quantization_block_tiling_versus_reference( swizzled_scale=swizzled_scale, use_cpp_allocator=use_cpp_allocator, with_2d_quantization=with_2d_quantization, - per_token_activation=per_token_activation, + row_scaled_activation=row_scaled_activation, ) @@ -204,7 +204,7 @@ def test_quantization_block_tiling_versus_reference( @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) -@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_per_token"]) +@pytest.mark.parametrize("row_scaled_activation", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) def test_nvfp4_quantization_extrema_versus_reference( x_dtype: torch.dtype, M: int, @@ -212,7 +212,7 @@ def test_nvfp4_quantization_extrema_versus_reference( extrema_high: bool, return_transpose: bool, use_cpp_allocator: bool, - per_token_activation: bool, + row_scaled_activation: bool, ): te_dtype = tex.DType.kFloat4E2M1 @@ -234,7 +234,7 @@ def test_nvfp4_quantization_extrema_versus_reference( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, - per_token_activation=per_token_activation, + row_scaled_activation=row_scaled_activation, ) if use_cpp_allocator: @@ -261,11 +261,11 @@ def test_nvfp4_quantization_extrema_versus_reference( ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, rowwise=True, - columnwise=(return_transpose and not per_token_activation), + columnwise=(return_transpose and not row_scaled_activation), pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), - per_token_activation=per_token_activation, + row_scaled_activation=row_scaled_activation, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -286,7 +286,7 @@ def test_nvfp4_quantization_extrema_versus_reference( sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]] torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) - if return_transpose and not per_token_activation: + if return_transpose and not row_scaled_activation: torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) ref_sx_t_shape = sx_t_ref.shape sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] @@ -313,14 +313,14 @@ def test_nvfp4_quantization_extrema_versus_reference( @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) -@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_per_token"]) +@pytest.mark.parametrize("row_scaled_activation", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) def test_nvfp4_quantization_boundary_values( x_dtype: torch.dtype, M: int, N: int, return_transpose: bool, use_cpp_allocator: bool, - per_token_activation: bool, + row_scaled_activation: bool, ): """ Stress rounding/threshold behavior by placing values just below/above @@ -356,7 +356,7 @@ def test_nvfp4_quantization_boundary_values( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, - per_token_activation=per_token_activation, + row_scaled_activation=row_scaled_activation, ) if use_cpp_allocator: @@ -383,11 +383,11 @@ def test_nvfp4_quantization_boundary_values( ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, rowwise=True, - columnwise=(return_transpose and not per_token_activation), + columnwise=(return_transpose and not row_scaled_activation), pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), - per_token_activation=per_token_activation, + row_scaled_activation=row_scaled_activation, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -409,7 +409,7 @@ def test_nvfp4_quantization_boundary_values( sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]] torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) - if return_transpose and not per_token_activation: + if return_transpose and not row_scaled_activation: torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) ref_sx_t_shape = sx_t_ref.shape sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] @@ -435,14 +435,14 @@ def test_nvfp4_quantization_boundary_values( @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) -@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_per_token"]) +@pytest.mark.parametrize("row_scaled_activation", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) def test_nvfp4_quantization_noncontiguous_inputs( x_dtype: torch.dtype, M: int, N: int, return_transpose: bool, use_cpp_allocator: bool, - per_token_activation: bool, + row_scaled_activation: bool, ): te_dtype = tex.DType.kFloat4E2M1 @@ -464,7 +464,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, - per_token_activation=per_token_activation, + row_scaled_activation=row_scaled_activation, ) if use_cpp_allocator: @@ -491,11 +491,11 @@ def test_nvfp4_quantization_noncontiguous_inputs( ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, rowwise=True, - columnwise=(return_transpose and not per_token_activation), + columnwise=(return_transpose and not row_scaled_activation), pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), - per_token_activation=per_token_activation, + row_scaled_activation=row_scaled_activation, ) x_nvfp4_ref = ref_quantizer.quantize(x_nc) @@ -518,7 +518,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]] torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) - if return_transpose and not per_token_activation: + if return_transpose and not row_scaled_activation: torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) ref_sx_t_shape = sx_t_ref.shape sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index de916ff562..c7c5a5b99d 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -79,9 +79,9 @@ id="NVFP4BlockScaling", ), pytest.param( - "nvfp4_per_token", + "nvfp4_row_scaled", marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), - id="NVFP4PerTokenBlockScaling", + id="NVFP4RowScaledBlockScaling", ), ] @@ -170,7 +170,7 @@ def _maybe_skip_recipe_dtype( ) -> None: if dtype == torch.bfloat16 and not bf16_available: pytest.skip(reason_for_no_bf16) - if recipe_name in ("nvfp4", "nvfp4_per_token"): + if recipe_name in ("nvfp4", "nvfp4_row_scaled"): if module_type in ("linear", "layernorm_linear") and dtype not in ( torch.bfloat16, torch.float32, @@ -183,12 +183,12 @@ def _maybe_skip_recipe_dtype( def _maybe_skip_unsupported_recipe_module_combo(recipe_name: str, module_type: str) -> None: if module_type == "ops_linear" and recipe_name == "fp8_block_scaling": pytest.skip("Fusible ops (te_ops.Linear) do not support Float8BlockScaling recipe") - if module_type == "ops_linear" and recipe_name == "nvfp4_per_token": - pytest.skip("Per-token NVFP4 currently does not support fused te_ops paths.") + if module_type == "ops_linear" and recipe_name == "nvfp4_row_scaled": + pytest.skip("Row-scaled NVFP4 currently does not support fused te_ops paths.") def _make_quantized_forward_reference_recipe(recipe_name: str) -> recipe.Recipe: - if recipe_name == "nvfp4_per_token": + if recipe_name == "nvfp4_row_scaled": return make_recipe(recipe_name, backward_override="dequantized") return make_recipe(recipe_name) @@ -208,7 +208,7 @@ def _maybe_skip_unsupported_recipe_shape( " by 32." ) return - if recipe_name in ("nvfp4", "nvfp4_per_token") and ( + if recipe_name in ("nvfp4", "nvfp4_row_scaled") and ( flat_first_dim % 16 != 0 or last_dim % 16 != 0 ): pytest.skip( @@ -235,7 +235,7 @@ def _maybe_skip_unsupported_recipe_shape( pytest.skip( "te_ops.Linear + MXFP8 requires prod(shape[:-1]) and shape[-1] divisible by 32." ) - if recipe_name in ("nvfp4", "nvfp4_per_token") and ( + if recipe_name in ("nvfp4", "nvfp4_row_scaled") and ( flat_first_dim % 16 != 0 or last_dim % 16 != 0 ): pytest.skip( @@ -256,9 +256,9 @@ def _maybe_skip_unsupported_grouped_splits(recipe_name: str, m_splits: list[int] ) if recipe_name == "mxfp8" and any(m % 32 != 0 for m in non_empty_splits): pytest.skip("GroupedLinear + MXFP8 requires each non-empty m_split divisible by 32.") - if recipe_name in ("nvfp4", "nvfp4_per_token") and any(m % 16 != 0 for m in non_empty_splits): + if recipe_name in ("nvfp4", "nvfp4_row_scaled") and any(m % 16 != 0 for m in non_empty_splits): pytest.skip("GroupedLinear + NVFP4 requires each non-empty m_split divisible by 16.") - if recipe_name in ("nvfp4", "nvfp4_per_token") and any(m % 64 != 0 for m in non_empty_splits): + if recipe_name in ("nvfp4", "nvfp4_row_scaled") and any(m % 64 != 0 for m in non_empty_splits): pytest.skip( "GroupedLinear + NVFP4 grouped split_quantize currently requires each non-empty " "m_split divisible by 64 due to grouped amax kernel constraints." @@ -1738,7 +1738,7 @@ def test_backward_override_memory_peak_report( modes = ( ("high_precision", "dequantized") - if recipe_name == "nvfp4_per_token" + if recipe_name == "nvfp4_row_scaled" else (None, "high_precision", "dequantized") ) mode_results: dict[str, dict[str, float] | str] = {} diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 8a01acf0eb..33ba65e0d9 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -64,8 +64,8 @@ def nvfp4_rht_and_2d_quantization(): return nvfp4_recipe -def nvfp4_per_token(): - nvfp4_recipe = recipe.NVFP4BlockScaling(per_token_activation=True) +def nvfp4_row_scaled(): + nvfp4_recipe = recipe.NVFP4BlockScaling(row_scaled_activation=True) nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() @@ -100,7 +100,7 @@ def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> fp8_recipes.append(recipe.MXFP8BlockScaling()) if nvfp4_available: fp8_recipes.append(nvfp4_rht_and_2d_quantization()) - fp8_recipes.append(nvfp4_per_token()) + fp8_recipes.append(nvfp4_row_scaled()) if fp8_block_scaling_available: fp8_recipes.append(recipe.Float8BlockScaling()) if fp8_available: @@ -402,8 +402,8 @@ def test_make_graphed_callables( f"Module not yet supported for {fp8_recipe.__class__.__name__} with CUDA graphs" ) if fp8 and fp8_recipe.nvfp4(): - if getattr(fp8_recipe, "per_token_activation", False) and module == "mha": - pytest.skip("Per-token NVFP4 CUDA graph coverage applies to GEMM modules.") + if getattr(fp8_recipe, "row_scaled_activation", False) and module == "mha": + pytest.skip("Row-scaled NVFP4 CUDA graph coverage applies to GEMM modules.") if dtype not in get_nvfp4_inp_supported_dtypes(fp8_recipe, dtype): pytest.skip( f"Input dtype {dtype} not supported for NVFP4 Recipe" diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index f12148232c..ae93765183 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -514,27 +514,27 @@ def test_quantizer_update(self, module_class): @pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) -def test_nvfp4_per_token_quantizer_roles(): - recipe = NVFP4BlockScaling(per_token_activation=True) +def test_nvfp4_row_scaled_quantizer_roles(): + recipe = NVFP4BlockScaling(row_scaled_activation=True) forward_quantizers = NVFP4BlockScalingRecipeState( recipe, mode="forward", num_quantizers=3, ).make_quantizers() - assert [q.per_token_activation for q in forward_quantizers] == [True, False, True] + assert [q.row_scaled_activation for q in forward_quantizers] == [True, False, True] backward_quantizers = NVFP4BlockScalingRecipeState( recipe, mode="backward", num_quantizers=2, ).make_quantizers() - assert [q.per_token_activation for q in backward_quantizers] == [False, False] + assert [q.row_scaled_activation for q in backward_quantizers] == [False, False] @pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=str) -@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_per_token"]) +@pytest.mark.parametrize("row_scaled_activation", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) @pytest.mark.parametrize( "M, N", [ @@ -550,8 +550,8 @@ def test_nvfp4_per_token_quantizer_roles(): (8192, 8192), ], ) -def test_fp4_dequantize(dtype, per_token_activation, M, N): - q = NVFP4Quantizer(per_token_activation=per_token_activation) +def test_fp4_dequantize(dtype, row_scaled_activation, M, N): + q = NVFP4Quantizer(row_scaled_activation=row_scaled_activation) a = torch.rand((M, N)).cuda().to(dtype=dtype) starting_tensor = q(a) dequantized_tensor = starting_tensor.dequantize() diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index bb1c952163..760f9659d0 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -94,8 +94,8 @@ def nvfp4_vanilla(): return nvfp4_recipe -def nvfp4_per_token(): - nvfp4_recipe = recipe.NVFP4BlockScaling(per_token_activation=True) +def nvfp4_row_scaled(): + nvfp4_recipe = recipe.NVFP4BlockScaling(row_scaled_activation=True) nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() @@ -113,9 +113,9 @@ def nvfp4_per_token(): fp8_recipes.append(recipe.Float8CurrentScaling()) fp8_recipes.append(recipe.DelayedScaling()) fp8_recipes.append(None) -fp8_recipes_with_per_token = fp8_recipes.copy() +fp8_recipes_with_row_scaled = fp8_recipes.copy() if nvfp4_available: - fp8_recipes_with_per_token.insert(-1, nvfp4_per_token()) + fp8_recipes_with_row_scaled.insert(-1, nvfp4_row_scaled()) param_types = [torch.float32, torch.float16] if is_bf16_available(): # bf16 requires sm_80 or higher @@ -415,7 +415,7 @@ def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normaliz @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_per_token, ids=recipe_id) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_row_scaled, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("model", ["small", "weird"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @@ -463,7 +463,7 @@ def test_sanity_layernorm_linear( @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_per_token, ids=recipe_id) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_row_scaled, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("model", ["small", "weird"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @@ -501,7 +501,7 @@ def test_sanity_linear( @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes_with_zero) @pytest.mark.parametrize("model", ["small", "weird"]) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_per_token, ids=recipe_id) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_row_scaled, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean) @@ -542,7 +542,7 @@ def test_sanity_linear_with_zero_tokens( @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes_with_zero) @pytest.mark.parametrize("model", ["small", "weird"]) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_per_token, ids=recipe_id) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_row_scaled, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean) @@ -576,7 +576,7 @@ def test_sanity_grouped_linear( if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") if fp8_recipe.nvfp4(): - if not getattr(fp8_recipe, "per_token_activation", False): + if not getattr(fp8_recipe, "row_scaled_activation", False): pytest.skip("NVFP4 not supported for grouped linear") if dtype == torch.float16: pytest.skip("FP16 output for NVFP4 not supported") diff --git a/tests/pytorch/test_torch_compile.py b/tests/pytorch/test_torch_compile.py index d67c5e77b7..51f72b1e56 100644 --- a/tests/pytorch/test_torch_compile.py +++ b/tests/pytorch/test_torch_compile.py @@ -48,7 +48,7 @@ _all_recipes.append(recipe.MXFP8BlockScaling()) if nvfp4_available: _all_recipes.append(recipe.NVFP4BlockScaling()) - _all_recipes.append(recipe.NVFP4BlockScaling(per_token_activation=True)) + _all_recipes.append(recipe.NVFP4BlockScaling(row_scaled_activation=True)) # --------------------------------------------------------------------------- diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 1b2f9fe987..6d388c7b7f 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -115,7 +115,7 @@ def quantization_tols(name: str) -> dict[str, float]: "mxfp8_block_scaling", ): return dtype_tols(tex.DType.kFloat8E4M3) - if name in ("nvfp4", "nvfp4_per_token"): + if name in ("nvfp4", "nvfp4_row_scaled"): return dtype_tols(tex.DType.kFloat4E2M1) raise ValueError(f"Unsupported quantization scheme ({name})") @@ -149,12 +149,12 @@ def make_recipe(name: Optional[str], **recipe_kwargs: Any) -> Optional[Recipe]: disable_2d_quantization=True, **recipe_kwargs, ) - if name == "nvfp4_per_token": + if name == "nvfp4_row_scaled": return transformer_engine.common.recipe.NVFP4BlockScaling( disable_rht=True, disable_stochastic_rounding=True, disable_2d_quantization=True, - per_token_activation=True, + row_scaled_activation=True, **recipe_kwargs, ) raise ValueError(f"Unsupported quantization scheme ({name})") @@ -164,8 +164,8 @@ def recipe_id(fp8_recipe: Optional[Recipe]) -> str: """Readable pytest id for FP8/FP4 recipes.""" if fp8_recipe is None: return "None" - if fp8_recipe.nvfp4() and getattr(fp8_recipe, "per_token_activation", False): - return "NVFP4PerTokenBlockScaling" + if fp8_recipe.nvfp4() and getattr(fp8_recipe, "row_scaled_activation", False): + return "NVFP4RowScaledBlockScaling" return type(fp8_recipe).__name__ @@ -178,10 +178,10 @@ def skip_unsupported_backward_override( if ( quant_recipe is not None and quant_recipe.nvfp4() - and getattr(quant_recipe, "per_token_activation", False) + and getattr(quant_recipe, "row_scaled_activation", False) and backward_override is None ): - pytest.skip("Per-token NVFP4 does not support default quantized backward.") + pytest.skip("Row-scaled NVFP4 does not support default quantized backward.") if backward_override is None: return if quant_recipe is None and backward_override is not None: diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 7b063dbf6a..1cdfc7da2b 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -100,15 +100,15 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, int32_t rows = input_tensor->flat_first_dim(); int32_t cols = input_tensor->flat_last_dim(); auto dtype = input_tensor->dtype(); - const bool per_token_activation = quant_config_cpp.nvfp4_per_token_activation; - if (per_token_activation) { + const bool rowwise_amax_is_row_scaled = output_tensor->rowwise_amax_is_row_scaled; + if (rowwise_amax_is_row_scaled) { NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, - "Per-token NVFP4 quantization does not support 2D quantization."); + "Row-scaled NVFP4 quantization does not support 2D quantization."); NVTE_CHECK(output_tensor->has_data(), - "Per-token NVFP4 quantization requires rowwise output."); + "Row-scaled NVFP4 quantization requires rowwise output."); NVTE_CHECK(!output_tensor->has_columnwise_data(), - "Per-token NVFP4 quantization does not produce columnwise output."); - nvfp4::compute_per_token_amax(*input_tensor, noop_tensor, output_tensor, stream); + "Row-scaled NVFP4 quantization does not produce columnwise output."); + nvfp4::compute_rowwise_amax(*input_tensor, noop_tensor, output_tensor, stream); } bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && (cols % 32 == 0) && output_tensor->has_data(); @@ -136,7 +136,8 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, /*rng_state=*/quant_config_cpp.rng_state, /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, - /*per_token_rowwise_scaling=*/per_token_activation, /*noop_tensor=*/noop_tensor->data, + /*rowwise_amax_is_row_scaled=*/rowwise_amax_is_row_scaled, + /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); } break; @@ -250,8 +251,8 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens int32_t rows = grad_tensor->flat_first_dim(); int32_t cols = grad_tensor->flat_last_dim(); auto dtype = grad_tensor->dtype(); - NVTE_CHECK(!quant_config_cpp.nvfp4_per_token_activation, - "Per-token NVFP4 quantization is only supported for forward activation tensors."); + NVTE_CHECK(!output_tensor->rowwise_amax_is_row_scaled, + "Row-scaled NVFP4 quantization is only supported for forward quantization."); bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && (cols % 32 == 0) && output_tensor->has_data(); @@ -278,7 +279,7 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, /*rng_state=*/quant_config_cpp.rng_state, /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, - /*per_token_rowwise_scaling=*/false, /*noop_tensor=*/noop_tensor->data, + /*rowwise_amax_is_row_scaled=*/false, /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); } break; diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh index 85e858e146..4013a10276 100644 --- a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -34,8 +34,8 @@ namespace dequantize_kernel { template __global__ void __launch_bounds__(512) dequantize_fp4_kernel(const void *const input, OType *output, const fp8e4m3 *const scales, - const float *const tensor_amax, const size_t amax_numel, const size_t N, - const size_t M, const size_t scale_stride, + const float *const tensor_amax, const bool rowwise_amax_is_row_scaled, + const size_t N, const size_t M, const size_t scale_stride, const size_t num_scale_tiles_X) { const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; const size_t x = thread_idx % M; @@ -64,7 +64,7 @@ __global__ void __launch_bounds__(512) fp4vec value; value.vec = input_vectorized[my_index]; fp8e4m3 scale = scales[my_scale_index]; - float amax = (amax_numel == 1) ? tensor_amax[0] : tensor_amax[y]; + float amax = rowwise_amax_is_row_scaled ? tensor_amax[y] : tensor_amax[0]; constexpr float factor_inv = 1.0 / (6.0 * 448.0); float final_scale = static_cast(scale) * amax * factor_inv; #pragma unroll @@ -91,6 +91,7 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); const bool with_gemm_swizzled_scales = input.with_gemm_swizzled_scales; + const bool rowwise_amax_is_row_scaled = input.rowwise_amax_is_row_scaled; constexpr int FP4_BLOCK_SIZE = 16; const size_t N = input.flat_first_dim(); @@ -104,6 +105,8 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) const size_t threads = 512; const size_t blocks = DIVUP(total, threads); const size_t num_scale_tiles_X = DIVUP(Mread, static_cast(4)); + NVTE_CHECK(!rowwise_amax_is_row_scaled || input.amax.numel() == N, + "Row-scaled NVFP4 dequantization requires one rowwise amax per row."); TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( output->data.dtype, OType, @@ -113,7 +116,7 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) dequantize_fp4_kernel<<>>( input.data.dptr, reinterpret_cast(output->data.dptr), reinterpret_cast(input.scale_inv.dptr), - reinterpret_cast(input.amax.dptr), input.amax.numel(), N, Mread, + reinterpret_cast(input.amax.dptr), rowwise_amax_is_row_scaled, N, Mread, input.scale_inv.shape.back(), num_scale_tiles_X);); // NOLINT(*) ); // NOLINT(*) diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index 1888bc74f7..cd31a32dd6 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -30,14 +30,14 @@ namespace transformer_engine { namespace dispatch { namespace nvfp4 { -namespace per_token_amax_kernel { +namespace rowwise_amax_kernel { using namespace ptx; #if FP4_TYPE_SUPPORTED -constexpr int PER_TOKEN_BLOCK_SIZE = 256; -constexpr int PER_TOKEN_SF_VEC_SIZE = 16; +constexpr int ROWWISE_AMAX_BLOCK_SIZE = 256; +constexpr int ROWWISE_AMAX_SF_VEC_SIZE = 16; template __device__ __forceinline__ void abs_max_2x_update(ptx::FPx2 &dst, @@ -64,10 +64,10 @@ __global__ void #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) __launch_bounds__(BLOCK_SIZE) #endif - compute_per_token_amax_kernel(const int num_rows, const int num_cols, - const IType *__restrict__ input, - float *__restrict__ output_per_token_amax, - const float *__restrict__ noop) { + compute_rowwise_amax_kernel(const int num_rows, const int num_cols, + const IType *__restrict__ input, + float *__restrict__ output_rowwise_amax, + const float *__restrict__ noop) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) if (noop != nullptr && noop[0] == 1.0f) { return; @@ -94,63 +94,63 @@ __launch_bounds__(BLOCK_SIZE) BlockReduce(temp_storage).Reduce(thread_max, [](float a, float b) { return fmaxf(a, b); }); if (threadIdx.x == 0) { - output_per_token_amax[row_idx] = row_amax; + output_rowwise_amax[row_idx] = row_amax; } #endif } template -void launch_compute_per_token_amax(const int num_rows, const int num_cols, const IType *input, - float *output_per_token_amax, cudaStream_t stream, - const float *noop = nullptr) { +void launch_compute_rowwise_amax(const int num_rows, const int num_cols, const IType *input, + float *output_rowwise_amax, cudaStream_t stream, + const float *noop = nullptr) { if (num_rows == 0 || num_cols == 0) return; - NVTE_CHECK(num_cols % 2 == 0, "num_cols must be even for per-token amax computation, got ", + NVTE_CHECK(num_cols % 2 == 0, "num_cols must be even for row-scaled amax computation, got ", num_cols); dim3 grid(num_rows); - dim3 block(PER_TOKEN_BLOCK_SIZE); + dim3 block(ROWWISE_AMAX_BLOCK_SIZE); - compute_per_token_amax_kernel - <<>>(num_rows, num_cols, input, output_per_token_amax, noop); + compute_rowwise_amax_kernel + <<>>(num_rows, num_cols, input, output_rowwise_amax, noop); NVTE_CHECK_CUDA(cudaGetLastError()); } #endif // FP4_TYPE_SUPPORTED -} // namespace per_token_amax_kernel +} // namespace rowwise_amax_kernel -inline void compute_per_token_amax(const Tensor &input, const Tensor *noop, Tensor *output, - cudaStream_t stream) { +inline void compute_rowwise_amax(const Tensor &input, const Tensor *noop, Tensor *output, + cudaStream_t stream) { #if FP4_TYPE_SUPPORTED - using namespace per_token_amax_kernel; + using namespace rowwise_amax_kernel; const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); - NVTE_CHECK(cols % PER_TOKEN_SF_VEC_SIZE == 0, - "Per-token NVFP4 quantization requires last dim divisible by ", PER_TOKEN_SF_VEC_SIZE, - "."); + NVTE_CHECK(cols % ROWWISE_AMAX_SF_VEC_SIZE == 0, + "Row-scaled NVFP4 quantization requires last dim divisible by ", + ROWWISE_AMAX_SF_VEC_SIZE, "."); auto *amax_ptr = reinterpret_cast(output->amax.dptr); - NVTE_CHECK(amax_ptr != nullptr, "Per-token rowwise amax tensor must be allocated."); - NVTE_CHECK(output->amax.numel() == rows, "Per-token rowwise amax must have ", rows, + NVTE_CHECK(amax_ptr != nullptr, "Row-scaled rowwise amax tensor must be allocated."); + NVTE_CHECK(output->amax.numel() == rows, "Row-scaled rowwise amax must have ", rows, " entries, got ", output->amax.shape, "."); const auto *noop_ptr = reinterpret_cast(noop->data.dptr); if (input.dtype() == DType::kBFloat16) { const auto *input_ptr = reinterpret_cast(input.data.dptr); - launch_compute_per_token_amax<__nv_bfloat16>(static_cast(rows), static_cast(cols), - input_ptr, amax_ptr, stream, noop_ptr); + launch_compute_rowwise_amax<__nv_bfloat16>(static_cast(rows), static_cast(cols), + input_ptr, amax_ptr, stream, noop_ptr); } else if (input.dtype() == DType::kFloat16) { const auto *input_ptr = reinterpret_cast(input.data.dptr); - launch_compute_per_token_amax(static_cast(rows), static_cast(cols), input_ptr, - amax_ptr, stream, noop_ptr); + launch_compute_rowwise_amax(static_cast(rows), static_cast(cols), input_ptr, + amax_ptr, stream, noop_ptr); } else if (input.dtype() == DType::kFloat32) { const auto *input_ptr = reinterpret_cast(input.data.dptr); - launch_compute_per_token_amax(static_cast(rows), static_cast(cols), input_ptr, - amax_ptr, stream, noop_ptr); + launch_compute_rowwise_amax(static_cast(rows), static_cast(cols), input_ptr, + amax_ptr, stream, noop_ptr); } else { NVTE_ERROR( - "Unsupported input dtype for per-token NVFP4 quantization. " + "Unsupported input dtype for row-scaled NVFP4 quantization. " "Expected BFloat16, Float16, or Float32."); } #else @@ -240,7 +240,7 @@ constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM; // 8 = 128 / template + bool ROWWISE_AMAX_IS_ROW_SCALED> __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_kernel(const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ CUtensorMap tensor_map_output, @@ -641,7 +641,7 @@ __global__ void __launch_bounds__(THREADS_NUM) } float block_scale_inverse; - if constexpr (PER_TOKEN_ROWWISE) { + if constexpr (ROWWISE_AMAX_IS_ROW_SCALED) { // 2. Compute E4M3 scaling factor const size_t scales_offset_Y = scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; @@ -1325,14 +1325,14 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, using namespace ptx; bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; - const bool per_token_rowwise = quant_config ? quant_config->nvfp4_per_token_activation : false; - NVTE_CHECK(!per_token_rowwise || !use_2d_quantization, - "Per-token NVFP4 quantization does not support 2D quantization."); + const bool rowwise_amax_is_row_scaled = output->rowwise_amax_is_row_scaled; + NVTE_CHECK(!rowwise_amax_is_row_scaled || !use_2d_quantization, + "Row-scaled NVFP4 quantization does not support 2D quantization."); // If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to // return the transposed data. // TODO(Frank): Is there a better way to do this? - bool return_transpose = output->has_columnwise_data() && !per_token_rowwise; + bool return_transpose = output->has_columnwise_data() && !rowwise_amax_is_row_scaled; if (!use_2d_quantization && (input.dtype() == DType::kBFloat16)) { quantize_transpose_tuned_1D(input, noop, output, quant_config, stream); @@ -1352,8 +1352,8 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated."); NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); - NVTE_CHECK(!per_token_rowwise || output->amax.dptr != nullptr, - "Per-token NVFP4 rowwise quantization requires rowwise amax."); + NVTE_CHECK(!rowwise_amax_is_row_scaled || output->amax.dptr != nullptr, + "Row-scaled NVFP4 quantization requires rowwise amax."); NVTE_CHECK(!output->with_gemm_swizzled_scales, "Output must have scales in compact format."); if (return_transpose) { NVTE_CHECK(output->has_columnwise_data(), "NVFP4 transposed output tensor must be allocated."); @@ -1436,11 +1436,11 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, TRANSFORMER_ENGINE_SWITCH_CONDITION( use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, - TRANSFORMER_ENGINE_SWITCH_CONDITION(per_token_rowwise, PER_TOKEN_ROWWISE, { + TRANSFORMER_ENGINE_SWITCH_CONDITION(rowwise_amax_is_row_scaled, ROWWISE_AMAX_IS_ROW_SCALED, { TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { auto kernel = quantize_transpose_nvfp4_kernel; + ROWWISE_AMAX_IS_ROW_SCALED>; if constexpr (use_2d_quantization) { kernel = quantize_transpose_nvfp4_2D_kernel +template __device__ __forceinline__ void rowwise_scaling( const IType *__restrict__ sIn_ptr, fp4e2m1x2 *__restrict__ sOut_ptr, nvfp4_scale_t *__restrict__ sSFrowwise_ptr, const float S_enc_rowwise, const int stage_Y, @@ -315,7 +315,7 @@ __device__ __forceinline__ void rowwise_scaling( nvfp4_scale_t S_dec_b_fp8; scaling_coeff_type SFcoefficient; - if constexpr (PER_TOKEN_ROWWISE) { + if constexpr (ROWWISE_AMAX_IS_ROW_SCALED) { const size_t row_idx = row_offset + stage_Y * TILE_DIM_Y + it_offset_Y_rowwise; const float S_enc_rowwise_block = row_idx < rows ? core::compute_global_encode_scaling_factor_FP4(amax_rowwise_ptr[row_idx]) @@ -361,7 +361,7 @@ __device__ __forceinline__ void rowwise_scaling( } template + bool ROWWISE_AMAX_IS_ROW_SCALED> __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_tuned_1D_kernel( const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ CUtensorMap tensor_map_output, @@ -582,7 +582,7 @@ __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_tuned_1D ptx::cp_async_bulk_wait_group_read(); // NVFP4 Quantization - rowwise_scaling( + rowwise_scaling( sIn_ptr, sOut_ptr, sSFrowwise_ptr, S_enc_rowwise, stage_Y, stage_X, buff_in, buff_out, amax_rowwise_ptr, block_offset_Y, rows, rng, random_uint4, rnd_idx); @@ -691,11 +691,11 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, const bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; const bool use_fast_math = quant_config ? quant_config->use_fast_math : false; - const bool per_token_rowwise = quant_config ? quant_config->nvfp4_per_token_activation : false; + const bool rowwise_amax_is_row_scaled = output->rowwise_amax_is_row_scaled; // If transposed output is allocated, return the transposed data // Otherwise, it's not necesary to return the transposed data. - const bool return_transpose = output->has_columnwise_data() && !per_token_rowwise; + const bool return_transpose = output->has_columnwise_data() && !rowwise_amax_is_row_scaled; checkCuDriverContext(stream); CheckNoopTensor(*noop, "cast_noop"); @@ -706,8 +706,8 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated."); NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); - NVTE_CHECK(!per_token_rowwise || output->amax.dptr != nullptr, - "Per-token NVFP4 rowwise quantization requires rowwise amax."); + NVTE_CHECK(!rowwise_amax_is_row_scaled || output->amax.dptr != nullptr, + "Row-scaled NVFP4 quantization requires rowwise amax."); if (return_transpose) { NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), @@ -798,11 +798,12 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, TRANSFORMER_ENGINE_SWITCH_CONDITION( use_fast_math, USE_FAST_MATH, TRANSFORMER_ENGINE_SWITCH_CONDITION( - per_token_rowwise, PER_TOKEN_ROWWISE, + rowwise_amax_is_row_scaled, ROWWISE_AMAX_IS_ROW_SCALED, TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { auto kernel = quantize_transpose_nvfp4_tuned_1D_kernel; + RETURN_TRANSPOSE, + ROWWISE_AMAX_IS_ROW_SCALED>; cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 133f1a09e6..0af00e3ace 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -222,6 +222,14 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz TensorWrapper chunk(scaling_mode); for (int param_id = 0; param_id < NVTETensorParam::kNVTENumTensorParams; param_id++) { auto param_type = static_cast(param_id); + if (param_type == NVTETensorParam::kNVTEWithGEMMSwizzledScales) { + chunk.set_with_gemm_swizzled_scales(source.get_with_gemm_swizzled_scales()); + continue; + } + if (param_type == NVTETensorParam::kNVTERowwiseAmaxIsRowScaled) { + chunk.set_rowwise_amax_is_row_scaled(source.get_rowwise_amax_is_row_scaled()); + continue; + } auto param = source.get_parameter(param_type); auto param_dptr = reinterpret_cast(param.data_ptr); auto param_dtype = static_cast(param.dtype); diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index c5b4254e8b..7efdc42b58 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -173,6 +173,11 @@ struct Tensor { * Only meaningful for MXFP8 and NVFP4. */ bool with_gemm_swizzled_scales = false; + /*! \brief Whether rowwise NVFP4 amax is one value per tensor row. + * + * Only meaningful for NVFP4 tensors. + */ + bool rowwise_amax_is_row_scaled = false; /*! Map from NVTETensorParam to parameter sizes */ static constexpr size_t attr_sizes[] = { @@ -183,7 +188,8 @@ struct Tensor { sizeof(NVTEBasicTensor), // kNVTERowwiseScaleInv sizeof(NVTEBasicTensor), // kNVTEColumnwiseScaleInv sizeof(NVTEBasicTensor), // kNVTEColumnwiseAmax - sizeof(uint8_t) // kNVTEWithGEMMSwizzledScales + sizeof(uint8_t), // kNVTEWithGEMMSwizzledScales + sizeof(uint8_t) // kNVTERowwiseAmaxIsRowScaled }; Tensor() : scaling_mode{NVTE_DELAYED_TENSOR_SCALING}, nvte_tensor{0} {} @@ -199,6 +205,7 @@ struct Tensor { columnwise_scale_inv.clear(); scaling_mode = NVTE_DELAYED_TENSOR_SCALING; with_gemm_swizzled_scales = false; + rowwise_amax_is_row_scaled = false; } explicit operator NVTETensor() const noexcept { return nvte_tensor; } @@ -470,7 +477,6 @@ struct QuantizationConfig { bool nvfp4_2d_quantization = false; bool stochastic_rounding = false; bool use_fast_math = false; - bool nvfp4_per_token_activation = false; static constexpr size_t attr_sizes[] = { sizeof(uint8_t), // force_pow_2_scales @@ -480,8 +486,7 @@ struct QuantizationConfig { sizeof(NVTETensor), // rng_seed and offset sizeof(uint8_t), // nvfp4_2d_quantization sizeof(uint8_t), // stochastic_rounding - sizeof(uint8_t), // use_fast_math - sizeof(uint8_t) // nvfp4_per_token_activation + sizeof(uint8_t) // use_fast_math }; }; diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 0463d51d1c..00989fef54 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -72,6 +72,7 @@ enum NVTETensorParam { kNVTEColumnwiseScaleInv = 5, /*!< Scale inverse tensor for decoding Columnwise Data */ kNVTEColumnwiseAmax = 6, /*!< Columnwise Amax tensor */ kNVTEWithGEMMSwizzledScales = 7, /*!< Whether scaling factors are in format expected by GEMM */ + kNVTERowwiseAmaxIsRowScaled = 8, /*!< Whether rowwise amax is one value per tensor row */ kNVTENumTensorParams }; @@ -370,8 +371,6 @@ enum NVTEQuantizationConfigAttribute { * inconsistently between kernels. */ kNVTEQuantizationConfigUseFastMath = 7, - /*! Whether to enable per-token (per-row) NVFP4 quantization */ - kNVTEQuantizationConfigNVFP4PerTokenActivation = 8, kNVTEQuantizationConfigNumAttributes }; @@ -767,6 +766,11 @@ class TensorWrapper { nvte_set_tensor_param_v2(tensor_, kNVTEWithGEMMSwizzledScales, &val, sizeof(val)); } + void set_rowwise_amax_is_row_scaled(bool rowwise_amax_is_row_scaled) { + const auto val = static_cast(rowwise_amax_is_row_scaled); + nvte_set_tensor_param_v2(tensor_, kNVTERowwiseAmaxIsRowScaled, &val, sizeof(val)); + } + // Parameter getters NVTEBasicTensor get_parameter(const NVTETensorParam param) const noexcept { @@ -803,6 +807,12 @@ class TensorWrapper { return static_cast(val); } + bool get_rowwise_amax_is_row_scaled() const { + uint8_t val = 0; + nvte_get_tensor_param_v2(tensor_, kNVTERowwiseAmaxIsRowScaled, &val, sizeof(val), nullptr); + return static_cast(val); + } + /*! \brief Get an underlying NVTETensor. * * \return NVTETensor held by this TensorWrapper. @@ -1298,13 +1308,6 @@ class QuantizationConfigWrapper { sizeof(val)); } - /*! \brief Set whether to enable per-token NVFP4 quantization */ - void set_nvfp4_per_token_activation(bool nvfp4_per_token_activation) { - const auto val = static_cast(nvfp4_per_token_activation); - nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigNVFP4PerTokenActivation, - &val, sizeof(val)); - } - private: /*! \brief Wrapped NVTEQuantizationConfig. */ NVTEQuantizationConfig config_ = nullptr; diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index c2e3ac334e..0d0b2fd37f 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -478,10 +478,10 @@ class NVFP4BlockScaling(Recipe): If set to `True`, stochastic rounding is disabled during quantization for all tensors. disable_2d_quantization : bool, default = False If set to `True`, 1D block scaling with block size 16 is used for all tensors. - per_token_activation : bool, default = False - If set to `True`, forward activation quantizers use per-token (per-row) - NVFP4 global amax values. In this mode, rowwise ``amax`` metadata is - stored as a vector with one FP32 value per token. + row_scaled_activation : bool, default = False + If set to `True`, forward activation quantizers emit row-scaled + NVFP4 tensors. In this mode, rowwise ``amax`` metadata is stored + as a vector with one FP32 value per tensor row. backward_override : {None, 'high_precision', 'dequantized'}, default = None Backward precision mode. None does not modify backward behavior, `high_precision` keeps original high-precision operands for backward, @@ -495,7 +495,7 @@ class NVFP4BlockScaling(Recipe): os.getenv("NVTE_NVFP4_DISABLE_STOCHASTIC_ROUNDING", "0") == "1" ) disable_2d_quantization: bool = os.getenv("NVTE_NVFP4_DISABLE_2D_QUANTIZATION", "0") == "1" - per_token_activation: bool = os.getenv("NVTE_NVFP4_PER_TOKEN_ACTIVATION", "0") == "1" + row_scaled_activation: bool = os.getenv("NVTE_NVFP4_ROW_SCALED_ACTIVATION", "0") == "1" fp4_format: Format = Format.E2M1 fp8_format: Format = Format.E4M3 @@ -539,7 +539,7 @@ def __repr__(self) -> str: f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}, " f"backward_override={self.backward_override}, " - f"per_token_activation={self.per_token_activation}, " + f"row_scaled_activation={self.row_scaled_activation}, " f"fp4_quant_fwd_inp={self.fp4_quant_fwd_inp}, " f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, " f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, " diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index a0a0ffa45f..e78f3d90ef 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -852,6 +852,9 @@ void nvte_set_tensor_param_v2(NVTETensor tensor, NVTETensorParam param, const vo case kNVTEWithGEMMSwizzledScales: t.with_gemm_swizzled_scales = static_cast(*reinterpret_cast(buf)); break; + case kNVTERowwiseAmaxIsRowScaled: + t.rowwise_amax_is_row_scaled = static_cast(*reinterpret_cast(buf)); + break; default: NVTE_ERROR("Unsupported tensor parameter (", static_cast(param), ")"); } @@ -932,6 +935,9 @@ void nvte_get_tensor_param_v2(const NVTETensor tensor, NVTETensorParam param, vo case kNVTEWithGEMMSwizzledScales: *reinterpret_cast(buf) = static_cast(t->with_gemm_swizzled_scales); break; + case kNVTERowwiseAmaxIsRowScaled: + *reinterpret_cast(buf) = static_cast(t->rowwise_amax_is_row_scaled); + break; default: NVTE_ERROR("Unsupported tensor parameter (", static_cast(param), ")"); } @@ -1043,9 +1049,6 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigUseFastMath: bool_to_uint8(config_.use_fast_math, buf); break; - case kNVTEQuantizationConfigNVFP4PerTokenActivation: - bool_to_uint8(config_.nvfp4_per_token_activation, buf); - break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } @@ -1101,9 +1104,6 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigUseFastMath: uint8_to_bool(buf, config_.use_fast_math); break; - case kNVTEQuantizationConfigNVFP4PerTokenActivation: - uint8_to_bool(buf, config_.nvfp4_per_token_activation); - break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } diff --git a/transformer_engine/common/transpose/cast_transpose.h b/transformer_engine/common/transpose/cast_transpose.h index d2f8ba384a..1a91cdc298 100644 --- a/transformer_engine/common/transpose/cast_transpose.h +++ b/transformer_engine/common/transpose/cast_transpose.h @@ -68,7 +68,7 @@ void quantize_transpose_vector_blockwise_fp4( const bool return_identity, const bool return_transpose, const bool pow2_scale, const bool swizzled_scale, const bool use_stochastic_rounding, const NVTETensor rng_state_tensor, const bool use_2d_quantization, - const bool per_token_rowwise_scaling, const SimpleTensor &noop_tensor, cudaStream_t stream); + const bool rowwise_amax_is_row_scaled, const SimpleTensor &noop_tensor, cudaStream_t stream); } // namespace transformer_engine::detail diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index 64e4e09f89..d1b476f3d2 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -316,7 +316,7 @@ __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x(const float2 in01, template + bool kApplyStochasticRounding, bool kIs2DBlockScaling, bool kRowwiseAmaxIsRowScaled> __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel( const IType* const input, const float* global_amax, OType* const output_c, OType* const output_t, ScaleType* const tile_scales_inv_c, ScaleType* const tile_scales_inv_t, @@ -511,15 +511,15 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo // Step 2.4: Compute scale const size_t row_idx = block_idx_y * kTileDim + r_s; float row_global_encode_scale = global_encode_scale; - if constexpr (kPerTokenRowwiseScaling) { + if constexpr (kRowwiseAmaxIsRowScaled) { row_global_encode_scale = row_idx < num_rows ? ComputeGlobalEncodeScaleFP4(global_amax[row_idx]) : 1.0f; } - const float row_global_encode_scale_multiplier = kPerTokenRowwiseScaling + const float row_global_encode_scale_multiplier = kRowwiseAmaxIsRowScaled ? row_global_encode_scale * fp4_max_inv : global_encode_scale_multiplier; const float row_global_decode_scale = - kPerTokenRowwiseScaling ? 1.0f / row_global_encode_scale : global_decode_scale; + kRowwiseAmaxIsRowScaled ? 1.0f / row_global_encode_scale : global_decode_scale; ScaleType scale_inv = ComputeDecodeScaleFP4(amax, row_global_encode_scale_multiplier); float encode_scale = ComputeEncodeScaleFP4(scale_inv, row_global_decode_scale); @@ -721,7 +721,7 @@ void quantize_transpose_vector_blockwise_fp4( const bool return_identity, const bool return_transpose, const bool pow2_scale, const bool swizzled_scale, const bool use_stochastic_rounding, const NVTETensor rng_state_tensor, const bool use_2d_quantization, - const bool per_token_rowwise_scaling, const SimpleTensor& noop_tensor, cudaStream_t stream) { + const bool rowwise_amax_is_row_scaled, const SimpleTensor& noop_tensor, cudaStream_t stream) { NVTE_API_CALL(quantize_transpose_vector_blockwise_fp4); #if CUDA_VERSION >= 12080 @@ -734,10 +734,10 @@ void quantize_transpose_vector_blockwise_fp4( NVTE_CHECK(return_identity || !use_2d_quantization, "2D block quantization is only supported when return_identity is true."); - NVTE_CHECK(!per_token_rowwise_scaling || (return_identity && !return_transpose), - "Per-token NVFP4 rowwise scaling only supports rowwise quantization."); - NVTE_CHECK(!per_token_rowwise_scaling || !use_2d_quantization, - "Per-token NVFP4 rowwise scaling does not support 2D quantization."); + NVTE_CHECK(!rowwise_amax_is_row_scaled || (return_identity && !return_transpose), + "Row-scaled NVFP4 quantization only supports rowwise quantization."); + NVTE_CHECK(!rowwise_amax_is_row_scaled || !use_2d_quantization, + "Row-scaled NVFP4 quantization does not support 2D quantization."); const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; size_t num_elements = row_length; @@ -818,14 +818,14 @@ void quantize_transpose_vector_blockwise_fp4( use_2d_quantization, kIs2DBlockScaling, TRANSFORMER_ENGINE_SWITCH_CONDITION( - per_token_rowwise_scaling, kPerTokenRowwiseScaling, + rowwise_amax_is_row_scaled, kRowwiseAmaxIsRowScaled, size_t smem_bytes = kSMemSize * sizeof(InputType); auto kernel = block_scaled_1d_cast_transpose_kernel< kReturnIdentity, kReturnTranspose, kPow2Scale, kAligned, float, InputType, OutputType, ScaleType, kSwizzledScale, kApplyStochasticRounding, kIs2DBlockScaling, - kPerTokenRowwiseScaling>; + kRowwiseAmaxIsRowScaled>; if (smem_bytes >= 48 * 1024) { cudaError_t err = cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, @@ -843,7 +843,7 @@ void quantize_transpose_vector_blockwise_fp4( row_length, num_rows, scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, kScaleBlockDim, epsilon, rng_state, - noop_ptr);) // kPerTokenRowwiseScaling + noop_ptr);) // kRowwiseAmaxIsRowScaled ) // kIs2DBlockScaling ) // kApplyStochasticRounding ) // kSwizzledScale diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 49beef0778..e31743eb4b 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -70,35 +70,36 @@ def validate_gemm_scale(scale: Optional[float], required: bool) -> float: return 0.0 -def _is_nvfp4_per_token_tensor(tensor: torch.Tensor) -> bool: - """Whether tensor carries per-token NVFP4 global amax metadata.""" - if not isinstance(tensor, NVFP4TensorStorage): - return False - amax = tensor._amax_rowwise if tensor._amax_rowwise is not None else tensor._amax_columnwise - return amax is not None and amax.numel() > 1 +def _is_nvfp4_row_scaled_tensor(tensor: torch.Tensor) -> bool: + """Whether tensor carries row-scaled NVFP4 global amax metadata.""" + return isinstance(tensor, NVFP4TensorStorage) and tensor._rowwise_amax_is_row_scaled -def _nvfp4_per_token_gemm_inputs( +def _nvfp4_row_scaled_gemm_inputs( A: NVFP4TensorStorage, B: NVFP4TensorStorage, *, transa: bool, ) -> Tuple[NVFP4TensorStorage, NVFP4TensorStorage, torch.Tensor]: - """Return GEMM aliases and FP32 output scales for per-token NVFP4.""" + """Return GEMM aliases and FP32 output scales for row-scaled NVFP4.""" A_metadata = A.get_metadata() weight_amax = A._amax_rowwise if transa else A._amax_columnwise assert weight_amax is not None and weight_amax.numel() == 1 A_metadata["amax_rowwise" if transa else "amax_columnwise"] = weight_amax.new_ones(1) + A_metadata["rowwise_amax_is_row_scaled"] = False B_metadata = B.get_metadata() + assert B._rowwise_amax_is_row_scaled if B._amax_rowwise is not None: activation_amax = B._amax_rowwise - assert activation_amax.numel() > 1 + assert activation_amax.numel() == 0 or activation_amax.numel() > 1 B_metadata["amax_rowwise"] = activation_amax.new_ones(1) else: activation_amax = B._amax_columnwise - assert activation_amax is not None and activation_amax.numel() > 1 + assert activation_amax is not None + assert activation_amax.numel() == 0 or activation_amax.numel() > 1 B_metadata["amax_columnwise"] = activation_amax.new_ones(1) + B_metadata["rowwise_amax_is_row_scaled"] = False assert activation_amax.dtype == torch.float32 and weight_amax.dtype == torch.float32 return ( @@ -213,27 +214,29 @@ def general_gemm( "beta": beta, } - if not _is_nvfp4_per_token_tensor(B): + if not _is_nvfp4_row_scaled_tensor(B): out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) else: - assert layout[1] == "N", "Per-token NVFP4 GEMM currently supports N-layout B only." + assert layout[1] == "N", "Row-scaled NVFP4 GEMM currently supports N-layout B only." if grad: raise RuntimeError( - "Per-token NVFP4 GEMM currently supports fprop only. " + "Row-scaled NVFP4 GEMM currently supports fprop only. " "Backward NVFP4 gradient quantizers should use scalar global amax." ) - assert not gelu, "Per-token NVFP4 GEMM currently does not support fused GELU." - assert not accumulate, "Per-token NVFP4 GEMM currently does not support accumulation." + assert not gelu, "Row-scaled NVFP4 GEMM currently does not support fused GELU." + assert not accumulate, "Row-scaled NVFP4 GEMM currently does not support accumulation." assert ( quantization_params is None - ), "Per-token NVFP4 GEMM currently does not support output quantization." + ), "Row-scaled NVFP4 GEMM currently does not support output quantization." assert out is None or ( isinstance(out, torch.Tensor) and not is_custom(out) - ), "Per-token NVFP4 GEMM currently supports only plain torch.Tensor outputs." - assert isinstance(A, NVFP4TensorStorage), "Per-token NVFP4 GEMM currently requires NVFP4 A." - # cuBLAS folds NVFP4 global amax values into GEMM alpha. Keep the per-token + ), "Row-scaled NVFP4 GEMM currently supports only plain torch.Tensor outputs." + assert isinstance( + A, NVFP4TensorStorage + ), "Row-scaled NVFP4 GEMM currently requires NVFP4 A." + # cuBLAS folds NVFP4 global amax values into GEMM alpha. Keep the row-scaled # recipe's global scales out of alpha and apply them in FP32 below. - gemm_A, gemm_B, per_token_scales = _nvfp4_per_token_gemm_inputs(A, B, transa=transa) + gemm_A, gemm_B, rowwise_global_scales = _nvfp4_row_scaled_gemm_inputs(A, B, transa=transa) requested_out, requested_out_dtype = out, out_dtype fp32_out = ( @@ -250,16 +253,16 @@ def general_gemm( out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*gemm_args, **kwargs) out_2d = out.reshape(-1, out.shape[-1]) - assert per_token_scales.dtype == torch.float32 and out.dtype == torch.float32 - assert per_token_scales.numel() == out_2d.shape[0] + assert rowwise_global_scales.dtype == torch.float32 and out.dtype == torch.float32 + assert rowwise_global_scales.numel() == out_2d.shape[0] if bias is not None: bias_cast = bias.to(dtype=torch.float32) out_2d.sub_(bias_cast) - out_2d.mul_(per_token_scales) + out_2d.mul_(rowwise_global_scales) out_2d.add_(bias_cast) else: - out_2d.mul_(per_token_scales) + out_2d.mul_(rowwise_global_scales) if requested_out is not None: requested_out.copy_(out.to(dtype=requested_out.dtype)) @@ -320,25 +323,25 @@ def general_grouped_gemm( else: bias_dtype = TE_DType[torch.bfloat16] - if any(_is_nvfp4_per_token_tensor(tensor) for tensor in B): - assert layout[1] == "N", "Per-token NVFP4 grouped GEMM currently supports N-layout B only." + if any(_is_nvfp4_row_scaled_tensor(tensor) for tensor in B): + assert layout[1] == "N", "Row-scaled NVFP4 grouped GEMM currently supports N-layout B only." if grad: raise RuntimeError( - "Per-token NVFP4 grouped GEMM currently supports fprop only. " + "Row-scaled NVFP4 grouped GEMM currently supports fprop only. " "Backward NVFP4 gradient quantizers should use scalar global amax." ) - assert not gelu, "Per-token NVFP4 grouped GEMM currently does not support fused GELU." + assert not gelu, "Row-scaled NVFP4 grouped GEMM currently does not support fused GELU." assert ( not accumulate - ), "Per-token NVFP4 grouped GEMM currently does not support accumulation." - assert D_dtype is None, "Per-token NVFP4 grouped GEMM currently does not support D_dtype." + ), "Row-scaled NVFP4 grouped GEMM currently does not support accumulation." + assert D_dtype is None, "Row-scaled NVFP4 grouped GEMM currently does not support D_dtype." assert all( q is None for q in quantization_params - ), "Per-token NVFP4 grouped GEMM currently does not support output quantization." + ), "Row-scaled NVFP4 grouped GEMM currently does not support output quantization." if single_output: assert ( m_splits is not None - ), "Per-token NVFP4 grouped GEMM requires m_splits with single output." + ), "Row-scaled NVFP4 grouped GEMM requires m_splits with single output." out_init = out[0] if single_output else None if single_output: start_idx = 0 diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index b9f852c07d..8593dbc331 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -320,7 +320,7 @@ class NVFP4Quantizer : public Quantizer { // 2D block scaling bool with_2d_quantization; bool stochastic_rounding; - bool per_token_activation; + bool row_scaled_activation; int rht_matrix_random_sign_mask_t; at::Tensor rht_matrix; diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 17f86d63d6..68d016f7a9 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -42,7 +42,7 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->per_token_activation || + if (nvfp4_quantizer_cpp->row_scaled_activation || (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; @@ -155,7 +155,7 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->per_token_activation || + if (nvfp4_quantizer_cpp->row_scaled_activation || (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index e2dba46370..c64f84bf3d 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -152,7 +152,7 @@ std::vector dact_dbias( } else if (detail::IsNVFP4Quantizers(quantizer_py.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->per_token_activation || + if (nvfp4_quantizer_cpp->row_scaled_activation || (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 3497f1aa59..abd0ce5189 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -798,10 +798,11 @@ std::tuple, std::vector, bool> bulk_alloc // Quantization parameters const auto rowwise_usage = quantizer_cpp_list[0]->rowwise_usage; - const bool per_token_activation = quantizer_cpp_list[0]->per_token_activation; - NVTE_CHECK(!per_token_activation || rowwise_usage, - "Per-token NVFP4 quantization requires rowwise usage."); - const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage && !per_token_activation; + const bool rowwise_amax_is_row_scaled = quantizer_cpp_list[0]->row_scaled_activation; + NVTE_CHECK(!rowwise_amax_is_row_scaled || rowwise_usage, + "Row-scaled NVFP4 quantization requires rowwise usage."); + const auto columnwise_usage = + quantizer_cpp_list[0]->columnwise_usage && !rowwise_amax_is_row_scaled; const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode(); const auto fp4_dtype = quantizer_cpp_list[0]->dtype; const bool with_gemm_swizzled_scales = false; /// TODO (tmoon) Enable based on optimize_for_gemm; @@ -880,7 +881,7 @@ std::tuple, std::vector, bool> bulk_alloc const auto offset = roundup(buffer_size, 16); amax_offsets.push_back(offset); const size_t amax_size = - per_token_activation ? 4 * flat_first_dim(rowwise_data_shapes[i]) : 4; + rowwise_amax_is_row_scaled ? 4 * flat_first_dim(rowwise_data_shapes[i]) : 4; buffer_size = offset + amax_size; } @@ -895,8 +896,8 @@ std::tuple, std::vector, bool> bulk_alloc rowwise_scale_list.emplace_back( make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); const std::vector amax_shape = - per_token_activation ? std::vector{flat_first_dim(rowwise_data_shapes[i])} - : std::vector{1}; + rowwise_amax_is_row_scaled ? std::vector{flat_first_dim(rowwise_data_shapes[i])} + : std::vector{1}; amax_rowwise_list.emplace_back( make_torch_view(buffer, amax_shape, amax_offsets[i], torch::kFloat32)); } @@ -978,9 +979,10 @@ std::tuple, std::vector, bool> bulk_alloc py::object amax_columnwise = columnwise_usage ? py::cast(amax_columnwise_list[i]) : py::none(); // Construct Python tensor - tensor_py_list.emplace_back(NVFP4TensorClass( - rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, amax_rowwise, - amax_columnwise, fp4_dtype, quantizer_py_list[i], with_gemm_swizzled_scales)); + tensor_py_list.emplace_back( + NVFP4TensorClass(rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, + amax_rowwise, amax_columnwise, fp4_dtype, quantizer_py_list[i], + with_gemm_swizzled_scales, rowwise_amax_is_row_scaled)); // Construct C++ tensor // Use a TensorWrapper variable to hold the output of makeTransformerEngineTensor, @@ -997,6 +999,7 @@ std::tuple, std::vector, bool> bulk_alloc rowwise_usage ? rowwise_scale_shapes[i] : std::vector{0}, columnwise_usage ? columnwise_scale_shapes[i] : std::vector{0}, scaling_mode); tensor_wrapper.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); + tensor_wrapper.set_rowwise_amax_is_row_scaled(rowwise_amax_is_row_scaled); // Set the amax rowwise and amax columnwise if available if (rowwise_usage) { @@ -1281,19 +1284,12 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, nvte_tensor_output_list.push_back(output_list[i].data()); } - if (quantizer.per_token_activation) { - NVTE_CHECK(!quantizer.with_rht, "Per-token NVFP4 split quantize does not support RHT."); + if (quantizer.row_scaled_activation) { + NVTE_CHECK(!quantizer.with_rht, "Row-scaled NVFP4 split quantize does not support RHT."); NVTE_CHECK(!quantizer.with_2d_quantization, - "Per-token NVFP4 split quantize does not support 2D quantization."); + "Row-scaled NVFP4 split quantize does not support 2D quantization."); NVTE_CHECK(!quantizer.stochastic_rounding, - "Per-token NVFP4 split quantize does not support stochastic rounding."); - - std::vector quant_config_list; - quant_config_list.reserve(num_tensors); - for (size_t i = 0; i < num_tensors; ++i) { - quant_config_list.emplace_back(QuantizationConfigWrapper()); - quant_config_list.back().set_nvfp4_per_token_activation(true); - } + "Row-scaled NVFP4 split quantize does not support stochastic rounding."); for (size_t i = 0; i < num_tensors; i++) { if (input_list[i].numel() == 0) { @@ -1302,8 +1298,10 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, const size_t input_ndim = input_list[i].ndim(); const size_t cols = input_ndim > 0 ? input_list[i].size(input_ndim - 1) : 1; NVTE_CHECK(cols % 16 == 0, - "Per-token NVFP4 split quantize requires split inner dim divisible by 16."); - nvte_quantize_v2(input_list[i].data(), output_list[i].data(), quant_config_list[i], stream); + "Row-scaled NVFP4 split quantize requires split inner dim divisible by 16."); + output_list[i].set_rowwise_amax_is_row_scaled(true); + QuantizationConfigWrapper quant_config; + nvte_quantize_v2(input_list[i].data(), output_list[i].data(), quant_config, stream); } return; } @@ -1405,9 +1403,9 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, // Check input tensor shape const size_t input_last_dim = input.ndim() > 0 ? input.size(input.ndim() - 1) : 1; - if (quantizer.per_token_activation) { + if (quantizer.row_scaled_activation) { NVTE_CHECK(input_last_dim % 16 == 0, - "Per-token NVFP4 split-quantize requires inner dim to be multiple of 16."); + "Row-scaled NVFP4 split-quantize requires inner dim to be multiple of 16."); } else { NVTE_CHECK(input_last_dim % 128 == 0, "NVFP4 multi-quantize requires inner dim to be multiple of 128."); @@ -1487,11 +1485,11 @@ std::vector split_quantize(const at::Tensor &tensor, [](const py::handle &quantizer) -> bool { return detail::IsNVFP4Quantizers(quantizer.ptr()); }); - const bool all_nvfp4_per_token_activation = + const bool all_nvfp4_row_scaled_activation = all_nvfp4_quantizers && std::all_of(quantizer_cpp_list.begin(), quantizer_cpp_list.end(), [](const std::unique_ptr &quantizer) -> bool { - return static_cast(quantizer.get())->per_token_activation; + return static_cast(quantizer.get())->row_scaled_activation; }); // Choose implementation for allocating and populating tensors @@ -1499,7 +1497,7 @@ std::vector split_quantize(const at::Tensor &tensor, enum class QuantizationMethod { UNFUSED, FUSED_NVFP4 }; AllocationMethod allocation_method = AllocationMethod::UNFUSED; QuantizationMethod quantization_method = QuantizationMethod::UNFUSED; - if (all_nvfp4_per_token_activation) { + if (all_nvfp4_row_scaled_activation) { quantization_method = QuantizationMethod::FUSED_NVFP4; } if (!disable_bulk_allocation) { @@ -1552,7 +1550,7 @@ std::vector split_quantize(const at::Tensor &tensor, bool contiguous_data_and_scale = false; std::tie(output_py_list, output_cpp_list, contiguous_data_and_scale) = bulk_allocate_nvfp4_tensors(split_shapes, quantizer_list, nvfp4_quantizers); - if (!all_nvfp4_per_token_activation && !input_shape.empty() && + if (!all_nvfp4_row_scaled_activation && !input_shape.empty() && input_shape.back() % 128 != 0) { static std::once_flag once_unfused_nvfp4_fallback_warning; std::call_once(once_unfused_nvfp4_fallback_warning, []() { @@ -1563,7 +1561,7 @@ std::vector split_quantize(const at::Tensor &tensor, }); quantization_method = QuantizationMethod::UNFUSED; } - if (!all_nvfp4_per_token_activation && !contiguous_data_and_scale) { + if (!all_nvfp4_row_scaled_activation && !contiguous_data_and_scale) { // Avoid fused quantize kernel if data is not contiguous quantization_method = QuantizationMethod::UNFUSED; } diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 3975c01fa5..34f6de658d 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -120,7 +120,7 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->per_token_activation || + if (nvfp4_quantizer_cpp->row_scaled_activation || (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; @@ -358,7 +358,7 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->per_token_activation || + if (nvfp4_quantizer_cpp->row_scaled_activation || (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 0b9119924e..be63bec98a 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1696,7 +1696,7 @@ NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantize this->with_post_rht_amax = quantizer.attr("with_post_rht_amax").cast(); this->with_2d_quantization = quantizer.attr("with_2d_quantization").cast(); this->stochastic_rounding = quantizer.attr("stochastic_rounding").cast(); - this->per_token_activation = quantizer.attr("per_token_activation").cast(); + this->row_scaled_activation = quantizer.attr("row_scaled_activation").cast(); // Get amax reduction group if needed for NVFP4 AG const bool with_amax_reduction = quantizer.attr("with_amax_reduction").cast(); @@ -1748,9 +1748,10 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve NVTE_CHECK(flat_last_dim % NVFP4_BLOCK_SIZE == 0, "NVFP4 requires tensor dims that are divisible by ", NVFP4_BLOCK_SIZE, " (got shape=", shape, ")"); - NVTE_CHECK(!this->per_token_activation || rowwise_usage, - "Per-token NVFP4 quantization requires rowwise usage."); - const bool columnwise_usage = this->columnwise_usage && !this->per_token_activation; + const bool rowwise_amax_is_row_scaled = this->row_scaled_activation; + NVTE_CHECK(!rowwise_amax_is_row_scaled || rowwise_usage, + "Row-scaled NVFP4 quantization requires rowwise usage."); + const bool columnwise_usage = this->columnwise_usage && !this->row_scaled_activation; const auto rowwise_scale_inv_shape = get_scale_shape(shape, false); const auto columnwise_scale_inv_shape = get_scale_shape(shape, true); @@ -1764,7 +1765,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve rowwise_scale_inv_shape.end()); rowwise_data_tensor = at::empty(convert_shape_for_fp4(shape_int64), bit8_tensor_opts); rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); - const int64_t amax_rows = this->per_token_activation ? static_cast(flat_first_dim) : 1; + const int64_t amax_rows = rowwise_amax_is_row_scaled ? static_cast(flat_first_dim) : 1; // hadamard amax kernel will zero out pointer with ZeroAmaxKernel // nvte_compute_amax_with_config will zero out the pointer if needed amax_rowwise = at::empty({amax_rows}, bit32_tensor_opts); @@ -1810,6 +1811,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve kwargs["fp4_dtype"] = py::cast(this->dtype); kwargs["quantizer"] = this->quantizer; kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); + kwargs["rowwise_amax_is_row_scaled"] = py::cast(rowwise_amax_is_row_scaled); kwargs["fake_dtype"] = GetATenDType(dtype); py::tuple args(0); @@ -1838,6 +1840,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve kwargs["fp4_dtype"] = py::cast(this->dtype); kwargs["quantizer"] = this->quantizer; kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); + kwargs["rowwise_amax_is_row_scaled"] = py::cast(rowwise_amax_is_row_scaled); py::tuple args(0); PyObject* result = PyObject_Call(reinterpret_cast(NVFP4TensorPythonClass), args.ptr(), kwargs.ptr()); @@ -1870,6 +1873,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve std::vector{1}); } out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); + out_cpp.set_rowwise_amax_is_row_scaled(rowwise_amax_is_row_scaled); this->set_quantization_params(&out_cpp); return {std::move(out_cpp), std::move(out_py)}; @@ -1897,9 +1901,9 @@ std::pair NVFP4Quantizer::create_grouped_tenso std::optional rowwise_amax; std::optional columnwise_amax; const std::vector logical_shape_vec = {logical_first_dim, logical_last_dim}; - NVTE_CHECK(!this->per_token_activation || rowwise_usage, - "Per-token NVFP4 grouped quantization requires rowwise usage."); - const bool columnwise_usage = this->columnwise_usage && !this->per_token_activation; + NVTE_CHECK(!this->row_scaled_activation || rowwise_usage, + "Row-scaled NVFP4 grouped quantization requires rowwise usage."); + const bool columnwise_usage = this->columnwise_usage && !this->row_scaled_activation; const int64_t total_data_elements = total_elements / 2; @@ -2046,9 +2050,11 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( } } const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; - NVTE_CHECK(!this->per_token_activation || rowwise_usage, - "Per-token NVFP4 quantization requires rowwise usage."); - const bool columnwise_usage = this->columnwise_usage && !this->per_token_activation; + const bool rowwise_amax_is_row_scaled = this->row_scaled_activation; + NVTE_CHECK(!rowwise_amax_is_row_scaled || rowwise_usage, + "Row-scaled NVFP4 quantization requires rowwise usage."); + const bool columnwise_usage = this->columnwise_usage && !this->row_scaled_activation; + tensor.attr("_rowwise_amax_is_row_scaled") = py::cast(rowwise_amax_is_row_scaled); // Coerce row-wise data if (rowwise_usage) { @@ -2066,10 +2072,9 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( rowwise_scale_inv = at::empty(scale_inv_shape_int64, opts); tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv; } - if (!amax_rowwise) { + const int64_t amax_rows = rowwise_amax_is_row_scaled ? static_cast(flat_first_dim) : 1; + if (!amax_rowwise || amax_rowwise->numel() != amax_rows) { const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); - const int64_t amax_rows = - this->per_token_activation ? static_cast(flat_first_dim) : 1; // hadamard amax kernel will zero out pointer with ZeroAmaxKernel // nvte_compute_amax_with_config will zero out the pointer if needed amax_rowwise = at::empty({amax_rows}, opts); @@ -2153,6 +2158,7 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( std::vector{1}); } out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); + out_cpp.set_rowwise_amax_is_row_scaled(rowwise_amax_is_row_scaled); this->set_quantization_params(&out_cpp); return {std::move(out_cpp), std::move(tensor)}; @@ -2261,16 +2267,18 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou } size_t cols = input.size(input.ndim() - 1); - if (this->per_token_activation) { - NVTE_CHECK(!this->with_rht, "Per-token NVFP4 activation does not support RHT."); + if (this->row_scaled_activation) { + out.set_rowwise_amax_is_row_scaled(true); + NVTE_CHECK(!this->with_rht, "Row-scaled NVFP4 quantization does not support RHT."); NVTE_CHECK(!this->with_2d_quantization, - "Per-token NVFP4 activation does not support 2D quantization."); + "Row-scaled NVFP4 quantization does not support 2D quantization."); NVTE_CHECK(!this->stochastic_rounding, - "Per-token NVFP4 activation does not support stochastic rounding."); + "Row-scaled NVFP4 quantization does not support stochastic rounding."); NVTE_CHECK(!this->with_amax_reduction, - "Per-token NVFP4 activation does not support amax reduction."); - NVTE_CHECK(cols % 16 == 0, "Per-token NVFP4 activation requires last dim divisible by 16."); - quant_config.set_nvfp4_per_token_activation(true); + "Row-scaled NVFP4 quantization does not support amax reduction."); + NVTE_CHECK(cols % 16 == 0, "Row-scaled NVFP4 quantization requires last dim divisible by 16."); + } else { + out.set_rowwise_amax_is_row_scaled(false); } // Restriction for the RHT cast fusion kernel because we are using MMA hardware for computing RHT @@ -2339,7 +2347,7 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou "Use with_post_rht_amax=true instead."); } } else { // Without RHT - if (compute_amax && !this->per_token_activation) { + if (compute_amax && !this->row_scaled_activation) { // Amax pointers auto rowwise_amax_ptr = out.get_amax().data_ptr; auto columnwise_amax_ptr = out.get_columnwise_amax().data_ptr; @@ -2440,8 +2448,8 @@ void NVFP4Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out, } void NVFP4Quantizer::quantize_with_amax(TensorWrapper& input, TensorWrapper& out) { - NVTE_CHECK(!this->per_token_activation, - "quantize_with_amax is not supported for per-token NVFP4 activation."); + NVTE_CHECK(!this->row_scaled_activation, + "quantize_with_amax is not supported for row-scaled NVFP4 quantization."); // Update output tensor amaxes with input tensor amax auto input_amax_ptr = input.amax(); auto output_rowwise_amax_ptr = out.get_amax().data_ptr; diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index e13554a98c..5e0310b4ce 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -134,6 +134,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) const bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none()); const bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); const bool with_gemm_swizzled_scales = tensor.attr("_with_gemm_swizzled_scales").cast(); + const bool rowwise_amax_is_row_scaled = tensor.attr("_rowwise_amax_is_row_scaled").cast(); NVTE_CHECK(rowwise_usage || columnwise_usage, "No data found for NVFP4 Tensor."); @@ -163,6 +164,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) // Scale layout ret.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); + ret.set_rowwise_amax_is_row_scaled(rowwise_amax_is_row_scaled); // Quantizer state quantizer->set_quantization_params(&ret); diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py index 36c50309f5..252e819dc7 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py @@ -350,18 +350,18 @@ def __init__( pow_2_scales: bool = False, eps: float = 0.0, quant_tile_shape: Tuple[int, int] = (1, 16), - per_token_activation: bool = False, + row_scaled_activation: bool = False, with_rht: bool = False, with_random_sign_mask: bool = True, ): - super().__init__(rowwise=rowwise, columnwise=columnwise and not per_token_activation) + super().__init__(rowwise=rowwise, columnwise=columnwise and not row_scaled_activation) self.internal = True self.dtype = dtype self.pow_2_scales = pow_2_scales self.eps = eps self.quant_tile_shape = quant_tile_shape - self.per_token_activation = per_token_activation + self.row_scaled_activation = row_scaled_activation self.with_rht = with_rht self.with_random_sign_mask = with_random_sign_mask @@ -449,7 +449,7 @@ def _quantize_blockwise_reference( tile_len_y: int, *, pow_2_scales: bool, - per_token_rowwise: bool = False, + rowwise_amax_is_row_scaled: bool = False, eps: float, # pylint: disable=unused-argument ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -491,7 +491,7 @@ def _quantize_blockwise_reference( decode_scale.to(torch.float32), ) else: - if per_token_rowwise: + if rowwise_amax_is_row_scaled: global_amax = global_amax.to(torch.float32).view(m, 1, 1) global_encode_scale = torch.div(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX, global_amax) @@ -622,9 +622,9 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ raise ValueError( f"MXFP4 only supports 1x32 tile shape, got {self.quant_tile_shape}" ) - if self.per_token_activation: + if self.row_scaled_activation: raise ValueError( - "Per-token activation is only supported for NVFP4 (non-pow2) mode." + "Row-scaled activation is only supported for NVFP4 (non-pow2) mode." ) # TODO(etsykunov): Fix bug where global_amax_row and # global_amax_col are not defined @@ -642,10 +642,10 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ if self.with_rht else tensor.t().contiguous() ) - if self.per_token_activation: + if self.row_scaled_activation: if self.quant_tile_shape != (1, 16): raise ValueError( - "Per-token activation only supports NVFP4 1x16 tile shape, " + "Row-scaled activation only supports NVFP4 1x16 tile shape, " f"got {self.quant_tile_shape}" ) global_amax_row = torch.max(torch.abs(row_input), dim=1).values.to(torch.float32) @@ -674,7 +674,7 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ self.quant_tile_shape[1], self.quant_tile_shape[0], pow_2_scales=self.pow_2_scales, - per_token_rowwise=self.per_token_activation, + rowwise_amax_is_row_scaled=self.row_scaled_activation, eps=self.eps, ) if transpose_scales: diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 2cb6c21946..a1a63ea4ce 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -1375,7 +1375,7 @@ def _make_quantizer(idx: int) -> NVFP4Quantizer: with_post_rht_amax=qparams.random_hadamard_transform, with_2d_quantization=qparams.fp4_2d_quantization, stochastic_rounding=qparams.stochastic_rounding, - per_token_activation=self.recipe.per_token_activation and idx % 3 != 1, + row_scaled_activation=self.recipe.row_scaled_activation and idx % 3 != 1, ) return [_make_quantizer(idx) for idx in range(self.num_quantizers)] @@ -1390,7 +1390,7 @@ def _make_quantizer(idx: int) -> NVFP4Quantizer: with_post_rht_amax=self.recipe.fp4_quant_bwd_grad.random_hadamard_transform, with_2d_quantization=self.recipe.fp4_quant_bwd_grad.fp4_2d_quantization, stochastic_rounding=self.recipe.fp4_quant_bwd_grad.stochastic_rounding, - per_token_activation=False, + row_scaled_activation=False, ) for _ in range(self.num_quantizers) ] diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index d6674f752e..01b9686b00 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -128,8 +128,8 @@ class NVFP4Quantizer(Quantizer): """Stochastic rounding, only applicable for gradients.""" stochastic_rounding: bool - """Per-token activation quantization path.""" - per_token_activation: bool + """Row-scaled activation quantization path.""" + row_scaled_activation: bool """RHT matrix random sign mask""" rht_matrix_random_sign_mask_t: int @@ -146,7 +146,7 @@ def __init__( with_post_rht_amax: bool = False, with_2d_quantization: bool = False, stochastic_rounding: bool = False, - per_token_activation: bool = False, + row_scaled_activation: bool = False, with_random_sign_mask: bool = True, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) @@ -157,7 +157,7 @@ def __init__( self.amax_reduction_group = amax_reduction_group self.with_2d_quantization = with_2d_quantization self.stochastic_rounding = stochastic_rounding - self.per_token_activation = per_token_activation + self.row_scaled_activation = row_scaled_activation self.rht_matrix_random_sign_mask_t = get_random_sign_mask_for_rht( with_random_sign_mask, torch.cuda.current_device() ) @@ -203,7 +203,7 @@ def copy(self) -> NVFP4Quantizer: with_post_rht_amax=self.with_post_rht_amax, with_2d_quantization=self.with_2d_quantization, stochastic_rounding=self.stochastic_rounding, - per_token_activation=self.per_token_activation, + row_scaled_activation=self.row_scaled_activation, ) quantizer.internal = self.internal quantizer.optimize_for_gemm = self.optimize_for_gemm @@ -336,7 +336,7 @@ def make_empty( scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory ) # Allocate per tensor scale inverse. FP32 format. - amax_rows = flat_first_dim if self.per_token_activation else 1 + amax_rows = flat_first_dim if self.row_scaled_activation else 1 amax_rowwise = torch.zeros( amax_rows, dtype=torch.float32, device=device, pin_memory=pin_memory ) @@ -345,7 +345,7 @@ def make_empty( columnwise_data = None columnwise_scale_inv = None amax_columnwise = None - columnwise_usage = self.columnwise_usage and not self.per_token_activation + columnwise_usage = self.columnwise_usage and not self.row_scaled_activation if columnwise_usage: # enforce 2D shape to avoid [S, B, H] shape and B and be 1 # and the transposed shape is [H, S, B], so divide last dim by 2 gives zero @@ -381,6 +381,7 @@ def make_empty( quantizer=self, requires_grad=requires_grad, with_gemm_swizzled_scales=False, + rowwise_amax_is_row_scaled=self.row_scaled_activation, ) def calibrate(self, tensor: torch.Tensor) -> None: @@ -441,6 +442,7 @@ def __new__( fp4_dtype: TE_DType, quantizer: Quantizer, with_gemm_swizzled_scales: bool, + rowwise_amax_is_row_scaled: bool = False, **kwargs, ): instance = super().__new__( @@ -454,6 +456,7 @@ def __new__( fp4_dtype, quantizer, with_gemm_swizzled_scales, + rowwise_amax_is_row_scaled, *args, **kwargs, ) diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index 2ea6ef958a..784e1a158d 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -662,10 +662,10 @@ def make_grouped_tensor( # Amax buffer for delayed scaling - one per tensor amax = torch.empty(num_tensors, dtype=torch.float32, device=device) elif quantizer._get_compatible_recipe().nvfp4(): - per_token_activation = getattr(quantizer, "per_token_activation", False) - columnwise_usage = columnwise_usage and not per_token_activation + row_scaled_activation = getattr(quantizer, "row_scaled_activation", False) + columnwise_usage = columnwise_usage and not row_scaled_activation total_amax_elements = ( - sum(math.prod(s[:-1]) for s in shape) if per_token_activation else num_tensors + sum(math.prod(s[:-1]) for s in shape) if row_scaled_activation else num_tensors ) if rowwise_usage: @@ -896,13 +896,13 @@ def split_into_quantized_tensors( cum += math.prod(scale_shape) columnwise_scale_inv_offsets.append(cum) self.columnwise_scale_inv_offsets = columnwise_scale_inv_offsets - nvfp4_per_token_amax_offsets = None - if recipe.nvfp4() and getattr(self.quantizer, "per_token_activation", False): + nvfp4_rowwise_amax_offsets = None + if recipe.nvfp4() and getattr(self.quantizer, "row_scaled_activation", False): cum = 0 - nvfp4_per_token_amax_offsets = [0] + nvfp4_rowwise_amax_offsets = [0] for i in range(self.num_tensors): cum += math.prod(self.tensor_shapes[i][:-1]) - nvfp4_per_token_amax_offsets.append(cum) + nvfp4_rowwise_amax_offsets.append(cum) for i in range(self.num_tensors): quantizer = self.quantizer @@ -1096,17 +1096,17 @@ def split_into_quantized_tensors( ) if self.amax is not None: - if nvfp4_per_token_amax_offsets is not None: - amax_start = nvfp4_per_token_amax_offsets[i] - amax_end = nvfp4_per_token_amax_offsets[i + 1] + if nvfp4_rowwise_amax_offsets is not None: + amax_start = nvfp4_rowwise_amax_offsets[i] + amax_end = nvfp4_rowwise_amax_offsets[i + 1] amax_rowwise = self.amax[amax_start:amax_end] else: amax_rowwise = self.amax[i : i + 1] if self.columnwise_amax is not None: - if nvfp4_per_token_amax_offsets is not None: - amax_start = nvfp4_per_token_amax_offsets[i] - amax_end = nvfp4_per_token_amax_offsets[i + 1] + if nvfp4_rowwise_amax_offsets is not None: + amax_start = nvfp4_rowwise_amax_offsets[i] + amax_end = nvfp4_rowwise_amax_offsets[i + 1] amax_columnwise = self.columnwise_amax[amax_start:amax_end] else: amax_columnwise = self.columnwise_amax[i : i + 1] @@ -1128,6 +1128,7 @@ def split_into_quantized_tensors( fp4_dtype=quantizer.dtype, quantizer=quantizer, with_gemm_swizzled_scales=quantizer.optimize_for_gemm, + rowwise_amax_is_row_scaled=getattr(quantizer, "row_scaled_activation", False), ) result.append(tensor) diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index 70699ad71a..ec06839dfb 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -97,6 +97,8 @@ class NVFP4TensorStorage(QuantizedTensorStorage): # Whether scaling factors are in the swizzled format expected by # GEMM _with_gemm_swizzled_scales: bool + # Whether rowwise amax stores one value per tensor row + _rowwise_amax_is_row_scaled: bool def __new__( cls, @@ -109,6 +111,7 @@ def __new__( fp4_dtype: TE_DType, quantizer: Optional[Quantizer], with_gemm_swizzled_scales: bool, + rowwise_amax_is_row_scaled: bool = False, *args, fake_dtype: Optional[torch.dtype] = None, **kwargs, @@ -128,6 +131,7 @@ def __new__( instance._amax_rowwise = amax_rowwise instance._amax_columnwise = amax_columnwise instance._with_gemm_swizzled_scales = with_gemm_swizzled_scales + instance._rowwise_amax_is_row_scaled = rowwise_amax_is_row_scaled return instance @@ -152,6 +156,8 @@ def copy_from_storage(self, src: QuantizedTensorStorage) -> None: raise RuntimeError("FP4 dtype mismatch in copy_from_storage") if self._with_gemm_swizzled_scales != src._with_gemm_swizzled_scales: raise RuntimeError("Scale layout mismatch in copy_from_storage") + if self._rowwise_amax_is_row_scaled != src._rowwise_amax_is_row_scaled: + raise RuntimeError("Rowwise amax scaling mode mismatch in copy_from_storage") def _copy_optional(dst: Optional[torch.Tensor], src_tensor: Optional[torch.Tensor]): if dst is not None and src_tensor is not None: @@ -176,6 +182,7 @@ def get_metadata(self) -> Dict[str, Any]: "fp4_dtype": self._fp4_dtype, "quantizer": self._quantizer, "with_gemm_swizzled_scales": self._with_gemm_swizzled_scales, + "rowwise_amax_is_row_scaled": self._rowwise_amax_is_row_scaled, "fake_dtype": self._dtype, } @@ -308,6 +315,7 @@ def view(self, shape: torch.Size): quantizer=self._quantizer, fp4_dtype=self._fp4_dtype, with_gemm_swizzled_scales=self._with_gemm_swizzled_scales, + rowwise_amax_is_row_scaled=self._rowwise_amax_is_row_scaled, fake_dtype=self._dtype, ) From 66622e80fb1a6b19cb7f949b0947f76c6fba9426 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 5 May 2026 17:45:56 -0700 Subject: [PATCH 32/45] Further refactor Signed-off-by: Ziang Li --- .../cpp/operator/test_cast_nvfp4_transpose.cu | 16 ++--- tests/cpp/operator/test_dequantize_nvfp4.cu | 20 +++--- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 26 ++++---- .../nvfp4/test_nvfp4_quantize_exact.py | 62 +++++++++++-------- tests/pytorch/test_recipe.py | 12 ++-- .../common/cast/dispatch/quantize.cuh | 2 +- .../cast/nvfp4/quantize_transpose_nvfp4.cuh | 4 +- .../quantize_transpose_nvfp4_tuned_1D.cuh | 4 +- .../common/include/transformer_engine/gemm.h | 2 +- .../pytorch/cpp_extensions/gemm.py | 17 ++--- transformer_engine/pytorch/csrc/common.h | 3 +- .../pytorch/csrc/extensions/activation.cpp | 4 +- .../pytorch/csrc/extensions/bias.cpp | 2 +- .../pytorch/csrc/extensions/cast.cpp | 32 ++++++---- .../pytorch/csrc/extensions/normalization.cpp | 4 +- transformer_engine/pytorch/csrc/quantizer.cpp | 31 +++++----- .../custom_recipes/quantization_nvfp4.py | 18 +++--- transformer_engine/pytorch/quantization.py | 4 +- .../pytorch/tensor/grouped_tensor.py | 3 + .../pytorch/tensor/nvfp4_tensor.py | 18 +++--- .../tensor/storage/grouped_tensor_storage.py | 19 ++++-- 21 files changed, 168 insertions(+), 135 deletions(-) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 34101e8572..21fb93b428 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -562,7 +562,7 @@ template void performTest(float (*OP)(const float), const std::vector& shape, const bool use_fast_math, - const bool row_scaled_activation = false) { + const bool rowwise_amax_is_row_scaled = false) { using namespace test; DType itype = TypeInfo::dtype; @@ -589,7 +589,7 @@ void performTest(float (*OP)(const float), const size_t scales_stride_t = blocks_X_t; Tensor input("input", shape, itype); - Tensor output("output", shape, otype, true, !row_scaled_activation, NVTE_NVFP4_1D_SCALING); + Tensor output("output", shape, otype, true, !rowwise_amax_is_row_scaled, NVTE_NVFP4_1D_SCALING); std::unique_ptr ref_output = std::make_unique(rows * (cols / 2)); std::unique_ptr ref_output_t = std::make_unique(cols * (rows / 2)); @@ -602,7 +602,7 @@ void performTest(float (*OP)(const float), const float amax = 448.0f * 6.0f * 8.0f; std::vector ref_rowwise_amax; bool use_2d_quantization = false; - if (row_scaled_activation) { + if (rowwise_amax_is_row_scaled) { output.set_tensor_amax_shape({rows}); output.set_rowwise_amax_is_row_scaled(true); compute_ref(OP, @@ -681,7 +681,7 @@ void performTest(float (*OP)(const float), // Set dump_data=true to enable dumping tensor data to files for analysis compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, true, - false, !row_scaled_activation); + false, !rowwise_amax_is_row_scaled); size_t scale_mismatches_num = 0; compare_scaling_factors("scales", output.rowwise_cpu_scale_inv_ptr(), @@ -689,14 +689,14 @@ void performTest(float (*OP)(const float), unpadded_blocks_Y, unpadded_blocks_X, scales_stride, scale_mismatches_num); - if (!row_scaled_activation) { + if (!rowwise_amax_is_row_scaled) { compare_scaling_factors("scales_t", output.columnwise_cpu_scale_inv_ptr(), ref_scales_t.get(), unpadded_blocks_Y_t, unpadded_blocks_X_t, scales_stride_t, scale_mismatches_num); } - if (row_scaled_activation) { + if (rowwise_amax_is_row_scaled) { compare_rowwise_amax(output, ref_rowwise_amax); } } @@ -747,7 +747,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { const auto tensor_dims = std::get<1>(GetParam()); const DType input_type = std::get<2>(GetParam()); const bool use_fast_math = std::get<3>(GetParam()); - const bool row_scaled_activation = std::get<4>(GetParam()); + const bool rowwise_amax_is_row_scaled = std::get<4>(GetParam()); // Skip tests if the input tensor is 1D if (tensor_dims.size() < 2) { @@ -765,7 +765,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { } TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, - performTest(OP, tensor_dims, use_fast_math, row_scaled_activation); + performTest(OP, tensor_dims, use_fast_math, rowwise_amax_is_row_scaled); ); } diff --git a/tests/cpp/operator/test_dequantize_nvfp4.cu b/tests/cpp/operator/test_dequantize_nvfp4.cu index 87dd0d3508..52a72a7e9b 100644 --- a/tests/cpp/operator/test_dequantize_nvfp4.cu +++ b/tests/cpp/operator/test_dequantize_nvfp4.cu @@ -90,7 +90,7 @@ float compute_amax(const test::Tensor &t, size_t rows, size_t cols) { // against a CPU reference computed from the quantized data. template void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, - const bool row_scaled_activation) { + const bool rowwise_amax_is_row_scaled) { using namespace test; DType otype = TypeInfo::dtype; @@ -99,7 +99,7 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, Tensor quantized("quantized", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); - if (row_scaled_activation) { + if (rowwise_amax_is_row_scaled) { quantized.set_tensor_amax_shape({rows}); quantized.set_rowwise_amax_is_row_scaled(true); } else if (rows > 0 && cols > 0) { @@ -143,7 +143,7 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, // Dequantize NVFP4 with GEMM-swizzled scales and compare against compact path. template void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, - const bool row_scaled_activation) { + const bool rowwise_amax_is_row_scaled) { using namespace test; DType otype = TypeInfo::dtype; @@ -152,7 +152,7 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, Tensor quantized_compact("quantized_compact", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); - if (row_scaled_activation) { + if (rowwise_amax_is_row_scaled) { quantized_compact.set_tensor_amax_shape({rows}); quantized_compact.set_rowwise_amax_is_row_scaled(true); } else if (rows > 0 && cols > 0) { @@ -174,7 +174,7 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, // Create tensor with same FP4 data but swizzled scales Tensor quantized_swizzled("quantized_swizzled", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); - if (row_scaled_activation) { + if (rowwise_amax_is_row_scaled) { quantized_swizzled.set_tensor_amax_shape({rows}); quantized_swizzled.set_rowwise_amax_is_row_scaled(true); } else { @@ -185,7 +185,7 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, // Copy amax and scale from compact to swizzled before FP4 data, // since from_cpu() uploads all CPU buffers (including zero-init data). quantized_compact.to_cpu(); - if (row_scaled_activation) { + if (rowwise_amax_is_row_scaled) { quantized_swizzled.copy_tensor_amax_from(quantized_compact); } else { quantized_swizzled.set_tensor_amax(quantized_compact.amax()); @@ -256,11 +256,11 @@ TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4) const auto tensor_size = std::get<0>(GetParam()); const DType output_type = std::get<1>(GetParam()); - const bool row_scaled_activation = std::get<2>(GetParam()); + const bool rowwise_amax_is_row_scaled = std::get<2>(GetParam()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType, performTest_dequantize_nvfp4( - tensor_size.first, tensor_size.second, row_scaled_activation); + tensor_size.first, tensor_size.second, rowwise_amax_is_row_scaled); ); } @@ -294,11 +294,11 @@ TEST_P(DequantizeNVFP4SwizzledTestSuite, TestDequantizeNVFP4Swizzled) const auto tensor_size = std::get<0>(GetParam()); const DType output_type = std::get<1>(GetParam()); - const bool row_scaled_activation = std::get<2>(GetParam()); + const bool rowwise_amax_is_row_scaled = std::get<2>(GetParam()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType, performTest_dequantize_nvfp4_swizzled( - tensor_size.first, tensor_size.second, row_scaled_activation); + tensor_size.first, tensor_size.second, rowwise_amax_is_row_scaled); ); } diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index a5052e9726..23512c9991 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -27,7 +27,7 @@ def check_nvfp4_gemm_versus_reference( *, x_columnwise: bool = False, w_columnwise: bool = False, - row_scaled_activation: bool = False, + rowwise_amax_is_row_scaled: bool = False, ): te_dtype = tex.DType.kFloat4E2M1 @@ -58,7 +58,7 @@ def check_nvfp4_gemm_versus_reference( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, - row_scaled_activation=row_scaled_activation, + rowwise_amax_is_row_scaled=rowwise_amax_is_row_scaled, ) w_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -118,11 +118,11 @@ def check_nvfp4_gemm_versus_reference( x_ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, rowwise=True, - columnwise=not row_scaled_activation, + columnwise=not rowwise_amax_is_row_scaled, pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), - row_scaled_activation=row_scaled_activation, + rowwise_amax_is_row_scaled=rowwise_amax_is_row_scaled, ) w_ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, @@ -178,7 +178,7 @@ def check_nvfp4_gemm_versus_reference( x_nvfp4_native.update_usage(rowwise_usage=False) if w_columnwise: w_nvfp4_native.update_usage(rowwise_usage=False) - if row_scaled_activation: + if rowwise_amax_is_row_scaled: layout = ("T" if transa else "N") + ("T" if transb else "N") y_native = general_gemm( w_nvfp4_native, @@ -248,7 +248,7 @@ def check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, - row_scaled_activation=True, + rowwise_amax_is_row_scaled=True, ) w_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -338,7 +338,7 @@ def check_nvfp4_row_scaled_gemm_matches_emulated( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, - row_scaled_activation=True, + rowwise_amax_is_row_scaled=True, ) x_tensorwise_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -416,7 +416,9 @@ def check_nvfp4_row_scaled_gemm_matches_emulated( ], ids=["rowxrow", "colxrow", "colxcol"], ) -@pytest.mark.parametrize("row_scaled_activation", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) +@pytest.mark.parametrize( + "rowwise_amax_is_row_scaled", [False, True], ids=["nvfp4", "nvfp4_row_scaled"] +) def test_nvfp4_gemm_versus_reference( M: int, K: int, @@ -427,13 +429,13 @@ def test_nvfp4_gemm_versus_reference( accumulate: bool, is_x_columnwise: bool, is_w_columnwise: bool, - row_scaled_activation: bool, + rowwise_amax_is_row_scaled: bool, ): - if row_scaled_activation: + if rowwise_amax_is_row_scaled: if accumulate: pytest.skip("Row-scaled NVFP4 GEMM output rescale does not support accumulation") if is_x_columnwise: - pytest.skip("Row-scaled NVFP4 GEMM output rescale requires rowwise activation usage") + pytest.skip("Row-scaled NVFP4 GEMM output rescale requires rowwise RHS usage") check_nvfp4_gemm_versus_reference( x_dtype=x_dtype, @@ -445,7 +447,7 @@ def test_nvfp4_gemm_versus_reference( accumulate=accumulate, x_columnwise=is_x_columnwise, w_columnwise=is_w_columnwise, - row_scaled_activation=row_scaled_activation, + rowwise_amax_is_row_scaled=rowwise_amax_is_row_scaled, ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index 4be06bba42..7c56fc1c07 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -31,7 +31,7 @@ def check_quantization_nvfp4_versus_reference( swizzled_scale: bool, use_cpp_allocator: bool, with_2d_quantization: bool, - row_scaled_activation: bool = False, + rowwise_amax_is_row_scaled: bool = False, ) -> None: te_dtype = tex.DType.kFloat4E2M1 @@ -53,7 +53,7 @@ def check_quantization_nvfp4_versus_reference( with_rht=False, with_post_rht_amax=False, with_2d_quantization=with_2d_quantization, - row_scaled_activation=row_scaled_activation, + rowwise_amax_is_row_scaled=rowwise_amax_is_row_scaled, ) if use_cpp_allocator: x_nvfp4_sut = nvfp4_quantizer(x) @@ -82,11 +82,11 @@ def check_quantization_nvfp4_versus_reference( ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, rowwise=True, - columnwise=(return_transpose and not row_scaled_activation), + columnwise=(return_transpose and not rowwise_amax_is_row_scaled), pow_2_scales=False, eps=0.0, quant_tile_shape=quant_tile_shape, - row_scaled_activation=row_scaled_activation, + rowwise_amax_is_row_scaled=rowwise_amax_is_row_scaled, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -119,7 +119,7 @@ def check_quantization_nvfp4_versus_reference( torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) - if return_transpose and not row_scaled_activation: + if return_transpose and not rowwise_amax_is_row_scaled: torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) # Compare only the valid portion of transpose scale tensors @@ -165,7 +165,9 @@ def check_quantization_nvfp4_versus_reference( @pytest.mark.parametrize( "with_2d_quantization", [True, False], ids=["2d_quantization", "1d_quantization"] ) -@pytest.mark.parametrize("row_scaled_activation", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) +@pytest.mark.parametrize( + "rowwise_amax_is_row_scaled", [False, True], ids=["nvfp4", "nvfp4_row_scaled"] +) def test_quantization_block_tiling_versus_reference( x_dtype: torch.dtype, M: int, @@ -174,9 +176,9 @@ def test_quantization_block_tiling_versus_reference( swizzled_scale: bool, use_cpp_allocator: bool, with_2d_quantization: bool, - row_scaled_activation: bool, + rowwise_amax_is_row_scaled: bool, ) -> None: - if row_scaled_activation and with_2d_quantization: + if rowwise_amax_is_row_scaled and with_2d_quantization: pytest.skip("Row-scaled NVFP4 does not support 2D quantization") check_quantization_nvfp4_versus_reference( @@ -187,7 +189,7 @@ def test_quantization_block_tiling_versus_reference( swizzled_scale=swizzled_scale, use_cpp_allocator=use_cpp_allocator, with_2d_quantization=with_2d_quantization, - row_scaled_activation=row_scaled_activation, + rowwise_amax_is_row_scaled=rowwise_amax_is_row_scaled, ) @@ -204,7 +206,9 @@ def test_quantization_block_tiling_versus_reference( @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) -@pytest.mark.parametrize("row_scaled_activation", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) +@pytest.mark.parametrize( + "rowwise_amax_is_row_scaled", [False, True], ids=["nvfp4", "nvfp4_row_scaled"] +) def test_nvfp4_quantization_extrema_versus_reference( x_dtype: torch.dtype, M: int, @@ -212,7 +216,7 @@ def test_nvfp4_quantization_extrema_versus_reference( extrema_high: bool, return_transpose: bool, use_cpp_allocator: bool, - row_scaled_activation: bool, + rowwise_amax_is_row_scaled: bool, ): te_dtype = tex.DType.kFloat4E2M1 @@ -234,7 +238,7 @@ def test_nvfp4_quantization_extrema_versus_reference( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, - row_scaled_activation=row_scaled_activation, + rowwise_amax_is_row_scaled=rowwise_amax_is_row_scaled, ) if use_cpp_allocator: @@ -261,11 +265,11 @@ def test_nvfp4_quantization_extrema_versus_reference( ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, rowwise=True, - columnwise=(return_transpose and not row_scaled_activation), + columnwise=(return_transpose and not rowwise_amax_is_row_scaled), pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), - row_scaled_activation=row_scaled_activation, + rowwise_amax_is_row_scaled=rowwise_amax_is_row_scaled, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -286,7 +290,7 @@ def test_nvfp4_quantization_extrema_versus_reference( sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]] torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) - if return_transpose and not row_scaled_activation: + if return_transpose and not rowwise_amax_is_row_scaled: torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) ref_sx_t_shape = sx_t_ref.shape sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] @@ -313,14 +317,16 @@ def test_nvfp4_quantization_extrema_versus_reference( @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) -@pytest.mark.parametrize("row_scaled_activation", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) +@pytest.mark.parametrize( + "rowwise_amax_is_row_scaled", [False, True], ids=["nvfp4", "nvfp4_row_scaled"] +) def test_nvfp4_quantization_boundary_values( x_dtype: torch.dtype, M: int, N: int, return_transpose: bool, use_cpp_allocator: bool, - row_scaled_activation: bool, + rowwise_amax_is_row_scaled: bool, ): """ Stress rounding/threshold behavior by placing values just below/above @@ -356,7 +362,7 @@ def test_nvfp4_quantization_boundary_values( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, - row_scaled_activation=row_scaled_activation, + rowwise_amax_is_row_scaled=rowwise_amax_is_row_scaled, ) if use_cpp_allocator: @@ -383,11 +389,11 @@ def test_nvfp4_quantization_boundary_values( ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, rowwise=True, - columnwise=(return_transpose and not row_scaled_activation), + columnwise=(return_transpose and not rowwise_amax_is_row_scaled), pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), - row_scaled_activation=row_scaled_activation, + rowwise_amax_is_row_scaled=rowwise_amax_is_row_scaled, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -409,7 +415,7 @@ def test_nvfp4_quantization_boundary_values( sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]] torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) - if return_transpose and not row_scaled_activation: + if return_transpose and not rowwise_amax_is_row_scaled: torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) ref_sx_t_shape = sx_t_ref.shape sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] @@ -435,14 +441,16 @@ def test_nvfp4_quantization_boundary_values( @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) -@pytest.mark.parametrize("row_scaled_activation", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) +@pytest.mark.parametrize( + "rowwise_amax_is_row_scaled", [False, True], ids=["nvfp4", "nvfp4_row_scaled"] +) def test_nvfp4_quantization_noncontiguous_inputs( x_dtype: torch.dtype, M: int, N: int, return_transpose: bool, use_cpp_allocator: bool, - row_scaled_activation: bool, + rowwise_amax_is_row_scaled: bool, ): te_dtype = tex.DType.kFloat4E2M1 @@ -464,7 +472,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, - row_scaled_activation=row_scaled_activation, + rowwise_amax_is_row_scaled=rowwise_amax_is_row_scaled, ) if use_cpp_allocator: @@ -491,11 +499,11 @@ def test_nvfp4_quantization_noncontiguous_inputs( ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, rowwise=True, - columnwise=(return_transpose and not row_scaled_activation), + columnwise=(return_transpose and not rowwise_amax_is_row_scaled), pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), - row_scaled_activation=row_scaled_activation, + rowwise_amax_is_row_scaled=rowwise_amax_is_row_scaled, ) x_nvfp4_ref = ref_quantizer.quantize(x_nc) @@ -518,7 +526,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]] torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) - if return_transpose and not row_scaled_activation: + if return_transpose and not rowwise_amax_is_row_scaled: torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) ref_sx_t_shape = sx_t_ref.shape sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index ae93765183..81a51335b1 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -522,19 +522,21 @@ def test_nvfp4_row_scaled_quantizer_roles(): mode="forward", num_quantizers=3, ).make_quantizers() - assert [q.row_scaled_activation for q in forward_quantizers] == [True, False, True] + assert [q.rowwise_amax_is_row_scaled for q in forward_quantizers] == [True, False, True] backward_quantizers = NVFP4BlockScalingRecipeState( recipe, mode="backward", num_quantizers=2, ).make_quantizers() - assert [q.row_scaled_activation for q in backward_quantizers] == [False, False] + assert [q.rowwise_amax_is_row_scaled for q in backward_quantizers] == [False, False] @pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=str) -@pytest.mark.parametrize("row_scaled_activation", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) +@pytest.mark.parametrize( + "rowwise_amax_is_row_scaled", [False, True], ids=["nvfp4", "nvfp4_row_scaled"] +) @pytest.mark.parametrize( "M, N", [ @@ -550,8 +552,8 @@ def test_nvfp4_row_scaled_quantizer_roles(): (8192, 8192), ], ) -def test_fp4_dequantize(dtype, row_scaled_activation, M, N): - q = NVFP4Quantizer(row_scaled_activation=row_scaled_activation) +def test_fp4_dequantize(dtype, rowwise_amax_is_row_scaled, M, N): + q = NVFP4Quantizer(rowwise_amax_is_row_scaled=rowwise_amax_is_row_scaled) a = torch.rand((M, N)).cuda().to(dtype=dtype) starting_tensor = q(a) dequantized_tensor = starting_tensor.dequantize() diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 1cdfc7da2b..9ed199d798 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -252,7 +252,7 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens int32_t cols = grad_tensor->flat_last_dim(); auto dtype = grad_tensor->dtype(); NVTE_CHECK(!output_tensor->rowwise_amax_is_row_scaled, - "Row-scaled NVFP4 quantization is only supported for forward quantization."); + "Backward NVFP4 quantization does not support row-scaled outputs."); bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && (cols % 32 == 0) && output_tensor->has_data(); diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index cd31a32dd6..50253c5629 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -1332,7 +1332,7 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, // If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to // return the transposed data. // TODO(Frank): Is there a better way to do this? - bool return_transpose = output->has_columnwise_data() && !rowwise_amax_is_row_scaled; + bool return_transpose = output->has_columnwise_data(); if (!use_2d_quantization && (input.dtype() == DType::kBFloat16)) { quantize_transpose_tuned_1D(input, noop, output, quant_config, stream); @@ -1354,6 +1354,8 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); NVTE_CHECK(!rowwise_amax_is_row_scaled || output->amax.dptr != nullptr, "Row-scaled NVFP4 quantization requires rowwise amax."); + NVTE_CHECK(!rowwise_amax_is_row_scaled || !output->has_columnwise_data(), + "Row-scaled NVFP4 quantization does not produce columnwise output."); NVTE_CHECK(!output->with_gemm_swizzled_scales, "Output must have scales in compact format."); if (return_transpose) { NVTE_CHECK(output->has_columnwise_data(), "NVFP4 transposed output tensor must be allocated."); diff --git a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh index be72dfa80f..9bdb4bb433 100644 --- a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh +++ b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh @@ -695,7 +695,7 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, // If transposed output is allocated, return the transposed data // Otherwise, it's not necesary to return the transposed data. - const bool return_transpose = output->has_columnwise_data() && !rowwise_amax_is_row_scaled; + const bool return_transpose = output->has_columnwise_data(); checkCuDriverContext(stream); CheckNoopTensor(*noop, "cast_noop"); @@ -708,6 +708,8 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); NVTE_CHECK(!rowwise_amax_is_row_scaled || output->amax.dptr != nullptr, "Row-scaled NVFP4 quantization requires rowwise amax."); + NVTE_CHECK(!rowwise_amax_is_row_scaled || !output->has_columnwise_data(), + "Row-scaled NVFP4 quantization does not produce columnwise output."); if (return_transpose) { NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index bf9394c988..9fe692dd2d 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -440,7 +440,7 @@ void nvte_grouped_bias_add(const NVTEGroupedTensor output, const NVTEGroupedTens /*! \brief Grouped Scaled Bias add for grouped GEMM outputs. * * output[row,col] += bias[col] * scale[row], where biases are per-group -* and scales are per-token (per-row across all groups). +* and scales are per-row across all groups. * Requires uniform last-dimension across all output tensors and bias tensors. */ void nvte_grouped_scaled_bias_add(const NVTEGroupedTensor output, const NVTEGroupedTensor bias, diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index e31743eb4b..12e2fcb386 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -90,22 +90,17 @@ def _nvfp4_row_scaled_gemm_inputs( B_metadata = B.get_metadata() assert B._rowwise_amax_is_row_scaled - if B._amax_rowwise is not None: - activation_amax = B._amax_rowwise - assert activation_amax.numel() == 0 or activation_amax.numel() > 1 - B_metadata["amax_rowwise"] = activation_amax.new_ones(1) - else: - activation_amax = B._amax_columnwise - assert activation_amax is not None - assert activation_amax.numel() == 0 or activation_amax.numel() > 1 - B_metadata["amax_columnwise"] = activation_amax.new_ones(1) + rhs_rowwise_amax = B._amax_rowwise + assert rhs_rowwise_amax is not None + assert rhs_rowwise_amax.numel() == 0 or rhs_rowwise_amax.numel() > 1 + B_metadata["amax_rowwise"] = rhs_rowwise_amax.new_ones(1) B_metadata["rowwise_amax_is_row_scaled"] = False - assert activation_amax.dtype == torch.float32 and weight_amax.dtype == torch.float32 + assert rhs_rowwise_amax.dtype == torch.float32 and weight_amax.dtype == torch.float32 return ( NVFP4TensorStorage(**A_metadata), NVFP4TensorStorage(**B_metadata), - (activation_amax * weight_amax).view(-1, 1), + (rhs_rowwise_amax * weight_amax).view(-1, 1), ) diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 8593dbc331..9bbbc270d8 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -320,7 +320,8 @@ class NVFP4Quantizer : public Quantizer { // 2D block scaling bool with_2d_quantization; bool stochastic_rounding; - bool row_scaled_activation; + // Whether tensors emitted by this quantizer store one rowwise amax per tensor row. + bool rowwise_amax_is_row_scaled; int rht_matrix_random_sign_mask_t; at::Tensor rht_matrix; diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 68d016f7a9..fff2ba1edc 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -42,7 +42,7 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->row_scaled_activation || + if (nvfp4_quantizer_cpp->rowwise_amax_is_row_scaled || (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; @@ -155,7 +155,7 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->row_scaled_activation || + if (nvfp4_quantizer_cpp->rowwise_amax_is_row_scaled || (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index c64f84bf3d..2e92d2eb80 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -152,7 +152,7 @@ std::vector dact_dbias( } else if (detail::IsNVFP4Quantizers(quantizer_py.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->row_scaled_activation || + if (nvfp4_quantizer_cpp->rowwise_amax_is_row_scaled || (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index abd0ce5189..263f3bdb7b 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -798,7 +798,7 @@ std::tuple, std::vector, bool> bulk_alloc // Quantization parameters const auto rowwise_usage = quantizer_cpp_list[0]->rowwise_usage; - const bool rowwise_amax_is_row_scaled = quantizer_cpp_list[0]->row_scaled_activation; + const bool rowwise_amax_is_row_scaled = quantizer_cpp_list[0]->rowwise_amax_is_row_scaled; NVTE_CHECK(!rowwise_amax_is_row_scaled || rowwise_usage, "Row-scaled NVFP4 quantization requires rowwise usage."); const auto columnwise_usage = @@ -1284,7 +1284,14 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, nvte_tensor_output_list.push_back(output_list[i].data()); } - if (quantizer.row_scaled_activation) { + const bool rowwise_amax_is_row_scaled = output_list.front().get_rowwise_amax_is_row_scaled(); + NVTE_CHECK(std::all_of(output_list.begin(), output_list.end(), + [rowwise_amax_is_row_scaled](const TensorWrapper &output) { + return output.get_rowwise_amax_is_row_scaled() == + rowwise_amax_is_row_scaled; + }), + "All NVFP4 split-quantize outputs must use the same rowwise amax scaling mode."); + if (rowwise_amax_is_row_scaled) { NVTE_CHECK(!quantizer.with_rht, "Row-scaled NVFP4 split quantize does not support RHT."); NVTE_CHECK(!quantizer.with_2d_quantization, "Row-scaled NVFP4 split quantize does not support 2D quantization."); @@ -1299,7 +1306,6 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, const size_t cols = input_ndim > 0 ? input_list[i].size(input_ndim - 1) : 1; NVTE_CHECK(cols % 16 == 0, "Row-scaled NVFP4 split quantize requires split inner dim divisible by 16."); - output_list[i].set_rowwise_amax_is_row_scaled(true); QuantizationConfigWrapper quant_config; nvte_quantize_v2(input_list[i].data(), output_list[i].data(), quant_config, stream); } @@ -1400,10 +1406,11 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, "NVFP4 split-quantize does not support 2D quantization"); NVTE_CHECK(!quantizer.with_amax_reduction, "NVFP4 split-quantize does not support amax reduction"); + const bool rowwise_amax_is_row_scaled = output_list.front().get_rowwise_amax_is_row_scaled(); // Check input tensor shape const size_t input_last_dim = input.ndim() > 0 ? input.size(input.ndim() - 1) : 1; - if (quantizer.row_scaled_activation) { + if (rowwise_amax_is_row_scaled) { NVTE_CHECK(input_last_dim % 16 == 0, "Row-scaled NVFP4 split-quantize requires inner dim to be multiple of 16."); } else { @@ -1485,19 +1492,20 @@ std::vector split_quantize(const at::Tensor &tensor, [](const py::handle &quantizer) -> bool { return detail::IsNVFP4Quantizers(quantizer.ptr()); }); - const bool all_nvfp4_row_scaled_activation = + const bool all_nvfp4_rowwise_amax_is_row_scaled = all_nvfp4_quantizers && - std::all_of(quantizer_cpp_list.begin(), quantizer_cpp_list.end(), - [](const std::unique_ptr &quantizer) -> bool { - return static_cast(quantizer.get())->row_scaled_activation; - }); + std::all_of( + quantizer_cpp_list.begin(), quantizer_cpp_list.end(), + [](const std::unique_ptr &quantizer) -> bool { + return static_cast(quantizer.get())->rowwise_amax_is_row_scaled; + }); // Choose implementation for allocating and populating tensors enum class AllocationMethod { UNFUSED, BULK_FP8_BLOCKWISE, BULK_MXFP8, BULK_NVFP4 }; enum class QuantizationMethod { UNFUSED, FUSED_NVFP4 }; AllocationMethod allocation_method = AllocationMethod::UNFUSED; QuantizationMethod quantization_method = QuantizationMethod::UNFUSED; - if (all_nvfp4_row_scaled_activation) { + if (all_nvfp4_rowwise_amax_is_row_scaled) { quantization_method = QuantizationMethod::FUSED_NVFP4; } if (!disable_bulk_allocation) { @@ -1550,7 +1558,7 @@ std::vector split_quantize(const at::Tensor &tensor, bool contiguous_data_and_scale = false; std::tie(output_py_list, output_cpp_list, contiguous_data_and_scale) = bulk_allocate_nvfp4_tensors(split_shapes, quantizer_list, nvfp4_quantizers); - if (!all_nvfp4_row_scaled_activation && !input_shape.empty() && + if (!all_nvfp4_rowwise_amax_is_row_scaled && !input_shape.empty() && input_shape.back() % 128 != 0) { static std::once_flag once_unfused_nvfp4_fallback_warning; std::call_once(once_unfused_nvfp4_fallback_warning, []() { @@ -1561,7 +1569,7 @@ std::vector split_quantize(const at::Tensor &tensor, }); quantization_method = QuantizationMethod::UNFUSED; } - if (!all_nvfp4_row_scaled_activation && !contiguous_data_and_scale) { + if (!all_nvfp4_rowwise_amax_is_row_scaled && !contiguous_data_and_scale) { // Avoid fused quantize kernel if data is not contiguous quantization_method = QuantizationMethod::UNFUSED; } diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 34f6de658d..3a38025a9e 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -120,7 +120,7 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->row_scaled_activation || + if (nvfp4_quantizer_cpp->rowwise_amax_is_row_scaled || (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; @@ -358,7 +358,7 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->row_scaled_activation || + if (nvfp4_quantizer_cpp->rowwise_amax_is_row_scaled || (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index be63bec98a..ee28852c32 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1696,7 +1696,7 @@ NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantize this->with_post_rht_amax = quantizer.attr("with_post_rht_amax").cast(); this->with_2d_quantization = quantizer.attr("with_2d_quantization").cast(); this->stochastic_rounding = quantizer.attr("stochastic_rounding").cast(); - this->row_scaled_activation = quantizer.attr("row_scaled_activation").cast(); + this->rowwise_amax_is_row_scaled = quantizer.attr("rowwise_amax_is_row_scaled").cast(); // Get amax reduction group if needed for NVFP4 AG const bool with_amax_reduction = quantizer.attr("with_amax_reduction").cast(); @@ -1748,10 +1748,10 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve NVTE_CHECK(flat_last_dim % NVFP4_BLOCK_SIZE == 0, "NVFP4 requires tensor dims that are divisible by ", NVFP4_BLOCK_SIZE, " (got shape=", shape, ")"); - const bool rowwise_amax_is_row_scaled = this->row_scaled_activation; + const bool rowwise_amax_is_row_scaled = this->rowwise_amax_is_row_scaled; NVTE_CHECK(!rowwise_amax_is_row_scaled || rowwise_usage, "Row-scaled NVFP4 quantization requires rowwise usage."); - const bool columnwise_usage = this->columnwise_usage && !this->row_scaled_activation; + const bool columnwise_usage = this->columnwise_usage && !rowwise_amax_is_row_scaled; const auto rowwise_scale_inv_shape = get_scale_shape(shape, false); const auto columnwise_scale_inv_shape = get_scale_shape(shape, true); @@ -1901,9 +1901,10 @@ std::pair NVFP4Quantizer::create_grouped_tenso std::optional rowwise_amax; std::optional columnwise_amax; const std::vector logical_shape_vec = {logical_first_dim, logical_last_dim}; - NVTE_CHECK(!this->row_scaled_activation || rowwise_usage, + const bool rowwise_amax_is_row_scaled = this->rowwise_amax_is_row_scaled; + NVTE_CHECK(!rowwise_amax_is_row_scaled || rowwise_usage, "Row-scaled NVFP4 grouped quantization requires rowwise usage."); - const bool columnwise_usage = this->columnwise_usage && !this->row_scaled_activation; + const bool columnwise_usage = this->columnwise_usage && !rowwise_amax_is_row_scaled; const int64_t total_data_elements = total_elements / 2; @@ -1912,7 +1913,10 @@ std::pair NVFP4Quantizer::create_grouped_tenso const auto scale_shape = get_scale_shape(logical_shape_vec, false); const int64_t total_scale_elements = static_cast(product(scale_shape)); rowwise_scale_inv = at::empty({total_scale_elements}, uint8_opts); - rowwise_amax = at::empty({static_cast(num_tensors)}, float_opts); + const int64_t amax_elements = rowwise_amax_is_row_scaled + ? static_cast(logical_first_dim) + : static_cast(num_tensors); + rowwise_amax = at::empty({amax_elements}, float_opts); } if (columnwise_usage) { @@ -1970,6 +1974,7 @@ std::pair NVFP4Quantizer::create_grouped_tenso kwargs["last_dims"] = py::none(); kwargs["tensor_offsets"] = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(); kwargs["with_gemm_swizzled_scales"] = this->optimize_for_gemm; + kwargs["rowwise_amax_is_row_scaled"] = py::cast(rowwise_amax_is_row_scaled); PyObject* result = PyObject_Call(GroupedTensorClass.ptr(), args.ptr(), kwargs.ptr()); if (result == nullptr) { PyErr_Print(); @@ -2050,10 +2055,10 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( } } const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; - const bool rowwise_amax_is_row_scaled = this->row_scaled_activation; + const bool rowwise_amax_is_row_scaled = this->rowwise_amax_is_row_scaled; NVTE_CHECK(!rowwise_amax_is_row_scaled || rowwise_usage, "Row-scaled NVFP4 quantization requires rowwise usage."); - const bool columnwise_usage = this->columnwise_usage && !this->row_scaled_activation; + const bool columnwise_usage = this->columnwise_usage && !rowwise_amax_is_row_scaled; tensor.attr("_rowwise_amax_is_row_scaled") = py::cast(rowwise_amax_is_row_scaled); // Coerce row-wise data @@ -2267,8 +2272,8 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou } size_t cols = input.size(input.ndim() - 1); - if (this->row_scaled_activation) { - out.set_rowwise_amax_is_row_scaled(true); + const bool rowwise_amax_is_row_scaled = out.get_rowwise_amax_is_row_scaled(); + if (rowwise_amax_is_row_scaled) { NVTE_CHECK(!this->with_rht, "Row-scaled NVFP4 quantization does not support RHT."); NVTE_CHECK(!this->with_2d_quantization, "Row-scaled NVFP4 quantization does not support 2D quantization."); @@ -2277,8 +2282,6 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou NVTE_CHECK(!this->with_amax_reduction, "Row-scaled NVFP4 quantization does not support amax reduction."); NVTE_CHECK(cols % 16 == 0, "Row-scaled NVFP4 quantization requires last dim divisible by 16."); - } else { - out.set_rowwise_amax_is_row_scaled(false); } // Restriction for the RHT cast fusion kernel because we are using MMA hardware for computing RHT @@ -2347,7 +2350,7 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou "Use with_post_rht_amax=true instead."); } } else { // Without RHT - if (compute_amax && !this->row_scaled_activation) { + if (compute_amax && !rowwise_amax_is_row_scaled) { // Amax pointers auto rowwise_amax_ptr = out.get_amax().data_ptr; auto columnwise_amax_ptr = out.get_columnwise_amax().data_ptr; @@ -2448,7 +2451,7 @@ void NVFP4Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out, } void NVFP4Quantizer::quantize_with_amax(TensorWrapper& input, TensorWrapper& out) { - NVTE_CHECK(!this->row_scaled_activation, + NVTE_CHECK(!out.get_rowwise_amax_is_row_scaled(), "quantize_with_amax is not supported for row-scaled NVFP4 quantization."); // Update output tensor amaxes with input tensor amax auto input_amax_ptr = input.amax(); diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py index 252e819dc7..4fb763c600 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py @@ -350,18 +350,18 @@ def __init__( pow_2_scales: bool = False, eps: float = 0.0, quant_tile_shape: Tuple[int, int] = (1, 16), - row_scaled_activation: bool = False, + rowwise_amax_is_row_scaled: bool = False, with_rht: bool = False, with_random_sign_mask: bool = True, ): - super().__init__(rowwise=rowwise, columnwise=columnwise and not row_scaled_activation) + super().__init__(rowwise=rowwise, columnwise=columnwise and not rowwise_amax_is_row_scaled) self.internal = True self.dtype = dtype self.pow_2_scales = pow_2_scales self.eps = eps self.quant_tile_shape = quant_tile_shape - self.row_scaled_activation = row_scaled_activation + self.rowwise_amax_is_row_scaled = rowwise_amax_is_row_scaled self.with_rht = with_rht self.with_random_sign_mask = with_random_sign_mask @@ -622,10 +622,8 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ raise ValueError( f"MXFP4 only supports 1x32 tile shape, got {self.quant_tile_shape}" ) - if self.row_scaled_activation: - raise ValueError( - "Row-scaled activation is only supported for NVFP4 (non-pow2) mode." - ) + if self.rowwise_amax_is_row_scaled: + raise ValueError("Row-scaled NVFP4 is only supported for NVFP4 (non-pow2) mode.") # TODO(etsykunov): Fix bug where global_amax_row and # global_amax_col are not defined # global_amax = torch.empty(0, device=tensor.device, dtype=torch.float32) @@ -642,10 +640,10 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ if self.with_rht else tensor.t().contiguous() ) - if self.row_scaled_activation: + if self.rowwise_amax_is_row_scaled: if self.quant_tile_shape != (1, 16): raise ValueError( - "Row-scaled activation only supports NVFP4 1x16 tile shape, " + "Row-scaled NVFP4 only supports NVFP4 1x16 tile shape, " f"got {self.quant_tile_shape}" ) global_amax_row = torch.max(torch.abs(row_input), dim=1).values.to(torch.float32) @@ -674,7 +672,7 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ self.quant_tile_shape[1], self.quant_tile_shape[0], pow_2_scales=self.pow_2_scales, - rowwise_amax_is_row_scaled=self.row_scaled_activation, + rowwise_amax_is_row_scaled=self.rowwise_amax_is_row_scaled, eps=self.eps, ) if transpose_scales: diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index a1a63ea4ce..f5143ef789 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -1375,7 +1375,7 @@ def _make_quantizer(idx: int) -> NVFP4Quantizer: with_post_rht_amax=qparams.random_hadamard_transform, with_2d_quantization=qparams.fp4_2d_quantization, stochastic_rounding=qparams.stochastic_rounding, - row_scaled_activation=self.recipe.row_scaled_activation and idx % 3 != 1, + rowwise_amax_is_row_scaled=self.recipe.row_scaled_activation and idx % 3 != 1, ) return [_make_quantizer(idx) for idx in range(self.num_quantizers)] @@ -1390,7 +1390,7 @@ def _make_quantizer(idx: int) -> NVFP4Quantizer: with_post_rht_amax=self.recipe.fp4_quant_bwd_grad.random_hadamard_transform, with_2d_quantization=self.recipe.fp4_quant_bwd_grad.fp4_2d_quantization, stochastic_rounding=self.recipe.fp4_quant_bwd_grad.stochastic_rounding, - row_scaled_activation=False, + rowwise_amax_is_row_scaled=False, ) for _ in range(self.num_quantizers) ] diff --git a/transformer_engine/pytorch/tensor/grouped_tensor.py b/transformer_engine/pytorch/tensor/grouped_tensor.py index ab0c7484fc..99dee437cd 100644 --- a/transformer_engine/pytorch/tensor/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/grouped_tensor.py @@ -92,6 +92,7 @@ def __new__( requires_grad: bool = False, stride: Optional[List[int]] = None, with_gemm_swizzled_scales: bool = False, + rowwise_amax_is_row_scaled: bool = False, ): if ( shapes is not None @@ -164,6 +165,7 @@ def __new__( scale_inv_offsets=scale_inv_offsets, columnwise_scale_inv_offsets=columnwise_scale_inv_offsets, with_gemm_swizzled_scales=with_gemm_swizzled_scales, + rowwise_amax_is_row_scaled=rowwise_amax_is_row_scaled, ) return instance @@ -195,6 +197,7 @@ def copy_grouped_storage_metadata(dst: GroupedTensor, src: GroupedTensor) -> Non dst.logical_shape = src.logical_shape dst.quantized_tensors = src.quantized_tensors dst._with_gemm_swizzled_scales = src._with_gemm_swizzled_scales + dst.rowwise_amax_is_row_scaled = src.rowwise_amax_is_row_scaled def make_wrapper_like(src: GroupedTensor, requires_grad: bool) -> GroupedTensor: """Create a wrapper of the same type and tensor metadata as src.""" diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 01b9686b00..6ebcdd1db4 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -128,8 +128,8 @@ class NVFP4Quantizer(Quantizer): """Stochastic rounding, only applicable for gradients.""" stochastic_rounding: bool - """Row-scaled activation quantization path.""" - row_scaled_activation: bool + """Row-scaled NVFP4 quantization path.""" + rowwise_amax_is_row_scaled: bool """RHT matrix random sign mask""" rht_matrix_random_sign_mask_t: int @@ -146,7 +146,7 @@ def __init__( with_post_rht_amax: bool = False, with_2d_quantization: bool = False, stochastic_rounding: bool = False, - row_scaled_activation: bool = False, + rowwise_amax_is_row_scaled: bool = False, with_random_sign_mask: bool = True, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) @@ -157,7 +157,7 @@ def __init__( self.amax_reduction_group = amax_reduction_group self.with_2d_quantization = with_2d_quantization self.stochastic_rounding = stochastic_rounding - self.row_scaled_activation = row_scaled_activation + self.rowwise_amax_is_row_scaled = rowwise_amax_is_row_scaled self.rht_matrix_random_sign_mask_t = get_random_sign_mask_for_rht( with_random_sign_mask, torch.cuda.current_device() ) @@ -203,7 +203,7 @@ def copy(self) -> NVFP4Quantizer: with_post_rht_amax=self.with_post_rht_amax, with_2d_quantization=self.with_2d_quantization, stochastic_rounding=self.stochastic_rounding, - row_scaled_activation=self.row_scaled_activation, + rowwise_amax_is_row_scaled=self.rowwise_amax_is_row_scaled, ) quantizer.internal = self.internal quantizer.optimize_for_gemm = self.optimize_for_gemm @@ -335,8 +335,8 @@ def make_empty( scale_inv = torch.empty( scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory ) - # Allocate per tensor scale inverse. FP32 format. - amax_rows = flat_first_dim if self.row_scaled_activation else 1 + # Allocate global amax metadata. Row-scaled NVFP4 stores one value per row. + amax_rows = flat_first_dim if self.rowwise_amax_is_row_scaled else 1 amax_rowwise = torch.zeros( amax_rows, dtype=torch.float32, device=device, pin_memory=pin_memory ) @@ -345,7 +345,7 @@ def make_empty( columnwise_data = None columnwise_scale_inv = None amax_columnwise = None - columnwise_usage = self.columnwise_usage and not self.row_scaled_activation + columnwise_usage = self.columnwise_usage and not self.rowwise_amax_is_row_scaled if columnwise_usage: # enforce 2D shape to avoid [S, B, H] shape and B and be 1 # and the transposed shape is [H, S, B], so divide last dim by 2 gives zero @@ -381,7 +381,7 @@ def make_empty( quantizer=self, requires_grad=requires_grad, with_gemm_swizzled_scales=False, - rowwise_amax_is_row_scaled=self.row_scaled_activation, + rowwise_amax_is_row_scaled=self.rowwise_amax_is_row_scaled, ) def calibrate(self, tensor: torch.Tensor) -> None: diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index 784e1a158d..8801102e43 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -72,6 +72,7 @@ def _initialize_storage_fields( requires_grad: bool = False, stride: Optional[List[int]] = None, with_gemm_swizzled_scales: bool = False, + rowwise_amax_is_row_scaled: bool = False, ) -> None: """ Initialize a GroupedTensor. @@ -147,6 +148,7 @@ def _initialize_storage_fields( # Used as a convenience. instance.quantized_tensors = None instance._with_gemm_swizzled_scales = with_gemm_swizzled_scales + instance.rowwise_amax_is_row_scaled = rowwise_amax_is_row_scaled def __new__( cls, @@ -172,6 +174,7 @@ def __new__( requires_grad: bool = False, stride: Optional[List[int]] = None, with_gemm_swizzled_scales: bool = False, + rowwise_amax_is_row_scaled: bool = False, ): instance = object.__new__(cls) cls._initialize_storage_fields( @@ -197,6 +200,7 @@ def __new__( requires_grad=requires_grad, stride=stride, with_gemm_swizzled_scales=with_gemm_swizzled_scales, + rowwise_amax_is_row_scaled=rowwise_amax_is_row_scaled, ) return instance @@ -326,6 +330,7 @@ def clear(self) -> None: self.columnwise_scale_inv_offsets = None self.tensor_shapes = [] self.fake_dtype = torch.float32 + self.rowwise_amax_is_row_scaled = False def __repr__(self) -> str: """String representation of the GroupedTensorStorage.""" @@ -494,6 +499,7 @@ def copy(self) -> "GroupedTensorStorage": scale_inv_offsets=self.scale_inv_offsets, columnwise_scale_inv_offsets=self.columnwise_scale_inv_offsets, with_gemm_swizzled_scales=self._with_gemm_swizzled_scales, + rowwise_amax_is_row_scaled=self.rowwise_amax_is_row_scaled, ) @staticmethod @@ -604,6 +610,7 @@ def make_grouped_tensor( scale = None scale_inv_offsets = None columnwise_scale_inv_offsets = None + rowwise_amax_is_row_scaled = False if no_quantization: assert dtype is not None, "dtype must be provided for unquantized GroupedTensor" if rowwise_usage: @@ -662,10 +669,10 @@ def make_grouped_tensor( # Amax buffer for delayed scaling - one per tensor amax = torch.empty(num_tensors, dtype=torch.float32, device=device) elif quantizer._get_compatible_recipe().nvfp4(): - row_scaled_activation = getattr(quantizer, "row_scaled_activation", False) - columnwise_usage = columnwise_usage and not row_scaled_activation + rowwise_amax_is_row_scaled = quantizer.rowwise_amax_is_row_scaled + columnwise_usage = columnwise_usage and not rowwise_amax_is_row_scaled total_amax_elements = ( - sum(math.prod(s[:-1]) for s in shape) if row_scaled_activation else num_tensors + sum(math.prod(s[:-1]) for s in shape) if rowwise_amax_is_row_scaled else num_tensors ) if rowwise_usage: @@ -784,6 +791,7 @@ def make_grouped_tensor( with_gemm_swizzled_scales=( quantizer.optimize_for_gemm if quantizer is not None else False ), + rowwise_amax_is_row_scaled=rowwise_amax_is_row_scaled, ) grouped_tensor.quantized_tensors = grouped_tensor.split_into_quantized_tensors() return grouped_tensor @@ -897,7 +905,8 @@ def split_into_quantized_tensors( columnwise_scale_inv_offsets.append(cum) self.columnwise_scale_inv_offsets = columnwise_scale_inv_offsets nvfp4_rowwise_amax_offsets = None - if recipe.nvfp4() and getattr(self.quantizer, "row_scaled_activation", False): + rowwise_amax_is_row_scaled = self.rowwise_amax_is_row_scaled + if recipe.nvfp4() and rowwise_amax_is_row_scaled: cum = 0 nvfp4_rowwise_amax_offsets = [0] for i in range(self.num_tensors): @@ -1128,7 +1137,7 @@ def split_into_quantized_tensors( fp4_dtype=quantizer.dtype, quantizer=quantizer, with_gemm_swizzled_scales=quantizer.optimize_for_gemm, - rowwise_amax_is_row_scaled=getattr(quantizer, "row_scaled_activation", False), + rowwise_amax_is_row_scaled=rowwise_amax_is_row_scaled, ) result.append(tensor) From 94b05e37152ac176636faddb7901ed5c51902c94 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 5 May 2026 17:52:54 -0700 Subject: [PATCH 33/45] Clean up bias Signed-off-by: Ziang Li --- transformer_engine/pytorch/cpp_extensions/gemm.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 12e2fcb386..0c195963b3 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -245,19 +245,16 @@ def general_gemm( gemm_args[4] = fp32_out # out gemm_args[5] = None # quantization_params gemm_args[6] = TE_DType[torch.float32] # out_dtype + gemm_args[7] = None # bias out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*gemm_args, **kwargs) out_2d = out.reshape(-1, out.shape[-1]) assert rowwise_global_scales.dtype == torch.float32 and out.dtype == torch.float32 assert rowwise_global_scales.numel() == out_2d.shape[0] + out_2d.mul_(rowwise_global_scales) if bias is not None: - bias_cast = bias.to(dtype=torch.float32) - out_2d.sub_(bias_cast) - out_2d.mul_(rowwise_global_scales) - out_2d.add_(bias_cast) - else: - out_2d.mul_(rowwise_global_scales) + out_2d.add_(bias.to(dtype=torch.float32)) if requested_out is not None: requested_out.copy_(out.to(dtype=requested_out.dtype)) From 6c10ed2871b4f51140a517e58c6dbf2764ee44fa Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 5 May 2026 18:18:53 -0700 Subject: [PATCH 34/45] Clean up cast Signed-off-by: Ziang Li --- .../pytorch/csrc/extensions/cast.cpp | 86 +++++-------------- 1 file changed, 23 insertions(+), 63 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 263f3bdb7b..538f5282b2 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -799,10 +799,7 @@ std::tuple, std::vector, bool> bulk_alloc // Quantization parameters const auto rowwise_usage = quantizer_cpp_list[0]->rowwise_usage; const bool rowwise_amax_is_row_scaled = quantizer_cpp_list[0]->rowwise_amax_is_row_scaled; - NVTE_CHECK(!rowwise_amax_is_row_scaled || rowwise_usage, - "Row-scaled NVFP4 quantization requires rowwise usage."); - const auto columnwise_usage = - quantizer_cpp_list[0]->columnwise_usage && !rowwise_amax_is_row_scaled; + const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage; const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode(); const auto fp4_dtype = quantizer_cpp_list[0]->dtype; const bool with_gemm_swizzled_scales = false; /// TODO (tmoon) Enable based on optimize_for_gemm; @@ -880,8 +877,10 @@ std::tuple, std::vector, bool> bulk_alloc // Note: Multi-quantize kernel does not require contiguous amaxes. const auto offset = roundup(buffer_size, 16); amax_offsets.push_back(offset); - const size_t amax_size = - rowwise_amax_is_row_scaled ? 4 * flat_first_dim(rowwise_data_shapes[i]) : 4; + size_t amax_size = 4; + if (rowwise_amax_is_row_scaled) { + amax_size *= flat_first_dim(rowwise_data_shapes[i]); + } buffer_size = offset + amax_size; } @@ -895,9 +894,10 @@ std::tuple, std::vector, bool> bulk_alloc data_offsets[i], torch::kUInt8)); rowwise_scale_list.emplace_back( make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); - const std::vector amax_shape = - rowwise_amax_is_row_scaled ? std::vector{flat_first_dim(rowwise_data_shapes[i])} - : std::vector{1}; + std::vector amax_shape{1}; + if (rowwise_amax_is_row_scaled) { + amax_shape = {flat_first_dim(rowwise_data_shapes[i])}; + } amax_rowwise_list.emplace_back( make_torch_view(buffer, amax_shape, amax_offsets[i], torch::kFloat32)); } @@ -1284,34 +1284,6 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, nvte_tensor_output_list.push_back(output_list[i].data()); } - const bool rowwise_amax_is_row_scaled = output_list.front().get_rowwise_amax_is_row_scaled(); - NVTE_CHECK(std::all_of(output_list.begin(), output_list.end(), - [rowwise_amax_is_row_scaled](const TensorWrapper &output) { - return output.get_rowwise_amax_is_row_scaled() == - rowwise_amax_is_row_scaled; - }), - "All NVFP4 split-quantize outputs must use the same rowwise amax scaling mode."); - if (rowwise_amax_is_row_scaled) { - NVTE_CHECK(!quantizer.with_rht, "Row-scaled NVFP4 split quantize does not support RHT."); - NVTE_CHECK(!quantizer.with_2d_quantization, - "Row-scaled NVFP4 split quantize does not support 2D quantization."); - NVTE_CHECK(!quantizer.stochastic_rounding, - "Row-scaled NVFP4 split quantize does not support stochastic rounding."); - - for (size_t i = 0; i < num_tensors; i++) { - if (input_list[i].numel() == 0) { - continue; - } - const size_t input_ndim = input_list[i].ndim(); - const size_t cols = input_ndim > 0 ? input_list[i].size(input_ndim - 1) : 1; - NVTE_CHECK(cols % 16 == 0, - "Row-scaled NVFP4 split quantize requires split inner dim divisible by 16."); - QuantizationConfigWrapper quant_config; - nvte_quantize_v2(input_list[i].data(), output_list[i].data(), quant_config, stream); - } - return; - } - // In this case without RHT, the rowwise and colwise quantization are fused // we don't need separate rng states for rowwise and colwise bool need_separate_rng_states = false; @@ -1406,17 +1378,11 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, "NVFP4 split-quantize does not support 2D quantization"); NVTE_CHECK(!quantizer.with_amax_reduction, "NVFP4 split-quantize does not support amax reduction"); - const bool rowwise_amax_is_row_scaled = output_list.front().get_rowwise_amax_is_row_scaled(); // Check input tensor shape const size_t input_last_dim = input.ndim() > 0 ? input.size(input.ndim() - 1) : 1; - if (rowwise_amax_is_row_scaled) { - NVTE_CHECK(input_last_dim % 16 == 0, - "Row-scaled NVFP4 split-quantize requires inner dim to be multiple of 16."); - } else { - NVTE_CHECK(input_last_dim % 128 == 0, - "NVFP4 multi-quantize requires inner dim to be multiple of 128."); - } + NVTE_CHECK(input_last_dim % 128 == 0, + "NVFP4 multi-quantize requires inner dim to be multiple of 128."); // CUDA stream auto stream = at::cuda::getCurrentCUDAStream(); @@ -1488,26 +1454,12 @@ std::vector split_quantize(const at::Tensor &tensor, for (size_t i = 0; i < num_splits; i++) { quantizer_cpp_list.push_back(convert_quantizer(quantizer_list[i])); } - const bool all_nvfp4_quantizers = std::all_of(quantizer_list.begin(), quantizer_list.end(), - [](const py::handle &quantizer) -> bool { - return detail::IsNVFP4Quantizers(quantizer.ptr()); - }); - const bool all_nvfp4_rowwise_amax_is_row_scaled = - all_nvfp4_quantizers && - std::all_of( - quantizer_cpp_list.begin(), quantizer_cpp_list.end(), - [](const std::unique_ptr &quantizer) -> bool { - return static_cast(quantizer.get())->rowwise_amax_is_row_scaled; - }); // Choose implementation for allocating and populating tensors enum class AllocationMethod { UNFUSED, BULK_FP8_BLOCKWISE, BULK_MXFP8, BULK_NVFP4 }; enum class QuantizationMethod { UNFUSED, FUSED_NVFP4 }; AllocationMethod allocation_method = AllocationMethod::UNFUSED; QuantizationMethod quantization_method = QuantizationMethod::UNFUSED; - if (all_nvfp4_rowwise_amax_is_row_scaled) { - quantization_method = QuantizationMethod::FUSED_NVFP4; - } if (!disable_bulk_allocation) { if (std::all_of(quantizer_list.begin(), quantizer_list.end(), [](const py::handle &quantizer) -> bool { @@ -1519,9 +1471,17 @@ std::vector split_quantize(const at::Tensor &tensor, return detail::IsMXFP8Quantizers(quantizer.ptr()); })) { allocation_method = AllocationMethod::BULK_MXFP8; - } else if (all_nvfp4_quantizers) { + } else if (std::all_of(quantizer_list.begin(), quantizer_list.end(), + [](const py::handle &quantizer) -> bool { + return detail::IsNVFP4Quantizers(quantizer.ptr()); + })) { allocation_method = AllocationMethod::BULK_NVFP4; - quantization_method = QuantizationMethod::FUSED_NVFP4; + if (static_cast(quantizer_cpp_list.front().get()) + ->rowwise_amax_is_row_scaled) { + quantization_method = QuantizationMethod::UNFUSED; + } else { + quantization_method = QuantizationMethod::FUSED_NVFP4; + } } } @@ -1558,7 +1518,7 @@ std::vector split_quantize(const at::Tensor &tensor, bool contiguous_data_and_scale = false; std::tie(output_py_list, output_cpp_list, contiguous_data_and_scale) = bulk_allocate_nvfp4_tensors(split_shapes, quantizer_list, nvfp4_quantizers); - if (!all_nvfp4_rowwise_amax_is_row_scaled && !input_shape.empty() && + if (quantization_method == QuantizationMethod::FUSED_NVFP4 && !input_shape.empty() && input_shape.back() % 128 != 0) { static std::once_flag once_unfused_nvfp4_fallback_warning; std::call_once(once_unfused_nvfp4_fallback_warning, []() { @@ -1569,7 +1529,7 @@ std::vector split_quantize(const at::Tensor &tensor, }); quantization_method = QuantizationMethod::UNFUSED; } - if (!all_nvfp4_rowwise_amax_is_row_scaled && !contiguous_data_and_scale) { + if (!contiguous_data_and_scale) { // Avoid fused quantize kernel if data is not contiguous quantization_method = QuantizationMethod::UNFUSED; } From aa519d1f45e80872b080b93c7833446e6520dfcc Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 5 May 2026 20:16:51 -0700 Subject: [PATCH 35/45] Avoid silently disable column wise Signed-off-by: Ziang Li --- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 6 +- .../nvfp4/test_nvfp4_quantize_exact.py | 58 ++++++++++--------- tests/pytorch/test_recipe.py | 5 +- .../cast/nvfp4/quantize_transpose_nvfp4.cuh | 3 - .../pytorch/cpp_extensions/gemm.py | 11 +++- .../pytorch/csrc/extensions/cast.cpp | 5 ++ transformer_engine/pytorch/csrc/quantizer.cpp | 24 +++++--- .../custom_recipes/quantization_nvfp4.py | 9 ++- .../pytorch/tensor/nvfp4_tensor.py | 8 ++- .../tensor/storage/grouped_tensor_storage.py | 10 +++- 10 files changed, 91 insertions(+), 48 deletions(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index 23512c9991..1a23ec0096 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -53,7 +53,7 @@ def check_nvfp4_gemm_versus_reference( x_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, rowwise=True, - columnwise=True, + columnwise=not rowwise_amax_is_row_scaled, with_amax_reduction=False, amax_reduction_group=None, with_rht=False, @@ -243,7 +243,7 @@ def check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( x_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, rowwise=True, - columnwise=True, + columnwise=False, with_amax_reduction=False, amax_reduction_group=None, with_rht=False, @@ -333,7 +333,7 @@ def check_nvfp4_row_scaled_gemm_matches_emulated( x_row_scaled_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, rowwise=True, - columnwise=True, + columnwise=False, with_amax_reduction=False, amax_reduction_group=None, with_rht=False, diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index 7c56fc1c07..4f898feee3 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -16,6 +16,19 @@ recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) +def maybe_skip_row_scaled_unsupported_quantization( + rowwise_amax_is_row_scaled: bool, + return_transpose: bool, + with_2d_quantization: bool = False, +) -> None: + if not rowwise_amax_is_row_scaled: + return + if return_transpose: + pytest.skip("Row-scaled NVFP4 does not support columnwise usage") + if with_2d_quantization: + pytest.skip("Row-scaled NVFP4 does not support 2D quantization") + + def unpack_fp4(x: torch.Tensor) -> torch.Tensor: repeated = x.repeat_interleave(2, dim=1) repeated[:, 0::2] &= 0x0F @@ -33,6 +46,10 @@ def check_quantization_nvfp4_versus_reference( with_2d_quantization: bool, rowwise_amax_is_row_scaled: bool = False, ) -> None: + maybe_skip_row_scaled_unsupported_quantization( + rowwise_amax_is_row_scaled, return_transpose, with_2d_quantization + ) + te_dtype = tex.DType.kFloat4E2M1 # Setup device and random seed @@ -82,7 +99,7 @@ def check_quantization_nvfp4_versus_reference( ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, rowwise=True, - columnwise=(return_transpose and not rowwise_amax_is_row_scaled), + columnwise=return_transpose, pow_2_scales=False, eps=0.0, quant_tile_shape=quant_tile_shape, @@ -119,7 +136,7 @@ def check_quantization_nvfp4_versus_reference( torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) - if return_transpose and not rowwise_amax_is_row_scaled: + if return_transpose: torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) # Compare only the valid portion of transpose scale tensors @@ -127,10 +144,6 @@ def check_quantization_nvfp4_versus_reference( sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) torch.testing.assert_close(qx_amax_t, ref_amax_t, atol=0.0, rtol=0.0) - elif return_transpose: - assert qx_t is None - assert sx_t is None - assert qx_amax_t is None torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) @@ -178,9 +191,6 @@ def test_quantization_block_tiling_versus_reference( with_2d_quantization: bool, rowwise_amax_is_row_scaled: bool, ) -> None: - if rowwise_amax_is_row_scaled and with_2d_quantization: - pytest.skip("Row-scaled NVFP4 does not support 2D quantization") - check_quantization_nvfp4_versus_reference( x_dtype=x_dtype, M=M, @@ -218,6 +228,8 @@ def test_nvfp4_quantization_extrema_versus_reference( use_cpp_allocator: bool, rowwise_amax_is_row_scaled: bool, ): + maybe_skip_row_scaled_unsupported_quantization(rowwise_amax_is_row_scaled, return_transpose) + te_dtype = tex.DType.kFloat4E2M1 device = "cuda" @@ -265,7 +277,7 @@ def test_nvfp4_quantization_extrema_versus_reference( ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, rowwise=True, - columnwise=(return_transpose and not rowwise_amax_is_row_scaled), + columnwise=return_transpose, pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), @@ -290,16 +302,12 @@ def test_nvfp4_quantization_extrema_versus_reference( sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]] torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) - if return_transpose and not rowwise_amax_is_row_scaled: + if return_transpose: torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) ref_sx_t_shape = sx_t_ref.shape sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) torch.testing.assert_close(qx_amax_t, ref_amax_t, atol=0.0, rtol=0.0) - elif return_transpose: - assert qx_t is None - assert sx_t is None - assert qx_amax_t is None torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) @@ -333,6 +341,8 @@ def test_nvfp4_quantization_boundary_values( many potential bin edges within each 16-element microblock. Validates native vs reference byte-for-byte and scale parity. """ + maybe_skip_row_scaled_unsupported_quantization(rowwise_amax_is_row_scaled, return_transpose) + te_dtype = tex.DType.kFloat4E2M1 device = "cuda" @@ -389,7 +399,7 @@ def test_nvfp4_quantization_boundary_values( ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, rowwise=True, - columnwise=(return_transpose and not rowwise_amax_is_row_scaled), + columnwise=return_transpose, pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), @@ -415,16 +425,12 @@ def test_nvfp4_quantization_boundary_values( sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]] torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) - if return_transpose and not rowwise_amax_is_row_scaled: + if return_transpose: torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) ref_sx_t_shape = sx_t_ref.shape sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) torch.testing.assert_close(qx_amax_t, ref_amax_t, atol=0.0, rtol=0.0) - elif return_transpose: - assert qx_t is None - assert sx_t is None - assert qx_amax_t is None torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) @@ -452,6 +458,8 @@ def test_nvfp4_quantization_noncontiguous_inputs( use_cpp_allocator: bool, rowwise_amax_is_row_scaled: bool, ): + maybe_skip_row_scaled_unsupported_quantization(rowwise_amax_is_row_scaled, return_transpose) + te_dtype = tex.DType.kFloat4E2M1 device = "cuda" @@ -499,7 +507,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, rowwise=True, - columnwise=(return_transpose and not rowwise_amax_is_row_scaled), + columnwise=return_transpose, pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), @@ -526,15 +534,11 @@ def test_nvfp4_quantization_noncontiguous_inputs( sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]] torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) - if return_transpose and not rowwise_amax_is_row_scaled: + if return_transpose: torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) ref_sx_t_shape = sx_t_ref.shape sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) torch.testing.assert_close(qx_amax_t, ref_amax_t, atol=0.0, rtol=0.0) - elif return_transpose: - assert qx_t is None - assert sx_t is None - assert qx_amax_t is None torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 81a51335b1..b85abc9829 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -553,7 +553,10 @@ def test_nvfp4_row_scaled_quantizer_roles(): ], ) def test_fp4_dequantize(dtype, rowwise_amax_is_row_scaled, M, N): - q = NVFP4Quantizer(rowwise_amax_is_row_scaled=rowwise_amax_is_row_scaled) + q = NVFP4Quantizer( + columnwise=not rowwise_amax_is_row_scaled, + rowwise_amax_is_row_scaled=rowwise_amax_is_row_scaled, + ) a = torch.rand((M, N)).cuda().to(dtype=dtype) starting_tensor = q(a) dequantized_tensor = starting_tensor.dequantize() diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index 50253c5629..0313c1b026 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -657,7 +657,6 @@ __global__ void __launch_bounds__(THREADS_NUM) const size_t scales_offset_X = scales_offset_X_rowwise; const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; - // const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows; const bool rowwise_scale_is_within_bounds_Y = (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { @@ -680,7 +679,6 @@ __global__ void __launch_bounds__(THREADS_NUM) const size_t scales_offset_X = scales_offset_X_rowwise; const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; - // const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows; const bool rowwise_scale_is_within_bounds_Y = (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { @@ -1214,7 +1212,6 @@ __global__ void __launch_bounds__(THREADS_NUM) const size_t scales_offset_X = scales_offset_X_rowwise; const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; - // const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows; const bool rowwise_scale_is_within_bounds_Y = (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 0c195963b3..d8d67c1376 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -223,6 +223,11 @@ def general_gemm( assert ( quantization_params is None ), "Row-scaled NVFP4 GEMM currently does not support output quantization." + assert ub is None, "Row-scaled NVFP4 GEMM currently does not support CommOverlap." + assert ( + extra_output is None + ), "Row-scaled NVFP4 GEMM currently does not support extra output." + assert not bulk_overlap, "Row-scaled NVFP4 GEMM currently does not support bulk overlap." assert out is None or ( isinstance(out, torch.Tensor) and not is_custom(out) ), "Row-scaled NVFP4 GEMM currently supports only plain torch.Tensor outputs." @@ -315,7 +320,11 @@ def general_grouped_gemm( else: bias_dtype = TE_DType[torch.bfloat16] - if any(_is_nvfp4_row_scaled_tensor(tensor) for tensor in B): + row_scaled_b = [_is_nvfp4_row_scaled_tensor(tensor) for tensor in B] + if any(row_scaled_b): + assert all( + row_scaled_b + ), "Row-scaled NVFP4 grouped GEMM requires all B tensors to be row-scaled." assert layout[1] == "N", "Row-scaled NVFP4 grouped GEMM currently supports N-layout B only." if grad: raise RuntimeError( diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 538f5282b2..0518ead41a 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -800,6 +800,11 @@ std::tuple, std::vector, bool> bulk_alloc const auto rowwise_usage = quantizer_cpp_list[0]->rowwise_usage; const bool rowwise_amax_is_row_scaled = quantizer_cpp_list[0]->rowwise_amax_is_row_scaled; const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage; + if (rowwise_amax_is_row_scaled) { + NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 bulk allocation requires rowwise usage."); + NVTE_CHECK(!columnwise_usage, + "Row-scaled NVFP4 bulk allocation does not support columnwise usage."); + } const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode(); const auto fp4_dtype = quantizer_cpp_list[0]->dtype; const bool with_gemm_swizzled_scales = false; /// TODO (tmoon) Enable based on optimize_for_gemm; diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index ee28852c32..6428c84cae 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1749,9 +1749,11 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve "NVFP4 requires tensor dims that are divisible by ", NVFP4_BLOCK_SIZE, " (got shape=", shape, ")"); const bool rowwise_amax_is_row_scaled = this->rowwise_amax_is_row_scaled; - NVTE_CHECK(!rowwise_amax_is_row_scaled || rowwise_usage, - "Row-scaled NVFP4 quantization requires rowwise usage."); - const bool columnwise_usage = this->columnwise_usage && !rowwise_amax_is_row_scaled; + if (rowwise_amax_is_row_scaled) { + NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 quantization requires rowwise usage."); + NVTE_CHECK(!columnwise_usage, + "Row-scaled NVFP4 quantization does not support columnwise usage."); + } const auto rowwise_scale_inv_shape = get_scale_shape(shape, false); const auto columnwise_scale_inv_shape = get_scale_shape(shape, true); @@ -1902,9 +1904,11 @@ std::pair NVFP4Quantizer::create_grouped_tenso std::optional columnwise_amax; const std::vector logical_shape_vec = {logical_first_dim, logical_last_dim}; const bool rowwise_amax_is_row_scaled = this->rowwise_amax_is_row_scaled; - NVTE_CHECK(!rowwise_amax_is_row_scaled || rowwise_usage, - "Row-scaled NVFP4 grouped quantization requires rowwise usage."); - const bool columnwise_usage = this->columnwise_usage && !rowwise_amax_is_row_scaled; + if (rowwise_amax_is_row_scaled) { + NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 grouped quantization requires rowwise usage."); + NVTE_CHECK(!columnwise_usage, + "Row-scaled NVFP4 grouped quantization does not support columnwise usage."); + } const int64_t total_data_elements = total_elements / 2; @@ -2056,9 +2060,11 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( } const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; const bool rowwise_amax_is_row_scaled = this->rowwise_amax_is_row_scaled; - NVTE_CHECK(!rowwise_amax_is_row_scaled || rowwise_usage, - "Row-scaled NVFP4 quantization requires rowwise usage."); - const bool columnwise_usage = this->columnwise_usage && !rowwise_amax_is_row_scaled; + if (rowwise_amax_is_row_scaled) { + NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 quantization requires rowwise usage."); + NVTE_CHECK(!columnwise_usage, + "Row-scaled NVFP4 quantization does not support columnwise usage."); + } tensor.attr("_rowwise_amax_is_row_scaled") = py::cast(rowwise_amax_is_row_scaled); // Coerce row-wise data diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py index 4fb763c600..fe0dda0b5f 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py @@ -354,7 +354,14 @@ def __init__( with_rht: bool = False, with_random_sign_mask: bool = True, ): - super().__init__(rowwise=rowwise, columnwise=columnwise and not rowwise_amax_is_row_scaled) + if rowwise_amax_is_row_scaled: + if not rowwise: + raise ValueError("Row-scaled NVFP4 reference quantization requires rowwise usage.") + if columnwise: + raise ValueError( + "Row-scaled NVFP4 reference quantization does not support columnwise usage." + ) + super().__init__(rowwise=rowwise, columnwise=columnwise) self.internal = True self.dtype = dtype diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 6ebcdd1db4..4678a9e1a1 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -319,6 +319,11 @@ def make_empty( f"Incorrect shape {shape} for NVFP4. Tensor dims must be divisible by" f" {NVFP4_BLOCK_SCALING_SIZE}" ) + if self.rowwise_amax_is_row_scaled: + if not self.rowwise_usage: + raise ValueError("Row-scaled NVFP4 quantization requires rowwise usage.") + if self.columnwise_usage: + raise ValueError("Row-scaled NVFP4 quantization does not support columnwise usage.") # Allocate FP4 data data = None @@ -345,8 +350,7 @@ def make_empty( columnwise_data = None columnwise_scale_inv = None amax_columnwise = None - columnwise_usage = self.columnwise_usage and not self.rowwise_amax_is_row_scaled - if columnwise_usage: + if self.columnwise_usage: # enforce 2D shape to avoid [S, B, H] shape and B and be 1 # and the transposed shape is [H, S, B], so divide last dim by 2 gives zero shape_2d = tuple([flat_first_dim, shape[-1]]) diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index 8801102e43..c70f4d1d10 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -670,7 +670,15 @@ def make_grouped_tensor( amax = torch.empty(num_tensors, dtype=torch.float32, device=device) elif quantizer._get_compatible_recipe().nvfp4(): rowwise_amax_is_row_scaled = quantizer.rowwise_amax_is_row_scaled - columnwise_usage = columnwise_usage and not rowwise_amax_is_row_scaled + if rowwise_amax_is_row_scaled: + if not rowwise_usage: + raise ValueError( + "Row-scaled NVFP4 grouped quantization requires rowwise usage." + ) + if columnwise_usage: + raise ValueError( + "Row-scaled NVFP4 grouped quantization does not support columnwise usage." + ) total_amax_elements = ( sum(math.prod(s[:-1]) for s in shape) if rowwise_amax_is_row_scaled else num_tensors ) From 90a97a4905d10d62c3f9fbe993508b34ab0faf9e Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 5 May 2026 20:35:23 -0700 Subject: [PATCH 36/45] Clean up Signed-off-by: Ziang Li --- transformer_engine/common/cast/dispatch/quantize.cuh | 2 -- .../common/cast/nvfp4/quantize_transpose_nvfp4.cuh | 2 -- transformer_engine/pytorch/cpp_extensions/gemm.py | 2 -- .../pytorch/tensor/storage/grouped_tensor_storage.py | 11 ++--------- 4 files changed, 2 insertions(+), 15 deletions(-) diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 9ed199d798..a538890503 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -104,8 +104,6 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, if (rowwise_amax_is_row_scaled) { NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, "Row-scaled NVFP4 quantization does not support 2D quantization."); - NVTE_CHECK(output_tensor->has_data(), - "Row-scaled NVFP4 quantization requires rowwise output."); NVTE_CHECK(!output_tensor->has_columnwise_data(), "Row-scaled NVFP4 quantization does not produce columnwise output."); nvfp4::compute_rowwise_amax(*input_tensor, noop_tensor, output_tensor, stream); diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index 0313c1b026..7500483c70 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -105,8 +105,6 @@ void launch_compute_rowwise_amax(const int num_rows, const int num_cols, const I const float *noop = nullptr) { if (num_rows == 0 || num_cols == 0) return; - NVTE_CHECK(num_cols % 2 == 0, "num_cols must be even for row-scaled amax computation, got ", - num_cols); dim3 grid(num_rows); dim3 block(ROWWISE_AMAX_BLOCK_SIZE); diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index d8d67c1376..a85698c36a 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -89,10 +89,8 @@ def _nvfp4_row_scaled_gemm_inputs( A_metadata["rowwise_amax_is_row_scaled"] = False B_metadata = B.get_metadata() - assert B._rowwise_amax_is_row_scaled rhs_rowwise_amax = B._amax_rowwise assert rhs_rowwise_amax is not None - assert rhs_rowwise_amax.numel() == 0 or rhs_rowwise_amax.numel() > 1 B_metadata["amax_rowwise"] = rhs_rowwise_amax.new_ones(1) B_metadata["rowwise_amax_is_row_scaled"] = False diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index c70f4d1d10..da15288da8 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -712,9 +712,7 @@ def make_grouped_tensor( columnwise_scale_inv = torch.empty( total_columnwise_scale_elements, dtype=torch.uint8, device=device ) - columnwise_amax = torch.empty( - total_amax_elements, dtype=torch.float32, device=device - ) + columnwise_amax = torch.empty(num_tensors, dtype=torch.float32, device=device) elif quantizer._get_compatible_recipe().float8_block_scaling(): if rowwise_usage: # Allocate rowwise data buffer (1D flattened, uint8) @@ -1121,12 +1119,7 @@ def split_into_quantized_tensors( amax_rowwise = self.amax[i : i + 1] if self.columnwise_amax is not None: - if nvfp4_rowwise_amax_offsets is not None: - amax_start = nvfp4_rowwise_amax_offsets[i] - amax_end = nvfp4_rowwise_amax_offsets[i + 1] - amax_columnwise = self.columnwise_amax[amax_start:amax_end] - else: - amax_columnwise = self.columnwise_amax[i : i + 1] + amax_columnwise = self.columnwise_amax[i : i + 1] if quantizer.internal: nvfp4_tensor_class = NVFP4TensorStorage From 600b4cd24d06607393c3a1ffbe0f88f600331ec3 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 5 May 2026 20:45:44 -0700 Subject: [PATCH 37/45] `is_quantizable` returns false Signed-off-by: Ziang Li --- tests/pytorch/test_recipe.py | 2 ++ transformer_engine/pytorch/tensor/nvfp4_tensor.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index b85abc9829..746a66f2f3 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -523,6 +523,8 @@ def test_nvfp4_row_scaled_quantizer_roles(): num_quantizers=3, ).make_quantizers() assert [q.rowwise_amax_is_row_scaled for q in forward_quantizers] == [True, False, True] + assert not forward_quantizers[0].is_quantizable(torch.empty(16, 16)) + assert forward_quantizers[1].is_quantizable(torch.empty(16, 16)) backward_quantizers = NVFP4BlockScalingRecipeState( recipe, diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 4678a9e1a1..f8fcc67d40 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -218,6 +218,8 @@ def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: def is_quantizable(self, inp: torch.Tensor) -> bool: """Returns whether or not given inp can be quantized""" + if self.rowwise_amax_is_row_scaled: + return False if inp.ndim < 2: return False if inp.shape[-1] % NVFP4_BLOCK_SCALING_SIZE != 0: From cc9a2103eaa60ffd6cc77f71db0f1f3468d353d5 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 5 May 2026 21:05:15 -0700 Subject: [PATCH 38/45] Error out grouped gemm Signed-off-by: Ziang Li --- transformer_engine/pytorch/cpp_extensions/gemm.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index a85698c36a..6795d635d8 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -15,6 +15,7 @@ from ..quantized_tensor import Quantizer from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage +from ..tensor.storage.grouped_tensor_storage import GroupedTensorStorage from ..tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage from ..tensor.utils import is_custom from ..custom_recipes.gemm import custom_gemm @@ -489,6 +490,13 @@ def general_grouped_gemm_for_grouped_tensor( if is_discrete_in and is_discrete_out: raise ValueError("Both A and out are discrete. This is not supported yet.") + if ( + (isinstance(A, GroupedTensorStorage) and A.rowwise_amax_is_row_scaled) + or (isinstance(B, GroupedTensorStorage) and B.rowwise_amax_is_row_scaled) + or (isinstance(out, GroupedTensorStorage) and out.rowwise_amax_is_row_scaled) + ): + raise NotImplementedError("Row-scaled NVFP4 GroupedTensor GEMM is not supported yet.") + if is_discrete_out: # wgrad case. grouped_gemm_impl = tex.te_general_grouped_gemm_for_discrete_out From 39f96c1acf0093d2757629cc160860b4c1e2e231 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 5 May 2026 21:19:25 -0700 Subject: [PATCH 39/45] Tighten test Signed-off-by: Ziang Li --- tests/pytorch/test_recipe.py | 4 ++++ tests/pytorch/test_sanity.py | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 746a66f2f3..b7762444c8 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -561,8 +561,12 @@ def test_fp4_dequantize(dtype, rowwise_amax_is_row_scaled, M, N): ) a = torch.rand((M, N)).cuda().to(dtype=dtype) starting_tensor = q(a) + assert starting_tensor._rowwise_amax_is_row_scaled == rowwise_amax_is_row_scaled + assert starting_tensor._amax_rowwise.numel() == (M if rowwise_amax_is_row_scaled else 1) dequantized_tensor = starting_tensor.dequantize() new_tensor = q(dequantized_tensor) + assert new_tensor._rowwise_amax_is_row_scaled == rowwise_amax_is_row_scaled + assert new_tensor._amax_rowwise.numel() == (M if rowwise_amax_is_row_scaled else 1) torch.testing.assert_close( new_tensor._rowwise_data, starting_tensor._rowwise_data, diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 760f9659d0..c811342df5 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -578,10 +578,10 @@ def test_sanity_grouped_linear( if fp8_recipe.nvfp4(): if not getattr(fp8_recipe, "row_scaled_activation", False): pytest.skip("NVFP4 not supported for grouped linear") + if single_param: + pytest.skip("Row-scaled NVFP4 does not support GroupedTensor grouped linear") if dtype == torch.float16: pytest.skip("FP16 output for NVFP4 not supported") - if backward_override is None and dtype != torch.bfloat16: - pytest.skip("NVFP4 grouped default backward requires BF16 grad output") use_fp8 = fp8_recipe is not None with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): From 4d34527031681307952e048f8216f06f88c9adda Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 5 May 2026 21:48:12 -0700 Subject: [PATCH 40/45] Rename verbose rowwise_amax_is_row_scaled Signed-off-by: Ziang Li --- .../cpp/operator/test_cast_nvfp4_transpose.cu | 18 +++--- tests/cpp/operator/test_dequantize_nvfp4.cu | 26 ++++----- tests/cpp/test_common.h | 4 +- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 26 ++++----- .../nvfp4/test_nvfp4_quantize_exact.py | 56 ++++++++----------- tests/pytorch/test_recipe.py | 22 ++++---- .../common/cast/dispatch/quantize.cuh | 10 ++-- .../common/cast/nvfp4/dequantize_nvfp4.cuh | 10 ++-- .../cast/nvfp4/quantize_transpose_nvfp4.cuh | 16 +++--- .../quantize_transpose_nvfp4_tuned_1D.cuh | 19 +++---- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 2 +- transformer_engine/common/common.h | 4 +- .../transformer_engine/transformer_engine.h | 6 +- .../common/transformer_engine.cpp | 4 +- .../common/transpose/cast_transpose.h | 4 +- ...quantize_transpose_vector_blockwise_fp4.cu | 25 ++++----- .../pytorch/cpp_extensions/gemm.py | 12 ++-- transformer_engine/pytorch/csrc/common.h | 2 +- .../pytorch/csrc/extensions/activation.cpp | 4 +- .../pytorch/csrc/extensions/bias.cpp | 2 +- .../pytorch/csrc/extensions/cast.cpp | 21 ++++--- .../pytorch/csrc/extensions/normalization.cpp | 4 +- transformer_engine/pytorch/csrc/quantizer.cpp | 43 +++++++------- .../pytorch/csrc/type_converters.cpp | 4 +- .../custom_recipes/quantization_nvfp4.py | 16 +++--- transformer_engine/pytorch/quantization.py | 4 +- .../pytorch/tensor/grouped_tensor.py | 6 +- .../pytorch/tensor/nvfp4_tensor.py | 20 +++---- .../tensor/storage/grouped_tensor_storage.py | 28 +++++----- .../tensor/storage/nvfp4_tensor_storage.py | 12 ++-- 30 files changed, 207 insertions(+), 223 deletions(-) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 21fb93b428..1f37520bc7 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -562,7 +562,7 @@ template void performTest(float (*OP)(const float), const std::vector& shape, const bool use_fast_math, - const bool rowwise_amax_is_row_scaled = false) { + const bool row_scaled_nvfp4 = false) { using namespace test; DType itype = TypeInfo::dtype; @@ -589,7 +589,7 @@ void performTest(float (*OP)(const float), const size_t scales_stride_t = blocks_X_t; Tensor input("input", shape, itype); - Tensor output("output", shape, otype, true, !rowwise_amax_is_row_scaled, NVTE_NVFP4_1D_SCALING); + Tensor output("output", shape, otype, true, !row_scaled_nvfp4, NVTE_NVFP4_1D_SCALING); std::unique_ptr ref_output = std::make_unique(rows * (cols / 2)); std::unique_ptr ref_output_t = std::make_unique(cols * (rows / 2)); @@ -602,9 +602,9 @@ void performTest(float (*OP)(const float), const float amax = 448.0f * 6.0f * 8.0f; std::vector ref_rowwise_amax; bool use_2d_quantization = false; - if (rowwise_amax_is_row_scaled) { + if (row_scaled_nvfp4) { output.set_tensor_amax_shape({rows}); - output.set_rowwise_amax_is_row_scaled(true); + output.set_row_scaled_nvfp4(true); compute_ref(OP, input.rowwise_cpu_dptr(), ref_output.get(), @@ -681,7 +681,7 @@ void performTest(float (*OP)(const float), // Set dump_data=true to enable dumping tensor data to files for analysis compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, true, - false, !rowwise_amax_is_row_scaled); + false, !row_scaled_nvfp4); size_t scale_mismatches_num = 0; compare_scaling_factors("scales", output.rowwise_cpu_scale_inv_ptr(), @@ -689,14 +689,14 @@ void performTest(float (*OP)(const float), unpadded_blocks_Y, unpadded_blocks_X, scales_stride, scale_mismatches_num); - if (!rowwise_amax_is_row_scaled) { + if (!row_scaled_nvfp4) { compare_scaling_factors("scales_t", output.columnwise_cpu_scale_inv_ptr(), ref_scales_t.get(), unpadded_blocks_Y_t, unpadded_blocks_X_t, scales_stride_t, scale_mismatches_num); } - if (rowwise_amax_is_row_scaled) { + if (row_scaled_nvfp4) { compare_rowwise_amax(output, ref_rowwise_amax); } } @@ -747,7 +747,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { const auto tensor_dims = std::get<1>(GetParam()); const DType input_type = std::get<2>(GetParam()); const bool use_fast_math = std::get<3>(GetParam()); - const bool rowwise_amax_is_row_scaled = std::get<4>(GetParam()); + const bool row_scaled_nvfp4 = std::get<4>(GetParam()); // Skip tests if the input tensor is 1D if (tensor_dims.size() < 2) { @@ -765,7 +765,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { } TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, - performTest(OP, tensor_dims, use_fast_math, rowwise_amax_is_row_scaled); + performTest(OP, tensor_dims, use_fast_math, row_scaled_nvfp4); ); } diff --git a/tests/cpp/operator/test_dequantize_nvfp4.cu b/tests/cpp/operator/test_dequantize_nvfp4.cu index 52a72a7e9b..ec405b1d90 100644 --- a/tests/cpp/operator/test_dequantize_nvfp4.cu +++ b/tests/cpp/operator/test_dequantize_nvfp4.cu @@ -90,7 +90,7 @@ float compute_amax(const test::Tensor &t, size_t rows, size_t cols) { // against a CPU reference computed from the quantized data. template void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, - const bool rowwise_amax_is_row_scaled) { + const bool row_scaled_nvfp4) { using namespace test; DType otype = TypeInfo::dtype; @@ -99,9 +99,9 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, Tensor quantized("quantized", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); - if (rowwise_amax_is_row_scaled) { + if (row_scaled_nvfp4) { quantized.set_tensor_amax_shape({rows}); - quantized.set_rowwise_amax_is_row_scaled(true); + quantized.set_row_scaled_nvfp4(true); } else if (rows > 0 && cols > 0) { quantized.set_tensor_amax(compute_amax(input, rows, cols)); } else { @@ -143,7 +143,7 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, // Dequantize NVFP4 with GEMM-swizzled scales and compare against compact path. template void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, - const bool rowwise_amax_is_row_scaled) { + const bool row_scaled_nvfp4) { using namespace test; DType otype = TypeInfo::dtype; @@ -152,9 +152,9 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, Tensor quantized_compact("quantized_compact", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); - if (rowwise_amax_is_row_scaled) { + if (row_scaled_nvfp4) { quantized_compact.set_tensor_amax_shape({rows}); - quantized_compact.set_rowwise_amax_is_row_scaled(true); + quantized_compact.set_row_scaled_nvfp4(true); } else if (rows > 0 && cols > 0) { quantized_compact.set_tensor_amax(compute_amax(input, rows, cols)); } else { @@ -174,9 +174,9 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, // Create tensor with same FP4 data but swizzled scales Tensor quantized_swizzled("quantized_swizzled", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); - if (rowwise_amax_is_row_scaled) { + if (row_scaled_nvfp4) { quantized_swizzled.set_tensor_amax_shape({rows}); - quantized_swizzled.set_rowwise_amax_is_row_scaled(true); + quantized_swizzled.set_row_scaled_nvfp4(true); } else { quantized_swizzled.set_tensor_amax(0.0f); } @@ -185,7 +185,7 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, // Copy amax and scale from compact to swizzled before FP4 data, // since from_cpu() uploads all CPU buffers (including zero-init data). quantized_compact.to_cpu(); - if (rowwise_amax_is_row_scaled) { + if (row_scaled_nvfp4) { quantized_swizzled.copy_tensor_amax_from(quantized_compact); } else { quantized_swizzled.set_tensor_amax(quantized_compact.amax()); @@ -256,11 +256,11 @@ TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4) const auto tensor_size = std::get<0>(GetParam()); const DType output_type = std::get<1>(GetParam()); - const bool rowwise_amax_is_row_scaled = std::get<2>(GetParam()); + const bool row_scaled_nvfp4 = std::get<2>(GetParam()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType, performTest_dequantize_nvfp4( - tensor_size.first, tensor_size.second, rowwise_amax_is_row_scaled); + tensor_size.first, tensor_size.second, row_scaled_nvfp4); ); } @@ -294,11 +294,11 @@ TEST_P(DequantizeNVFP4SwizzledTestSuite, TestDequantizeNVFP4Swizzled) const auto tensor_size = std::get<0>(GetParam()); const DType output_type = std::get<1>(GetParam()); - const bool rowwise_amax_is_row_scaled = std::get<2>(GetParam()); + const bool row_scaled_nvfp4 = std::get<2>(GetParam()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType, performTest_dequantize_nvfp4_swizzled( - tensor_size.first, tensor_size.second, rowwise_amax_is_row_scaled); + tensor_size.first, tensor_size.second, row_scaled_nvfp4); ); } diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 61684e8e40..b2a7da89cf 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -327,8 +327,8 @@ class Tensor { tensor_.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); } - void set_rowwise_amax_is_row_scaled(bool rowwise_amax_is_row_scaled) { - tensor_.set_rowwise_amax_is_row_scaled(rowwise_amax_is_row_scaled); + void set_row_scaled_nvfp4(bool row_scaled_nvfp4) { + tensor_.set_row_scaled_nvfp4(row_scaled_nvfp4); } void to_cpu() const; diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index 1a23ec0096..b939336275 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -27,7 +27,7 @@ def check_nvfp4_gemm_versus_reference( *, x_columnwise: bool = False, w_columnwise: bool = False, - rowwise_amax_is_row_scaled: bool = False, + row_scaled_nvfp4: bool = False, ): te_dtype = tex.DType.kFloat4E2M1 @@ -53,12 +53,12 @@ def check_nvfp4_gemm_versus_reference( x_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, rowwise=True, - columnwise=not rowwise_amax_is_row_scaled, + columnwise=not row_scaled_nvfp4, with_amax_reduction=False, amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, - rowwise_amax_is_row_scaled=rowwise_amax_is_row_scaled, + row_scaled_nvfp4=row_scaled_nvfp4, ) w_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -118,11 +118,11 @@ def check_nvfp4_gemm_versus_reference( x_ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, rowwise=True, - columnwise=not rowwise_amax_is_row_scaled, + columnwise=not row_scaled_nvfp4, pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), - rowwise_amax_is_row_scaled=rowwise_amax_is_row_scaled, + row_scaled_nvfp4=row_scaled_nvfp4, ) w_ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, @@ -178,7 +178,7 @@ def check_nvfp4_gemm_versus_reference( x_nvfp4_native.update_usage(rowwise_usage=False) if w_columnwise: w_nvfp4_native.update_usage(rowwise_usage=False) - if rowwise_amax_is_row_scaled: + if row_scaled_nvfp4: layout = ("T" if transa else "N") + ("T" if transb else "N") y_native = general_gemm( w_nvfp4_native, @@ -248,7 +248,7 @@ def check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, - rowwise_amax_is_row_scaled=True, + row_scaled_nvfp4=True, ) w_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -338,7 +338,7 @@ def check_nvfp4_row_scaled_gemm_matches_emulated( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, - rowwise_amax_is_row_scaled=True, + row_scaled_nvfp4=True, ) x_tensorwise_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -416,9 +416,7 @@ def check_nvfp4_row_scaled_gemm_matches_emulated( ], ids=["rowxrow", "colxrow", "colxcol"], ) -@pytest.mark.parametrize( - "rowwise_amax_is_row_scaled", [False, True], ids=["nvfp4", "nvfp4_row_scaled"] -) +@pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) def test_nvfp4_gemm_versus_reference( M: int, K: int, @@ -429,9 +427,9 @@ def test_nvfp4_gemm_versus_reference( accumulate: bool, is_x_columnwise: bool, is_w_columnwise: bool, - rowwise_amax_is_row_scaled: bool, + row_scaled_nvfp4: bool, ): - if rowwise_amax_is_row_scaled: + if row_scaled_nvfp4: if accumulate: pytest.skip("Row-scaled NVFP4 GEMM output rescale does not support accumulation") if is_x_columnwise: @@ -447,7 +445,7 @@ def test_nvfp4_gemm_versus_reference( accumulate=accumulate, x_columnwise=is_x_columnwise, w_columnwise=is_w_columnwise, - rowwise_amax_is_row_scaled=rowwise_amax_is_row_scaled, + row_scaled_nvfp4=row_scaled_nvfp4, ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index 4f898feee3..0824a5e7bc 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -17,11 +17,11 @@ def maybe_skip_row_scaled_unsupported_quantization( - rowwise_amax_is_row_scaled: bool, + row_scaled_nvfp4: bool, return_transpose: bool, with_2d_quantization: bool = False, ) -> None: - if not rowwise_amax_is_row_scaled: + if not row_scaled_nvfp4: return if return_transpose: pytest.skip("Row-scaled NVFP4 does not support columnwise usage") @@ -44,10 +44,10 @@ def check_quantization_nvfp4_versus_reference( swizzled_scale: bool, use_cpp_allocator: bool, with_2d_quantization: bool, - rowwise_amax_is_row_scaled: bool = False, + row_scaled_nvfp4: bool = False, ) -> None: maybe_skip_row_scaled_unsupported_quantization( - rowwise_amax_is_row_scaled, return_transpose, with_2d_quantization + row_scaled_nvfp4, return_transpose, with_2d_quantization ) te_dtype = tex.DType.kFloat4E2M1 @@ -70,7 +70,7 @@ def check_quantization_nvfp4_versus_reference( with_rht=False, with_post_rht_amax=False, with_2d_quantization=with_2d_quantization, - rowwise_amax_is_row_scaled=rowwise_amax_is_row_scaled, + row_scaled_nvfp4=row_scaled_nvfp4, ) if use_cpp_allocator: x_nvfp4_sut = nvfp4_quantizer(x) @@ -103,7 +103,7 @@ def check_quantization_nvfp4_versus_reference( pow_2_scales=False, eps=0.0, quant_tile_shape=quant_tile_shape, - rowwise_amax_is_row_scaled=rowwise_amax_is_row_scaled, + row_scaled_nvfp4=row_scaled_nvfp4, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -178,9 +178,7 @@ def check_quantization_nvfp4_versus_reference( @pytest.mark.parametrize( "with_2d_quantization", [True, False], ids=["2d_quantization", "1d_quantization"] ) -@pytest.mark.parametrize( - "rowwise_amax_is_row_scaled", [False, True], ids=["nvfp4", "nvfp4_row_scaled"] -) +@pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) def test_quantization_block_tiling_versus_reference( x_dtype: torch.dtype, M: int, @@ -189,7 +187,7 @@ def test_quantization_block_tiling_versus_reference( swizzled_scale: bool, use_cpp_allocator: bool, with_2d_quantization: bool, - rowwise_amax_is_row_scaled: bool, + row_scaled_nvfp4: bool, ) -> None: check_quantization_nvfp4_versus_reference( x_dtype=x_dtype, @@ -199,7 +197,7 @@ def test_quantization_block_tiling_versus_reference( swizzled_scale=swizzled_scale, use_cpp_allocator=use_cpp_allocator, with_2d_quantization=with_2d_quantization, - rowwise_amax_is_row_scaled=rowwise_amax_is_row_scaled, + row_scaled_nvfp4=row_scaled_nvfp4, ) @@ -216,9 +214,7 @@ def test_quantization_block_tiling_versus_reference( @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) -@pytest.mark.parametrize( - "rowwise_amax_is_row_scaled", [False, True], ids=["nvfp4", "nvfp4_row_scaled"] -) +@pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) def test_nvfp4_quantization_extrema_versus_reference( x_dtype: torch.dtype, M: int, @@ -226,9 +222,9 @@ def test_nvfp4_quantization_extrema_versus_reference( extrema_high: bool, return_transpose: bool, use_cpp_allocator: bool, - rowwise_amax_is_row_scaled: bool, + row_scaled_nvfp4: bool, ): - maybe_skip_row_scaled_unsupported_quantization(rowwise_amax_is_row_scaled, return_transpose) + maybe_skip_row_scaled_unsupported_quantization(row_scaled_nvfp4, return_transpose) te_dtype = tex.DType.kFloat4E2M1 @@ -250,7 +246,7 @@ def test_nvfp4_quantization_extrema_versus_reference( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, - rowwise_amax_is_row_scaled=rowwise_amax_is_row_scaled, + row_scaled_nvfp4=row_scaled_nvfp4, ) if use_cpp_allocator: @@ -281,7 +277,7 @@ def test_nvfp4_quantization_extrema_versus_reference( pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), - rowwise_amax_is_row_scaled=rowwise_amax_is_row_scaled, + row_scaled_nvfp4=row_scaled_nvfp4, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -325,23 +321,21 @@ def test_nvfp4_quantization_extrema_versus_reference( @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) -@pytest.mark.parametrize( - "rowwise_amax_is_row_scaled", [False, True], ids=["nvfp4", "nvfp4_row_scaled"] -) +@pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) def test_nvfp4_quantization_boundary_values( x_dtype: torch.dtype, M: int, N: int, return_transpose: bool, use_cpp_allocator: bool, - rowwise_amax_is_row_scaled: bool, + row_scaled_nvfp4: bool, ): """ Stress rounding/threshold behavior by placing values just below/above many potential bin edges within each 16-element microblock. Validates native vs reference byte-for-byte and scale parity. """ - maybe_skip_row_scaled_unsupported_quantization(rowwise_amax_is_row_scaled, return_transpose) + maybe_skip_row_scaled_unsupported_quantization(row_scaled_nvfp4, return_transpose) te_dtype = tex.DType.kFloat4E2M1 @@ -372,7 +366,7 @@ def test_nvfp4_quantization_boundary_values( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, - rowwise_amax_is_row_scaled=rowwise_amax_is_row_scaled, + row_scaled_nvfp4=row_scaled_nvfp4, ) if use_cpp_allocator: @@ -403,7 +397,7 @@ def test_nvfp4_quantization_boundary_values( pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), - rowwise_amax_is_row_scaled=rowwise_amax_is_row_scaled, + row_scaled_nvfp4=row_scaled_nvfp4, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -447,18 +441,16 @@ def test_nvfp4_quantization_boundary_values( @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) -@pytest.mark.parametrize( - "rowwise_amax_is_row_scaled", [False, True], ids=["nvfp4", "nvfp4_row_scaled"] -) +@pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) def test_nvfp4_quantization_noncontiguous_inputs( x_dtype: torch.dtype, M: int, N: int, return_transpose: bool, use_cpp_allocator: bool, - rowwise_amax_is_row_scaled: bool, + row_scaled_nvfp4: bool, ): - maybe_skip_row_scaled_unsupported_quantization(rowwise_amax_is_row_scaled, return_transpose) + maybe_skip_row_scaled_unsupported_quantization(row_scaled_nvfp4, return_transpose) te_dtype = tex.DType.kFloat4E2M1 @@ -480,7 +472,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, - rowwise_amax_is_row_scaled=rowwise_amax_is_row_scaled, + row_scaled_nvfp4=row_scaled_nvfp4, ) if use_cpp_allocator: @@ -511,7 +503,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), - rowwise_amax_is_row_scaled=rowwise_amax_is_row_scaled, + row_scaled_nvfp4=row_scaled_nvfp4, ) x_nvfp4_ref = ref_quantizer.quantize(x_nc) diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index b7762444c8..5f5221af76 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -522,7 +522,7 @@ def test_nvfp4_row_scaled_quantizer_roles(): mode="forward", num_quantizers=3, ).make_quantizers() - assert [q.rowwise_amax_is_row_scaled for q in forward_quantizers] == [True, False, True] + assert [q.row_scaled_nvfp4 for q in forward_quantizers] == [True, False, True] assert not forward_quantizers[0].is_quantizable(torch.empty(16, 16)) assert forward_quantizers[1].is_quantizable(torch.empty(16, 16)) @@ -531,14 +531,12 @@ def test_nvfp4_row_scaled_quantizer_roles(): mode="backward", num_quantizers=2, ).make_quantizers() - assert [q.rowwise_amax_is_row_scaled for q in backward_quantizers] == [False, False] + assert [q.row_scaled_nvfp4 for q in backward_quantizers] == [False, False] @pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=str) -@pytest.mark.parametrize( - "rowwise_amax_is_row_scaled", [False, True], ids=["nvfp4", "nvfp4_row_scaled"] -) +@pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) @pytest.mark.parametrize( "M, N", [ @@ -554,19 +552,19 @@ def test_nvfp4_row_scaled_quantizer_roles(): (8192, 8192), ], ) -def test_fp4_dequantize(dtype, rowwise_amax_is_row_scaled, M, N): +def test_fp4_dequantize(dtype, row_scaled_nvfp4, M, N): q = NVFP4Quantizer( - columnwise=not rowwise_amax_is_row_scaled, - rowwise_amax_is_row_scaled=rowwise_amax_is_row_scaled, + columnwise=not row_scaled_nvfp4, + row_scaled_nvfp4=row_scaled_nvfp4, ) a = torch.rand((M, N)).cuda().to(dtype=dtype) starting_tensor = q(a) - assert starting_tensor._rowwise_amax_is_row_scaled == rowwise_amax_is_row_scaled - assert starting_tensor._amax_rowwise.numel() == (M if rowwise_amax_is_row_scaled else 1) + assert starting_tensor._row_scaled_nvfp4 == row_scaled_nvfp4 + assert starting_tensor._amax_rowwise.numel() == (M if row_scaled_nvfp4 else 1) dequantized_tensor = starting_tensor.dequantize() new_tensor = q(dequantized_tensor) - assert new_tensor._rowwise_amax_is_row_scaled == rowwise_amax_is_row_scaled - assert new_tensor._amax_rowwise.numel() == (M if rowwise_amax_is_row_scaled else 1) + assert new_tensor._row_scaled_nvfp4 == row_scaled_nvfp4 + assert new_tensor._amax_rowwise.numel() == (M if row_scaled_nvfp4 else 1) torch.testing.assert_close( new_tensor._rowwise_data, starting_tensor._rowwise_data, diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index a538890503..123362ce10 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -100,8 +100,8 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, int32_t rows = input_tensor->flat_first_dim(); int32_t cols = input_tensor->flat_last_dim(); auto dtype = input_tensor->dtype(); - const bool rowwise_amax_is_row_scaled = output_tensor->rowwise_amax_is_row_scaled; - if (rowwise_amax_is_row_scaled) { + const bool row_scaled_nvfp4 = output_tensor->row_scaled_nvfp4; + if (row_scaled_nvfp4) { NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, "Row-scaled NVFP4 quantization does not support 2D quantization."); NVTE_CHECK(!output_tensor->has_columnwise_data(), @@ -134,7 +134,7 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, /*rng_state=*/quant_config_cpp.rng_state, /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, - /*rowwise_amax_is_row_scaled=*/rowwise_amax_is_row_scaled, + /*row_scaled_nvfp4=*/row_scaled_nvfp4, /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); } @@ -249,7 +249,7 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens int32_t rows = grad_tensor->flat_first_dim(); int32_t cols = grad_tensor->flat_last_dim(); auto dtype = grad_tensor->dtype(); - NVTE_CHECK(!output_tensor->rowwise_amax_is_row_scaled, + NVTE_CHECK(!output_tensor->row_scaled_nvfp4, "Backward NVFP4 quantization does not support row-scaled outputs."); bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && (cols % 32 == 0) && output_tensor->has_data(); @@ -277,7 +277,7 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, /*rng_state=*/quant_config_cpp.rng_state, /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, - /*rowwise_amax_is_row_scaled=*/false, /*noop_tensor=*/noop_tensor->data, + /*row_scaled_nvfp4=*/false, /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); } break; diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh index 4013a10276..d549a050ee 100644 --- a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -34,7 +34,7 @@ namespace dequantize_kernel { template __global__ void __launch_bounds__(512) dequantize_fp4_kernel(const void *const input, OType *output, const fp8e4m3 *const scales, - const float *const tensor_amax, const bool rowwise_amax_is_row_scaled, + const float *const tensor_amax, const bool row_scaled_nvfp4, const size_t N, const size_t M, const size_t scale_stride, const size_t num_scale_tiles_X) { const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; @@ -64,7 +64,7 @@ __global__ void __launch_bounds__(512) fp4vec value; value.vec = input_vectorized[my_index]; fp8e4m3 scale = scales[my_scale_index]; - float amax = rowwise_amax_is_row_scaled ? tensor_amax[y] : tensor_amax[0]; + float amax = row_scaled_nvfp4 ? tensor_amax[y] : tensor_amax[0]; constexpr float factor_inv = 1.0 / (6.0 * 448.0); float final_scale = static_cast(scale) * amax * factor_inv; #pragma unroll @@ -91,7 +91,7 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); const bool with_gemm_swizzled_scales = input.with_gemm_swizzled_scales; - const bool rowwise_amax_is_row_scaled = input.rowwise_amax_is_row_scaled; + const bool row_scaled_nvfp4 = input.row_scaled_nvfp4; constexpr int FP4_BLOCK_SIZE = 16; const size_t N = input.flat_first_dim(); @@ -105,7 +105,7 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) const size_t threads = 512; const size_t blocks = DIVUP(total, threads); const size_t num_scale_tiles_X = DIVUP(Mread, static_cast(4)); - NVTE_CHECK(!rowwise_amax_is_row_scaled || input.amax.numel() == N, + NVTE_CHECK(!row_scaled_nvfp4 || input.amax.numel() == N, "Row-scaled NVFP4 dequantization requires one rowwise amax per row."); TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( @@ -116,7 +116,7 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) dequantize_fp4_kernel<<>>( input.data.dptr, reinterpret_cast(output->data.dptr), reinterpret_cast(input.scale_inv.dptr), - reinterpret_cast(input.amax.dptr), rowwise_amax_is_row_scaled, N, Mread, + reinterpret_cast(input.amax.dptr), row_scaled_nvfp4, N, Mread, input.scale_inv.shape.back(), num_scale_tiles_X);); // NOLINT(*) ); // NOLINT(*) diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index 7500483c70..4b0d4df81a 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -238,7 +238,7 @@ constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM; // 8 = 128 / template + bool ROW_SCALED_NVFP4> __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_kernel(const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ CUtensorMap tensor_map_output, @@ -639,7 +639,7 @@ __global__ void __launch_bounds__(THREADS_NUM) } float block_scale_inverse; - if constexpr (ROWWISE_AMAX_IS_ROW_SCALED) { + if constexpr (ROW_SCALED_NVFP4) { // 2. Compute E4M3 scaling factor const size_t scales_offset_Y = scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; @@ -1320,8 +1320,8 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, using namespace ptx; bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; - const bool rowwise_amax_is_row_scaled = output->rowwise_amax_is_row_scaled; - NVTE_CHECK(!rowwise_amax_is_row_scaled || !use_2d_quantization, + const bool row_scaled_nvfp4 = output->row_scaled_nvfp4; + NVTE_CHECK(!row_scaled_nvfp4 || !use_2d_quantization, "Row-scaled NVFP4 quantization does not support 2D quantization."); // If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to @@ -1347,9 +1347,9 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated."); NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); - NVTE_CHECK(!rowwise_amax_is_row_scaled || output->amax.dptr != nullptr, + NVTE_CHECK(!row_scaled_nvfp4 || output->amax.dptr != nullptr, "Row-scaled NVFP4 quantization requires rowwise amax."); - NVTE_CHECK(!rowwise_amax_is_row_scaled || !output->has_columnwise_data(), + NVTE_CHECK(!row_scaled_nvfp4 || !output->has_columnwise_data(), "Row-scaled NVFP4 quantization does not produce columnwise output."); NVTE_CHECK(!output->with_gemm_swizzled_scales, "Output must have scales in compact format."); if (return_transpose) { @@ -1433,11 +1433,11 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, TRANSFORMER_ENGINE_SWITCH_CONDITION( use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, - TRANSFORMER_ENGINE_SWITCH_CONDITION(rowwise_amax_is_row_scaled, ROWWISE_AMAX_IS_ROW_SCALED, { + TRANSFORMER_ENGINE_SWITCH_CONDITION(row_scaled_nvfp4, ROW_SCALED_NVFP4, { TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { auto kernel = quantize_transpose_nvfp4_kernel; + ROW_SCALED_NVFP4>; if constexpr (use_2d_quantization) { kernel = quantize_transpose_nvfp4_2D_kernel +template __device__ __forceinline__ void rowwise_scaling( const IType *__restrict__ sIn_ptr, fp4e2m1x2 *__restrict__ sOut_ptr, nvfp4_scale_t *__restrict__ sSFrowwise_ptr, const float S_enc_rowwise, const int stage_Y, @@ -315,7 +315,7 @@ __device__ __forceinline__ void rowwise_scaling( nvfp4_scale_t S_dec_b_fp8; scaling_coeff_type SFcoefficient; - if constexpr (ROWWISE_AMAX_IS_ROW_SCALED) { + if constexpr (ROW_SCALED_NVFP4) { const size_t row_idx = row_offset + stage_Y * TILE_DIM_Y + it_offset_Y_rowwise; const float S_enc_rowwise_block = row_idx < rows ? core::compute_global_encode_scaling_factor_FP4(amax_rowwise_ptr[row_idx]) @@ -361,7 +361,7 @@ __device__ __forceinline__ void rowwise_scaling( } template + bool ROW_SCALED_NVFP4> __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_tuned_1D_kernel( const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ CUtensorMap tensor_map_output, @@ -582,7 +582,7 @@ __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_tuned_1D ptx::cp_async_bulk_wait_group_read(); // NVFP4 Quantization - rowwise_scaling( + rowwise_scaling( sIn_ptr, sOut_ptr, sSFrowwise_ptr, S_enc_rowwise, stage_Y, stage_X, buff_in, buff_out, amax_rowwise_ptr, block_offset_Y, rows, rng, random_uint4, rnd_idx); @@ -691,7 +691,7 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, const bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; const bool use_fast_math = quant_config ? quant_config->use_fast_math : false; - const bool rowwise_amax_is_row_scaled = output->rowwise_amax_is_row_scaled; + const bool row_scaled_nvfp4 = output->row_scaled_nvfp4; // If transposed output is allocated, return the transposed data // Otherwise, it's not necesary to return the transposed data. @@ -706,9 +706,9 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated."); NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); - NVTE_CHECK(!rowwise_amax_is_row_scaled || output->amax.dptr != nullptr, + NVTE_CHECK(!row_scaled_nvfp4 || output->amax.dptr != nullptr, "Row-scaled NVFP4 quantization requires rowwise amax."); - NVTE_CHECK(!rowwise_amax_is_row_scaled || !output->has_columnwise_data(), + NVTE_CHECK(!row_scaled_nvfp4 || !output->has_columnwise_data(), "Row-scaled NVFP4 quantization does not produce columnwise output."); if (return_transpose) { @@ -800,12 +800,11 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, TRANSFORMER_ENGINE_SWITCH_CONDITION( use_fast_math, USE_FAST_MATH, TRANSFORMER_ENGINE_SWITCH_CONDITION( - rowwise_amax_is_row_scaled, ROWWISE_AMAX_IS_ROW_SCALED, + row_scaled_nvfp4, ROW_SCALED_NVFP4, TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { auto kernel = quantize_transpose_nvfp4_tuned_1D_kernel; + RETURN_TRANSPOSE, ROW_SCALED_NVFP4>; cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 0af00e3ace..5e3df9e25f 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -227,7 +227,7 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz continue; } if (param_type == NVTETensorParam::kNVTERowwiseAmaxIsRowScaled) { - chunk.set_rowwise_amax_is_row_scaled(source.get_rowwise_amax_is_row_scaled()); + chunk.set_row_scaled_nvfp4(source.get_row_scaled_nvfp4()); continue; } auto param = source.get_parameter(param_type); diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 7efdc42b58..d9d5b4baae 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -177,7 +177,7 @@ struct Tensor { * * Only meaningful for NVFP4 tensors. */ - bool rowwise_amax_is_row_scaled = false; + bool row_scaled_nvfp4 = false; /*! Map from NVTETensorParam to parameter sizes */ static constexpr size_t attr_sizes[] = { @@ -205,7 +205,7 @@ struct Tensor { columnwise_scale_inv.clear(); scaling_mode = NVTE_DELAYED_TENSOR_SCALING; with_gemm_swizzled_scales = false; - rowwise_amax_is_row_scaled = false; + row_scaled_nvfp4 = false; } explicit operator NVTETensor() const noexcept { return nvte_tensor; } diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 00989fef54..38f60ae6db 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -766,8 +766,8 @@ class TensorWrapper { nvte_set_tensor_param_v2(tensor_, kNVTEWithGEMMSwizzledScales, &val, sizeof(val)); } - void set_rowwise_amax_is_row_scaled(bool rowwise_amax_is_row_scaled) { - const auto val = static_cast(rowwise_amax_is_row_scaled); + void set_row_scaled_nvfp4(bool row_scaled_nvfp4) { + const auto val = static_cast(row_scaled_nvfp4); nvte_set_tensor_param_v2(tensor_, kNVTERowwiseAmaxIsRowScaled, &val, sizeof(val)); } @@ -807,7 +807,7 @@ class TensorWrapper { return static_cast(val); } - bool get_rowwise_amax_is_row_scaled() const { + bool get_row_scaled_nvfp4() const { uint8_t val = 0; nvte_get_tensor_param_v2(tensor_, kNVTERowwiseAmaxIsRowScaled, &val, sizeof(val), nullptr); return static_cast(val); diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index e78f3d90ef..aaf3bdd6ee 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -853,7 +853,7 @@ void nvte_set_tensor_param_v2(NVTETensor tensor, NVTETensorParam param, const vo t.with_gemm_swizzled_scales = static_cast(*reinterpret_cast(buf)); break; case kNVTERowwiseAmaxIsRowScaled: - t.rowwise_amax_is_row_scaled = static_cast(*reinterpret_cast(buf)); + t.row_scaled_nvfp4 = static_cast(*reinterpret_cast(buf)); break; default: NVTE_ERROR("Unsupported tensor parameter (", static_cast(param), ")"); @@ -936,7 +936,7 @@ void nvte_get_tensor_param_v2(const NVTETensor tensor, NVTETensorParam param, vo *reinterpret_cast(buf) = static_cast(t->with_gemm_swizzled_scales); break; case kNVTERowwiseAmaxIsRowScaled: - *reinterpret_cast(buf) = static_cast(t->rowwise_amax_is_row_scaled); + *reinterpret_cast(buf) = static_cast(t->row_scaled_nvfp4); break; default: NVTE_ERROR("Unsupported tensor parameter (", static_cast(param), ")"); diff --git a/transformer_engine/common/transpose/cast_transpose.h b/transformer_engine/common/transpose/cast_transpose.h index 1a91cdc298..c462b30147 100644 --- a/transformer_engine/common/transpose/cast_transpose.h +++ b/transformer_engine/common/transpose/cast_transpose.h @@ -67,8 +67,8 @@ void quantize_transpose_vector_blockwise_fp4( SimpleTensor &scale_inv_t, SimpleTensor &output, SimpleTensor &output_t, const float epsilon, const bool return_identity, const bool return_transpose, const bool pow2_scale, const bool swizzled_scale, const bool use_stochastic_rounding, - const NVTETensor rng_state_tensor, const bool use_2d_quantization, - const bool rowwise_amax_is_row_scaled, const SimpleTensor &noop_tensor, cudaStream_t stream); + const NVTETensor rng_state_tensor, const bool use_2d_quantization, const bool row_scaled_nvfp4, + const SimpleTensor &noop_tensor, cudaStream_t stream); } // namespace transformer_engine::detail diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index d1b476f3d2..cf9821f1a9 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -316,7 +316,7 @@ __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x(const float2 in01, template + bool kApplyStochasticRounding, bool kIs2DBlockScaling, bool kRowScaledNVFP4> __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel( const IType* const input, const float* global_amax, OType* const output_c, OType* const output_t, ScaleType* const tile_scales_inv_c, ScaleType* const tile_scales_inv_t, @@ -511,15 +511,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo // Step 2.4: Compute scale const size_t row_idx = block_idx_y * kTileDim + r_s; float row_global_encode_scale = global_encode_scale; - if constexpr (kRowwiseAmaxIsRowScaled) { + if constexpr (kRowScaledNVFP4) { row_global_encode_scale = row_idx < num_rows ? ComputeGlobalEncodeScaleFP4(global_amax[row_idx]) : 1.0f; } - const float row_global_encode_scale_multiplier = kRowwiseAmaxIsRowScaled - ? row_global_encode_scale * fp4_max_inv - : global_encode_scale_multiplier; + const float row_global_encode_scale_multiplier = + kRowScaledNVFP4 ? row_global_encode_scale * fp4_max_inv : global_encode_scale_multiplier; const float row_global_decode_scale = - kRowwiseAmaxIsRowScaled ? 1.0f / row_global_encode_scale : global_decode_scale; + kRowScaledNVFP4 ? 1.0f / row_global_encode_scale : global_decode_scale; ScaleType scale_inv = ComputeDecodeScaleFP4(amax, row_global_encode_scale_multiplier); float encode_scale = ComputeEncodeScaleFP4(scale_inv, row_global_decode_scale); @@ -720,8 +719,8 @@ void quantize_transpose_vector_blockwise_fp4( SimpleTensor& scale_inv_t, SimpleTensor& output, SimpleTensor& output_t, const float epsilon, const bool return_identity, const bool return_transpose, const bool pow2_scale, const bool swizzled_scale, const bool use_stochastic_rounding, - const NVTETensor rng_state_tensor, const bool use_2d_quantization, - const bool rowwise_amax_is_row_scaled, const SimpleTensor& noop_tensor, cudaStream_t stream) { + const NVTETensor rng_state_tensor, const bool use_2d_quantization, const bool row_scaled_nvfp4, + const SimpleTensor& noop_tensor, cudaStream_t stream) { NVTE_API_CALL(quantize_transpose_vector_blockwise_fp4); #if CUDA_VERSION >= 12080 @@ -734,9 +733,9 @@ void quantize_transpose_vector_blockwise_fp4( NVTE_CHECK(return_identity || !use_2d_quantization, "2D block quantization is only supported when return_identity is true."); - NVTE_CHECK(!rowwise_amax_is_row_scaled || (return_identity && !return_transpose), + NVTE_CHECK(!row_scaled_nvfp4 || (return_identity && !return_transpose), "Row-scaled NVFP4 quantization only supports rowwise quantization."); - NVTE_CHECK(!rowwise_amax_is_row_scaled || !use_2d_quantization, + NVTE_CHECK(!row_scaled_nvfp4 || !use_2d_quantization, "Row-scaled NVFP4 quantization does not support 2D quantization."); const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; @@ -818,14 +817,14 @@ void quantize_transpose_vector_blockwise_fp4( use_2d_quantization, kIs2DBlockScaling, TRANSFORMER_ENGINE_SWITCH_CONDITION( - rowwise_amax_is_row_scaled, kRowwiseAmaxIsRowScaled, + row_scaled_nvfp4, kRowScaledNVFP4, size_t smem_bytes = kSMemSize * sizeof(InputType); auto kernel = block_scaled_1d_cast_transpose_kernel< kReturnIdentity, kReturnTranspose, kPow2Scale, kAligned, float, InputType, OutputType, ScaleType, kSwizzledScale, kApplyStochasticRounding, kIs2DBlockScaling, - kRowwiseAmaxIsRowScaled>; + kRowScaledNVFP4>; if (smem_bytes >= 48 * 1024) { cudaError_t err = cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, @@ -843,7 +842,7 @@ void quantize_transpose_vector_blockwise_fp4( row_length, num_rows, scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, kScaleBlockDim, epsilon, rng_state, - noop_ptr);) // kRowwiseAmaxIsRowScaled + noop_ptr);) // kRowScaledNVFP4 ) // kIs2DBlockScaling ) // kApplyStochasticRounding ) // kSwizzledScale diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 6795d635d8..891c764dc9 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -73,7 +73,7 @@ def validate_gemm_scale(scale: Optional[float], required: bool) -> float: def _is_nvfp4_row_scaled_tensor(tensor: torch.Tensor) -> bool: """Whether tensor carries row-scaled NVFP4 global amax metadata.""" - return isinstance(tensor, NVFP4TensorStorage) and tensor._rowwise_amax_is_row_scaled + return isinstance(tensor, NVFP4TensorStorage) and tensor._row_scaled_nvfp4 def _nvfp4_row_scaled_gemm_inputs( @@ -87,13 +87,13 @@ def _nvfp4_row_scaled_gemm_inputs( weight_amax = A._amax_rowwise if transa else A._amax_columnwise assert weight_amax is not None and weight_amax.numel() == 1 A_metadata["amax_rowwise" if transa else "amax_columnwise"] = weight_amax.new_ones(1) - A_metadata["rowwise_amax_is_row_scaled"] = False + A_metadata["row_scaled_nvfp4"] = False B_metadata = B.get_metadata() rhs_rowwise_amax = B._amax_rowwise assert rhs_rowwise_amax is not None B_metadata["amax_rowwise"] = rhs_rowwise_amax.new_ones(1) - B_metadata["rowwise_amax_is_row_scaled"] = False + B_metadata["row_scaled_nvfp4"] = False assert rhs_rowwise_amax.dtype == torch.float32 and weight_amax.dtype == torch.float32 return ( @@ -491,9 +491,9 @@ def general_grouped_gemm_for_grouped_tensor( raise ValueError("Both A and out are discrete. This is not supported yet.") if ( - (isinstance(A, GroupedTensorStorage) and A.rowwise_amax_is_row_scaled) - or (isinstance(B, GroupedTensorStorage) and B.rowwise_amax_is_row_scaled) - or (isinstance(out, GroupedTensorStorage) and out.rowwise_amax_is_row_scaled) + (isinstance(A, GroupedTensorStorage) and A.row_scaled_nvfp4) + or (isinstance(B, GroupedTensorStorage) and B.row_scaled_nvfp4) + or (isinstance(out, GroupedTensorStorage) and out.row_scaled_nvfp4) ): raise NotImplementedError("Row-scaled NVFP4 GroupedTensor GEMM is not supported yet.") diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 9bbbc270d8..d8f193f8cb 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -321,7 +321,7 @@ class NVFP4Quantizer : public Quantizer { bool with_2d_quantization; bool stochastic_rounding; // Whether tensors emitted by this quantizer store one rowwise amax per tensor row. - bool rowwise_amax_is_row_scaled; + bool row_scaled_nvfp4; int rht_matrix_random_sign_mask_t; at::Tensor rht_matrix; diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index fff2ba1edc..cab9fab30a 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -42,7 +42,7 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->rowwise_amax_is_row_scaled || + if (nvfp4_quantizer_cpp->row_scaled_nvfp4 || (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; @@ -155,7 +155,7 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->rowwise_amax_is_row_scaled || + if (nvfp4_quantizer_cpp->row_scaled_nvfp4 || (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index 2e92d2eb80..4a78dde388 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -152,7 +152,7 @@ std::vector dact_dbias( } else if (detail::IsNVFP4Quantizers(quantizer_py.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->rowwise_amax_is_row_scaled || + if (nvfp4_quantizer_cpp->row_scaled_nvfp4 || (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 0518ead41a..71a2ada3ec 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -798,9 +798,9 @@ std::tuple, std::vector, bool> bulk_alloc // Quantization parameters const auto rowwise_usage = quantizer_cpp_list[0]->rowwise_usage; - const bool rowwise_amax_is_row_scaled = quantizer_cpp_list[0]->rowwise_amax_is_row_scaled; + const bool row_scaled_nvfp4 = quantizer_cpp_list[0]->row_scaled_nvfp4; const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage; - if (rowwise_amax_is_row_scaled) { + if (row_scaled_nvfp4) { NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 bulk allocation requires rowwise usage."); NVTE_CHECK(!columnwise_usage, "Row-scaled NVFP4 bulk allocation does not support columnwise usage."); @@ -883,7 +883,7 @@ std::tuple, std::vector, bool> bulk_alloc const auto offset = roundup(buffer_size, 16); amax_offsets.push_back(offset); size_t amax_size = 4; - if (rowwise_amax_is_row_scaled) { + if (row_scaled_nvfp4) { amax_size *= flat_first_dim(rowwise_data_shapes[i]); } buffer_size = offset + amax_size; @@ -900,7 +900,7 @@ std::tuple, std::vector, bool> bulk_alloc rowwise_scale_list.emplace_back( make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); std::vector amax_shape{1}; - if (rowwise_amax_is_row_scaled) { + if (row_scaled_nvfp4) { amax_shape = {flat_first_dim(rowwise_data_shapes[i])}; } amax_rowwise_list.emplace_back( @@ -984,10 +984,10 @@ std::tuple, std::vector, bool> bulk_alloc py::object amax_columnwise = columnwise_usage ? py::cast(amax_columnwise_list[i]) : py::none(); // Construct Python tensor - tensor_py_list.emplace_back( - NVFP4TensorClass(rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, - amax_rowwise, amax_columnwise, fp4_dtype, quantizer_py_list[i], - with_gemm_swizzled_scales, rowwise_amax_is_row_scaled)); + tensor_py_list.emplace_back(NVFP4TensorClass(rowwise_data, rowwise_scale, columnwise_data, + columnwise_scale, amax_rowwise, amax_columnwise, + fp4_dtype, quantizer_py_list[i], + with_gemm_swizzled_scales, row_scaled_nvfp4)); // Construct C++ tensor // Use a TensorWrapper variable to hold the output of makeTransformerEngineTensor, @@ -1004,7 +1004,7 @@ std::tuple, std::vector, bool> bulk_alloc rowwise_usage ? rowwise_scale_shapes[i] : std::vector{0}, columnwise_usage ? columnwise_scale_shapes[i] : std::vector{0}, scaling_mode); tensor_wrapper.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); - tensor_wrapper.set_rowwise_amax_is_row_scaled(rowwise_amax_is_row_scaled); + tensor_wrapper.set_row_scaled_nvfp4(row_scaled_nvfp4); // Set the amax rowwise and amax columnwise if available if (rowwise_usage) { @@ -1481,8 +1481,7 @@ std::vector split_quantize(const at::Tensor &tensor, return detail::IsNVFP4Quantizers(quantizer.ptr()); })) { allocation_method = AllocationMethod::BULK_NVFP4; - if (static_cast(quantizer_cpp_list.front().get()) - ->rowwise_amax_is_row_scaled) { + if (static_cast(quantizer_cpp_list.front().get())->row_scaled_nvfp4) { quantization_method = QuantizationMethod::UNFUSED; } else { quantization_method = QuantizationMethod::FUSED_NVFP4; diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 3a38025a9e..4887b59c28 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -120,7 +120,7 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->rowwise_amax_is_row_scaled || + if (nvfp4_quantizer_cpp->row_scaled_nvfp4 || (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; @@ -358,7 +358,7 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->rowwise_amax_is_row_scaled || + if (nvfp4_quantizer_cpp->row_scaled_nvfp4 || (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 6428c84cae..8f2de325ae 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1696,7 +1696,7 @@ NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantize this->with_post_rht_amax = quantizer.attr("with_post_rht_amax").cast(); this->with_2d_quantization = quantizer.attr("with_2d_quantization").cast(); this->stochastic_rounding = quantizer.attr("stochastic_rounding").cast(); - this->rowwise_amax_is_row_scaled = quantizer.attr("rowwise_amax_is_row_scaled").cast(); + this->row_scaled_nvfp4 = quantizer.attr("row_scaled_nvfp4").cast(); // Get amax reduction group if needed for NVFP4 AG const bool with_amax_reduction = quantizer.attr("with_amax_reduction").cast(); @@ -1748,8 +1748,8 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve NVTE_CHECK(flat_last_dim % NVFP4_BLOCK_SIZE == 0, "NVFP4 requires tensor dims that are divisible by ", NVFP4_BLOCK_SIZE, " (got shape=", shape, ")"); - const bool rowwise_amax_is_row_scaled = this->rowwise_amax_is_row_scaled; - if (rowwise_amax_is_row_scaled) { + const bool row_scaled_nvfp4 = this->row_scaled_nvfp4; + if (row_scaled_nvfp4) { NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 quantization requires rowwise usage."); NVTE_CHECK(!columnwise_usage, "Row-scaled NVFP4 quantization does not support columnwise usage."); @@ -1767,7 +1767,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve rowwise_scale_inv_shape.end()); rowwise_data_tensor = at::empty(convert_shape_for_fp4(shape_int64), bit8_tensor_opts); rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); - const int64_t amax_rows = rowwise_amax_is_row_scaled ? static_cast(flat_first_dim) : 1; + const int64_t amax_rows = row_scaled_nvfp4 ? static_cast(flat_first_dim) : 1; // hadamard amax kernel will zero out pointer with ZeroAmaxKernel // nvte_compute_amax_with_config will zero out the pointer if needed amax_rowwise = at::empty({amax_rows}, bit32_tensor_opts); @@ -1813,7 +1813,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve kwargs["fp4_dtype"] = py::cast(this->dtype); kwargs["quantizer"] = this->quantizer; kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); - kwargs["rowwise_amax_is_row_scaled"] = py::cast(rowwise_amax_is_row_scaled); + kwargs["row_scaled_nvfp4"] = py::cast(row_scaled_nvfp4); kwargs["fake_dtype"] = GetATenDType(dtype); py::tuple args(0); @@ -1842,7 +1842,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve kwargs["fp4_dtype"] = py::cast(this->dtype); kwargs["quantizer"] = this->quantizer; kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); - kwargs["rowwise_amax_is_row_scaled"] = py::cast(rowwise_amax_is_row_scaled); + kwargs["row_scaled_nvfp4"] = py::cast(row_scaled_nvfp4); py::tuple args(0); PyObject* result = PyObject_Call(reinterpret_cast(NVFP4TensorPythonClass), args.ptr(), kwargs.ptr()); @@ -1875,7 +1875,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve std::vector{1}); } out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); - out_cpp.set_rowwise_amax_is_row_scaled(rowwise_amax_is_row_scaled); + out_cpp.set_row_scaled_nvfp4(row_scaled_nvfp4); this->set_quantization_params(&out_cpp); return {std::move(out_cpp), std::move(out_py)}; @@ -1903,8 +1903,8 @@ std::pair NVFP4Quantizer::create_grouped_tenso std::optional rowwise_amax; std::optional columnwise_amax; const std::vector logical_shape_vec = {logical_first_dim, logical_last_dim}; - const bool rowwise_amax_is_row_scaled = this->rowwise_amax_is_row_scaled; - if (rowwise_amax_is_row_scaled) { + const bool row_scaled_nvfp4 = this->row_scaled_nvfp4; + if (row_scaled_nvfp4) { NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 grouped quantization requires rowwise usage."); NVTE_CHECK(!columnwise_usage, "Row-scaled NVFP4 grouped quantization does not support columnwise usage."); @@ -1917,9 +1917,8 @@ std::pair NVFP4Quantizer::create_grouped_tenso const auto scale_shape = get_scale_shape(logical_shape_vec, false); const int64_t total_scale_elements = static_cast(product(scale_shape)); rowwise_scale_inv = at::empty({total_scale_elements}, uint8_opts); - const int64_t amax_elements = rowwise_amax_is_row_scaled - ? static_cast(logical_first_dim) - : static_cast(num_tensors); + const int64_t amax_elements = row_scaled_nvfp4 ? static_cast(logical_first_dim) + : static_cast(num_tensors); rowwise_amax = at::empty({amax_elements}, float_opts); } @@ -1978,7 +1977,7 @@ std::pair NVFP4Quantizer::create_grouped_tenso kwargs["last_dims"] = py::none(); kwargs["tensor_offsets"] = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(); kwargs["with_gemm_swizzled_scales"] = this->optimize_for_gemm; - kwargs["rowwise_amax_is_row_scaled"] = py::cast(rowwise_amax_is_row_scaled); + kwargs["row_scaled_nvfp4"] = py::cast(row_scaled_nvfp4); PyObject* result = PyObject_Call(GroupedTensorClass.ptr(), args.ptr(), kwargs.ptr()); if (result == nullptr) { PyErr_Print(); @@ -2059,13 +2058,13 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( } } const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; - const bool rowwise_amax_is_row_scaled = this->rowwise_amax_is_row_scaled; - if (rowwise_amax_is_row_scaled) { + const bool row_scaled_nvfp4 = this->row_scaled_nvfp4; + if (row_scaled_nvfp4) { NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 quantization requires rowwise usage."); NVTE_CHECK(!columnwise_usage, "Row-scaled NVFP4 quantization does not support columnwise usage."); } - tensor.attr("_rowwise_amax_is_row_scaled") = py::cast(rowwise_amax_is_row_scaled); + tensor.attr("_row_scaled_nvfp4") = py::cast(row_scaled_nvfp4); // Coerce row-wise data if (rowwise_usage) { @@ -2083,7 +2082,7 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( rowwise_scale_inv = at::empty(scale_inv_shape_int64, opts); tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv; } - const int64_t amax_rows = rowwise_amax_is_row_scaled ? static_cast(flat_first_dim) : 1; + const int64_t amax_rows = row_scaled_nvfp4 ? static_cast(flat_first_dim) : 1; if (!amax_rowwise || amax_rowwise->numel() != amax_rows) { const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); // hadamard amax kernel will zero out pointer with ZeroAmaxKernel @@ -2169,7 +2168,7 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( std::vector{1}); } out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); - out_cpp.set_rowwise_amax_is_row_scaled(rowwise_amax_is_row_scaled); + out_cpp.set_row_scaled_nvfp4(row_scaled_nvfp4); this->set_quantization_params(&out_cpp); return {std::move(out_cpp), std::move(tensor)}; @@ -2278,8 +2277,8 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou } size_t cols = input.size(input.ndim() - 1); - const bool rowwise_amax_is_row_scaled = out.get_rowwise_amax_is_row_scaled(); - if (rowwise_amax_is_row_scaled) { + const bool row_scaled_nvfp4 = out.get_row_scaled_nvfp4(); + if (row_scaled_nvfp4) { NVTE_CHECK(!this->with_rht, "Row-scaled NVFP4 quantization does not support RHT."); NVTE_CHECK(!this->with_2d_quantization, "Row-scaled NVFP4 quantization does not support 2D quantization."); @@ -2356,7 +2355,7 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou "Use with_post_rht_amax=true instead."); } } else { // Without RHT - if (compute_amax && !rowwise_amax_is_row_scaled) { + if (compute_amax && !row_scaled_nvfp4) { // Amax pointers auto rowwise_amax_ptr = out.get_amax().data_ptr; auto columnwise_amax_ptr = out.get_columnwise_amax().data_ptr; @@ -2457,7 +2456,7 @@ void NVFP4Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out, } void NVFP4Quantizer::quantize_with_amax(TensorWrapper& input, TensorWrapper& out) { - NVTE_CHECK(!out.get_rowwise_amax_is_row_scaled(), + NVTE_CHECK(!out.get_row_scaled_nvfp4(), "quantize_with_amax is not supported for row-scaled NVFP4 quantization."); // Update output tensor amaxes with input tensor amax auto input_amax_ptr = input.amax(); diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index 5e0310b4ce..37ab0b0535 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -134,7 +134,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) const bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none()); const bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); const bool with_gemm_swizzled_scales = tensor.attr("_with_gemm_swizzled_scales").cast(); - const bool rowwise_amax_is_row_scaled = tensor.attr("_rowwise_amax_is_row_scaled").cast(); + const bool row_scaled_nvfp4 = tensor.attr("_row_scaled_nvfp4").cast(); NVTE_CHECK(rowwise_usage || columnwise_usage, "No data found for NVFP4 Tensor."); @@ -164,7 +164,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) // Scale layout ret.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); - ret.set_rowwise_amax_is_row_scaled(rowwise_amax_is_row_scaled); + ret.set_row_scaled_nvfp4(row_scaled_nvfp4); // Quantizer state quantizer->set_quantization_params(&ret); diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py index fe0dda0b5f..12f8ef8f5b 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py @@ -350,11 +350,11 @@ def __init__( pow_2_scales: bool = False, eps: float = 0.0, quant_tile_shape: Tuple[int, int] = (1, 16), - rowwise_amax_is_row_scaled: bool = False, + row_scaled_nvfp4: bool = False, with_rht: bool = False, with_random_sign_mask: bool = True, ): - if rowwise_amax_is_row_scaled: + if row_scaled_nvfp4: if not rowwise: raise ValueError("Row-scaled NVFP4 reference quantization requires rowwise usage.") if columnwise: @@ -368,7 +368,7 @@ def __init__( self.pow_2_scales = pow_2_scales self.eps = eps self.quant_tile_shape = quant_tile_shape - self.rowwise_amax_is_row_scaled = rowwise_amax_is_row_scaled + self.row_scaled_nvfp4 = row_scaled_nvfp4 self.with_rht = with_rht self.with_random_sign_mask = with_random_sign_mask @@ -456,7 +456,7 @@ def _quantize_blockwise_reference( tile_len_y: int, *, pow_2_scales: bool, - rowwise_amax_is_row_scaled: bool = False, + row_scaled_nvfp4: bool = False, eps: float, # pylint: disable=unused-argument ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -498,7 +498,7 @@ def _quantize_blockwise_reference( decode_scale.to(torch.float32), ) else: - if rowwise_amax_is_row_scaled: + if row_scaled_nvfp4: global_amax = global_amax.to(torch.float32).view(m, 1, 1) global_encode_scale = torch.div(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX, global_amax) @@ -629,7 +629,7 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ raise ValueError( f"MXFP4 only supports 1x32 tile shape, got {self.quant_tile_shape}" ) - if self.rowwise_amax_is_row_scaled: + if self.row_scaled_nvfp4: raise ValueError("Row-scaled NVFP4 is only supported for NVFP4 (non-pow2) mode.") # TODO(etsykunov): Fix bug where global_amax_row and # global_amax_col are not defined @@ -647,7 +647,7 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ if self.with_rht else tensor.t().contiguous() ) - if self.rowwise_amax_is_row_scaled: + if self.row_scaled_nvfp4: if self.quant_tile_shape != (1, 16): raise ValueError( "Row-scaled NVFP4 only supports NVFP4 1x16 tile shape, " @@ -679,7 +679,7 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ self.quant_tile_shape[1], self.quant_tile_shape[0], pow_2_scales=self.pow_2_scales, - rowwise_amax_is_row_scaled=self.rowwise_amax_is_row_scaled, + row_scaled_nvfp4=self.row_scaled_nvfp4, eps=self.eps, ) if transpose_scales: diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index f5143ef789..e9f009d93d 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -1375,7 +1375,7 @@ def _make_quantizer(idx: int) -> NVFP4Quantizer: with_post_rht_amax=qparams.random_hadamard_transform, with_2d_quantization=qparams.fp4_2d_quantization, stochastic_rounding=qparams.stochastic_rounding, - rowwise_amax_is_row_scaled=self.recipe.row_scaled_activation and idx % 3 != 1, + row_scaled_nvfp4=self.recipe.row_scaled_activation and idx % 3 != 1, ) return [_make_quantizer(idx) for idx in range(self.num_quantizers)] @@ -1390,7 +1390,7 @@ def _make_quantizer(idx: int) -> NVFP4Quantizer: with_post_rht_amax=self.recipe.fp4_quant_bwd_grad.random_hadamard_transform, with_2d_quantization=self.recipe.fp4_quant_bwd_grad.fp4_2d_quantization, stochastic_rounding=self.recipe.fp4_quant_bwd_grad.stochastic_rounding, - rowwise_amax_is_row_scaled=False, + row_scaled_nvfp4=False, ) for _ in range(self.num_quantizers) ] diff --git a/transformer_engine/pytorch/tensor/grouped_tensor.py b/transformer_engine/pytorch/tensor/grouped_tensor.py index 99dee437cd..f28f972b58 100644 --- a/transformer_engine/pytorch/tensor/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/grouped_tensor.py @@ -92,7 +92,7 @@ def __new__( requires_grad: bool = False, stride: Optional[List[int]] = None, with_gemm_swizzled_scales: bool = False, - rowwise_amax_is_row_scaled: bool = False, + row_scaled_nvfp4: bool = False, ): if ( shapes is not None @@ -165,7 +165,7 @@ def __new__( scale_inv_offsets=scale_inv_offsets, columnwise_scale_inv_offsets=columnwise_scale_inv_offsets, with_gemm_swizzled_scales=with_gemm_swizzled_scales, - rowwise_amax_is_row_scaled=rowwise_amax_is_row_scaled, + row_scaled_nvfp4=row_scaled_nvfp4, ) return instance @@ -197,7 +197,7 @@ def copy_grouped_storage_metadata(dst: GroupedTensor, src: GroupedTensor) -> Non dst.logical_shape = src.logical_shape dst.quantized_tensors = src.quantized_tensors dst._with_gemm_swizzled_scales = src._with_gemm_swizzled_scales - dst.rowwise_amax_is_row_scaled = src.rowwise_amax_is_row_scaled + dst.row_scaled_nvfp4 = src.row_scaled_nvfp4 def make_wrapper_like(src: GroupedTensor, requires_grad: bool) -> GroupedTensor: """Create a wrapper of the same type and tensor metadata as src.""" diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index f8fcc67d40..7f92bbd7fa 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -129,7 +129,7 @@ class NVFP4Quantizer(Quantizer): stochastic_rounding: bool """Row-scaled NVFP4 quantization path.""" - rowwise_amax_is_row_scaled: bool + row_scaled_nvfp4: bool """RHT matrix random sign mask""" rht_matrix_random_sign_mask_t: int @@ -146,7 +146,7 @@ def __init__( with_post_rht_amax: bool = False, with_2d_quantization: bool = False, stochastic_rounding: bool = False, - rowwise_amax_is_row_scaled: bool = False, + row_scaled_nvfp4: bool = False, with_random_sign_mask: bool = True, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) @@ -157,7 +157,7 @@ def __init__( self.amax_reduction_group = amax_reduction_group self.with_2d_quantization = with_2d_quantization self.stochastic_rounding = stochastic_rounding - self.rowwise_amax_is_row_scaled = rowwise_amax_is_row_scaled + self.row_scaled_nvfp4 = row_scaled_nvfp4 self.rht_matrix_random_sign_mask_t = get_random_sign_mask_for_rht( with_random_sign_mask, torch.cuda.current_device() ) @@ -203,7 +203,7 @@ def copy(self) -> NVFP4Quantizer: with_post_rht_amax=self.with_post_rht_amax, with_2d_quantization=self.with_2d_quantization, stochastic_rounding=self.stochastic_rounding, - rowwise_amax_is_row_scaled=self.rowwise_amax_is_row_scaled, + row_scaled_nvfp4=self.row_scaled_nvfp4, ) quantizer.internal = self.internal quantizer.optimize_for_gemm = self.optimize_for_gemm @@ -218,7 +218,7 @@ def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: def is_quantizable(self, inp: torch.Tensor) -> bool: """Returns whether or not given inp can be quantized""" - if self.rowwise_amax_is_row_scaled: + if self.row_scaled_nvfp4: return False if inp.ndim < 2: return False @@ -321,7 +321,7 @@ def make_empty( f"Incorrect shape {shape} for NVFP4. Tensor dims must be divisible by" f" {NVFP4_BLOCK_SCALING_SIZE}" ) - if self.rowwise_amax_is_row_scaled: + if self.row_scaled_nvfp4: if not self.rowwise_usage: raise ValueError("Row-scaled NVFP4 quantization requires rowwise usage.") if self.columnwise_usage: @@ -343,7 +343,7 @@ def make_empty( scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory ) # Allocate global amax metadata. Row-scaled NVFP4 stores one value per row. - amax_rows = flat_first_dim if self.rowwise_amax_is_row_scaled else 1 + amax_rows = flat_first_dim if self.row_scaled_nvfp4 else 1 amax_rowwise = torch.zeros( amax_rows, dtype=torch.float32, device=device, pin_memory=pin_memory ) @@ -387,7 +387,7 @@ def make_empty( quantizer=self, requires_grad=requires_grad, with_gemm_swizzled_scales=False, - rowwise_amax_is_row_scaled=self.rowwise_amax_is_row_scaled, + row_scaled_nvfp4=self.row_scaled_nvfp4, ) def calibrate(self, tensor: torch.Tensor) -> None: @@ -448,7 +448,7 @@ def __new__( fp4_dtype: TE_DType, quantizer: Quantizer, with_gemm_swizzled_scales: bool, - rowwise_amax_is_row_scaled: bool = False, + row_scaled_nvfp4: bool = False, **kwargs, ): instance = super().__new__( @@ -462,7 +462,7 @@ def __new__( fp4_dtype, quantizer, with_gemm_swizzled_scales, - rowwise_amax_is_row_scaled, + row_scaled_nvfp4, *args, **kwargs, ) diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index da15288da8..3cdff471de 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -72,7 +72,7 @@ def _initialize_storage_fields( requires_grad: bool = False, stride: Optional[List[int]] = None, with_gemm_swizzled_scales: bool = False, - rowwise_amax_is_row_scaled: bool = False, + row_scaled_nvfp4: bool = False, ) -> None: """ Initialize a GroupedTensor. @@ -148,7 +148,7 @@ def _initialize_storage_fields( # Used as a convenience. instance.quantized_tensors = None instance._with_gemm_swizzled_scales = with_gemm_swizzled_scales - instance.rowwise_amax_is_row_scaled = rowwise_amax_is_row_scaled + instance.row_scaled_nvfp4 = row_scaled_nvfp4 def __new__( cls, @@ -174,7 +174,7 @@ def __new__( requires_grad: bool = False, stride: Optional[List[int]] = None, with_gemm_swizzled_scales: bool = False, - rowwise_amax_is_row_scaled: bool = False, + row_scaled_nvfp4: bool = False, ): instance = object.__new__(cls) cls._initialize_storage_fields( @@ -200,7 +200,7 @@ def __new__( requires_grad=requires_grad, stride=stride, with_gemm_swizzled_scales=with_gemm_swizzled_scales, - rowwise_amax_is_row_scaled=rowwise_amax_is_row_scaled, + row_scaled_nvfp4=row_scaled_nvfp4, ) return instance @@ -330,7 +330,7 @@ def clear(self) -> None: self.columnwise_scale_inv_offsets = None self.tensor_shapes = [] self.fake_dtype = torch.float32 - self.rowwise_amax_is_row_scaled = False + self.row_scaled_nvfp4 = False def __repr__(self) -> str: """String representation of the GroupedTensorStorage.""" @@ -499,7 +499,7 @@ def copy(self) -> "GroupedTensorStorage": scale_inv_offsets=self.scale_inv_offsets, columnwise_scale_inv_offsets=self.columnwise_scale_inv_offsets, with_gemm_swizzled_scales=self._with_gemm_swizzled_scales, - rowwise_amax_is_row_scaled=self.rowwise_amax_is_row_scaled, + row_scaled_nvfp4=self.row_scaled_nvfp4, ) @staticmethod @@ -610,7 +610,7 @@ def make_grouped_tensor( scale = None scale_inv_offsets = None columnwise_scale_inv_offsets = None - rowwise_amax_is_row_scaled = False + row_scaled_nvfp4 = False if no_quantization: assert dtype is not None, "dtype must be provided for unquantized GroupedTensor" if rowwise_usage: @@ -669,8 +669,8 @@ def make_grouped_tensor( # Amax buffer for delayed scaling - one per tensor amax = torch.empty(num_tensors, dtype=torch.float32, device=device) elif quantizer._get_compatible_recipe().nvfp4(): - rowwise_amax_is_row_scaled = quantizer.rowwise_amax_is_row_scaled - if rowwise_amax_is_row_scaled: + row_scaled_nvfp4 = quantizer.row_scaled_nvfp4 + if row_scaled_nvfp4: if not rowwise_usage: raise ValueError( "Row-scaled NVFP4 grouped quantization requires rowwise usage." @@ -680,7 +680,7 @@ def make_grouped_tensor( "Row-scaled NVFP4 grouped quantization does not support columnwise usage." ) total_amax_elements = ( - sum(math.prod(s[:-1]) for s in shape) if rowwise_amax_is_row_scaled else num_tensors + sum(math.prod(s[:-1]) for s in shape) if row_scaled_nvfp4 else num_tensors ) if rowwise_usage: @@ -797,7 +797,7 @@ def make_grouped_tensor( with_gemm_swizzled_scales=( quantizer.optimize_for_gemm if quantizer is not None else False ), - rowwise_amax_is_row_scaled=rowwise_amax_is_row_scaled, + row_scaled_nvfp4=row_scaled_nvfp4, ) grouped_tensor.quantized_tensors = grouped_tensor.split_into_quantized_tensors() return grouped_tensor @@ -911,8 +911,8 @@ def split_into_quantized_tensors( columnwise_scale_inv_offsets.append(cum) self.columnwise_scale_inv_offsets = columnwise_scale_inv_offsets nvfp4_rowwise_amax_offsets = None - rowwise_amax_is_row_scaled = self.rowwise_amax_is_row_scaled - if recipe.nvfp4() and rowwise_amax_is_row_scaled: + row_scaled_nvfp4 = self.row_scaled_nvfp4 + if recipe.nvfp4() and row_scaled_nvfp4: cum = 0 nvfp4_rowwise_amax_offsets = [0] for i in range(self.num_tensors): @@ -1138,7 +1138,7 @@ def split_into_quantized_tensors( fp4_dtype=quantizer.dtype, quantizer=quantizer, with_gemm_swizzled_scales=quantizer.optimize_for_gemm, - rowwise_amax_is_row_scaled=rowwise_amax_is_row_scaled, + row_scaled_nvfp4=row_scaled_nvfp4, ) result.append(tensor) diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index ec06839dfb..a0b0b86eb7 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -98,7 +98,7 @@ class NVFP4TensorStorage(QuantizedTensorStorage): # GEMM _with_gemm_swizzled_scales: bool # Whether rowwise amax stores one value per tensor row - _rowwise_amax_is_row_scaled: bool + _row_scaled_nvfp4: bool def __new__( cls, @@ -111,7 +111,7 @@ def __new__( fp4_dtype: TE_DType, quantizer: Optional[Quantizer], with_gemm_swizzled_scales: bool, - rowwise_amax_is_row_scaled: bool = False, + row_scaled_nvfp4: bool = False, *args, fake_dtype: Optional[torch.dtype] = None, **kwargs, @@ -131,7 +131,7 @@ def __new__( instance._amax_rowwise = amax_rowwise instance._amax_columnwise = amax_columnwise instance._with_gemm_swizzled_scales = with_gemm_swizzled_scales - instance._rowwise_amax_is_row_scaled = rowwise_amax_is_row_scaled + instance._row_scaled_nvfp4 = row_scaled_nvfp4 return instance @@ -156,7 +156,7 @@ def copy_from_storage(self, src: QuantizedTensorStorage) -> None: raise RuntimeError("FP4 dtype mismatch in copy_from_storage") if self._with_gemm_swizzled_scales != src._with_gemm_swizzled_scales: raise RuntimeError("Scale layout mismatch in copy_from_storage") - if self._rowwise_amax_is_row_scaled != src._rowwise_amax_is_row_scaled: + if self._row_scaled_nvfp4 != src._row_scaled_nvfp4: raise RuntimeError("Rowwise amax scaling mode mismatch in copy_from_storage") def _copy_optional(dst: Optional[torch.Tensor], src_tensor: Optional[torch.Tensor]): @@ -182,7 +182,7 @@ def get_metadata(self) -> Dict[str, Any]: "fp4_dtype": self._fp4_dtype, "quantizer": self._quantizer, "with_gemm_swizzled_scales": self._with_gemm_swizzled_scales, - "rowwise_amax_is_row_scaled": self._rowwise_amax_is_row_scaled, + "row_scaled_nvfp4": self._row_scaled_nvfp4, "fake_dtype": self._dtype, } @@ -315,7 +315,7 @@ def view(self, shape: torch.Size): quantizer=self._quantizer, fp4_dtype=self._fp4_dtype, with_gemm_swizzled_scales=self._with_gemm_swizzled_scales, - rowwise_amax_is_row_scaled=self._rowwise_amax_is_row_scaled, + row_scaled_nvfp4=self._row_scaled_nvfp4, fake_dtype=self._dtype, ) From 9676563a56af50ca78d08c1f27e645f5d7f172cc Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 5 May 2026 22:02:08 -0700 Subject: [PATCH 41/45] Clean up Signed-off-by: Ziang Li --- .../common/comm_gemm_overlap/comm_gemm_overlap.cpp | 2 +- transformer_engine/common/common.h | 4 ++-- .../common/include/transformer_engine/transformer_engine.h | 6 +++--- transformer_engine/common/transformer_engine.cpp | 4 ++-- transformer_engine/pytorch/csrc/common.h | 2 +- transformer_engine/pytorch/csrc/extensions/cast.cpp | 7 ++++++- transformer_engine/pytorch/tensor/nvfp4_tensor.py | 2 +- .../pytorch/tensor/storage/nvfp4_tensor_storage.py | 2 +- 8 files changed, 17 insertions(+), 12 deletions(-) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 5e3df9e25f..28218e2b43 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -226,7 +226,7 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz chunk.set_with_gemm_swizzled_scales(source.get_with_gemm_swizzled_scales()); continue; } - if (param_type == NVTETensorParam::kNVTERowwiseAmaxIsRowScaled) { + if (param_type == NVTETensorParam::kNVTERowScaledNVFP4) { chunk.set_row_scaled_nvfp4(source.get_row_scaled_nvfp4()); continue; } diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index d9d5b4baae..12479f2a9c 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -173,7 +173,7 @@ struct Tensor { * Only meaningful for MXFP8 and NVFP4. */ bool with_gemm_swizzled_scales = false; - /*! \brief Whether rowwise NVFP4 amax is one value per tensor row. + /*! \brief Whether NVFP4 rowwise amax metadata is row-scaled. * * Only meaningful for NVFP4 tensors. */ @@ -189,7 +189,7 @@ struct Tensor { sizeof(NVTEBasicTensor), // kNVTEColumnwiseScaleInv sizeof(NVTEBasicTensor), // kNVTEColumnwiseAmax sizeof(uint8_t), // kNVTEWithGEMMSwizzledScales - sizeof(uint8_t) // kNVTERowwiseAmaxIsRowScaled + sizeof(uint8_t) // kNVTERowScaledNVFP4 }; Tensor() : scaling_mode{NVTE_DELAYED_TENSOR_SCALING}, nvte_tensor{0} {} diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 38f60ae6db..e9a6f4f735 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -72,7 +72,7 @@ enum NVTETensorParam { kNVTEColumnwiseScaleInv = 5, /*!< Scale inverse tensor for decoding Columnwise Data */ kNVTEColumnwiseAmax = 6, /*!< Columnwise Amax tensor */ kNVTEWithGEMMSwizzledScales = 7, /*!< Whether scaling factors are in format expected by GEMM */ - kNVTERowwiseAmaxIsRowScaled = 8, /*!< Whether rowwise amax is one value per tensor row */ + kNVTERowScaledNVFP4 = 8, /*!< Whether an NVFP4 tensor uses row scaling */ kNVTENumTensorParams }; @@ -768,7 +768,7 @@ class TensorWrapper { void set_row_scaled_nvfp4(bool row_scaled_nvfp4) { const auto val = static_cast(row_scaled_nvfp4); - nvte_set_tensor_param_v2(tensor_, kNVTERowwiseAmaxIsRowScaled, &val, sizeof(val)); + nvte_set_tensor_param_v2(tensor_, kNVTERowScaledNVFP4, &val, sizeof(val)); } // Parameter getters @@ -809,7 +809,7 @@ class TensorWrapper { bool get_row_scaled_nvfp4() const { uint8_t val = 0; - nvte_get_tensor_param_v2(tensor_, kNVTERowwiseAmaxIsRowScaled, &val, sizeof(val), nullptr); + nvte_get_tensor_param_v2(tensor_, kNVTERowScaledNVFP4, &val, sizeof(val), nullptr); return static_cast(val); } diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index aaf3bdd6ee..1a52d76019 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -852,7 +852,7 @@ void nvte_set_tensor_param_v2(NVTETensor tensor, NVTETensorParam param, const vo case kNVTEWithGEMMSwizzledScales: t.with_gemm_swizzled_scales = static_cast(*reinterpret_cast(buf)); break; - case kNVTERowwiseAmaxIsRowScaled: + case kNVTERowScaledNVFP4: t.row_scaled_nvfp4 = static_cast(*reinterpret_cast(buf)); break; default: @@ -935,7 +935,7 @@ void nvte_get_tensor_param_v2(const NVTETensor tensor, NVTETensorParam param, vo case kNVTEWithGEMMSwizzledScales: *reinterpret_cast(buf) = static_cast(t->with_gemm_swizzled_scales); break; - case kNVTERowwiseAmaxIsRowScaled: + case kNVTERowScaledNVFP4: *reinterpret_cast(buf) = static_cast(t->row_scaled_nvfp4); break; default: diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index d8f193f8cb..8f5b8294e8 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -320,7 +320,7 @@ class NVFP4Quantizer : public Quantizer { // 2D block scaling bool with_2d_quantization; bool stochastic_rounding; - // Whether tensors emitted by this quantizer store one rowwise amax per tensor row. + // Whether tensors emitted by this quantizer use row-scaled NVFP4 metadata. bool row_scaled_nvfp4; int rht_matrix_random_sign_mask_t; diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 71a2ada3ec..9e1f381bfe 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -1481,7 +1481,12 @@ std::vector split_quantize(const at::Tensor &tensor, return detail::IsNVFP4Quantizers(quantizer.ptr()); })) { allocation_method = AllocationMethod::BULK_NVFP4; - if (static_cast(quantizer_cpp_list.front().get())->row_scaled_nvfp4) { + const bool has_row_scaled_nvfp4 = + std::any_of(quantizer_cpp_list.begin(), quantizer_cpp_list.end(), + [](const std::unique_ptr &quantizer) { + return static_cast(quantizer.get())->row_scaled_nvfp4; + }); + if (has_row_scaled_nvfp4) { quantization_method = QuantizationMethod::UNFUSED; } else { quantization_method = QuantizationMethod::FUSED_NVFP4; diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 7f92bbd7fa..63cbd6b7e8 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -128,7 +128,7 @@ class NVFP4Quantizer(Quantizer): """Stochastic rounding, only applicable for gradients.""" stochastic_rounding: bool - """Row-scaled NVFP4 quantization path.""" + """Whether emitted NVFP4 tensors store one FP32 amax per row.""" row_scaled_nvfp4: bool """RHT matrix random sign mask""" diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index a0b0b86eb7..8a066ce3f6 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -97,7 +97,7 @@ class NVFP4TensorStorage(QuantizedTensorStorage): # Whether scaling factors are in the swizzled format expected by # GEMM _with_gemm_swizzled_scales: bool - # Whether rowwise amax stores one value per tensor row + # Whether this NVFP4 tensor uses row-scaled amax metadata _row_scaled_nvfp4: bool def __new__( From 0187d80f7170a578228a045fa40cb22b3b090ae7 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 5 May 2026 22:28:03 -0700 Subject: [PATCH 42/45] Explicitly handle both gemm input and error out Signed-off-by: Ziang Li --- .../pytorch/cpp_extensions/gemm.py | 30 +++++++------------ 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 891c764dc9..a46b303d73 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -208,9 +208,11 @@ def general_gemm( "beta": beta, } - if not _is_nvfp4_row_scaled_tensor(B): + if not _is_nvfp4_row_scaled_tensor(A) and not _is_nvfp4_row_scaled_tensor(B): out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) else: + if _is_nvfp4_row_scaled_tensor(A): + raise NotImplementedError("Row-scaled NVFP4 GEMM does not support row-scaled A.") assert layout[1] == "N", "Row-scaled NVFP4 GEMM currently supports N-layout B only." if grad: raise RuntimeError( @@ -319,25 +321,10 @@ def general_grouped_gemm( else: bias_dtype = TE_DType[torch.bfloat16] - row_scaled_b = [_is_nvfp4_row_scaled_tensor(tensor) for tensor in B] - if any(row_scaled_b): - assert all( - row_scaled_b - ), "Row-scaled NVFP4 grouped GEMM requires all B tensors to be row-scaled." - assert layout[1] == "N", "Row-scaled NVFP4 grouped GEMM currently supports N-layout B only." - if grad: - raise RuntimeError( - "Row-scaled NVFP4 grouped GEMM currently supports fprop only. " - "Backward NVFP4 gradient quantizers should use scalar global amax." - ) - assert not gelu, "Row-scaled NVFP4 grouped GEMM currently does not support fused GELU." - assert ( - not accumulate - ), "Row-scaled NVFP4 grouped GEMM currently does not support accumulation." + if any(_is_nvfp4_row_scaled_tensor(tensor) for tensor in A): + raise NotImplementedError("Row-scaled NVFP4 grouped GEMM does not support row-scaled A.") + if any(_is_nvfp4_row_scaled_tensor(tensor) for tensor in B): assert D_dtype is None, "Row-scaled NVFP4 grouped GEMM currently does not support D_dtype." - assert all( - q is None for q in quantization_params - ), "Row-scaled NVFP4 grouped GEMM currently does not support output quantization." if single_output: assert ( m_splits is not None @@ -358,11 +345,14 @@ def general_grouped_gemm( gemm_out, _, _, _ = general_gemm( A[i], B[i], - quantization_params=None, + quantization_params=quantization_params[i], out_dtype=out_views[i].dtype, + gelu=gelu, + accumulate=accumulate, layout=layout, bias=bias[i] if use_bias else None, use_split_accumulator=use_split_accumulator, + grad=grad, ) out_views[i].copy_(gemm_out) if single_output: From ee740193b3b4d5b38b341a16566815ab4ff65c69 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 5 May 2026 22:42:29 -0700 Subject: [PATCH 43/45] Minor Signed-off-by: Ziang Li --- transformer_engine/pytorch/cpp_extensions/gemm.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index a46b303d73..5fc2a53889 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -340,13 +340,16 @@ def general_grouped_gemm( else: out_views = out for i in range(num_gemms): + if out_views[i] is None: + raise ValueError("Row-scaled NVFP4 grouped GEMM requires pre-allocated outputs.") if out_views[i].numel() == 0: continue - gemm_out, _, _, _ = general_gemm( + general_gemm( A[i], B[i], quantization_params=quantization_params[i], out_dtype=out_views[i].dtype, + out=out_views[i], gelu=gelu, accumulate=accumulate, layout=layout, @@ -354,7 +357,6 @@ def general_grouped_gemm( use_split_accumulator=use_split_accumulator, grad=grad, ) - out_views[i].copy_(gemm_out) if single_output: out = out_init return out, grad_bias, gelu_input From 01a32ef2537c2e455f47fd3a593227e902d4fc31 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 6 May 2026 15:37:49 -0700 Subject: [PATCH 44/45] Nits and lint Signed-off-by: Ziang Li --- .../common/cast/nvfp4/quantize_transpose_nvfp4.cuh | 9 ++++----- transformer_engine/common/gemm/cublaslt_gemm.cu | 3 +++ transformer_engine/pytorch/cpp_extensions/gemm.py | 12 +++++------- transformer_engine/pytorch/tensor/nvfp4_tensor.py | 2 +- .../pytorch/tensor/storage/nvfp4_tensor_storage.py | 2 +- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index 4b0d4df81a..9e4aef5a1c 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -16,7 +16,6 @@ #include #include -#include #include #include "../../common.h" @@ -68,7 +67,9 @@ __launch_bounds__(BLOCK_SIZE) const IType *__restrict__ input, float *__restrict__ output_rowwise_amax, const float *__restrict__ noop) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ < 1000) + NVTE_DEVICE_ERROR("SM 10.0+ is required."); +#else if (noop != nullptr && noop[0] == 1.0f) { return; } @@ -88,10 +89,8 @@ __launch_bounds__(BLOCK_SIZE) } const float thread_max = abs_max_2x_to_float(thread_amax_2x); - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; const float row_amax = - BlockReduce(temp_storage).Reduce(thread_max, [](float a, float b) { return fmaxf(a, b); }); + reduce_max(thread_max, threadIdx.x / THREADS_PER_WARP); if (threadIdx.x == 0) { output_rowwise_amax[row_idx] = row_amax; diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 144aea1a07..8589d7045d 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -318,6 +318,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const void *alpha, const void *beta, bool use_split_accumulator, int math_sm_count, int m_split, int n_split, bool gemm_producer, const Tensor *inputCounter, cudaStream_t stream) { + NVTE_CHECK(!inputA->row_scaled_nvfp4 && !inputB->row_scaled_nvfp4, + "cuBLAS GEMM does not support row-scaled NVFP4 inputs."); + // Tensor dims in row-major order const int A0 = inputA->flat_first_dim(); const int A1 = inputA->flat_last_dim(); diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 5fc2a53889..edf2c1e1c2 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -340,8 +340,6 @@ def general_grouped_gemm( else: out_views = out for i in range(num_gemms): - if out_views[i] is None: - raise ValueError("Row-scaled NVFP4 grouped GEMM requires pre-allocated outputs.") if out_views[i].numel() == 0: continue general_gemm( @@ -482,11 +480,11 @@ def general_grouped_gemm_for_grouped_tensor( if is_discrete_in and is_discrete_out: raise ValueError("Both A and out are discrete. This is not supported yet.") - if ( - (isinstance(A, GroupedTensorStorage) and A.row_scaled_nvfp4) - or (isinstance(B, GroupedTensorStorage) and B.row_scaled_nvfp4) - or (isinstance(out, GroupedTensorStorage) and out.row_scaled_nvfp4) - ): + if isinstance(A, GroupedTensorStorage) and A.row_scaled_nvfp4: + raise NotImplementedError("Row-scaled NVFP4 GroupedTensor GEMM is not supported yet.") + if isinstance(B, GroupedTensorStorage) and B.row_scaled_nvfp4: + raise NotImplementedError("Row-scaled NVFP4 GroupedTensor GEMM is not supported yet.") + if isinstance(out, GroupedTensorStorage) and out.row_scaled_nvfp4: raise NotImplementedError("Row-scaled NVFP4 GroupedTensor GEMM is not supported yet.") if is_discrete_out: diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 63cbd6b7e8..285a7f030a 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -462,8 +462,8 @@ def __new__( fp4_dtype, quantizer, with_gemm_swizzled_scales, - row_scaled_nvfp4, *args, + row_scaled_nvfp4=row_scaled_nvfp4, **kwargs, ) return instance diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index 8a066ce3f6..e51acb71e5 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -111,9 +111,9 @@ def __new__( fp4_dtype: TE_DType, quantizer: Optional[Quantizer], with_gemm_swizzled_scales: bool, - row_scaled_nvfp4: bool = False, *args, fake_dtype: Optional[torch.dtype] = None, + row_scaled_nvfp4: bool = False, **kwargs, ): if cls is NVFP4TensorStorage: From afc99ad0c126b708d1eea7572e7f89d43a201531 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 6 May 2026 21:54:33 -0700 Subject: [PATCH 45/45] Minor fix A100 ci Signed-off-by: Ziang Li --- tests/pytorch/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 824147b31e..8ca796c268 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -166,7 +166,8 @@ def recipe_id(fp8_recipe: Optional[Recipe]) -> str: """Readable pytest id for FP8/FP4 recipes.""" if fp8_recipe is None: return "None" - if fp8_recipe.nvfp4() and getattr(fp8_recipe, "row_scaled_activation", False): + nvfp4 = getattr(fp8_recipe, "nvfp4", None) + if nvfp4 is not None and nvfp4() and getattr(fp8_recipe, "row_scaled_activation", False): return "NVFP4RowScaledBlockScaling" return type(fp8_recipe).__name__