Skip to content

[PyTorch/Common] Remove legacy FP8DS implementation #2959

Open
cyanguwa wants to merge 10 commits intoNVIDIA:mainfrom
cyanguwa:remove_fp8_v0
Open

[PyTorch/Common] Remove legacy FP8DS implementation #2959
cyanguwa wants to merge 10 commits intoNVIDIA:mainfrom
cyanguwa:remove_fp8_v0

Conversation

@cyanguwa
Copy link
Copy Markdown
Collaborator

@cyanguwa cyanguwa commented May 5, 2026

Description

This PR removes a legacy path of FP8 Delayed Scaling implementation from TE 1.6.0. It supports T3HD with max_seq_len<=512, head_dim=64, and padding mask. cudnn-frontend will remove their pre-FORT hand-written FMHA kernels (MR2829) hence the removal of this FP8 implementation here. General THD support for FP8 will be added in future PRs.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

See Description.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

cyanguwa and others added 3 commits April 30, 2026 16:50
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
@cyanguwa cyanguwa marked this pull request as ready for review May 5, 2026 19:31
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 5, 2026

Greptile Summary

This PR removes the legacy FP8 Delayed Scaling implementation from TE 1.6.0 that supported only T3HD layout with max_seq_len≤512, head_dim=64, and padding mask. The removal is motivated by cudnn-frontend dropping its pre-FORT hand-written FMHA kernels.

  • ~1,700 lines deleted from fused_attn_fp8.cu: the entire v0 cuDNN 8.x hand-written FP8 graph-API implementation (fused_attn_fp8_fwd_impl / fused_attn_fp8_bwd_impl) is removed, and the remaining v1 functions are renamed to be the canonical entry points.
  • ZInv tensor (devPtrZInv / input_ZInv) is removed throughout the C++ call chain (fused_attn.cpp, fused_attn_fp8.cu, fused_attn_fp8.h) because it was only needed by the v0 T3HD path.
  • Python-side CP parallel code (context_parallel.py) and test setup (test_attention.py) are simplified by eliminating all qkv_layout == "t3hd" FP8 special-casing and the associated NVTE_FUSED_ATTN_FE_VER env-var branching.

Confidence Score: 5/5

Safe to merge — a focused deletion of the v0 T3HD FP8 path with consistent cleanup across C++, CUDA, and Python layers.

All changes are pure removal or renaming of code that was gated behind the T3HD layout. The v1 functions are promoted to canonical names and the call sites, parameter lists, aux-tensor packing, and tests are updated uniformly. The T3HD layout itself remains valid for F16/BF16 paths; only the FP8 subpath is dropped. No new logic is introduced.

No files require special attention. The deletions in fused_attn_fp8.cu are the bulk of the change and they remove a single, well-isolated v0 implementation that was already superseded by the v1 functions.

Important Files Changed

Filename Overview
transformer_engine/common/fused_attn/fused_attn_fp8.cu Removes ~1,700 lines of v0 hand-written cuDNN graph-API FP8 implementation; renames v1 functions to canonical names and drops devPtrZInv parameter throughout.
transformer_engine/common/fused_attn/fused_attn.cpp Removes T3HD condition from FP8 backend selection; simplifies backend assignment (the cudnn_runtime_version >= 8900 guard was already subsumed by the >= 90201 inner conditions); removes ZInv from BWD aux tensor extraction.
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Removes all T3HD FP8 special-casing (the duplicate softmax_lse in 3-element aux_tensors lists) from CP P2P forward/backward; non-FP8 and FP8 paths now share identical aux_tensors construction.
transformer_engine/pytorch/csrc/extensions/attention.cpp Removes T3HD-specific ZInv tensor allocation from the aux_ctx_tensors pack; Max tensor allocation is now guarded solely by return_max_logit.
transformer_engine/common/fused_attn/fused_attn_fp8.h Updates fused_attn_fp8_bwd declaration to remove input_ZInv parameter, matching the new implementation signature.
transformer_engine/common/fused_attn/utils.cu Removes the cu_seqlens_to_offsets CUDA kernel and its header declaration — this kernel was exclusively used by the now-deleted T3HD v0 FP8 backward path.
tests/pytorch/attention/test_attention.py Removes NVTE_FUSED_ATTN_FE_VER-based model branching; all FP8 model configs are now tested unconditionally with the v1 backend, hardening test coverage.
transformer_engine/common/include/transformer_engine/fused_attn.h Removes the now-stale backend support matrix table from the public API docstring and updates ZInv references to generic 'softmax stats' descriptions.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["nvte_get_fused_attn_backend()"] --> B{FP8 dtype?}
    B -- No --> C["F16/BF16 backend selection"]
    B -- Yes --> D{cuDNN version & layout checks}
    D -- "cuDNN >= 9.2.1 & BSHD/SBHD/BHSD" --> E["NVTE_FP8 backend"]
    D -- "layout = T3HD\n(or unsupported)" --> F["NVTE_No_Backend\n(T3HD FP8 no longer supported)"]
    E --> G["nvte_fused_attn_fwd / bwd"]
    G --> H{qkv_format?}
    H -- "BSHD / SBHD / BHSD" --> I["fused_attn_fp8_fwd_impl()\nfused_attn_fp8_bwd_impl()\n(formerly _v1 - now canonical)"]
    H -- "Other (e.g. THD)" --> J["NVTE_ERROR: unsupported format"]
    I --> K["Aux CTX tensors:\nS (softmax stats)\noptional Max\nrng_state"]
Loading

Reviews (6): Last reviewed commit: "address review: drop dead 8.9 FP8 guard ..." | Re-trigger Greptile

cyanguwa and others added 5 commits May 5, 2026 13:37
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
@cyanguwa cyanguwa changed the title [PyTorch/Common] Remove old, unused FP8 implementation [PyTorch/Common] Remove legacy FP8DS implementation May 5, 2026
cyanguwa added 2 commits May 5, 2026 15:11
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
@cyanguwa cyanguwa added the 2.16.0 label May 5, 2026
@cyanguwa
Copy link
Copy Markdown
Collaborator Author

cyanguwa commented May 5, 2026

/te-ci L0

@cyanguwa cyanguwa requested a review from sudhakarsingh27 May 5, 2026 22:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant