From 43104fc8041ee908deeaaf1a05e1a98bb8108d05 Mon Sep 17 00:00:00 2001 From: kunlunl Date: Fri, 30 Jan 2026 04:48:51 +0800 Subject: [PATCH 1/4] Add 2d quant for mxfp8 --- tests/pytorch/test_mxfp8_2d_quantize.py | 432 ++++++++++++++++++ .../common/cast/dispatch/quantize.cuh | 4 +- .../common/cast/mxfp8/quantize_mxfp8.cuh | 149 ++++-- transformer_engine/common/common.h | 4 +- .../transformer_engine/transformer_engine.h | 9 + transformer_engine/common/recipe/__init__.py | 19 +- .../common/transformer_engine.cpp | 6 + transformer_engine/pytorch/csrc/common.h | 2 + transformer_engine/pytorch/csrc/quantizer.cpp | 2 + transformer_engine/pytorch/quantization.py | 30 +- .../pytorch/tensor/mxfp8_tensor.py | 6 + 11 files changed, 608 insertions(+), 55 deletions(-) create mode 100644 tests/pytorch/test_mxfp8_2d_quantize.py diff --git a/tests/pytorch/test_mxfp8_2d_quantize.py b/tests/pytorch/test_mxfp8_2d_quantize.py new file mode 100644 index 0000000000..c76214277c --- /dev/null +++ b/tests/pytorch/test_mxfp8_2d_quantize.py @@ -0,0 +1,432 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Unit tests for MXFP8 2D block scaling quantization. +MXFP8 2D scaling: 32x32 blocks share a single scaling factor, rowwise and colwise scales are identical. +""" + +import pytest +import torch + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +from transformer_engine.pytorch import MXFP8Quantizer +from transformer_engine.common.recipe import MXFP8BlockScaling, QParams + + +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) + +# MXFP8 constants +MXFP8_BLOCK_SIZE = 32 +FP8_E4M3_MAX = 448.0 + + +def float_to_e8m0(amax: torch.Tensor) -> torch.Tensor: + """ + Convert absolute maximum values to E8M0 biased exponent (scale inverse). + + This mimics the GPU implementation in ptx::float_to_e8m0: + 1. Compute val = amax / FP8_MAX (same as amax * max_norm_rcp) + 2. Extract the biased exponent from the IEEE754 FP32 representation + 3. Round up if there's any mantissa (ceil behavior) + + E8M0 format: 8-bit unsigned integer representing 2^(value - 127) + """ + # Compute val = amax / FP8_MAX (same as GPU: amax * max_norm_rcp) + val = amax.to(torch.float32) / FP8_E4M3_MAX + + # Reinterpret float32 bits as int32 + val_u32 = val.view(torch.int32) + + # Extract biased exponent (bits 30:23) - GPU does: (val_u32 >> 23) and truncates to uint8 + exponent = ((val_u32 >> 23) & 0xFF).to(torch.int32) + + # Extract mantissa (bits 22:0) + mantissa = val_u32 & 0x7FFFFF + + # Round up condition from GPU: + # if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) + round_up = (mantissa > 0) & (exponent != 254) & ~((exponent == 0) & (mantissa <= 0x400000)) + exponent = exponent + round_up.to(torch.int32) + + # Handle special cases (GPU handles these before the main logic) + # val == 0 -> return 0 + exponent = torch.where(val == 0, torch.zeros_like(exponent), exponent) + + return exponent.to(torch.uint8) + + +def e8m0_to_scale_inv(e8m0: torch.Tensor) -> torch.Tensor: + """Convert E8M0 biased exponent back to scale inverse (float).""" + return torch.pow(2.0, e8m0.to(torch.float32) - 127) + + +def quantize_mxfp8_2d_reference( + x: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Reference implementation of MXFP8 2D block scaling quantization. + + For 2D scaling, each 32x32 block shares a single E8M0 scale factor. + + Args: + x: Input tensor of shape (M, N), assumes M and N are multiples of 32 + + Returns: + qx_rowwise: Quantized data in row-major order + scale_rowwise: E8M0 scale inverses for rowwise (shape: M x ceil(N/32)) + qx_colwise: Quantized data in column-major order + scale_colwise: E8M0 scale inverses for colwise (shape: ceil(M/32) x N) + """ + M, N = x.shape + device = x.device + dtype = x.dtype + + # Pad to multiples of 32 if needed + pad_M = (MXFP8_BLOCK_SIZE - M % MXFP8_BLOCK_SIZE) % MXFP8_BLOCK_SIZE + pad_N = (MXFP8_BLOCK_SIZE - N % MXFP8_BLOCK_SIZE) % MXFP8_BLOCK_SIZE + if pad_M > 0 or pad_N > 0: + x = torch.nn.functional.pad(x, (0, pad_N, 0, pad_M), mode='constant', value=0.0) + + M_padded, N_padded = x.shape + num_block_rows = M_padded // MXFP8_BLOCK_SIZE + num_block_cols = N_padded // MXFP8_BLOCK_SIZE + + # Reshape to expose 32x32 blocks + x_blocks = x.view( + num_block_rows, MXFP8_BLOCK_SIZE, + num_block_cols, MXFP8_BLOCK_SIZE + ).permute(0, 2, 1, 3) # (num_block_rows, num_block_cols, 32, 32) + + # Compute amax for each 32x32 block + block_amax = torch.amax(torch.abs(x_blocks.to(torch.float32)), dim=(-1, -2)) # (num_block_rows, num_block_cols) + + # Convert to E8M0 scale inverse + block_scale_e8m0 = float_to_e8m0(block_amax) # (num_block_rows, num_block_cols) + block_scale_inv = e8m0_to_scale_inv(block_scale_e8m0) # (num_block_rows, num_block_cols) + + # Expand scale to match input dimensions for quantization + # For rowwise: each row in a block uses the same scale, scale shape is (M, num_block_cols) + scale_rowwise = block_scale_e8m0.repeat_interleave(MXFP8_BLOCK_SIZE, dim=0) # (M_padded, num_block_cols) + + # For colwise: each column in a block uses the same scale, scale shape is (num_block_rows, N) + scale_colwise = block_scale_e8m0.repeat_interleave(MXFP8_BLOCK_SIZE, dim=1) # (num_block_rows, N_padded) + + # Compute scale inverse for quantization (broadcast over 32x32 blocks) + scale_inv_expanded = block_scale_inv.unsqueeze(-1).unsqueeze(-1) # (num_block_rows, num_block_cols, 1, 1) + scale_inv_expanded = scale_inv_expanded.expand(-1, -1, MXFP8_BLOCK_SIZE, MXFP8_BLOCK_SIZE) + + # Quantize: x_quantized = round(x / scale_inv) clamped to FP8 range + x_blocks_float = x_blocks.to(torch.float32) + x_scaled = x_blocks_float / scale_inv_expanded + + # Convert to FP8 (using PyTorch's float8_e4m3fn) + x_quantized = x_scaled.to(torch.float8_e4m3fn) + + # Reshape back to original layout + # Rowwise: (M_padded, N_padded) + qx_rowwise = x_quantized.permute(0, 2, 1, 3).reshape(M_padded, N_padded) + + # Colwise: same data but transposed for column-major access + qx_colwise = x_quantized.permute(0, 2, 1, 3).reshape(M_padded, N_padded) + + # Remove padding from outputs + qx_rowwise = qx_rowwise[:M, :N] + qx_colwise = qx_colwise[:M, :N] + scale_rowwise = scale_rowwise[:M, :] + scale_colwise = scale_colwise[:, :N] + + return qx_rowwise, scale_rowwise, qx_colwise, scale_colwise + + +def check_mxfp8_2d_quantization_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + use_cpp_allocator: bool, +) -> None: + """ + Test MXFP8 2D quantization against CPU reference implementation. + + Verifies: + 1. scales match reference + 2. 32x32 blocks share the same scale + 3. rowwise and colwise quantized data match reference + """ + fp8_dtype = tex.DType.kFloat8E4M3 + + device = "cuda" + seed = 42 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Create input tensor + x = torch.randn((M, N), dtype=x_dtype, device=device) + + # GPU Quantization using MXFP8Quantizer with 2D scaling + quantizer = MXFP8Quantizer( + fp8_dtype=fp8_dtype, + rowwise=True, + columnwise=True, + with_2d_quantization=True, + ) + + if use_cpp_allocator: + x_mxfp8 = quantizer(x) + else: + x_mxfp8 = quantizer.make_empty( + (M, N), dtype=x_dtype, device=device, requires_grad=False + ) + x_mxfp8 = quantizer.update_quantized(x, x_mxfp8) + + # Extract GPU results + assert x_mxfp8._rowwise_data is not None + assert x_mxfp8._columnwise_data is not None + assert x_mxfp8._rowwise_scale_inv is not None + assert x_mxfp8._columnwise_scale_inv is not None + + gpu_qx_rowwise = x_mxfp8._rowwise_data + gpu_scale_rowwise = x_mxfp8._rowwise_scale_inv + gpu_qx_colwise = x_mxfp8._columnwise_data + gpu_scale_colwise = x_mxfp8._columnwise_scale_inv + + # Reference Quantization + ref_qx_rowwise, ref_scale_rowwise, ref_qx_colwise, ref_scale_colwise = \ + quantize_mxfp8_2d_reference(x) + + num_block_rows = (M + MXFP8_BLOCK_SIZE - 1) // MXFP8_BLOCK_SIZE + num_block_cols = (N + MXFP8_BLOCK_SIZE - 1) // MXFP8_BLOCK_SIZE + + # GPU scales may have padding, compare valid portion + gpu_scale_rowwise_valid = gpu_scale_rowwise[:M, :num_block_cols] + gpu_scale_colwise_valid = gpu_scale_colwise[:num_block_rows, :N] + + # 1. Verify scales match reference + torch.testing.assert_close( + gpu_scale_rowwise_valid, + ref_scale_rowwise, + atol=0, rtol=0, + ) + + # 2. Verify 32x32 blocks share the same scale + for bi in range(num_block_rows): + for bj in range(num_block_cols): + row_start = bi * MXFP8_BLOCK_SIZE + row_end = min((bi + 1) * MXFP8_BLOCK_SIZE, M) + col_start = bj * MXFP8_BLOCK_SIZE + col_end = min((bj + 1) * MXFP8_BLOCK_SIZE, N) + + # All rows in block should have same scale for this column block + block_rowwise_scales = gpu_scale_rowwise[row_start:row_end, bj] + assert torch.all(block_rowwise_scales == block_rowwise_scales[0]), ( + f"2D mode: Block ({bi},{bj}) rowwise scales should be identical" + ) + + # All columns in block should have same scale for this row block + block_colwise_scales = gpu_scale_colwise[bi, col_start:col_end] + assert torch.all(block_colwise_scales == block_colwise_scales[0]), ( + f"2D mode: Block ({bi},{bj}) colwise scales should be identical" + ) + + # Rowwise and colwise scales should match + assert block_rowwise_scales[0] == block_colwise_scales[0], ( + f"2D mode: Block ({bi},{bj}) rowwise and colwise scales should be equal, " + f"got rowwise={block_rowwise_scales[0]}, colwise={block_colwise_scales[0]}" + ) + + # 3. Verify rowwise and colwise quantized data match reference + # Convert FP8 tensors to uint8 for bitwise comparison + gpu_qx_rowwise_uint8 = gpu_qx_rowwise.view(torch.uint8)[:M, :N] + gpu_qx_colwise_uint8 = gpu_qx_colwise.view(torch.uint8)[:M, :N] + ref_qx_rowwise_uint8 = ref_qx_rowwise.view(torch.uint8) + + torch.testing.assert_close( + gpu_qx_rowwise_uint8, + ref_qx_rowwise_uint8, + atol=0, rtol=0, + ) + + torch.testing.assert_close( + gpu_qx_colwise_uint8, + ref_qx_rowwise_uint8, + atol=0, rtol=0, + ) + + +@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) +@pytest.mark.parametrize( + "M, N", + [ + # Full tile cases (multiples of 32) + (64, 64), + (128, 128), + (256, 256), + (256, 1024), + (1024, 256), + # Padding required cases + (256, 288), + (320, 320), + (352, 256), + # Larger sizes + (2048, 2048), + (1024, 2048), + (2048, 1024), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] +) +def test_mxfp8_2d_quantization_versus_reference( + M: int, + N: int, + x_dtype: torch.dtype, + use_cpp_allocator: bool, +) -> None: + """Test MXFP8 2D quantization against reference implementation.""" + check_mxfp8_2d_quantization_versus_reference( + x_dtype=x_dtype, + M=M, + N=N, + use_cpp_allocator=use_cpp_allocator, + ) + + +# ============================================================================ +# Recipe Configuration Tests +# ============================================================================ + +class TestMXFP8BlockScalingRecipe: + """Tests for MXFP8BlockScaling recipe configuration.""" + + @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) + def test_default_recipe_has_qparams(self): + """Test that default MXFP8BlockScaling has QParams attributes.""" + mxfp8_recipe = MXFP8BlockScaling() + + # Verify QParams attributes exist + assert hasattr(mxfp8_recipe, 'fp8_quant_fwd_inp') + assert hasattr(mxfp8_recipe, 'fp8_quant_fwd_weight') + assert hasattr(mxfp8_recipe, 'fp8_quant_bwd_grad') + + # Verify they are QParams instances + assert isinstance(mxfp8_recipe.fp8_quant_fwd_inp, QParams) + assert isinstance(mxfp8_recipe.fp8_quant_fwd_weight, QParams) + assert isinstance(mxfp8_recipe.fp8_quant_bwd_grad, QParams) + + @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) + def test_default_2d_quantization_disabled(self): + """Test that 2D quantization is disabled by default.""" + mxfp8_recipe = MXFP8BlockScaling() + + # By default, 2D quantization should be disabled + assert mxfp8_recipe.enable_2d_quantization is False + + # QParams should reflect this + assert mxfp8_recipe.fp8_quant_fwd_inp.mxfp8_2d_quantization is False + assert mxfp8_recipe.fp8_quant_fwd_weight.mxfp8_2d_quantization is False + assert mxfp8_recipe.fp8_quant_bwd_grad.mxfp8_2d_quantization is False + + @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) + def test_2d_quantization_enabled_only_for_weight(self): + """Test that when 2D quantization is enabled, it only applies to weight.""" + # Create recipe with 2D quantization enabled + mxfp8_recipe = MXFP8BlockScaling(enable_2d_quantization=True) + + # enable_2d_quantization should be True + assert mxfp8_recipe.enable_2d_quantization is True + + # Only weight should have 2D quantization enabled + assert mxfp8_recipe.fp8_quant_fwd_inp.mxfp8_2d_quantization is False + assert mxfp8_recipe.fp8_quant_fwd_weight.mxfp8_2d_quantization is True + assert mxfp8_recipe.fp8_quant_bwd_grad.mxfp8_2d_quantization is False + + @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) + def test_qparams_default_values(self): + """Test that QParams have correct default values for MXFP8.""" + mxfp8_recipe = MXFP8BlockScaling() + + # Check default values for all QParams + for qparams in [ + mxfp8_recipe.fp8_quant_fwd_inp, + mxfp8_recipe.fp8_quant_fwd_weight, + mxfp8_recipe.fp8_quant_bwd_grad, + ]: + # These should use defaults for MXFP8 + assert qparams.power_2_scale is False # MXFP8 uses E8M0, inherently power of 2 + assert qparams.amax_epsilon == 0.0 + assert qparams.random_hadamard_transform is False + assert qparams.stochastic_rounding is False + assert qparams.fp4_2d_quantization is False # Not applicable to MXFP8 + assert qparams.mxfp8_2d_quantization is False # Default is False + + @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) + def test_recipe_repr_includes_2d_quantization(self): + """Test that recipe __repr__ includes 2D quantization status.""" + mxfp8_recipe_disabled = MXFP8BlockScaling(enable_2d_quantization=False) + mxfp8_recipe_enabled = MXFP8BlockScaling(enable_2d_quantization=True) + + repr_disabled = repr(mxfp8_recipe_disabled) + repr_enabled = repr(mxfp8_recipe_enabled) + + assert "enable_2d_quantization=False" in repr_disabled + assert "enable_2d_quantization=True" in repr_enabled + + +@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) +def test_mxfp8_quantizer_respects_2d_flag(): + """Test that MXFP8Quantizer correctly uses the 2D quantization flag from recipe.""" + # Test with 2D disabled + quantizer_1d = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=True, + with_2d_quantization=False, + ) + assert quantizer_1d.with_2d_quantization is False + + # Test with 2D enabled + quantizer_2d = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=True, + with_2d_quantization=True, + ) + assert quantizer_2d.with_2d_quantization is True + + +@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) +def test_mxfp8_recipe_state_creates_correct_quantizers(): + """Test that MXFP8BlockScalingRecipeState creates quantizers with correct 2D settings.""" + from transformer_engine.pytorch.quantization import MXFP8BlockScalingRecipeState + + # Test with 2D disabled + recipe_1d = MXFP8BlockScaling(enable_2d_quantization=False) + state_fwd_1d = MXFP8BlockScalingRecipeState( + recipe=recipe_1d, + mode="forward", + num_quantizers=3, # input, weight, output + ) + quantizers_1d = state_fwd_1d.make_quantizers() + + # All quantizers should have 2D disabled + for idx, q in enumerate(quantizers_1d): + assert q.with_2d_quantization is False, f"Quantizer {idx} should have 2D disabled" + + # Test with 2D enabled + recipe_2d = MXFP8BlockScaling(enable_2d_quantization=True) + state_fwd_2d = MXFP8BlockScalingRecipeState( + recipe=recipe_2d, + mode="forward", + num_quantizers=3, + ) + quantizers_2d = state_fwd_2d.make_quantizers() + + # Only weight (idx % 3 == 1) should have 2D enabled + for idx, q in enumerate(quantizers_2d): + if idx % 3 == 1: # weight + assert q.with_2d_quantization is True, f"Weight quantizer {idx} should have 2D enabled" + else: # input or output + assert q.with_2d_quantization is False, f"Non-weight quantizer {idx} should have 2D disabled" diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index a02e7f4f07..f1e966f9e0 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -85,7 +85,7 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, Tensor *dummy_workspace_tensor = nullptr; mxfp8::quantize( *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, - dummy_workspace_tensor, stream); + dummy_workspace_tensor, quant_config_cpp.mxfp8_2d_quantization, stream); break; } case NVTE_NVFP4_1D_SCALING: { @@ -223,7 +223,7 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens case NVTE_MXFP8_1D_SCALING: { mxfp8::quantize( *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, - stream); + quant_config_cpp.mxfp8_2d_quantization, stream); break; } case NVTE_NVFP4_1D_SCALING: { diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index 70a68132ad..4ffa2340d3 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -45,7 +45,7 @@ constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 template + size_t CHUNK_DIM_X, size_t THREADS_PER_CHUNK, bool kIs2DBlockScaling> __global__ void __launch_bounds__(THREADS_PER_CHUNK) quantize_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ CUtensorMap tensor_map_act_input, @@ -163,6 +163,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) #pragma nv_diag_suppress static_var_with_dynamic_init __shared__ alignas(8) uint64_t mbar[STAGES]; + // Shared memory to pass 2D block scales from colwise to rowwise pass + // THREADS_X = number of 32x32 blocks in X direction + __shared__ e8m0_t block_scales_2d[THREADS_X]; + initialize_barriers(mbar, is_master_thread); int parity = 0; @@ -176,7 +180,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) &mbar[0], is_master_thread); } -#pragma unroll + #pragma unroll for (int stage = 0; stage < STAGES; ++stage) { const size_t buff = stage % BUFFS_NUM; const size_t next_stage = stage + 1; @@ -264,6 +268,13 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } } + if constexpr (kIs2DBlockScaling) { +#pragma unroll + for (int i = 16; i > 0; i /= 2) { + thread_amax = fmaxf(thread_amax, __shfl_xor_sync(0xffffffff, thread_amax, i)); + } + } + // 2. Compute E8M0 scaling factor const e8m0_t biased_exponent = ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); @@ -276,7 +287,15 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } else { scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; } - scales_colwise[scale_idx] = biased_exponent; + scales_colwise[scale_idx] = biased_exponent; + + // In 2D mode, save scale to shared memory for rowwise pass + // Each warp (processing one 32x32 block) writes one scale via lane 0 + if constexpr (kIs2DBlockScaling && ROWWISE_SCALING) { + if (thread_lane == 0) { + block_scales_2d[threadIdx.x / THREADS_PER_WARP] = biased_exponent; + } + } const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; @@ -300,7 +319,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if constexpr (ROWWISE_SCALING) { const size_t shmem_offset_base_rowwise = buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; - thread_amax = 0.0f; + if constexpr (!kIs2DBlockScaling) { + thread_amax = 0.0f; + } float in_compute_rowwise[SCALE_DIM_X]; Vec in_cached[WAVES]; @@ -317,13 +338,17 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; // Load elements in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); + if constexpr (!kIs2DBlockScaling) { #pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + } } } - thread_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + if constexpr (!kIs2DBlockScaling) { + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } } else if constexpr (IS_CACHED_ACT_OP) { // ensures that all writes to cache made in the section above are visible to all threads __syncthreads(); @@ -342,25 +367,29 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries - if (!out_of_bounds) { - if constexpr (std::is_same_v) { + if constexpr (!kIs2DBlockScaling) { + if (!out_of_bounds) { + if constexpr (std::is_same_v) { #pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); - } - } else { + for (int e = 0; e < PACK_SIZE; ++e) { + thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { #pragma unroll - for (int e = 0; e < PACK_SIZE; e += 2) { - const IType2 in_cached_2x = {in_cached[w].data.elt[e], - in_cached[w].data.elt[e + 1]}; - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], + in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } } } } } - if constexpr (!std::is_same_v) { - thread_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + if constexpr (!kIs2DBlockScaling) { + if constexpr (!std::is_same_v) { + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } } } else { #pragma unroll @@ -397,17 +426,19 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if constexpr (!std::is_same_v) { elt = static_cast(static_cast(elt)); } - if constexpr (COMPUTE_ACTIVATIONS) { - const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); - const bool swizzled_col_out_of_bounds = - (block_offset_X + swizzled_thread_idx >= cols); - const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); - if (!out_of_bounds) { + if constexpr (!kIs2DBlockScaling) { + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = + (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + if (!out_of_bounds) { + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this thread_amax = fmaxf(thread_amax, fabsf(elt)); } - } else { - // If no activation, elt is 0 so we can safely do this - thread_amax = fmaxf(thread_amax, fabsf(elt)); } in_compute_rowwise[j] = elt; } @@ -415,8 +446,20 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } // 2. Compute E8M0 scaling factor - const e8m0_t biased_exponent = - ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + e8m0_t biased_exponent; + if constexpr (kIs2DBlockScaling && COLWISE_SCALING) { + // In 2D mode with both scaling directions, use scale from colwise pass + // Sync to ensure colwise writes to block_scales_2d are visible across warps + __syncthreads(); + e8m0_t scale_from_shmem; + if (thread_lane < THREADS_X) { + scale_from_shmem = block_scales_2d[thread_lane]; + } + // Broadcast: each thread gets scale from lane matching its tid_X_rowwise + biased_exponent = __shfl_sync(0xffffffff, scale_from_shmem, tid_X_rowwise); + } else { + biased_exponent = ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + } const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; const int stage_scales_offset_X = scales_offset_X_rowwise; size_t scale_idx; @@ -427,7 +470,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; } if (rowwise_scale_is_within_bounds) { - scales_rowwise[scale_idx] = biased_exponent; + scales_rowwise[scale_idx] = biased_exponent; } const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); @@ -556,7 +599,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) template void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, // TODO (ksivamani) - Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream) { + Tensor *output, Tensor *dbias, Tensor *workspace, const bool use_2d_quantization, cudaStream_t stream) { using namespace quantize_kernel; checkCuDriverContext(stream); @@ -642,7 +685,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES, if (specialized::hasSpec() && - !WITH_GEMM_SWIZZLED_SCALES) { + !WITH_GEMM_SWIZZLED_SCALES && !use_2d_quantization) { switch (scaling_type) { case ScalingType::ROWWISE: { using traits = specialized::CastTraits; @@ -774,11 +817,15 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, } } + if (use_2d_quantization) { + scaling_type = ScalingType::BIDIMENSIONAL; + } + switch (scaling_type) { case ScalingType::ROWWISE: { auto kernel = quantize_mxfp8_kernel; + CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK, false>; NVTE_CHECK_CUDA(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); @@ -793,7 +840,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, case ScalingType::COLWISE: { auto kernel = quantize_mxfp8_kernel; + CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK, false>; NVTE_CHECK_CUDA(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); @@ -806,18 +853,22 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, break; } case ScalingType::BIDIMENSIONAL: { - auto kernel = quantize_mxfp8_kernel; - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); - - kernel<<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, - workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_2d_quantization, kIs2DBlockScaling, + + auto kernel = quantize_mxfp8_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); + + kernel<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + ); break; } } diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 970b7aef6c..59d9072f8b 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -411,6 +411,7 @@ struct QuantizationConfig { bool nvfp4_2d_quantization = false; bool stochastic_rounding = false; bool use_fast_math = false; + bool mxfp8_2d_quantization = false; static constexpr size_t attr_sizes[] = { sizeof(uint8_t), // force_pow_2_scales @@ -420,7 +421,8 @@ struct QuantizationConfig { sizeof(NVTETensor), // rng_seed and offset sizeof(uint8_t), // nvfp4_2d_quantization sizeof(uint8_t), // stochastic_rounding - sizeof(uint8_t) // use_fast_math + sizeof(uint8_t), // use_fast_math + sizeof(uint8_t) // mxfp8_2d_quantization }; }; diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index ae41f238a4..fc9a8959b3 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -370,6 +370,8 @@ enum NVTEQuantizationConfigAttribute { * inconsistently between kernels. */ kNVTEQuantizationConfigUseFastMath = 7, + /*! Whether to use 2D block scaling for MXFP8 */ + kNVTEQuantizationConfigMXFP82DQuantization = 8, kNVTEQuantizationConfigNumAttributes }; @@ -1046,6 +1048,13 @@ class QuantizationConfigWrapper { sizeof(val)); } + /*! \brief Set whether to use 2D block scaling for MXFP8 */ + void set_mxfp8_2d_quantization(bool mxfp8_2d_quantization) { + const auto val = static_cast(mxfp8_2d_quantization); + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigMXFP82DQuantization, + &val, sizeof(val)); + } + private: /*! \brief Wrapped NVTEQuantizationConfig. */ NVTEQuantizationConfig config_ = nullptr; diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 64ee2a5a16..fee3f2c81e 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -65,6 +65,8 @@ class QParams: amax_epsilon: optional minimum value of abs max random_hadamard_transform: whether to use random hadamard transform stochastic_rounding: whether to use stocastic rounding + fp4_2d_quantization: whether to use 2D block scaling for NVFP4 + mxfp8_2d_quantization: whether to use 2D block scaling for MXFP8 """ power_2_scale: bool = False @@ -72,6 +74,7 @@ class QParams: random_hadamard_transform: bool = False stochastic_rounding: bool = False fp4_2d_quantization: bool = False + mxfp8_2d_quantization: bool = False def __repr__(self) -> str: return ( @@ -79,7 +82,8 @@ def __repr__(self) -> str: f"amax_epsilon={self.amax_epsilon},\n" f"random_hadamard_transform={self.random_hadamard_transform},\n" f"stochastic_rounding={self.stochastic_rounding},\n" - f"fp4_2d_quantization={self.fp4_2d_quantization}\n)" + f"fp4_2d_quantization={self.fp4_2d_quantization},\n" + f"mxfp8_2d_quantization={self.mxfp8_2d_quantization}\n)" ) @@ -284,8 +288,13 @@ class MXFP8BlockScaling(Recipe): fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3 Controls the FP8 data format used during forward and backward pass. + enable_2d_quantization : bool, default = False + If set to `True`, 2D block scaling is used for weight tensors. """ + # Configuration envvars + enable_2d_quantization: bool = os.getenv("NVTE_MXFP8_ENABLE_2D_QUANTIZATION", "0") == "1" + margin: int = 0 fp8_format: Format = Format.E4M3 fp8_dpa: bool = False @@ -294,11 +303,17 @@ class MXFP8BlockScaling(Recipe): def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + # Quantization params (same pattern as NVFP4BlockScaling) + self.fp8_quant_fwd_inp = QParams(mxfp8_2d_quantization=False) + self.fp8_quant_fwd_weight = QParams(mxfp8_2d_quantization=self.enable_2d_quantization) + self.fp8_quant_bwd_grad = QParams(mxfp8_2d_quantization=False) + def __repr__(self) -> str: return ( f"recipe_type={self.__class__.__name__}, " f"margin={self.margin}, " - f"format={str(self.fp8_format).split('.')[1]}" + f"format={str(self.fp8_format).split('.')[1]}, " + f"enable_2d_quantization={self.enable_2d_quantization}" ) diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 06971443dd..f3ea3d1e4f 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -1059,6 +1059,9 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigUseFastMath: bool_to_uint8(config_.use_fast_math, buf); break; + case kNVTEQuantizationConfigMXFP82DQuantization: + bool_to_uint8(config_.mxfp8_2d_quantization, buf); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } @@ -1114,6 +1117,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigUseFastMath: uint8_to_bool(buf, config_.use_fast_math); break; + case kNVTEQuantizationConfigMXFP82DQuantization: + uint8_to_bool(buf, config_.mxfp8_2d_quantization); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index bc22e03097..cee08ba96f 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -264,6 +264,8 @@ class Float8BlockQuantizer : public Quantizer { class MXFP8Quantizer : public Quantizer { public: DType dtype; + // 2D block scaling + bool with_2d_quantization; explicit MXFP8Quantizer(const py::handle& quantizer); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 1c968e276d..b537625875 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -853,6 +853,7 @@ std::vector Float8BlockQuantizer::get_scale_shape(const std::vectordtype = quantizer.attr("dtype").cast(); + this->with_2d_quantization = quantizer.attr("with_2d_quantization").cast(); } void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const {} @@ -1059,6 +1060,7 @@ void MXFP8Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out, if (noop_flag) { quant_config.set_noop_tensor(noop_flag->data()); } + quant_config.set_mxfp8_2d_quantization(this->with_2d_quantization); NVTE_SCOPED_GIL_RELEASE({ nvte_quantize_v2(input.data(), out.data(), quant_config, at::cuda::getCurrentCUDAStream()); }); diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index eba547afb0..6d63901628 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -1169,7 +1169,35 @@ def make_quantizers(self) -> list: # TODO(ksivamani); Find better design for this, adding here to avoid circular import. from .tensor.mxfp8_tensor import MXFP8Quantizer - return [MXFP8Quantizer(self.dtype) for i in range(self.num_quantizers)] + if self.mode == "forward": + + def _make_quantizer(idx: int) -> MXFP8Quantizer: + qparams = ( + self.recipe.fp8_quant_fwd_weight + if idx % 3 == 1 + else self.recipe.fp8_quant_fwd_inp + ) + return MXFP8Quantizer( + fp8_dtype=self.dtype, + rowwise=True, + columnwise=True, + with_2d_quantization=qparams.mxfp8_2d_quantization, + ) + + return [_make_quantizer(idx) for idx in range(self.num_quantizers)] + + if self.mode == "backward": + return [ + MXFP8Quantizer( + fp8_dtype=self.dtype, + rowwise=True, + columnwise=True, + with_2d_quantization=self.recipe.fp8_quant_bwd_grad.mxfp8_2d_quantization, + ) + for _ in range(self.num_quantizers) + ] + + raise ValueError(f"Unknown mode: {self.mode}") class Float8BlockScalingRecipeState(RecipeState): diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 8dd2255d89..9000592b3c 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -35,15 +35,20 @@ class MXFP8Quantizer(Quantizer): dtype: TE_DType + """2D block scaling, only applicable for weights.""" + with_2d_quantization: bool + def __init__( self, fp8_dtype: TE_DType, *, rowwise: bool = True, columnwise: bool = True, + with_2d_quantization: bool = False, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) self.dtype = fp8_dtype + self.with_2d_quantization = with_2d_quantization def copy(self) -> MXFP8Quantizer: """Create shallow copy""" @@ -52,6 +57,7 @@ def copy(self) -> MXFP8Quantizer: fp8_dtype=self.dtype, rowwise=self.rowwise_usage, columnwise=self.columnwise_usage, + with_2d_quantization=self.with_2d_quantization, ) quantizer.internal = self.internal quantizer.optimize_for_gemm = self.optimize_for_gemm From be464ea03776060f3f1e84df297068e31b05f7d1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 29 Jan 2026 20:50:32 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_mxfp8_2d_quantize.py | 150 ++++++++++-------- .../common/cast/mxfp8/quantize_mxfp8.cuh | 63 ++++---- 2 files changed, 114 insertions(+), 99 deletions(-) diff --git a/tests/pytorch/test_mxfp8_2d_quantize.py b/tests/pytorch/test_mxfp8_2d_quantize.py index c76214277c..0f70b10b50 100644 --- a/tests/pytorch/test_mxfp8_2d_quantize.py +++ b/tests/pytorch/test_mxfp8_2d_quantize.py @@ -26,35 +26,35 @@ def float_to_e8m0(amax: torch.Tensor) -> torch.Tensor: """ Convert absolute maximum values to E8M0 biased exponent (scale inverse). - + This mimics the GPU implementation in ptx::float_to_e8m0: 1. Compute val = amax / FP8_MAX (same as amax * max_norm_rcp) 2. Extract the biased exponent from the IEEE754 FP32 representation 3. Round up if there's any mantissa (ceil behavior) - + E8M0 format: 8-bit unsigned integer representing 2^(value - 127) """ # Compute val = amax / FP8_MAX (same as GPU: amax * max_norm_rcp) val = amax.to(torch.float32) / FP8_E4M3_MAX - + # Reinterpret float32 bits as int32 val_u32 = val.view(torch.int32) - + # Extract biased exponent (bits 30:23) - GPU does: (val_u32 >> 23) and truncates to uint8 exponent = ((val_u32 >> 23) & 0xFF).to(torch.int32) - + # Extract mantissa (bits 22:0) mantissa = val_u32 & 0x7FFFFF - + # Round up condition from GPU: # if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) round_up = (mantissa > 0) & (exponent != 254) & ~((exponent == 0) & (mantissa <= 0x400000)) exponent = exponent + round_up.to(torch.int32) - + # Handle special cases (GPU handles these before the main logic) # val == 0 -> return 0 exponent = torch.where(val == 0, torch.zeros_like(exponent), exponent) - + return exponent.to(torch.uint8) @@ -68,76 +68,83 @@ def quantize_mxfp8_2d_reference( ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Reference implementation of MXFP8 2D block scaling quantization. - + For 2D scaling, each 32x32 block shares a single E8M0 scale factor. - + Args: x: Input tensor of shape (M, N), assumes M and N are multiples of 32 - + Returns: qx_rowwise: Quantized data in row-major order scale_rowwise: E8M0 scale inverses for rowwise (shape: M x ceil(N/32)) - qx_colwise: Quantized data in column-major order + qx_colwise: Quantized data in column-major order scale_colwise: E8M0 scale inverses for colwise (shape: ceil(M/32) x N) """ M, N = x.shape device = x.device dtype = x.dtype - + # Pad to multiples of 32 if needed pad_M = (MXFP8_BLOCK_SIZE - M % MXFP8_BLOCK_SIZE) % MXFP8_BLOCK_SIZE pad_N = (MXFP8_BLOCK_SIZE - N % MXFP8_BLOCK_SIZE) % MXFP8_BLOCK_SIZE if pad_M > 0 or pad_N > 0: - x = torch.nn.functional.pad(x, (0, pad_N, 0, pad_M), mode='constant', value=0.0) - + x = torch.nn.functional.pad(x, (0, pad_N, 0, pad_M), mode="constant", value=0.0) + M_padded, N_padded = x.shape num_block_rows = M_padded // MXFP8_BLOCK_SIZE num_block_cols = N_padded // MXFP8_BLOCK_SIZE - + # Reshape to expose 32x32 blocks - x_blocks = x.view( - num_block_rows, MXFP8_BLOCK_SIZE, - num_block_cols, MXFP8_BLOCK_SIZE - ).permute(0, 2, 1, 3) # (num_block_rows, num_block_cols, 32, 32) - + x_blocks = x.view(num_block_rows, MXFP8_BLOCK_SIZE, num_block_cols, MXFP8_BLOCK_SIZE).permute( + 0, 2, 1, 3 + ) # (num_block_rows, num_block_cols, 32, 32) + # Compute amax for each 32x32 block - block_amax = torch.amax(torch.abs(x_blocks.to(torch.float32)), dim=(-1, -2)) # (num_block_rows, num_block_cols) - + block_amax = torch.amax( + torch.abs(x_blocks.to(torch.float32)), dim=(-1, -2) + ) # (num_block_rows, num_block_cols) + # Convert to E8M0 scale inverse block_scale_e8m0 = float_to_e8m0(block_amax) # (num_block_rows, num_block_cols) block_scale_inv = e8m0_to_scale_inv(block_scale_e8m0) # (num_block_rows, num_block_cols) - + # Expand scale to match input dimensions for quantization # For rowwise: each row in a block uses the same scale, scale shape is (M, num_block_cols) - scale_rowwise = block_scale_e8m0.repeat_interleave(MXFP8_BLOCK_SIZE, dim=0) # (M_padded, num_block_cols) - + scale_rowwise = block_scale_e8m0.repeat_interleave( + MXFP8_BLOCK_SIZE, dim=0 + ) # (M_padded, num_block_cols) + # For colwise: each column in a block uses the same scale, scale shape is (num_block_rows, N) - scale_colwise = block_scale_e8m0.repeat_interleave(MXFP8_BLOCK_SIZE, dim=1) # (num_block_rows, N_padded) - + scale_colwise = block_scale_e8m0.repeat_interleave( + MXFP8_BLOCK_SIZE, dim=1 + ) # (num_block_rows, N_padded) + # Compute scale inverse for quantization (broadcast over 32x32 blocks) - scale_inv_expanded = block_scale_inv.unsqueeze(-1).unsqueeze(-1) # (num_block_rows, num_block_cols, 1, 1) + scale_inv_expanded = block_scale_inv.unsqueeze(-1).unsqueeze( + -1 + ) # (num_block_rows, num_block_cols, 1, 1) scale_inv_expanded = scale_inv_expanded.expand(-1, -1, MXFP8_BLOCK_SIZE, MXFP8_BLOCK_SIZE) - + # Quantize: x_quantized = round(x / scale_inv) clamped to FP8 range x_blocks_float = x_blocks.to(torch.float32) x_scaled = x_blocks_float / scale_inv_expanded # Convert to FP8 (using PyTorch's float8_e4m3fn) x_quantized = x_scaled.to(torch.float8_e4m3fn) - + # Reshape back to original layout # Rowwise: (M_padded, N_padded) qx_rowwise = x_quantized.permute(0, 2, 1, 3).reshape(M_padded, N_padded) - + # Colwise: same data but transposed for column-major access qx_colwise = x_quantized.permute(0, 2, 1, 3).reshape(M_padded, N_padded) - + # Remove padding from outputs qx_rowwise = qx_rowwise[:M, :N] qx_colwise = qx_colwise[:M, :N] scale_rowwise = scale_rowwise[:M, :] scale_colwise = scale_colwise[:, :N] - + return qx_rowwise, scale_rowwise, qx_colwise, scale_colwise @@ -149,7 +156,7 @@ def check_mxfp8_2d_quantization_versus_reference( ) -> None: """ Test MXFP8 2D quantization against CPU reference implementation. - + Verifies: 1. scales match reference 2. 32x32 blocks share the same scale @@ -176,9 +183,7 @@ def check_mxfp8_2d_quantization_versus_reference( if use_cpp_allocator: x_mxfp8 = quantizer(x) else: - x_mxfp8 = quantizer.make_empty( - (M, N), dtype=x_dtype, device=device, requires_grad=False - ) + x_mxfp8 = quantizer.make_empty((M, N), dtype=x_dtype, device=device, requires_grad=False) x_mxfp8 = quantizer.update_quantized(x, x_mxfp8) # Extract GPU results @@ -193,12 +198,13 @@ def check_mxfp8_2d_quantization_versus_reference( gpu_scale_colwise = x_mxfp8._columnwise_scale_inv # Reference Quantization - ref_qx_rowwise, ref_scale_rowwise, ref_qx_colwise, ref_scale_colwise = \ + ref_qx_rowwise, ref_scale_rowwise, ref_qx_colwise, ref_scale_colwise = ( quantize_mxfp8_2d_reference(x) + ) num_block_rows = (M + MXFP8_BLOCK_SIZE - 1) // MXFP8_BLOCK_SIZE num_block_cols = (N + MXFP8_BLOCK_SIZE - 1) // MXFP8_BLOCK_SIZE - + # GPU scales may have padding, compare valid portion gpu_scale_rowwise_valid = gpu_scale_rowwise[:M, :num_block_cols] gpu_scale_colwise_valid = gpu_scale_colwise[:num_block_rows, :N] @@ -207,9 +213,10 @@ def check_mxfp8_2d_quantization_versus_reference( torch.testing.assert_close( gpu_scale_rowwise_valid, ref_scale_rowwise, - atol=0, rtol=0, + atol=0, + rtol=0, ) - + # 2. Verify 32x32 blocks share the same scale for bi in range(num_block_rows): for bj in range(num_block_cols): @@ -220,15 +227,15 @@ def check_mxfp8_2d_quantization_versus_reference( # All rows in block should have same scale for this column block block_rowwise_scales = gpu_scale_rowwise[row_start:row_end, bj] - assert torch.all(block_rowwise_scales == block_rowwise_scales[0]), ( - f"2D mode: Block ({bi},{bj}) rowwise scales should be identical" - ) + assert torch.all( + block_rowwise_scales == block_rowwise_scales[0] + ), f"2D mode: Block ({bi},{bj}) rowwise scales should be identical" # All columns in block should have same scale for this row block block_colwise_scales = gpu_scale_colwise[bi, col_start:col_end] - assert torch.all(block_colwise_scales == block_colwise_scales[0]), ( - f"2D mode: Block ({bi},{bj}) colwise scales should be identical" - ) + assert torch.all( + block_colwise_scales == block_colwise_scales[0] + ), f"2D mode: Block ({bi},{bj}) colwise scales should be identical" # Rowwise and colwise scales should match assert block_rowwise_scales[0] == block_colwise_scales[0], ( @@ -241,17 +248,19 @@ def check_mxfp8_2d_quantization_versus_reference( gpu_qx_rowwise_uint8 = gpu_qx_rowwise.view(torch.uint8)[:M, :N] gpu_qx_colwise_uint8 = gpu_qx_colwise.view(torch.uint8)[:M, :N] ref_qx_rowwise_uint8 = ref_qx_rowwise.view(torch.uint8) - + torch.testing.assert_close( gpu_qx_rowwise_uint8, ref_qx_rowwise_uint8, - atol=0, rtol=0, + atol=0, + rtol=0, ) torch.testing.assert_close( gpu_qx_colwise_uint8, ref_qx_rowwise_uint8, - atol=0, rtol=0, + atol=0, + rtol=0, ) @@ -298,6 +307,7 @@ def test_mxfp8_2d_quantization_versus_reference( # Recipe Configuration Tests # ============================================================================ + class TestMXFP8BlockScalingRecipe: """Tests for MXFP8BlockScaling recipe configuration.""" @@ -305,12 +315,12 @@ class TestMXFP8BlockScalingRecipe: def test_default_recipe_has_qparams(self): """Test that default MXFP8BlockScaling has QParams attributes.""" mxfp8_recipe = MXFP8BlockScaling() - + # Verify QParams attributes exist - assert hasattr(mxfp8_recipe, 'fp8_quant_fwd_inp') - assert hasattr(mxfp8_recipe, 'fp8_quant_fwd_weight') - assert hasattr(mxfp8_recipe, 'fp8_quant_bwd_grad') - + assert hasattr(mxfp8_recipe, "fp8_quant_fwd_inp") + assert hasattr(mxfp8_recipe, "fp8_quant_fwd_weight") + assert hasattr(mxfp8_recipe, "fp8_quant_bwd_grad") + # Verify they are QParams instances assert isinstance(mxfp8_recipe.fp8_quant_fwd_inp, QParams) assert isinstance(mxfp8_recipe.fp8_quant_fwd_weight, QParams) @@ -320,10 +330,10 @@ def test_default_recipe_has_qparams(self): def test_default_2d_quantization_disabled(self): """Test that 2D quantization is disabled by default.""" mxfp8_recipe = MXFP8BlockScaling() - + # By default, 2D quantization should be disabled assert mxfp8_recipe.enable_2d_quantization is False - + # QParams should reflect this assert mxfp8_recipe.fp8_quant_fwd_inp.mxfp8_2d_quantization is False assert mxfp8_recipe.fp8_quant_fwd_weight.mxfp8_2d_quantization is False @@ -334,10 +344,10 @@ def test_2d_quantization_enabled_only_for_weight(self): """Test that when 2D quantization is enabled, it only applies to weight.""" # Create recipe with 2D quantization enabled mxfp8_recipe = MXFP8BlockScaling(enable_2d_quantization=True) - + # enable_2d_quantization should be True assert mxfp8_recipe.enable_2d_quantization is True - + # Only weight should have 2D quantization enabled assert mxfp8_recipe.fp8_quant_fwd_inp.mxfp8_2d_quantization is False assert mxfp8_recipe.fp8_quant_fwd_weight.mxfp8_2d_quantization is True @@ -347,7 +357,7 @@ def test_2d_quantization_enabled_only_for_weight(self): def test_qparams_default_values(self): """Test that QParams have correct default values for MXFP8.""" mxfp8_recipe = MXFP8BlockScaling() - + # Check default values for all QParams for qparams in [ mxfp8_recipe.fp8_quant_fwd_inp, @@ -367,10 +377,10 @@ def test_recipe_repr_includes_2d_quantization(self): """Test that recipe __repr__ includes 2D quantization status.""" mxfp8_recipe_disabled = MXFP8BlockScaling(enable_2d_quantization=False) mxfp8_recipe_enabled = MXFP8BlockScaling(enable_2d_quantization=True) - + repr_disabled = repr(mxfp8_recipe_disabled) repr_enabled = repr(mxfp8_recipe_enabled) - + assert "enable_2d_quantization=False" in repr_disabled assert "enable_2d_quantization=True" in repr_enabled @@ -386,7 +396,7 @@ def test_mxfp8_quantizer_respects_2d_flag(): with_2d_quantization=False, ) assert quantizer_1d.with_2d_quantization is False - + # Test with 2D enabled quantizer_2d = MXFP8Quantizer( fp8_dtype=tex.DType.kFloat8E4M3, @@ -401,7 +411,7 @@ def test_mxfp8_quantizer_respects_2d_flag(): def test_mxfp8_recipe_state_creates_correct_quantizers(): """Test that MXFP8BlockScalingRecipeState creates quantizers with correct 2D settings.""" from transformer_engine.pytorch.quantization import MXFP8BlockScalingRecipeState - + # Test with 2D disabled recipe_1d = MXFP8BlockScaling(enable_2d_quantization=False) state_fwd_1d = MXFP8BlockScalingRecipeState( @@ -410,11 +420,11 @@ def test_mxfp8_recipe_state_creates_correct_quantizers(): num_quantizers=3, # input, weight, output ) quantizers_1d = state_fwd_1d.make_quantizers() - + # All quantizers should have 2D disabled for idx, q in enumerate(quantizers_1d): assert q.with_2d_quantization is False, f"Quantizer {idx} should have 2D disabled" - + # Test with 2D enabled recipe_2d = MXFP8BlockScaling(enable_2d_quantization=True) state_fwd_2d = MXFP8BlockScalingRecipeState( @@ -423,10 +433,12 @@ def test_mxfp8_recipe_state_creates_correct_quantizers(): num_quantizers=3, ) quantizers_2d = state_fwd_2d.make_quantizers() - + # Only weight (idx % 3 == 1) should have 2D enabled for idx, q in enumerate(quantizers_2d): if idx % 3 == 1: # weight assert q.with_2d_quantization is True, f"Weight quantizer {idx} should have 2D enabled" else: # input or output - assert q.with_2d_quantization is False, f"Non-weight quantizer {idx} should have 2D disabled" + assert ( + q.with_2d_quantization is False + ), f"Non-weight quantizer {idx} should have 2D disabled" diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index 4ffa2340d3..da9e5ffd75 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -180,7 +180,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) &mbar[0], is_master_thread); } - #pragma unroll +#pragma unroll for (int stage = 0; stage < STAGES; ++stage) { const size_t buff = stage % BUFFS_NUM; const size_t next_stage = stage + 1; @@ -287,7 +287,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } else { scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; } - scales_colwise[scale_idx] = biased_exponent; + scales_colwise[scale_idx] = biased_exponent; // In 2D mode, save scale to shared memory for rowwise pass // Each warp (processing one 32x32 block) writes one scale via lane 0 @@ -378,7 +378,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) #pragma unroll for (int e = 0; e < PACK_SIZE; e += 2) { const IType2 in_cached_2x = {in_cached[w].data.elt[e], - in_cached[w].data.elt[e + 1]}; + in_cached[w].data.elt[e + 1]}; ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); } } @@ -431,7 +431,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); - const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + const bool out_of_bounds = + (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); if (!out_of_bounds) { thread_amax = fmaxf(thread_amax, fabsf(elt)); } @@ -470,7 +471,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; } if (rowwise_scale_is_within_bounds) { - scales_rowwise[scale_idx] = biased_exponent; + scales_rowwise[scale_idx] = biased_exponent; } const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); @@ -599,7 +600,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) template void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, // TODO (ksivamani) - Tensor *output, Tensor *dbias, Tensor *workspace, const bool use_2d_quantization, cudaStream_t stream) { + Tensor *output, Tensor *dbias, Tensor *workspace, const bool use_2d_quantization, + cudaStream_t stream) { using namespace quantize_kernel; checkCuDriverContext(stream); @@ -817,15 +819,14 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, } } - if (use_2d_quantization) { - scaling_type = ScalingType::BIDIMENSIONAL; - } + if (use_2d_quantization) { scaling_type = ScalingType::BIDIMENSIONAL; } switch (scaling_type) { case ScalingType::ROWWISE: { - auto kernel = quantize_mxfp8_kernel; + auto kernel = + quantize_mxfp8_kernel; NVTE_CHECK_CUDA(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); @@ -838,9 +839,10 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, break; } case ScalingType::COLWISE: { - auto kernel = quantize_mxfp8_kernel; + auto kernel = + quantize_mxfp8_kernel; NVTE_CHECK_CUDA(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); @@ -854,21 +856,22 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, } case ScalingType::BIDIMENSIONAL: { TRANSFORMER_ENGINE_SWITCH_CONDITION( - use_2d_quantization, kIs2DBlockScaling, - - auto kernel = quantize_mxfp8_kernel; - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); - - kernel<<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, - workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - ); + use_2d_quantization, kIs2DBlockScaling, + + auto kernel = + quantize_mxfp8_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); + + kernel<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, + noop_ptr, workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError());); break; } } From 728efb2db961c65c8412048f0d58c68d7ac51ee0 Mon Sep 17 00:00:00 2001 From: kunlunl Date: Tue, 10 Feb 2026 14:24:00 +0800 Subject: [PATCH 3/4] Add check for BIDIMENSIONAL scaling Signed-off-by: kunlunl --- qa/L0_pytorch_unittest/test.sh | 1 + .../common/cast/mxfp8/quantize_mxfp8.cuh | 12 +++++++++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index a13dfada79..6ed18b0f75 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -51,6 +51,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.x NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_partial_cast.xml $TE_PATH/tests/pytorch/test_partial_cast.py || test_fail "test_partial_cast.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_mxfp8_2d_quantize.xml $TE_PATH/tests/pytorch/test_mxfp8_2d_quantize.py || test_fail "test_mxfp8_2d_quantize.py" if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index da9e5ffd75..6434d4031f 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -291,7 +291,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // In 2D mode, save scale to shared memory for rowwise pass // Each warp (processing one 32x32 block) writes one scale via lane 0 - if constexpr (kIs2DBlockScaling && ROWWISE_SCALING) { + if constexpr (kIs2DBlockScaling) { + static_assert(ROWWISE_SCALING, "ROWWISE_SCALING must be true when using 2D block scaling"); if (thread_lane == 0) { block_scales_2d[threadIdx.x / THREADS_PER_WARP] = biased_exponent; } @@ -448,7 +449,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // 2. Compute E8M0 scaling factor e8m0_t biased_exponent; - if constexpr (kIs2DBlockScaling && COLWISE_SCALING) { + if constexpr (kIs2DBlockScaling) { + static_assert(COLWISE_SCALING, "COLWISE_SCALING must be true when using 2D block scaling"); // In 2D mode with both scaling directions, use scale from colwise pass // Sync to ensure colwise writes to block_scales_2d are visible across warps __syncthreads(); @@ -819,7 +821,11 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, } } - if (use_2d_quantization) { scaling_type = ScalingType::BIDIMENSIONAL; } + if (use_2d_quantization) { + scaling_type = ScalingType::BIDIMENSIONAL; + NVTE_CHECK(scaling_type == ScalingType::BIDIMENSIONAL, + "Scaling type must be BIDIMENSIONAL when using 2D block scaling"); + } switch (scaling_type) { case ScalingType::ROWWISE: { From b8fd2145f4561feed69844ce1519ad816b39b0d0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Feb 2026 06:24:50 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index 6434d4031f..0387045042 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -821,7 +821,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, } } - if (use_2d_quantization) { + if (use_2d_quantization) { scaling_type = ScalingType::BIDIMENSIONAL; NVTE_CHECK(scaling_type == ScalingType::BIDIMENSIONAL, "Scaling type must be BIDIMENSIONAL when using 2D block scaling");