diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index c9ea791444..0f248a9f55 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -2579,28 +2579,21 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: "fp8_8": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"), } param_types_fp8 = [torch.float16, torch.bfloat16] -cudnn_frontend_version = int(os.getenv("NVTE_FUSED_ATTN_FE_VER", "1")) -models_v0 = ["fp8_1", "fp8_2", "fp8_5", "fp8_6"] -models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"] @pytest.mark.skipif( - ( - get_cudnn_version() < (8, 9, 3) - if cudnn_frontend_version == 0 - else get_cudnn_version() < (9, 2, 1) - ), - reason=f"""cuDNN {"8.9.3" if cudnn_frontend_version == 0 else "9.2.1"}+ is required.""", + get_cudnn_version() < (9, 2, 1), + reason="cuDNN 9.2.1+ is required for FP8 fused attention.", ) @pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn) @pytest.mark.parametrize("dtype", param_types_fp8) -@pytest.mark.parametrize("model", models_v1 if cudnn_frontend_version == 1 else models_v0) +@pytest.mark.parametrize("model", model_configs_fp8) def test_custom_mha_fp8_vs_f16(dtype, model): """Test FP8 dot product attention implementations based on cuDNN frontend v0.9 and v1.0+. Each test compares results from a custom implementation of an FP8 MHA module, i.e. Custom_MHA_FP8(), to results from an F16 MHA implementation, i.e. transformer_engine.pytorch.attention.MultiHeadAttention. - Both paths take F16 input and output. QKV layout is t3hd or bs3hd""" + Both paths take F16 input and output. QKV layout is bs3hd""" config = model_configs_fp8[model] @@ -2609,7 +2602,7 @@ def test_custom_mha_fp8_vs_f16(dtype, model): available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=torch.float8_e4m3fn, - qkv_layout="t3hd" if cudnn_frontend_version == 0 else "bs3hd", + qkv_layout="bs3hd", is_training=is_training, deterministic=_deterministic, ) @@ -2816,18 +2809,17 @@ def forward( quantization_params=qkv_quantizer, use_split_accumulator=_2X_ACC_FPROP, ) - qkv_layout = "bs3hd" if cudnn_frontend_version == 1 else "t3hd" - o_format = "bshd" if cudnn_frontend_version == 1 else "thd" + qkv_layout = "bs3hd" + o_format = "bshd" qkv = qkv.view(-1, 3, h, d) qkv_fp16 = qkv.dequantize().view(b, max_s, 3, h, d).contiguous() torch.save(qkv_fp16, "qkv.pt") - if cudnn_frontend_version == 1: - qkv = qkv.view(b, max_s, 3, h, d) # bs3hd + qkv = qkv.view(b, max_s, 3, h, d) # bs3hd # FMHA - q_data = qkv._data[:, :, 0, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 0, :, :] - k_data = qkv._data[:, :, 1, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 1, :, :] - v_data = qkv._data[:, :, 2, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 2, :, :] + q_data = qkv._data[:, :, 0, :, :] + k_data = qkv._data[:, :, 1, :, :] + v_data = qkv._data[:, :, 2, :, :] q = qkv.make_like(tensor=qkv, data=q_data, shape=q_data.shape) k = qkv.make_like(tensor=qkv, data=k_data, shape=k_data.shape) v = qkv.make_like(tensor=qkv, data=v_data, shape=v_data.shape) @@ -2849,7 +2841,7 @@ def forward( qkv_layout=qkv_layout, o_format=o_format, attn_bias_type="no_bias", - attn_mask_type=mask_type if cudnn_frontend_version == 1 else "padding", + attn_mask_type=mask_type, rng_gen=None, o_quantizer=o_quantizer, s_quantizer=s_quantizer, @@ -2916,9 +2908,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], do_format=ctx.o_format, dqkv_layout=ctx.qkv_layout, attn_bias_type="no_bias", - attn_mask_type=ctx.mask_type if cudnn_frontend_version == 1 else "padding", + attn_mask_type=ctx.mask_type, ) - dim = 2 if cudnn_frontend_version == 1 else 1 + dim = 2 dqkv = torch.Tensor().to(device=dq._data.device, dtype=dq._data.dtype) dqkv_shape = list(dq._data.shape) dqkv_shape.insert(dim, 3) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 141767b803..c6a85ad448 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -255,35 +255,31 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( if ((q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2) && sm_arch_ >= 90 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && - // 8.9: t3hd, max_s=512, d=64, padding - ((cudnn_runtime_version >= 8900 && sm_arch_ < 100 && - qkv_layout == NVTE_QKV_Layout::NVTE_T3HD && max_seqlen_q == max_seqlen_kv && - max_seqlen_q <= 512 && head_dim_qk == 64 && head_dim_v == 64 && - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || - // 9.2.1: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal} - (cudnn_runtime_version >= 90201 && sm_arch_ < 100 && max_seqlen_q % 128 == 0 && - max_seqlen_kv % 128 == 0 && head_dim_qk == 128 && head_dim_v == 128 && - (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || - // 9.7: {bshd, sbhd}, any seqlen, d<=256 for sm90 and d<=128 for sm100, {padding, padding_causal} - (cudnn_runtime_version >= 90700 && - // TODO (cyang): add is_training to nvte_get_fused_attn_backend - // sm90: fwd d<=256, bwd d=128 only - // sm100: fwd d<=128, bwd d<=128 - ((sm_arch_ < 100 && (!is_training) && head_dim_qk <= 256 && head_dim_v <= 256) || - (sm_arch_ < 100 && is_training && head_dim_qk == 128 && head_dim_v == 128) || - (sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) && - head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && - (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) || - // 9.21: d_qk=192, d_v=128 - (cudnn_runtime_version >= 92100 && sm_arch_ >= 100 && head_dim_qk <= 192 && - head_dim_v <= 128 && head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && - (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK))) && + ( + // 9.2.1: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal} + (cudnn_runtime_version >= 90201 && sm_arch_ < 100 && max_seqlen_q % 128 == 0 && + max_seqlen_kv % 128 == 0 && head_dim_qk == 128 && head_dim_v == 128 && + (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || + // 9.7: {bshd, sbhd}, any seqlen, d<=256 for sm90 and d<=128 for sm100, {padding, padding_causal} + (cudnn_runtime_version >= 90700 && + // TODO (cyang): add is_training to nvte_get_fused_attn_backend + // sm90: fwd d<=256, bwd d=128 only + // sm100: fwd d<=128, bwd d<=128 + ((sm_arch_ < 100 && (!is_training) && head_dim_qk <= 256 && head_dim_v <= 256) || + (sm_arch_ < 100 && is_training && head_dim_qk == 128 && head_dim_v == 128) || + (sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) && + head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && + (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) || + // 9.21: d_qk=192, d_v=128 + (cudnn_runtime_version >= 92100 && sm_arch_ >= 100 && head_dim_qk <= 192 && + head_dim_v <= 128 && head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && + (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK))) && // pre-9.21: {bshd, sbhd}, {vanilla} // 9.21+: {bshd, sbhd, bhsd}, {vanilla, off-by-one, learnable} ((cudnn_runtime_version < 92100 && @@ -295,14 +291,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( !requires_64bit_ragged_offset && // 9.10.0: known bugs with SDPA FP8 (cudnn_runtime_version != 91000) && !return_max_logit) { - if (cudnn_runtime_version >= 8900) { - backend = NVTE_Fused_Attn_Backend::NVTE_FP8; - } else { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: FP8 fused attention is supported by cuDNN 8.9.0+." - " Please upgrade your cuDNN version if possible." - << std::endl; - } + backend = NVTE_Fused_Attn_Backend::NVTE_FP8; } else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) { bool flag_m512 = false; bool flag_arb = false; @@ -781,10 +770,6 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { size_t i = 0; const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - const Tensor *input_ZInv = nullptr; - if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - } const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); const Tensor *input_SoftmaxOffset = nullptr; if (softmax_type != NVTE_VANILLA_SOFTMAX) { @@ -798,10 +783,10 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso qkv_layout, o_format, do_format, dqkv_layout, qkv_scale_inv_format, do_scale_inv_format, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, - input_Q, input_K, input_V, input_O, input_dO, input_dO_f16, input_M, - input_ZInv, input_S, input_SoftmaxOffset, input_output_dP, output_dQ, - output_dK, output_dV, output_dSoftmaxOffset, input_cu_seqlens_q, - input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); + input_Q, input_K, input_V, input_O, input_dO, input_dO_f16, input_M, input_S, + input_SoftmaxOffset, input_output_dP, output_dQ, output_dK, output_dV, + output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, + input_rng_state, wkspace, stream, handle); } else { NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); } diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index d97f388459..eab1ae02e6 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -15,1648 +15,14 @@ namespace fused_attn { using namespace transformer_engine; -std::unordered_map tensor_name_to_uid = {{"Q", 1}, - {"K", 2}, - {"V", 3}, - {"O", 4}, - {"S", 5}, - {"B", 6}, - {"DROPOUT_SCALE", 7}, - {"S_CONST", 8}, - {"MNK_OVERRIDE", 9}, - {"dQ", 11}, - {"dK", 12}, - {"dV", 13}, - {"dO", 14}, - {"MASK_VAL", 15}, - {"dS", 16}, - {"O_SEQLEN", 17}, - {"M", 18}, - {"Z", 19}, - {"descaleQ", 20}, - {"descaleK", 21}, - {"descaleV", 22}, - {"descaleS", 23}, - {"scaleS", 24}, - {"amaxS", 25}, - {"amaxO", 26}, - {"QKV_RAGGED", 27}, - {"O_RAGGED", 28}, - {"K_TRANSPOSE", 29}, - {"AttnScale", 30}, - {"scaleO", 31}, - {"Z_INV", 32}, - {"descaleO", 33}, - {"descaledO", 34}, - {"descaledS", 35}, - {"descaledQ", 36}, - {"descaledK", 37}, - {"descaledV", 38}, - {"scaledS", 39}, - {"scaledQ", 40}, - {"scaledK", 41}, - {"scaledV", 42}, - {"amaxdS", 43}, - {"amaxdQ", 44}, - {"amaxdK", 45}, - {"amaxdV", 46}, - {"V_TRANSPOSE", 47}, - {"AttnScale_dS_K", 48}, - {"AttnScale_dSTranspose_Q", 49}, - {"DROPOUT_SCALE_dOVt_OdO", 50}, - {"DROPOUT_OFFSET", 51}, - {"DROPOUT_SEED", 52}, - {"VIRTUAL", 80}}; - -static cudnn_frontend::Tensor createAmax(const std::string& amax_tensor_name, - const cudnn_frontend::Tensor& prevBlockOutputTensor, - std::vector* ops) { - int64_t amax_dim[4] = {1, 1, 1, 1}; - int64_t amax_stride[4] = {1, 1, 1, 1}; - auto amaxTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid[amax_tensor_name], amax_dim, - amax_stride, false, false); - - // Define the amax descriptor - auto reductionDesc = cudnn_frontend::ReductionDescBuilder() - .setMathPrecision(CUDNN_DATA_FLOAT) - .setReductionOp(CUDNN_REDUCE_TENSOR_AMAX) - .build(); - - // Create a reduction amax Node - auto reduction_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) - .setxDesc(prevBlockOutputTensor) - .setyDesc(amaxTensor) - .setreductionDesc(reductionDesc) - .build(); - ops->push_back(std::move(reduction_op)); - return amaxTensor; -} - -static cudnn_frontend::Tensor createScale(const cudnn_frontend::Tensor& prevBlockOutputTensor, - const std::string& scale_tensor_name, - cudnnDataType_t tensorType, bool isOutputVirtual, - bool isScaleByValue, - std::vector* ops, - const std::string& output_tensor_name = "") { - int64_t scale_dim[4] = {1, 1, 1, 1}; - int64_t scale_stride[4] = {1, 1, 1, 1}; - - int64_t output_dim[4]; - int64_t output_stride[4]; - - for (int i = 0; i < 4; i++) { - output_dim[i] = prevBlockOutputTensor.getDim()[i]; - output_stride[i] = prevBlockOutputTensor.getStride()[i]; - } - - auto scaleTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid[scale_tensor_name], - scale_dim, scale_stride, false, isScaleByValue); // is by value - - int64_t outputUID = - isOutputVirtual ? tensor_name_to_uid["VIRTUAL"] + tensor_name_to_uid[scale_tensor_name] + 5000 - : tensor_name_to_uid[output_tensor_name]; - auto afterScaleKTensor = tensor_create(tensorType, outputUID, output_dim, output_stride, - isOutputVirtual, false); // is virtual - - // Define the scale descriptor - auto scaleDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - - // Create a Scale Node - auto scale_op = - binary_pw_op_create(prevBlockOutputTensor, scaleTensor, afterScaleKTensor, scaleDesc); - - ops->push_back(std::move(scale_op)); - return afterScaleKTensor; -} - -static cudnn_frontend::Tensor createScale(const cudnn_frontend::Tensor& prevBlockOutputTensor, - const cudnn_frontend::Tensor& scaleTensor, - cudnnDataType_t tensorType, bool isOutputVirtual, - bool isScaleByValue, - std::vector* ops, - int UID_offset, - const std::string& output_tensor_name = "") { - int64_t output_dim[4]; - int64_t output_stride[4]; - for (int i = 0; i < 4; i++) { - output_dim[i] = prevBlockOutputTensor.getDim()[i]; - output_stride[i] = prevBlockOutputTensor.getStride()[i]; - } - - int64_t outputUID = isOutputVirtual ? tensor_name_to_uid["VIRTUAL"] + UID_offset - : tensor_name_to_uid[output_tensor_name]; - auto afterScaleTensor = tensor_create(tensorType, outputUID, output_dim, output_stride, - isOutputVirtual, false); // is virtual - - // Define the scale descriptor - auto scaleDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - - // Create a Scale Node - auto scale_op = - binary_pw_op_create(prevBlockOutputTensor, scaleTensor, afterScaleTensor, scaleDesc); - - ops->push_back(std::move(scale_op)); - return afterScaleTensor; -} - -static cudnn_frontend::Tensor createScaleWithOffset( - const cudnn_frontend::Tensor& prevBlockOutputTensor, const std::string& scale_tensor_name, - NVTE_QKV_Layout layout, cudnnDataType_t tensorType, bool isOutputVirtual, bool isScaleByValue, - std::vector* ops, - std::shared_ptr offsetTensor, - const std::string& output_tensor_name = "") { - int64_t scale_dim[4] = {1, 1, 1, 1}; - int64_t scale_stride[4] = {1, 1, 1, 1}; - - int64_t output_dim[4]; - int64_t output_stride[4]; - // If output tensor is dQ, dK, or dV, we need to generate QKV interleaved strides - if (output_tensor_name == "dQ" || output_tensor_name == "dK" || output_tensor_name == "dV") { - for (int i = 0; i < 4; i++) { - output_dim[i] = prevBlockOutputTensor.getDim()[i]; - } - generateMatrixStrides(output_dim[0], output_dim[1], output_dim[2], - 0 /*s_kv = 0 for placeholder*/, output_dim[3], output_stride, layout, - NVTE_QKV_Matrix::NVTE_Q_Matrix); - } else { - // Otherwise output dim and stride should be the same as prev block dim and stride - for (int i = 0; i < 4; i++) { - output_dim[i] = prevBlockOutputTensor.getDim()[i]; - output_stride[i] = prevBlockOutputTensor.getStride()[i]; - } - } - - auto scaleTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid[scale_tensor_name], - scale_dim, scale_stride, false, isScaleByValue); // is by value - - cudnnDataType_t outputDataType = isOutputVirtual ? CUDNN_DATA_FLOAT : tensorType; - int64_t outputUID = - isOutputVirtual ? tensor_name_to_uid["VIRTUAL"] + tensor_name_to_uid[scale_tensor_name] + 7000 - : tensor_name_to_uid[output_tensor_name]; - auto afterScaleTensor = - tensor_create_with_offset(outputDataType, outputUID, output_dim, output_stride, - isOutputVirtual, false, offsetTensor); // is virtual - - // Define the scale descriptor - auto scaleDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - - // Create a Scale Node - auto scale_op = - binary_pw_op_create(prevBlockOutputTensor, scaleTensor, afterScaleTensor, scaleDesc); - - ops->push_back(std::move(scale_op)); - return afterScaleTensor; -} - -static cudnn_frontend::Tensor createSoftmaxForward( - int64_t b, int64_t h, int64_t s_q, int64_t s_kv, std::vector* ops, - const cudnn_frontend::Tensor& prevBlockOutputTensor, bool isTraining) { - int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; - int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; - - int64_t afterReduction_dim[4] = {b, h, s_q, 1}; - int64_t afterReduction_stride[4] = {h * s_q, s_q, 1, 1}; - - // max (x) (M tensor) - auto afterMaxReductionTensor = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["M"], afterReduction_dim, - afterReduction_stride, !isTraining, false); // not virtual if training is true, - // virtual if training is false - // x - max(x) - auto afterSubtractionTensor = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 151, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - // e^(x - max(x)) - auto afterExponentTensor = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 152, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual; - // sum (e^(x - max(x))) (Z tensor) - auto zTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["Z"], afterReduction_dim, - afterReduction_stride, true, false); // is virtual - // 1 / sum (e^(x - max(x))) (Z_INV tensor) - auto zInvTensor = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["Z_INV"], afterReduction_dim, - afterReduction_stride, !isTraining, false); // not virtual if training is true, - // virtual if training is false - // Final softmax output (After exponent * Z_INV) - auto beforeDropoutTensor = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 153, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - - // Define the reduction descriptor - auto reductionMaxDesc = cudnn_frontend::ReductionDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setReductionOp(CUDNN_REDUCE_TENSOR_MAX) - .build(); - - // Create a reduction max Node - auto reductionMax_op = - cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) - .setxDesc(prevBlockOutputTensor) - .setyDesc(afterMaxReductionTensor) - .setreductionDesc(reductionMaxDesc) - .build(); - - // Define the subtract descriptor - auto subtractDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB); - - // Create a subtract Node - auto subtract_op = binary_pw_op_create(prevBlockOutputTensor, afterMaxReductionTensor, - afterSubtractionTensor, subtractDesc); - - // Define the exponent descriptor - auto exponentDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_EXP); - - // Create a exponent Node - auto exponent_op = unary_pw_op_create(afterSubtractionTensor, afterExponentTensor, exponentDesc); - - // Define the reduction descriptor - auto reductionAddDesc = cudnn_frontend::ReductionDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setReductionOp(CUDNN_REDUCE_TENSOR_ADD) - .build(); - - // Create a reduction add Node - auto reductionAdd_op = - cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) - .setxDesc(afterExponentTensor) - .setyDesc(zTensor) - .setreductionDesc(reductionAddDesc) - .build(); - - // Define the reciprocal descriptor - auto reciprocalDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_RECIPROCAL); - - // Create a reciprocal Node - auto reciprocal_op = unary_pw_op_create(zTensor, zInvTensor, reciprocalDesc); - - // Define the pw multiply descriptor - auto multiplyDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - - // Create a multiply Node - auto mutliply_op = - binary_pw_op_create(afterExponentTensor, zInvTensor, beforeDropoutTensor, multiplyDesc); - - ops->push_back(std::move(reductionMax_op)); - ops->push_back(std::move(subtract_op)); - ops->push_back(std::move(exponent_op)); - ops->push_back(std::move(reductionAdd_op)); - ops->push_back(std::move(reciprocal_op)); - ops->push_back(std::move(mutliply_op)); - - return beforeDropoutTensor; -} - -static cudnn_frontend::Tensor createDropoutForward( - int64_t b, int64_t h, int64_t s_q, int64_t s_kv, double probability, - std::vector* ops, - const cudnn_frontend::Tensor& beforeDropoutTensor) { - NVTE_CHECK(ops->size() > 0, "Dropout DAG constructed incorrectly as the first one"); - - int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; - int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; - - int64_t scale_dim[4] = {1, 1, 1, 1}; - int64_t scale_stride[4] = {1, 1, 1, 1}; - - // Mask for the dropout - auto dropoutMaskTensor = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 250, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - auto dropoutSeedTensor = tensor_create(CUDNN_DATA_INT64, tensor_name_to_uid["DROPOUT_SEED"], - scale_dim, scale_stride, false, false); // is by value - auto dropoutOffsetTensor = tensor_create(CUDNN_DATA_INT64, tensor_name_to_uid["DROPOUT_OFFSET"], - scale_dim, scale_stride, false, false); // is by value - - // After dropout tensor befor scale - auto beforeDropoutScaleTensor = - cudnn_frontend::TensorBuilder() - .setDim(4, afterBMM1_dim) - .setStride(4, afterBMM1_stride) - .setId(tensor_name_to_uid["VIRTUAL"] + 201) - .setAlignment(16) // 16B alignment is needed to run a tensor core engine - .setDataType(CUDNN_DATA_FLOAT) - .setVirtual(true) - .setByValue(false) - .setReorderType(cudnn_frontend::TensorReordering_t::F16x16) - .build(); - // Scale after dropout - auto scaleDropoutTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["DROPOUT_SCALE"], - scale_dim, scale_stride, false, true); // is by value - // After Scale - auto afterDropout_before_quan_S = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 202, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - - // Define the reduction descriptor - auto rngDesc = cudnn_frontend::RngDescBuilder() - .setRngDistribution(CUDNN_RNG_DISTRIBUTION_BERNOULLI) - .setBernoulliDistProbability(1.0 - probability) - .build(); - - // Create a rng Node - auto rng_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR) - .setyDesc(dropoutMaskTensor) - .setSeedDesc(dropoutSeedTensor) - .setOffsetDesc(dropoutOffsetTensor) - .setRngDesc(rngDesc) - .build(); - - // Define the multiply mask descriptor - auto maskMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - - // Create a multiply mask Node - auto maskMul_op = binary_pw_op_create(beforeDropoutTensor, dropoutMaskTensor, - beforeDropoutScaleTensor, maskMulDesc); - - // Define the multiply scale descriptor - auto scaleMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - - // Create a multiply mask Node - auto scaleMul_op = binary_pw_op_create(beforeDropoutScaleTensor, scaleDropoutTensor, - afterDropout_before_quan_S, scaleMulDesc); - - ops->push_back(std::move(rng_op)); - ops->push_back(std::move(maskMul_op)); - ops->push_back(std::move(scaleMul_op)); - - return afterDropout_before_quan_S; -} - -static cudnn_frontend::Tensor createDropoutBackward( - int64_t b, int64_t h, int64_t s_q, int64_t s_kv, double probability, - std::vector* ops, const cudnn_frontend::Tensor& beforeDropoutTensor, - const cudnn_frontend::Tensor& dropoutMaskTensor) { - NVTE_CHECK(ops->size() > 0, "Dropout DAG constructed incorrectly as the first one"); - - int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; - int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; - - int64_t scale_dim[4] = {1, 1, 1, 1}; - int64_t scale_stride[4] = {1, 1, 1, 1}; - - auto dropoutSeedTensor = tensor_create(CUDNN_DATA_INT64, tensor_name_to_uid["DROPOUT_SEED"], - scale_dim, scale_stride, false, false); // is by value - auto dropoutOffsetTensor = tensor_create(CUDNN_DATA_INT64, tensor_name_to_uid["DROPOUT_OFFSET"], - scale_dim, scale_stride, false, false); // is by value - - // After dropout tensor befor scale - auto beforeDropoutScaleTensor = - cudnn_frontend::TensorBuilder() - .setDim(4, afterBMM1_dim) - .setStride(4, afterBMM1_stride) - .setId(tensor_name_to_uid["VIRTUAL"] + 201) - .setAlignment(16) // 16B alignment is needed to run a tensor core engine - .setDataType(CUDNN_DATA_FLOAT) - .setVirtual(true) - .setByValue(false) - .setReorderType(cudnn_frontend::TensorReordering_t::F16x16) - .build(); - // Scale after dropout (1 / (1 - p)) - auto scaleDropoutTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["DROPOUT_SCALE"], - scale_dim, scale_stride, false, true); // is by value - // After Scale - auto afterDropout_before_quan_S = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 202, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - - // Define the reduction descriptor - auto rngDesc = cudnn_frontend::RngDescBuilder() - .setRngDistribution(CUDNN_RNG_DISTRIBUTION_BERNOULLI) - .setBernoulliDistProbability(1.0 - probability) - .build(); - - // Create a rng Node - auto rng_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR) - .setyDesc(dropoutMaskTensor) - .setSeedDesc(dropoutSeedTensor) - .setOffsetDesc(dropoutOffsetTensor) - .setRngDesc(rngDesc) - .build(); - - // Define the multiply mask descriptor - auto maskMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - - // Create a multiply mask Node - auto maskMul_op = binary_pw_op_create(beforeDropoutTensor, dropoutMaskTensor, - beforeDropoutScaleTensor, maskMulDesc); - - // Define the multiply scale descriptor - auto scaleMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - - // Create a multiply mask Node - auto scaleMul_op = binary_pw_op_create(beforeDropoutScaleTensor, scaleDropoutTensor, - afterDropout_before_quan_S, scaleMulDesc); - - ops->push_back(std::move(rng_op)); - ops->push_back(std::move(maskMul_op)); - ops->push_back(std::move(scaleMul_op)); - - return afterDropout_before_quan_S; -} - -static cudnn_frontend::Tensor createSoftmaxBackward(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, - std::vector* ops, - const cudnn_frontend::Tensor& dyTensor) { - NVTE_CHECK(ops->size() > 0, "Softmax backward constructed incorrectly as the first one"); - - int64_t dx_dim[4] = {b, h, s_q, s_kv}; - int64_t dx_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; - - int64_t M_Z_dim[4] = {b, h, s_q, 1}; - int64_t M_Z_stride[4] = {h * s_q, s_q, 1, 1}; - - // Creating all tensors - auto MTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["M"], M_Z_dim, M_Z_stride, - false, false); // not virtual - auto ZInvTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["Z_INV"], M_Z_dim, - M_Z_stride, false, false); // not virtual - auto dxAfterSubtractionTensor = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 252, dx_dim, dx_stride, true, - false); // is virtual - auto dxAfterExponentiation = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 253, - dx_dim, dx_stride, true, false); // is virtual - auto dxBeforeDropout_QKt_Tensor = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 254, dx_dim, dx_stride, true, - false); // is virtual - - // Creating all ops - // sub (dy - M) - auto subtractionDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB); - auto subtractionOp = - binary_pw_op_create(dyTensor, MTensor, dxAfterSubtractionTensor, subtractionDesc); - - // Define the exponent descriptor - auto exponentDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_EXP); - - // Create a exponent Node. (exp(dy - M)) - auto exponentOp = - unary_pw_op_create(dxAfterSubtractionTensor, dxAfterExponentiation, exponentDesc); - - // Define the pw multiply descriptor - auto multiplyDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - - // Create a multiply Node - auto mutliplyOp = binary_pw_op_create(dxAfterExponentiation, ZInvTensor, - dxBeforeDropout_QKt_Tensor, multiplyDesc); - - ops->push_back(std::move(subtractionOp)); - ops->push_back(std::move(exponentOp)); - ops->push_back(std::move(mutliplyOp)); - - return dxBeforeDropout_QKt_Tensor; -} - -static cudnn_frontend::Tensor createQKBMM( - int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, NVTE_QKV_Layout layout, - cudnnDataType_t tensorType, std::vector* ops, - const cudnn_frontend::Tensor& qTensor, const cudnn_frontend::Tensor& kTensor, - const cudnn_frontend::Tensor& mnkOverride, - std::shared_ptr QKVRaggedOffsetTensor) { - // Creates the necessary tensor descriptors - int64_t k_transpose_dim[4] = {b, h, d, s_kv}; - int64_t k_transpose_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, k_transpose_stride, layout, - NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose); - - int64_t s_dim[4] = {b, h, s_q, s_kv}; - int64_t s_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, s_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix); - - auto kTransposeTensor = tensor_create_with_offset(tensorType, tensor_name_to_uid["K_TRANSPOSE"], - k_transpose_dim, k_transpose_stride, false, - false, QKVRaggedOffsetTensor); // is virtual - - // First GEMM output - auto afterQKTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 1, s_dim, - s_stride, true, false); // is virtual - - // Define the matmul desc - auto matmulDesc = cudnn_frontend::MatMulDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setPaddingValue(-2000000) - .build(); - - // Create reshape node for K -> K.T - auto reshape_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR) - .setxDesc(kTensor) - .setyDesc(kTransposeTensor) - .build(); - - // Create a matmul Node - auto matmulOp = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(qTensor) - .setbMatDesc(kTransposeTensor) - .setcMatDesc(afterQKTensor) - .setmOverrideDesc(mnkOverride) - .setnOverrideDesc(mnkOverride) - .setmatmulDesc(matmulDesc) - .build(); - - ops->push_back(std::move(reshape_op)); - ops->push_back(std::move(matmulOp)); - - return afterQKTensor; -} - -static cudnn_frontend::Tensor createSVBMM( - int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, NVTE_QKV_Layout layout, - cudnnDataType_t tensorType, std::vector* ops, - const cudnn_frontend::Tensor& softmaxTensor, const cudnn_frontend::Tensor& mnkOverride, - std::shared_ptr QKVRaggedOffsetTensor) { - NVTE_CHECK(ops->size() > 0, "BMM2 op constructed incorrectly as the first one"); - - int64_t v_dim[4] = {b, h, s_kv, d}; - int64_t v_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, v_stride, layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - - int64_t o_dim[4] = {b, h, s_q, d}; - int64_t o_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, o_stride, layout, NVTE_QKV_Matrix::NVTE_O_Matrix); - - auto vTensor = tensor_create_with_offset(tensorType, tensor_name_to_uid["V"], v_dim, v_stride, - false, false, QKVRaggedOffsetTensor); - // Second fprop GEMM output - auto oTensor = tensor_create(tensorType, tensor_name_to_uid["VIRTUAL"] + 300, o_dim, o_stride, - true, false); // is virtual - - // Define the matmul desc - auto matmulDesc = cudnn_frontend::MatMulDescBuilder().setComputeType(CUDNN_DATA_FLOAT).build(); - - // Create a matmul Node - auto matmulOp = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(softmaxTensor) - .setbMatDesc(vTensor) - .setcMatDesc(oTensor) - .setmOverrideDesc(mnkOverride) - .setkOverrideDesc(mnkOverride) - .setmatmulDesc(matmulDesc) - .build(); - - ops->push_back(std::move(matmulOp)); - - return oTensor; -} - -static cudnn_frontend::Tensor createSdOBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, - int64_t d, cudnnDataType_t tensorType, - std::vector* ops, - const cudnn_frontend::Tensor& softmaxTensor, - const cudnn_frontend::Tensor& dOTensor, - const cudnn_frontend::Tensor& mnkOverride) { - NVTE_CHECK(ops->size() > 0, "BMM2 op constructed incorrectly as the first one"); - - int64_t s_dim_transpose[4] = {b, h, s_kv, s_q}; - int64_t s_stride_transpose[4] = {h * s_kv * s_q, s_kv * s_q, 1, s_kv}; - - int64_t v_dim[4] = {b, h, s_kv, d}; - int64_t v_stride[4] = {h * s_kv * d, d, h * d, 1}; - - auto sTransposeTensor = - tensor_create(tensorType, tensor_name_to_uid["VIRTUAL"] + 499, s_dim_transpose, - s_stride_transpose, true, false); // is virtual - // S.T * dO - auto dVTensor_before_dequan_S = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 500, v_dim, v_stride, true, - false); // is virtual - - // Create reshape node for softmax -> softmax.T - auto reshape_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR) - .setxDesc(softmaxTensor) - .setyDesc(sTransposeTensor) - .build(); - - // Define the matmul desc - auto matmulDesc = cudnn_frontend::MatMulDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setPaddingValue(0) - .build(); - - // Create a matmul Node - auto matmulOp = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(sTransposeTensor) - .setbMatDesc(dOTensor) - .setcMatDesc(dVTensor_before_dequan_S) - .setmOverrideDesc(mnkOverride) - .setkOverrideDesc(mnkOverride) - .setmatmulDesc(matmulDesc) - .build(); - - ops->push_back(std::move(reshape_op)); - ops->push_back(std::move(matmulOp)); - - return dVTensor_before_dequan_S; -} - -static cudnn_frontend::Tensor createdOVBMM( - int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, NVTE_QKV_Layout layout, - cudnnDataType_t tensorType, std::vector* ops, - const cudnn_frontend::Tensor& dOTensor, const cudnn_frontend::Tensor& mnkOverride, - std::shared_ptr QKVRaggedOffsetTensor) { - // Creates the necessary tensor descriptors - int64_t v_dim[4] = {b, h, s_kv, d}; - int64_t v_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, v_stride, layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - - int64_t v_transpose_dim[4] = {b, h, d, s_kv}; - int64_t v_transpose_stride[4]; - v_transpose_stride[0] = v_stride[0]; - v_transpose_stride[1] = v_stride[1]; - v_transpose_stride[2] = v_stride[3]; - v_transpose_stride[3] = v_stride[2]; - - int64_t s_dim[4] = {b, h, s_q, s_kv}; - int64_t s_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, s_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix); - - auto vTensor = tensor_create_with_offset(tensorType, tensor_name_to_uid["V"], v_dim, v_stride, - false, false, QKVRaggedOffsetTensor); - auto vTransposeTensor = tensor_create_with_offset(tensorType, tensor_name_to_uid["V_TRANSPOSE"], - v_transpose_dim, v_transpose_stride, false, - false, QKVRaggedOffsetTensor); // is virtual - - // dO * V.T - auto afterdOVTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 600, s_dim, - s_stride, true, false); // is virtual - - // Define the matmul desc - auto matmulDesc = cudnn_frontend::MatMulDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setPaddingValue(0) - .build(); - - // Create reshape node for V -> V.T - auto reshape_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR) - .setxDesc(vTensor) - .setyDesc(vTransposeTensor) - .build(); - - // Create a matmul Node - auto matmulOp = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(dOTensor) - .setbMatDesc(vTransposeTensor) - .setcMatDesc(afterdOVTensor) - .setmOverrideDesc(mnkOverride) - .setnOverrideDesc(mnkOverride) - .setmatmulDesc(matmulDesc) - .build(); - - ops->push_back(std::move(reshape_op)); - ops->push_back(std::move(matmulOp)); - - return afterdOVTensor; -} - -static cudnn_frontend::Tensor createdOAndORowReductionChain( - int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, NVTE_QKV_Layout layout, - std::vector* ops, const cudnn_frontend::Tensor& O_after_dequan, - const cudnn_frontend::Tensor& dO_after_dequan, - const cudnn_frontend::Tensor& dropoutScale_dOVt_OdO_Tensor) { - int64_t o_dim[4] = {b, h, s_q, d}; - int64_t o_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, o_stride, layout, NVTE_QKV_Matrix::NVTE_O_Matrix); - int64_t o_dim_row_sum[4] = {b, h, s_q, 1}; - int64_t o_dim_row_sum_stride[4] = {s_q * h, s_q, 1, 1}; - - auto O_dO_after_pointwise_multiply = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 700, o_dim, o_stride, true, - false); // is virtual - auto O_dO_after_dropout_scale = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 701, o_dim, o_stride, true, - false); // is virtual - auto O_dO_after_rowsum = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 702, o_dim_row_sum, - o_dim_row_sum_stride, true, false); // is virtual - - // Define the pw multiply descriptor - auto multiplyDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - - // Create a multiply Node - auto mutliply_op = binary_pw_op_create(O_after_dequan, dO_after_dequan, - O_dO_after_pointwise_multiply, multiplyDesc); - - // Create multiply node with dropout scale - auto dropout_scale_multiply_op = - binary_pw_op_create(O_dO_after_pointwise_multiply, dropoutScale_dOVt_OdO_Tensor, - O_dO_after_dropout_scale, multiplyDesc); - - // Define the reduction descriptor - auto reductionAddDesc = cudnn_frontend::ReductionDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setReductionOp(CUDNN_REDUCE_TENSOR_ADD) - .build(); - - // Create a reduction add Node - auto reductionAdd_op = - cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) - .setxDesc(O_dO_after_dropout_scale) - .setyDesc(O_dO_after_rowsum) - .setreductionDesc(reductionAddDesc) - .build(); - - ops->push_back(std::move(mutliply_op)); - ops->push_back(std::move(dropout_scale_multiply_op)); - ops->push_back(std::move(reductionAdd_op)); - - return O_dO_after_rowsum; -} - -static cudnn_frontend::Tensor createBiasSubtractionSoftmaxMulChain( - int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, NVTE_QKV_Layout layout, - std::vector* ops, const cudnn_frontend::Tensor& dS_after_dropout, - const cudnn_frontend::Tensor& AfterDropout_before_quan_S, - const cudnn_frontend::Tensor& O_dO_after_rowsum, const cudnn_frontend::Tensor& attnScale) { - int64_t o_dim[4] = {b, h, s_q, s_kv}; - int64_t o_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, o_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix); - auto dS_minus_O_dO = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 800, o_dim, - o_stride, true, false); // is virtual - auto AfterAttnScale_before_dS = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 801, o_dim, o_stride, true, - false); // is virtual - auto S_mul_dS_minus_O_dO = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 802, - o_dim, o_stride, true, false); // is virtual - - // Define the pw subtraction descriptor - auto subDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB); - - // Create a subtraction Node - auto sub_op = binary_pw_op_create(dS_after_dropout, O_dO_after_rowsum, dS_minus_O_dO, subDesc); - - // Define the pw multiplication descriptor - auto multiplyDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - - // dS_minus_O_dO * attnScale - auto mutliply_attn_scale_op = - binary_pw_op_create(dS_minus_O_dO, attnScale, AfterAttnScale_before_dS, multiplyDesc); - - // AfterDropout_before_quan_S * AfterAttnScale_before_dS - auto mutliply_op = binary_pw_op_create(AfterDropout_before_quan_S, AfterAttnScale_before_dS, - S_mul_dS_minus_O_dO, multiplyDesc); - - ops->push_back(std::move(sub_op)); - ops->push_back(std::move(mutliply_attn_scale_op)); - ops->push_back(std::move(mutliply_op)); - - return S_mul_dS_minus_O_dO; -} - -static cudnn_frontend::Tensor createdSKBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, - int64_t d, std::vector* ops, - const cudnn_frontend::Tensor& dSTensor, - const cudnn_frontend::Tensor& kTensor, - const cudnn_frontend::Tensor& mnkOverride) { - // Creates the necessary tensor descriptors - int64_t after_dSK_dim[4] = {b, h, s_kv, d}; - int64_t after_dSK_stride[4] = {h * s_kv * d, d, h * d, 1}; - // dS * K - auto After_dS_K = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 875, - after_dSK_dim, after_dSK_stride, true, false); // is virtual - - // Define the matmul desc - auto matmulDesc = cudnn_frontend::MatMulDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setPaddingValue(0) - .build(); - - // Create a matmul Node - auto matmulOp = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(dSTensor) - .setbMatDesc(kTensor) - .setcMatDesc(After_dS_K) - .setmOverrideDesc(mnkOverride) - .setkOverrideDesc(mnkOverride) - .setmatmulDesc(matmulDesc) - .build(); - - ops->push_back(std::move(matmulOp)); - - return After_dS_K; -} - -static cudnn_frontend::Tensor createdSQBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, - int64_t d, NVTE_QKV_Layout layout, - std::vector* ops, - const cudnn_frontend::Tensor& dSTensor, - const cudnn_frontend::Tensor& qTensor, - const cudnn_frontend::Tensor& mnkOverride) { - // Creates the necessary tensor descriptors - int64_t dS_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, dS_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix); - - int64_t dS_transpose_dim[4] = {b, h, s_kv, s_q}; - int64_t dS_transpose_stride[4]; - dS_transpose_stride[0] = dS_stride[0]; - dS_transpose_stride[1] = dS_stride[1]; - dS_transpose_stride[2] = dS_stride[3]; - dS_transpose_stride[3] = dS_stride[2]; - - int64_t after_dSTranspose_Q_dim[4] = {b, h, s_kv, d}; - int64_t after_dSTranspose_Q_stride[4] = {h * s_kv * d, d, h * d, 1}; - - auto dSTransposeTensor = - tensor_create(CUDNN_DATA_FP8_E5M2, tensor_name_to_uid["VIRTUAL"] + 650, dS_transpose_dim, - dS_transpose_stride, true, false); // is virtual - - // dS.T * Q - auto After_dSTranspose_Q = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 651, after_dSTranspose_Q_dim, - after_dSTranspose_Q_stride, true, false); // is virtual - - // Create reshape node for V -> V.T - auto reshape_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR) - .setxDesc(dSTensor) - .setyDesc(dSTransposeTensor) - .build(); - - // Define the matmul desc - auto matmulDesc = cudnn_frontend::MatMulDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setPaddingValue(0) - .build(); - - // Create a matmul Node - auto matmulOp = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(dSTransposeTensor) - .setbMatDesc(qTensor) - .setcMatDesc(After_dSTranspose_Q) - .setmOverrideDesc(mnkOverride) - .setkOverrideDesc(mnkOverride) - .setmatmulDesc(matmulDesc) - .build(); - - ops->push_back(std::move(reshape_op)); - ops->push_back(std::move(matmulOp)); - - return After_dSTranspose_Q; -} - -// fused attention FWD FP8 with FE 0.9 -void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, - bool isTraining, float attnScale, float dropoutProbability, - NVTE_QKV_Layout layout, void* devPtrQ, void* devPtrK, void* devPtrV, - void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrDescaleQ, - void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleS, - void* devPtrScaleS, void* devPtrScaleO, void* devPtrAmaxO, - void* devPtrAmaxS, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, - void* devPtrDropoutSeed, void* devPtrDropoutOffset, - cudnnDataType_t tensorType, void* workspace_ptr, - size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle_) { - try { - FADescriptor descriptor{b, - h, - s_q, - s_kv, - d, - attnScale, - isTraining, - dropoutProbability, - layout, - NVTE_Bias_Type::NVTE_NO_BIAS, - NVTE_Mask_Type::NVTE_PADDING_MASK, - tensorType, - false}; - - using CacheType = std::map; - static thread_local CacheType fa_fprop_cache; - - // Get plan from cache if cache is available, otherwise create one - auto get_plan = [&](CacheType& cache, const FADescriptor& descriptor) { - // If hit, return - auto it = cache.find(descriptor); - if (it != cache.end()) { - auto plan = it->second; - return plan; - } - - // Otherwise, build the op_graph and the plan. Then update cache - std::vector all_ops; - std::vector ops; - - NVTE_CHECK(dropoutProbability == 0.0f || isTraining, - "Dropout probability should be 0.0f for inference mode"); - NVTE_CHECK(dropoutProbability != 1.0f, "Dropout probability cannot be 1.0"); - - int64_t raggedDim[4] = {b + 1, 1, 1, 1}; - int64_t raggedStride[4] = {1, 1, 1, 1}; - // Create offset tensors - auto QKVOffsetTensor = tensor_create(CUDNN_DATA_INT32, tensor_name_to_uid["QKV_RAGGED"], - raggedDim, raggedStride, false, false); - auto ORaggedOffsetTensor = tensor_create(CUDNN_DATA_INT32, tensor_name_to_uid["O_RAGGED"], - raggedDim, raggedStride, false, false); - - int64_t seqlen_dim[4] = {b, 1, 1, 1}; - int64_t seqlen_stride[4] = {1, 1, 1, 1}; - // Create override tensors - auto seqlenMNKTensor = tensor_create(CUDNN_DATA_INT32, tensor_name_to_uid["MNK_OVERRIDE"], - seqlen_dim, seqlen_stride, false, false); - - // Create shared ptrs to ragged offset tensors - // for multiple tensors to use ragged offset - std::shared_ptr QKVRaggedOffsetTensorPtr = - std::make_shared(std::move(QKVOffsetTensor)); - std::shared_ptr ORaggedOffsetTensorPtr = - std::make_shared(std::move(ORaggedOffsetTensor)); - - // Create Q and K tensors that are used in different places - int64_t q_dim[4] = {b, h, s_q, d}; - int64_t q_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, q_stride, layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - - int64_t k_dim[4] = {b, h, s_kv, d}; - int64_t k_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, k_stride, layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - - auto qTensor = tensor_create_with_offset(tensorType, tensor_name_to_uid["Q"], q_dim, q_stride, - false, false, QKVRaggedOffsetTensorPtr); - auto kTensor = tensor_create_with_offset(tensorType, tensor_name_to_uid["K"], k_dim, k_stride, - false, false, QKVRaggedOffsetTensorPtr); - - // Q * K.T - auto afterQKTensor = createQKBMM(b, h, s_q, s_kv, d, layout, tensorType, &ops, qTensor, - kTensor, seqlenMNKTensor, QKVRaggedOffsetTensorPtr); - - // QK.T * attn scale - auto AfterAttnScale_before_dequan_Q_tensor = - createScale(afterQKTensor, // input tensor - "AttnScale", // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - true, // scale is by value - &ops); - - // QK.T * attn scale * dequant_Q - auto AfterAttnScale_before_dequan_K_tensor = - createScale(AfterAttnScale_before_dequan_Q_tensor, // input tensor - "descaleQ", // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops); - - // QK.T * attn scale * dequant_Q * dequant_K - auto AfterAttnScale_tensor = - createScale(AfterAttnScale_before_dequan_K_tensor, // input tensor - "descaleK", // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops); - - auto BeforeDropoutTensor = - createSoftmaxForward(b, h, s_q, s_kv, &ops, AfterAttnScale_tensor, isTraining); - - auto AfterDropout_before_quan_S = - createDropoutForward(b, h, s_q, s_kv, dropoutProbability, &ops, BeforeDropoutTensor); - - // Amax for S - createAmax("amaxS", BeforeDropoutTensor, &ops); - - // After softmax * dropout * scale S -> fp8 input to next bmm with V - auto AfterMultiplyDropout = createScale(AfterDropout_before_quan_S, // input tensor - "scaleS", // scale tensor - tensorType, // output tensor type - true, // output is virtual - false, // scale is by value - &ops); - - // After softmax * Dropout * V - auto OTensor_before_dequan_S_tensor = - createSVBMM(b, h, s_q, s_kv, d, layout, tensorType, &ops, AfterMultiplyDropout, - seqlenMNKTensor, QKVRaggedOffsetTensorPtr); - - // O * dequant_S - auto OTensor_before_dequan_V_tensor = - createScale(OTensor_before_dequan_S_tensor, // input tensor - "descaleS", // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops); - - // O * dequant_S * dequant_V - auto OTensor_before_quan_O_tensor = - createScale(OTensor_before_dequan_V_tensor, // input tensor - "descaleV", // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops); - - // O * dequant_S * dequant_V * scale O - auto OTensor = createScaleWithOffset(OTensor_before_quan_O_tensor, // input tensor - "scaleO", // scale tensor - layout, // qkv layout - tensorType, // output tensor type - false, // output not virtual - false, // scale is by value - &ops, - ORaggedOffsetTensorPtr, // ragged offset - "O"); - - // Amax for O - createAmax("amaxO", OTensor_before_quan_O_tensor, &ops); - - for (unsigned int i = 0; i < ops.size(); i++) { - all_ops.push_back(&ops[i]); - } - - // Create an Operation Graph - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(all_ops.size(), all_ops.data()) - .build(); - - cudnn_frontend::EngineConfigList filtered_configs; - auto statuses = cudnn_frontend::get_heuristics_list<1>( - {"heuristics_instant"}, opGraph, allowAllConfig, filtered_configs, true); - - if (filtered_configs.size() == 0) { - cudnn_frontend::set_error_and_throw_exception( - nullptr, CUDNN_STATUS_NOT_SUPPORTED, - "run_mha_fprop: No config returned by the heuristics"); - } - - auto plan = cudnn_frontend::ExecutionPlanBuilder() - .setHandle(handle_) - .setEngineConfig(filtered_configs[0], opGraph.getTag()) - .build(); - cache.insert({descriptor, plan}); - return plan; - }; // end of get_plan - - auto plan = get_plan(fa_fprop_cache, descriptor); - size_t wkspace_size = static_cast(plan.getWorkspaceSize()); - - // Exit to request upper level API to allocate memory if needed - if (workspace_ptr == nullptr) { - *workspace_size = wkspace_size + ((b + 1) * 2 + b) * sizeof(int32_t); - return; - } - - // cuDNN stream check needs to be moved here to support dummy kernel calls with - // null streams for sizing the cuDNN workspace. - NVTE_CHECK_CUDNN(cudnnSetStream(handle_, stream)); - - int32_t* qkv_ragged_offset = - reinterpret_cast(reinterpret_cast(workspace_ptr) + wkspace_size); - int32_t* o_ragged_offset = reinterpret_cast(reinterpret_cast(workspace_ptr) + - wkspace_size + (b + 1) * sizeof(int32_t)); - int32_t* actual_seqlens_q = reinterpret_cast( - reinterpret_cast(workspace_ptr) + wkspace_size + (b + 1) * 2 * sizeof(int32_t)); - // FP8 currently only supports self-attention, so doesn't use devPtrcuSeqlensKV - dim3 blockDims(128); - dim3 gridDims((b + blockDims.x) / blockDims.x); - cu_seqlens_to_offsets<<>>( - b, h, d, reinterpret_cast(devPtrcuSeqlensQ), actual_seqlens_q, qkv_ragged_offset, - o_ragged_offset); - NVTE_CHECK_CUDA(cudaGetLastError()); - void* devPtrQKVRaggedOffset = reinterpret_cast(qkv_ragged_offset); - void* devPtrORaggedOffset = reinterpret_cast(o_ragged_offset); - void* devPtrMNKOverride = reinterpret_cast(actual_seqlens_q); - - float dropoutScale = 1.0f / (1.0f - dropoutProbability); - - std::set> data_ptrs; - data_ptrs.emplace(std::pair(tensor_name_to_uid["Q"], devPtrQ)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["K"], devPtrK)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["K_TRANSPOSE"], devPtrK)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["V"], devPtrV)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["AttnScale"], &attnScale)); - data_ptrs.emplace( - std::pair(tensor_name_to_uid["DROPOUT_SCALE"], &dropoutScale)); - data_ptrs.emplace( - std::pair(tensor_name_to_uid["DROPOUT_SEED"], devPtrDropoutSeed)); - data_ptrs.emplace( - std::pair(tensor_name_to_uid["DROPOUT_OFFSET"], devPtrDropoutOffset)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["O"], devPtrO)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["descaleQ"], devPtrDescaleQ)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["descaleK"], devPtrDescaleK)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["descaleV"], devPtrDescaleV)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["descaleS"], devPtrDescaleS)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["scaleS"], devPtrScaleS)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["scaleO"], devPtrScaleO)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["amaxO"], devPtrAmaxO)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["amaxS"], devPtrAmaxS)); - data_ptrs.emplace( - std::pair(tensor_name_to_uid["QKV_RAGGED"], devPtrQKVRaggedOffset)); - data_ptrs.emplace( - std::pair(tensor_name_to_uid["O_RAGGED"], devPtrORaggedOffset)); - data_ptrs.emplace( - std::pair(tensor_name_to_uid["MNK_OVERRIDE"], devPtrMNKOverride)); - - // If training, then we need to write out M and Z_INV - if (isTraining) { - data_ptrs.emplace(std::pair(tensor_name_to_uid["M"], devPtrM)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["Z_INV"], devPtrZInv)); - } - - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(data_ptrs) - .build(); - - NVTE_CHECK_CUDNN(cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc())); - } catch (cudnn_frontend::cudnnException& e) { - struct cudaDeviceProp prop; - NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, 0)); - - // This example is only for GH100 cards (cudnn Version >= 8900) - if (!((prop.major == 9 && prop.minor == 0 && CUDNN_VERSION >= 8900)) && - (e.getCudnnStatus() == CUDNN_STATUS_ARCH_MISMATCH || - e.getCudnnStatus() == CUDNN_STATUS_NOT_SUPPORTED)) { - std::cout << "Example is only supported for GH100 (cuDNN >= 8900) GPUs" << std::endl; - } else { - std::cout << "[ERROR] Exception " << e.what() << std::endl; - } - } -} - -// fused attention BWD FP8 with FE 0.9 -void fused_attn_fp8_bwd_impl( - int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, float attnScale, - float dropoutProbability, NVTE_QKV_Layout layout, void* devPtrQ, void* devPtrK, void* devPtrV, - void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrdQ, void* devPtrdK, - void* devPtrdV, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, - void* devPtrDescaleO, void* devPtrDescaledO, void* devPtrDescaleS, void* devPtrDescaledS, - void* devPtrScaleS, void* devPtrScaledS, void* devPtrScaledQ, void* devPtrScaledK, - void* devPtrScaledV, void* devPtrAmaxdS, void* devPtrAmaxdQ, void* devPtrAmaxdK, - void* devPtrAmaxdV, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, - void* devPtrDropoutOffset, cudnnDataType_t tensorType, void* workspace_ptr, - size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle_) { - try { - FADescriptor descriptor{b, - h, - s_q, - s_kv, - d, - attnScale, - false, - dropoutProbability, - layout, - NVTE_Bias_Type::NVTE_NO_BIAS, - NVTE_Mask_Type::NVTE_PADDING_MASK, - tensorType, - false}; - - using CacheType = std::map; - static thread_local CacheType fa_bprop_cache; - - // Get plan from cache if cache is available, otherwise create one - auto get_plan = [&](CacheType& cache, const FADescriptor& descriptor) { - // If hit, return - auto it = cache.find(descriptor); - if (it != cache.end()) { - auto plan = it->second; - return plan; - } - - // Otherwise, build the op_graph and the plan. Then update cache - std::vector all_ops; - std::vector ops; - - NVTE_CHECK(dropoutProbability != 1.0f, "Dropout probability cannot be 1.0"); - - int64_t raggedDim[4] = {b + 1, 1, 1, 1}; - int64_t raggedStride[4] = {1, 1, 1, 1}; - // Create offset tensors - auto QKVOffsetTensor = tensor_create(CUDNN_DATA_INT32, tensor_name_to_uid["QKV_RAGGED"], - raggedDim, raggedStride, false, false); - auto ORaggedOffsetTensor = tensor_create(CUDNN_DATA_INT32, tensor_name_to_uid["O_RAGGED"], - raggedDim, raggedStride, false, false); - - // Create shared ptrs to ragged offset tensors for multiple tensors - std::shared_ptr QKVRaggedOffsetTensorPtr = - std::make_shared(std::move(QKVOffsetTensor)); - std::shared_ptr ORaggedOffsetTensorPtr = - std::make_shared(std::move(ORaggedOffsetTensor)); - - // Create Q and K tensors that are used in different places - int64_t q_dim[4] = {b, h, s_q, d}; - int64_t q_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, q_stride, layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - - int64_t k_dim[4] = {b, h, s_kv, d}; - int64_t k_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, k_stride, layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - - auto qTensor = tensor_create_with_offset(tensorType, tensor_name_to_uid["Q"], q_dim, q_stride, - false, false, QKVRaggedOffsetTensorPtr); - auto kTensor = tensor_create_with_offset(tensorType, tensor_name_to_uid["K"], k_dim, k_stride, - false, false, QKVRaggedOffsetTensorPtr); - - int64_t scale_dim[4] = {1, 1, 1, 1}; - int64_t scale_stride[4] = {1, 1, 1, 1}; - - // Create attnScale tensor for multiple ops to use - auto attnScaleTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["AttnScale"], - scale_dim, scale_stride, false, true); // is by value - - // Create descale Q K dO dS global tensors since they are used in multiple places - auto descaleQTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["descaleQ"], - scale_dim, scale_stride, false, false); - auto descaleKTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["descaleK"], - scale_dim, scale_stride, false, false); - auto descaledOTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["descaledO"], - scale_dim, scale_stride, false, false); - auto descaledSTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["descaledS"], - scale_dim, scale_stride, false, false); - - int64_t seqlen_dim[4] = {b, 1, 1, 1}; - int64_t seqlen_stride[4] = {1, 1, 1, 1}; - // Create MNK override tensor - auto seqlenMNKTensor = tensor_create(CUDNN_DATA_INT32, tensor_name_to_uid["MNK_OVERRIDE"], - seqlen_dim, seqlen_stride, false, false); - - int64_t O_dim[4] = {b, h, s_q, d}; - int64_t O_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, O_stride, layout, NVTE_QKV_Matrix::NVTE_O_Matrix); - // Create O and loss tensor - auto OTensor = tensor_create_with_offset(tensorType, tensor_name_to_uid["O"], O_dim, O_stride, - false, false, ORaggedOffsetTensorPtr); - // dO is used in multiple places and E5M2 - auto dOTensor = - tensor_create_with_offset(CUDNN_DATA_FP8_E5M2, tensor_name_to_uid["dO"], O_dim, O_stride, - false, false, ORaggedOffsetTensorPtr); - - // Q * K.T - auto afterQKTensor = createQKBMM(b, h, s_q, s_kv, d, layout, tensorType, &ops, qTensor, - kTensor, seqlenMNKTensor, QKVRaggedOffsetTensorPtr); - - // QK.T * attn scale - auto AfterAttnScale_before_dequan_Q_tensor = - createScale(afterQKTensor, // input tensor - attnScaleTensor, // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - true, // scale is by value - &ops, 1999 /*UID offset*/); - - // QK.T * attn scale * dequant_Q - auto AfterAttnScale_before_dequan_K_tensor = - createScale(AfterAttnScale_before_dequan_Q_tensor, // input tensor - descaleQTensor, // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops, 2000 /*UID offset*/); - - // QK.T * attn scale * dequant_Q * dequant_K - auto AfterAttnScale_tensor = - createScale(AfterAttnScale_before_dequan_K_tensor, // input tensor - descaleKTensor, // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops, 2001 /*UID offset*/); - - auto beforeDropout_QKt_Tensor = - createSoftmaxBackward(b, h, s_q, s_kv, &ops, AfterAttnScale_tensor); - - int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; - int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; - - // mask for the dropout. Used in different places - auto dropoutMaskTensor = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 200, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - - auto AfterDropout_before_quan_S = createDropoutBackward( - b, h, s_q, s_kv, dropoutProbability, &ops, beforeDropout_QKt_Tensor, dropoutMaskTensor); - - // After softmax * scale S -> fp8 input to next bmm with V - auto AfterMultiply = createScale(AfterDropout_before_quan_S, // input tensor - "scaleS", // scale tensor - tensorType, // output tensor type - true, // output is virtual - false, // scale is by value - &ops); - - // After softmax * dO - auto dVTensor_before_dequan_S = createSdOBMM(b, h, s_q, s_kv, d, tensorType, &ops, - AfterMultiply, dOTensor, seqlenMNKTensor); - - // O * dequant_S - auto dVTensor_before_dequan_dO = createScale(dVTensor_before_dequan_S, // input tensor - "descaleS", // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops); - - // O * dequant_S * dequant_dO - auto dVTensor_before_quan_dV = createScale(dVTensor_before_dequan_dO, // input tensor - descaledOTensor, // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops, 2002 /*UID offset*/); - - // O * dequant_S * dequant_dO * scale dV - auto dVTensor = createScaleWithOffset(dVTensor_before_quan_dV, // input tensor - "scaledV", // scale tensor - layout, // qkv layout - CUDNN_DATA_FP8_E5M2, // output tensor type - false, // output not virtual - false, // scale is by value - &ops, - QKVRaggedOffsetTensorPtr, // ragged offset - "dV" /*Output tensor name*/); - - // Amax for dV - createAmax("amaxdV", dVTensor_before_quan_dV, &ops); - - auto dS_before_dequan_dO_Tensor = - createdOVBMM(b, h, s_q, s_kv, d, layout, tensorType, &ops, dOTensor, seqlenMNKTensor, - QKVRaggedOffsetTensorPtr); - - // dS * dequant_dO - auto dS_before_dequan_V = createScale(dS_before_dequan_dO_Tensor, // input tensor - descaledOTensor, // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops, 2003 /*UID offset*/); - - // O * dequant_S * dequant_dV - auto dS_after_dequan = createScale(dS_before_dequan_V, // input tensor - "descaleV", // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops); - - // RNG Multiply - auto beforeDropoutScale_dOVt_Tensor = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 350, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - // After dropout mask and scale - auto dS_after_dropout = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 351, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - - // Define the multiply mask descriptor - auto mulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - - // Create a multiply mask Node - auto maskMul_op = binary_pw_op_create(dS_after_dequan, dropoutMaskTensor, - beforeDropoutScale_dOVt_Tensor, mulDesc); - - ops.push_back(std::move(maskMul_op)); - - // scale after dropout for dO and O chain - auto dropoutScale_dOVt_OdO_Tensor = - tensor_create(tensorType, tensor_name_to_uid["DROPOUT_SCALE_dOVt_OdO"], scale_dim, - scale_stride, false, true); // is by value - - // Create a multiply dropout scale Node - auto mul_dropout_scale_op = binary_pw_op_create( - beforeDropoutScale_dOVt_Tensor, dropoutScale_dOVt_OdO_Tensor, dS_after_dropout, mulDesc); - - ops.push_back(std::move(mul_dropout_scale_op)); - - // O * dequant_O - auto O_after_dequan_Tensor = createScale(OTensor, // input tensor - "descaleO", // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops); - - // dO * dequant_dO - auto dO_after_dequan_Tensor = createScale(dOTensor, // input tensor - descaledOTensor, // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops, 2004 /*UID offset*/); - - // row reduction sum[(dO * dequant_dO) * (O * dequant_O) * (1 - p)] - auto O_dO_after_rowsum = - createdOAndORowReductionChain(b, h, s_q, s_kv, d, layout, &ops, O_after_dequan_Tensor, - dO_after_dequan_Tensor, dropoutScale_dOVt_OdO_Tensor); - - // (dS_after_dropout - O_dO_after_rowsum) * AfterDropout_before_quan_S * attnScale - auto S_mul_dS_minus_O_dO = createBiasSubtractionSoftmaxMulChain( - b, h, s_q, s_kv, d, layout, &ops, dS_after_dropout, AfterDropout_before_quan_S, - O_dO_after_rowsum, attnScaleTensor); - - // S_mul_dS_minus_O_dO * scaledS - auto S_mul_dS_minus_O_dO_after_quan_dS = - createScale(S_mul_dS_minus_O_dO, // input tensor - "scaledS", // scale tensor - CUDNN_DATA_FP8_E5M2, // output tensor type - true, // output is virtual - false, // scale is by value - &ops); - - // Amax for dS - createAmax("amaxdS", S_mul_dS_minus_O_dO, &ops); - - // dS @ K - auto After_dS_K = createdSKBMM(b, h, s_q, s_kv, d, &ops, S_mul_dS_minus_O_dO_after_quan_dS, - kTensor, seqlenMNKTensor); - - // (dS * K) * descale dS - auto After_dS_K_before_dequan_K = createScale(After_dS_K, // input tensor - descaledSTensor, // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops, 2006 /*UID offset*/); - - // (dS * K) * descale dS * descale K - auto After_dS_K_before_quan_dQ = createScale(After_dS_K_before_dequan_K, // input tensor - descaleKTensor, // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops, 2007 /*UID offset*/); - - // (dS * K) * descale dS * descale K * scale dQ - auto dQ = createScaleWithOffset(After_dS_K_before_quan_dQ, // input tensor - "scaledQ", // scale tensor - layout, // qkv layout - CUDNN_DATA_FP8_E5M2, // output tensor type - false, // output not virtual - false, // scale is by value - &ops, - QKVRaggedOffsetTensorPtr, // ragged offset - "dQ"); - - // Amax for dQ - createAmax("amaxdQ", After_dS_K_before_quan_dQ, &ops); - - // dS.T @ Q - auto After_dSTranspose_Q = - createdSQBMM(b, h, s_q, s_kv, d, layout, &ops, S_mul_dS_minus_O_dO_after_quan_dS, qTensor, - seqlenMNKTensor); - - // (dS.T * Q) * descale dS - auto After_dSTranspose_Q_before_dequan_Q = - createScale(After_dSTranspose_Q, // input tensor - descaledSTensor, // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops, 2009 /*UID offset*/); - - // (dS.T * Q) * descale dS * descale Q - auto After_dSTranspose_Q_before_quan_dK = - createScale(After_dSTranspose_Q_before_dequan_Q, // input tensor - descaleQTensor, // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops, 2010 /*UID offset*/); - - // (dS.T * Q) * descale dS * descale Q * scale dK - auto dK = createScaleWithOffset(After_dSTranspose_Q_before_quan_dK, // input tensor - "scaledK", // scale tensor - layout, // qkv layout - CUDNN_DATA_FP8_E5M2, // output tensor type - false, // output not virtual - false, // scale is by value - &ops, - QKVRaggedOffsetTensorPtr, // ragged offset - "dK"); - - // Amax for dK - createAmax("amaxdK", After_dSTranspose_Q_before_quan_dK, &ops); - - for (unsigned int i = 0; i < ops.size(); i++) { - all_ops.push_back(&ops[i]); - } - - // Create an Operation Graph - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(all_ops.size(), all_ops.data()) - .build(); - - cudnn_frontend::EngineConfigList filtered_configs; - auto statuses = cudnn_frontend::get_heuristics_list<1>( - {"heuristics_instant"}, opGraph, allowAllConfig, filtered_configs, true); - - if (filtered_configs.size() == 0) { - cudnn_frontend::set_error_and_throw_exception( - nullptr, CUDNN_STATUS_NOT_SUPPORTED, - "run_mha_bprop: No config returned by the heuristics"); - } - - auto plan = cudnn_frontend::ExecutionPlanBuilder() - .setHandle(handle_) - .setEngineConfig(filtered_configs[0], opGraph.getTag()) - .build(); - cache.insert({descriptor, plan}); - return plan; - }; - - auto plan = get_plan(fa_bprop_cache, descriptor); - size_t wkspace_size = static_cast(plan.getWorkspaceSize()); - - // Exit to request upper level API to allocate memory if needed - if (workspace_ptr == nullptr) { - *workspace_size = wkspace_size + ((b + 1) * 2 + b) * sizeof(int32_t); - return; - } - - // cuDNN stream check needs to be moved here to support dummy kernel calls with - // null streams for sizing the cuDNN workspace. - NVTE_CHECK_CUDNN(cudnnSetStream(handle_, stream)); - - int32_t* qkv_ragged_offset = - reinterpret_cast(reinterpret_cast(workspace_ptr) + wkspace_size); - int32_t* o_ragged_offset = reinterpret_cast(reinterpret_cast(workspace_ptr) + - wkspace_size + (b + 1) * sizeof(int32_t)); - int32_t* actual_seqlens_q = reinterpret_cast( - reinterpret_cast(workspace_ptr) + wkspace_size + (b + 1) * 2 * sizeof(int32_t)); - // FP8 currently only supports self-attention, so doesn't use devPtrcuSeqlensKV - dim3 blockDims(128); - dim3 gridDims((b + blockDims.x) / blockDims.x); - cu_seqlens_to_offsets<<>>( - b, h, d, reinterpret_cast(devPtrcuSeqlensQ), actual_seqlens_q, qkv_ragged_offset, - o_ragged_offset); - NVTE_CHECK_CUDA(cudaGetLastError()); - void* devPtrQKVRaggedOffset = reinterpret_cast(qkv_ragged_offset); - void* devPtrORaggedOffset = reinterpret_cast(o_ragged_offset); - void* devPtrMNKOverride = reinterpret_cast(actual_seqlens_q); - - std::set> data_ptrs; - float dropoutScale = 1.0f / (1.0f - dropoutProbability); - float dropoutScale_dOVt_OdO = 1.0f - dropoutProbability; - data_ptrs.emplace(std::pair(tensor_name_to_uid["Q"], devPtrQ)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["K"], devPtrK)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["K_TRANSPOSE"], devPtrK)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["V"], devPtrV)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["V_TRANSPOSE"], devPtrV)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["dQ"], devPtrdQ)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["dK"], devPtrdK)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["dV"], devPtrdV)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["dO"], devPtrdO)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["AttnScale"], &attnScale)); - data_ptrs.emplace( - std::pair(tensor_name_to_uid["DROPOUT_SCALE"], &dropoutScale)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["DROPOUT_SCALE_dOVt_OdO"], - &dropoutScale_dOVt_OdO)); - data_ptrs.emplace( - std::pair(tensor_name_to_uid["DROPOUT_SEED"], devPtrDropoutSeed)); - data_ptrs.emplace( - std::pair(tensor_name_to_uid["DROPOUT_OFFSET"], devPtrDropoutOffset)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["M"], devPtrM)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["Z_INV"], devPtrZInv)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["O"], devPtrO)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["descaleQ"], devPtrDescaleQ)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["descaleK"], devPtrDescaleK)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["descaleV"], devPtrDescaleV)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["descaleS"], devPtrDescaleS)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["descaledS"], devPtrDescaledS)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["descaleO"], devPtrDescaleO)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["descaledO"], devPtrDescaledO)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["scaleS"], devPtrScaleS)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["scaledS"], devPtrScaledS)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["scaledQ"], devPtrScaledQ)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["scaledK"], devPtrScaledK)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["scaledV"], devPtrScaledV)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["amaxdS"], devPtrAmaxdS)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["amaxdQ"], devPtrAmaxdQ)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["amaxdK"], devPtrAmaxdK)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["amaxdV"], devPtrAmaxdV)); - data_ptrs.emplace( - std::pair(tensor_name_to_uid["QKV_RAGGED"], devPtrQKVRaggedOffset)); - data_ptrs.emplace( - std::pair(tensor_name_to_uid["O_RAGGED"], devPtrORaggedOffset)); - data_ptrs.emplace( - std::pair(tensor_name_to_uid["MNK_OVERRIDE"], devPtrMNKOverride)); - - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(data_ptrs) - .build(); - NVTE_CHECK_CUDNN(cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc())); - } catch (cudnn_frontend::cudnnException& e) { - struct cudaDeviceProp prop; - NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, 0)); - - // This example is only for GH100 cards (cudnn Version >= 8900) - if (!((prop.major == 9 && prop.minor == 0 && CUDNN_VERSION >= 8900)) && - (e.getCudnnStatus() == CUDNN_STATUS_ARCH_MISMATCH || - e.getCudnnStatus() == CUDNN_STATUS_NOT_SUPPORTED)) { - std::cout << "Example is only supported for GH100 (cuDNN >= 8900) GPUs" << std::endl; - } else { - std::cout << "[ERROR] Exception " << e.what() << std::endl; - } - } -} - // fused attention FWD FP8 with FE 1.0+ -void fused_attn_fp8_fwd_impl_v1( +void fused_attn_fp8_fwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, bool is_training, float scaling_factor, float dropout_probability, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, void* devPtrQ, void* devPtrK, void* devPtrV, - void* devPtrSoftmaxOffset, void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrDescaleQ, + void* devPtrSoftmaxOffset, void* devPtrM, void* devPtrO, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO, void* devPtrAmaxO, void* devPtrAmaxS, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, void* devPtrDropoutOffset, @@ -2080,26 +446,26 @@ void fused_attn_fp8_fwd_impl_v1( } // fused attention BWD FP8 with FE 1.0+ -void fused_attn_fp8_bwd_impl_v1( +void fused_attn_fp8_bwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, float scaling_factor, float dropout_probability, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - bool deterministic, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, - void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrSoftmaxOffset, void* devPtrdQ, - void* devPtrdK, void* devPtrdV, void* devPtrdSoftmaxOffset, void* devPtrDescaleQ, - void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleO, void* devPtrDescaledO, - void* devPtrDescaleS, void* devPtrDescaledP, void* devPtrScaleS, void* devPtrScaledP, - void* devPtrScaledQ, void* devPtrScaledK, void* devPtrScaledV, void* devPtrAmaxdP, - void* devPtrAmaxdQ, void* devPtrAmaxdK, void* devPtrAmaxdV, void* devPtrQ_t, void* devPtrK_t, - void* devPtrdO_f16, void* devPtrdO_t, void* devPtrDescaleQ_t, void* devPtrDescaleK_t, - void* devPtrDescaledO_t, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, - void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, - cudnn_frontend::DataType_t o_tensor_type, cudnn_frontend::DataType_t do_tensor_type, - cudnn_frontend::DataType_t dqkv_tensor_type, NVTEScalingMode scaling_mode, - NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format, void* workspace, - size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { + bool deterministic, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, void* devPtrO, + void* devPtrdO, void* devPtrSoftmaxOffset, void* devPtrdQ, void* devPtrdK, void* devPtrdV, + void* devPtrdSoftmaxOffset, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, + void* devPtrDescaleO, void* devPtrDescaledO, void* devPtrDescaleS, void* devPtrDescaledP, + void* devPtrScaleS, void* devPtrScaledP, void* devPtrScaledQ, void* devPtrScaledK, + void* devPtrScaledV, void* devPtrAmaxdP, void* devPtrAmaxdQ, void* devPtrAmaxdK, + void* devPtrAmaxdV, void* devPtrQ_t, void* devPtrK_t, void* devPtrdO_f16, void* devPtrdO_t, + void* devPtrDescaleQ_t, void* devPtrDescaleK_t, void* devPtrDescaledO_t, void* devPtrcuSeqlensQ, + void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, void* devPtrDropoutOffset, + cudnn_frontend::DataType_t qkv_tensor_type, cudnn_frontend::DataType_t o_tensor_type, + cudnn_frontend::DataType_t do_tensor_type, cudnn_frontend::DataType_t dqkv_tensor_type, + NVTEScalingMode scaling_mode, NVTE_QKV_Format qkv_scale_inv_format, + NVTE_QKV_Format do_scale_inv_format, void* workspace, size_t* workspace_size, + cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto cudnn_runtime_version = cudnnGetVersion(); bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -2760,19 +1126,12 @@ void fused_attn_fp8_fwd( devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; } void* devPtrM = nullptr; - void* devPtrZInv = nullptr; if (Aux_CTX_Tensors->size == 0) { int i = 0; Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_M->data.dptr = nullptr; output_M->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; output_M->data.dtype = DType::kFloat32; - if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_ZInv->data.dptr = nullptr; - output_ZInv->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - output_ZInv->data.dtype = DType::kFloat32; - } Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = nullptr; output_rng_state->data.shape = {2}; @@ -2788,11 +1147,6 @@ void fused_attn_fp8_fwd( int i = 0; Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); devPtrM = output_M->data.dptr; - devPtrZInv = nullptr; - if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - devPtrZInv = output_ZInv->data.dptr; - } Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = rng_state->data.dptr; if (softmax_type != NVTE_VANILLA_SOFTMAX) { @@ -2819,25 +1173,17 @@ void fused_attn_fp8_fwd( NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD) || (qkv_format == NVTE_QKV_Format::NVTE_BHSD)) { - fused_attn::fused_attn_fp8_fwd_impl_v1( + fused_attn::fused_attn_fp8_fwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, is_training, attn_scale, p_dropout, qkv_layout, o_format, bias_type, mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, - devPtrV, devPtrSoftmaxOffset, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, + devPtrV, devPtrSoftmaxOffset, devPtrM, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(O_type), input_Q->scaling_mode, qkv_scale_inv_format, workspace->data.dptr, &workspace_size, stream, handle); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - fused_attn::fused_attn_fp8_fwd_impl( - batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim_qk, is_training, attn_scale, - p_dropout, qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, - devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, - devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, - devPtrDropoutOffset, get_cudnn_dtype(QKV_type), workspace->data.dptr, &workspace_size, - stream, handle); } else { - NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); + NVTE_ERROR("FP8 fused attention only supports qkv_format=BSHD, SBHD, or BHSD.\n"); } if (workspace_size > 0) { @@ -2862,11 +1208,11 @@ void fused_attn_fp8_bwd( NVTE_Softmax_Type softmax_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, bool deterministic, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, const Tensor* input_O, const Tensor* input_dO, - const Tensor* input_dO_f16, const Tensor* input_M, const Tensor* input_ZInv, - const Tensor* input_S, const Tensor* input_SoftmaxOffset, Tensor* input_output_dP, - const Tensor* output_dQ, const Tensor* output_dK, const Tensor* output_dV, - Tensor* output_dSoftmaxOffset, const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, - const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) { + const Tensor* input_dO_f16, const Tensor* input_M, const Tensor* input_S, + const Tensor* input_SoftmaxOffset, Tensor* input_output_dP, const Tensor* output_dQ, + const Tensor* output_dK, const Tensor* output_dV, Tensor* output_dSoftmaxOffset, + const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, const Tensor* rng_state, + Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; void* devPtrQ = input_Q->data.dptr; void* devPtrK = input_K->data.dptr; @@ -2899,7 +1245,6 @@ void fused_attn_fp8_bwd( } void* devPtrM = input_M->data.dptr; - void* devPtrZInv = (input_ZInv != nullptr) ? input_ZInv->data.dptr : nullptr; void *devPtrScaleS = nullptr, *devPtrDescaleS = nullptr, *devPtrAmaxdP = nullptr, *devPtrScaledP = nullptr, *devPtrDescaledP = nullptr; @@ -2949,34 +1294,22 @@ void fused_attn_fp8_bwd( NVTE_QKV_Format dqkv_format = nvte_get_qkv_format(dqkv_layout); if ((dqkv_format == NVTE_QKV_Format::NVTE_BSHD) || (dqkv_format == NVTE_QKV_Format::NVTE_SBHD) || (dqkv_format == NVTE_QKV_Format::NVTE_BHSD)) { - fused_attn::fused_attn_fp8_bwd_impl_v1( + fused_attn::fused_attn_fp8_bwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, attn_scale, p_dropout, qkv_layout, o_format, do_format, dqkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, - devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrSoftmaxOffset, - devPtrdQ, devPtrdK, devPtrdV, devPtrdSoftmaxOffset, devPtrDescaleQ, devPtrDescaleK, - devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, - devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, - devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrQ_t, devPtrK_t, devPtrdO_f16, devPtrdO_t, + devPtrQ, devPtrK, devPtrV, devPtrM, devPtrO, devPtrdO, devPtrSoftmaxOffset, devPtrdQ, + devPtrdK, devPtrdV, devPtrdSoftmaxOffset, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, + devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, + devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, + devPtrAmaxdK, devPtrAmaxdV, devPtrQ_t, devPtrK_t, devPtrdO_f16, devPtrdO_t, devPtrDescaleQ_t, devPtrDescaleK_t, devPtrDescaledO_t, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), input_dO->scaling_mode, qkv_scale_inv_format, do_scale_inv_format, workspace->data.dptr, &workspace_size, stream, handle); - } else if (dqkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - // remove this when cuDNN FE supports FP8 + THD - NVTE_CHECK(input_ZInv != nullptr && input_ZInv->data.dptr != nullptr, - "ZInv tensor required for FP8 fused attention backward with T3HD layout."); - fused_attn::fused_attn_fp8_bwd_impl( - batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim_qk, attn_scale, p_dropout, - qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, - devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, - devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, - devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, - devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, - get_cudnn_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); } else { - NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); + NVTE_ERROR("FP8 fused attention only supports dqkv_format=BSHD, SBHD, or BHSD.\n"); } if (workspace_size > 0) { diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index aaf5039eeb..b9660128ca 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -5,7 +5,7 @@ ************************************************************************/ /*! \file fused_attn_fp8.h - * \brief Functions for fused attention for FP8 with seqlen <= 512 + * \brief Functions for fused attention for FP8 */ #include "transformer_engine/fused_attn.h" @@ -34,9 +34,9 @@ void fused_attn_fp8_bwd( NVTE_Softmax_Type softmax_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_dO_f16, const Tensor *input_M, const Tensor *input_ZInv, - const Tensor *input_S, const Tensor *input_SoftmaxOffset, Tensor *input_output_dP, - const Tensor *output_dQ, const Tensor *output_dK, const Tensor *output_dV, - Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + const Tensor *input_dO_f16, const Tensor *input_M, const Tensor *input_S, + const Tensor *input_SoftmaxOffset, Tensor *input_output_dP, const Tensor *output_dQ, + const Tensor *output_dK, const Tensor *output_dV, Tensor *output_dSoftmaxOffset, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index f37eeb0c68..3e628b6581 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -411,20 +411,6 @@ cudnn_frontend::Operation ternary_pw_op_create(cudnn_frontend::Tensor const &xDe return pw_op_created; } -// convert cu_seqlens_q to qkv/o_ragged_offset and actual_seqlens_q -__global__ void cu_seqlens_to_offsets(int64_t b, int64_t h, int64_t d, int32_t *cu_seqlens_q, - int32_t *actual_seqlens_q, int32_t *qkv_ragged_offset, - int32_t *o_ragged_offset) { - size_t tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < b) { - actual_seqlens_q[tid] = cu_seqlens_q[tid + 1] - cu_seqlens_q[tid]; - } - if (tid < b + 1) { - qkv_ragged_offset[tid] = cu_seqlens_q[tid] * 3 * h * d; - o_ragged_offset[tid] = cu_seqlens_q[tid] * h * d; - } -} - // convert cu_seqlens to actual_seqlens __global__ void cu_seqlens_to_actual_seqlens(int64_t actual_b, int64_t max_b, int32_t const *const q_cu_seqlens, diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index c3736a6c65..41656062a4 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -333,10 +333,6 @@ struct FADescriptor_v1 { } }; -__global__ void cu_seqlens_to_offsets(int64_t b, int64_t h, int64_t d, int32_t *cu_seqlens_q, - int32_t *actual_seqlens_q, int32_t *qkv_ragged_offset, - int32_t *o_ragged_offset); - __global__ void cu_seqlens_to_actual_seqlens(int64_t actual_b, int64_t max_b, int32_t const *const q_cu_seqlens, int32_t const *const kv_cu_seqlens, int32_t *q_seqlens, diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 912dc32d35..77193c0721 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -233,16 +233,6 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * - D = Dropout(S) * - O = D * Transpose(V) * - * Support Matrix: - \verbatim - | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | - | 0 | FP16/BF16 | BS3HD,SB3HD,BSHD_BS2HD,SBHD_SB2HD | NO/POST_SCALE_BIAS | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | <= 512, % 64 == 0 | 64 | - | 1 | FP16/BF16 | BS3HD,SB3HD,BSH3D,SBH3D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 | - | | | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | | | | | | - | | | BSHD_BSHD_BSHD,SBHD_SBHD_SBHD | | | | | | - | 2 | FP8 | T3HD | NO_BIAS | PADDING_MASK | Yes | <= 512, % 64 == 0 | 64 | - \endverbatim - * * Notes: * * Tensors `cu_seqlens_q_padded` and `cu_seqlens_kv_padded` @@ -264,7 +254,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in,out] S The S tensor. * \param[out] O The output O tensor. * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, - * e.g. M, ZInv, rng_state. + * e.g. softmax stats, optional Max, rng_state. * \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1]. * \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1]. * \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1]. @@ -311,16 +301,6 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso /*! \brief Compute the backward of the dot product attention with separate Q, K and V. * - * Support Matrix: - \verbatim - | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | - | 0 | FP16/BF16 | BS3HD,SB3HD,BSHD_BS2HD,SBHD_SB2HD | NO/POST_SCALE_BIAS | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | <= 512, % 64 == 0 | 64 | - | 1 | FP16/BF16 | BS3HD,SB3HD,BSH3D,SBH3D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 | - | | | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | | | | | | - | | | BSHD_BSHD_BSHD,SBHD_SBHD_SBHD | | | | | | - | 2 | FP8 | T3HD | NO_BIAS | PADDING_MASK | Yes | <= 512, % 64 == 0 | 64 | - \endverbatim - * * Notes: * * Tensors `cu_seqlens_q_padded` and `cu_seqlens_kv_padded` @@ -342,7 +322,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso * \param[in] S The S tensor. * \param[in,out] dP The gradient of the P tensor. * \param[in] Aux_CTX_Tensors Auxiliary tensors from context when in training mode, - * e.g. M, ZInv, rng_state. + * e.g. softmax stats, optional Max, rng_state. * \param[out] dQ The gradient of the Q tensor. * \param[out] dK The gradient of the K tensor. * \param[out] dV The gradient of the V tensor. diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 7b10593acf..32eb1b597a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -980,10 +980,7 @@ def cp_p2p_fwd_fused_attn( ) if fp8: - if qkv_layout != "t3hd": - softmax_lse_per_step, rng_states = aux_ctx_tensors - else: - softmax_lse_per_step, _, rng_states = aux_ctx_tensors + softmax_lse_per_step, rng_states = aux_ctx_tensors else: softmax_lse_per_step, rng_states, *rest = aux_ctx_tensors attn_bias = rest[0] if len(rest) > 0 else None @@ -1169,17 +1166,7 @@ def cp_p2p_bwd_fused_attn( section, ): """Per-tile backward call of CP P2P with FusedAttention backend""" - if fp8: - if qkv_layout == "t3hd": - aux_tensors = [ - softmax_lse, - softmax_lse, - rng_states[cp_size - step - 1], - ] - else: - aux_tensors = [softmax_lse, rng_states[cp_size - step - 1]] - else: - aux_tensors = [softmax_lse, rng_states[cp_size - step - 1]] + aux_tensors = [softmax_lse, rng_states[cp_size - step - 1]] max_seqlen_q_ = max_seqlen_q max_seqlen_kv_ = max_seqlen_kv @@ -1195,17 +1182,7 @@ def cp_p2p_bwd_fused_attn( attn_mask_type_ = "padding" if "padding" in attn_mask_type else "no_mask" elif section == "upper-triangle": q_part, out_part, dout_part = [x.contiguous() for x in [q_part, out_part, dout_part]] - if fp8: - if qkv_layout == "t3hd": - aux_tensors = [ - softmax_lse_, - softmax_lse_, - rng_states[cp_size - step - 1], - ] - else: - aux_tensors = [softmax_lse_, rng_states[cp_size - step - 1]] - else: - aux_tensors = [softmax_lse_, rng_states[cp_size - step - 1]] + aux_tensors = [softmax_lse_, rng_states[cp_size - step - 1]] max_seqlen_q_ = max_seqlen_q // 2 cu_seqlens_q_padded_ = None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // 2 @@ -3223,10 +3200,7 @@ def forward( **fp8_meta_kwargs, ) if fp8: - if qkv_layout != "t3hd": - softmax_lse_per_step[i], rng_states[i] = aux_ctx_tensors - else: - softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors + softmax_lse_per_step[i], rng_states[i] = aux_ctx_tensors else: softmax_lse_per_step[i], rng_states[i], *_ = aux_ctx_tensors if return_max_logit: @@ -3588,17 +3562,10 @@ def backward(ctx, dout, *_args): out_part = out.select(seq_dim_o, i).contiguous() dout_part = dout.select(seq_dim_o, i).contiguous() if ctx.use_fused_attention: - if ctx.fp8 and ctx.qkv_layout == "t3hd": - aux_ctx_tensors = [ - softmax_lse_per_step[i], - softmax_lse_per_step[i], - rng_states[i], - ] - else: - aux_ctx_tensors = [ - softmax_lse_per_step[i], - rng_states[i], - ] + aux_ctx_tensors = [ + softmax_lse_per_step[i], + rng_states[i], + ] fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen fp8_meta_kwargs = {} new_qkv_layout = ctx.qkv_layout diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 01e139da46..5bbec87a4b 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -257,12 +257,11 @@ def fused_attn_fwd( softmaxStats: torch.Tensor log(sum(e^(x - max(x)))), where x=Q*K.T shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32 - 3. if fused_attention_backend == FusedAttnBackend["FP8"] - M: torch.Tensor - max(Q*K.T) + Max: torch.Tensor, only when return_max_logit is True shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32 - ZInv: torch.Tensor, only allocated for T3HD path - 1/sum(e^(x - max(x))), where x=Q*K.T + 3. if fused_attention_backend == FusedAttnBackend["FP8"] + softmaxStats: torch.Tensor + log(sum(e^(x - max(x)))), where x=Q*K.T shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32 rng_state: torch.Tensor, optional, if backend is not F16_max512_seqlen state of the random number generator; @@ -472,7 +471,7 @@ def fused_attn_bwd( in torch.dtype aux_ctx_tensors : List[torch.Tensor] auxiliary output tensors of the forward pass when its is_training is True, - e.g. aux_ctx_tensors = [M, ZInv, rng_state] + e.g. aux_ctx_tensors = [S, Max, rng_state] fused_attention_backend : tex.NVTE_Fused_Attn_Backend please see FusedAttention module for details on supported backends. cu_seqlens_q_padded : torch.Tensor, default = None diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index e6781bd58a..8b890e171f 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -272,21 +272,17 @@ std::vector fused_attn_fwd( }; // allocate memory for nvte_aux_tensor_pack.tensors // f16_max512 : S [b, h, sq, skv] - // f16_arbitrary: - // return_max_logit=false: S [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1] - // return_max_logit=true: S [b, h, sq, 1], Max [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1] - // fp8 : M [b, h, sq, 1], optional ZInv [b, h, sq, 1] (T3HD path), rng_state [2] + // f16_arbitrary: S [b, h, sq, 1]/[tq, h, 1], (optional) Max [b, h, sq, 1]/[tq, h, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1] + // fp8 : S [b, h, sq, 1], rng_state [2] size_t i = 0; at::Tensor output_tensor; - // intermediate softmax tensor, S or M (for fp8) + // intermediate softmax stats tensor S output_tensor = allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])), static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); set_tensor_param(i++, output_tensor); - // fp8 T3HD has an additional softmax stats tensor, ZInv; return_max_logit=true has an additional Max tensor - if (((qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) && - qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) || - return_max_logit) { + // return_max_logit=true allocates Max after S + if (return_max_logit) { output_tensor = allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])), static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false);