-
Notifications
You must be signed in to change notification settings - Fork 718
[All] Refactor nvte_get_fused_attn_backend with cudnn-frontend calls #2964
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
2b9fbc5
16b837c
5a482f9
42bcd89
de8e814
4b8c7ed
81e59a9
5640c68
6c5126d
d35bff7
f6fc585
3e666b0
056aba6
e054863
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1333,4 +1333,124 @@ void fused_attn_arbitrary_seqlen_bwd( | |
| NVTE_ERROR("Unexpected workspace_size."); | ||
| } | ||
| } | ||
|
|
||
| std::string is_supported_f16_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, | ||
| size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, | ||
| size_t head_dim_v, bool is_training, bool return_max_logit, | ||
| float p_dropout, NVTE_QKV_Layout qkv_layout, | ||
| NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, | ||
| NVTE_Softmax_Type softmax_type, int64_t window_size_left, | ||
| int64_t window_size_right, bool bottom_right_diagonal, | ||
| DType qkv_dtype, cudnnHandle_t handle) { | ||
| const auto b = static_cast<int64_t>(batch); | ||
| const auto h = static_cast<int64_t>(num_attn_heads); | ||
| const auto sq = static_cast<int64_t>(max_seqlen_q); | ||
| const auto skv = static_cast<int64_t>(max_seqlen_kv); | ||
|
|
||
| const NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); | ||
| const NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); | ||
| const bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD); | ||
| const bool is_ragged_kv = (kv_format == NVTE_QKV_Format::NVTE_THD); | ||
| const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); | ||
| const bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD); | ||
| const bool has_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); | ||
|
|
||
| const int64_t max_b = (is_ragged_q || is_ragged_kv) ? b : 0; | ||
| const int64_t max_t_q = is_ragged_q ? b * sq : 0; | ||
| const int64_t max_t_kv = is_ragged_kv ? b * skv : 0; | ||
| const int64_t num_pages_k = is_paged_kv ? b : 0; | ||
| const int64_t num_pages_v = is_paged_kv ? b : 0; | ||
| const int64_t page_size_k = is_paged_kv ? skv : 0; | ||
| const int64_t page_size_v = is_paged_kv ? skv : 0; | ||
| const int64_t max_pages_per_seq_k = is_paged_kv ? 1 : 0; | ||
| const int64_t max_pages_per_seq_v = is_paged_kv ? 1 : 0; | ||
| const int64_t bias_b = has_bias ? b : 0; | ||
| const int64_t bias_h = has_bias ? h : 0; | ||
| const int64_t bias_sq = has_bias ? sq : 0; | ||
| const int64_t bias_skv = has_bias ? skv : 0; | ||
|
|
||
| const NVTE_QKV_Format o_format = q_format; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| size_t workspace_size = 0; | ||
| try { | ||
| fused_attn::fused_attn_arbitrary_seqlen_fwd_impl( | ||
| b, h, static_cast<int64_t>(num_gqa_groups), sq, skv, static_cast<int64_t>(head_dim_qk), | ||
| static_cast<int64_t>(head_dim_v), max_b, max_t_q, max_t_kv, num_pages_k, num_pages_v, | ||
| page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, | ||
| bias_skv, is_training, return_max_logit, | ||
| /*scaling_factor=*/1.0f, p_dropout, qkv_layout, o_format, bias_type, mask_type, | ||
| softmax_type, window_size_left, window_size_right, bottom_right_diagonal, | ||
| /*devPtrQ=*/nullptr, /*devPtrK=*/nullptr, /*devPtrV=*/nullptr, /*devPtrBias=*/nullptr, | ||
| /*devPtrSoftmaxOffset=*/nullptr, /*devPtrS1=*/nullptr, /*devPtrS2=*/nullptr, | ||
| /*devPtrO=*/nullptr, /*devPtrDropoutSeed=*/nullptr, /*devPtrDropoutOffset=*/nullptr, | ||
| /*devPtrCuSeqlensQ=*/nullptr, /*devPtrCuSeqlensKV=*/nullptr, | ||
| /*devPtrPageTableK=*/nullptr, /*devPtrPageTableV=*/nullptr, | ||
| /*devPtrSeqOffsetsQ=*/nullptr, /*devPtrSeqOffsetsKV=*/nullptr, | ||
| get_cudnn_fe_dtype(qkv_dtype), | ||
| /*workspace=*/nullptr, &workspace_size, | ||
| /*stream=*/static_cast<cudaStream_t>(0), handle); | ||
| return ""; | ||
| } catch (const std::exception &e) { | ||
| return e.what(); | ||
| } catch (...) { | ||
| return "is_supported_f16_fwd: unknown failure."; | ||
| } | ||
| } | ||
|
|
||
| std::string is_supported_f16_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, | ||
| size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, | ||
| size_t head_dim_v, float p_dropout, NVTE_QKV_Layout qkv_layout, | ||
| NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, | ||
| NVTE_Softmax_Type softmax_type, int64_t window_size_left, | ||
| int64_t window_size_right, bool bottom_right_diagonal, | ||
| bool deterministic, DType qkv_dtype, cudnnHandle_t handle) { | ||
| const auto b = static_cast<int64_t>(batch); | ||
| const auto h = static_cast<int64_t>(num_attn_heads); | ||
| const auto sq = static_cast<int64_t>(max_seqlen_q); | ||
| const auto skv = static_cast<int64_t>(max_seqlen_kv); | ||
|
|
||
| const NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); | ||
| const NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); | ||
| const bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD); | ||
| const bool is_ragged_kv = (kv_format == NVTE_QKV_Format::NVTE_THD); | ||
| const bool has_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); | ||
|
|
||
| const int64_t max_b = (is_ragged_q || is_ragged_kv) ? b : 0; | ||
| const int64_t max_t_q = is_ragged_q ? b * sq : 0; | ||
| const int64_t max_t_kv = is_ragged_kv ? b * skv : 0; | ||
| const int64_t bias_b = has_bias ? b : 0; | ||
| const int64_t bias_h = has_bias ? h : 0; | ||
| const int64_t bias_sq = has_bias ? sq : 0; | ||
| const int64_t bias_skv = has_bias ? skv : 0; | ||
|
|
||
| const NVTE_QKV_Format o_format = q_format; | ||
| const NVTE_QKV_Format do_format = o_format; | ||
| const NVTE_QKV_Layout dqkv_layout = qkv_layout; | ||
|
Comment on lines
+1427
to
+1428
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Comment on lines
+1426
to
+1428
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| size_t workspace_size = 0; | ||
| try { | ||
| fused_attn::fused_attn_arbitrary_seqlen_bwd_impl( | ||
| b, h, static_cast<int64_t>(num_gqa_groups), sq, skv, static_cast<int64_t>(head_dim_qk), | ||
| static_cast<int64_t>(head_dim_v), max_b, max_t_q, max_t_kv, bias_b, bias_h, bias_sq, | ||
| bias_skv, /*scaling_factor=*/1.0f, p_dropout, qkv_layout, o_format, do_format, dqkv_layout, | ||
| bias_type, mask_type, softmax_type, window_size_left, window_size_right, | ||
| bottom_right_diagonal, deterministic, /*devPtrQ=*/nullptr, /*devPtrKTranspose=*/nullptr, | ||
| /*devPtrVTranspose=*/nullptr, /*devPtrO=*/nullptr, /*devPtrSoftmaxStats=*/nullptr, | ||
| /*devPtrBias=*/nullptr, /*devPtrSoftmaxOffset=*/nullptr, /*devPtrdQ=*/nullptr, | ||
| /*devPtrdK=*/nullptr, /*devPtrdV=*/nullptr, /*devPtrdO=*/nullptr, | ||
| /*devPtrdBias=*/nullptr, /*devPtrdSoftmaxOffset=*/nullptr, | ||
| /*devPtrDropoutSeed=*/nullptr, /*devPtrDropoutOffset=*/nullptr, | ||
| /*devPtrCuSeqlensQ=*/nullptr, /*devPtrCuSeqlensKV=*/nullptr, | ||
| /*devPtrSeqOffsetsQ=*/nullptr, /*devPtrSeqOffsetsKV=*/nullptr, | ||
| get_cudnn_fe_dtype(qkv_dtype), | ||
| /*workspace=*/nullptr, &workspace_size, | ||
| /*stream=*/static_cast<cudaStream_t>(0), handle); | ||
| return ""; | ||
| } catch (const std::exception &e) { | ||
| return e.what(); | ||
| } catch (...) { | ||
| return "is_supported_f16_bwd: unknown failure."; | ||
| } | ||
| } | ||
|
|
||
| } // namespace transformer_engine | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
o_format = q_formatassumption can mismatch the actual forward callis_supported_f16_fwdderiveso_formatfromq_format, butnvte_fused_attn_fwdaccepts a separateo_formatparameter that is not forwarded intonvte_get_fused_attn_backend. When the caller uses a different output format from the query format — such as returning BSHD output from an SBHD_BSHD_BSHD layout — the probe builds a cuDNN graph for the wrongo_format. If the graph witho_format=q_formatis accepted but the config with the actualo_formatis not (or vice versa),nvte_get_fused_attn_backendproduces an incorrect backend decision, causing an error when the actual kernel is invoked.