From 3afce1f133112d162cf66f680b83a7cd8d360ab0 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 2 Feb 2026 16:45:50 -0800 Subject: [PATCH 01/39] Add NVTE_KEEP_BACKWARD_UNQUANTIZED Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/base.py | 4 +- .../pytorch/module/grouped_linear.py | 36 +++-- .../pytorch/module/layernorm_linear.py | 80 +++++++--- .../pytorch/module/layernorm_mlp.py | 147 +++++++++++------- transformer_engine/pytorch/module/linear.py | 65 +++++--- .../pytorch/ops/basic/basic_linear.py | 48 ++++-- .../pytorch/ops/basic/quantize.py | 6 +- .../ops/fused/backward_activation_bias.py | 7 +- .../fused/forward_linear_bias_activation.py | 18 ++- .../ops/fused/forward_linear_bias_add.py | 18 ++- .../ops/fused/forward_linear_scale_add.py | 18 ++- .../ops/fused/userbuffers_forward_linear.py | 49 +++++- transformer_engine/pytorch/ops/fuser.py | 16 +- transformer_engine/pytorch/quantization.py | 5 + 14 files changed, 375 insertions(+), 142 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 841cdf04ca..4a2140718d 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1135,9 +1135,11 @@ def grad_output_preprocess( grad_output = grad_output.reshape((-1, grad_output.shape[-1])) grad_output = grad_output.contiguous() gather_grad_output = row_parallel_mode and ctx.sequence_parallel + keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) + use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized # Non-FP8 case: bgrad is fused with wgrad for this case. - if not ctx.fp8 and not ctx.debug: + if not use_fp8_bwd and not ctx.debug: if gather_grad_output: if not ctx.ub_overlap_ag: # Perform NCCL all-gather grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index c9ceb714e3..874eadeb36 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -96,6 +96,9 @@ def forward( save_original_input, debug, ) = non_tensor_args + keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() + if keep_backward_unquantized: + save_original_input = True num_gemms = len(m_splits) weights = weights_and_biases[:num_gemms] @@ -286,6 +289,7 @@ def forward( ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.keep_backward_unquantized = keep_backward_unquantized ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -294,7 +298,11 @@ def forward( ctx.inp_shape = inp.shape ctx.requires_dgrad = inp.requires_grad ctx.reduce_and_update_bwd_fp8_tensors = False - if ctx.fp8 and requires_grad(inp, weights[0], biases[0]): + if ( + ctx.fp8 + and not ctx.keep_backward_unquantized + and requires_grad(inp, weights[0], biases[0]) + ): ctx.reduce_and_update_bwd_fp8_tensors = ( ctx.reduce_and_update_bwd_fp8_tensors or FP8GlobalStateManager.is_first_fp8_module() @@ -318,6 +326,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], origin_weights = saved_tensors[2 * N : 3 * N] biases = saved_tensors[3 * N : 4 * N] main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] + keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) + use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized if ctx.cpu_offloading: if ctx.grad_added_to_main_grad: @@ -333,7 +343,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1]) grad_output = [None] * ctx.num_gemms grad_biases = [None] * ctx.num_gemms - if ctx.fp8 and not ctx.debug: + if use_fp8_bwd and not ctx.debug: if ctx.use_bias: grad_output_mats = torch.split(grad_output_view, ctx.m_splits) recipe = ctx.fp8_recipe @@ -384,7 +394,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.requires_dgrad: dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD - if ctx.fp8 or ctx.debug: + if use_fp8_bwd or ctx.debug: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): dgrad_gemm_use_split_accumulator = ( @@ -395,13 +405,15 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dtype=ctx.activation_dtype, device=ctx.device, ) - # Make sure weights are available in column-wise format - # for dgrad computation. - for weight in weights: - if isinstance(weight, QuantizedTensorStorage): - weight.update_usage(columnwise_usage=True) + weights_for_dgrad = weights if use_fp8_bwd else origin_weights + if use_fp8_bwd: + # Make sure weights are available in column-wise format + # for dgrad computation. + for weight in weights_for_dgrad: + if isinstance(weight, QuantizedTensorStorage): + weight.update_usage(columnwise_usage=True) general_grouped_gemm( - weights, + weights_for_dgrad, grad_output, [dgrad], ctx.grad_input_quantizers, @@ -415,7 +427,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.weights_requires_grad: wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD - if ctx.fp8: + if use_fp8_bwd: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_wgrad"): wgrad_gemm_use_split_accumulator = ( @@ -442,7 +454,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], else: input_quantizer.set_usage(rowwise=False, columnwise=True) inputmats: list - if ctx.fp8 and not ctx.debug: + if use_fp8_bwd and not ctx.debug: inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers) elif ctx.debug: inputmats = DebugQuantizer.multi_tensor_quantize( @@ -516,7 +528,7 @@ def handle_custom_ddp_from_mcore(weight, wgrad): if not ctx.use_bias or ( ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute() - and not ctx.fp8 + and not use_fp8_bwd ): grad_biases = [None] * ctx.num_gemms diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 702916696b..28842fc315 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -141,6 +141,7 @@ def forward( symmetric_ar_type, debug, ) = non_tensor_args + keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() # NVTX label for profiling nvtx_label = "transformer_engine._LayerNormLinear.forward" @@ -200,7 +201,10 @@ def forward( if fp8: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") - input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) + input_quantizer.set_usage( + rowwise=True, + columnwise=backward_needs_input and not keep_backward_unquantized, + ) if with_input_all_gather and input_quantizer.supports_only_rowwise_all_gather(): # All-gather is not supported with FP8 column-wise data input_quantizer.set_usage(columnwise=False) @@ -213,6 +217,7 @@ def forward( and not debug and not return_layernorm_output and not return_layernorm_output_gathered + and not keep_backward_unquantized and not custom # TODO(negvet): and not FP8GlobalStateManager.get_fp8_recipe().custom() ) @@ -236,6 +241,7 @@ def forward( ln_out_return = None if return_layernorm_output or return_layernorm_output_gathered: ln_out_return = ln_out + ln_out_hp = ln_out if keep_backward_unquantized else None # ------------------------------------------------------ # Prepare GEMM input tensor @@ -409,13 +415,14 @@ def forward( # ------------------------------------------------------ if is_grad_enabled: + ln_out_to_save = ln_out_hp if keep_backward_unquantized else ln_out ctx.weight_quantizer = weight_quantizer ctx.ln_out_needs_gather = ( weight.requires_grad and parallel_mode == "column" and sequence_parallel ) # Input with column-wise usage is needed for wgrad GEMM. - if backward_needs_input: + if backward_needs_input and not keep_backward_unquantized: if isinstance(ln_out, QuantizedTensorStorage): # For sequence parallel in vanilla FP8, rowwise data is # to gather the input. For MXFP8, columnwise only data @@ -427,7 +434,7 @@ def forward( ln_out.update_usage(rowwise_usage=False) if cpu_offloading: - mark_activation_offload(inputmat, mu, rsigma, ln_out) + mark_activation_offload(inputmat, mu, rsigma, ln_out_to_save) # Scatter intermediate/activation tensors saved for the backward pass # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already @@ -439,7 +446,7 @@ def forward( mu, rsigma, weightmat if fp8 and not is_weight_param_quantized else None, - ln_out if weight.requires_grad else None, + ln_out_to_save if weight.requires_grad else None, ) nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") @@ -466,7 +473,7 @@ def forward( weight, bias, ln_weight, - ln_out, + ln_out_to_save, mu, rsigma, ) @@ -493,6 +500,7 @@ def forward( ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.keep_backward_unquantized = keep_backward_unquantized ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -515,7 +523,11 @@ def forward( ctx.requires_dgrad = inp_requires_grad ctx.normalization = normalization ctx.reduce_and_update_bwd_fp8_tensors = False - if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias): + if ( + ctx.fp8 + and not ctx.keep_backward_unquantized + and requires_grad(inp, ln_weight, ln_bias, weight, bias) + ): _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase(): @@ -592,6 +604,15 @@ def backward( if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: origin_weight.main_grad = main_grad + keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) + use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized + use_quantized_bwd = use_fp8_bwd or ctx.debug + if keep_backward_unquantized: + ctx.ub_overlap_ag = False + ctx.ub_overlap_rs_dgrad = False + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None ub_obj_dgrad = None @@ -601,23 +622,23 @@ def backward( dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] if ctx.ub_overlap_ag: # Overlap grad_output all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG elif ctx.ub_overlap_rs_dgrad: # Overlap dgrad reduce-scatter with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.RS else: if ctx.ub_bulk_dgrad: # Overlap inputmat all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG if ctx.ub_bulk_wgrad: # Overlap dgrad reduce-scatter with wgrad compute - ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) + ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", use_fp8_bwd) ub_type_wgrad = tex.CommOverlapType.RS # -------------------------------------------------- @@ -628,7 +649,7 @@ def backward( # Configure quantizer for grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_output_quantizer is not None: + if ctx.grad_output_quantizer is not None and use_quantized_bwd: quantizer = ctx.grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -665,7 +686,7 @@ def backward( ln_out_total_work = None if ctx.ln_out_needs_gather: quantizer = None - if ctx.input_quantizer is not None: + if ctx.input_quantizer is not None and use_quantized_bwd: quantizer = ctx.input_quantizer if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually @@ -703,18 +724,22 @@ def backward( # Make sure required data is available if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) - if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorStorage): + if ( + use_quantized_bwd + and ctx.weight_quantizer is not None + and isinstance(weight, QuantizedTensorStorage) + ): weight.update_usage(columnwise_usage=True) # Choose whether to use GEMM kernel with split accumulator use_split_accumulator = _2X_ACC_DGRAD - if ctx.fp8: + if use_fp8_bwd: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator # Update grad input quantizer - if ctx.grad_input_quantizer is not None: + if ctx.grad_input_quantizer is not None and use_quantized_bwd: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # Output buffers for Userbuffers reduce-scatter @@ -730,12 +755,13 @@ def backward( # dgrad GEMM # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") + weight_for_dgrad = weight if use_quantized_bwd else origin_weight gemm_out, *_, reduce_scatter_out = general_gemm( - weight, + weight_for_dgrad, grad_output, layout="NN", grad=True, - quantization_params=ctx.grad_input_quantizer, + quantization_params=ctx.grad_input_quantizer if use_quantized_bwd else None, out=gemm_out, out_dtype=ctx.activation_dtype, use_split_accumulator=use_split_accumulator, @@ -782,7 +808,11 @@ def backward( # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer): + if ( + use_fp8_bwd + and ctx.ub_overlap_ag + and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer) + ): # UB does not support pipelined overlapping grad output # all-gather with wgrad GEMM. Also, we can't # convert row-scaled MXFP8 to column-scaled, so we @@ -794,7 +824,7 @@ def backward( dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream() # This object is separate from the ub_obj_wgrad object which is passed to the GEMM - ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) + ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", use_fp8_bwd) ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) @@ -820,14 +850,14 @@ def backward( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) ln_out_total = ctx.input_quantizer(ln_out_total) - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -836,7 +866,7 @@ def backward( # Figure out whether to use split accumulator use_split_accumulator = _2X_ACC_WGRAD - if ctx.fp8: + if use_fp8_bwd: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_wgrad"): use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator @@ -862,7 +892,9 @@ def backward( "out_dtype": ( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": ctx.grad_weight_quantizer, + "quantization_params": ( + ctx.grad_weight_quantizer if use_quantized_bwd else None + ), "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(weight, "overwrite_main_grad", False) @@ -870,7 +902,7 @@ def backward( ), "layout": "NT", "out": main_grad if ctx.fuse_wgrad_accumulation else None, - "bias": (bias if (grad_bias is None and not ctx.fp8) else None), + "bias": (bias if (grad_bias is None and not use_fp8_bwd) else None), "use_split_accumulator": use_split_accumulator, "grad": True, "ub": ub_obj_wgrad, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index bec6744518..2b3a72b803 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -232,6 +232,7 @@ def _forward( debug, recompute_for_bwd, ) = non_tensor_args + keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() # if grad is enabled and this is not the bwd stage, we must save this so bwd knows which path to take if is_grad_enabled and not recompute_for_bwd: @@ -350,8 +351,10 @@ def _forward( # bwd needs fc1 input when grad is enabled, fc1 needs grad, and either # 1) no checkpointing # or 2) doing the recomputation with checkpointing - backwards_needs_fc1_input = fc1_weight.requires_grad and ( - (is_grad_enabled and not checkpoint) or is_recomputation + backwards_needs_fc1_input = ( + fc1_weight.requires_grad + and ((is_grad_enabled and not checkpoint) or is_recomputation) + and not keep_backward_unquantized ) device = inp.device @@ -394,6 +397,7 @@ def _forward( and not debug and not return_layernorm_output and not return_layernorm_output_gathered + and not keep_backward_unquantized and not custom ) @@ -415,6 +419,7 @@ def _forward( # do not return layernorm output unless 1) no checkpointing or 2) checkpointing but not recomputing if (return_layernorm_output or return_layernorm_output_gathered) and not is_recomputation: ln_out_return = ln_out + ln_out_hp = ln_out if keep_backward_unquantized else None # Prepare GEMM input # Note: Cast to expected dtype and perform tensor-parallel communication @@ -611,6 +616,10 @@ def _forward( if fc2_input_quantizer is not None: fc2_input_quantizer.calibrate(act_out) + act_out_hp = act_out + if keep_backward_unquantized and is_grad_enabled and fc1_out is not None: + act_out_hp = activation_func(fc1_out, None, **act_params) + # we want to skip fc2 computation if we are checkpointing and recomputing, # otherwise we compute fc2 if not (is_recomputation and checkpoint): @@ -686,22 +695,30 @@ def _forward( # if we are not checkpointing, then we must save this if grad is enabled if is_grad_enabled and not save_for_checkpoint: + ln_out_to_save = ln_out_hp if keep_backward_unquantized else ln_out + act_out_to_save = act_out_hp if keep_backward_unquantized else act_out ctx.fc1_weight_quantizer = fc1_weight_quantizer ctx.fc2_weight_quantizer = fc2_weight_quantizer if not fc1_weight.requires_grad: if not return_layernorm_output: - clear_tensor_data(ln_out) - ln_out = None + clear_tensor_data(ln_out_to_save) + ln_out_to_save = None if not fc2_weight.requires_grad: - clear_tensor_data(act_out) - act_out = None + clear_tensor_data(act_out_to_save) + act_out_to_save = None if not checkpoint: # regular path, no selective activation checkpointing if cpu_offloading: mark_activation_offload( - inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out + inputmat, + mu, + rsigma, + ln_out_to_save, + fc1_out, + fc1_out_without_bias, + act_out_to_save, ) # Scatter intermediate/activation tensors saved for the backward pass @@ -714,9 +731,9 @@ def _forward( fsdp_group, mu, rsigma, - ln_out, + ln_out_to_save, fc1_out_without_bias if bias_gelu_fusion else fc1_out, - act_out, + act_out_to_save, ( fc1_weight_final if fp8 and not isinstance(fc1_weight, Float8Tensor) @@ -744,13 +761,13 @@ def _forward( tensors_to_save, tensor_objects = prepare_for_saving( inputmat, ln_weight, - ln_out, + ln_out_to_save, fc1_weight_final, fc1_weight, fc1_bias, fc1_out, fc1_out_without_bias, - act_out, + act_out_to_save, fc2_weight_final, fc2_weight, fc2_bias, @@ -798,6 +815,7 @@ def _forward( ctx.activation_params = activation_params ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.keep_backward_unquantized = keep_backward_unquantized ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -826,8 +844,12 @@ def _forward( ) ctx.normalization = normalization ctx.reduce_and_update_bwd_fp8_tensors = False - if ctx.fp8 and requires_grad( - inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias + if ( + ctx.fp8 + and not ctx.keep_backward_unquantized + and requires_grad( + inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias + ) ): _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() @@ -996,6 +1018,16 @@ def backward( origin_fc1_weight.main_grad = fc1_weight_main_grad origin_fc2_weight.main_grad = fc2_weight_main_grad + keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) + use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized + use_quantized_bwd = use_fp8_bwd or ctx.debug + fp8_recipe_bwd = ctx.fp8_recipe if use_fp8_bwd else None + if keep_backward_unquantized: + ctx.ub_overlap_ag = False + ctx.ub_overlap_rs_dgrad = False + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + # TODO: Fix this # pylint: disable=fixme # Gather saved autograd context tensors when running with FSDP # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already @@ -1015,7 +1047,7 @@ def backward( # Choose whether to use GEMM kernel with split accumulator dgrad_use_split_accumulator = _2X_ACC_DGRAD wgrad_use_split_accumulator = _2X_ACC_WGRAD - if ctx.fp8: + if use_fp8_bwd: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): dgrad_use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator @@ -1029,7 +1061,7 @@ def backward( # Configure quantizer for FC2 grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.fc2_grad_output_quantizer is not None: + if ctx.fc2_grad_output_quantizer is not None and use_quantized_bwd: quantizer = ctx.fc2_grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -1042,7 +1074,7 @@ def backward( # Note: Cast to expected dtype and perform tensor-parallel communication ub_obj_fc2_dgrad = None if ctx.ub_overlap_ag: - ub_obj_fc2_dgrad = get_ub("fc2_dgrad", ctx.fp8) + ub_obj_fc2_dgrad = get_ub("fc2_dgrad", use_fp8_bwd) ctx.ub_obj_gradout = ub_obj_fc2_dgrad ( grad_output, @@ -1057,7 +1089,7 @@ def backward( ub_obj_fc1_dgrad = None if ctx.fc1_weight_requires_grad and ctx.tensor_parallel and ctx.sequence_parallel: quantizer = None - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: quantizer = ctx.fc1_input_quantizer if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # If data is in FP8, we compute FP8 transposes manually @@ -1066,7 +1098,7 @@ def backward( # wgrad GEMM requires input with column-wise usage quantizer.set_usage(rowwise=False, columnwise=True) if ctx.ub_bulk_dgrad: - ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8) + ub_obj_fc1_dgrad = get_ub("fc1_dgrad", use_fp8_bwd) ln_out_total, _ = fill_userbuffers_buffer_for_all_gather( ub_obj_fc1_dgrad, ln_out, @@ -1103,7 +1135,7 @@ def backward( # 5 high-precision unfused: gemm, activation, FC1_bias + FC1_gemm # 6 fp8 unfused: gemm, activation, FC1_bias + FC1_gemm fc2_dgrad_gemm_gelu_fusion = ( - not ctx.fp8 + not use_fp8_bwd and (ctx.activation == "gelu") and (not ctx.bias_gelu_fusion) and (not ctx.debug) @@ -1112,20 +1144,23 @@ def backward( # Make sure required data is available if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) - if ctx.fc2_weight_quantizer is not None and isinstance( - ctx.fc2_weight, QuantizedTensorStorage + if ( + use_quantized_bwd + and ctx.fc2_weight_quantizer is not None + and isinstance(ctx.fc2_weight, QuantizedTensorStorage) ): ctx.fc2_weight.update_usage(columnwise_usage=True) # Perform GEMM + fc2_weight_for_dgrad = fc2_weight if use_fp8_bwd else origin_fc2_weight gemm_output, *_ = general_gemm( - fc2_weight, + fc2_weight_for_dgrad, grad_output, layout="NN", grad=True, quantization_params=( ctx.fc1_grad_input_quantizer - if fc2_dgrad_gemm_gelu_fusion or ctx.debug + if (fc2_dgrad_gemm_gelu_fusion or ctx.debug) and use_quantized_bwd else None ), # high precision to activation out_dtype=ctx.activation_dtype, @@ -1157,7 +1192,11 @@ def backward( # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if ctx.ub_overlap_ag and isinstance(ctx.fc2_grad_output_quantizer, MXFP8Quantizer): + if ( + use_fp8_bwd + and ctx.ub_overlap_ag + and isinstance(ctx.fc2_grad_output_quantizer, MXFP8Quantizer) + ): # UB does not support pipelined overlapping grad output # all-gather with wgrad GEMM. Also, we can't # convert row-scaled MXFP8 to column-scaled, so we @@ -1170,7 +1209,7 @@ def backward( ub_obj_fc2_dgrad.get_communication_stream() ) - ub_obj_fc2_wgrad = get_ub("fc2_wgrad", ctx.fp8) + ub_obj_fc2_wgrad = get_ub("fc2_wgrad", use_fp8_bwd) ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True) @@ -1193,14 +1232,14 @@ def backward( # Prepare input tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: if isinstance(act_out, QuantizedTensorStorage): act_out.update_usage(columnwise_usage=True) else: ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True) act_out = ctx.fc2_input_quantizer(act_out) - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -1209,7 +1248,7 @@ def backward( # Whether to set grad arg in general_gemm grad_arg = True - if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling(): + if use_fp8_bwd and fp8_recipe_bwd.float8_block_scaling(): grad_arg = False # Arguments to include in wgrad GEMM closure @@ -1219,7 +1258,9 @@ def backward( if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": ctx.fc2_grad_weight_quantizer, # wgrad in high precision + "quantization_params": ( + ctx.fc2_grad_weight_quantizer if use_quantized_bwd else None + ), # wgrad in high precision "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(fc1_weight, "overwrite_main_grad", False) @@ -1256,8 +1297,8 @@ def fc2_wgrad_gemm( # Update grad bias if needed if fc2_bias_grad is None: if ( - ctx.fp8 - and ctx.fp8_recipe.float8_block_scaling() + use_fp8_bwd + and fp8_recipe_bwd.float8_block_scaling() and fc2_bias is not None ): # BGRAD not fused with GEMM for float8 blockwise gemm. @@ -1277,12 +1318,12 @@ def fc2_wgrad_gemm( act_params = ctx.activation_params or {} fc1_bias_grad = None fuse_gemm_and_bias_fc1_wgrad = False - if ctx.fc1_grad_output_quantizer is not None: + if ctx.fc1_grad_output_quantizer is not None and use_quantized_bwd: ctx.fc1_grad_output_quantizer.set_usage(rowwise=True, columnwise=True) if ctx.bias_gelu_fusion: # Fusion: gemm, bias + gelu assert ctx.activation == "gelu" - assert not ctx.fp8 + assert not use_fp8_bwd fc1_bias_grad, dact = bgrad_dgelu_fused(fc2_dgrad, fc1_out_without_bias, fc1_bias) if ctx.fc1_grad_output_quantizer is not None: dact = ctx.fc1_grad_output_quantizer(dact) @@ -1292,13 +1333,10 @@ def fc2_wgrad_gemm( fc1_bias_grad = dact.sum(dim=0) dact = ctx.fc1_grad_output_quantizer(dact) elif ( - _act_func(ctx.activation, ctx.fp8_recipe if ctx.fp8 else None)[2] is not None - and ctx.fp8 + _act_func(ctx.activation, fp8_recipe_bwd)[2] is not None and use_fp8_bwd ): # Fusion: gemm, bias + gelu + quantize - dbias_dact_quantize_func = _act_func( - ctx.activation, ctx.fp8_recipe if ctx.fp8 else None - )[2] + dbias_dact_quantize_func = _act_func(ctx.activation, fp8_recipe_bwd)[2] fc1_bias_grad, dact = dbias_dact_quantize_func( fc2_dgrad, fc1_out.to(ctx.activation_dtype), @@ -1308,18 +1346,16 @@ def fc2_wgrad_gemm( else: # Fusion: gemm + gelu, if not fc2_dgrad_gemm_gelu_fusion: - activation_func_bwd = _act_func( - ctx.activation, ctx.fp8_recipe if ctx.fp8 else None - )[1] + activation_func_bwd = _act_func(ctx.activation, fp8_recipe_bwd)[1] dact = activation_func_bwd( fc2_dgrad, fc1_out.to(ctx.activation_dtype), None, **act_params ) # activation in high precision - if ctx.fp8: + if use_fp8_bwd: # TODO float8 blockwise current scaling (as well as custom quantizers) has no bgrad fusion for now if ( isinstance(ctx.fc1_grad_output_quantizer, Float8BlockQuantizer) - or ctx.fp8_recipe.custom() + or fp8_recipe_bwd.custom() ): fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0) dact = ctx.fc1_grad_output_quantizer(dact) @@ -1347,16 +1383,16 @@ def fc2_wgrad_gemm( fc1_dgrad_shape = [reduce(multiply_op, inputmat.shape[:-1]), inputmat.shape[-1]] if ctx.ub_overlap_rs_dgrad: # Overlap DGRAD+RS - ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8) + ub_obj_fc1_dgrad = get_ub("fc1_dgrad", use_fp8_bwd) ub_type_fc1_dgrad = tex.CommOverlapType.RS else: if ctx.ub_bulk_dgrad: # Overlap ln_out all-gather with DGRAD compute - ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8) + ub_obj_fc1_dgrad = get_ub("fc1_dgrad", use_fp8_bwd) ub_type_fc1_dgrad = tex.CommOverlapType.AG if ctx.ub_bulk_wgrad: # Overlap FC1 DGRAD reduce-scatter with WGRAD compute - ub_obj_fc1_wgrad = get_ub("fc1_wgrad", ctx.fp8) + ub_obj_fc1_wgrad = get_ub("fc1_wgrad", use_fp8_bwd) ub_type_fc1_wgrad = tex.CommOverlapType.RS # -------------------------------------------------- @@ -1364,8 +1400,10 @@ def fc2_wgrad_gemm( # -------------------------------------------------- # Make sure required data is available - if ctx.fc1_weight_quantizer is not None and isinstance( - ctx.fc1_weight_quantizer, QuantizedTensorStorage + if ( + use_quantized_bwd + and ctx.fc1_weight_quantizer is not None + and isinstance(ctx.fc1_weight_quantizer, QuantizedTensorStorage) ): ctx.fc1_weight.update_usage(columnwise_usage=True) @@ -1380,12 +1418,13 @@ def fc2_wgrad_gemm( gemm_out = ub_obj_fc1_wgrad.get_buffer(local_chunk=False) # dgrad GEMM + fc1_weight_for_dgrad = fc1_weight if use_fp8_bwd else origin_fc1_weight gemm_out, *_, reduce_scatter_out = general_gemm( - fc1_weight, + fc1_weight_for_dgrad, dact, out=gemm_out, out_dtype=ctx.activation_dtype, - quantization_params=ctx.fc1_grad_input_quantizer, + quantization_params=ctx.fc1_grad_input_quantizer if use_quantized_bwd else None, layout="NN", grad=True, use_split_accumulator=dgrad_use_split_accumulator, @@ -1434,7 +1473,7 @@ def fc2_wgrad_gemm( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: @@ -1444,7 +1483,7 @@ def fc2_wgrad_gemm( # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: if isinstance(dact, QuantizedTensorStorage): dact.update_usage(columnwise_usage=True) else: @@ -1466,7 +1505,9 @@ def fc2_wgrad_gemm( if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": ctx.fc1_grad_weight_quantizer, + "quantization_params": ( + ctx.fc1_grad_weight_quantizer if use_quantized_bwd else None + ), "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(fc2_weight, "overwrite_main_grad", False) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 23ad8cacb0..b4bad849c1 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -129,6 +129,9 @@ def forward( save_original_input, debug, ) = non_tensor_args + keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() + if keep_backward_unquantized: + save_original_input = True # NVTX label for profiling nvtx_label = "transformer_engine._Linear.forward" @@ -443,6 +446,7 @@ def forward( ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.keep_backward_unquantized = keep_backward_unquantized ctx.input_quantizer = input_quantizer ctx.grad_input_quantizer = grad_input_quantizer ctx.grad_weight_quantizer = grad_weight_quantizer @@ -479,7 +483,7 @@ def forward( ctx.reduce_and_update_bwd_fp8_tensors = False ctx.owns_input = saved_inputmat is not inp - if ctx.fp8 and requires_grad(inp, weight, bias): + if ctx.fp8 and not ctx.keep_backward_unquantized and requires_grad(inp, weight, bias): _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase(): @@ -536,6 +540,15 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) nvtx_range_pop(f"{nvtx_label}.fsdp_gather") + keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) + use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized + use_quantized_bwd = use_fp8_bwd or ctx.debug + if keep_backward_unquantized: + ctx.ub_overlap_ag = False + ctx.ub_overlap_rs_dgrad = False + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None ub_obj_dgrad = None @@ -545,23 +558,23 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] if ctx.ub_overlap_ag: # Overlap grad_output all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG elif ctx.ub_overlap_rs_dgrad: # Overlap dgrad reduce-scatter with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.RS else: if ctx.ub_bulk_dgrad: # Overlap inputmat all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG if ctx.ub_bulk_wgrad: # Overlap dgrad reduce-scatter with wgrad compute - ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) + ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", use_fp8_bwd) ub_type_wgrad = tex.CommOverlapType.RS # -------------------------------------------------- @@ -575,7 +588,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Configure quantizer for grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_output_quantizer is not None: + if ctx.grad_output_quantizer is not None and use_quantized_bwd: quantizer = ctx.grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -594,6 +607,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], not ctx.use_bias and not ctx.requires_wgrad and ctx.grad_output_quantizer is not None + and use_quantized_bwd ): ctx.grad_output_quantizer.set_usage(columnwise=False) @@ -623,7 +637,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat_total = None inputmat_total_work = None if ctx.requires_wgrad: - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: if isinstance(inputmat, QuantizedTensorStorage): # Input tensor is already quantized pass @@ -649,7 +663,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat = cast_if_needed(inputmat, ctx.activation_dtype) if ctx.backward_input_needs_gather: quantizer = None - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: quantizer = ctx.input_quantizer if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually @@ -690,20 +704,22 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Make sure required data is available if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) - if ctx.weight_quantizer is not None and isinstance( - weight_fp8, QuantizedTensorStorage + if ( + use_quantized_bwd + and ctx.weight_quantizer is not None + and isinstance(weight_fp8, QuantizedTensorStorage) ): weight_fp8.update_usage(columnwise_usage=True) # Choose whether to use GEMM kernel with split accumulator use_split_accumulator = _2X_ACC_DGRAD - if ctx.fp8: + if use_fp8_bwd: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator # Update grad input quantizer - if ctx.grad_input_quantizer is not None: + if ctx.grad_input_quantizer is not None and use_quantized_bwd: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # Output buffers for Userbuffers reduce-scatter @@ -720,12 +736,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") + weight_for_dgrad = weight_fp8 if use_quantized_bwd else weight gemm_out, *_, reduce_scatter_out = general_gemm( - weight_fp8, + weight_for_dgrad, grad_output, layout="NN", grad=True, - quantization_params=ctx.grad_input_quantizer, + quantization_params=ctx.grad_input_quantizer if use_quantized_bwd else None, out=gemm_out, out_dtype=ctx.activation_dtype, use_split_accumulator=use_split_accumulator, @@ -774,7 +791,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if inputmat_total_work is not None: inputmat_total_work.wait() inputmat_total_work = None - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: if isinstance(inputmat_total, QuantizedTensorStorage): inputmat_total.update_usage(columnwise_usage=True) else: @@ -784,7 +801,11 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer): + if ( + use_fp8_bwd + and ctx.ub_overlap_ag + and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer) + ): # UB does not support pipelined overlapping grad output # all-gather with wgrad GEMM. Also, we can't # convert row-scaled MXFP8 to column-scaled, so we @@ -796,7 +817,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream() # This object is separate from the ub_obj_wgrad object which is passed to the GEMM - ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) + ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", use_fp8_bwd) ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) @@ -816,7 +837,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream ) - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -825,7 +846,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Figure out whether to use split accumulator use_split_accumulator = _2X_ACC_WGRAD - if ctx.fp8: + if use_fp8_bwd: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_wgrad"): use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator @@ -851,7 +872,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], "out_dtype": ( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": ctx.grad_weight_quantizer, + "quantization_params": ( + ctx.grad_weight_quantizer if use_quantized_bwd else None + ), "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(weight, "overwrite_main_grad", False) @@ -859,7 +882,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ), "layout": "NT", "out": main_grad if ctx.fuse_wgrad_accumulation else None, - "bias": (bias if (grad_bias is None and not ctx.fp8) else None), + "bias": (bias if (grad_bias is None and not use_fp8_bwd) else None), "use_split_accumulator": use_split_accumulator, "grad": True, "ub": ub_obj_wgrad, diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index e640f3ffb1..a9a6895112 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -332,12 +332,14 @@ def pre_fuser_forward(self, *, requires_grad: bool) -> None: # Note: We cache the quantized input for backward pass, # but discard the quantized weights. weight_requires_grad = requires_grad and self.weight.requires_grad + keep_backward_unquantized = FP8GlobalStateManager.keep_backward_unquantized() + columnwise_usage = weight_requires_grad and not keep_backward_unquantized input_quantizer = self.get_quantizer("forward", 0) weight_quantizer = self.get_quantizer("forward", 1) grad_output_quantizer = self.get_quantizer("backward", 0) - input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) weight_quantizer.set_usage(rowwise=True, columnwise=False) - grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + grad_output_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: super().reset_recipe_state(recipe=recipe) @@ -420,6 +422,7 @@ def _functional_forward( tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, sequence_parallel: bool = False, with_quantized_compute: bool = False, + keep_backward_unquantized: bool = False, input_quantizer: Optional[Quantizer] = None, weight_quantizer: Optional[Quantizer] = None, output_quantizer: Optional[Quantizer] = None, @@ -459,6 +462,8 @@ def _functional_forward( distributing along inner dimension (embedding dim) with_quantized_compute: bool, default = `False` Whether to perform compute with quantized data. + keep_backward_unquantized: bool, default = `False` + Whether to skip quantized backward and use high precision. input_quantizer: Quantizer, optional Builder class for quantized input tensor. weight_quantizer: Quantizer, optional @@ -510,7 +515,10 @@ def _functional_forward( if with_quantized_compute: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") - input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + input_quantizer.set_usage( + rowwise=True, + columnwise=weight_requires_grad and not keep_backward_unquantized, + ) if with_x_all_gather: input_quantizer.set_usage(columnwise=False) x, x_async = gather_along_first_dim( @@ -542,7 +550,10 @@ def _functional_forward( elif with_quantized_compute and not is_quantized_tensor(w): if weight_quantizer is None: raise ValueError("Missing quantizer for weight tensor") - weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + weight_quantizer.set_usage( + rowwise=True, + columnwise=input_requires_grad and not keep_backward_unquantized, + ) w = weight_quantizer(w) # Check output tensor @@ -611,14 +622,23 @@ def _functional_forward( # Prepare weight tensor for backward pass if input_requires_grad: - if w is not weight and with_quantized_compute and is_quantized_tensor(w): + if ( + w is not weight + and with_quantized_compute + and is_quantized_tensor(w) + and not keep_backward_unquantized + ): w.update_usage(rowwise_usage=False, columnwise_usage=True) else: w = None # Prepare input tensor for backward pass if weight_requires_grad: - if with_quantized_compute and is_quantized_tensor(x_local): + if ( + with_quantized_compute + and is_quantized_tensor(x_local) + and not keep_backward_unquantized + ): if not (isinstance(x_local, Float8TensorStorage) and with_x_all_gather): # FP8 does not support all-gather of transpose data x_local.update_usage(rowwise_usage=False, columnwise_usage=True) @@ -968,6 +988,9 @@ def op_forward( grad_output_quantizer = self.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + keep_backward_unquantized = ( + with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + ) # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -984,6 +1007,7 @@ def op_forward( tensor_parallel_group=self.tensor_parallel_group, sequence_parallel=self.sequence_parallel, with_quantized_compute=with_quantized_compute, + keep_backward_unquantized=keep_backward_unquantized, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -993,10 +1017,16 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: + saved_input = input_ if keep_backward_unquantized else x_local + if not weight_requires_grad: + saved_input = None + saved_weight = self.weight if keep_backward_unquantized else w + if not input_requires_grad: + saved_weight = None if is_cpu_offload_enabled(): - mark_activation_offload(x_local) - ctx.save_for_backward(x_local, w) - ctx.with_quantized_compute = with_quantized_compute + mark_activation_offload(saved_input) + ctx.save_for_backward(saved_input, saved_weight) + ctx.with_quantized_compute = with_quantized_compute and not keep_backward_unquantized ctx.input_quantizer = input_quantizer ctx.weight_quantizer = weight_quantizer ctx.grad_output_quantizer = grad_output_quantizer diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py index d126b554b5..cc26022d0e 100644 --- a/transformer_engine/pytorch/ops/basic/quantize.py +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -57,7 +57,11 @@ def op_forward( # Check if FP8 is enabled fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() quantize_forward = fp8_enabled and self._quantize_forward - quantize_backward = fp8_enabled and self._quantize_backward + quantize_backward = ( + fp8_enabled + and self._quantize_backward + and not FP8GlobalStateManager.keep_backward_unquantized() + ) # Quantize if needed out = input_ diff --git a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py index 4ab082d32b..59e9af14f4 100644 --- a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py +++ b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py @@ -10,7 +10,7 @@ import torch import transformer_engine_torch as tex -from transformer_engine.pytorch.quantization import Recipe +from transformer_engine.pytorch.quantization import Recipe, FP8GlobalStateManager from transformer_engine.pytorch.ops.basic import Bias from transformer_engine.pytorch.ops.basic.activation import ( _ActivationOperation, @@ -105,7 +105,10 @@ def fuse_backward_ops( """ # Check if recipe supports bias activation fusion - if recipe is None: + if recipe is None or ( + FP8GlobalStateManager.is_fp8_enabled() + and FP8GlobalStateManager.keep_backward_unquantized() + ): return ops # Scan through ops, fusing if possible diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index dfc11a19e7..0a28d00706 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -92,6 +92,9 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + keep_backward_unquantized = ( + with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + ) # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -109,6 +112,7 @@ def fuser_forward( tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, with_quantized_compute=with_quantized_compute, + keep_backward_unquantized=keep_backward_unquantized, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -118,10 +122,18 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + saved_input = input_ if keep_backward_unquantized else x_local + if not weight_requires_grad: + saved_input = None + saved_weight = linear_op.weight if keep_backward_unquantized else w + if not input_requires_grad: + saved_weight = None if is_cpu_offload_enabled(): - mark_activation_offload(x_local) - linear_op_ctx.save_for_backward(x_local, w) - linear_op_ctx.with_quantized_compute = with_quantized_compute + mark_activation_offload(saved_input) + linear_op_ctx.save_for_backward(saved_input, saved_weight) + linear_op_ctx.with_quantized_compute = ( + with_quantized_compute and not keep_backward_unquantized + ) linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 2dfc0566b7..41ae096e54 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -86,6 +86,9 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + keep_backward_unquantized = ( + with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + ) # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -106,6 +109,7 @@ def fuser_forward( tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, with_quantized_compute=with_quantized_compute, + keep_backward_unquantized=keep_backward_unquantized, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -115,10 +119,18 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + saved_input = input_ if keep_backward_unquantized else x_local + if not weight_requires_grad: + saved_input = None + saved_weight = linear_op.weight if keep_backward_unquantized else w + if not input_requires_grad: + saved_weight = None if is_cpu_offload_enabled(): - mark_activation_offload(x_local) - linear_op_ctx.save_for_backward(x_local, w) - linear_op_ctx.with_quantized_compute = with_quantized_compute + mark_activation_offload(saved_input) + linear_op_ctx.save_for_backward(saved_input, saved_weight) + linear_op_ctx.with_quantized_compute = ( + with_quantized_compute and not keep_backward_unquantized + ) linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index ae4bdd4b19..b06f5ad36a 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -65,6 +65,9 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + keep_backward_unquantized = ( + with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + ) # Get extra input tensor for add operation extra_input = basic_op_extra_inputs[2][0] @@ -87,6 +90,7 @@ def fuser_forward( tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, with_quantized_compute=with_quantized_compute, + keep_backward_unquantized=keep_backward_unquantized, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -96,10 +100,18 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + saved_input = input_ if keep_backward_unquantized else x_local + if not weight_requires_grad: + saved_input = None + saved_weight = linear_op.weight if keep_backward_unquantized else w + if not input_requires_grad: + saved_weight = None if is_cpu_offload_enabled(): - mark_activation_offload(x_local) - linear_op_ctx.save_for_backward(x_local, w) - linear_op_ctx.with_quantized_compute = with_quantized_compute + mark_activation_offload(saved_input) + linear_op_ctx.save_for_backward(saved_input, saved_weight) + linear_op_ctx.with_quantized_compute = ( + with_quantized_compute and not keep_backward_unquantized + ) linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 6ef9bf083b..8c04fca17c 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -94,6 +94,7 @@ def _functional_forward( tensor_parallel_size: Optional[int] = None, sequence_parallel: bool = False, with_quantized_compute: bool = False, + keep_backward_unquantized: bool = False, input_quantizer: Optional[Quantizer] = None, weight_quantizer: Optional[Quantizer] = None, output_quantizer: Optional[Quantizer] = None, @@ -126,6 +127,8 @@ def _functional_forward( distributing along inner dimension (embedding dim) with_quantized_compute: bool, default = `False` Whether to perform compute with quantized data. + keep_backward_unquantized: bool, default = `False` + Whether to skip quantized backward and use high precision. input_quantizer: Quantizer, optional Builder class for quantized input tensor. weight_quantizer: Quantizer, optional @@ -200,7 +203,10 @@ def _functional_forward( if with_ub_all_gather: if input_quantizer is not None: if not is_quantized_tensor(x_local): - input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + input_quantizer.set_usage( + rowwise=True, + columnwise=weight_requires_grad and not keep_backward_unquantized, + ) if isinstance( input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) ): @@ -216,7 +222,10 @@ def _functional_forward( else: if with_quantized_compute: if not is_quantized_tensor(x_local): - input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + input_quantizer.set_usage( + rowwise=True, + columnwise=weight_requires_grad and not keep_backward_unquantized, + ) x_local = input_quantizer(x_local) else: x_local = maybe_dequantize(x_local, dtype) @@ -227,7 +236,10 @@ def _functional_forward( if not with_quantized_compute: w = maybe_dequantize(w, dtype) elif with_quantized_compute and not is_quantized_tensor(w): - weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + weight_quantizer.set_usage( + rowwise=True, + columnwise=input_requires_grad and not keep_backward_unquantized, + ) w = weight_quantizer(w) # Construct output tensor if needed @@ -257,14 +269,23 @@ def _functional_forward( # Prepare weight tensor for backward pass if input_requires_grad: - if w is not weight and with_quantized_compute and is_quantized_tensor(w): + if ( + w is not weight + and with_quantized_compute + and is_quantized_tensor(w) + and not keep_backward_unquantized + ): w.update_usage(rowwise_usage=False, columnwise_usage=True) else: w = None # Prepare input tensor for backward pass if weight_requires_grad: - if with_quantized_compute and is_quantized_tensor(x_local): + if ( + with_quantized_compute + and is_quantized_tensor(x_local) + and not keep_backward_unquantized + ): if not (isinstance(x_local, Float8TensorStorage) and with_ub_all_gather): # FP8 does not support all-gather of transpose data x_local.update_usage(rowwise_usage=False, columnwise_usage=True) @@ -311,6 +332,9 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + keep_backward_unquantized = ( + with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + ) if with_quantized_compute: recipe = FP8GlobalStateManager.get_fp8_recipe() if not any((recipe.delayed(), recipe.float8_current_scaling(), recipe.mxfp8())): @@ -340,6 +364,7 @@ def fuser_forward( tensor_parallel_size=self.tensor_parallel_size, sequence_parallel=self.sequence_parallel, with_quantized_compute=with_quantized_compute, + keep_backward_unquantized=keep_backward_unquantized, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=None, # Not supported @@ -352,10 +377,18 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + saved_input = input_ if keep_backward_unquantized else x_local + if not weight_requires_grad: + saved_input = None + saved_weight = linear_op.weight if keep_backward_unquantized else w + if not input_requires_grad: + saved_weight = None if is_cpu_offload_enabled(): - mark_activation_offload(x_local) - linear_op_ctx.save_for_backward(x_local, w) - linear_op_ctx.with_quantized_compute = with_quantized_compute + mark_activation_offload(saved_input) + linear_op_ctx.save_for_backward(saved_input, saved_weight) + linear_op_ctx.with_quantized_compute = ( + with_quantized_compute and not keep_backward_unquantized + ) linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 7fe6ea37ed..035233fb55 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -109,6 +109,10 @@ def forward( # Apply forward ops x = input_ extra_outputs = [None] * fuser._num_basic_ops + keep_backward_unquantized = ( + FP8GlobalStateManager.is_fp8_enabled() + and FP8GlobalStateManager.keep_backward_unquantized() + ) for op, basic_op_idxs in fuser._forward_ops: # Set if backward op is required @@ -120,7 +124,7 @@ def forward( prev_op_idx = basic_op_idxs[0] - 1 prev_op = fuser._basic_ops[prev_op_idx] if prev_op_idx >= 0 else None prev_op_grad_output_quantizer = None - if prev_op is not None: + if prev_op is not None and not keep_backward_unquantized: prev_op_grad_output_quantizer = prev_op.get_grad_output_quantizer() next_op_idx = basic_op_idxs[-1] + 1 next_op = fuser._basic_ops[next_op_idx] if next_op_idx < fuser._num_basic_ops else None @@ -286,7 +290,15 @@ def backward( grad_extra_inputs_flat.extend(dxs) # Update FP8 scaling factors - if func_ctx.is_first_module and not _is_graph_capturing(): + keep_backward_unquantized = ( + FP8GlobalStateManager.is_fp8_enabled() + and FP8GlobalStateManager.keep_backward_unquantized() + ) + if ( + func_ctx.is_first_module + and not keep_backward_unquantized + and not _is_graph_capturing() + ): FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) return ( diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index eba547afb0..9806871ef6 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -430,6 +430,11 @@ def with_high_precision_init_val(cls) -> bool: """Should the high precision initial values be stored with FP8 parameters""" return cls.HIGH_PRECISION_INIT_VAL + @classmethod + def keep_backward_unquantized(cls) -> bool: + """Should backward skip FP8 quantization and use high precision""" + return bool(int(os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0"))) + @classmethod def fp8_graph_capturing(cls) -> bool: """Is CUDA graph capture under way?""" From 72149be265539dc732cf8656e4ed2d21ecde374c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Feb 2026 00:49:22 +0000 Subject: [PATCH 02/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/layernorm_mlp.py | 4 +--- transformer_engine/pytorch/ops/fuser.py | 6 +----- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 2b3a72b803..8e8749b237 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1332,9 +1332,7 @@ def fc2_wgrad_gemm( dact = dact_func(fc2_dgrad, fc1_out.to(ctx.activation_dtype), None, **act_params) fc1_bias_grad = dact.sum(dim=0) dact = ctx.fc1_grad_output_quantizer(dact) - elif ( - _act_func(ctx.activation, fp8_recipe_bwd)[2] is not None and use_fp8_bwd - ): + elif _act_func(ctx.activation, fp8_recipe_bwd)[2] is not None and use_fp8_bwd: # Fusion: gemm, bias + gelu + quantize dbias_dact_quantize_func = _act_func(ctx.activation, fp8_recipe_bwd)[2] fc1_bias_grad, dact = dbias_dact_quantize_func( diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 035233fb55..a692bc9487 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -294,11 +294,7 @@ def backward( FP8GlobalStateManager.is_fp8_enabled() and FP8GlobalStateManager.keep_backward_unquantized() ) - if ( - func_ctx.is_first_module - and not keep_backward_unquantized - and not _is_graph_capturing() - ): + if func_ctx.is_first_module and not keep_backward_unquantized and not _is_graph_capturing(): FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) return ( From 927d482136a3f297813f7bdb3b36d678e44faf6c Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 09:36:13 -0800 Subject: [PATCH 03/39] Disable ub and clean up Signed-off-by: Ziang Li --- .../pytorch/module/layernorm_linear.py | 9 ++-- .../pytorch/module/layernorm_mlp.py | 13 ++--- transformer_engine/pytorch/module/linear.py | 17 +++---- .../ops/fused/userbuffers_forward_linear.py | 49 +++---------------- 4 files changed, 25 insertions(+), 63 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 28842fc315..66e67522f6 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -608,6 +608,7 @@ def backward( use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized use_quantized_bwd = use_fp8_bwd or ctx.debug if keep_backward_unquantized: + # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True ctx.ub_overlap_ag = False ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False @@ -622,23 +623,23 @@ def backward( dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] if ctx.ub_overlap_ag: # Overlap grad_output all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG elif ctx.ub_overlap_rs_dgrad: # Overlap dgrad reduce-scatter with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.RS else: if ctx.ub_bulk_dgrad: # Overlap inputmat all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG if ctx.ub_bulk_wgrad: # Overlap dgrad reduce-scatter with wgrad compute - ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", use_fp8_bwd) + ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) ub_type_wgrad = tex.CommOverlapType.RS # -------------------------------------------------- diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 8e8749b237..5d72508d0d 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1023,6 +1023,7 @@ def backward( use_quantized_bwd = use_fp8_bwd or ctx.debug fp8_recipe_bwd = ctx.fp8_recipe if use_fp8_bwd else None if keep_backward_unquantized: + # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True ctx.ub_overlap_ag = False ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False @@ -1074,7 +1075,7 @@ def backward( # Note: Cast to expected dtype and perform tensor-parallel communication ub_obj_fc2_dgrad = None if ctx.ub_overlap_ag: - ub_obj_fc2_dgrad = get_ub("fc2_dgrad", use_fp8_bwd) + ub_obj_fc2_dgrad = get_ub("fc2_dgrad", ctx.fp8) ctx.ub_obj_gradout = ub_obj_fc2_dgrad ( grad_output, @@ -1098,7 +1099,7 @@ def backward( # wgrad GEMM requires input with column-wise usage quantizer.set_usage(rowwise=False, columnwise=True) if ctx.ub_bulk_dgrad: - ub_obj_fc1_dgrad = get_ub("fc1_dgrad", use_fp8_bwd) + ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8) ln_out_total, _ = fill_userbuffers_buffer_for_all_gather( ub_obj_fc1_dgrad, ln_out, @@ -1192,11 +1193,7 @@ def backward( # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if ( - use_fp8_bwd - and ctx.ub_overlap_ag - and isinstance(ctx.fc2_grad_output_quantizer, MXFP8Quantizer) - ): + if ctx.ub_overlap_ag and isinstance(ctx.fc2_grad_output_quantizer, MXFP8Quantizer): # UB does not support pipelined overlapping grad output # all-gather with wgrad GEMM. Also, we can't # convert row-scaled MXFP8 to column-scaled, so we @@ -1209,7 +1206,7 @@ def backward( ub_obj_fc2_dgrad.get_communication_stream() ) - ub_obj_fc2_wgrad = get_ub("fc2_wgrad", use_fp8_bwd) + ub_obj_fc2_wgrad = get_ub("fc2_wgrad", ctx.fp8) ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b4bad849c1..a03e9ac4d5 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -544,6 +544,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized use_quantized_bwd = use_fp8_bwd or ctx.debug if keep_backward_unquantized: + # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True ctx.ub_overlap_ag = False ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False @@ -558,23 +559,23 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] if ctx.ub_overlap_ag: # Overlap grad_output all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG elif ctx.ub_overlap_rs_dgrad: # Overlap dgrad reduce-scatter with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.RS else: if ctx.ub_bulk_dgrad: # Overlap inputmat all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG if ctx.ub_bulk_wgrad: # Overlap dgrad reduce-scatter with wgrad compute - ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", use_fp8_bwd) + ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) ub_type_wgrad = tex.CommOverlapType.RS # -------------------------------------------------- @@ -801,11 +802,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if ( - use_fp8_bwd - and ctx.ub_overlap_ag - and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer) - ): + if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer): # UB does not support pipelined overlapping grad output # all-gather with wgrad GEMM. Also, we can't # convert row-scaled MXFP8 to column-scaled, so we @@ -817,7 +814,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream() # This object is separate from the ub_obj_wgrad object which is passed to the GEMM - ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", use_fp8_bwd) + ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 8c04fca17c..6ef9bf083b 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -94,7 +94,6 @@ def _functional_forward( tensor_parallel_size: Optional[int] = None, sequence_parallel: bool = False, with_quantized_compute: bool = False, - keep_backward_unquantized: bool = False, input_quantizer: Optional[Quantizer] = None, weight_quantizer: Optional[Quantizer] = None, output_quantizer: Optional[Quantizer] = None, @@ -127,8 +126,6 @@ def _functional_forward( distributing along inner dimension (embedding dim) with_quantized_compute: bool, default = `False` Whether to perform compute with quantized data. - keep_backward_unquantized: bool, default = `False` - Whether to skip quantized backward and use high precision. input_quantizer: Quantizer, optional Builder class for quantized input tensor. weight_quantizer: Quantizer, optional @@ -203,10 +200,7 @@ def _functional_forward( if with_ub_all_gather: if input_quantizer is not None: if not is_quantized_tensor(x_local): - input_quantizer.set_usage( - rowwise=True, - columnwise=weight_requires_grad and not keep_backward_unquantized, - ) + input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) if isinstance( input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) ): @@ -222,10 +216,7 @@ def _functional_forward( else: if with_quantized_compute: if not is_quantized_tensor(x_local): - input_quantizer.set_usage( - rowwise=True, - columnwise=weight_requires_grad and not keep_backward_unquantized, - ) + input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) x_local = input_quantizer(x_local) else: x_local = maybe_dequantize(x_local, dtype) @@ -236,10 +227,7 @@ def _functional_forward( if not with_quantized_compute: w = maybe_dequantize(w, dtype) elif with_quantized_compute and not is_quantized_tensor(w): - weight_quantizer.set_usage( - rowwise=True, - columnwise=input_requires_grad and not keep_backward_unquantized, - ) + weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) w = weight_quantizer(w) # Construct output tensor if needed @@ -269,23 +257,14 @@ def _functional_forward( # Prepare weight tensor for backward pass if input_requires_grad: - if ( - w is not weight - and with_quantized_compute - and is_quantized_tensor(w) - and not keep_backward_unquantized - ): + if w is not weight and with_quantized_compute and is_quantized_tensor(w): w.update_usage(rowwise_usage=False, columnwise_usage=True) else: w = None # Prepare input tensor for backward pass if weight_requires_grad: - if ( - with_quantized_compute - and is_quantized_tensor(x_local) - and not keep_backward_unquantized - ): + if with_quantized_compute and is_quantized_tensor(x_local): if not (isinstance(x_local, Float8TensorStorage) and with_ub_all_gather): # FP8 does not support all-gather of transpose data x_local.update_usage(rowwise_usage=False, columnwise_usage=True) @@ -332,9 +311,6 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - keep_backward_unquantized = ( - with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() - ) if with_quantized_compute: recipe = FP8GlobalStateManager.get_fp8_recipe() if not any((recipe.delayed(), recipe.float8_current_scaling(), recipe.mxfp8())): @@ -364,7 +340,6 @@ def fuser_forward( tensor_parallel_size=self.tensor_parallel_size, sequence_parallel=self.sequence_parallel, with_quantized_compute=with_quantized_compute, - keep_backward_unquantized=keep_backward_unquantized, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=None, # Not supported @@ -377,18 +352,10 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input = input_ if keep_backward_unquantized else x_local - if not weight_requires_grad: - saved_input = None - saved_weight = linear_op.weight if keep_backward_unquantized else w - if not input_requires_grad: - saved_weight = None if is_cpu_offload_enabled(): - mark_activation_offload(saved_input) - linear_op_ctx.save_for_backward(saved_input, saved_weight) - linear_op_ctx.with_quantized_compute = ( - with_quantized_compute and not keep_backward_unquantized - ) + mark_activation_offload(x_local) + linear_op_ctx.save_for_backward(x_local, w) + linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer From cc85b606cf31717ccb7684b21125e858505413d0 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 09:37:57 -0800 Subject: [PATCH 04/39] Drop fuser changes Signed-off-by: Ziang Li --- transformer_engine/pytorch/ops/fuser.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index a692bc9487..7fe6ea37ed 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -109,10 +109,6 @@ def forward( # Apply forward ops x = input_ extra_outputs = [None] * fuser._num_basic_ops - keep_backward_unquantized = ( - FP8GlobalStateManager.is_fp8_enabled() - and FP8GlobalStateManager.keep_backward_unquantized() - ) for op, basic_op_idxs in fuser._forward_ops: # Set if backward op is required @@ -124,7 +120,7 @@ def forward( prev_op_idx = basic_op_idxs[0] - 1 prev_op = fuser._basic_ops[prev_op_idx] if prev_op_idx >= 0 else None prev_op_grad_output_quantizer = None - if prev_op is not None and not keep_backward_unquantized: + if prev_op is not None: prev_op_grad_output_quantizer = prev_op.get_grad_output_quantizer() next_op_idx = basic_op_idxs[-1] + 1 next_op = fuser._basic_ops[next_op_idx] if next_op_idx < fuser._num_basic_ops else None @@ -290,11 +286,7 @@ def backward( grad_extra_inputs_flat.extend(dxs) # Update FP8 scaling factors - keep_backward_unquantized = ( - FP8GlobalStateManager.is_fp8_enabled() - and FP8GlobalStateManager.keep_backward_unquantized() - ) - if func_ctx.is_first_module and not keep_backward_unquantized and not _is_graph_capturing(): + if func_ctx.is_first_module and not _is_graph_capturing(): FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) return ( From fe24f95c16d8c5a46b363f612afbcbc7fd676b6d Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 09:56:43 -0800 Subject: [PATCH 05/39] Replace use_quantized_bwd with use_fp8_bwd Signed-off-by: Ziang Li --- .../pytorch/module/layernorm_linear.py | 19 +++++++------ .../pytorch/module/layernorm_mlp.py | 27 +++++++++---------- transformer_engine/pytorch/module/linear.py | 23 ++++++++-------- 3 files changed, 33 insertions(+), 36 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 66e67522f6..b759c152ec 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -606,7 +606,6 @@ def backward( keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized - use_quantized_bwd = use_fp8_bwd or ctx.debug if keep_backward_unquantized: # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True ctx.ub_overlap_ag = False @@ -650,7 +649,7 @@ def backward( # Configure quantizer for grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_output_quantizer is not None and use_quantized_bwd: + if ctx.grad_output_quantizer is not None and use_fp8_bwd: quantizer = ctx.grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -687,7 +686,7 @@ def backward( ln_out_total_work = None if ctx.ln_out_needs_gather: quantizer = None - if ctx.input_quantizer is not None and use_quantized_bwd: + if ctx.input_quantizer is not None and use_fp8_bwd: quantizer = ctx.input_quantizer if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually @@ -726,7 +725,7 @@ def backward( if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) if ( - use_quantized_bwd + use_fp8_bwd and ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorStorage) ): @@ -740,7 +739,7 @@ def backward( use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator # Update grad input quantizer - if ctx.grad_input_quantizer is not None and use_quantized_bwd: + if ctx.grad_input_quantizer is not None and use_fp8_bwd: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # Output buffers for Userbuffers reduce-scatter @@ -756,13 +755,13 @@ def backward( # dgrad GEMM # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") - weight_for_dgrad = weight if use_quantized_bwd else origin_weight + weight_for_dgrad = weight if use_fp8_bwd else origin_weight gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, grad_output, layout="NN", grad=True, - quantization_params=ctx.grad_input_quantizer if use_quantized_bwd else None, + quantization_params=ctx.grad_input_quantizer if use_fp8_bwd else None, out=gemm_out, out_dtype=ctx.activation_dtype, use_split_accumulator=use_split_accumulator, @@ -851,14 +850,14 @@ def backward( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if use_quantized_bwd: + if use_fp8_bwd: if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) ln_out_total = ctx.input_quantizer(ln_out_total) - if use_quantized_bwd: + if use_fp8_bwd: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -894,7 +893,7 @@ def backward( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), "quantization_params": ( - ctx.grad_weight_quantizer if use_quantized_bwd else None + ctx.grad_weight_quantizer if use_fp8_bwd else None ), "accumulate": ( accumulate_wgrad_into_param_main_grad diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 5d72508d0d..1414bb4afa 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1020,7 +1020,6 @@ def backward( keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized - use_quantized_bwd = use_fp8_bwd or ctx.debug fp8_recipe_bwd = ctx.fp8_recipe if use_fp8_bwd else None if keep_backward_unquantized: # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True @@ -1062,7 +1061,7 @@ def backward( # Configure quantizer for FC2 grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.fc2_grad_output_quantizer is not None and use_quantized_bwd: + if ctx.fc2_grad_output_quantizer is not None and use_fp8_bwd: quantizer = ctx.fc2_grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -1090,7 +1089,7 @@ def backward( ub_obj_fc1_dgrad = None if ctx.fc1_weight_requires_grad and ctx.tensor_parallel and ctx.sequence_parallel: quantizer = None - if use_quantized_bwd: + if use_fp8_bwd: quantizer = ctx.fc1_input_quantizer if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # If data is in FP8, we compute FP8 transposes manually @@ -1146,7 +1145,7 @@ def backward( if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) if ( - use_quantized_bwd + use_fp8_bwd and ctx.fc2_weight_quantizer is not None and isinstance(ctx.fc2_weight, QuantizedTensorStorage) ): @@ -1161,7 +1160,7 @@ def backward( grad=True, quantization_params=( ctx.fc1_grad_input_quantizer - if (fc2_dgrad_gemm_gelu_fusion or ctx.debug) and use_quantized_bwd + if (fc2_dgrad_gemm_gelu_fusion or ctx.debug) and use_fp8_bwd else None ), # high precision to activation out_dtype=ctx.activation_dtype, @@ -1229,14 +1228,14 @@ def backward( # Prepare input tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if use_quantized_bwd: + if use_fp8_bwd: if isinstance(act_out, QuantizedTensorStorage): act_out.update_usage(columnwise_usage=True) else: ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True) act_out = ctx.fc2_input_quantizer(act_out) - if use_quantized_bwd: + if use_fp8_bwd: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -1256,7 +1255,7 @@ def backward( else ctx.activation_dtype ), "quantization_params": ( - ctx.fc2_grad_weight_quantizer if use_quantized_bwd else None + ctx.fc2_grad_weight_quantizer if use_fp8_bwd else None ), # wgrad in high precision "accumulate": ( accumulate_wgrad_into_param_main_grad @@ -1315,7 +1314,7 @@ def fc2_wgrad_gemm( act_params = ctx.activation_params or {} fc1_bias_grad = None fuse_gemm_and_bias_fc1_wgrad = False - if ctx.fc1_grad_output_quantizer is not None and use_quantized_bwd: + if ctx.fc1_grad_output_quantizer is not None and use_fp8_bwd: ctx.fc1_grad_output_quantizer.set_usage(rowwise=True, columnwise=True) if ctx.bias_gelu_fusion: # Fusion: gemm, bias + gelu @@ -1396,7 +1395,7 @@ def fc2_wgrad_gemm( # Make sure required data is available if ( - use_quantized_bwd + use_fp8_bwd and ctx.fc1_weight_quantizer is not None and isinstance(ctx.fc1_weight_quantizer, QuantizedTensorStorage) ): @@ -1419,7 +1418,7 @@ def fc2_wgrad_gemm( dact, out=gemm_out, out_dtype=ctx.activation_dtype, - quantization_params=ctx.fc1_grad_input_quantizer if use_quantized_bwd else None, + quantization_params=ctx.fc1_grad_input_quantizer if use_fp8_bwd else None, layout="NN", grad=True, use_split_accumulator=dgrad_use_split_accumulator, @@ -1468,7 +1467,7 @@ def fc2_wgrad_gemm( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if use_quantized_bwd: + if use_fp8_bwd: if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: @@ -1478,7 +1477,7 @@ def fc2_wgrad_gemm( # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if use_quantized_bwd: + if use_fp8_bwd: if isinstance(dact, QuantizedTensorStorage): dact.update_usage(columnwise_usage=True) else: @@ -1501,7 +1500,7 @@ def fc2_wgrad_gemm( else ctx.activation_dtype ), "quantization_params": ( - ctx.fc1_grad_weight_quantizer if use_quantized_bwd else None + ctx.fc1_grad_weight_quantizer if use_fp8_bwd else None ), "accumulate": ( accumulate_wgrad_into_param_main_grad diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index a03e9ac4d5..6ecc647626 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -542,7 +542,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized - use_quantized_bwd = use_fp8_bwd or ctx.debug if keep_backward_unquantized: # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True ctx.ub_overlap_ag = False @@ -589,7 +588,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Configure quantizer for grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_output_quantizer is not None and use_quantized_bwd: + if ctx.grad_output_quantizer is not None and use_fp8_bwd: quantizer = ctx.grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -608,7 +607,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], not ctx.use_bias and not ctx.requires_wgrad and ctx.grad_output_quantizer is not None - and use_quantized_bwd + and use_fp8_bwd ): ctx.grad_output_quantizer.set_usage(columnwise=False) @@ -638,7 +637,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat_total = None inputmat_total_work = None if ctx.requires_wgrad: - if use_quantized_bwd: + if use_fp8_bwd: if isinstance(inputmat, QuantizedTensorStorage): # Input tensor is already quantized pass @@ -664,7 +663,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat = cast_if_needed(inputmat, ctx.activation_dtype) if ctx.backward_input_needs_gather: quantizer = None - if use_quantized_bwd: + if use_fp8_bwd: quantizer = ctx.input_quantizer if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually @@ -706,7 +705,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) if ( - use_quantized_bwd + use_fp8_bwd and ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensorStorage) ): @@ -720,7 +719,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator # Update grad input quantizer - if ctx.grad_input_quantizer is not None and use_quantized_bwd: + if ctx.grad_input_quantizer is not None and use_fp8_bwd: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # Output buffers for Userbuffers reduce-scatter @@ -737,13 +736,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") - weight_for_dgrad = weight_fp8 if use_quantized_bwd else weight + weight_for_dgrad = weight_fp8 if use_fp8_bwd else weight gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, grad_output, layout="NN", grad=True, - quantization_params=ctx.grad_input_quantizer if use_quantized_bwd else None, + quantization_params=ctx.grad_input_quantizer if use_fp8_bwd else None, out=gemm_out, out_dtype=ctx.activation_dtype, use_split_accumulator=use_split_accumulator, @@ -792,7 +791,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if inputmat_total_work is not None: inputmat_total_work.wait() inputmat_total_work = None - if use_quantized_bwd: + if use_fp8_bwd: if isinstance(inputmat_total, QuantizedTensorStorage): inputmat_total.update_usage(columnwise_usage=True) else: @@ -834,7 +833,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream ) - if use_quantized_bwd: + if use_fp8_bwd: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -870,7 +869,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), "quantization_params": ( - ctx.grad_weight_quantizer if use_quantized_bwd else None + ctx.grad_weight_quantizer if use_fp8_bwd else None ), "accumulate": ( accumulate_wgrad_into_param_main_grad From 5ca361584796e6010768f8c91ee9b265a379f8bd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Feb 2026 17:57:32 +0000 Subject: [PATCH 06/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/layernorm_linear.py | 4 +--- transformer_engine/pytorch/module/layernorm_mlp.py | 4 +--- transformer_engine/pytorch/module/linear.py | 4 +--- 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index b759c152ec..bdfeff056b 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -892,9 +892,7 @@ def backward( "out_dtype": ( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": ( - ctx.grad_weight_quantizer if use_fp8_bwd else None - ), + "quantization_params": (ctx.grad_weight_quantizer if use_fp8_bwd else None), "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(weight, "overwrite_main_grad", False) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 1414bb4afa..c5f7051fa1 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1499,9 +1499,7 @@ def fc2_wgrad_gemm( if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": ( - ctx.fc1_grad_weight_quantizer if use_fp8_bwd else None - ), + "quantization_params": (ctx.fc1_grad_weight_quantizer if use_fp8_bwd else None), "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(fc2_weight, "overwrite_main_grad", False) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 6ecc647626..1ce4fac445 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -868,9 +868,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], "out_dtype": ( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": ( - ctx.grad_weight_quantizer if use_fp8_bwd else None - ), + "quantization_params": (ctx.grad_weight_quantizer if use_fp8_bwd else None), "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(weight, "overwrite_main_grad", False) From 5ba76747ab50fc5cd8cccd3e5bfa9fcf53fe58bb Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 10:30:04 -0800 Subject: [PATCH 07/39] Ignore keep_backward_unquantized if delayed scaling Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/grouped_linear.py | 1 + transformer_engine/pytorch/module/linear.py | 1 + transformer_engine/pytorch/quantization.py | 3 +++ 3 files changed, 5 insertions(+) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 874eadeb36..0ccacd9b17 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -98,6 +98,7 @@ def forward( ) = non_tensor_args keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() if keep_backward_unquantized: + # Note, keep_backward_unquantized is ignored when delayed scaling is used save_original_input = True num_gemms = len(m_splits) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 1ce4fac445..49b78382d2 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -131,6 +131,7 @@ def forward( ) = non_tensor_args keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() if keep_backward_unquantized: + # Note, keep_backward_unquantized is ignored when delayed scaling is used save_original_input = True # NVTX label for profiling diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 9806871ef6..e8f6dafdb5 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -433,6 +433,9 @@ def with_high_precision_init_val(cls) -> bool: @classmethod def keep_backward_unquantized(cls) -> bool: """Should backward skip FP8 quantization and use high precision""" + recipe = cls.get_fp8_recipe() + if recipe.delayed(): + return False return bool(int(os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0"))) @classmethod From 02b7b2ae23f01942968e59eda24a47d74ee832a3 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 10:39:02 -0800 Subject: [PATCH 08/39] Refactor ignoring NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling is used Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/grouped_linear.py | 2 +- transformer_engine/pytorch/module/linear.py | 2 +- transformer_engine/pytorch/quantization.py | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 0ccacd9b17..9e2eb60ea5 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -98,7 +98,7 @@ def forward( ) = non_tensor_args keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() if keep_backward_unquantized: - # Note, keep_backward_unquantized is ignored when delayed scaling is used + # Note, NVTE_KEEP_BACKWARD_UNQUANTIZED is ignored when delayed scaling is used save_original_input = True num_gemms = len(m_splits) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 49b78382d2..0bf560c7b7 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -131,7 +131,7 @@ def forward( ) = non_tensor_args keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() if keep_backward_unquantized: - # Note, keep_backward_unquantized is ignored when delayed scaling is used + # Note, NVTE_KEEP_BACKWARD_UNQUANTIZED is ignored when delayed scaling is used save_original_input = True # NVTX label for profiling diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index e8f6dafdb5..aab7ed2d1c 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -434,7 +434,8 @@ def with_high_precision_init_val(cls) -> bool: def keep_backward_unquantized(cls) -> bool: """Should backward skip FP8 quantization and use high precision""" recipe = cls.get_fp8_recipe() - if recipe.delayed(): + if recipe is not None and recipe.delayed(): + # Ignore NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling is used return False return bool(int(os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0"))) From 01a7de026f92e7bb9e8f1e8b8e6f51b7da1c668a Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 11:13:57 -0800 Subject: [PATCH 09/39] Add back missing ctx.debug Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/layernorm_linear.py | 4 ++-- transformer_engine/pytorch/module/layernorm_mlp.py | 10 +++++----- transformer_engine/pytorch/module/linear.py | 8 ++++---- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index bdfeff056b..fd458a34b4 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -850,14 +850,14 @@ def backward( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) ln_out_total = ctx.input_quantizer(ln_out_total) - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index c5f7051fa1..a98ecfb903 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1089,7 +1089,7 @@ def backward( ub_obj_fc1_dgrad = None if ctx.fc1_weight_requires_grad and ctx.tensor_parallel and ctx.sequence_parallel: quantizer = None - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: quantizer = ctx.fc1_input_quantizer if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # If data is in FP8, we compute FP8 transposes manually @@ -1228,14 +1228,14 @@ def backward( # Prepare input tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: if isinstance(act_out, QuantizedTensorStorage): act_out.update_usage(columnwise_usage=True) else: ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True) act_out = ctx.fc2_input_quantizer(act_out) - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -1467,7 +1467,7 @@ def fc2_wgrad_gemm( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: @@ -1477,7 +1477,7 @@ def fc2_wgrad_gemm( # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: if isinstance(dact, QuantizedTensorStorage): dact.update_usage(columnwise_usage=True) else: diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 0bf560c7b7..930fbe061d 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -638,7 +638,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat_total = None inputmat_total_work = None if ctx.requires_wgrad: - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: if isinstance(inputmat, QuantizedTensorStorage): # Input tensor is already quantized pass @@ -664,7 +664,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat = cast_if_needed(inputmat, ctx.activation_dtype) if ctx.backward_input_needs_gather: quantizer = None - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: quantizer = ctx.input_quantizer if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually @@ -792,7 +792,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if inputmat_total_work is not None: inputmat_total_work.wait() inputmat_total_work = None - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: if isinstance(inputmat_total, QuantizedTensorStorage): inputmat_total.update_usage(columnwise_usage=True) else: @@ -834,7 +834,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream ) - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: From bf904aab91dad9d2a515dc249400b9282e65ce09 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 11:43:45 -0800 Subject: [PATCH 10/39] Refactor changes under fused Signed-off-by: Ziang Li --- .../ops/fused/backward_activation_bias.py | 7 ++----- .../ops/fused/forward_linear_bias_activation.py | 17 +++++++++++------ .../ops/fused/forward_linear_bias_add.py | 17 +++++++++++------ .../ops/fused/forward_linear_scale_add.py | 17 +++++++++++------ 4 files changed, 35 insertions(+), 23 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py index 59e9af14f4..4ab082d32b 100644 --- a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py +++ b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py @@ -10,7 +10,7 @@ import torch import transformer_engine_torch as tex -from transformer_engine.pytorch.quantization import Recipe, FP8GlobalStateManager +from transformer_engine.pytorch.quantization import Recipe from transformer_engine.pytorch.ops.basic import Bias from transformer_engine.pytorch.ops.basic.activation import ( _ActivationOperation, @@ -105,10 +105,7 @@ def fuse_backward_ops( """ # Check if recipe supports bias activation fusion - if recipe is None or ( - FP8GlobalStateManager.is_fp8_enabled() - and FP8GlobalStateManager.keep_backward_unquantized() - ): + if recipe is None: return ops # Scan through ops, fusing if possible diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 0a28d00706..6e7c85988f 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -122,12 +122,17 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input = input_ if keep_backward_unquantized else x_local - if not weight_requires_grad: - saved_input = None - saved_weight = linear_op.weight if keep_backward_unquantized else w - if not input_requires_grad: - saved_weight = None + saved_input = x_local + saved_weight = w + if keep_backward_unquantized: + saved_input = input_ if input_requires_grad else None + saved_weight = linear_op.weight if weight_requires_grad else None + # saved_input = input_ if keep_backward_unquantized else x_local + # if not weight_requires_grad: + # saved_input = None + # saved_weight = linear_op.weight if keep_backward_unquantized else w + # if not input_requires_grad: + # saved_weight = None if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 41ae096e54..f3b4533848 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -119,12 +119,17 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input = input_ if keep_backward_unquantized else x_local - if not weight_requires_grad: - saved_input = None - saved_weight = linear_op.weight if keep_backward_unquantized else w - if not input_requires_grad: - saved_weight = None + saved_input = x_local + saved_weight = w + if keep_backward_unquantized: + saved_input = input_ if input_requires_grad else None + saved_weight = linear_op.weight if weight_requires_grad else None + # saved_input = input_ if keep_backward_unquantized else x_local + # if not weight_requires_grad: + # saved_input = None + # saved_weight = linear_op.weight if keep_backward_unquantized else w + # if not input_requires_grad: + # saved_weight = None if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index b06f5ad36a..53e7327873 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -100,12 +100,17 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input = input_ if keep_backward_unquantized else x_local - if not weight_requires_grad: - saved_input = None - saved_weight = linear_op.weight if keep_backward_unquantized else w - if not input_requires_grad: - saved_weight = None + saved_input = x_local + saved_weight = w + if keep_backward_unquantized: + saved_input = input_ if input_requires_grad else None + saved_weight = linear_op.weight if weight_requires_grad else None + # saved_input = input_ if keep_backward_unquantized else x_local + # if not weight_requires_grad: + # saved_input = None + # saved_weight = linear_op.weight if keep_backward_unquantized else w + # if not input_requires_grad: + # saved_weight = None if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) From b449fc4516f5e3146d13f99d2377158788de385c Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 11:44:30 -0800 Subject: [PATCH 11/39] Clean up Signed-off-by: Ziang Li --- .../pytorch/ops/fused/forward_linear_bias_activation.py | 6 ------ .../pytorch/ops/fused/forward_linear_bias_add.py | 6 ------ .../pytorch/ops/fused/forward_linear_scale_add.py | 6 ------ 3 files changed, 18 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 6e7c85988f..2458d4d072 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -127,12 +127,6 @@ def fuser_forward( if keep_backward_unquantized: saved_input = input_ if input_requires_grad else None saved_weight = linear_op.weight if weight_requires_grad else None - # saved_input = input_ if keep_backward_unquantized else x_local - # if not weight_requires_grad: - # saved_input = None - # saved_weight = linear_op.weight if keep_backward_unquantized else w - # if not input_requires_grad: - # saved_weight = None if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index f3b4533848..efa543e555 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -124,12 +124,6 @@ def fuser_forward( if keep_backward_unquantized: saved_input = input_ if input_requires_grad else None saved_weight = linear_op.weight if weight_requires_grad else None - # saved_input = input_ if keep_backward_unquantized else x_local - # if not weight_requires_grad: - # saved_input = None - # saved_weight = linear_op.weight if keep_backward_unquantized else w - # if not input_requires_grad: - # saved_weight = None if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index 53e7327873..2804534968 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -105,12 +105,6 @@ def fuser_forward( if keep_backward_unquantized: saved_input = input_ if input_requires_grad else None saved_weight = linear_op.weight if weight_requires_grad else None - # saved_input = input_ if keep_backward_unquantized else x_local - # if not weight_requires_grad: - # saved_input = None - # saved_weight = linear_op.weight if keep_backward_unquantized else w - # if not input_requires_grad: - # saved_weight = None if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) From de3acaf7e11c79cc072face5d3fc8431be84fec6 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 14:11:07 -0800 Subject: [PATCH 12/39] Refactor high-precision overwrite if keep_backward_unquantized Signed-off-by: Ziang Li --- .../pytorch/module/grouped_linear.py | 17 ++++++++++------- .../pytorch/module/layernorm_linear.py | 10 ++++++++-- .../pytorch/module/layernorm_mlp.py | 14 +++++++++++--- transformer_engine/pytorch/module/linear.py | 5 ++++- 4 files changed, 33 insertions(+), 13 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 9e2eb60ea5..859e648579 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -406,13 +406,16 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dtype=ctx.activation_dtype, device=ctx.device, ) - weights_for_dgrad = weights if use_fp8_bwd else origin_weights - if use_fp8_bwd: - # Make sure weights are available in column-wise format - # for dgrad computation. - for weight in weights_for_dgrad: - if isinstance(weight, QuantizedTensorStorage): - weight.update_usage(columnwise_usage=True) + # weights_for_dgrad = weights if use_fp8_bwd else origin_weights + # if use_fp8_bwd: + weights_for_dgrad = weights + if keep_backward_unquantized: + weights_for_dgrad = origin_weights + # Make sure weights are available in column-wise format + # for dgrad computation. + for weight in weights_for_dgrad: + if isinstance(weight, QuantizedTensorStorage): + weight.update_usage(columnwise_usage=True) general_grouped_gemm( weights_for_dgrad, grad_output, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index fd458a34b4..70d8936ce3 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -415,7 +415,10 @@ def forward( # ------------------------------------------------------ if is_grad_enabled: - ln_out_to_save = ln_out_hp if keep_backward_unquantized else ln_out + ln_out_to_save = ln_out + if keep_backward_unquantized: + ln_out_to_save = ln_out_hp + # ln_out_to_save = ln_out_hp if keep_backward_unquantized else ln_out ctx.weight_quantizer = weight_quantizer ctx.ln_out_needs_gather = ( weight.requires_grad and parallel_mode == "column" and sequence_parallel @@ -755,7 +758,10 @@ def backward( # dgrad GEMM # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") - weight_for_dgrad = weight if use_fp8_bwd else origin_weight + # weight_for_dgrad = weight if use_fp8_bwd else origin_weight + weight_for_dgrad = weight + if keep_backward_unquantized: + weight_for_dgrad = origin_weight gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, grad_output, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index a98ecfb903..a8e0bda73d 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -695,8 +695,13 @@ def _forward( # if we are not checkpointing, then we must save this if grad is enabled if is_grad_enabled and not save_for_checkpoint: - ln_out_to_save = ln_out_hp if keep_backward_unquantized else ln_out - act_out_to_save = act_out_hp if keep_backward_unquantized else act_out + ln_out_to_save = ln_out + act_out_to_save = act_out + if keep_backward_unquantized: + ln_out_to_save = ln_out_hp + act_out_to_save = act_out_hp + # ln_out_to_save = ln_out_hp if keep_backward_unquantized else ln_out + # act_out_to_save = act_out_hp if keep_backward_unquantized else act_out ctx.fc1_weight_quantizer = fc1_weight_quantizer ctx.fc2_weight_quantizer = fc2_weight_quantizer @@ -1152,7 +1157,10 @@ def backward( ctx.fc2_weight.update_usage(columnwise_usage=True) # Perform GEMM - fc2_weight_for_dgrad = fc2_weight if use_fp8_bwd else origin_fc2_weight + fc2_weight_for_dgrad = fc2_weight + if keep_backward_unquantized: + fc2_weight_for_dgrad = origin_fc2_weight + # fc2_weight_for_dgrad = fc2_weight if use_fp8_bwd else origin_fc2_weight gemm_output, *_ = general_gemm( fc2_weight_for_dgrad, grad_output, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 930fbe061d..496bfd45b7 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -737,7 +737,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") - weight_for_dgrad = weight_fp8 if use_fp8_bwd else weight + weight_for_dgrad = weight_fp8 + if keep_backward_unquantized: + weight_for_dgrad = weight + # weight_for_dgrad = weight_fp8 if use_fp8_bwd else weight gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, grad_output, From fe65d34213cfa6061459e5a04ab2ce4610865535 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 14:14:22 -0800 Subject: [PATCH 13/39] Clean up Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/grouped_linear.py | 2 -- transformer_engine/pytorch/module/layernorm_linear.py | 2 -- transformer_engine/pytorch/module/layernorm_mlp.py | 3 --- transformer_engine/pytorch/module/linear.py | 1 - 4 files changed, 8 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 859e648579..e782f20cc6 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -406,8 +406,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dtype=ctx.activation_dtype, device=ctx.device, ) - # weights_for_dgrad = weights if use_fp8_bwd else origin_weights - # if use_fp8_bwd: weights_for_dgrad = weights if keep_backward_unquantized: weights_for_dgrad = origin_weights diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 70d8936ce3..e3aab9b304 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -418,7 +418,6 @@ def forward( ln_out_to_save = ln_out if keep_backward_unquantized: ln_out_to_save = ln_out_hp - # ln_out_to_save = ln_out_hp if keep_backward_unquantized else ln_out ctx.weight_quantizer = weight_quantizer ctx.ln_out_needs_gather = ( weight.requires_grad and parallel_mode == "column" and sequence_parallel @@ -758,7 +757,6 @@ def backward( # dgrad GEMM # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") - # weight_for_dgrad = weight if use_fp8_bwd else origin_weight weight_for_dgrad = weight if keep_backward_unquantized: weight_for_dgrad = origin_weight diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index a8e0bda73d..6107c7d377 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -700,8 +700,6 @@ def _forward( if keep_backward_unquantized: ln_out_to_save = ln_out_hp act_out_to_save = act_out_hp - # ln_out_to_save = ln_out_hp if keep_backward_unquantized else ln_out - # act_out_to_save = act_out_hp if keep_backward_unquantized else act_out ctx.fc1_weight_quantizer = fc1_weight_quantizer ctx.fc2_weight_quantizer = fc2_weight_quantizer @@ -1160,7 +1158,6 @@ def backward( fc2_weight_for_dgrad = fc2_weight if keep_backward_unquantized: fc2_weight_for_dgrad = origin_fc2_weight - # fc2_weight_for_dgrad = fc2_weight if use_fp8_bwd else origin_fc2_weight gemm_output, *_ = general_gemm( fc2_weight_for_dgrad, grad_output, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 496bfd45b7..10ea095c16 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -740,7 +740,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], weight_for_dgrad = weight_fp8 if keep_backward_unquantized: weight_for_dgrad = weight - # weight_for_dgrad = weight_fp8 if use_fp8_bwd else weight gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, grad_output, From 59aaf6b7875202f19f4180e5057a07df418668cd Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 4 Feb 2026 10:56:41 -0800 Subject: [PATCH 14/39] Drop redundant fp8_recipe_bwd Signed-off-by: Ziang Li --- .../pytorch/module/layernorm_mlp.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 6107c7d377..9406c0c7ef 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1023,7 +1023,6 @@ def backward( keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized - fp8_recipe_bwd = ctx.fp8_recipe if use_fp8_bwd else None if keep_backward_unquantized: # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True ctx.ub_overlap_ag = False @@ -1249,7 +1248,7 @@ def backward( # Whether to set grad arg in general_gemm grad_arg = True - if use_fp8_bwd and fp8_recipe_bwd.float8_block_scaling(): + if use_fp8_bwd and ctx.fp8_recipe.float8_block_scaling(): grad_arg = False # Arguments to include in wgrad GEMM closure @@ -1299,7 +1298,7 @@ def fc2_wgrad_gemm( if fc2_bias_grad is None: if ( use_fp8_bwd - and fp8_recipe_bwd.float8_block_scaling() + and ctx.fp8_recipe.float8_block_scaling() and fc2_bias is not None ): # BGRAD not fused with GEMM for float8 blockwise gemm. @@ -1333,9 +1332,14 @@ def fc2_wgrad_gemm( dact = dact_func(fc2_dgrad, fc1_out.to(ctx.activation_dtype), None, **act_params) fc1_bias_grad = dact.sum(dim=0) dact = ctx.fc1_grad_output_quantizer(dact) - elif _act_func(ctx.activation, fp8_recipe_bwd)[2] is not None and use_fp8_bwd: + elif ( + _act_func(ctx.activation, ctx.fp8_recipe if ctx.fp8 else None)[2] is not None + and use_fp8_bwd + ): # Fusion: gemm, bias + gelu + quantize - dbias_dact_quantize_func = _act_func(ctx.activation, fp8_recipe_bwd)[2] + dbias_dact_quantize_func = _act_func( + ctx.activation, ctx.fp8_recipe if ctx.fp8 else None + )[2] fc1_bias_grad, dact = dbias_dact_quantize_func( fc2_dgrad, fc1_out.to(ctx.activation_dtype), @@ -1345,7 +1349,9 @@ def fc2_wgrad_gemm( else: # Fusion: gemm + gelu, if not fc2_dgrad_gemm_gelu_fusion: - activation_func_bwd = _act_func(ctx.activation, fp8_recipe_bwd)[1] + activation_func_bwd = _act_func( + ctx.activation, ctx.fp8_recipe if ctx.fp8 else None + )[1] dact = activation_func_bwd( fc2_dgrad, fc1_out.to(ctx.activation_dtype), None, **act_params ) # activation in high precision @@ -1354,7 +1360,7 @@ def fc2_wgrad_gemm( # TODO float8 blockwise current scaling (as well as custom quantizers) has no bgrad fusion for now if ( isinstance(ctx.fc1_grad_output_quantizer, Float8BlockQuantizer) - or fp8_recipe_bwd.custom() + or ctx.fp8_recipe.custom() ): fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0) dact = ctx.fc1_grad_output_quantizer(dact) From 44da62593ef2476d80691f79f652ec907333870f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Feb 2026 18:57:29 +0000 Subject: [PATCH 15/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/layernorm_mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 9406c0c7ef..863a70e5e8 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1335,7 +1335,7 @@ def fc2_wgrad_gemm( elif ( _act_func(ctx.activation, ctx.fp8_recipe if ctx.fp8 else None)[2] is not None and use_fp8_bwd - ): + ): # Fusion: gemm, bias + gelu + quantize dbias_dact_quantize_func = _act_func( ctx.activation, ctx.fp8_recipe if ctx.fp8 else None From 0f5879380fcdb9a9c90d0fa73d6de3edfb646df0 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 4 Feb 2026 11:02:24 -0800 Subject: [PATCH 16/39] Drop redundant ub changes Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/layernorm_mlp.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 863a70e5e8..add32c0ba9 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1388,16 +1388,16 @@ def fc2_wgrad_gemm( fc1_dgrad_shape = [reduce(multiply_op, inputmat.shape[:-1]), inputmat.shape[-1]] if ctx.ub_overlap_rs_dgrad: # Overlap DGRAD+RS - ub_obj_fc1_dgrad = get_ub("fc1_dgrad", use_fp8_bwd) + ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8) ub_type_fc1_dgrad = tex.CommOverlapType.RS else: if ctx.ub_bulk_dgrad: # Overlap ln_out all-gather with DGRAD compute - ub_obj_fc1_dgrad = get_ub("fc1_dgrad", use_fp8_bwd) + ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8) ub_type_fc1_dgrad = tex.CommOverlapType.AG if ctx.ub_bulk_wgrad: # Overlap FC1 DGRAD reduce-scatter with WGRAD compute - ub_obj_fc1_wgrad = get_ub("fc1_wgrad", use_fp8_bwd) + ub_obj_fc1_wgrad = get_ub("fc1_wgrad", ctx.fp8) ub_type_fc1_wgrad = tex.CommOverlapType.RS # -------------------------------------------------- From 192fbad0501fb967bb02c5e545343726a2dbaff1 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 4 Feb 2026 11:07:16 -0800 Subject: [PATCH 17/39] Drop more redundant ub changes Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/layernorm_linear.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index e3aab9b304..60c4e1d8b2 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -812,11 +812,7 @@ def backward( # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if ( - use_fp8_bwd - and ctx.ub_overlap_ag - and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer) - ): + if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer): # UB does not support pipelined overlapping grad output # all-gather with wgrad GEMM. Also, we can't # convert row-scaled MXFP8 to column-scaled, so we @@ -828,7 +824,7 @@ def backward( dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream() # This object is separate from the ub_obj_wgrad object which is passed to the GEMM - ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", use_fp8_bwd) + ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) From 0dd12689957868370d0f17890cbb743361bf134a Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 4 Feb 2026 11:25:01 -0800 Subject: [PATCH 18/39] Drop redundant delayed scaling changes Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/grouped_linear.py | 6 +----- transformer_engine/pytorch/module/layernorm_mlp.py | 6 +----- transformer_engine/pytorch/module/linear.py | 2 +- 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index e782f20cc6..7e6773043d 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -299,11 +299,7 @@ def forward( ctx.inp_shape = inp.shape ctx.requires_dgrad = inp.requires_grad ctx.reduce_and_update_bwd_fp8_tensors = False - if ( - ctx.fp8 - and not ctx.keep_backward_unquantized - and requires_grad(inp, weights[0], biases[0]) - ): + if ctx.fp8 and requires_grad(inp, weights[0], biases[0]): ctx.reduce_and_update_bwd_fp8_tensors = ( ctx.reduce_and_update_bwd_fp8_tensors or FP8GlobalStateManager.is_first_fp8_module() diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index add32c0ba9..5f8de6159e 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -847,12 +847,8 @@ def _forward( ) ctx.normalization = normalization ctx.reduce_and_update_bwd_fp8_tensors = False - if ( - ctx.fp8 - and not ctx.keep_backward_unquantized - and requires_grad( + if ctx.fp8 and requires_grad( inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias - ) ): _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 10ea095c16..535d2e75e5 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -484,7 +484,7 @@ def forward( ctx.reduce_and_update_bwd_fp8_tensors = False ctx.owns_input = saved_inputmat is not inp - if ctx.fp8 and not ctx.keep_backward_unquantized and requires_grad(inp, weight, bias): + if ctx.fp8 and requires_grad(inp, weight, bias): _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase(): From 216621d01a3021a63e1c6f102817113ec46edd0a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Feb 2026 19:25:49 +0000 Subject: [PATCH 19/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/layernorm_mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 5f8de6159e..6a88848236 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -848,7 +848,7 @@ def _forward( ctx.normalization = normalization ctx.reduce_and_update_bwd_fp8_tensors = False if ctx.fp8 and requires_grad( - inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias + inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias ): _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() From ab8749bb120ce73f6009d285c2c2c84c7890590b Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 4 Feb 2026 12:01:36 -0800 Subject: [PATCH 20/39] Drop unneeded backwards_needs_fc1_input Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/layernorm_mlp.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 6a88848236..44028aebcc 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -351,10 +351,8 @@ def _forward( # bwd needs fc1 input when grad is enabled, fc1 needs grad, and either # 1) no checkpointing # or 2) doing the recomputation with checkpointing - backwards_needs_fc1_input = ( - fc1_weight.requires_grad - and ((is_grad_enabled and not checkpoint) or is_recomputation) - and not keep_backward_unquantized + backwards_needs_fc1_input = fc1_weight.requires_grad and ( + (is_grad_enabled and not checkpoint) or is_recomputation ) device = inp.device From 58810837b3d3794c4c66b2994c767418ec2b9e8d Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 4 Feb 2026 14:01:43 -0800 Subject: [PATCH 21/39] Drop and disallow LayerNormMLP implementation Signed-off-by: Ziang Li --- .../pytorch/module/layernorm_mlp.py | 104 ++++++------------ 1 file changed, 34 insertions(+), 70 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 44028aebcc..8d78ceab86 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -233,6 +233,7 @@ def _forward( recompute_for_bwd, ) = non_tensor_args keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() + assert not keep_backward_unquantized, "NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" # if grad is enabled and this is not the bwd stage, we must save this so bwd knows which path to take if is_grad_enabled and not recompute_for_bwd: @@ -395,7 +396,6 @@ def _forward( and not debug and not return_layernorm_output and not return_layernorm_output_gathered - and not keep_backward_unquantized and not custom ) @@ -417,7 +417,6 @@ def _forward( # do not return layernorm output unless 1) no checkpointing or 2) checkpointing but not recomputing if (return_layernorm_output or return_layernorm_output_gathered) and not is_recomputation: ln_out_return = ln_out - ln_out_hp = ln_out if keep_backward_unquantized else None # Prepare GEMM input # Note: Cast to expected dtype and perform tensor-parallel communication @@ -614,10 +613,6 @@ def _forward( if fc2_input_quantizer is not None: fc2_input_quantizer.calibrate(act_out) - act_out_hp = act_out - if keep_backward_unquantized and is_grad_enabled and fc1_out is not None: - act_out_hp = activation_func(fc1_out, None, **act_params) - # we want to skip fc2 computation if we are checkpointing and recomputing, # otherwise we compute fc2 if not (is_recomputation and checkpoint): @@ -693,33 +688,22 @@ def _forward( # if we are not checkpointing, then we must save this if grad is enabled if is_grad_enabled and not save_for_checkpoint: - ln_out_to_save = ln_out - act_out_to_save = act_out - if keep_backward_unquantized: - ln_out_to_save = ln_out_hp - act_out_to_save = act_out_hp ctx.fc1_weight_quantizer = fc1_weight_quantizer ctx.fc2_weight_quantizer = fc2_weight_quantizer if not fc1_weight.requires_grad: if not return_layernorm_output: - clear_tensor_data(ln_out_to_save) - ln_out_to_save = None + clear_tensor_data(ln_out) + ln_out = None if not fc2_weight.requires_grad: - clear_tensor_data(act_out_to_save) - act_out_to_save = None + clear_tensor_data(act_out) + act_out = None if not checkpoint: # regular path, no selective activation checkpointing if cpu_offloading: mark_activation_offload( - inputmat, - mu, - rsigma, - ln_out_to_save, - fc1_out, - fc1_out_without_bias, - act_out_to_save, + inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out ) # Scatter intermediate/activation tensors saved for the backward pass @@ -732,9 +716,9 @@ def _forward( fsdp_group, mu, rsigma, - ln_out_to_save, + ln_out, fc1_out_without_bias if bias_gelu_fusion else fc1_out, - act_out_to_save, + act_out, ( fc1_weight_final if fp8 and not isinstance(fc1_weight, Float8Tensor) @@ -762,13 +746,13 @@ def _forward( tensors_to_save, tensor_objects = prepare_for_saving( inputmat, ln_weight, - ln_out_to_save, + ln_out, fc1_weight_final, fc1_weight, fc1_bias, fc1_out, fc1_out_without_bias, - act_out_to_save, + act_out, fc2_weight_final, fc2_weight, fc2_bias, @@ -816,7 +800,6 @@ def _forward( ctx.activation_params = activation_params ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None - ctx.keep_backward_unquantized = keep_backward_unquantized ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -1015,15 +998,6 @@ def backward( origin_fc1_weight.main_grad = fc1_weight_main_grad origin_fc2_weight.main_grad = fc2_weight_main_grad - keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) - use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized - if keep_backward_unquantized: - # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True - ctx.ub_overlap_ag = False - ctx.ub_overlap_rs_dgrad = False - ctx.ub_bulk_dgrad = False - ctx.ub_bulk_wgrad = False - # TODO: Fix this # pylint: disable=fixme # Gather saved autograd context tensors when running with FSDP # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already @@ -1043,7 +1017,7 @@ def backward( # Choose whether to use GEMM kernel with split accumulator dgrad_use_split_accumulator = _2X_ACC_DGRAD wgrad_use_split_accumulator = _2X_ACC_WGRAD - if use_fp8_bwd: + if ctx.fp8: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): dgrad_use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator @@ -1057,7 +1031,7 @@ def backward( # Configure quantizer for FC2 grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.fc2_grad_output_quantizer is not None and use_fp8_bwd: + if ctx.fc2_grad_output_quantizer is not None: quantizer = ctx.fc2_grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -1085,7 +1059,7 @@ def backward( ub_obj_fc1_dgrad = None if ctx.fc1_weight_requires_grad and ctx.tensor_parallel and ctx.sequence_parallel: quantizer = None - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: quantizer = ctx.fc1_input_quantizer if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # If data is in FP8, we compute FP8 transposes manually @@ -1131,7 +1105,7 @@ def backward( # 5 high-precision unfused: gemm, activation, FC1_bias + FC1_gemm # 6 fp8 unfused: gemm, activation, FC1_bias + FC1_gemm fc2_dgrad_gemm_gelu_fusion = ( - not use_fp8_bwd + not ctx.fp8 and (ctx.activation == "gelu") and (not ctx.bias_gelu_fusion) and (not ctx.debug) @@ -1140,25 +1114,20 @@ def backward( # Make sure required data is available if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) - if ( - use_fp8_bwd - and ctx.fc2_weight_quantizer is not None - and isinstance(ctx.fc2_weight, QuantizedTensorStorage) + if ctx.fc2_weight_quantizer is not None and isinstance( + ctx.fc2_weight, QuantizedTensorStorage ): ctx.fc2_weight.update_usage(columnwise_usage=True) # Perform GEMM - fc2_weight_for_dgrad = fc2_weight - if keep_backward_unquantized: - fc2_weight_for_dgrad = origin_fc2_weight gemm_output, *_ = general_gemm( - fc2_weight_for_dgrad, + fc2_weight, grad_output, layout="NN", grad=True, quantization_params=( ctx.fc1_grad_input_quantizer - if (fc2_dgrad_gemm_gelu_fusion or ctx.debug) and use_fp8_bwd + if fc2_dgrad_gemm_gelu_fusion or ctx.debug else None ), # high precision to activation out_dtype=ctx.activation_dtype, @@ -1226,14 +1195,14 @@ def backward( # Prepare input tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: if isinstance(act_out, QuantizedTensorStorage): act_out.update_usage(columnwise_usage=True) else: ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True) act_out = ctx.fc2_input_quantizer(act_out) - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -1242,7 +1211,7 @@ def backward( # Whether to set grad arg in general_gemm grad_arg = True - if use_fp8_bwd and ctx.fp8_recipe.float8_block_scaling(): + if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling(): grad_arg = False # Arguments to include in wgrad GEMM closure @@ -1252,9 +1221,7 @@ def backward( if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": ( - ctx.fc2_grad_weight_quantizer if use_fp8_bwd else None - ), # wgrad in high precision + "quantization_params": ctx.fc2_grad_weight_quantizer, # wgrad in high precision "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(fc1_weight, "overwrite_main_grad", False) @@ -1291,7 +1258,7 @@ def fc2_wgrad_gemm( # Update grad bias if needed if fc2_bias_grad is None: if ( - use_fp8_bwd + ctx.fp8 and ctx.fp8_recipe.float8_block_scaling() and fc2_bias is not None ): @@ -1312,12 +1279,12 @@ def fc2_wgrad_gemm( act_params = ctx.activation_params or {} fc1_bias_grad = None fuse_gemm_and_bias_fc1_wgrad = False - if ctx.fc1_grad_output_quantizer is not None and use_fp8_bwd: + if ctx.fc1_grad_output_quantizer is not None: ctx.fc1_grad_output_quantizer.set_usage(rowwise=True, columnwise=True) if ctx.bias_gelu_fusion: # Fusion: gemm, bias + gelu assert ctx.activation == "gelu" - assert not use_fp8_bwd + assert not ctx.fp8 fc1_bias_grad, dact = bgrad_dgelu_fused(fc2_dgrad, fc1_out_without_bias, fc1_bias) if ctx.fc1_grad_output_quantizer is not None: dact = ctx.fc1_grad_output_quantizer(dact) @@ -1328,7 +1295,7 @@ def fc2_wgrad_gemm( dact = ctx.fc1_grad_output_quantizer(dact) elif ( _act_func(ctx.activation, ctx.fp8_recipe if ctx.fp8 else None)[2] is not None - and use_fp8_bwd + and ctx.fp8 ): # Fusion: gemm, bias + gelu + quantize dbias_dact_quantize_func = _act_func( @@ -1350,7 +1317,7 @@ def fc2_wgrad_gemm( fc2_dgrad, fc1_out.to(ctx.activation_dtype), None, **act_params ) # activation in high precision - if use_fp8_bwd: + if ctx.fp8: # TODO float8 blockwise current scaling (as well as custom quantizers) has no bgrad fusion for now if ( isinstance(ctx.fc1_grad_output_quantizer, Float8BlockQuantizer) @@ -1399,10 +1366,8 @@ def fc2_wgrad_gemm( # -------------------------------------------------- # Make sure required data is available - if ( - use_fp8_bwd - and ctx.fc1_weight_quantizer is not None - and isinstance(ctx.fc1_weight_quantizer, QuantizedTensorStorage) + if ctx.fc1_weight_quantizer is not None and isinstance( + ctx.fc1_weight_quantizer, QuantizedTensorStorage ): ctx.fc1_weight.update_usage(columnwise_usage=True) @@ -1417,13 +1382,12 @@ def fc2_wgrad_gemm( gemm_out = ub_obj_fc1_wgrad.get_buffer(local_chunk=False) # dgrad GEMM - fc1_weight_for_dgrad = fc1_weight if use_fp8_bwd else origin_fc1_weight gemm_out, *_, reduce_scatter_out = general_gemm( - fc1_weight_for_dgrad, + fc1_weight, dact, out=gemm_out, out_dtype=ctx.activation_dtype, - quantization_params=ctx.fc1_grad_input_quantizer if use_fp8_bwd else None, + quantization_params=ctx.fc1_grad_input_quantizer, layout="NN", grad=True, use_split_accumulator=dgrad_use_split_accumulator, @@ -1472,7 +1436,7 @@ def fc2_wgrad_gemm( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: @@ -1482,7 +1446,7 @@ def fc2_wgrad_gemm( # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: if isinstance(dact, QuantizedTensorStorage): dact.update_usage(columnwise_usage=True) else: @@ -1504,7 +1468,7 @@ def fc2_wgrad_gemm( if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": (ctx.fc1_grad_weight_quantizer if use_fp8_bwd else None), + "quantization_params": ctx.fc1_grad_weight_quantizer, "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(fc2_weight, "overwrite_main_grad", False) From 431f0c8fd3c643380fefa3f4ca923d59ada5bcea Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Feb 2026 22:02:31 +0000 Subject: [PATCH 22/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/layernorm_mlp.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 8d78ceab86..da236e7be0 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -233,7 +233,9 @@ def _forward( recompute_for_bwd, ) = non_tensor_args keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() - assert not keep_backward_unquantized, "NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" + assert ( + not keep_backward_unquantized + ), "NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" # if grad is enabled and this is not the bwd stage, we must save this so bwd knows which path to take if is_grad_enabled and not recompute_for_bwd: From 937e34b10585058383293c649cf2a5841813e7a9 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Thu, 5 Feb 2026 13:10:10 -0800 Subject: [PATCH 23/39] Move interface changes to recipe Signed-off-by: Ziang Li --- transformer_engine/common/recipe/__init__.py | 67 +++++++++++++++++-- .../pytorch/module/grouped_linear.py | 2 +- .../pytorch/module/layernorm_linear.py | 2 +- .../pytorch/module/layernorm_mlp.py | 2 +- transformer_engine/pytorch/module/linear.py | 2 +- .../pytorch/ops/basic/basic_linear.py | 6 +- .../pytorch/ops/basic/quantize.py | 2 +- .../fused/forward_linear_bias_activation.py | 2 +- .../ops/fused/forward_linear_bias_add.py | 2 +- .../ops/fused/forward_linear_scale_add.py | 2 +- transformer_engine/pytorch/quantization.py | 39 +++++++---- 11 files changed, 99 insertions(+), 29 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 64ee2a5a16..a36b743f3b 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -11,6 +11,11 @@ from pydantic.dataclasses import dataclass +def _default_quantize_backward() -> bool: + """Default backward quantization setting.""" + return not bool(int(os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0"))) + + class _FormatHelper(NamedTuple): """ Stores max FP8 values for fprop and bprop a `Format`. @@ -181,6 +186,11 @@ def scaling_factor_compute(amax: Tensor, `LayerNormLinear (BF16 output) -> (cast to FP8 ) FP8 DPA (cast to BF16) -> Linear`. When `fp8_mha = True, fp8_dpa = True`, it becomes `LayerNormLinear (FP8 output) -> FP8 DPA -> Linear`. + quantize_forward : bool, default = True + Whether to quantize tensors in the forward pass. + quantize_backward : bool, default = True + Whether to quantize tensors in the backward pass. Delayed scaling + always quantizes backward; setting this to False is not supported. Notes ----- @@ -204,6 +214,8 @@ def scaling_factor_compute(amax: Tensor, reduce_amax: bool = True fp8_dpa: bool = False fp8_mha: bool = False + quantize_forward: bool = True + quantize_backward: bool = field(default_factory=_default_quantize_backward) def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." @@ -216,7 +228,9 @@ def __repr__(self) -> str: f"amax_history_len={self.amax_history_len}, " f"reduce_amax={self.reduce_amax}, " f"fp8_dpa={self.fp8_dpa}, " - f"fp8_mha={self.fp8_mha}" + f"fp8_mha={self.fp8_mha}, " + f"quantize_forward={self.quantize_forward}, " + f"quantize_backward={self.quantize_backward}" ) @@ -230,6 +244,10 @@ class Float8CurrentScaling(Recipe): fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.HYBRID Controls the FP8 data format used during forward and backward pass. + quantize_forward : bool, default = True + Whether to quantize tensors in the forward pass. + quantize_backward : bool, default = True + Whether to quantize tensors in the backward pass. """ use_power_2_scales: bool = os.getenv("NVTE_FP8_CURRENT_SCALING_POWER_2_SCALES", "0") == "1" @@ -242,6 +260,10 @@ class Float8CurrentScaling(Recipe): fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) fp8_dpa: bool = False fp8_mha: bool = False + quantize_forward: bool = True + quantize_backward: bool = field(default_factory=_default_quantize_backward) + quantize_forward: bool = True + quantize_backward: bool = field(default_factory=_default_quantize_backward) def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." @@ -257,7 +279,9 @@ def __repr__(self) -> str: f"fp8_gemm_dgrad={self.fp8_gemm_dgrad}, " f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " f"fp8_dpa={self.fp8_dpa}, " - f"fp8_mha={self.fp8_mha}" + f"fp8_mha={self.fp8_mha}, " + f"quantize_forward={self.quantize_forward}, " + f"quantize_backward={self.quantize_backward}" ) @@ -284,12 +308,18 @@ class MXFP8BlockScaling(Recipe): fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3 Controls the FP8 data format used during forward and backward pass. + quantize_forward : bool, default = True + Whether to quantize tensors in the forward pass. + quantize_backward : bool, default = True + Whether to quantize tensors in the backward pass. """ margin: int = 0 fp8_format: Format = Format.E4M3 fp8_dpa: bool = False fp8_mha: bool = False + quantize_forward: bool = True + quantize_backward: bool = field(default_factory=_default_quantize_backward) def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." @@ -298,7 +328,9 @@ def __repr__(self) -> str: return ( f"recipe_type={self.__class__.__name__}, " f"margin={self.margin}, " - f"format={str(self.fp8_format).split('.')[1]}" + f"format={str(self.fp8_format).split('.')[1]}, " + f"quantize_forward={self.quantize_forward}, " + f"quantize_backward={self.quantize_backward}" ) @@ -327,6 +359,10 @@ class Float8BlockScaling(Recipe): fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3 Controls the FP8 data format used during forward and backward pass. + quantize_forward : bool, default = True + Whether to quantize tensors in the forward pass. + quantize_backward : bool, default = True + Whether to quantize tensors in the backward pass. """ use_f32_scales: bool = os.getenv("NVTE_FP8_BLOCK_SCALING_FP32_SCALES", "0") == "1" @@ -379,7 +415,9 @@ def __repr__(self) -> str: f"fp8_gemm_dgrad={self.fp8_gemm_dgrad}, " f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " f"fp8_dpa={self.fp8_dpa}, " - f"fp8_mha={self.fp8_mha}" + f"fp8_mha={self.fp8_mha}, " + f"quantize_forward={self.quantize_forward}, " + f"quantize_backward={self.quantize_backward}" ) @@ -428,6 +466,10 @@ class NVFP4BlockScaling(Recipe): If set to `True`, stochastic rounding is disabled during quantization for all tensors. disable_2d_quantization : bool, default = False If set to `True`, 1D block scaling with block size 16 is used for all tensors. + quantize_forward : bool, default = True + Whether to quantize tensors in the forward pass. + quantize_backward : bool, default = True + Whether to quantize tensors in the backward pass. """ # Configuration envvars @@ -443,6 +485,8 @@ class NVFP4BlockScaling(Recipe): # Not applying quantization to attention for now fp8_dpa: bool = False fp8_mha: bool = False + quantize_forward: bool = True + quantize_backward: bool = field(default_factory=_default_quantize_backward) def __post_init__(self) -> None: assert self.fp4_format == Format.E2M1, "Only E2M1 is supported for NVFP4 scaling" @@ -474,6 +518,8 @@ def __repr__(self) -> str: f"fp8_format={str(self.fp8_format).split('.')[1]}, " f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}, " + f"quantize_forward={self.quantize_forward}, " + f"quantize_backward={self.quantize_backward}, " f"fp4_quant_fwd_inp={self.fp4_quant_fwd_inp}, " f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, " f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, " @@ -505,12 +551,23 @@ class CustomRecipe(Recipe): - forward: "linear_input", "linear_weight", "linear_output" - backward: "linear_grad_output", "linear_grad_input" + quantize_forward : bool, default = True + Whether to quantize tensors in the forward pass. + quantize_backward : bool, default = True + Whether to quantize tensors in the backward pass. """ qfactory: Callable[..., Any] fp8_dpa: bool = False fp8_mha: bool = False + quantize_forward: bool = True + quantize_backward: bool = field(default_factory=_default_quantize_backward) def __repr__(self) -> str: - return f"recipe_type={self.__class__.__name__}, qfactory={self.qfactory}" + return ( + f"recipe_type={self.__class__.__name__}, " + f"qfactory={self.qfactory}, " + f"quantize_forward={self.quantize_forward}, " + f"quantize_backward={self.quantize_backward}" + ) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 7e6773043d..a7d7bc8948 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -96,7 +96,7 @@ def forward( save_original_input, debug, ) = non_tensor_args - keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() + keep_backward_unquantized = fp8 and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) if keep_backward_unquantized: # Note, NVTE_KEEP_BACKWARD_UNQUANTIZED is ignored when delayed scaling is used save_original_input = True diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 60c4e1d8b2..4173c76216 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -141,7 +141,7 @@ def forward( symmetric_ar_type, debug, ) = non_tensor_args - keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() + keep_backward_unquantized = fp8 and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) # NVTX label for profiling nvtx_label = "transformer_engine._LayerNormLinear.forward" diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index da236e7be0..82e7d868b4 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -232,7 +232,7 @@ def _forward( debug, recompute_for_bwd, ) = non_tensor_args - keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() + keep_backward_unquantized = fp8 and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) assert ( not keep_backward_unquantized ), "NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 535d2e75e5..76ff5dd1d4 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -129,7 +129,7 @@ def forward( save_original_input, debug, ) = non_tensor_args - keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() + keep_backward_unquantized = fp8 and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) if keep_backward_unquantized: # Note, NVTE_KEEP_BACKWARD_UNQUANTIZED is ignored when delayed scaling is used save_original_input = True diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index a9a6895112..a362485a7e 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -332,7 +332,9 @@ def pre_fuser_forward(self, *, requires_grad: bool) -> None: # Note: We cache the quantized input for backward pass, # but discard the quantized weights. weight_requires_grad = requires_grad and self.weight.requires_grad - keep_backward_unquantized = FP8GlobalStateManager.keep_backward_unquantized() + keep_backward_unquantized = ( + FP8GlobalStateManager.is_fp8_enabled() and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + ) columnwise_usage = weight_requires_grad and not keep_backward_unquantized input_quantizer = self.get_quantizer("forward", 0) weight_quantizer = self.get_quantizer("forward", 1) @@ -989,7 +991,7 @@ def op_forward( grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() keep_backward_unquantized = ( - with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + with_quantized_compute and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) ) # Get autocast dtype if needed diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py index cc26022d0e..e6c28b9fdc 100644 --- a/transformer_engine/pytorch/ops/basic/quantize.py +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -60,7 +60,7 @@ def op_forward( quantize_backward = ( fp8_enabled and self._quantize_backward - and not FP8GlobalStateManager.keep_backward_unquantized() + and FP8GlobalStateManager.get_fp8_recipe().quantize_backward ) # Quantize if needed diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 2458d4d072..80cb5647d7 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -93,7 +93,7 @@ def fuser_forward( grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() keep_backward_unquantized = ( - with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + with_quantized_compute and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) ) # Get autocast dtype if needed diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index efa543e555..cf29140a20 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -87,7 +87,7 @@ def fuser_forward( grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() keep_backward_unquantized = ( - with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + with_quantized_compute and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) ) # Get autocast dtype if needed diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index 2804534968..0caae13af9 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -66,7 +66,7 @@ def fuser_forward( grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() keep_backward_unquantized = ( - with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + with_quantized_compute and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) ) # Get extra input tensor for add operation diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index aab7ed2d1c..fb0553056a 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -87,6 +87,21 @@ def check_fp8_block_scaling_support() -> Tuple[bool, str]: ) +def _validate_recipe_quantization_flags(recipe: Recipe) -> None: + """Validate forward/backward quantization flags on a recipe.""" + quantize_forward = getattr(recipe, "quantize_forward", True) + quantize_backward = getattr(recipe, "quantize_backward", True) + if not quantize_forward and quantize_backward: + raise ValueError( + "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." + ) + if recipe.delayed() and not quantize_backward: + raise ValueError( + "Invalid recipe configuration: delayed scaling does not support " + "quantize_backward=False." + ) + + def check_recipe_support(recipe: Recipe) -> None: """Check if the given recipe is supported.""" recipe_supported = True @@ -430,15 +445,6 @@ def with_high_precision_init_val(cls) -> bool: """Should the high precision initial values be stored with FP8 parameters""" return cls.HIGH_PRECISION_INIT_VAL - @classmethod - def keep_backward_unquantized(cls) -> bool: - """Should backward skip FP8 quantization and use high precision""" - recipe = cls.get_fp8_recipe() - if recipe is not None and recipe.delayed(): - # Ignore NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling is used - return False - return bool(int(os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0"))) - @classmethod def fp8_graph_capturing(cls) -> bool: """Is CUDA graph capture under way?""" @@ -851,16 +857,21 @@ def autocast( are reduced at the end of each training step. """ - if enabled: - check_recipe_support(recipe) + fp8_recipe = get_default_fp8_recipe() if recipe is None else recipe + if enabled or calibrating: + _validate_recipe_quantization_flags(fp8_recipe) + quantize_forward = getattr(fp8_recipe, "quantize_forward", True) + effective_enabled = enabled and quantize_forward + if effective_enabled: + check_recipe_support(fp8_recipe) # Save current state so we always restore it on exit. fp8_state = FP8GlobalStateManager.get_autocast_state() FP8GlobalStateManager.autocast_enter( - enabled=enabled, + enabled=effective_enabled, calibrating=calibrating, - fp8_recipe=recipe, + fp8_recipe=fp8_recipe, fp8_group=amax_reduction_group, _graph=_graph, ) @@ -868,7 +879,7 @@ def autocast( yield finally: FP8GlobalStateManager.set_autocast_state(fp8_state) - FP8GlobalStateManager.autocast_exit(enabled, _graph=_graph) + FP8GlobalStateManager.autocast_exit(effective_enabled, _graph=_graph) def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor: From 0d26127d2d90370bfedfe834fb3d7e10ac4e07ba Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Feb 2026 21:11:01 +0000 Subject: [PATCH 24/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/grouped_linear.py | 4 +++- transformer_engine/pytorch/module/layernorm_linear.py | 4 +++- transformer_engine/pytorch/module/layernorm_mlp.py | 4 +++- transformer_engine/pytorch/module/linear.py | 4 +++- transformer_engine/pytorch/ops/basic/basic_linear.py | 8 ++++---- .../pytorch/ops/fused/forward_linear_bias_activation.py | 4 ++-- .../pytorch/ops/fused/forward_linear_bias_add.py | 4 ++-- .../pytorch/ops/fused/forward_linear_scale_add.py | 4 ++-- 8 files changed, 22 insertions(+), 14 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index a7d7bc8948..9aad36a868 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -96,7 +96,9 @@ def forward( save_original_input, debug, ) = non_tensor_args - keep_backward_unquantized = fp8 and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + keep_backward_unquantized = fp8 and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward + ) if keep_backward_unquantized: # Note, NVTE_KEEP_BACKWARD_UNQUANTIZED is ignored when delayed scaling is used save_original_input = True diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 4173c76216..3016d41c5f 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -141,7 +141,9 @@ def forward( symmetric_ar_type, debug, ) = non_tensor_args - keep_backward_unquantized = fp8 and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + keep_backward_unquantized = fp8 and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward + ) # NVTX label for profiling nvtx_label = "transformer_engine._LayerNormLinear.forward" diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 82e7d868b4..8e6a189843 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -232,7 +232,9 @@ def _forward( debug, recompute_for_bwd, ) = non_tensor_args - keep_backward_unquantized = fp8 and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + keep_backward_unquantized = fp8 and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward + ) assert ( not keep_backward_unquantized ), "NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 76ff5dd1d4..c8feddf5af 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -129,7 +129,9 @@ def forward( save_original_input, debug, ) = non_tensor_args - keep_backward_unquantized = fp8 and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + keep_backward_unquantized = fp8 and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward + ) if keep_backward_unquantized: # Note, NVTE_KEEP_BACKWARD_UNQUANTIZED is ignored when delayed scaling is used save_original_input = True diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index a362485a7e..ba7de55f69 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -332,8 +332,8 @@ def pre_fuser_forward(self, *, requires_grad: bool) -> None: # Note: We cache the quantized input for backward pass, # but discard the quantized weights. weight_requires_grad = requires_grad and self.weight.requires_grad - keep_backward_unquantized = ( - FP8GlobalStateManager.is_fp8_enabled() and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + keep_backward_unquantized = FP8GlobalStateManager.is_fp8_enabled() and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward ) columnwise_usage = weight_requires_grad and not keep_backward_unquantized input_quantizer = self.get_quantizer("forward", 0) @@ -990,8 +990,8 @@ def op_forward( grad_output_quantizer = self.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - keep_backward_unquantized = ( - with_quantized_compute and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + keep_backward_unquantized = with_quantized_compute and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward ) # Get autocast dtype if needed diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 80cb5647d7..2bccabb306 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -92,8 +92,8 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - keep_backward_unquantized = ( - with_quantized_compute and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + keep_backward_unquantized = with_quantized_compute and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward ) # Get autocast dtype if needed diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index cf29140a20..03e3bff6f3 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -86,8 +86,8 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - keep_backward_unquantized = ( - with_quantized_compute and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + keep_backward_unquantized = with_quantized_compute and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward ) # Get autocast dtype if needed diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index 0caae13af9..8cebcec53a 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -65,8 +65,8 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - keep_backward_unquantized = ( - with_quantized_compute and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + keep_backward_unquantized = with_quantized_compute and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward ) # Get extra input tensor for add operation From 0135366a68fc8add8988c540b60ec96cbf25b723 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Thu, 5 Feb 2026 13:43:08 -0800 Subject: [PATCH 25/39] Move ub overrides to fwd Signed-off-by: Ziang Li --- .../pytorch/module/layernorm_linear.py | 14 ++++++++------ transformer_engine/pytorch/module/linear.py | 13 +++++++------ 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 3016d41c5f..f39fb45608 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -539,6 +539,14 @@ def forward( ctx.wgrad_store = wgrad_store ctx.debug = debug + # keep_backward_unquantized overrides + if keep_backward_unquantized: + # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True + ctx.ub_overlap_ag = False + ctx.ub_overlap_rs_dgrad = False + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + # ------------------------------------------------------ # Cached state for backward pass is ready... # ------------------------------------------------------ @@ -610,12 +618,6 @@ def backward( keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized - if keep_backward_unquantized: - # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True - ctx.ub_overlap_ag = False - ctx.ub_overlap_rs_dgrad = False - ctx.ub_bulk_dgrad = False - ctx.ub_bulk_wgrad = False # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index c8feddf5af..3ed78e85da 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -493,6 +493,13 @@ def forward( FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module ctx.wgrad_store = wgrad_store + # keep_backward_unquantized overrides + if keep_backward_unquantized: + ctx.ub_overlap_ag = False + ctx.ub_overlap_rs_dgrad = False + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + # ------------------------------------------------------ # Cached state for backward pass is ready... # ------------------------------------------------------ @@ -545,12 +552,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized - if keep_backward_unquantized: - # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True - ctx.ub_overlap_ag = False - ctx.ub_overlap_rs_dgrad = False - ctx.ub_bulk_dgrad = False - ctx.ub_bulk_wgrad = False # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None From 1de3c64a524e1d5127ab94e0eaf54037461cc7bd Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Thu, 5 Feb 2026 13:44:22 -0800 Subject: [PATCH 26/39] Remove duplication Signed-off-by: Ziang Li --- transformer_engine/common/recipe/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index a36b743f3b..85b232c26b 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -262,8 +262,6 @@ class Float8CurrentScaling(Recipe): fp8_mha: bool = False quantize_forward: bool = True quantize_backward: bool = field(default_factory=_default_quantize_backward) - quantize_forward: bool = True - quantize_backward: bool = field(default_factory=_default_quantize_backward) def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." From 04d35430cdd4537056f2dd18d4e62275a133e245 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Thu, 5 Feb 2026 13:59:39 -0800 Subject: [PATCH 27/39] Simplify use_fp8_bwd logic in bwd Signed-off-by: Ziang Li --- .../pytorch/module/grouped_linear.py | 19 +++++++++---- .../pytorch/module/layernorm_linear.py | 25 ++++++++--------- transformer_engine/pytorch/module/linear.py | 28 +++++++++---------- 3 files changed, 39 insertions(+), 33 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 9aad36a868..38e3ceef9a 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -310,6 +310,14 @@ def forward( ctx.debug = debug ctx.save_original_input = save_original_input ctx.input_quantizers = input_quantizers + + # keep_backward_unquantized overrides + if keep_backward_unquantized: + ctx.fp8 = ctx.fp8 and not keep_backward_unquantized + ctx.ub_overlap_ag = False + ctx.ub_overlap_rs_dgrad = False + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False # [*, in_features] -> [*, out_features] except first dimension changes for SP return out.view(-1, *inp.shape[1:-1], out.shape[-1]) @@ -326,7 +334,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], biases = saved_tensors[3 * N : 4 * N] main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) - use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized if ctx.cpu_offloading: if ctx.grad_added_to_main_grad: @@ -342,7 +349,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1]) grad_output = [None] * ctx.num_gemms grad_biases = [None] * ctx.num_gemms - if use_fp8_bwd and not ctx.debug: + if ctx.fp8 and not ctx.debug: if ctx.use_bias: grad_output_mats = torch.split(grad_output_view, ctx.m_splits) recipe = ctx.fp8_recipe @@ -393,7 +400,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.requires_dgrad: dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): dgrad_gemm_use_split_accumulator = ( @@ -427,7 +434,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.weights_requires_grad: wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD - if use_fp8_bwd: + if ctx.fp8: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_wgrad"): wgrad_gemm_use_split_accumulator = ( @@ -454,7 +461,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], else: input_quantizer.set_usage(rowwise=False, columnwise=True) inputmats: list - if use_fp8_bwd and not ctx.debug: + if ctx.fp8 and not ctx.debug: inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers) elif ctx.debug: inputmats = DebugQuantizer.multi_tensor_quantize( @@ -528,7 +535,7 @@ def handle_custom_ddp_from_mcore(weight, wgrad): if not ctx.use_bias or ( ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute() - and not use_fp8_bwd + and not ctx.fp8 ): grad_biases = [None] * ctx.num_gemms diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index f39fb45608..1ef8536e4f 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -541,7 +541,7 @@ def forward( # keep_backward_unquantized overrides if keep_backward_unquantized: - # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True + ctx.fp8 = ctx.fp8 and not keep_backward_unquantized ctx.ub_overlap_ag = False ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False @@ -617,7 +617,6 @@ def backward( origin_weight.main_grad = main_grad keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) - use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None @@ -655,7 +654,7 @@ def backward( # Configure quantizer for grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_output_quantizer is not None and use_fp8_bwd: + if ctx.grad_output_quantizer is not None and ctx.fp8: quantizer = ctx.grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -692,7 +691,7 @@ def backward( ln_out_total_work = None if ctx.ln_out_needs_gather: quantizer = None - if ctx.input_quantizer is not None and use_fp8_bwd: + if ctx.input_quantizer is not None and ctx.fp8: quantizer = ctx.input_quantizer if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually @@ -731,7 +730,7 @@ def backward( if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) if ( - use_fp8_bwd + ctx.fp8 and ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorStorage) ): @@ -739,13 +738,13 @@ def backward( # Choose whether to use GEMM kernel with split accumulator use_split_accumulator = _2X_ACC_DGRAD - if use_fp8_bwd: + if ctx.fp8: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator # Update grad input quantizer - if ctx.grad_input_quantizer is not None and use_fp8_bwd: + if ctx.grad_input_quantizer is not None and ctx.fp8: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # Output buffers for Userbuffers reduce-scatter @@ -769,7 +768,7 @@ def backward( grad_output, layout="NN", grad=True, - quantization_params=ctx.grad_input_quantizer if use_fp8_bwd else None, + quantization_params=ctx.grad_input_quantizer if ctx.fp8 else None, out=gemm_out, out_dtype=ctx.activation_dtype, use_split_accumulator=use_split_accumulator, @@ -854,14 +853,14 @@ def backward( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) ln_out_total = ctx.input_quantizer(ln_out_total) - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -870,7 +869,7 @@ def backward( # Figure out whether to use split accumulator use_split_accumulator = _2X_ACC_WGRAD - if use_fp8_bwd: + if ctx.fp8: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_wgrad"): use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator @@ -896,7 +895,7 @@ def backward( "out_dtype": ( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": (ctx.grad_weight_quantizer if use_fp8_bwd else None), + "quantization_params": (ctx.grad_weight_quantizer if ctx.fp8 else None), "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(weight, "overwrite_main_grad", False) @@ -904,7 +903,7 @@ def backward( ), "layout": "NT", "out": main_grad if ctx.fuse_wgrad_accumulation else None, - "bias": (bias if (grad_bias is None and not use_fp8_bwd) else None), + "bias": (bias if (grad_bias is None and not ctx.fp8) else None), "use_split_accumulator": use_split_accumulator, "grad": True, "ub": ub_obj_wgrad, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 3ed78e85da..a97ba398e0 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -495,6 +495,7 @@ def forward( # keep_backward_unquantized overrides if keep_backward_unquantized: + ctx.fp8 = ctx.fp8 and not keep_backward_unquantized ctx.ub_overlap_ag = False ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False @@ -551,7 +552,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], nvtx_range_pop(f"{nvtx_label}.fsdp_gather") keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) - use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None @@ -592,7 +592,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Configure quantizer for grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_output_quantizer is not None and use_fp8_bwd: + if ctx.grad_output_quantizer is not None and ctx.fp8: quantizer = ctx.grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -611,7 +611,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], not ctx.use_bias and not ctx.requires_wgrad and ctx.grad_output_quantizer is not None - and use_fp8_bwd + and ctx.fp8 ): ctx.grad_output_quantizer.set_usage(columnwise=False) @@ -641,7 +641,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat_total = None inputmat_total_work = None if ctx.requires_wgrad: - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: if isinstance(inputmat, QuantizedTensorStorage): # Input tensor is already quantized pass @@ -667,7 +667,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat = cast_if_needed(inputmat, ctx.activation_dtype) if ctx.backward_input_needs_gather: quantizer = None - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: quantizer = ctx.input_quantizer if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually @@ -709,7 +709,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) if ( - use_fp8_bwd + ctx.fp8 and ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensorStorage) ): @@ -717,13 +717,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Choose whether to use GEMM kernel with split accumulator use_split_accumulator = _2X_ACC_DGRAD - if use_fp8_bwd: + if ctx.fp8: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator # Update grad input quantizer - if ctx.grad_input_quantizer is not None and use_fp8_bwd: + if ctx.grad_input_quantizer is not None and ctx.fp8: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # Output buffers for Userbuffers reduce-scatter @@ -748,7 +748,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_output, layout="NN", grad=True, - quantization_params=ctx.grad_input_quantizer if use_fp8_bwd else None, + quantization_params=ctx.grad_input_quantizer if ctx.fp8 else None, out=gemm_out, out_dtype=ctx.activation_dtype, use_split_accumulator=use_split_accumulator, @@ -797,7 +797,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if inputmat_total_work is not None: inputmat_total_work.wait() inputmat_total_work = None - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: if isinstance(inputmat_total, QuantizedTensorStorage): inputmat_total.update_usage(columnwise_usage=True) else: @@ -839,7 +839,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream ) - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -848,7 +848,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Figure out whether to use split accumulator use_split_accumulator = _2X_ACC_WGRAD - if use_fp8_bwd: + if ctx.fp8: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_wgrad"): use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator @@ -874,7 +874,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], "out_dtype": ( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": (ctx.grad_weight_quantizer if use_fp8_bwd else None), + "quantization_params": (ctx.grad_weight_quantizer if ctx.fp8 else None), "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(weight, "overwrite_main_grad", False) @@ -882,7 +882,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ), "layout": "NT", "out": main_grad if ctx.fuse_wgrad_accumulation else None, - "bias": (bias if (grad_bias is None and not use_fp8_bwd) else None), + "bias": (bias if (grad_bias is None and not ctx.fp8) else None), "use_split_accumulator": use_split_accumulator, "grad": True, "ub": ub_obj_wgrad, From 454976eaeb1520ad075d0f3dbc2de736108ea0cd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Feb 2026 22:00:24 +0000 Subject: [PATCH 28/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/grouped_linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 38e3ceef9a..54caabdb7e 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -310,7 +310,7 @@ def forward( ctx.debug = debug ctx.save_original_input = save_original_input ctx.input_quantizers = input_quantizers - + # keep_backward_unquantized overrides if keep_backward_unquantized: ctx.fp8 = ctx.fp8 and not keep_backward_unquantized From f7794c94eb301e466db5c1c0b311bf054977caa6 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Thu, 5 Feb 2026 14:28:06 -0800 Subject: [PATCH 29/39] Set grad quantizers to none if keep bwd unquantized Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/grouped_linear.py | 3 +++ .../pytorch/module/layernorm_linear.py | 11 +++++++---- transformer_engine/pytorch/module/linear.py | 13 ++++++++----- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 54caabdb7e..73dc81ad41 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -318,6 +318,9 @@ def forward( ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False ctx.ub_bulk_wgrad = False + ctx.grad_input_quantizer = None + ctx.grad_weight_quantizer = None + ctx.grad_output_quantizer = None # [*, in_features] -> [*, out_features] except first dimension changes for SP return out.view(-1, *inp.shape[1:-1], out.shape[-1]) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 1ef8536e4f..4de6afa38b 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -546,6 +546,9 @@ def forward( ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False ctx.ub_bulk_wgrad = False + ctx.grad_input_quantizer = None + ctx.grad_weight_quantizer = None + ctx.grad_output_quantizer = None # ------------------------------------------------------ # Cached state for backward pass is ready... @@ -654,7 +657,7 @@ def backward( # Configure quantizer for grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_output_quantizer is not None and ctx.fp8: + if ctx.grad_output_quantizer is not None: quantizer = ctx.grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -744,7 +747,7 @@ def backward( use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator # Update grad input quantizer - if ctx.grad_input_quantizer is not None and ctx.fp8: + if ctx.grad_input_quantizer is not None: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # Output buffers for Userbuffers reduce-scatter @@ -768,7 +771,7 @@ def backward( grad_output, layout="NN", grad=True, - quantization_params=ctx.grad_input_quantizer if ctx.fp8 else None, + quantization_params=ctx.grad_input_quantizer, out=gemm_out, out_dtype=ctx.activation_dtype, use_split_accumulator=use_split_accumulator, @@ -895,7 +898,7 @@ def backward( "out_dtype": ( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": (ctx.grad_weight_quantizer if ctx.fp8 else None), + "quantization_params": ctx.grad_weight_quantizer, "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(weight, "overwrite_main_grad", False) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index a97ba398e0..1fd2fcba8d 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -500,6 +500,10 @@ def forward( ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False ctx.ub_bulk_wgrad = False + ctx.grad_input_quantizer = None + ctx.grad_weight_quantizer = None + ctx.grad_output_quantizer = None + # ------------------------------------------------------ # Cached state for backward pass is ready... @@ -592,7 +596,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Configure quantizer for grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_output_quantizer is not None and ctx.fp8: + if ctx.grad_output_quantizer is not None: quantizer = ctx.grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -611,7 +615,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], not ctx.use_bias and not ctx.requires_wgrad and ctx.grad_output_quantizer is not None - and ctx.fp8 ): ctx.grad_output_quantizer.set_usage(columnwise=False) @@ -723,7 +726,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator # Update grad input quantizer - if ctx.grad_input_quantizer is not None and ctx.fp8: + if ctx.grad_input_quantizer is not None: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # Output buffers for Userbuffers reduce-scatter @@ -748,7 +751,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_output, layout="NN", grad=True, - quantization_params=ctx.grad_input_quantizer if ctx.fp8 else None, + quantization_params=ctx.grad_input_quantizer, out=gemm_out, out_dtype=ctx.activation_dtype, use_split_accumulator=use_split_accumulator, @@ -874,7 +877,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], "out_dtype": ( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": (ctx.grad_weight_quantizer if ctx.fp8 else None), + "quantization_params": ctx.grad_weight_quantizer, "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(weight, "overwrite_main_grad", False) From 58db8ea72fd2a52a8c4fabe324c487352919fa35 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Feb 2026 22:28:55 +0000 Subject: [PATCH 30/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/linear.py | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 1fd2fcba8d..3e8c4c146f 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -503,7 +503,6 @@ def forward( ctx.grad_input_quantizer = None ctx.grad_weight_quantizer = None ctx.grad_output_quantizer = None - # ------------------------------------------------------ # Cached state for backward pass is ready... From 9d0b6547427e4e6f2c969d84431e665c762577b1 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Thu, 5 Feb 2026 17:28:04 -0800 Subject: [PATCH 31/39] Drop delayed scaling change Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/layernorm_linear.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 4de6afa38b..26b14c2d8a 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -527,11 +527,7 @@ def forward( ctx.requires_dgrad = inp_requires_grad ctx.normalization = normalization ctx.reduce_and_update_bwd_fp8_tensors = False - if ( - ctx.fp8 - and not ctx.keep_backward_unquantized - and requires_grad(inp, ln_weight, ln_bias, weight, bias) - ): + if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias): _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase(): From 004cb455f8a39abe5c75f13c8b3a0fb5b179d664 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 9 Feb 2026 11:29:24 -0800 Subject: [PATCH 32/39] Simplify env var logic Signed-off-by: Ziang Li --- transformer_engine/common/recipe/__init__.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 85b232c26b..55010499ec 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -11,11 +11,6 @@ from pydantic.dataclasses import dataclass -def _default_quantize_backward() -> bool: - """Default backward quantization setting.""" - return not bool(int(os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0"))) - - class _FormatHelper(NamedTuple): """ Stores max FP8 values for fprop and bprop a `Format`. @@ -215,7 +210,7 @@ def scaling_factor_compute(amax: Tensor, fp8_dpa: bool = False fp8_mha: bool = False quantize_forward: bool = True - quantize_backward: bool = field(default_factory=_default_quantize_backward) + quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." @@ -261,7 +256,7 @@ class Float8CurrentScaling(Recipe): fp8_dpa: bool = False fp8_mha: bool = False quantize_forward: bool = True - quantize_backward: bool = field(default_factory=_default_quantize_backward) + quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." @@ -317,7 +312,7 @@ class MXFP8BlockScaling(Recipe): fp8_dpa: bool = False fp8_mha: bool = False quantize_forward: bool = True - quantize_backward: bool = field(default_factory=_default_quantize_backward) + quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." @@ -484,7 +479,7 @@ class NVFP4BlockScaling(Recipe): fp8_dpa: bool = False fp8_mha: bool = False quantize_forward: bool = True - quantize_backward: bool = field(default_factory=_default_quantize_backward) + quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") def __post_init__(self) -> None: assert self.fp4_format == Format.E2M1, "Only E2M1 is supported for NVFP4 scaling" @@ -560,7 +555,7 @@ class CustomRecipe(Recipe): fp8_dpa: bool = False fp8_mha: bool = False quantize_forward: bool = True - quantize_backward: bool = field(default_factory=_default_quantize_backward) + quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") def __repr__(self) -> str: return ( From 9baccfd65dde556099fe8ded69160de09342c3a4 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 9 Feb 2026 11:41:01 -0800 Subject: [PATCH 33/39] Move validation check to recipe Signed-off-by: Ziang Li --- transformer_engine/common/recipe/__init__.py | 18 ++++++++++++++++++ transformer_engine/pytorch/quantization.py | 17 ----------------- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 55010499ec..673df45f4c 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -214,6 +214,12 @@ def scaling_factor_compute(amax: Tensor, def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + assert not ( + not self.quantize_forward and self.quantize_backward + ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." + assert ( + not self.quantize_backward + ), "Delayed scaling does not support quantize_backward=False." def __repr__(self) -> str: return ( @@ -260,6 +266,9 @@ class Float8CurrentScaling(Recipe): def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + assert not ( + not self.quantize_forward and self.quantize_backward + ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." def __repr__(self) -> str: return ( @@ -316,6 +325,9 @@ class MXFP8BlockScaling(Recipe): def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + assert not ( + not self.quantize_forward and self.quantize_backward + ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." def __repr__(self) -> str: return ( @@ -393,6 +405,9 @@ def __post_init__(self) -> None: not self.fp8_dpa and not self.fp8_mha ), "FP8 attention is not supported for Float8BlockScaling." assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + assert not ( + not self.quantize_forward and self.quantize_backward + ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." def __repr__(self) -> str: return ( @@ -484,6 +499,9 @@ class NVFP4BlockScaling(Recipe): def __post_init__(self) -> None: assert self.fp4_format == Format.E2M1, "Only E2M1 is supported for NVFP4 scaling" assert self.fp8_format == Format.E4M3, "Only E4M3 is supported for NVFP4 scaling" + assert not ( + not self.quantize_forward and self.quantize_backward + ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." # Quantization params # Note: RHT is currently only applied to column-wise usage so that diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index fb0553056a..bbffe51eec 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -87,21 +87,6 @@ def check_fp8_block_scaling_support() -> Tuple[bool, str]: ) -def _validate_recipe_quantization_flags(recipe: Recipe) -> None: - """Validate forward/backward quantization flags on a recipe.""" - quantize_forward = getattr(recipe, "quantize_forward", True) - quantize_backward = getattr(recipe, "quantize_backward", True) - if not quantize_forward and quantize_backward: - raise ValueError( - "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." - ) - if recipe.delayed() and not quantize_backward: - raise ValueError( - "Invalid recipe configuration: delayed scaling does not support " - "quantize_backward=False." - ) - - def check_recipe_support(recipe: Recipe) -> None: """Check if the given recipe is supported.""" recipe_supported = True @@ -858,8 +843,6 @@ def autocast( """ fp8_recipe = get_default_fp8_recipe() if recipe is None else recipe - if enabled or calibrating: - _validate_recipe_quantization_flags(fp8_recipe) quantize_forward = getattr(fp8_recipe, "quantize_forward", True) effective_enabled = enabled and quantize_forward if effective_enabled: From 207eb5a7d2319d4e12a016faa37a0438c4c8ce27 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 9 Feb 2026 11:55:28 -0800 Subject: [PATCH 34/39] Simplify effective_enabled Signed-off-by: Ziang Li --- transformer_engine/pytorch/quantization.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index bbffe51eec..00196c584f 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -842,11 +842,9 @@ def autocast( are reduced at the end of each training step. """ - fp8_recipe = get_default_fp8_recipe() if recipe is None else recipe - quantize_forward = getattr(fp8_recipe, "quantize_forward", True) - effective_enabled = enabled and quantize_forward + effective_enabled = enabled and getattr(recipe, "quantize_forward", True) if effective_enabled: - check_recipe_support(fp8_recipe) + check_recipe_support(recipe) # Save current state so we always restore it on exit. fp8_state = FP8GlobalStateManager.get_autocast_state() @@ -854,7 +852,7 @@ def autocast( FP8GlobalStateManager.autocast_enter( enabled=effective_enabled, calibrating=calibrating, - fp8_recipe=fp8_recipe, + fp8_recipe=recipe, fp8_group=amax_reduction_group, _graph=_graph, ) From 15117b1d545660aa3a9ceae82fb0bd4c4191ed44 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 9 Feb 2026 11:56:33 -0800 Subject: [PATCH 35/39] Fix inverted assertion logic Signed-off-by: Ziang Li --- transformer_engine/common/recipe/__init__.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 673df45f4c..f03e9b24d6 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -217,9 +217,7 @@ def __post_init__(self) -> None: assert not ( not self.quantize_forward and self.quantize_backward ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." - assert ( - not self.quantize_backward - ), "Delayed scaling does not support quantize_backward=False." + assert self.quantize_backward, "Delayed scaling does not support quantize_backward=False." def __repr__(self) -> str: return ( From 3fc5270e82689136aa58d12deb98b1012bdc11a0 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 9 Feb 2026 12:33:38 -0800 Subject: [PATCH 36/39] Simplify changes under ops Signed-off-by: Ziang Li --- transformer_engine/pytorch/ops/basic/basic_linear.py | 4 ---- transformer_engine/pytorch/ops/basic/quantize.py | 11 ++++++----- .../ops/fused/forward_linear_bias_activation.py | 7 ++----- .../pytorch/ops/fused/forward_linear_bias_add.py | 7 ++----- .../pytorch/ops/fused/forward_linear_scale_add.py | 7 ++----- 5 files changed, 12 insertions(+), 24 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index ba7de55f69..16b7bcb7c5 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -1020,11 +1020,7 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: saved_input = input_ if keep_backward_unquantized else x_local - if not weight_requires_grad: - saved_input = None saved_weight = self.weight if keep_backward_unquantized else w - if not input_requires_grad: - saved_weight = None if is_cpu_offload_enabled(): mark_activation_offload(saved_input) ctx.save_for_backward(saved_input, saved_weight) diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py index e6c28b9fdc..6e90e33846 100644 --- a/transformer_engine/pytorch/ops/basic/quantize.py +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -57,11 +57,12 @@ def op_forward( # Check if FP8 is enabled fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() quantize_forward = fp8_enabled and self._quantize_forward - quantize_backward = ( - fp8_enabled - and self._quantize_backward - and FP8GlobalStateManager.get_fp8_recipe().quantize_backward - ) + quantize_backward = fp8_enabled and self._quantize_backward + + # Recipe quantize overrides + if FP8GlobalStateManager.get_fp8_recipe() is not None: + quantize_forward = quantize_forward and FP8GlobalStateManager.get_fp8_recipe().quantize_forward + quantize_backward = quantize_backward and FP8GlobalStateManager.get_fp8_recipe().quantize_backward # Quantize if needed out = input_ diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 2bccabb306..860407904c 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -122,11 +122,8 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input = x_local - saved_weight = w - if keep_backward_unquantized: - saved_input = input_ if input_requires_grad else None - saved_weight = linear_op.weight if weight_requires_grad else None + saved_input = input_ if keep_backward_unquantized else x_local + saved_weight = linear_op.weight if keep_backward_unquantized else w if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 03e3bff6f3..0729291d55 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -119,11 +119,8 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input = x_local - saved_weight = w - if keep_backward_unquantized: - saved_input = input_ if input_requires_grad else None - saved_weight = linear_op.weight if weight_requires_grad else None + saved_input = input_ if keep_backward_unquantized else x_local + saved_weight = linear_op.weight if keep_backward_unquantized else w if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index 8cebcec53a..dfdd11a231 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -100,11 +100,8 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input = x_local - saved_weight = w - if keep_backward_unquantized: - saved_input = input_ if input_requires_grad else None - saved_weight = linear_op.weight if weight_requires_grad else None + saved_input = input_ if keep_backward_unquantized else x_local + saved_weight = linear_op.weight if keep_backward_unquantized else w if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) From 9201d1926d44099f556745613b555869b70b08a3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Feb 2026 20:34:39 +0000 Subject: [PATCH 37/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/ops/basic/quantize.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py index 6e90e33846..33062d5b88 100644 --- a/transformer_engine/pytorch/ops/basic/quantize.py +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -61,8 +61,12 @@ def op_forward( # Recipe quantize overrides if FP8GlobalStateManager.get_fp8_recipe() is not None: - quantize_forward = quantize_forward and FP8GlobalStateManager.get_fp8_recipe().quantize_forward - quantize_backward = quantize_backward and FP8GlobalStateManager.get_fp8_recipe().quantize_backward + quantize_forward = ( + quantize_forward and FP8GlobalStateManager.get_fp8_recipe().quantize_forward + ) + quantize_backward = ( + quantize_backward and FP8GlobalStateManager.get_fp8_recipe().quantize_backward + ) # Quantize if needed out = input_ From 1e0f1d2deb435facb7c28bbc4374036db930c91b Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 9 Feb 2026 12:52:01 -0800 Subject: [PATCH 38/39] Simplify ctx.keep_backward_unquantized Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/base.py | 3 +-- transformer_engine/pytorch/module/grouped_linear.py | 3 +-- transformer_engine/pytorch/module/layernorm_linear.py | 4 +--- transformer_engine/pytorch/module/linear.py | 4 +--- 4 files changed, 4 insertions(+), 10 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 4a2140718d..a878f2ace2 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1135,8 +1135,7 @@ def grad_output_preprocess( grad_output = grad_output.reshape((-1, grad_output.shape[-1])) grad_output = grad_output.contiguous() gather_grad_output = row_parallel_mode and ctx.sequence_parallel - keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) - use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized + use_fp8_bwd = ctx.fp8 and not ctx.keep_backward_unquantized # Non-FP8 case: bgrad is fused with wgrad for this case. if not use_fp8_bwd and not ctx.debug: diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 73dc81ad41..abe6df6875 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -336,7 +336,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], origin_weights = saved_tensors[2 * N : 3 * N] biases = saved_tensors[3 * N : 4 * N] main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] - keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) if ctx.cpu_offloading: if ctx.grad_added_to_main_grad: @@ -415,7 +414,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], device=ctx.device, ) weights_for_dgrad = weights - if keep_backward_unquantized: + if ctx.keep_backward_unquantized: weights_for_dgrad = origin_weights # Make sure weights are available in column-wise format # for dgrad computation. diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 26b14c2d8a..187fd70f92 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -615,8 +615,6 @@ def backward( if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: origin_weight.main_grad = main_grad - keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) - # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None ub_obj_dgrad = None @@ -760,7 +758,7 @@ def backward( # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") weight_for_dgrad = weight - if keep_backward_unquantized: + if ctx.keep_backward_unquantized: weight_for_dgrad = origin_weight gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 3e8c4c146f..7d960102ec 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -554,8 +554,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) nvtx_range_pop(f"{nvtx_label}.fsdp_gather") - keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) - # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None ub_obj_dgrad = None @@ -743,7 +741,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], nvtx_range_push(f"{nvtx_label}.dgrad_gemm") weight_for_dgrad = weight_fp8 - if keep_backward_unquantized: + if ctx.keep_backward_unquantized: weight_for_dgrad = weight gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, From 253873a4560b2c2a2c909918cc3ee26500e5b43d Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 9 Feb 2026 15:07:48 -0800 Subject: [PATCH 39/39] Fix missing attribute Signed-off-by: Ziang Li --- transformer_engine/common/recipe/__init__.py | 2 ++ transformer_engine/pytorch/module/layernorm_mlp.py | 1 + 2 files changed, 3 insertions(+) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index f03e9b24d6..d534ad883b 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -382,6 +382,8 @@ class Float8BlockScaling(Recipe): fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) fp8_dpa: bool = False fp8_mha: bool = False + quantize_forward: bool = True + quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") def __post_init__(self) -> None: assert self.x_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for x" diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 8e6a189843..ac10534012 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -784,6 +784,7 @@ def _forward( ctx.fc2_main_grad_func = lambda: fc2_weight.main_grad ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.keep_backward_unquantized = keep_backward_unquantized ctx.fc1_grad_input_quantizer = fc1_grad_input_quantizer ctx.fc1_grad_weight_quantizer = fc1_grad_weight_quantizer ctx.fc1_grad_output_quantizer = fc1_grad_output_quantizer