Skip to content

[JAX][Common] Enable cuDNN fused attn backend for NO_MASK + bidirectional SWA#2961

Merged
KshitijLakhani merged 6 commits intoNVIDIA:mainfrom
KshitijLakhani:klakhani/fix/no-mask-swa-attn
May 6, 2026
Merged

[JAX][Common] Enable cuDNN fused attn backend for NO_MASK + bidirectional SWA#2961
KshitijLakhani merged 6 commits intoNVIDIA:mainfrom
KshitijLakhani:klakhani/fix/no-mask-swa-attn

Conversation

@KshitijLakhani
Copy link
Copy Markdown
Collaborator

@KshitijLakhani KshitijLakhani commented May 5, 2026

Description

Enable cuDNN fused attn backend for right sided window for NO_MASK when using cuDNN 9.6+

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • [] Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Enabled cuDNN fused attn backend for right sided window for NO_MASK when using cuDNN 9.6+ in TE common
  • No changes needed in TE PyT as there are tests already exercising this path (currently they would use a non cuDNN fused attn backend for NO_MASK + right side window + cuDNN 9.6+ but this PR ensures they use the cuDNN fused attn backend instead)
  • TE JAX fused attn did not have any bidirectional window tests - so added those
  • Added a helper in the TE JAX fused attn tests to pick the window size based on the cuDNN version
    • If bidirectional is available, run only those (no unidirectional tests run)
    • If only unidirectional available, run only those
    • If none availabel, skip
  • Added window size to the warning string reported when falling back to unfused attn

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…so add a helper to pick window based on cuDNN version support in fused_attn.cpp

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani self-assigned this May 5, 2026
@KshitijLakhani
Copy link
Copy Markdown
Collaborator Author

/te-ci L0 L1 L2

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani marked this pull request as ready for review May 6, 2026 16:59
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 6, 2026

Greptile Summary

This PR enables the cuDNN fused attention backend for the NO_MASK + bidirectional sliding-window attention (SWA) case when running cuDNN 9.6+. A single-line addition in fused_attn.cpp adds NVTE_NO_MASK to the 9.6 sliding-window condition block, while a new test helper selects the appropriate window size based on the runtime cuDNN version, and the JAX fallback warning is extended to include window_size.

  • fused_attn.cpp: Adds NVTE_NO_MASK to the cuDNN ≥ 9.6 second-branch sliding-window gate alongside the already-present PADDING_MASK / PADDING_CAUSAL_MASK / CAUSAL_BOTTOM_RIGHT_MASK entries. The placement and surrounding constraints (max_seqlen_q ≤ max_seqlen_kv, NO_BIAS, dropout == 0) match the pattern for the other mask types in that branch.
  • test_fused_attn.py: Introduces _get_swa_window_size_for_test to pick a version-appropriate window; tests skip on cuDNN < 9.2 and use an asymmetric bidirectional window (s_kv // 10, s_kv // 10 + 5) for NO_MASK / PADDING_MASK on 9.6+. One minor gap: on cuDNN ≥ 9.6, the left-only SWA path (right = 0) for NO_MASK is no longer exercised by these tests.
  • transformer.py: Trivial diagnostic improvement — adds window_size to the unfused-attention fallback warning.

Confidence Score: 4/5

Safe to merge; the C++ backend change is a one-line, well-scoped addition that mirrors the pattern already used for PADDING_MASK in the same block, and the JAX warning change is purely cosmetic.

The core backend logic change is minimal and follows the established pattern for existing mask types in the 9.6 SWA branch. The test helper is readable and the cuDNN version gating is correct. The only notable gap is that on cuDNN ≥ 9.6 the existing SWA tests for NO_MASK now exclusively exercise the bidirectional path, leaving the left-only (right=0) path untested on that version family.

tests/jax/test_fused_attn.py — the _get_swa_window_size_for_test helper shifts NO_MASK SWA tests from left-only to bidirectional on cuDNN ≥ 9.6, narrowing coverage of the left-only path on newer driver versions.

Important Files Changed

Filename Overview
transformer_engine/common/fused_attn/fused_attn.cpp Adds NVTE_NO_MASK to the cuDNN 9.6+ sliding-window second branch, enabling bidirectional SWA for NO_MASK via the fused backend; placement and conditions look correct.
tests/jax/test_fused_attn.py Adds _get_swa_window_size_for_test helper that version-gates bidirectional window selection; replaces inline tuple logic in _test_forward and test_backward.
transformer_engine/jax/flax/transformer.py One-liner improvement: adds window_size to the fallback-to-unfused warning for easier diagnostics.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[nvte_get_fused_attn_backend] --> B{cuDNN version?}
    B -->|< 9.2| C{SWA window?}
    C -->|left=-1, right=-1 or 0| D[Allow]
    C -->|other| E[Reject SWA]
    B -->|>= 9.2| F{Sliding window branch}
    F -->|left=-1, right=-1, NO_MASK| G[Allow full attn]
    F -->|left=any, right=0, NO_MASK/CAUSAL/...| H{Format check}
    H -->|BSHD/SBHD| I[Allow left-only SWA]
    H -->|other| J[Reject]
    B -->|>= 9.6| K{9.6 sliding window}
    K -->|left=-1, right=-1 or 0| L[Allow - any mask]
    K -->|left=any, right>=0 or -1| M{Mask type?}
    M -->|CAUSAL_BOTTOM_RIGHT + sm_arch checks| N[Allow bidir SWA]
    M -->|NO_MASK NEW| N
    M -->|PADDING_MASK| N
    M -->|PADDING_CAUSAL_MASK| N
    M -->|PADDING_CAUSAL_BOTTOM_RIGHT + sm_arch checks| N
    M -->|CAUSAL_MASK or other| O[Reject bidir SWA]
    N --> P{max_seqlen_q <= max_seqlen_kv + NO_BIAS + dropout=0}
    P -->|Yes| Q[flag_arb = true - cuDNN fused backend]
    P -->|No| R[Reject]
Loading

Comments Outside Diff (1)

  1. tests/jax/test_fused_attn.py, line 1078-1080 (link)

    P2 Left-only SWA no longer tested for NO_MASK on cuDNN ≥ 9.6

    Before this PR, SWA tests always used (s_kv // 10, 0) regardless of cuDNN version, so the left-only path was covered on every version. After this change, on cuDNN ≥ 9.6 the test switches to a bidirectional window for NO_MASK and PADDING_MASK, meaning the finite-left / zero-right combination is only validated on cuDNN 9.2–9.5 for those mask types. If the new bidirectional path introduced a regression in the left-only path on 9.6+ (different code path in the 9.6 second branch when right == 0), no test in this class would catch it. Consider adding a separate parametrize value or an additional assertion that covers (left, 0) explicitly on cuDNN ≥ 9.6.

Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

@KshitijLakhani
Copy link
Copy Markdown
Collaborator Author

Failures in the CI are unrelated to attention in any way so safe to merge.

@KshitijLakhani KshitijLakhani merged commit 4b6923d into NVIDIA:main May 6, 2026
2 checks passed
@KshitijLakhani KshitijLakhani deleted the klakhani/fix/no-mask-swa-attn branch May 6, 2026 18:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

attention enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants