Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def impl_test_self_attn(

if not is_fused_attn_kernel_available(
is_training,
batch,
dtype,
dtype,
QKVLayout.BS3HD,
Expand Down Expand Up @@ -227,6 +228,7 @@ def test_cross_attn(

if not is_fused_attn_kernel_available(
is_training,
batch,
dtype,
dtype,
QKVLayout.BSHD_BS2HD,
Expand Down Expand Up @@ -368,6 +370,7 @@ def impl_test_context_parallel_attn(
def check_has_backend_for_mask(mask_type):
return is_fused_attn_kernel_available(
is_training,
batch,
dtype,
dtype,
qkv_layout,
Expand Down
6 changes: 4 additions & 2 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,8 +444,9 @@ def _check_configs(self):
"is either BSHD_BSHD_BSHD or THD_THD_THD"
)

self.backend = FusedAttnHelper(
self.backend, message = FusedAttnHelper(
self.is_training,
self.batch_size,
self.dtype,
self.dtype,
self.qkv_layout,
Expand All @@ -460,9 +461,10 @@ def _check_configs(self):
self.head_dim_qk,
self.head_dim_v,
(-1, -1) if self.window_size is None else self.window_size,
self.attn_mask_type.is_bottom_right(),
).get_fused_attn_backend()
if self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
pytest.skip("Unsupported inputs combination or device compute capability.")
pytest.skip(message)

if (
self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS
Expand Down
407 changes: 122 additions & 285 deletions transformer_engine/common/fused_attn/fused_attn.cpp

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -1333,4 +1333,124 @@ void fused_attn_arbitrary_seqlen_bwd(
NVTE_ERROR("Unexpected workspace_size.");
}
}

std::string is_supported_f16_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
size_t head_dim_v, bool is_training, bool return_max_logit,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool bottom_right_diagonal,
DType qkv_dtype, cudnnHandle_t handle) {
const auto b = static_cast<int64_t>(batch);
const auto h = static_cast<int64_t>(num_attn_heads);
const auto sq = static_cast<int64_t>(max_seqlen_q);
const auto skv = static_cast<int64_t>(max_seqlen_kv);

const NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout);
const NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
const bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD);
const bool is_ragged_kv = (kv_format == NVTE_QKV_Format::NVTE_THD);
const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
const bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD);
const bool has_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);

const int64_t max_b = (is_ragged_q || is_ragged_kv) ? b : 0;
const int64_t max_t_q = is_ragged_q ? b * sq : 0;
const int64_t max_t_kv = is_ragged_kv ? b * skv : 0;
const int64_t num_pages_k = is_paged_kv ? b : 0;
const int64_t num_pages_v = is_paged_kv ? b : 0;
const int64_t page_size_k = is_paged_kv ? skv : 0;
const int64_t page_size_v = is_paged_kv ? skv : 0;
const int64_t max_pages_per_seq_k = is_paged_kv ? 1 : 0;
const int64_t max_pages_per_seq_v = is_paged_kv ? 1 : 0;
const int64_t bias_b = has_bias ? b : 0;
const int64_t bias_h = has_bias ? h : 0;
const int64_t bias_sq = has_bias ? sq : 0;
const int64_t bias_skv = has_bias ? skv : 0;

const NVTE_QKV_Format o_format = q_format;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 o_format = q_format assumption can mismatch the actual forward call

is_supported_f16_fwd derives o_format from q_format, but nvte_fused_attn_fwd accepts a separate o_format parameter that is not forwarded into nvte_get_fused_attn_backend. When the caller uses a different output format from the query format — such as returning BSHD output from an SBHD_BSHD_BSHD layout — the probe builds a cuDNN graph for the wrong o_format. If the graph with o_format=q_format is accepted but the config with the actual o_format is not (or vice versa), nvte_get_fused_attn_backend produces an incorrect backend decision, causing an error when the actual kernel is invoked.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 o_format probe hardcoded as q_format, mismatching callers that use a different output format

is_supported_f16_fwd derives o_format = q_format (line 1372) and uses it to build and cache the cuDNN graph. However, nvte_fused_attn_fwd accepts an independent o_format parameter that is never forwarded into nvte_get_fused_attn_backend. When a caller uses an output format different from the query format (e.g. BSHD output from an SBHD_BSHD_BSHD layout), the cached graph was built for q_format, not the actual o_format. If cuDNN accepts the wrong graph but rejects the real one — or vice versa — the backend check produces an incorrect decision, causing an unexpected error at actual kernel invocation.


size_t workspace_size = 0;
try {
fused_attn::fused_attn_arbitrary_seqlen_fwd_impl(
b, h, static_cast<int64_t>(num_gqa_groups), sq, skv, static_cast<int64_t>(head_dim_qk),
static_cast<int64_t>(head_dim_v), max_b, max_t_q, max_t_kv, num_pages_k, num_pages_v,
page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq,
bias_skv, is_training, return_max_logit,
/*scaling_factor=*/1.0f, p_dropout, qkv_layout, o_format, bias_type, mask_type,
softmax_type, window_size_left, window_size_right, bottom_right_diagonal,
/*devPtrQ=*/nullptr, /*devPtrK=*/nullptr, /*devPtrV=*/nullptr, /*devPtrBias=*/nullptr,
/*devPtrSoftmaxOffset=*/nullptr, /*devPtrS1=*/nullptr, /*devPtrS2=*/nullptr,
/*devPtrO=*/nullptr, /*devPtrDropoutSeed=*/nullptr, /*devPtrDropoutOffset=*/nullptr,
/*devPtrCuSeqlensQ=*/nullptr, /*devPtrCuSeqlensKV=*/nullptr,
/*devPtrPageTableK=*/nullptr, /*devPtrPageTableV=*/nullptr,
/*devPtrSeqOffsetsQ=*/nullptr, /*devPtrSeqOffsetsKV=*/nullptr,
get_cudnn_fe_dtype(qkv_dtype),
/*workspace=*/nullptr, &workspace_size,
/*stream=*/static_cast<cudaStream_t>(0), handle);
return "";
} catch (const std::exception &e) {
return e.what();
} catch (...) {
return "is_supported_f16_fwd: unknown failure.";
}
}

std::string is_supported_f16_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
size_t head_dim_v, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool bottom_right_diagonal,
bool deterministic, DType qkv_dtype, cudnnHandle_t handle) {
const auto b = static_cast<int64_t>(batch);
const auto h = static_cast<int64_t>(num_attn_heads);
const auto sq = static_cast<int64_t>(max_seqlen_q);
const auto skv = static_cast<int64_t>(max_seqlen_kv);

const NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout);
const NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
const bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD);
const bool is_ragged_kv = (kv_format == NVTE_QKV_Format::NVTE_THD);
const bool has_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);

const int64_t max_b = (is_ragged_q || is_ragged_kv) ? b : 0;
const int64_t max_t_q = is_ragged_q ? b * sq : 0;
const int64_t max_t_kv = is_ragged_kv ? b * skv : 0;
const int64_t bias_b = has_bias ? b : 0;
const int64_t bias_h = has_bias ? h : 0;
const int64_t bias_sq = has_bias ? sq : 0;
const int64_t bias_skv = has_bias ? skv : 0;

const NVTE_QKV_Format o_format = q_format;
const NVTE_QKV_Format do_format = o_format;
const NVTE_QKV_Layout dqkv_layout = qkv_layout;
Comment on lines +1427 to +1428
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Hardcoded dqkv_layout and do_format in backward probe may build incorrect cached graph

is_supported_f16_bwd hardcodes dqkv_layout = qkv_layout and do_format = q_format. nvte_fused_attn_bwd accepts independent dqkv_layout and do_format parameters and passes them through to the actual backward kernel — but these are never forwarded into nvte_get_fused_attn_backend. When the activation and gradient layouts differ, the probe builds and caches a cuDNN graph for a configuration that won't be used. More critically, if the config with qkv_layout is unsupported but the config with the true dqkv_layout would be supported, backend selection will falsely return NVTE_No_Backend. The same assumption exists in is_supported_fp8_bwd at line 1385 of fused_attn_fp8.cu.

Comment on lines +1426 to +1428
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Backward probe hardcodes dqkv_layout = qkv_layout and do_format = o_format, diverging from the actual backward call

is_supported_f16_bwd sets dqkv_layout = qkv_layout and do_format = o_format (where o_format is already fixed to q_format). The actual backward call nvte_fused_attn_bwd accepts independent dqkv_layout and do_format parameters that are never threaded through nvte_get_fused_attn_backend. When activation and gradient layouts differ, the probe builds a cuDNN graph for a configuration that is never used at runtime. More critically, if the real dqkv_layout is unsupported but the probe's assumed qkv_layout is accepted, the backend check returns NVTE_F16_arbitrary_seqlen and the actual backward pass silently fails. The same issue exists in is_supported_fp8_bwd at line 1385 of fused_attn_fp8.cu.


size_t workspace_size = 0;
try {
fused_attn::fused_attn_arbitrary_seqlen_bwd_impl(
b, h, static_cast<int64_t>(num_gqa_groups), sq, skv, static_cast<int64_t>(head_dim_qk),
static_cast<int64_t>(head_dim_v), max_b, max_t_q, max_t_kv, bias_b, bias_h, bias_sq,
bias_skv, /*scaling_factor=*/1.0f, p_dropout, qkv_layout, o_format, do_format, dqkv_layout,
bias_type, mask_type, softmax_type, window_size_left, window_size_right,
bottom_right_diagonal, deterministic, /*devPtrQ=*/nullptr, /*devPtrKTranspose=*/nullptr,
/*devPtrVTranspose=*/nullptr, /*devPtrO=*/nullptr, /*devPtrSoftmaxStats=*/nullptr,
/*devPtrBias=*/nullptr, /*devPtrSoftmaxOffset=*/nullptr, /*devPtrdQ=*/nullptr,
/*devPtrdK=*/nullptr, /*devPtrdV=*/nullptr, /*devPtrdO=*/nullptr,
/*devPtrdBias=*/nullptr, /*devPtrdSoftmaxOffset=*/nullptr,
/*devPtrDropoutSeed=*/nullptr, /*devPtrDropoutOffset=*/nullptr,
/*devPtrCuSeqlensQ=*/nullptr, /*devPtrCuSeqlensKV=*/nullptr,
/*devPtrSeqOffsetsQ=*/nullptr, /*devPtrSeqOffsetsKV=*/nullptr,
get_cudnn_fe_dtype(qkv_dtype),
/*workspace=*/nullptr, &workspace_size,
/*stream=*/static_cast<cudaStream_t>(0), handle);
return "";
} catch (const std::exception &e) {
return e.what();
} catch (...) {
return "is_supported_f16_bwd: unknown failure.";
}
}

} // namespace transformer_engine
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

#include <cudnn.h>

#include <string>

#include "common/common.h"
#include "transformer_engine/fused_attn.h"

Expand Down Expand Up @@ -47,6 +49,29 @@ void fused_attn_arbitrary_seqlen_bwd(
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);

// check if a given configuration is supported for F16/BF16 forward;
// if it is, cache the graph built for this config, and return an empty string;
// if not, return a diagnostic message in the form of a string.
std::string is_supported_f16_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
size_t head_dim_v, bool is_training, bool return_max_logit,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool bottom_right_diagonal,
DType qkv_dtype, cudnnHandle_t handle);

// check if a given configuration is supported for F16/BF16 backward;
// if it is, cache the graph built for this config, and return an empty string;
// if not, return a diagnostic message in the form of a string.
std::string is_supported_f16_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
size_t head_dim_v, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool bottom_right_diagonal,
bool deterministic, DType qkv_dtype, cudnnHandle_t handle);

} // namespace transformer_engine

#endif // TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_ARBITRARY_SEQLEN_H_
87 changes: 87 additions & 0 deletions transformer_engine/common/fused_attn/fused_attn_fp8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1324,4 +1324,91 @@ void fused_attn_fp8_bwd(
return;
}
}

std::string is_supported_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
size_t head_dim_v, bool is_training, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right,
bool bottom_right_diagonal, DType qkv_dtype, DType o_dtype,
NVTEScalingMode scaling_mode, cudnnHandle_t handle) {
const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
size_t workspace_size = 0;
try {
fused_attn::fused_attn_fp8_fwd_impl(
static_cast<int64_t>(batch), static_cast<int64_t>(num_attn_heads),
static_cast<int64_t>(num_gqa_groups), static_cast<int64_t>(max_seqlen_q),
static_cast<int64_t>(max_seqlen_kv), static_cast<int64_t>(head_dim_qk),
static_cast<int64_t>(head_dim_v), is_training, /*scaling_factor=*/1.0f, p_dropout,
qkv_layout, /*o_format=*/qkv_format, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, bottom_right_diagonal,
/*devPtrQ=*/nullptr, /*devPtrK=*/nullptr, /*devPtrV=*/nullptr,
/*devPtrSoftmaxOffset=*/nullptr, /*devPtrM=*/nullptr, /*devPtrO=*/nullptr,
/*devPtrDescaleQ=*/nullptr, /*devPtrDescaleK=*/nullptr, /*devPtrDescaleV=*/nullptr,
/*devPtrDescaleS=*/nullptr, /*devPtrScaleS=*/nullptr, /*devPtrScaleO=*/nullptr,
/*devPtrAmaxO=*/nullptr, /*devPtrAmaxS=*/nullptr, /*devPtrcuSeqlensQ=*/nullptr,
/*devPtrcuSeqlensKV=*/nullptr, /*devPtrDropoutSeed=*/nullptr,
/*devPtrDropoutOffset=*/nullptr, get_cudnn_fe_dtype(qkv_dtype), get_cudnn_fe_dtype(o_dtype),
scaling_mode,
/*qkv_scale_inv_format=*/NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET,
/*workspace=*/nullptr, &workspace_size,
/*stream=*/static_cast<cudaStream_t>(0), handle);
return "";
} catch (const std::exception& e) {
return e.what();
} catch (...) {
return "is_supported_fp8_fwd: unknown failure.";
}
}

std::string is_supported_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
size_t head_dim_v, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool bottom_right_diagonal,
bool deterministic, DType qkv_dtype, DType o_dtype,
NVTEScalingMode scaling_mode, cudnnHandle_t handle) {
const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
const cudnn_frontend::DataType_t qkv_t = get_cudnn_fe_dtype(qkv_dtype);
const cudnn_frontend::DataType_t o_t = get_cudnn_fe_dtype(o_dtype);
const cudnn_frontend::DataType_t do_t = o_t;
const cudnn_frontend::DataType_t dqkv_t = qkv_t;
size_t workspace_size = 0;
try {
fused_attn::fused_attn_fp8_bwd_impl(
static_cast<int64_t>(batch), static_cast<int64_t>(num_attn_heads),
static_cast<int64_t>(num_gqa_groups), static_cast<int64_t>(max_seqlen_q),
static_cast<int64_t>(max_seqlen_kv), static_cast<int64_t>(head_dim_qk),
static_cast<int64_t>(head_dim_v), /*scaling_factor=*/1.0f, p_dropout, qkv_layout,
/*o_format=*/qkv_format, /*do_format=*/qkv_format, /*dqkv_layout=*/qkv_layout, bias_type,
mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal,
deterministic,
/*devPtrQ=*/nullptr, /*devPtrK=*/nullptr, /*devPtrV=*/nullptr, /*devPtrM=*/nullptr,
/*devPtrO=*/nullptr, /*devPtrdO=*/nullptr, /*devPtrSoftmaxOffset=*/nullptr,
/*devPtrdQ=*/nullptr, /*devPtrdK=*/nullptr, /*devPtrdV=*/nullptr,
/*devPtrdSoftmaxOffset=*/nullptr, /*devPtrDescaleQ=*/nullptr,
/*devPtrDescaleK=*/nullptr, /*devPtrDescaleV=*/nullptr, /*devPtrDescaleO=*/nullptr,
/*devPtrDescaledO=*/nullptr, /*devPtrDescaleS=*/nullptr, /*devPtrDescaledP=*/nullptr,
/*devPtrScaleS=*/nullptr, /*devPtrScaledP=*/nullptr, /*devPtrScaledQ=*/nullptr,
/*devPtrScaledK=*/nullptr, /*devPtrScaledV=*/nullptr, /*devPtrAmaxdP=*/nullptr,
/*devPtrAmaxdQ=*/nullptr, /*devPtrAmaxdK=*/nullptr, /*devPtrAmaxdV=*/nullptr,
/*devPtrQ_t=*/nullptr, /*devPtrK_t=*/nullptr, /*devPtrdO_f16=*/nullptr,
/*devPtrdO_t=*/nullptr, /*devPtrDescaleQ_t=*/nullptr, /*devPtrDescaleK_t=*/nullptr,
/*devPtrDescaledO_t=*/nullptr, /*devPtrcuSeqlensQ=*/nullptr,
/*devPtrcuSeqlensKV=*/nullptr, /*devPtrDropoutSeed=*/nullptr,
/*devPtrDropoutOffset=*/nullptr, qkv_t, o_t, do_t, dqkv_t, scaling_mode,
/*qkv_scale_inv_format=*/NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET,
/*do_scale_inv_format=*/NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET,
/*workspace=*/nullptr, &workspace_size,
/*stream=*/static_cast<cudaStream_t>(0), handle);
return "";
} catch (const std::exception& e) {
return e.what();
} catch (...) {
return "is_supported_fp8_bwd: unknown failure.";
}
}

} // namespace transformer_engine
26 changes: 26 additions & 0 deletions transformer_engine/common/fused_attn/fused_attn_fp8.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
* \brief Functions for fused attention for FP8
*/

#include <string>

#include "transformer_engine/fused_attn.h"
#include "transformer_engine/transformer_engine.h"

Expand Down Expand Up @@ -39,4 +41,28 @@ void fused_attn_fp8_bwd(
const Tensor *output_dK, const Tensor *output_dV, Tensor *output_dSoftmaxOffset,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);

// check if a given configuration is supported for FP8 forward;
// if it is, cache the graph built for this config, and return an empty string;
// if not, return a diagnostic message in the form of a string.
std::string is_supported_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
size_t head_dim_v, bool is_training, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right,
bool bottom_right_diagonal, DType qkv_dtype, DType o_dtype,
NVTEScalingMode scaling_mode, cudnnHandle_t handle);

// check if a given configuration is supported for FP8 backward;
// if it is, cache the graph built for this config, and return an empty string;
// if not, return a diagnostic message in the form of a string.
std::string is_supported_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
size_t head_dim_v, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool bottom_right_diagonal,
bool deterministic, DType qkv_dtype, DType o_dtype,
NVTEScalingMode scaling_mode, cudnnHandle_t handle);
} // namespace transformer_engine
Loading
Loading