diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 8b727b1d43..1fb0108068 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -1060,6 +1060,30 @@ def check_dqkv(primitive, reference, pad, idx): assert_equal_collectives(target_hlo, self.coll_count_ref) +def _get_swa_window_size_for_test(s_kv: int, attn_mask_type: AttnMaskType) -> Tuple[int, int]: + """Pick a sliding-window size for SWA tests, gated on cuDNN version. + + cuDNN < 9.2: skip (no SWA support). + cuDNN >= 9.2: left-only window (s_kv // 10, 0). + cuDNN >= 9.6: bidirectional window (s_kv // 10, s_kv // 10 + 5) for the mask types whose + bidirectional fused dispatch is meaningful here (NO_MASK, PADDING_MASK). + Other mask types keep the left-only window: causal-family masks would + collapse (W, W) -> (W, 0), hence not tested here. + """ + cudnn_version = get_cudnn_version() + if cudnn_version < 90200: + pytest.skip("Sliding window attention requires cuDNN >= 9.2") + left_window_size = s_kv // 10 + # choose asymmetric window size for testing + right_window_size = left_window_size + 5 + if cudnn_version >= 90600 and attn_mask_type in ( + AttnMaskType.NO_MASK, + AttnMaskType.PADDING_MASK, + ): + return (left_window_size, right_window_size) + return (left_window_size, 0) + + @pytest.mark.parametrize( "attn_mask_type", [ @@ -1330,9 +1354,7 @@ def _test_forward( This test is not intended to run automatically during CI as it is time-consuming It is kept for development and debugging """ - window_size = None - if swa: - window_size = (s_kv // 10, 0) + window_size = _get_swa_window_size_for_test(s_kv, attn_mask_type) if swa else None runner = FusedAttnRunner( b, s_q, @@ -1383,9 +1405,7 @@ def test_backward( """ Test backward with parameterized configs """ - window_size = None - if swa: - window_size = (s_kv // 10, 0) + window_size = _get_swa_window_size_for_test(s_kv, attn_mask_type) if swa else None runner = FusedAttnRunner( b, s_q, diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 141767b803..ae8ddbed69 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -469,6 +469,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv && cudnn_runtime_version <= 90700) || cudnn_runtime_version > 90700)))) || + attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 513677e4a1..a2e7920843 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -788,7 +788,7 @@ def __call__( "Fall back to the unfused attention.\n" "Please try to update the cuDNN and TE to the latest version.\n" f"{qkv_layout=}\n{attn_bias_type=}\n{attn_mask_type=}\n" - f"{self.attention_dropout=}\n{self.num_attention_heads=}\n" + f"{self.attention_dropout=}\n{self.num_attention_heads=}\n{self.window_size=}\n" f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{head_dim_qk=}\n{head_dim_v=}\n" )