Skip to content

[PyTorch] Preserve fprop operands for dequantized backward override#3141

Open
negvet wants to merge 4 commits into
NVIDIA:mainfrom
negvet:fix_dequantized_override_save_original_input
Open

[PyTorch] Preserve fprop operands for dequantized backward override#3141
negvet wants to merge 4 commits into
NVIDIA:mainfrom
negvet:fix_dequantized_override_save_original_input

Conversation

@negvet

@negvet negvet commented Jun 23, 2026

Copy link
Copy Markdown
Collaborator

Description

Follow-up to #2644, which introduced NVTE_BACKWARD_OVERRIDE=high_precision|dequantized.

high_precision is intended to use original unquantized tensor in backward, while dequantized is intended to use dequantized tensor from the forward-quantized one. However, save_original_input=True could override the dequantized behavior in Linear and GroupedLinear, causing backward to use the original input instead of the fprop-quantized operand.

This PR makes the override semantics explicit:

  • backward_override="high_precision" forces save_original_input=True
  • backward_override="dequantized" forces save_original_input=False

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Evgeny <etsykunov@nvidia.com>
@negvet negvet requested a review from ksivaman as a code owner June 23, 2026 13:19
@negvet

negvet commented Jun 23, 2026

Copy link
Copy Markdown
Collaborator Author

cc @zianglih

@negvet

negvet commented Jun 23, 2026

Copy link
Copy Markdown
Collaborator Author

/te-ci L0 L1

@greptile-apps

greptile-apps Bot commented Jun 23, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR fixes a semantic conflict in te.Linear and te.GroupedLinear where a user-set save_original_input=True could silently override the backward_override="dequantized" behavior, causing backward to use the original input tensor instead of the fprop-quantized operand. The fix adds an elif backward_override == "dequantized": save_original_input = False guard to make the two override modes mutually exclusive.

  • linear.py / grouped_linear.py: Two-line change each — when the dequantized override is active, save_original_input is forced to False regardless of the module-level setting; the existing high_precision branch continues to force it to True.
  • test_backward_override.py: Four new parametrised tests covering both dequantized (verifies fprop-quantized rowwise-only operand is saved, and that gradients match the reference) and high_precision (verifies original plain tensor is saved) for both Linear and GroupedLinear.

Confidence Score: 5/5

Safe to merge — the two-line change in each module has no effect outside FP8 mode and its interaction with the downstream backward logic has been verified.

The fix is minimal and scoped: it only activates when FP8 is enabled and the recipe explicitly sets backward_override=dequantized. Both modules already null-out backward_override on the non-FP8 path, so non-FP8 callers are completely unaffected. The backward logic in grouped_linear.py uses if ctx.save_original_input / elif ctx.backward_override == dequantized branches that are now correctly exclusive. Four new parametrised tests — including bit-exact gradient comparisons — give strong coverage of both override modes for both modules.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/linear.py Adds elif backward_override == "dequantized": save_original_input = False immediately after the existing high_precision branch in _linear_forward_impl, correctly forcing exclusive semantics.
transformer_engine/pytorch/module/grouped_linear.py Same two-line fix in _GroupedLinear.forward for the non-grouped-tensor path; the change sits after the existing high_precision branch and before the quantizer configuration and tensor-save logic that depends on save_original_input.
tests/pytorch/test_backward_override.py Four new parametrised tests: dequantized + Linear, dequantized + GroupedLinear (bit-exact gradient comparison against reference), high_precision + Linear, high_precision + GroupedLinear (saved-operand type assertions). Also adds the previously-missing GroupedLinear + high_precision coverage.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[Linear / GroupedLinear forward] --> B{fp8 active?}
    B -- No --> C[backward_override = None]
    B -- Yes --> D[backward_override = recipe.backward_override]
    C --> E[save_original_input = module setting]
    D --> F{backward_override?}
    F -- high_precision --> G[save_original_input = True]
    F -- dequantized --> H[save_original_input = False]
    F -- None --> E
    G --> I[Save original inp tensor for re-quantisation in backward]
    H --> J[Save fprop-quantized QuantizedTensorStorage rowwise-only layout]
    E --> K{module.save_original_input?}
    K -- True --> I
    K -- False --> J
    I --> L[backward wgrad: re-split + requantize from original]
    J --> M[backward wgrad: dequantize from fprop quantized tensor]
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
    A[Linear / GroupedLinear forward] --> B{fp8 active?}
    B -- No --> C[backward_override = None]
    B -- Yes --> D[backward_override = recipe.backward_override]
    C --> E[save_original_input = module setting]
    D --> F{backward_override?}
    F -- high_precision --> G[save_original_input = True]
    F -- dequantized --> H[save_original_input = False]
    F -- None --> E
    G --> I[Save original inp tensor for re-quantisation in backward]
    H --> J[Save fprop-quantized QuantizedTensorStorage rowwise-only layout]
    E --> K{module.save_original_input?}
    K -- True --> I
    K -- False --> J
    I --> L[backward wgrad: re-split + requantize from original]
    J --> M[backward wgrad: dequantize from fprop quantized tensor]
Loading

Reviews (2): Last reviewed commit: "Merge branch 'main' into fix_dequantized..." | Re-trigger Greptile

Comment thread tests/pytorch/test_backward_override.py
@zianglih

Copy link
Copy Markdown
Contributor

Thanks for the fix!

root and others added 2 commits June 25, 2026 06:30
@negvet

negvet commented Jun 25, 2026

Copy link
Copy Markdown
Collaborator Author

/te-ci L0 L1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants