From 3d0fcd78b19bced347efeea53246b4c82a523766 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Mon, 27 Apr 2026 02:27:46 -0700 Subject: [PATCH 1/2] enable head dim 256 for FA4 Signed-off-by: Xin Yao --- tests/pytorch/attention/test_attention.py | 2 + .../dot_product_attention/backends.py | 2 + .../attention/dot_product_attention/utils.py | 51 +++++++++++-------- 3 files changed, 33 insertions(+), 22 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index c9ea791444..6d4f3b45a9 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -372,6 +372,8 @@ def test_dpa_num_splits(dtype, model_configs, model): "fa4_base_1": ModelConfig(4, 128, 16, 64), "fa4_base_2": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"), "fa4_base_3": ModelConfig(2, 1024, 8, 96, attn_mask_type="causal"), + # head_dim=256 (SM100 only via dedicated kernel; flash-attn-4 > 4.0.0b10) + "fa4_base_hdim256": ModelConfig(2, 1024, 8, 256, attn_mask_type="causal"), # GQA "fa4_gqa_1": ModelConfig(2, 1024, 32, 128, num_gqa_groups=8, attn_mask_type="causal"), "fa4_gqa_2": ModelConfig(2, 1024, 16, 128, num_gqa_groups=1, attn_mask_type="causal"), diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 4104820a1c..61e7651207 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -167,8 +167,10 @@ from flash_attn.cute.interface import ( # pylint: disable=ungrouped-imports,no-name-in-module flash_attn_func as flash_attn_func_v4, flash_attn_varlen_func as flash_attn_varlen_func_v4, + _validate_head_dims as _fa4_validate_head_dims, ) + fa_utils.v4_validate_head_dims = _fa4_validate_head_dims fa_utils.set_flash_attention_4_params() # Float8CurrentScaling: fused_attn_bwd takes O in FP8 by default, this flag allows it in F16 diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index ed87423534..996c6fac37 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -149,6 +149,9 @@ class FlashAttentionUtils: v4_installation_steps = """\ pip install flash-attn-4==4.0.0b8 nvidia-cutlass-dsl[cu13]""" v4_warning_printed = False + # Set by backends.py if FA4 is installed; calls flash_attn.cute.interface._validate_head_dims + # which raises AssertionError for unsupported (head_dim, head_dim_v) combinations. + v4_validate_head_dims = None @staticmethod def set_flash_attention_version(): @@ -792,21 +795,24 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt ) use_flash_attention_3 = False - if use_flash_attention_4 and FlashAttentionUtils.v4_is_installed: - # FA4 head dimension support is architecture-dependent - # (matches _validate_head_dims in flash_attn.cute.interface): - # SM90: head_dim <= 256 and head_dim_v <= 256 - # SM100/110: head_dim <= 128 and head_dim_v <= 128, - # OR DeepSeek MLA shape (head_dim=192, head_dim_v=128) - # SM80/120: constrained by shared memory (~256 max in practice) - _fa4_hdim_ok = True - if (10, 0) <= device_compute_capability < (12, 0): - _is_standard = head_dim_qk <= 128 and head_dim_v <= 128 - _is_deepseek = head_dim_qk == 192 and head_dim_v == 128 - _fa4_hdim_ok = _is_standard or _is_deepseek - else: - _fa4_hdim_ok = head_dim_qk <= 256 and head_dim_v <= 256 - if not _fa4_hdim_ok: + if ( + use_flash_attention_4 + and FlashAttentionUtils.v4_is_installed + and FlashAttentionUtils.v4_validate_head_dims is not None + ): + # Defer to FA4's own _validate_head_dims to keep TE in sync with FA4 supported shapes + # (e.g., (256, 256) on SM100, (192, 128) DeepSeek, (64, 512) MLA-absorbed). + # The function asserts on unsupported combinations; SM80/SM120 have no validation branch + # in FA4 so the call passes through silently for those archs. + _fa4_alignment = 16 // torch.empty(0, dtype=qkv_dtype).element_size() + try: + FlashAttentionUtils.v4_validate_head_dims( + head_dim_qk, + head_dim_v, + device_compute_capability[0], + _fa4_alignment, + ) + except AssertionError: logger.debug( "Disabling FlashAttention 4 due to unsupported head dimensions. " "Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.", @@ -815,13 +821,14 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt device_compute_capability[0] * 10 + device_compute_capability[1], ) use_flash_attention_4 = False - # Workaround: SM100 backward kernel bug when MLA + 2CTA (head_dim_qk >= 128). - # FlashAttentionBackwardSm100 computes dK_reduce_ncol = gcd(32, tile_hdim // 2) - # based on Q/K head_dim but reuses it for dV TMEM load atoms. When - # (tile_hdimv // 2) % dK_reduce_ncol != 0, dV reads are misaligned. - # See: flash_attn/cute/flash_bwd_sm100.py, line ~262 and ~3890. - elif ( - _fa4_hdim_ok + # Workaround: SM100 backward kernel bug when MLA + 2CTA (head_dim_qk >= 128) for the + # standard (non-dedicated) kernel path. FlashAttentionBackwardSm100 computes + # dK_reduce_ncol = gcd(32, tile_hdim // 2) based on Q/K head_dim but reuses it for + # dV TMEM load atoms. When (tile_hdimv // 2) % dK_reduce_ncol != 0, dV reads are + # misaligned. The dedicated (256, 256) kernel uses its own tmem layout so it's + # not affected. See: flash_attn/cute/flash_bwd_sm100.py, line ~262 and ~3890. + if ( + use_flash_attention_4 and is_training and head_dim_qk != head_dim_v and head_dim_qk >= 128 From 8aa524258eb8f4d50e9bec66da0aed59f0dc2c6f Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Tue, 5 May 2026 19:44:24 -0700 Subject: [PATCH 2/2] update CI, fix lint, resolve comments Signed-off-by: Xin Yao --- qa/L3_pytorch_FA_versions_test/test.sh | 4 +-- tests/pytorch/attention/test_attention.py | 34 ++++++++++++++----- .../attention/dot_product_attention/utils.py | 7 ++-- 3 files changed, 32 insertions(+), 13 deletions(-) diff --git a/qa/L3_pytorch_FA_versions_test/test.sh b/qa/L3_pytorch_FA_versions_test/test.sh index 642eb93b06..1ce90412fb 100644 --- a/qa/L3_pytorch_FA_versions_test/test.sh +++ b/qa/L3_pytorch_FA_versions_test/test.sh @@ -18,10 +18,10 @@ sm_arch=`python3 -c "import torch; sm = torch.cuda.get_device_capability(0); pri export FLASH_ATTN_CUDA_ARCHS=$sm_arch if [ $sm_arch -gt 90 ] then - FA_versions=(2.8.3 4.0.0b8) + FA_versions=(2.8.3 4.0.0b11) elif [ $sm_arch -eq 90 ] then - FA_versions=(2.7.3 2.8.3 3.0.0b1 4.0.0b8) + FA_versions=(2.7.3 2.8.3 3.0.0b1 4.0.0b11) fi for fa_version in "${FA_versions[@]}" diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 6d4f3b45a9..8a32582065 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -372,8 +372,6 @@ def test_dpa_num_splits(dtype, model_configs, model): "fa4_base_1": ModelConfig(4, 128, 16, 64), "fa4_base_2": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"), "fa4_base_3": ModelConfig(2, 1024, 8, 96, attn_mask_type="causal"), - # head_dim=256 (SM100 only via dedicated kernel; flash-attn-4 > 4.0.0b10) - "fa4_base_hdim256": ModelConfig(2, 1024, 8, 256, attn_mask_type="causal"), # GQA "fa4_gqa_1": ModelConfig(2, 1024, 32, 128, num_gqa_groups=8, attn_mask_type="causal"), "fa4_gqa_2": ModelConfig(2, 1024, 16, 128, num_gqa_groups=1, attn_mask_type="causal"), @@ -386,12 +384,36 @@ def test_dpa_num_splits(dtype, model_configs, model): @pytest.mark.skipif( not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required." ) -@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_fa4_base]) @pytest.mark.parametrize("model", model_configs_fa4_base.keys()) def test_dpa_fa4_base(dtype, model_configs, model): - """Test DotProductAttention with FA4: base configs, extended head dims, GQA, num_splits""" + """Test DotProductAttention with FA4: base configs, GQA, num_splits""" + test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False) + + +# head_dim=256 is supported only on SM100 via FA4's dedicated kernel +# (flash_attn/cute/sm100_hd256_2cta_fmha_*.py), available in flash-attn-4 > 4.0.0b10. +# On other architectures, _validate_head_dims rejects (256, 256), FA4 is disabled, and +# the test would silently fall back to another backend — defeating the purpose. Gate +# explicitly so the CI signal is unambiguous. +model_configs_fa4_hdim256 = { + "fa4_hdim256": ModelConfig(2, 1024, 8, 256, attn_mask_type="causal"), +} + + +@pytest.mark.skipif( + not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required." +) +@pytest.mark.skipif( + get_device_compute_capability() != (10, 0), + reason="FA4 head_dim=256 dedicated kernel is SM100-only.", +) +@pytest.mark.parametrize("dtype", param_types_lean) +@pytest.mark.parametrize("model_configs", [model_configs_fa4_hdim256]) +@pytest.mark.parametrize("model", model_configs_fa4_hdim256.keys()) +def test_dpa_fa4_hdim256(dtype, model_configs, model): + """Test DotProductAttention with FA4: head_dim=256 dedicated kernel on SM100""" test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False) @@ -411,7 +433,6 @@ def test_dpa_fa4_base(dtype, model_configs, model): @pytest.mark.skipif( not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required." ) -@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_fa4_mla]) @pytest.mark.parametrize("model", model_configs_fa4_mla.keys()) @@ -438,7 +459,6 @@ def test_dpa_fa4_mla(dtype, model_configs, model): @pytest.mark.skipif( not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required." ) -@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_fa4_swa]) @pytest.mark.parametrize("model", model_configs_fa4_swa.keys()) @@ -462,7 +482,6 @@ def test_dpa_fa4_sliding_window(dtype, model_configs, model, qkv_layout): @pytest.mark.skipif( not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required." ) -@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_fa4_varlen]) @pytest.mark.parametrize("model", model_configs_fa4_varlen.keys()) @@ -488,7 +507,6 @@ def test_dpa_fa4_varlen(dtype, model_configs, model, qkv_layout): @pytest.mark.skipif( not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required." ) -@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_fa4_mask]) @pytest.mark.parametrize("model", model_configs_fa4_mask.keys()) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 996c6fac37..c98287bb2c 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -7,7 +7,7 @@ """ import math import os -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import warnings import logging import functools @@ -147,11 +147,11 @@ class FlashAttentionUtils: fa4_version = PkgVersion("0") use_v4 = False v4_installation_steps = """\ -pip install flash-attn-4==4.0.0b8 nvidia-cutlass-dsl[cu13]""" +pip install flash-attn-4==4.0.0b11 nvidia-cutlass-dsl[cu13]""" v4_warning_printed = False # Set by backends.py if FA4 is installed; calls flash_attn.cute.interface._validate_head_dims # which raises AssertionError for unsupported (head_dim, head_dim_v) combinations. - v4_validate_head_dims = None + v4_validate_head_dims: Callable = None @staticmethod def set_flash_attention_version(): @@ -806,6 +806,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt # in FA4 so the call passes through silently for those archs. _fa4_alignment = 16 // torch.empty(0, dtype=qkv_dtype).element_size() try: + # pylint: disable-next=not-callable FlashAttentionUtils.v4_validate_head_dims( head_dim_qk, head_dim_v,