From f2c78e5edb2766fa39212858d094235f542506d9 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Tue, 5 May 2026 10:22:32 -0700 Subject: [PATCH 1/6] Enable right side of sliding window for cuDNN fused attn backend Signed-off-by: Kshitij Lakhani --- transformer_engine/common/fused_attn/fused_attn.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 141767b803..ae8ddbed69 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -469,6 +469,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv && cudnn_runtime_version <= 90700) || cudnn_runtime_version > 90700)))) || + attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && From 66a7031c239aa640b67d529ea4c2284a3d1eccd1 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Tue, 5 May 2026 15:58:39 -0700 Subject: [PATCH 2/6] Add window size in the warning string when falling back to unfused attn Signed-off-by: Kshitij Lakhani --- transformer_engine/jax/flax/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 513677e4a1..a2e7920843 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -788,7 +788,7 @@ def __call__( "Fall back to the unfused attention.\n" "Please try to update the cuDNN and TE to the latest version.\n" f"{qkv_layout=}\n{attn_bias_type=}\n{attn_mask_type=}\n" - f"{self.attention_dropout=}\n{self.num_attention_heads=}\n" + f"{self.attention_dropout=}\n{self.num_attention_heads=}\n{self.window_size=}\n" f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{head_dim_qk=}\n{head_dim_v=}\n" ) From 134aac584e66958de60d0494c1891f73b021638f Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Tue, 5 May 2026 16:00:28 -0700 Subject: [PATCH 3/6] Add a test for bidirectional asymmetric SWA testing in fused attn. Also add a helper to pick window based on cuDNN version support in fused_attn.cpp Signed-off-by: Kshitij Lakhani --- tests/jax/test_fused_attn.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 8b727b1d43..adb24e6a21 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -1059,6 +1059,25 @@ def check_dqkv(primitive, reference, pad, idx): target_hlo = jitted_primitive.lower(*customcall_args).compile().as_text() assert_equal_collectives(target_hlo, self.coll_count_ref) +def _swa_window_size_for_test(s_kv: int, attn_mask_type: AttnMaskType) -> Tuple[int, int]: + """Pick a sliding-window size for SWA tests, gated on cuDNN version. + + cuDNN < 9.2: skip (no SWA support). + cuDNN >= 9.2: left-only window (s_kv // 10, 0). + cuDNN >= 9.6: bidirectional window (s_kv // 10, s_kv // 10) for the mask types whose + bidirectional fused dispatch is meaningful here (NO_MASK, PADDING_MASK). + Other mask types keep the left-only window: causal-family masks would + collapse (W, W) -> (W, 0) under the causal AND, so a separate bidirectional + case adds no signal. + """ + cudnn_v = get_cudnn_version() + if cudnn_v < 90200: + pytest.skip("Sliding window attention requires cuDNN >= 9.2") + left = s_kv // 10 + right = left + 5 + if cudnn_v >= 90600 and attn_mask_type in (AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK): + return (left, right) + return (left, 0) @pytest.mark.parametrize( "attn_mask_type", @@ -1330,9 +1349,7 @@ def _test_forward( This test is not intended to run automatically during CI as it is time-consuming It is kept for development and debugging """ - window_size = None - if swa: - window_size = (s_kv // 10, 0) + window_size = _swa_window_size_for_test(s_kv, attn_mask_type) if swa else None runner = FusedAttnRunner( b, s_q, @@ -1383,9 +1400,7 @@ def test_backward( """ Test backward with parameterized configs """ - window_size = None - if swa: - window_size = (s_kv // 10, 0) + window_size = _swa_window_size_for_test(s_kv, attn_mask_type) if swa else None runner = FusedAttnRunner( b, s_q, From dcbfc5110d1ec044c63051d11f42be0089f79c8b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 5 May 2026 23:27:59 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_fused_attn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index adb24e6a21..fe6377c52a 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -1059,6 +1059,7 @@ def check_dqkv(primitive, reference, pad, idx): target_hlo = jitted_primitive.lower(*customcall_args).compile().as_text() assert_equal_collectives(target_hlo, self.coll_count_ref) + def _swa_window_size_for_test(s_kv: int, attn_mask_type: AttnMaskType) -> Tuple[int, int]: """Pick a sliding-window size for SWA tests, gated on cuDNN version. @@ -1079,6 +1080,7 @@ def _swa_window_size_for_test(s_kv: int, attn_mask_type: AttnMaskType) -> Tuple[ return (left, right) return (left, 0) + @pytest.mark.parametrize( "attn_mask_type", [ From c13f399c1dcaa11567e0f61e7ae7104b4ec62e4c Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Tue, 5 May 2026 16:42:56 -0700 Subject: [PATCH 5/6] nit: Code clean up Signed-off-by: Kshitij Lakhani --- tests/jax/test_fused_attn.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index fe6377c52a..dc5e6ede28 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -1059,26 +1059,25 @@ def check_dqkv(primitive, reference, pad, idx): target_hlo = jitted_primitive.lower(*customcall_args).compile().as_text() assert_equal_collectives(target_hlo, self.coll_count_ref) - -def _swa_window_size_for_test(s_kv: int, attn_mask_type: AttnMaskType) -> Tuple[int, int]: +def _get_swa_window_size_for_test(s_kv: int, attn_mask_type: AttnMaskType) -> Tuple[int, int]: """Pick a sliding-window size for SWA tests, gated on cuDNN version. cuDNN < 9.2: skip (no SWA support). cuDNN >= 9.2: left-only window (s_kv // 10, 0). - cuDNN >= 9.6: bidirectional window (s_kv // 10, s_kv // 10) for the mask types whose + cuDNN >= 9.6: bidirectional window (s_kv // 10, s_kv // 10 + 5) for the mask types whose bidirectional fused dispatch is meaningful here (NO_MASK, PADDING_MASK). Other mask types keep the left-only window: causal-family masks would - collapse (W, W) -> (W, 0) under the causal AND, so a separate bidirectional - case adds no signal. + collapse (W, W) -> (W, 0), hence not tested here. """ - cudnn_v = get_cudnn_version() - if cudnn_v < 90200: + cudnn_version = get_cudnn_version() + if cudnn_version < 90200: pytest.skip("Sliding window attention requires cuDNN >= 9.2") - left = s_kv // 10 - right = left + 5 - if cudnn_v >= 90600 and attn_mask_type in (AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK): - return (left, right) - return (left, 0) + left_window_size = s_kv // 10 + # choose asymmetric window size for testing + right_window_size = left_window_size + 5 + if cudnn_version >= 90600 and attn_mask_type in (AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK): + return (left_window_size, right_window_size) + return (left_window_size, 0) @pytest.mark.parametrize( @@ -1351,7 +1350,7 @@ def _test_forward( This test is not intended to run automatically during CI as it is time-consuming It is kept for development and debugging """ - window_size = _swa_window_size_for_test(s_kv, attn_mask_type) if swa else None + window_size = _get_swa_window_size_for_test(s_kv, attn_mask_type) if swa else None runner = FusedAttnRunner( b, s_q, @@ -1402,7 +1401,7 @@ def test_backward( """ Test backward with parameterized configs """ - window_size = _swa_window_size_for_test(s_kv, attn_mask_type) if swa else None + window_size = _get_swa_window_size_for_test(s_kv, attn_mask_type) if swa else None runner = FusedAttnRunner( b, s_q, From 3680fbcced872d727dc5230aa5bf9d30301a7be4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 6 May 2026 16:59:31 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_fused_attn.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index dc5e6ede28..1fb0108068 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -1059,6 +1059,7 @@ def check_dqkv(primitive, reference, pad, idx): target_hlo = jitted_primitive.lower(*customcall_args).compile().as_text() assert_equal_collectives(target_hlo, self.coll_count_ref) + def _get_swa_window_size_for_test(s_kv: int, attn_mask_type: AttnMaskType) -> Tuple[int, int]: """Pick a sliding-window size for SWA tests, gated on cuDNN version. @@ -1075,7 +1076,10 @@ def _get_swa_window_size_for_test(s_kv: int, attn_mask_type: AttnMaskType) -> Tu left_window_size = s_kv // 10 # choose asymmetric window size for testing right_window_size = left_window_size + 5 - if cudnn_version >= 90600 and attn_mask_type in (AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK): + if cudnn_version >= 90600 and attn_mask_type in ( + AttnMaskType.NO_MASK, + AttnMaskType.PADDING_MASK, + ): return (left_window_size, right_window_size) return (left_window_size, 0)