Skip to content

Conversation

@zianglih
Copy link

@zianglih zianglih commented Feb 3, 2026

Description

@HumansAnd

Add an NVTE_KEEP_BACKWARD_UNQUANTIZED env var for quantized fprop + high precision wgrad & dgrad.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

zianglih and others added 2 commits February 2, 2026 16:45
Signed-off-by: Ziang Li <ziangli@umich.edu>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 3, 2026

Greptile Overview

Greptile Summary

This PR adds an opt-in path to run FP8/FP4 quantized forward while keeping backward in higher precision, primarily controlled through a new quantize_backward recipe flag (defaulted from NVTE_KEEP_BACKWARD_UNQUANTIZED). The change is threaded through TE modules (Linear/GroupedLinear/LayerNormLinear) and fusible ops (BasicLinear + several forward fusions) by saving high-precision activations/weights for backward and disabling quantized backward compute when requested. It also refactors TransformerEngineBaseModule.prepare_forward from a context manager to an explicit prepare_forward() + end_forward() pair and introduces fast_setattr/module_setattr for attribute assignment.

In the ops fuser stack, forward fusion functions are migrated to a unified Class.fuse_forward_ops(ops, **unused) API and registered via transformer_engine/pytorch/ops/fused/__init__.py.

Confidence Score: 3/5

  • This PR is moderately safe to merge, but UB-enabled forward fusion likely violates the new keep-backward-unquantized contract.
  • Most linear/module paths consistently thread keep_backward_unquantized into saved tensors and disable quantized backward, and the fusion API refactors appear internally consistent. However, UserbuffersForwardLinear does not honor keep_backward_unquantized and will still save quantized tensors / keep quantized backward enabled, which breaks the feature semantics when UB overlap is active and could lead to incorrect behavior or unexpected precision/regressions in that configuration.
  • transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py Refactors fusion registration into static fuse_forward_ops; does not incorporate new keep_backward_unquantized behavior unlike other fused linear ops (forward saves remain x_local/w, with_quantized_compute unchanged).
transformer_engine/pytorch/ops/basic/basic_linear.py Threads keep_backward_unquantized into functional forward and quantizer usage; saves high-precision tensors for backward when requested and disables quantized backward compute.
transformer_engine/pytorch/module/base.py Refactors prepare_forward from contextmanager to explicit prepare_forward/end_forward, adds fast_setattr/module_setattr, and gates FP8 backward preprocessing via ctx.keep_backward_unquantized.
transformer_engine/pytorch/module/linear.py Implements keep_backward_unquantized for Linear autograd path (forces save_original_input, disables UB overlap/FP8 backward, uses full-precision weights for dgrad) and adapts to new prepare_forward/end_forward API.
transformer_engine/common/recipe/init.py Adds quantize_forward/quantize_backward flags to recipes with NVTE_KEEP_BACKWARD_UNQUANTIZED default and new validation/repr fields.
transformer_engine/pytorch/module/layernorm_mlp.py Adds keep_backward_unquantized detection but currently hard-asserts unsupported; refactors prepare_forward usage and quantizer settings.

Sequence Diagram

sequenceDiagram
    participant U as User code
    participant A as te.autocast/FP8GlobalStateManager
    participant M as TE Module (Linear/Grouped/LNLinear)
    participant F as OperationFuser (fusible ops)
    participant UB as UserbuffersForwardLinear
    participant BL as BasicLinear._functional_forward

    U->>A: enter autocast(recipe)
    A-->>M: fp8 enabled + recipe.quantize_backward

    U->>M: forward(inp)
    M->>M: prepare_forward(inp)
    Note over M: keep_backward_unquantized = fp8 && !recipe.quantize_backward

    alt BasicLinear / non-UB fused paths
        M->>F: run pipeline
        F->>BL: _functional_forward(..., keep_backward_unquantized)
        BL-->>F: output, x_local, w
        F->>F: save_for_backward(saved_input/saved_weight)
        Note over F: if keep_backward_unquantized, save high-precision tensors
    else Userbuffers fused forward selected
        M->>F: run pipeline
        F->>UB: fuser_forward(...)
        UB-->>F: output, x_local, w
        F->>F: save_for_backward(x_local, w)
        Note over UB,F: UB path does not incorporate keep_backward_unquantized
    end

    U-->>M: loss.backward()
    M->>F: autograd backward
    alt keep_backward_unquantized honored
        F->>BL: backward uses saved high-precision tensors
        Note over BL: quantized backward disabled
    else UB path
        F->>UB: backward expects quantized tensors / quantized compute
    end

    U->>A: exit autocast
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@zianglih
Copy link
Author

zianglih commented Feb 3, 2026

I'll work on potential unit test breakage.

Signed-off-by: Ziang Li <ziangli@umich.edu>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

zianglih and others added 2 commits February 3, 2026 09:56
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: Ziang Li <ziangli@umich.edu>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

… is used

Signed-off-by: Ziang Li <ziangli@umich.edu>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

ln_out_return = None
if return_layernorm_output or return_layernorm_output_gathered:
ln_out_return = ln_out
ln_out_hp = ln_out if keep_backward_unquantized else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

storing both ln_out (quantized) and ln_out_hp (high precision) doubles the memory footprint for this activation

verify this memory overhead is acceptable for your target models, especially during training with large batch sizes or long sequences

Signed-off-by: Ziang Li <ziangli@umich.edu>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

zianglih and others added 2 commits February 5, 2026 13:59
Signed-off-by: Ziang Li <ziangli@umich.edu>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Ziang Li <ziangli@umich.edu>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

if enabled or calibrating:
_validate_recipe_quantization_flags(fp8_recipe)
quantize_forward = getattr(fp8_recipe, "quantize_forward", True)
effective_enabled = enabled and quantize_forward
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not very sure if we should disable when quantize_forward is false

Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

12 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 57 to 60
# Check if FP8 is enabled
fp8_enabled = FP8GlobalStateManager.is_fp8_enabled()
quantize_forward = fp8_enabled and self._quantize_forward
quantize_backward = fp8_enabled and self._quantize_backward
quantize_backward = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recipe None crash

FP8GlobalStateManager.get_fp8_recipe() can be None (e.g., if FP8 is enabled but no recipe was set), so ...get_fp8_recipe().quantize_backward will raise an AttributeError. This makes Quantize.forward() crash in that configuration; please guard for None (e.g., default to quantized backward when recipe is unset, or explicitly error with a clearer message).

zianglih and others added 2 commits February 9, 2026 12:33
Signed-off-by: Ziang Li <ziangli@umich.edu>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

assert not (
not self.quantize_forward and self.quantize_backward
), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True."
assert self.quantize_backward, "Delayed scaling does not support quantize_backward=False."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assertion prevents using NVTE_KEEP_BACKWARD_UNQUANTIZED=1 with DelayedScaling recipe - when env var is set, quantize_backward becomes False, making this assert fail and blocking the entire feature for this recipe type

Comment on lines +238 to +240
assert (
not keep_backward_unquantized
), "NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assertion makes LayerNormMLP crash immediately when NVTE_KEEP_BACKWARD_UNQUANTIZED=1 is set - if this module doesn't support the feature, either implement it or handle more gracefully with a clear error message before reaching this point

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 9, 2026

Additional Comments (1)

transformer_engine/common/recipe/__init__.py
Float8BlockScaling is missing field declarations for quantize_forward and quantize_backward, but they're used in __post_init__ (line 407) and __repr__ (line 425-426) - will cause AttributeError at runtime

    fp8_mha: bool = False
    quantize_forward: bool = True
    quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1")

Signed-off-by: Ziang Li <ziangli@umich.edu>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

10 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines 1019 to +1026

# Save state for backward pass
if ctx.requires_grad:
saved_input = input_ if keep_backward_unquantized else x_local
saved_weight = self.weight if keep_backward_unquantized else w
if is_cpu_offload_enabled():
mark_activation_offload(x_local)
ctx.save_for_backward(x_local, w)
ctx.with_quantized_compute = with_quantized_compute
mark_activation_offload(saved_input)
ctx.save_for_backward(saved_input, saved_weight)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary saved tensors

The forward path now saves saved_input/saved_weight whenever ctx.requires_grad is true, even when weight_requires_grad or input_requires_grad is false. In cases like frozen weights (common for finetune) or when only one side needs grads, this will save extra tensors and can materially increase activation memory. The prior if not weight_requires_grad: saved_input = None / if not input_requires_grad: saved_weight = None logic avoided that.

Signed-off-by: Ziang Li <ziangli@umich.edu>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 9, 2026

Additional Comments (1)

transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py
Backward remains quantized
NVTE_KEEP_BACKWARD_UNQUANTIZED is implemented across other linear paths by saving high-precision tensors for backward and setting ctx.with_quantized_compute = fp8 && !keep_backward_unquantized. In the Userbuffers forward fusion, the forward path never computes/propagates keep_backward_unquantized and always saves x_local, w (potentially quantized) and sets linear_op_ctx.with_quantized_compute = with_quantized_compute, so UB-enabled execution will still take the quantized-backward path even when recipe.quantize_backward=False (i.e., when the feature is intended to disable quantized backward).

@zianglih
Copy link
Author

Currently, without NVTE_KEEP_BACKWARD_UNQUANTIZED , unit test is aligned with main:
te-2644.log
te-main.log

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants