diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 50c5de1db7..39efabc598 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -75,6 +75,7 @@ def impl_test_self_attn( if not is_fused_attn_kernel_available( is_training, + batch, dtype, dtype, QKVLayout.BS3HD, @@ -227,6 +228,7 @@ def test_cross_attn( if not is_fused_attn_kernel_available( is_training, + batch, dtype, dtype, QKVLayout.BSHD_BS2HD, @@ -368,6 +370,7 @@ def impl_test_context_parallel_attn( def check_has_backend_for_mask(mask_type): return is_fused_attn_kernel_available( is_training, + batch, dtype, dtype, qkv_layout, diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 1fb0108068..88c485db81 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -444,8 +444,9 @@ def _check_configs(self): "is either BSHD_BSHD_BSHD or THD_THD_THD" ) - self.backend = FusedAttnHelper( + self.backend, message = FusedAttnHelper( self.is_training, + self.batch_size, self.dtype, self.dtype, self.qkv_layout, @@ -460,9 +461,10 @@ def _check_configs(self): self.head_dim_qk, self.head_dim_v, (-1, -1) if self.window_size is None else self.window_size, + self.attn_mask_type.is_bottom_right(), ).get_fused_attn_backend() if self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: - pytest.skip("Unsupported inputs combination or device compute capability.") + pytest.skip(message) if ( self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index d2eb1a831c..628bce1b54 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -225,304 +225,135 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { } } +namespace { + +// per-thread storage for the diagnostic string +// re-used (cleared + re-populated) on every call to nvte_get_fused_attn_backend on this thread +thread_local std::string fused_attn_backend_message_buffer; + +void set_message(const char **message, const std::string &reason) { + if (message == nullptr) return; + fused_attn_backend_message_buffer = reason; + *message = fused_attn_backend_message_buffer.c_str(); +} + +} // namespace + // select a backend for fused attention NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( - bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic) { + bool is_training, size_t batch_size, NVTEDType q_dtype, NVTEDType kv_dtype, NVTEDType o_dtype, + NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float dropout, + size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, + size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool return_max_logit, bool cuda_graph, bool deterministic, + const char **message) { using namespace transformer_engine; - NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - const int device_id = cuda::current_device(); - const int sm_arch_ = cuda::sm_arch(device_id); + set_message(message, ""); NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type."); - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - auto cudnn_runtime_version = cudnnGetVersion(); - // For ragged offsets we only support 32-bit prior to cuDNN 9.5 - // Only used when THD format is requested. + cudnnHandle_t handle = cudnnExecutionPlanManager::Instance().GetHandle(); + const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + const auto cudnn_runtime_version = cudnnGetVersion(); + + // THD + 64-bit ragged offsets require cuDNN >= 9.5 const bool requires_64bit_ragged_offset = (qkv_format == NVTE_THD && fused_attn::get_ragged_offset_dtype( layout_group, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v) == DType::kInt64); - const bool supported_ragged_offset_size = - (!requires_64bit_ragged_offset || cudnn_runtime_version >= 90500); - - if ((q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2) && - sm_arch_ >= 90 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && - ( - // 9.2.1: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal} - (cudnn_runtime_version >= 90201 && sm_arch_ < 100 && max_seqlen_q % 128 == 0 && - max_seqlen_kv % 128 == 0 && head_dim_qk == 128 && head_dim_v == 128 && - (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || - // 9.7: {bshd, sbhd}, any seqlen, d<=256 for sm90 and d<=128 for sm100, {padding, padding_causal} - (cudnn_runtime_version >= 90700 && - // TODO (cyang): add is_training to nvte_get_fused_attn_backend - // sm90: fwd d<=256, bwd d=128 only - // sm100: fwd d<=128, bwd d<=128 - ((sm_arch_ < 100 && (!is_training) && head_dim_qk <= 256 && head_dim_v <= 256) || - (sm_arch_ < 100 && is_training && head_dim_qk == 128 && head_dim_v == 128) || - (sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) && - head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && - (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) || - // 9.21: d_qk=192, d_v=128 - (cudnn_runtime_version >= 92100 && sm_arch_ >= 100 && head_dim_qk <= 192 && - head_dim_v <= 128 && head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && - (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK))) && - // pre-9.21: {bshd, sbhd}, {vanilla} - // 9.21+: {bshd, sbhd, bhsd}, {vanilla, off-by-one, learnable} - ((cudnn_runtime_version < 92100 && - (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && - softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) || - (cudnn_runtime_version >= 92100 && - (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD || - qkv_format == NVTE_QKV_Format::NVTE_BHSD))) && - !requires_64bit_ragged_offset && - // 9.10.0: known bugs with SDPA FP8 - (cudnn_runtime_version != 91000) && !return_max_logit) { - backend = NVTE_Fused_Attn_Backend::NVTE_FP8; - } else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) { - bool flag_arb = false; - if ( - // TODO(cyang): replace with cudnn-frontend check_support for cleaner logic and better error messaging - // architecture - ((cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90)) || - (cudnn_runtime_version >= 8903 && sm_arch_ >= 80 && sm_arch_ < 100) || - (cudnn_runtime_version >= 90700 && sm_arch_ >= 100)) && - // sequence length - ((cudnn_runtime_version < 90000 && max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0) || - (cudnn_runtime_version >= 90000)) && - // number of heads - ((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups) || - (cudnn_runtime_version >= 8907)) && - // head dimension - // multiples of 8 - (head_dim_qk % 8 == 0 && head_dim_v % 8 == 0 && - // <= 128 - ((head_dim_qk <= 128 && head_dim_v <= 128) || - // 9.1: <= 256 + Hopper + fprop - // 9.5: <= 256 + Hopper + bprop - (head_dim_qk <= 256 && head_dim_v <= 256 && - ((!is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90100) || - (is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90500))) || - // 9.9: any head_dim + Blackwell + fprop + non_paged + sq > 1 - (!is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 90900 && max_seqlen_q > 1 && - layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) || - // 9.10.2: any head_dim + any arch + fprop + paged - // 9.10.2: any head_dim + any arch + fprop + non_paged + sq > 1 - // 9.10.2: any head_dim + any arch + fprop + non_paged + sq = 1 + {no_mask, padding, BRCM, padding_BRCM} - (!is_training && cudnn_runtime_version >= 91002 && - (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD || max_seqlen_q > 1 || - (max_seqlen_q == 1 && attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_MASK && - attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) || - // 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged - (head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 && - cudnn_runtime_version >= 91100)) && - // 9.11+ bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA - // Conditional to temporarily use blanket cudnn_runtime_version >= 9.11 until fixed - (!((cudnn_runtime_version >= 91100) && is_training && sm_arch_ == 90 && - head_dim_qk >= 128 && head_dim_v >= 128 && !(head_dim_qk == 192 && head_dim_v == 128) && - head_dim_qk != head_dim_v))) && - // bias type - ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || - (cudnn_runtime_version >= 8906 && - (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS || - (bias_type == NVTE_Bias_Type::NVTE_ALIBI && - attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK && - attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && - attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && - attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && - sm_arch_ >= 90) || - (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 90))) || - (cudnn_runtime_version >= 90000 && - (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 80))) && - // mask type - // pre-8.9.6: causal - ((cudnn_runtime_version < 8906 && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || - // 8.9.6: {bshd, sbhd} + {no_mask, causal, padding, padding_causal} - (cudnn_runtime_version >= 8906 && - (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) && - (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || - // 9.1: adds thd + {padding, padding_causal} - (cudnn_runtime_version >= 90100 && qkv_format == NVTE_QKV_Format::NVTE_THD && - (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) || - // 9.3: adds {bshd, sbhd} + causal_bottom_right + self/cross-attn (sq <= skv) - (cudnn_runtime_version >= 90300 && - (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) && - attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && - max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv && - bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || - // 9.5: adds {paged_kv_bshd, paged_kv_sbhd} + {padding, padding_causal, padding_causal_bottom_right} - (cudnn_runtime_version >= 90500 && - layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD && - (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || - (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && - max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv)) && - bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || - // 9.6: adds {bshd, sbhd, thd} + padding_causal_bottom_right + self/cross-attn (sq <= skv) - (cudnn_runtime_version >= 90600 && - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && - max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv && - bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || - // 9.7: removes s_q/s_kv % 64 = 0 for {causal_bottom_right, padding_causal_bottom_right} - // for any q_format/kv_format, and paged/non-paged - (cudnn_runtime_version >= 90700 && - (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - ((attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && - bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || - ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && - max_seqlen_q <= max_seqlen_kv)))) && - // bias + mask combination - (!(cudnn_runtime_version >= 8906 && - (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) && - bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && - // qkv format - (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD || - qkv_format == NVTE_QKV_Format::NVTE_BHSD || - (qkv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90 && - ((cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) || - cudnn_runtime_version >= 90600)) || - ((q_format == NVTE_QKV_Format::NVTE_SBHD || q_format == NVTE_QKV_Format::NVTE_BSHD || - q_format == NVTE_QKV_Format::NVTE_BHSD || - (q_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90) || - kv_format == NVTE_QKV_Format::NVTE_SBHD || kv_format == NVTE_QKV_Format::NVTE_BSHD || - kv_format == NVTE_QKV_Format::NVTE_BHSD || - (kv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90)) && - cudnn_runtime_version >= 90700)) && - // sliding window - // pre-9.2: full attn, causal - ((cudnn_runtime_version < 90200 && window_size_left == -1 && - (window_size_right == -1 || window_size_right == 0)) || - // 9.2: SWA (left, 0) + top-left diagonal + {bshd, sbhd} - (cudnn_runtime_version >= 90200 && - ((window_size_left == -1 && window_size_right == -1 && - attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK) || - ((window_size_left == -1 || window_size_left >= 0) && window_size_right == 0 && - (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && - max_seqlen_q == max_seqlen_kv)) && - max_seqlen_q <= max_seqlen_kv && dropout == 0.0 && - bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && - (qkv_format == NVTE_QKV_Format::NVTE_BSHD || - qkv_format == NVTE_QKV_Format::NVTE_SBHD)))) || - // 9.6: SWA (left, 0) + top-left/bottom-right diagonal + {bshd, sbhd, thd} - (cudnn_runtime_version >= 90600 && - ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || - ((window_size_left >= 0 || window_size_left == -1) && - (window_size_right >= 0 || window_size_right == -1) && - ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && - // TODO(cyang): fix bug for BRCM + cross-attention on sm100 - (sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv && - cudnn_runtime_version <= 90700) || - cudnn_runtime_version > 90700)))) || - attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || - (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && - (sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv && - cudnn_runtime_version <= 90700) || - cudnn_runtime_version > 90700))))) && - max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && - dropout == 0.0)))) && - // check 64-bit ragged offset support - (supported_ragged_offset_size) && - // 9.10.0/9.10.1: known bugs with SDPA F16 - (cudnn_runtime_version != 91000) && (cudnn_runtime_version != 91001) && - // softmax type - // pre-9.13.1: vanilla - // 9.13.1+: vanilla, off-by-one, learnable - (cudnn_runtime_version >= 91301 || - (cudnn_runtime_version < 91301 && - softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX)) && - // determinism on Blackwell - // pre-9.18.1: fwd: deterministic; bwd: non-deterministic - // 9.18.1+: fwd: deterministic; bwd: non-deterministic/deterministic - (sm_arch_ < 100 || - (sm_arch_ >= 100 && (!is_training || - (is_training && !deterministic && - (dropout == 0.0 || bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)) || - (is_training && deterministic && cudnn_runtime_version >= 91801 && - dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS))))) { - flag_arb = true; + if (requires_64bit_ragged_offset && cudnn_runtime_version < 90500) { + set_message(message, + "Configuration requires 64-bit ragged offsets, which require " + "cuDNN >= 9.5."); + return NVTE_Fused_Attn_Backend::NVTE_No_Backend; + } + + // THD requires padding-style mask + if (qkv_format == NVTE_QKV_Format::NVTE_THD && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) { + set_message(message, + "THD format requires PADDING / PADDING_CAUSAL / PADDING_CAUSAL_BOTTOM_RIGHT mask."); + return NVTE_Fused_Attn_Backend::NVTE_No_Backend; + } + + const bool is_fp8 = + (q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2); + const bool is_f16_or_bf16 = + (q_dtype == NVTEDType::kNVTEFloat16 || q_dtype == NVTEDType::kNVTEBFloat16); + + if (is_fp8) { + if (return_max_logit) { + set_message(message, "FP8 fused attention does not support return_max_logit=True."); + return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } - if (flag_arb) { - backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; + if (qkv_format != NVTE_QKV_Format::NVTE_BSHD && qkv_format != NVTE_QKV_Format::NVTE_SBHD && + qkv_format != NVTE_QKV_Format::NVTE_BHSD) { + set_message(message, "FP8 fused attention supports BSHD/SBHD/BHSD formats, found " + + std::to_string(static_cast(qkv_format)) + "."); + return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } - if (cudnn_runtime_version < 8900 && - backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.0+." - " Please upgrade your cuDNN version if possible." - << std::endl; + const DType qkv_t = static_cast(q_dtype); + const DType o_t = static_cast(o_dtype); + std::string fwd_reason = is_supported_fp8_fwd( + batch_size, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, + head_dim_v, is_training, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, + window_size_left, window_size_right, bottom_right_diagonal, qkv_t, o_t, scaling_mode, + handle); + if (!fwd_reason.empty()) { + set_message(message, fwd_reason); + return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } - if ((cudnn_runtime_version == 91400) && (max_seqlen_kv > 1024) && (window_size_left != -1) && - (attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_MASK) && - (attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK)) { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: Given combination of attention mask (non-causal) and " - "max_seqlen_kv (> 1024) does not support fused attention for cuDNN 9.14.0. " - " Please upgrade your cuDNN version if possible." - << std::endl; + if (is_training) { + std::string bwd_reason = is_supported_fp8_bwd( + batch_size, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, + head_dim_v, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, + window_size_left, window_size_right, bottom_right_diagonal, deterministic, qkv_t, o_t, + scaling_mode, handle); + if (!bwd_reason.empty()) { + set_message(message, bwd_reason); + return NVTE_Fused_Attn_Backend::NVTE_No_Backend; + } } - if ((cudnn_runtime_version <= 91500) && is_training && + return NVTE_Fused_Attn_Backend::NVTE_FP8; + } + + if (is_f16_or_bf16) { + if (cudnn_runtime_version <= 91500 && is_training && (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && (max_seqlen_kv % 128 != 0) && cuda_graph && - (attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK) && - (attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) && - (attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)) { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: Given combination of attention mask (non-padding)," - " max_seqlen_kv (not divisible by 128), and qkv_format (BSHD/SBHD) for" - " backward fused attention with graph capture requires cuDNN 9.15.1+. " - "Please upgrade your cuDNN version if possible." - << std::endl; + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) { + set_message(message, "Known cuDNN <= 9.15 issue with CUDA graph. Please upgrade cuDNN."); + return NVTE_Fused_Attn_Backend::NVTE_No_Backend; + } + const DType qkv_t = static_cast(q_dtype); + std::string fwd_reason = is_supported_f16_fwd( + batch_size, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, + head_dim_v, is_training, return_max_logit, dropout, qkv_layout, bias_type, attn_mask_type, + softmax_type, window_size_left, window_size_right, bottom_right_diagonal, qkv_t, handle); + if (!fwd_reason.empty()) { + set_message(message, fwd_reason); + return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } - if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen && sm_arch_ == 120) { - if (cudnn_runtime_version < 91801) { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: Given combination of sm_arch_ == 120 and cudnn_runtime_version < " - "91801 is not supported. " - << " Please upgrade your cuDNN version if possible." << std::endl; - } else if (deterministic && is_training) { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: Deterministic fused attention on SM120 is not supported." - << std::endl; - } else { - // Known missing support for T3HD/TH3D layouts on SM120 - const bool is_t3hd_or_th3d = - (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD || qkv_layout == NVTE_QKV_Layout::NVTE_TH3D); - if (is_t3hd_or_th3d) { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: Given combination of T3HD/TH3D layouts on SM120 is not supported. " - << " Please consider using other THD layouts if possible." << std::endl; - } + if (is_training) { + std::string bwd_reason = is_supported_f16_bwd( + batch_size, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, + head_dim_v, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, + window_size_left, window_size_right, bottom_right_diagonal, deterministic, qkv_t, handle); + if (!bwd_reason.empty()) { + set_message(message, bwd_reason); + return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } } - } else { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; + return NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; } - return backend; + + set_message(message, "Unsupported QKV dtype qkv_dtype=" + std::to_string(q_dtype) + " ."); + return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } // NVTE fused attention FWD with separate Q, K and V @@ -607,11 +438,14 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); const NVTEDType KV_type = static_cast(input_K->data.dtype); + const NVTEDType O_type = static_cast(output_O->data.dtype); + const NVTEScalingMode scaling_mode = input_Q->scaling_mode; NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, - h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, - return_max_logit, cuda_graph, false); + is_training, b, Q_type, KV_type, O_type, scaling_mode, qkv_layout, bias_type, attn_mask_type, + softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, + window_size_right, bottom_right_diagonal, return_max_logit, cuda_graph, + /*deterministic=*/false, /*message=*/nullptr); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { fused_attn_arbitrary_seqlen_fwd( @@ -688,11 +522,14 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); const NVTEDType KV_type = static_cast(input_K->data.dtype); + const NVTEDType O_type = static_cast(input_O->data.dtype); + const NVTEScalingMode scaling_mode = input_Q->scaling_mode; NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, - h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false, - cuda_graph, deterministic); + /*is_training=*/true, b, Q_type, KV_type, O_type, scaling_mode, qkv_layout, bias_type, + attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, + window_size_left, window_size_right, bottom_right_diagonal, /*return_max_logit=*/false, + cuda_graph, deterministic, /*message=*/nullptr); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { size_t i = 0; diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 6df7ad35c8..3a2b296ffc 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -1333,4 +1333,124 @@ void fused_attn_arbitrary_seqlen_bwd( NVTE_ERROR("Unexpected workspace_size."); } } + +std::string is_supported_f16_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, bool is_training, bool return_max_logit, + float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, + DType qkv_dtype, cudnnHandle_t handle) { + const auto b = static_cast(batch); + const auto h = static_cast(num_attn_heads); + const auto sq = static_cast(max_seqlen_q); + const auto skv = static_cast(max_seqlen_kv); + + const NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + const NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + const bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD); + const bool is_ragged_kv = (kv_format == NVTE_QKV_Format::NVTE_THD); + const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + const bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD); + const bool has_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); + + const int64_t max_b = (is_ragged_q || is_ragged_kv) ? b : 0; + const int64_t max_t_q = is_ragged_q ? b * sq : 0; + const int64_t max_t_kv = is_ragged_kv ? b * skv : 0; + const int64_t num_pages_k = is_paged_kv ? b : 0; + const int64_t num_pages_v = is_paged_kv ? b : 0; + const int64_t page_size_k = is_paged_kv ? skv : 0; + const int64_t page_size_v = is_paged_kv ? skv : 0; + const int64_t max_pages_per_seq_k = is_paged_kv ? 1 : 0; + const int64_t max_pages_per_seq_v = is_paged_kv ? 1 : 0; + const int64_t bias_b = has_bias ? b : 0; + const int64_t bias_h = has_bias ? h : 0; + const int64_t bias_sq = has_bias ? sq : 0; + const int64_t bias_skv = has_bias ? skv : 0; + + const NVTE_QKV_Format o_format = q_format; + + size_t workspace_size = 0; + try { + fused_attn::fused_attn_arbitrary_seqlen_fwd_impl( + b, h, static_cast(num_gqa_groups), sq, skv, static_cast(head_dim_qk), + static_cast(head_dim_v), max_b, max_t_q, max_t_kv, num_pages_k, num_pages_v, + page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, + bias_skv, is_training, return_max_logit, + /*scaling_factor=*/1.0f, p_dropout, qkv_layout, o_format, bias_type, mask_type, + softmax_type, window_size_left, window_size_right, bottom_right_diagonal, + /*devPtrQ=*/nullptr, /*devPtrK=*/nullptr, /*devPtrV=*/nullptr, /*devPtrBias=*/nullptr, + /*devPtrSoftmaxOffset=*/nullptr, /*devPtrS1=*/nullptr, /*devPtrS2=*/nullptr, + /*devPtrO=*/nullptr, /*devPtrDropoutSeed=*/nullptr, /*devPtrDropoutOffset=*/nullptr, + /*devPtrCuSeqlensQ=*/nullptr, /*devPtrCuSeqlensKV=*/nullptr, + /*devPtrPageTableK=*/nullptr, /*devPtrPageTableV=*/nullptr, + /*devPtrSeqOffsetsQ=*/nullptr, /*devPtrSeqOffsetsKV=*/nullptr, + get_cudnn_fe_dtype(qkv_dtype), + /*workspace=*/nullptr, &workspace_size, + /*stream=*/static_cast(0), handle); + return ""; + } catch (const std::exception &e) { + return e.what(); + } catch (...) { + return "is_supported_f16_fwd: unknown failure."; + } +} + +std::string is_supported_f16_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, DType qkv_dtype, cudnnHandle_t handle) { + const auto b = static_cast(batch); + const auto h = static_cast(num_attn_heads); + const auto sq = static_cast(max_seqlen_q); + const auto skv = static_cast(max_seqlen_kv); + + const NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + const NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + const bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD); + const bool is_ragged_kv = (kv_format == NVTE_QKV_Format::NVTE_THD); + const bool has_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); + + const int64_t max_b = (is_ragged_q || is_ragged_kv) ? b : 0; + const int64_t max_t_q = is_ragged_q ? b * sq : 0; + const int64_t max_t_kv = is_ragged_kv ? b * skv : 0; + const int64_t bias_b = has_bias ? b : 0; + const int64_t bias_h = has_bias ? h : 0; + const int64_t bias_sq = has_bias ? sq : 0; + const int64_t bias_skv = has_bias ? skv : 0; + + const NVTE_QKV_Format o_format = q_format; + const NVTE_QKV_Format do_format = o_format; + const NVTE_QKV_Layout dqkv_layout = qkv_layout; + + size_t workspace_size = 0; + try { + fused_attn::fused_attn_arbitrary_seqlen_bwd_impl( + b, h, static_cast(num_gqa_groups), sq, skv, static_cast(head_dim_qk), + static_cast(head_dim_v), max_b, max_t_q, max_t_kv, bias_b, bias_h, bias_sq, + bias_skv, /*scaling_factor=*/1.0f, p_dropout, qkv_layout, o_format, do_format, dqkv_layout, + bias_type, mask_type, softmax_type, window_size_left, window_size_right, + bottom_right_diagonal, deterministic, /*devPtrQ=*/nullptr, /*devPtrKTranspose=*/nullptr, + /*devPtrVTranspose=*/nullptr, /*devPtrO=*/nullptr, /*devPtrSoftmaxStats=*/nullptr, + /*devPtrBias=*/nullptr, /*devPtrSoftmaxOffset=*/nullptr, /*devPtrdQ=*/nullptr, + /*devPtrdK=*/nullptr, /*devPtrdV=*/nullptr, /*devPtrdO=*/nullptr, + /*devPtrdBias=*/nullptr, /*devPtrdSoftmaxOffset=*/nullptr, + /*devPtrDropoutSeed=*/nullptr, /*devPtrDropoutOffset=*/nullptr, + /*devPtrCuSeqlensQ=*/nullptr, /*devPtrCuSeqlensKV=*/nullptr, + /*devPtrSeqOffsetsQ=*/nullptr, /*devPtrSeqOffsetsKV=*/nullptr, + get_cudnn_fe_dtype(qkv_dtype), + /*workspace=*/nullptr, &workspace_size, + /*stream=*/static_cast(0), handle); + return ""; + } catch (const std::exception &e) { + return e.what(); + } catch (...) { + return "is_supported_f16_bwd: unknown failure."; + } +} + } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 8f79b5bb4a..fe94d0c10c 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -13,6 +13,8 @@ #include +#include + #include "common/common.h" #include "transformer_engine/fused_attn.h" @@ -47,6 +49,29 @@ void fused_attn_arbitrary_seqlen_bwd( const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); +// check if a given configuration is supported for F16/BF16 forward; +// if it is, cache the graph built for this config, and return an empty string; +// if not, return a diagnostic message in the form of a string. +std::string is_supported_f16_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, bool is_training, bool return_max_logit, + float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, + DType qkv_dtype, cudnnHandle_t handle); + +// check if a given configuration is supported for F16/BF16 backward; +// if it is, cache the graph built for this config, and return an empty string; +// if not, return a diagnostic message in the form of a string. +std::string is_supported_f16_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, DType qkv_dtype, cudnnHandle_t handle); + } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_ARBITRARY_SEQLEN_H_ diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index eab1ae02e6..f4064a8d34 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1324,4 +1324,91 @@ void fused_attn_fp8_bwd( return; } } + +std::string is_supported_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, bool is_training, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, DType qkv_dtype, DType o_dtype, + NVTEScalingMode scaling_mode, cudnnHandle_t handle) { + const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + size_t workspace_size = 0; + try { + fused_attn::fused_attn_fp8_fwd_impl( + static_cast(batch), static_cast(num_attn_heads), + static_cast(num_gqa_groups), static_cast(max_seqlen_q), + static_cast(max_seqlen_kv), static_cast(head_dim_qk), + static_cast(head_dim_v), is_training, /*scaling_factor=*/1.0f, p_dropout, + qkv_layout, /*o_format=*/qkv_format, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, bottom_right_diagonal, + /*devPtrQ=*/nullptr, /*devPtrK=*/nullptr, /*devPtrV=*/nullptr, + /*devPtrSoftmaxOffset=*/nullptr, /*devPtrM=*/nullptr, /*devPtrO=*/nullptr, + /*devPtrDescaleQ=*/nullptr, /*devPtrDescaleK=*/nullptr, /*devPtrDescaleV=*/nullptr, + /*devPtrDescaleS=*/nullptr, /*devPtrScaleS=*/nullptr, /*devPtrScaleO=*/nullptr, + /*devPtrAmaxO=*/nullptr, /*devPtrAmaxS=*/nullptr, /*devPtrcuSeqlensQ=*/nullptr, + /*devPtrcuSeqlensKV=*/nullptr, /*devPtrDropoutSeed=*/nullptr, + /*devPtrDropoutOffset=*/nullptr, get_cudnn_fe_dtype(qkv_dtype), get_cudnn_fe_dtype(o_dtype), + scaling_mode, + /*qkv_scale_inv_format=*/NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, + /*workspace=*/nullptr, &workspace_size, + /*stream=*/static_cast(0), handle); + return ""; + } catch (const std::exception& e) { + return e.what(); + } catch (...) { + return "is_supported_fp8_fwd: unknown failure."; + } +} + +std::string is_supported_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, DType qkv_dtype, DType o_dtype, + NVTEScalingMode scaling_mode, cudnnHandle_t handle) { + const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + const cudnn_frontend::DataType_t qkv_t = get_cudnn_fe_dtype(qkv_dtype); + const cudnn_frontend::DataType_t o_t = get_cudnn_fe_dtype(o_dtype); + const cudnn_frontend::DataType_t do_t = o_t; + const cudnn_frontend::DataType_t dqkv_t = qkv_t; + size_t workspace_size = 0; + try { + fused_attn::fused_attn_fp8_bwd_impl( + static_cast(batch), static_cast(num_attn_heads), + static_cast(num_gqa_groups), static_cast(max_seqlen_q), + static_cast(max_seqlen_kv), static_cast(head_dim_qk), + static_cast(head_dim_v), /*scaling_factor=*/1.0f, p_dropout, qkv_layout, + /*o_format=*/qkv_format, /*do_format=*/qkv_format, /*dqkv_layout=*/qkv_layout, bias_type, + mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, + deterministic, + /*devPtrQ=*/nullptr, /*devPtrK=*/nullptr, /*devPtrV=*/nullptr, /*devPtrM=*/nullptr, + /*devPtrO=*/nullptr, /*devPtrdO=*/nullptr, /*devPtrSoftmaxOffset=*/nullptr, + /*devPtrdQ=*/nullptr, /*devPtrdK=*/nullptr, /*devPtrdV=*/nullptr, + /*devPtrdSoftmaxOffset=*/nullptr, /*devPtrDescaleQ=*/nullptr, + /*devPtrDescaleK=*/nullptr, /*devPtrDescaleV=*/nullptr, /*devPtrDescaleO=*/nullptr, + /*devPtrDescaledO=*/nullptr, /*devPtrDescaleS=*/nullptr, /*devPtrDescaledP=*/nullptr, + /*devPtrScaleS=*/nullptr, /*devPtrScaledP=*/nullptr, /*devPtrScaledQ=*/nullptr, + /*devPtrScaledK=*/nullptr, /*devPtrScaledV=*/nullptr, /*devPtrAmaxdP=*/nullptr, + /*devPtrAmaxdQ=*/nullptr, /*devPtrAmaxdK=*/nullptr, /*devPtrAmaxdV=*/nullptr, + /*devPtrQ_t=*/nullptr, /*devPtrK_t=*/nullptr, /*devPtrdO_f16=*/nullptr, + /*devPtrdO_t=*/nullptr, /*devPtrDescaleQ_t=*/nullptr, /*devPtrDescaleK_t=*/nullptr, + /*devPtrDescaledO_t=*/nullptr, /*devPtrcuSeqlensQ=*/nullptr, + /*devPtrcuSeqlensKV=*/nullptr, /*devPtrDropoutSeed=*/nullptr, + /*devPtrDropoutOffset=*/nullptr, qkv_t, o_t, do_t, dqkv_t, scaling_mode, + /*qkv_scale_inv_format=*/NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, + /*do_scale_inv_format=*/NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, + /*workspace=*/nullptr, &workspace_size, + /*stream=*/static_cast(0), handle); + return ""; + } catch (const std::exception& e) { + return e.what(); + } catch (...) { + return "is_supported_fp8_bwd: unknown failure."; + } +} + } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index b9660128ca..01c7561402 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -8,6 +8,8 @@ * \brief Functions for fused attention for FP8 */ +#include + #include "transformer_engine/fused_attn.h" #include "transformer_engine/transformer_engine.h" @@ -39,4 +41,28 @@ void fused_attn_fp8_bwd( const Tensor *output_dK, const Tensor *output_dV, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + +// check if a given configuration is supported for FP8 forward; +// if it is, cache the graph built for this config, and return an empty string; +// if not, return a diagnostic message in the form of a string. +std::string is_supported_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, bool is_training, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, DType qkv_dtype, DType o_dtype, + NVTEScalingMode scaling_mode, cudnnHandle_t handle); + +// check if a given configuration is supported for FP8 backward; +// if it is, cache the graph built for this config, and return an empty string; +// if not, return a diagnostic message in the form of a string. +std::string is_supported_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, DType qkv_dtype, DType o_dtype, + NVTEScalingMode scaling_mode, cudnnHandle_t handle); } // namespace transformer_engine diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index d9d2786623..227afed24e 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -11,6 +11,8 @@ #ifndef TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_ #define TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_ +#include + #include "stdint.h" #include "transformer_engine.h" @@ -196,32 +198,41 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); /*! \brief Get fused attention backend based on input parameters. * - * \param[in] is_training Whether the model is in training mode. - * \param[in] q_dtype The data type of Tensor Q. - * \param[in] kv_dtype The data type of Tensors K, V. - * \param[in] qkv_layout The layout of Tensors Q, K, V. - * \param[in] bias_type The attention bias type. - * \param[in] attn_mask_type The attention mask type. - * \param[in] softmax_type The attention softmax type. - * \param[in] dropout The dropout probability. - * \param[in] num_attn_heads The number of heads in Q. - * \param[in] num_gqa_groups The number of heads in K, V. - * \param[in] max_seqlen_q The sequence length of Q. - * \param[in] max_seqlen_kv The sequence length of K, V. - * \param[in] head_dim_qk The head dimension of Q, K. - * \param[in] head_dim_v The head dimension of V. - * \param[in] window_size_left Sliding window size (the left half). - * \param[in] window_size_right Sliding window size (the right half). - * \param[in] return_max_logit Whether to produce Max along with Stats. - * \param[in] cuda_graph Whether cuda graph capture is enabled or not. - * \param[in] deterministic Whether determinism is required or not. + * \param[in] is_training Whether the model is in training mode. + * \param[in] batch_size Batch size. + * \param[in] q_dtype The data type of Tensor Q. + * \param[in] kv_dtype The data type of Tensors K, V. + * \param[in] o_dtype The data type of Tensor O. + * \param[in] scaling_mode Scaling mode of attention. + * \param[in] qkv_layout The layout of Tensors Q, K, V. + * \param[in] bias_type The attention bias type. + * \param[in] attn_mask_type The attention mask type. + * \param[in] softmax_type The attention softmax type. + * \param[in] dropout The dropout probability. + * \param[in] num_attn_heads The number of heads in Q. + * \param[in] num_gqa_groups The number of heads in K, V. + * \param[in] max_seqlen_q The sequence length of Q. + * \param[in] max_seqlen_kv The sequence length of K, V. + * \param[in] head_dim_qk The head dimension of Q, K. + * \param[in] head_dim_v The head dimension of V. + * \param[in] window_size_left Sliding window size (the left half). + * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the + * bottom right corner of the softmax matrix. + * \param[in] return_max_logit Whether to produce Max along with Stats. + * \param[in] cuda_graph Whether cuda graph capture is enabled or not. + * \param[in] deterministic Whether determinism is required or not. + * \param[out] message Empty string on success, otherwise a diagnostic string + * describing why the configuration was rejected. */ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( - bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic); + bool is_training, size_t batch_size, NVTEDType q_dtype, NVTEDType kv_dtype, NVTEDType o_dtype, + NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float dropout, + size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, + size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool return_max_logit, bool cuda_graph, bool deterministic, + const char **message); /*! \brief Compute dot product attention with separate Q, K and V. * diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index ef7687e3e9..950bb7778f 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -82,6 +82,13 @@ .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \ .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \ + pybind11::enum_(m, "NVTEScalingMode", pybind11::module_local()) \ + .value("NVTE_DELAYED_TENSOR_SCALING", NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING) \ + .value("NVTE_MXFP8_1D_SCALING", NVTEScalingMode::NVTE_MXFP8_1D_SCALING) \ + .value("NVTE_BLOCK_SCALING_1D", NVTEScalingMode::NVTE_BLOCK_SCALING_1D) \ + .value("NVTE_BLOCK_SCALING_2D", NVTEScalingMode::NVTE_BLOCK_SCALING_2D) \ + .value("NVTE_NVFP4_1D_SCALING", NVTEScalingMode::NVTE_NVFP4_1D_SCALING) \ + .value("NVTE_INVALID_SCALING", NVTEScalingMode::NVTE_INVALID_SCALING); \ pybind11::enum_( \ m, "Float8BlockScaleTensorFormat", pybind11::module_local()) \ .value("GEMM_READY", transformer_engine::Float8BlockScaleTensorFormat::GEMM_READY) \ diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index f54a043fd2..ac6cf8975c 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -325,6 +325,7 @@ def canonicalize_attn_mask_type(attn_mask_type: str): def is_fused_attn_kernel_available( is_training, + batch_size, q_dtype, kv_dtype, qkv_layout, @@ -339,15 +340,25 @@ def is_fused_attn_kernel_available( head_dim_qk, head_dim_v, window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, ): """ - To check whether the fused attention kernel is supported + To check whether the fused attention kernel is supported. + + If ``bottom_right_diagonal`` is None, it is derived from the mask type, matching the + convention used everywhere else in JAX TE (see ``_FusedAttnConfig`` constructions). """ window_size_tuple = (-1, -1) if window_size is None else window_size def make_helper(attn_mask_type): + bottom_right = ( + attn_mask_type.is_bottom_right() + if bottom_right_diagonal is None + else bottom_right_diagonal + ) return tex.FusedAttnHelper( is_training, + batch_size, q_dtype, kv_dtype, qkv_layout, @@ -362,6 +373,7 @@ def make_helper(attn_mask_type): head_dim_qk, head_dim_v, window_size_tuple, + bottom_right, ) return make_helper(attn_mask_type).is_fused_attn_kernel_available() diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 489bfde997..2a533c3f3e 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -16,7 +16,7 @@ from jax.experimental.custom_partitioning import SdyShardingRule import transformer_engine_jax -from transformer_engine_jax import NVTE_Fused_Attn_Backend +from transformer_engine_jax import NVTE_Fused_Attn_Backend, NVTEScalingMode from transformer_engine.jax.attention import ( AttnBiasType, AttnMaskType, @@ -108,6 +108,7 @@ class FusedAttnHelper: """ is_training: bool + batch_size: int q_dtype: jnp.dtype kv_dtype: jnp.dtype qkv_layout: QKVLayout @@ -122,17 +123,27 @@ class FusedAttnHelper: head_dim_qk: int head_dim_v: int window_size: Tuple[int, int] + bottom_right_diagonal: bool def is_fused_attn_kernel_available(self): """Check if there is available fused attention kernel""" - return self.get_fused_attn_backend() != NVTE_Fused_Attn_Backend.NVTE_No_Backend + backend, _ = self.get_fused_attn_backend() + return backend != NVTE_Fused_Attn_Backend.NVTE_No_Backend def get_fused_attn_backend(self): - """Get the fused attention kernel backend""" + """Get the fused attention kernel backend. + + Returns a ``(backend, message)`` tuple. ``message`` is empty on success, otherwise a + diagnostic string describing why the configuration was rejected when backend = NVTE_No_Backend. + """ + q_type = jax_dtype_to_te_dtype(self.q_dtype) return transformer_engine_jax.get_fused_attn_backend( self.is_training, - jax_dtype_to_te_dtype(self.q_dtype), + self.batch_size, + q_type, jax_dtype_to_te_dtype(self.kv_dtype), + q_type, + NVTEScalingMode.NVTE_INVALID_SCALING, self.qkv_layout.value, self.attn_bias_type.value, self.attn_mask_type.value, @@ -146,6 +157,7 @@ def get_fused_attn_backend(self): self.head_dim_v, self.window_size[0], self.window_size[1], + self.bottom_right_diagonal, not self.is_non_deterministic_allowed(), ) @@ -335,8 +347,10 @@ def abstract( out_aval = q_aval.update(shape=output_shape, dtype=q_dtype) # backend determines the softmax buffer shape/dtype - backend = FusedAttnHelper( + input_batch = reduce(operator.mul, batch_shape) + backend, message = FusedAttnHelper( config.is_training, + input_batch, q_dtype, k_dtype, config.qkv_layout, @@ -351,6 +365,7 @@ def abstract( q_head_dim, v_head_dim, config.window_size, + config.bottom_right_diagonal, ).get_fused_attn_backend() if backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: @@ -369,7 +384,7 @@ def abstract( ) softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) else: - raise ValueError(f"Unsupported {backend=}") + raise ValueError(f"Unsupported backend: {message}") softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype) # JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 2ecfedc8a2..1e8d99c3d8 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include "common/common.h" @@ -146,12 +147,15 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler); -NVTE_Fused_Attn_Backend GetFusedAttnBackend( - bool is_training, DType q_dtype, DType kv_dtype, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, - int64_t window_size_right, bool deterministic); +// Returns (backend, message). `message` is empty on success, otherwise a diagnostic string +// describing why the configuration was rejected when backend = NVTE_No_Backend. +std::tuple GetFusedAttnBackend( + bool is_training, size_t batch_size, DType q_dtype, DType kv_dtype, DType o_dtype, + NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, float dropout_probability, + size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, + size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic); pybind11::tuple GetFusedAttnForwardWorkspaceSizes( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index ed136d7b9e..5cd3265c3e 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -11,18 +11,21 @@ namespace transformer_engine { namespace jax { -NVTE_Fused_Attn_Backend GetFusedAttnBackend( - bool is_training, DType q_dtype, DType kv_dtype, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, - int64_t window_size_right, bool deterministic) { +std::tuple GetFusedAttnBackend( + bool is_training, size_t batch_size, DType q_dtype, DType kv_dtype, DType o_dtype, + NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, float dropout_probability, + size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, + size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic) { + const char *message = nullptr; auto backend = nvte_get_fused_attn_backend( - is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, - bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads, - q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, - false, false, deterministic); - return backend; + is_training, batch_size, static_cast(q_dtype), static_cast(kv_dtype), + static_cast(o_dtype), scaling_mode, qkv_layout, bias_type, mask_type, softmax_type, + dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen, qk_head_dim, + v_head_dim, window_size_left, window_size_right, bottom_right_diagonal, + /*return_max_logit=*/false, /*cuda_graph=*/false, deterministic, &message); + return {backend, message ? std::string(message) : std::string()}; } /* @@ -262,10 +265,11 @@ static void FusedAttnForwardImpl( auto rng_state_tensor = TensorWrapper(rng_state, std::vector{2}, DType::kInt64); auto backend = nvte_get_fused_attn_backend( - is_training, static_cast(dtype), static_cast(dtype), qkv_layout, - bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, - q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, - false, false, deterministic); + is_training, input_batch, static_cast(dtype), static_cast(dtype), + static_cast(dtype), NVTE_INVALID_SCALING, qkv_layout, bias_type, mask_type, + softmax_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, + qk_head_dim, v_head_dim, window_size_left, window_size_right, bottom_right_diagonal, + /*return_max_logit=*/false, /*cuda_graph=*/false, deterministic, /*message=*/nullptr); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); /* Auxiliary tensors (to be propagated to the backward pass later) */ @@ -538,10 +542,11 @@ static void FusedAttnBackwardImpl( NVTETensorPack aux_input_tensors; nvte_tensor_pack_create(&aux_input_tensors); auto backend = nvte_get_fused_attn_backend( - is_training, static_cast(dtype), static_cast(dtype), qkv_layout, - bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, - q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, - false, false, deterministic); + is_training, input_batch, static_cast(dtype), static_cast(dtype), + static_cast(dtype), NVTE_INVALID_SCALING, qkv_layout, bias_type, mask_type, + softmax_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, + qk_head_dim, v_head_dim, window_size_left, window_size_right, bottom_right_diagonal, + /*return_max_logit=*/false, /*cuda_graph=*/false, deterministic, /*message=*/nullptr); PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, softmax_aux, rng_state, bias, softmax_offset); diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 70d0403b3e..2d55abedc6 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -206,6 +206,14 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("NVFP4_2D_SCALING", JAXX_Scaling_Mode::NVFP4_2D_SCALING) .export_values(); + pybind11::enum_(m, "NVTEScalingMode", pybind11::module_local()) + .value("NVTE_DELAYED_TENSOR_SCALING", NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING) + .value("NVTE_MXFP8_1D_SCALING", NVTEScalingMode::NVTE_MXFP8_1D_SCALING) + .value("NVTE_BLOCK_SCALING_1D", NVTEScalingMode::NVTE_BLOCK_SCALING_1D) + .value("NVTE_BLOCK_SCALING_2D", NVTEScalingMode::NVTE_BLOCK_SCALING_2D) + .value("NVTE_NVFP4_1D_SCALING", NVTEScalingMode::NVTE_NVFP4_1D_SCALING) + .value("NVTE_INVALID_SCALING", NVTEScalingMode::NVTE_INVALID_SCALING); + pybind11::enum_(m, "JAXX_Quantize_Layout", pybind11::module_local()) .value("ROWWISE", JAXX_Quantize_Layout::ROWWISE) .value("COLWISE", JAXX_Quantize_Layout::COLWISE) diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index a2e7920843..184547aa92 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -748,6 +748,8 @@ def __call__( enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "1")) sequence_dim = 0 if self.transpose_batch_sequence else 1 + batch_dim = 1 - sequence_dim + batch_size = query.shape[batch_dim] seqlen_q = query.shape[sequence_dim] if qkv_layout == QKVLayout.BS3HD: seqlen_kv = seqlen_q @@ -763,6 +765,7 @@ def __call__( has_fused_attn_kernel = is_fused_attn_kernel_available( # This needs to be fixed: TE-Jax has historically correlated training mode with deterministic mode. not deterministic, + batch_size, input_dtype, # self._assert_dtypes enforces Q, K, V, bias to have the same dtype so using input_dtype as kv dtype is sufficient input_dtype, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 7df5daabe5..52bb687851 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -1225,10 +1225,13 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if fp8 and fp8_meta["recipe"].fp8_dpa: q_type = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) kv_type = q_type - fused_attention_backend = tex.get_fused_attn_backend( + fused_attention_backend, reject_message = tex.get_fused_attn_backend( is_training, + batch_size, q_type, kv_type, + q_type, + tex.NVTEScalingMode.NVTE_INVALID_SCALING, QKVLayout[qkv_layout], AttnBiasType[fu_core_attention_bias_type], AttnMaskType[attn_mask_type], @@ -1242,12 +1245,16 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt head_dim_v, window_size[0], window_size[1], + bottom_right_diagonal, return_max_logit, cuda_graph, deterministic, ) if fused_attention_backend == FusedAttnBackend["No_Backend"]: - logger.debug("Disabling FusedAttention as no backend supports the provided input") + logger.debug( + "Disabling FusedAttention: %s", + reject_message, + ) use_fused_attention = False fused_attention_backend = None # Filter: Determinism diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 4a2ea7412b..205e7eb834 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -75,12 +75,16 @@ std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::T * Attention **************************************************************************************************/ -NVTE_Fused_Attn_Backend get_fused_attn_backend( - bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, +// Returns (backend, reason). `reason` is empty on success, otherwise a diagnostic string +// describing why the configuration was rejected when backend = NVTE_No_Backend. +std::tuple get_fused_attn_backend( + bool is_training, size_t batch_size, const DType q_dtype, const DType kv_dtype, + const DType o_dtype, NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic); + int64_t window_size_right, bool bottom_right_diagonal, bool return_max_logit, bool cuda_graph, + bool deterministic); std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 7e8018b3fd..41dcd3301a 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -40,18 +40,22 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s namespace transformer_engine::pytorch { // get the fused attention backend -NVTE_Fused_Attn_Backend get_fused_attn_backend( - bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, +std::tuple get_fused_attn_backend( + bool is_training, size_t batch_size, const DType q_dtype, const DType kv_dtype, + const DType o_dtype, NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic) { + int64_t window_size_right, bool bottom_right_diagonal, bool return_max_logit, bool cuda_graph, + bool deterministic) { + const char *message = nullptr; NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, - bias_type, attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups, - max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right, - return_max_logit, cuda_graph, deterministic); - return fused_attention_backend; + is_training, batch_size, static_cast(q_dtype), static_cast(kv_dtype), + static_cast(o_dtype), scaling_mode, qkv_layout, bias_type, attn_mask_type, + softmax_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, + head_dim_qk, head_dim_v, window_size_left, window_size_right, bottom_right_diagonal, + return_max_logit, cuda_graph, deterministic, &message); + return {fused_attention_backend, message ? std::string(message) : std::string()}; } // helper function for S and dP quantizers