From 3ba991b447561f04fdad46f12b40ce87cf2deb0b Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Fri, 30 Jan 2026 16:20:09 -0800 Subject: [PATCH 1/2] fix subchannel fp8 + sp Signed-off-by: Chen Cui --- transformer_engine/pytorch/distributed.py | 6 +++++- transformer_engine/pytorch/utils.py | 7 +++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index f269e21b8c..fd519249a3 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -1100,7 +1100,11 @@ def _start_all_gather_fp8_blockwise( # Fall back to high-precision all-gather if FP8 is not supported if not quantizer.is_quantizable(inp) or quantizer.block_scaling_dim != 1: - out = torch.empty(out_shape, dtype=dtype, device=device) + # Dequantize if input is already quantized + if isinstance(inp, Float8BlockwiseQTensorStorage): + inp = inp.dequantize() + # Use dtype from actual input tensor (may differ from initial guess after dequantize) + out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device) torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False) out = quantizer(out) return out, None diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 47af9fabe1..3edc0d7c77 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -452,6 +452,13 @@ def assert_dim_for_all_gather( ) -> None: """Assert that tensor dimensions are supported for all-gather""" if with_all_gather: + # Float8BlockQuantizer has a fallback path in gather_along_first_dim + # that handles non-quantizable local tensors by all-gathering in + # high precision first, then quantizing the result. + from .tensor import Float8BlockQuantizer + + if isinstance(quantizer, Float8BlockQuantizer): + return assert quantizer.is_quantizable(tensor), ( "All-gather requires quantizable tensor for quantizer " + quantizer.__class__.__name__ ) From 637ba0f8cf91805e5963c565fe14b4339693873f Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Mon, 2 Feb 2026 19:59:13 +0000 Subject: [PATCH 2/2] Support sequence-parallel all-gather with small inputs Perform all-gather in high-precision if the input tensor is too small to quantize. Signed-off-by: Tim Moon --- transformer_engine/pytorch/distributed.py | 21 ++++++++++++------- .../pytorch/module/layernorm_linear.py | 2 -- .../pytorch/module/layernorm_mlp.py | 2 -- transformer_engine/pytorch/module/linear.py | 2 -- transformer_engine/pytorch/utils.py | 17 --------------- 5 files changed, 13 insertions(+), 31 deletions(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index fd519249a3..2208ff720c 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -1100,10 +1100,9 @@ def _start_all_gather_fp8_blockwise( # Fall back to high-precision all-gather if FP8 is not supported if not quantizer.is_quantizable(inp) or quantizer.block_scaling_dim != 1: - # Dequantize if input is already quantized - if isinstance(inp, Float8BlockwiseQTensorStorage): - inp = inp.dequantize() - # Use dtype from actual input tensor (may differ from initial guess after dequantize) + warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") + if isinstance(inp, QuantizedTensorStorage): + inp = inp.dequantize() # Dequantize if needed out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device) torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False) out = quantizer(out) @@ -1342,10 +1341,13 @@ def _all_gather_nvfp4( and quantizer is not None and not quantizer.is_quantizable(inp) ): + warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") + if isinstance(inp, QuantizedTensorStorage): + inp = inp.dequantize() # Dequantize if needed out = torch.empty( out_shape, - dtype=dtype, - device=device, + dtype=inp.dtype, + device=inp.device, memory_format=torch.contiguous_format, ) torch.distributed.all_gather_into_tensor(out, inp, group=process_group) @@ -1509,10 +1511,13 @@ def _all_gather_mxfp8( and quantizer is not None and not quantizer.is_quantizable(inp) ): + warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") + if isinstance(inp, QuantizedTensorStorage): + inp = inp.dequantize() # Dequantize if needed out = torch.empty( out_shape, - dtype=dtype, - device=device, + dtype=inp.dtype, + device=inp.device, memory_format=torch.contiguous_format, ) torch.distributed.all_gather_into_tensor(out, inp, group=process_group) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 702916696b..f79bc91c0a 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -29,7 +29,6 @@ from ..quantization import FP8GlobalStateManager from ..utils import ( assert_dim_for_fp8_exec, - assert_dim_for_all_gather, cast_if_needed, clear_tensor_data, divide, @@ -158,7 +157,6 @@ def forward( inputmat = inp if fp8: assert_dim_for_fp8_exec(inputmat, weight) - assert_dim_for_all_gather(inputmat, with_input_all_gather, input_quantizer) # Cast for native AMP nvtx_range_push(f"{nvtx_label}.norm_input_cast") diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index bec6744518..d9f046aa38 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -40,7 +40,6 @@ init_method_constant, cast_if_needed, assert_dim_for_fp8_exec, - assert_dim_for_all_gather, clear_tensor_data, requires_grad, needs_quantized_gemm, @@ -331,7 +330,6 @@ def _forward( inputmat = inp.view((-1, in_features)) if fp8: assert_dim_for_fp8_exec(inputmat, fc1_weight, fc2_weight) - assert_dim_for_all_gather(inputmat, sequence_parallel, fc1_input_quantizer) activation_func = _act_func( activation, FP8GlobalStateManager.get_fp8_recipe() if fp8 else None diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 23ad8cacb0..d7283cc047 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -34,7 +34,6 @@ requires_grad, needs_quantized_gemm, assert_dim_for_fp8_exec, - assert_dim_for_all_gather, nvtx_range_pop, nvtx_range_push, get_nvtx_range_context, @@ -175,7 +174,6 @@ def forward( own_quantized_input = False if fp8: assert_dim_for_fp8_exec(inputmat, weight) - assert_dim_for_all_gather(inputmat, with_input_all_gather_nccl, input_quantizer) if save_original_input: assert not isinstance( input_quantizer, Float8Quantizer diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 3edc0d7c77..0a74c75edd 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -447,23 +447,6 @@ def assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None: ) -def assert_dim_for_all_gather( - tensor: torch.Tensor, with_all_gather: bool, quantizer: Quantizer -) -> None: - """Assert that tensor dimensions are supported for all-gather""" - if with_all_gather: - # Float8BlockQuantizer has a fallback path in gather_along_first_dim - # that handles non-quantizable local tensors by all-gathering in - # high precision first, then quantizing the result. - from .tensor import Float8BlockQuantizer - - if isinstance(quantizer, Float8BlockQuantizer): - return - assert quantizer.is_quantizable(tensor), ( - "All-gather requires quantizable tensor for quantizer " + quantizer.__class__.__name__ - ) - - def is_bf16_compatible() -> bool: """Replaces torch.cuda.is_bf16_compatible() with an explicit check on device compute capability to enforce sm_80 or higher.