Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 14 additions & 22 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2579,28 +2579,21 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
"fp8_8": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"),
}
param_types_fp8 = [torch.float16, torch.bfloat16]
cudnn_frontend_version = int(os.getenv("NVTE_FUSED_ATTN_FE_VER", "1"))
models_v0 = ["fp8_1", "fp8_2", "fp8_5", "fp8_6"]
models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"]


@pytest.mark.skipif(
(
get_cudnn_version() < (8, 9, 3)
if cudnn_frontend_version == 0
else get_cudnn_version() < (9, 2, 1)
),
reason=f"""cuDNN {"8.9.3" if cudnn_frontend_version == 0 else "9.2.1"}+ is required.""",
get_cudnn_version() < (9, 2, 1),
reason="cuDNN 9.2.1+ is required for FP8 fused attention.",
)
@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
@pytest.mark.parametrize("dtype", param_types_fp8)
@pytest.mark.parametrize("model", models_v1 if cudnn_frontend_version == 1 else models_v0)
@pytest.mark.parametrize("model", model_configs_fp8)
def test_custom_mha_fp8_vs_f16(dtype, model):
"""Test FP8 dot product attention implementations based on cuDNN frontend
v0.9 and v1.0+. Each test compares results from a custom implementation of
an FP8 MHA module, i.e. Custom_MHA_FP8(), to results from an F16 MHA
implementation, i.e. transformer_engine.pytorch.attention.MultiHeadAttention.
Both paths take F16 input and output. QKV layout is t3hd or bs3hd"""
Both paths take F16 input and output. QKV layout is bs3hd"""

config = model_configs_fp8[model]

Expand All @@ -2609,7 +2602,7 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=torch.float8_e4m3fn,
qkv_layout="t3hd" if cudnn_frontend_version == 0 else "bs3hd",
qkv_layout="bs3hd",
is_training=is_training,
deterministic=_deterministic,
)
Expand Down Expand Up @@ -2816,18 +2809,17 @@ def forward(
quantization_params=qkv_quantizer,
use_split_accumulator=_2X_ACC_FPROP,
)
qkv_layout = "bs3hd" if cudnn_frontend_version == 1 else "t3hd"
o_format = "bshd" if cudnn_frontend_version == 1 else "thd"
qkv_layout = "bs3hd"
o_format = "bshd"
qkv = qkv.view(-1, 3, h, d)
qkv_fp16 = qkv.dequantize().view(b, max_s, 3, h, d).contiguous()
torch.save(qkv_fp16, "qkv.pt")
if cudnn_frontend_version == 1:
qkv = qkv.view(b, max_s, 3, h, d) # bs3hd
qkv = qkv.view(b, max_s, 3, h, d) # bs3hd

# FMHA
q_data = qkv._data[:, :, 0, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 0, :, :]
k_data = qkv._data[:, :, 1, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 1, :, :]
v_data = qkv._data[:, :, 2, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 2, :, :]
q_data = qkv._data[:, :, 0, :, :]
k_data = qkv._data[:, :, 1, :, :]
v_data = qkv._data[:, :, 2, :, :]
q = qkv.make_like(tensor=qkv, data=q_data, shape=q_data.shape)
k = qkv.make_like(tensor=qkv, data=k_data, shape=k_data.shape)
v = qkv.make_like(tensor=qkv, data=v_data, shape=v_data.shape)
Expand All @@ -2849,7 +2841,7 @@ def forward(
qkv_layout=qkv_layout,
o_format=o_format,
attn_bias_type="no_bias",
attn_mask_type=mask_type if cudnn_frontend_version == 1 else "padding",
attn_mask_type=mask_type,
rng_gen=None,
o_quantizer=o_quantizer,
s_quantizer=s_quantizer,
Expand Down Expand Up @@ -2916,9 +2908,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
do_format=ctx.o_format,
dqkv_layout=ctx.qkv_layout,
attn_bias_type="no_bias",
attn_mask_type=ctx.mask_type if cudnn_frontend_version == 1 else "padding",
attn_mask_type=ctx.mask_type,
)
dim = 2 if cudnn_frontend_version == 1 else 1
dim = 2
dqkv = torch.Tensor().to(device=dq._data.device, dtype=dq._data.dtype)
dqkv_shape = list(dq._data.shape)
dqkv_shape.insert(dim, 3)
Expand Down
75 changes: 30 additions & 45 deletions transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,35 +255,31 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(

if ((q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2) &&
sm_arch_ >= 90 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS &&
// 8.9: t3hd, max_s=512, d=64, padding
((cudnn_runtime_version >= 8900 && sm_arch_ < 100 &&
qkv_layout == NVTE_QKV_Layout::NVTE_T3HD && max_seqlen_q == max_seqlen_kv &&
max_seqlen_q <= 512 && head_dim_qk == 64 && head_dim_v == 64 &&
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) ||
// 9.2.1: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal}
(cudnn_runtime_version >= 90201 && sm_arch_ < 100 && max_seqlen_q % 128 == 0 &&
max_seqlen_kv % 128 == 0 && head_dim_qk == 128 && head_dim_v == 128 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) ||
// 9.7: {bshd, sbhd}, any seqlen, d<=256 for sm90 and d<=128 for sm100, {padding, padding_causal}
(cudnn_runtime_version >= 90700 &&
// TODO (cyang): add is_training to nvte_get_fused_attn_backend
// sm90: fwd d<=256, bwd d=128 only
// sm100: fwd d<=128, bwd d<=128
((sm_arch_ < 100 && (!is_training) && head_dim_qk <= 256 && head_dim_v <= 256) ||
(sm_arch_ < 100 && is_training && head_dim_qk == 128 && head_dim_v == 128) ||
(sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) &&
head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) ||
// 9.21: d_qk=192, d_v=128
(cudnn_runtime_version >= 92100 && sm_arch_ >= 100 && head_dim_qk <= 192 &&
head_dim_v <= 128 && head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK))) &&
(
// 9.2.1: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal}
(cudnn_runtime_version >= 90201 && sm_arch_ < 100 && max_seqlen_q % 128 == 0 &&
max_seqlen_kv % 128 == 0 && head_dim_qk == 128 && head_dim_v == 128 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) ||
// 9.7: {bshd, sbhd}, any seqlen, d<=256 for sm90 and d<=128 for sm100, {padding, padding_causal}
(cudnn_runtime_version >= 90700 &&
// TODO (cyang): add is_training to nvte_get_fused_attn_backend
// sm90: fwd d<=256, bwd d=128 only
// sm100: fwd d<=128, bwd d<=128
((sm_arch_ < 100 && (!is_training) && head_dim_qk <= 256 && head_dim_v <= 256) ||
(sm_arch_ < 100 && is_training && head_dim_qk == 128 && head_dim_v == 128) ||
(sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) &&
head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) ||
// 9.21: d_qk=192, d_v=128
(cudnn_runtime_version >= 92100 && sm_arch_ >= 100 && head_dim_qk <= 192 &&
head_dim_v <= 128 && head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK))) &&
// pre-9.21: {bshd, sbhd}, {vanilla}
// 9.21+: {bshd, sbhd, bhsd}, {vanilla, off-by-one, learnable}
((cudnn_runtime_version < 92100 &&
Expand All @@ -295,14 +291,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
!requires_64bit_ragged_offset &&
// 9.10.0: known bugs with SDPA FP8
(cudnn_runtime_version != 91000) && !return_max_logit) {
if (cudnn_runtime_version >= 8900) {
backend = NVTE_Fused_Attn_Backend::NVTE_FP8;
} else {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
std::cout << "Warning: FP8 fused attention is supported by cuDNN 8.9.0+."
" Please upgrade your cuDNN version if possible."
<< std::endl;
}
backend = NVTE_Fused_Attn_Backend::NVTE_FP8;
} else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) {
bool flag_m512 = false;
bool flag_arb = false;
Expand Down Expand Up @@ -781,10 +770,6 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
size_t i = 0;
const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
const Tensor *input_ZInv = nullptr;
if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) {
input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
}
const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
const Tensor *input_SoftmaxOffset = nullptr;
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
Expand All @@ -798,10 +783,10 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
qkv_layout, o_format, do_format, dqkv_layout, qkv_scale_inv_format,
do_scale_inv_format, bias_type, attn_mask_type, softmax_type,
window_size_left, window_size_right, bottom_right_diagonal, deterministic,
input_Q, input_K, input_V, input_O, input_dO, input_dO_f16, input_M,
input_ZInv, input_S, input_SoftmaxOffset, input_output_dP, output_dQ,
output_dK, output_dV, output_dSoftmaxOffset, input_cu_seqlens_q,
input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle);
input_Q, input_K, input_V, input_O, input_dO, input_dO_f16, input_M, input_S,
input_SoftmaxOffset, input_output_dP, output_dQ, output_dK, output_dV,
output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv,
input_rng_state, wkspace, stream, handle);
} else {
NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
}
Expand Down
Loading
Loading