From bf2ba25daabc28090bb31269c6642adc4b4db97f Mon Sep 17 00:00:00 2001 From: Chen Bing <1561804820@qq.com> Date: Tue, 24 Mar 2026 21:03:50 +0800 Subject: [PATCH] =?UTF-8?q?=E6=8E=A5=E5=85=A5flashattention=202026.3.24?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitmodules | 3 + CMakeLists.txt | 11 + example/gpt2/main.cc | 10 +- example/gpt2/net.cc | 81 ++- example/gpt2/net.h | 3 +- example/llama3/main.cc | 5 +- example/llama3/net.cc | 5 +- example/llama3/net.h | 2 +- .../autograd/scaled_dot_product_attention.h | 44 ++ infini_train/include/nn/functional.h | 1 + .../nn/parallel/ddp/distributed_optimizer.h | 3 + .../nn/parallel/ddp/param_and_grad_buffer.h | 3 + .../autograd/scaled_dot_product_attention.cc | 75 +++ .../src/kernels/cuda/flash_attention.cu | 564 ++++++++++++++++++ .../src/kernels/cuda/ysyx.code-workspace | 8 + third_party/cudnn-frontend | 1 + 16 files changed, 792 insertions(+), 27 deletions(-) create mode 100644 infini_train/include/autograd/scaled_dot_product_attention.h create mode 100644 infini_train/src/autograd/scaled_dot_product_attention.cc create mode 100644 infini_train/src/kernels/cuda/flash_attention.cu create mode 100644 infini_train/src/kernels/cuda/ysyx.code-workspace create mode 160000 third_party/cudnn-frontend diff --git a/.gitmodules b/.gitmodules index 470cf466..59022b76 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,3 +7,6 @@ [submodule "third_party/eigen"] path = third_party/eigen url = git@github.com:InfiniTensor/eigen-mirror.git +[submodule "third_party/cudnn-frontend"] + path = third_party/cudnn-frontend + url = https://github.com/NVIDIA/cudnn-frontend.git diff --git a/CMakeLists.txt b/CMakeLists.txt index df636b27..d39e67ea 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -77,6 +77,12 @@ if(USE_CUDA) find_package(CUDAToolkit REQUIRED) include_directories(${CUDAToolkit_INCLUDE_DIRS}) + # ========== cuDNN 库 ========== + find_library(CUDNN_LIBRARY cudnn REQUIRED) + message(STATUS "Found cuDNN at: ${CUDNN_LIBRARY}") + # ======================================== + + # CUDA compilation options set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda --expt-relaxed-constexpr") @@ -84,6 +90,9 @@ if(USE_CUDA) file(GLOB_RECURSE CUDA_KERNELS ${PROJECT_SOURCE_DIR}/infini_train/src/*.cu) add_library(infini_train_cuda_kernels STATIC ${CUDA_KERNELS}) + target_include_directories(infini_train_cuda_kernels PUBLIC + ${PROJECT_SOURCE_DIR}/third_party/cudnn-frontend/include) + set_target_properties(infini_train_cuda_kernels PROPERTIES CUDA_ARCHITECTURES "75;80;90") target_link_libraries(infini_train_cuda_kernels @@ -92,6 +101,8 @@ if(USE_CUDA) CUDA::cudart CUDA::cublas CUDA::cuda_driver + CUDA::nvrtc + ${CUDNN_LIBRARY} ) if(USE_NCCL) diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index 8e28af52..dd1863da 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -80,6 +80,9 @@ DEFINE_string( precision_check, "", "precision check config: level=N,format=simple|table,output_md5=true|false,output_path=PATH,baseline=PATH"); +//Flash attention parameters +DEFINE_bool(flash, false, "Whether to enable flash attention"); + // LoRA parameters DEFINE_int32(lora_rank, 0, "LoRA rank (0 = disabled)"); DEFINE_double(lora_alpha, 16.0, "LoRA alpha scaling factor"); @@ -189,16 +192,17 @@ void Train(const nn::parallel::Rank &rank) { // init the model, either from scratch or from OpenAI pretrained checkpoint GPT2Config model_config; std::shared_ptr model = nullptr; - if (!FLAGS_llmc_filepath.empty()) { - model = GPT2::FromLLMC(FLAGS_llmc_filepath); + model = GPT2::FromLLMC(FLAGS_llmc_filepath, FLAGS_flash); } else if (kModelToConfigs.count(FLAGS_model)) { model_config = kModelToConfigs.at(FLAGS_model); + model_config.flash = FLAGS_flash; model = std::make_shared(model_config); } else { model = GPT2::FromPretrained(kStrToModelType.at(FLAGS_model)); + model_config.flash = FLAGS_flash; } - + model->To(device); utils::PrecisionChecker::BuildNameMap(model.get()); diff --git a/example/gpt2/net.cc b/example/gpt2/net.cc index d000d1cf..3dfcf604 100644 --- a/example/gpt2/net.cc +++ b/example/gpt2/net.cc @@ -12,6 +12,7 @@ #include #include "glog/logging.h" +#include "gflags/gflags.h" #include "example/common/utils.h" #include "infini_train/include/device.h" @@ -105,24 +106,63 @@ CausalSelfAttention::Forward(const std::vectorView({B, T, local_n_head_, head_dim})->Transpose(1, 2); v = v->View({B, T, local_n_head_, head_dim})->Transpose(1, 2); - // (B, h_l, T, T) - auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(head_dim)); - // (1, 1, T, T) - auto mask = buffers_[kParamBiasName]->Slice({0, 0, 0, 0}, {1, 1, T, T}, {1, 1, 1, 1}); - // (1, 1, T, T) -> eq 0 -> (1, 1, T, T) -> masked_fill -> (B, h_l, T, T) - att = att->MaskedFill(mask == 0, -std::numeric_limits::infinity()); - // (B, h_l, T, T) - att = nn::function::Softmax(att, -1); - // (B, h_l, T, Dh) - auto y = att->Matmul(v); - // (B, h_l, T, Dh) -> (B, T, h_l, Dh) -> (B, T, local_C) - y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_C}); - - // Get full tensor - // (B, T, local_C) -> RowParallelLinear(n_embd, n_embd) -> (B, T, C) - y = (*modules_[kCProjLayerName])({y})[0]; - // (B, T, C) == (bs, seq_len, n_embd) - return {y}; + auto Flag_flash = config_.flash; + const bool use_flash_sdpa = Flag_flash && q->Dtype() == DataType::kBFLOAT16; + static bool logged_attention_path = false; + if (!logged_attention_path) { + LOG(ERROR) << "[GPT2][AttentionPath] flash_flag=" << Flag_flash + << ", q_dtype=" << kDataTypeToDesc.at(q->Dtype()) + << ", selected=" << (use_flash_sdpa ? "cuDNN_SDPA_flash" : "matmul_softmax_fallback"); + logged_attention_path = true; + } + + std::shared_ptr y; + + if (Flag_flash) { + // cuDNN SDPA path: causal masking should be enabled by `is_causal=true`. + // Do not pass the 0/1 tril mask as additive bias (it is not -inf mask). + auto q_flash = q; + auto k_flash = k; + auto v_flash = v; + if (q->Dtype() != DataType::kBFLOAT16) { + q_flash = std::make_shared(q->To(DataType::kBFLOAT16)); + k_flash = std::make_shared(k->To(DataType::kBFLOAT16)); + v_flash = std::make_shared(v->To(DataType::kBFLOAT16)); + } + y = nn::function::ScaledDotProductAttention(q_flash, k_flash, v_flash, nullptr, 0.0, true, std::nullopt, + false); + + if (y->Dtype() != q->Dtype()) { + y = std::make_shared(y->To(q->Dtype())); + } + // ensure expected layout: (B, h_l, T, Dh) -> (B, T, h_l, Dh) -> (B, T, local_C) + y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_C}); + + // Get full tensor + // (B, T, local_C) -> RowParallelLinear(n_embd, n_embd) -> (B, T, C) + y = (*modules_[kCProjLayerName])({y})[0]; + // (B, T, C) == (bs, seq_len, n_embd) + return {y}; + } else { + // (B, h_l, T, T) + auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(head_dim)); + // (1, 1, T, T) + auto mask = buffers_[kParamBiasName]->Slice({0, 0, 0, 0}, {1, 1, T, T}, {1, 1, 1, 1}); + // (1, 1, T, T) -> eq 0 -> (1, 1, T, T) -> masked_fill -> (B, h_l, T, T) + att = att->MaskedFill(mask == 0, -std::numeric_limits::infinity()); + // (B, h_l, T, T) + att = nn::function::Softmax(att, -1); + // (B, h_l, T, Dh) + auto y = att->Matmul(v); + // (B, h_l, T, Dh) -> (B, T, h_l, Dh) -> (B, T, local_C) + y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_C}); + + // Get full tensor + // (B, T, local_C) -> RowParallelLinear(n_embd, n_embd) -> (B, T, C) + y = (*modules_[kCProjLayerName])({y})[0]; + // (B, T, C) == (bs, seq_len, n_embd) + return {y}; + } } MLP::MLP(const GPT2Config &config) : CloneableModule(kType) { @@ -356,7 +396,7 @@ std::tuple DetermineAndCheckVersion(const std:: } } // namespace -std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { +std::shared_ptr GPT2::FromLLMC(const std::string &filepath, bool flash) { if (!std::filesystem::exists(filepath)) { LOG(FATAL) << "File not found: " << filepath; } @@ -384,7 +424,8 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { .original_vocab_size = vocab_size, .n_layer = n_layer, .n_head = n_head, - .n_embd = n_embd}); + .n_embd = n_embd, + .flash = flash}); LOG(INFO) << "magic: " << magic << " version: " << version << " block_size: " << block_size << " vocab_size: " << vocab_size << " n_layer: " << n_layer << " n_head: " << n_head diff --git a/example/gpt2/net.h b/example/gpt2/net.h index 4faf5451..e429770a 100644 --- a/example/gpt2/net.h +++ b/example/gpt2/net.h @@ -19,6 +19,7 @@ struct GPT2Config { int64_t n_layer = 12; int64_t n_head = 12; int64_t n_embd = 768; + bool flash = false; }; class NewGELU : public infini_train::nn::CloneableModule { @@ -140,7 +141,7 @@ class GPT2 : public infini_train::nn::CloneableModule { Forward(const std::vector> &x) override; static std::shared_ptr FromPretrained(ModelType model_type); - static std::shared_ptr FromLLMC(const std::string &filepath); + static std::shared_ptr FromLLMC(const std::string &filepath, bool flash = false); int GetChunkSize() const; diff --git a/example/llama3/main.cc b/example/llama3/main.cc index acc20ac4..72db5bd7 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -84,6 +84,8 @@ DEFINE_string(lora_target_modules, "c_attn,c_proj,c_fc,c_fc2", "LoRA target modu DEFINE_string(lora_save_path, "", "Path to save LoRA weights after training"); DEFINE_string(lora_load_path, "", "Path to load LoRA weights from"); +DEFINE_bool(flash, false, "Whether to enable flash attention"); + using namespace infini_train; namespace { @@ -170,8 +172,9 @@ void Train(const nn::parallel::Rank &rank) { LLaMA3Config model_config = LLaMA3Config(); std::shared_ptr model = nullptr; if (!FLAGS_llmc_filepath.empty()) { - model = LLaMA3::FromLLMC(FLAGS_llmc_filepath); + model = LLaMA3::FromLLMC(FLAGS_llmc_filepath, FLAGS_flash); } else { + model_config.flash = FLAGS_flash; model = std::make_shared(model_config); } diff --git a/example/llama3/net.cc b/example/llama3/net.cc index a50fb831..c146af55 100644 --- a/example/llama3/net.cc +++ b/example/llama3/net.cc @@ -12,6 +12,7 @@ #include #include "glog/logging.h" +#include "gflags/gflags.h" #include "example/common/utils.h" #include "infini_train/include/device.h" @@ -457,7 +458,7 @@ constexpr int32_t kLLaMA3Magic = 20240803; constexpr int32_t kLLaMA3FP32Version = 3; } // namespace -std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { +std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath, bool flash) { if (!std::filesystem::exists(filepath)) { LOG(FATAL) << "File not found: " << filepath; } @@ -496,8 +497,10 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { .rope_theta = rope_theta, .use_scaled_rope = static_cast(use_scaled_rope), .norm_eps = norm_eps, + .flash = flash, .max_gen_batch_size = max_gen_bs}); + // ========== pp_size:num_stages; vpp_size: num_chunks_per_stage ========== int pp_size = nn::parallel::global::GetPipelineParallelSize(); int vpp_size = nn::parallel::global::GetVirtualPipelineParallelSize(); diff --git a/example/llama3/net.h b/example/llama3/net.h index 4496a68d..8338913d 100644 --- a/example/llama3/net.h +++ b/example/llama3/net.h @@ -179,7 +179,7 @@ class LLaMA3 : public infini_train::nn::CloneableModule { Forward(const std::vector> &x) override; static std::shared_ptr FromPretrained(ModelType model_type); - static std::shared_ptr FromLLMC(const std::string &filepath); + static std::shared_ptr FromLLMC(const std::string &filepath, bool flash = false); int GetChunkSize() const { return stage_info_.layer_ranges_per_chunk.size(); } diff --git a/infini_train/include/autograd/scaled_dot_product_attention.h b/infini_train/include/autograd/scaled_dot_product_attention.h new file mode 100644 index 00000000..e48f900a --- /dev/null +++ b/infini_train/include/autograd/scaled_dot_product_attention.h @@ -0,0 +1,44 @@ +#pragma once + +#include +#include +#include + +#include "infini_train/include/autograd/function.h" + +namespace infini_train { +class Tensor; +} + +namespace infini_train::autograd { + +class ScaledDotProductAttention : public Function { +public: + static constexpr char kType[] = "ScaledDotProductAttention"; + + ScaledDotProductAttention(double dropout_p, bool is_causal, + std::optional scale, bool enable_gqa) + : Function(kType), dropout_p_(dropout_p), is_causal_(is_causal), scale_(scale), + enable_gqa_(enable_gqa) {} + + std::vector> Forward( + const std::vector> &input_tensors) override; + + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + + std::vector> Backward( + const std::vector> &grad_outputs) override; + +private: + double dropout_p_ = 0.0; + + bool is_causal_ = false; + std::optional scale_ = std::nullopt; + bool enable_gqa_ = false; + bool has_attn_mask_input_ = false; + std::shared_ptr forward_out_ = nullptr; + std::shared_ptr forward_lse_ = nullptr; + // Saved tensors for backward can be managed via Function's SaveForBackward helper +}; +} // namespace infini_train::autograd diff --git a/infini_train/include/nn/functional.h b/infini_train/include/nn/functional.h index e4354fd1..87a91304 100644 --- a/infini_train/include/nn/functional.h +++ b/infini_train/include/nn/functional.h @@ -3,6 +3,7 @@ #include #include #include +#include namespace infini_train { class Tensor; diff --git a/infini_train/include/nn/parallel/ddp/distributed_optimizer.h b/infini_train/include/nn/parallel/ddp/distributed_optimizer.h index bc31442e..81f640fc 100644 --- a/infini_train/include/nn/parallel/ddp/distributed_optimizer.h +++ b/infini_train/include/nn/parallel/ddp/distributed_optimizer.h @@ -31,6 +31,9 @@ class DistributedOptimizer final : public infini_train::Optimizer { void StartGradSync(); void FinishGradSync(); + // Forward microbatch boundary info to bucket groups. + void SetIsLastMicrobatch(bool is_last_microbatch); + void StartParamSync(bool force_sync = false); void FinishParamSync(bool skip_next_bucket_dispatch = false); diff --git a/infini_train/include/nn/parallel/ddp/param_and_grad_buffer.h b/infini_train/include/nn/parallel/ddp/param_and_grad_buffer.h index c83fe9a5..b4a2aa9d 100644 --- a/infini_train/include/nn/parallel/ddp/param_and_grad_buffer.h +++ b/infini_train/include/nn/parallel/ddp/param_and_grad_buffer.h @@ -70,6 +70,9 @@ class ParamAndGradBucketGroup { // When all params in a bucket group are ready, will call StartGradSync() void RegisterGradReady(const std::shared_ptr ¶meter); + // Mark whether current backward corresponds to the last microbatch in a gradient accumulation window. + void SetIsLastMicrobatch(bool is_last_microbatch); + // Start grad reduce void StartGradSync(); diff --git a/infini_train/src/autograd/scaled_dot_product_attention.cc b/infini_train/src/autograd/scaled_dot_product_attention.cc new file mode 100644 index 00000000..e15b1579 --- /dev/null +++ b/infini_train/src/autograd/scaled_dot_product_attention.cc @@ -0,0 +1,75 @@ +#include "infini_train/include/autograd/scaled_dot_product_attention.h" + +#include "glog/logging.h" + +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::autograd { + +std::vector> ScaledDotProductAttention::Forward( + const std::vector> &input_tensors) { + CHECK(input_tensors.size() == 3 || input_tensors.size() == 4); + const auto &q = input_tensors[0]; + const auto &k = input_tensors[1]; + const auto &v = input_tensors[2]; + const auto mask = input_tensors.size() == 4 ? input_tensors[3] : nullptr; + + auto device = q->GetDevice().type(); + // Call device kernel. Kernel name: ScaledDotProductAttentionForward + auto out_and_lse = Dispatcher::Instance().Call, std::shared_ptr>>( + {device, "ScaledDotProductAttentionForward"}, q, k, v, mask, dropout_p_, is_causal_, scale_, + enable_gqa_); + forward_out_ = std::get<0>(out_and_lse); + forward_lse_ = std::get<1>(out_and_lse); + auto out = forward_out_; + return {out}; +} + +void ScaledDotProductAttention::SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) { + (void)output_tensors; + // Save q,k,v and mask (mask may be nullptr) + const auto &q = input_tensors[0]; + const auto &k = input_tensors[1]; + const auto &v = input_tensors[2]; + std::shared_ptr mask = nullptr; + has_attn_mask_input_ = (input_tensors.size() == 4); + if (input_tensors.size() == 4) { + mask = input_tensors[3]; + } + saved_tensors_ = {q, k, v, mask}; +} + +std::vector> ScaledDotProductAttention::Backward( + const std::vector> &grad_outputs) { + CHECK(saved_tensors_.size() == 4); + const auto &q = saved_tensors_[0]; + const auto &k = saved_tensors_[1]; + const auto &v = saved_tensors_[2]; + const auto &mask = saved_tensors_[3]; + + CHECK_EQ(grad_outputs.size(), 1); + const auto &grad_output = grad_outputs[0]; + + auto device = grad_output->GetDevice().type(); + + CHECK(forward_out_ != nullptr); + CHECK(forward_lse_ != nullptr); + + auto grads = Dispatcher::Instance().Call, std::shared_ptr, + std::shared_ptr>>( + {device, "ScaledDotProductAttentionBackward"}, grad_output, q, k, v, mask, forward_out_, forward_lse_, + dropout_p_, is_causal_, scale_, enable_gqa_); + + forward_out_ = nullptr; + forward_lse_ = nullptr; + + if (has_attn_mask_input_) { + return {std::get<0>(grads), std::get<1>(grads), std::get<2>(grads), nullptr}; + } + + return {std::get<0>(grads), std::get<1>(grads), std::get<2>(grads)}; +} + +} // namespace infini_train::autograd diff --git a/infini_train/src/kernels/cuda/flash_attention.cu b/infini_train/src/kernels/cuda/flash_attention.cu new file mode 100644 index 00000000..d12fa6a3 --- /dev/null +++ b/infini_train/src/kernels/cuda/flash_attention.cu @@ -0,0 +1,564 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "infini_train/include/common/cuda/common_cuda.h" +#include "infini_train/include/common/cuda/kernel_helper.cuh" +#include "infini_train/include/core/runtime/device_guard.h" +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +#include "infini_train/src/core/runtime/cuda/cuda_runtime_common.h" +//#include "infini_train/src/core/cuda/cuda_stream.h" +#include "infini_train/include/common/common.h" // ComputeStrides +#include // cudaStream_t + +// 强烈建议使用 NVIDIA 提供的 frontend 库,否则原始 API 会写到手软 +#include +namespace fe = cudnn_frontend; + + +namespace infini_train::kernels::cuda { + +namespace { +constexpr int64_t Q_UID = 101; +constexpr int64_t K_UID = 102; +constexpr int64_t V_UID = 103; +constexpr int64_t MASK_UID = 104; +constexpr int64_t O_UID = 201; +constexpr int64_t STATS_UID = 202; + +constexpr int64_t dO_UID = 301; +constexpr int64_t dQ_UID = 401; +constexpr int64_t dK_UID = 402; +constexpr int64_t dV_UID = 403; + +struct WorkspaceCache { + void *ptr = nullptr; + size_t size = 0; +}; + +static inline std::size_t hash_combine(std::size_t seed, std::size_t v) { + return seed ^ (v + 0x9e3779b97f4a7c15ULL + (seed << 6U) + (seed >> 2U)); +} + +static inline uint32_t float_to_bits(float x) { + uint32_t bits; + std::memcpy(&bits, &x, sizeof(float)); + return bits; +} + +static std::size_t hash_dims(std::vector const &dims) { + std::size_t h = 0; + for (auto d : dims) { + h = hash_combine(h, std::hash{}(d)); + } + return h; +} + +struct FwdPlanKey { + std::vector q_dims; + std::vector k_dims; + std::vector v_dims; + std::vector mask_dims; + int dtype = 0; + bool is_causal = false; + bool has_mask = false; + uint32_t attn_scale_bits = 0; + + bool operator==(FwdPlanKey const &other) const { + return q_dims == other.q_dims && + k_dims == other.k_dims && + v_dims == other.v_dims && + mask_dims == other.mask_dims && + dtype == other.dtype && + is_causal == other.is_causal && + has_mask == other.has_mask && + attn_scale_bits == other.attn_scale_bits; + } +}; + +struct FwdPlanKeyHash { + std::size_t operator()(FwdPlanKey const &k) const { + std::size_t h = 0; + h = hash_combine(h, hash_dims(k.q_dims)); + h = hash_combine(h, hash_dims(k.k_dims)); + h = hash_combine(h, hash_dims(k.v_dims)); + h = hash_combine(h, hash_dims(k.mask_dims)); + h = hash_combine(h, std::hash{}(k.dtype)); + h = hash_combine(h, std::hash{}(k.is_causal)); + h = hash_combine(h, std::hash{}(k.has_mask)); + h = hash_combine(h, std::hash{}(k.attn_scale_bits)); + return h; + } +}; + +struct BwdPlanKey { + std::vector q_dims; + std::vector k_dims; + std::vector v_dims; + std::vector o_dims; + std::vector do_dims; + std::vector lse_dims; + std::vector mask_dims; + int dtype = 0; + bool is_causal = false; + bool has_mask = false; + uint32_t attn_scale_bits = 0; + + bool operator==(BwdPlanKey const &other) const { + return q_dims == other.q_dims && + k_dims == other.k_dims && + v_dims == other.v_dims && + o_dims == other.o_dims && + do_dims == other.do_dims && + lse_dims == other.lse_dims && + mask_dims == other.mask_dims && + dtype == other.dtype && + is_causal == other.is_causal && + has_mask == other.has_mask && + attn_scale_bits == other.attn_scale_bits; + } +}; + +struct BwdPlanKeyHash { + std::size_t operator()(BwdPlanKey const &k) const { + std::size_t h = 0; + h = hash_combine(h, hash_dims(k.q_dims)); + h = hash_combine(h, hash_dims(k.k_dims)); + h = hash_combine(h, hash_dims(k.v_dims)); + h = hash_combine(h, hash_dims(k.o_dims)); + h = hash_combine(h, hash_dims(k.do_dims)); + h = hash_combine(h, hash_dims(k.lse_dims)); + h = hash_combine(h, hash_dims(k.mask_dims)); + h = hash_combine(h, std::hash{}(k.dtype)); + h = hash_combine(h, std::hash{}(k.is_causal)); + h = hash_combine(h, std::hash{}(k.has_mask)); + h = hash_combine(h, std::hash{}(k.attn_scale_bits)); + return h; + } +}; + +struct CachedPlan { + std::shared_ptr graph; + int64_t workspace_size = 0; +}; + +using FwdPlanCache = std::unordered_map; +using BwdPlanCache = std::unordered_map; +} + +// helpers for cuDNN frontend path +static cudaStream_t get_cuda_stream(const ::infini_train::Device &device) { + auto impl = ::infini_train::core::GetDeviceGuardImpl(device.type()); + auto stream_obj = impl->GetStream(device); + auto cuda_stream = dynamic_cast(stream_obj)->cuda_stream(); + return cuda_stream; +} + +static cudnnHandle_t get_cudnn_handle(const ::infini_train::Device &device) { + //用来记录现在thread正在使用哪个cuda device,cudnn handle是和device绑定的,所以需要这个信息 + int cuda_device = 0; + CUDA_CHECK(cudaGetDevice(&cuda_device)); + + static thread_local std::unordered_map handles; + auto it = handles.find(cuda_device); + if (it == handles.end()) { + cudnnHandle_t handle; + cudnnCreate(&handle); + it = handles.emplace(cuda_device, handle).first; + } + + auto cuda_stream = get_cuda_stream(device); + cudnnSetStream(it->second, cuda_stream); + + return it->second; +} + +static void *acquire_workspace(WorkspaceCache &cache, size_t requested_bytes) { + if (requested_bytes == 0) { + return nullptr; + } + if (cache.ptr == nullptr || cache.size < requested_bytes) { + if (cache.ptr != nullptr) { + CUDA_CHECK(cudaFree(cache.ptr)); + } + CUDA_CHECK(cudaMalloc(&cache.ptr, requested_bytes)); + cache.size = requested_bytes; + } + return cache.ptr; +} + +static WorkspaceCache &forward_workspace_cache() { + static thread_local WorkspaceCache cache; + return cache; +} + +static WorkspaceCache &backward_workspace_cache() { + static thread_local WorkspaceCache cache; + return cache; +} + +static FwdPlanCache &forward_plan_cache() { + static thread_local FwdPlanCache cache; + return cache; +} + +static BwdPlanCache &backward_plan_cache() { + static thread_local BwdPlanCache cache; + return cache; +} + +static fe::DataType_t get_cudnn_dtype(const ::infini_train::DataType dtype); +static std::shared_ptr make_graph_tensor( + const std::shared_ptr &graph, + const std::shared_ptr &tensor, + const std::string &name, + int64_t uid); +static void check_fe_status(fe::error_t status, const char *stage); +static CachedPlan const &get_or_create_fwd_plan(const std::shared_ptr &q, + const std::shared_ptr &k, + const std::shared_ptr &v, + const std::shared_ptr &attn_mask, + bool is_causal, + float attn_scale, + cudnnHandle_t handle); +static CachedPlan const &get_or_create_bwd_plan(const std::shared_ptr &grad_out, + const std::shared_ptr &q, + const std::shared_ptr &k, + const std::shared_ptr &v, + const std::shared_ptr &attn_mask, + const std::shared_ptr &out, + const std::shared_ptr &lse, + bool is_causal, + float attn_scale, + cudnnHandle_t handle); + +static std::tuple, std::shared_ptr> ExecuteSdpaForwardWithLse( + const std::shared_ptr &q, + const std::shared_ptr &k, + const std::shared_ptr &v, + const std::shared_ptr &attn_mask, + double dropout_p, + bool is_causal, + std::optional scale, + bool /*enable_gqa*/) { + if (dropout_p > 0.0) { + throw std::runtime_error("cuDNN frontend SDPA path currently does not support dropout in this minimal kernel"); + } + + auto out = std::make_shared(q->Dims(), q->Dtype(), q->GetDevice()); + + auto q_dims = q->Dims(); + CHECK_EQ(q_dims.size(), 4) << "SDPA expects 4D Q/K/V tensor layout [B, H, S, D]"; + std::vector lse_dims = {q_dims[0], q_dims[1], q_dims[2], 1}; + //lse(Log-sum-exp) + auto lse = std::make_shared(lse_dims, DataType::kFLOAT32, q->GetDevice()); + + cudnnHandle_t handle = get_cudnn_handle(q->GetDevice()); + + float attn_scale = scale.has_value() ? static_cast(scale.value()) + : 1.0f / std::sqrt(static_cast(q->Dims().back())); + + auto const &plan = get_or_create_fwd_plan(q, k, v, attn_mask, is_causal, attn_scale, handle); + void *workspace = acquire_workspace(forward_workspace_cache(), static_cast(plan.workspace_size)); + + std::unordered_map variant_pack = { + {Q_UID, q->DataPtr()}, + {K_UID, k->DataPtr()}, + {V_UID, v->DataPtr()}, + {O_UID, out->DataPtr()}, + {STATS_UID, lse->DataPtr()}, + }; + if (attn_mask) { + variant_pack[MASK_UID] = attn_mask->DataPtr(); + } + + auto exec_status = plan.graph->execute(handle, variant_pack, workspace); + check_fe_status(exec_status, "graph->execute"); + + return {out, lse}; +} + +static fe::DataType_t get_cudnn_dtype(const ::infini_train::DataType dtype) { + switch (dtype) { + case ::infini_train::DataType::kFLOAT32: + return fe::DataType_t::FLOAT; + case ::infini_train::DataType::kFLOAT16: + return fe::DataType_t::HALF; + case ::infini_train::DataType::kBFLOAT16: + return fe::DataType_t::BFLOAT16; + default: + throw std::runtime_error("unsupported dtype for cuDNN SDP"); + } +} + +static std::shared_ptr make_graph_tensor( + const std::shared_ptr &graph, + const std::shared_ptr &tensor, + const std::string &name, + int64_t uid) { + return graph->tensor(fe::graph::Tensor_attributes() + .set_name(name) + .set_uid(uid) + .set_dim(tensor->Dims()) + .set_stride(ComputeStrides(tensor->Dims())) + .set_data_type(get_cudnn_dtype(tensor->Dtype()))); +} + +static void check_fe_status(fe::error_t status, const char *stage) { + if (status.is_bad()) { + throw std::runtime_error(std::string(stage) + ": " + status.get_message()); + } +} + +static CachedPlan const &get_or_create_fwd_plan(const std::shared_ptr &q, + const std::shared_ptr &k, + const std::shared_ptr &v, + const std::shared_ptr &attn_mask, + bool is_causal, + float attn_scale, + cudnnHandle_t handle) { + FwdPlanKey key; + key.q_dims = q->Dims(); + key.k_dims = k->Dims(); + key.v_dims = v->Dims(); + key.has_mask = (attn_mask != nullptr); + if (attn_mask) { + key.mask_dims = attn_mask->Dims(); + } + key.dtype = static_cast(q->Dtype()); + key.is_causal = is_causal; + key.attn_scale_bits = float_to_bits(attn_scale); + + //cache ——FwdPlanCache::map,根据key查找是否已经存在对应的plan,如果存在就直接返回,如果不存在就创建新的plan并插入cache + auto &cache = forward_plan_cache(); + auto it = cache.find(key); + //若能直接找到就返回对应的plan,优化速度 + if (it != cache.end()) { + return it->second; + } + + auto graph = std::make_shared(); + graph->set_io_data_type(get_cudnn_dtype(q->Dtype())) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + auto q_tensor = make_graph_tensor(graph, q, "Q", Q_UID); + auto k_tensor = make_graph_tensor(graph, k, "K", K_UID); + auto v_tensor = make_graph_tensor(graph, v, "V", V_UID); + + auto sdpa_options = fe::graph::SDPA_attributes() + .set_name("flash_attention") + .set_generate_stats(true) + .set_attn_scale(attn_scale); + + if (is_causal) { + sdpa_options.set_diagonal_alignment(cudnn_frontend::DiagonalAlignment_t::TOP_LEFT) + .set_diagonal_band_right_bound(0); + } + + if (attn_mask) { + auto mask_tensor = make_graph_tensor(graph, attn_mask, "Bias", MASK_UID); + sdpa_options.set_bias(mask_tensor); + } + + auto [out_tensor, stats_tensor] = graph->sdpa(q_tensor, k_tensor, v_tensor, sdpa_options); + out_tensor->set_output(true) + .set_uid(O_UID) + .set_dim(q->Dims()) + .set_stride(ComputeStrides(q->Dims())); + std::vector lse_dims = {q->Dims()[0], q->Dims()[1], q->Dims()[2], 1}; + stats_tensor->set_output(true) + .set_uid(STATS_UID) + .set_dim(lse_dims) + .set_stride(ComputeStrides(lse_dims)) + .set_data_type(fe::DataType_t::FLOAT); + + check_fe_status(graph->build(handle, {fe::HeurMode_t::A}), "graph->build (fwd cache build)"); + + int64_t workspace_size = 0; + check_fe_status(graph->get_workspace_size(workspace_size), "graph->get_workspace_size (fwd cache build)"); + + CachedPlan plan; + plan.graph = graph; + plan.workspace_size = workspace_size; + auto inserted = cache.emplace(std::move(key), std::move(plan)); + return inserted.first->second; +} + +static CachedPlan const &get_or_create_bwd_plan(const std::shared_ptr &grad_out, + const std::shared_ptr &q, + const std::shared_ptr &k, + const std::shared_ptr &v, + const std::shared_ptr &attn_mask, + const std::shared_ptr &out, + const std::shared_ptr &lse, + bool is_causal, + float attn_scale, + cudnnHandle_t handle) { + BwdPlanKey key; + key.q_dims = q->Dims(); + key.k_dims = k->Dims(); + key.v_dims = v->Dims(); + key.o_dims = out->Dims(); + key.do_dims = grad_out->Dims(); + key.lse_dims = lse->Dims(); + key.has_mask = (attn_mask != nullptr); + if (attn_mask) { + key.mask_dims = attn_mask->Dims(); + } + key.dtype = static_cast(q->Dtype()); + key.is_causal = is_causal; + key.attn_scale_bits = float_to_bits(attn_scale); + + auto &cache = backward_plan_cache(); + auto it = cache.find(key); + if (it != cache.end()) { + return it->second; + } + + auto graph = std::make_shared(); + graph->set_io_data_type(get_cudnn_dtype(q->Dtype())) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + auto q_tensor = make_graph_tensor(graph, q, "Q", Q_UID); + auto k_tensor = make_graph_tensor(graph, k, "K", K_UID); + auto v_tensor = make_graph_tensor(graph, v, "V", V_UID); + auto o_tensor = make_graph_tensor(graph, out, "O", O_UID); + auto dO_tensor = make_graph_tensor(graph, grad_out, "dO", dO_UID); + auto lse_tensor = make_graph_tensor(graph, lse, "Stats", STATS_UID); + + auto sdpa_bwd_options = fe::graph::SDPA_backward_attributes() + .set_name("flash_attention_backward") + .set_attn_scale(attn_scale) + .set_deterministic_algorithm(true); + + if (is_causal) { + sdpa_bwd_options.set_diagonal_alignment(cudnn_frontend::DiagonalAlignment_t::TOP_LEFT) + .set_diagonal_band_right_bound(0); + } + + if (attn_mask) { + auto mask_tensor = make_graph_tensor(graph, attn_mask, "Bias", MASK_UID); + sdpa_bwd_options.set_bias(mask_tensor); + } + + auto [dQ_tensor, dK_tensor, dV_tensor] = graph->sdpa_backward( + q_tensor, k_tensor, v_tensor, o_tensor, dO_tensor, lse_tensor, sdpa_bwd_options); + + dQ_tensor->set_output(true) + .set_uid(dQ_UID) + .set_dim(q->Dims()) + .set_stride(ComputeStrides(q->Dims())); + dK_tensor->set_output(true) + .set_uid(dK_UID) + .set_dim(k->Dims()) + .set_stride(ComputeStrides(k->Dims())); + dV_tensor->set_output(true) + .set_uid(dV_UID) + .set_dim(v->Dims()) + .set_stride(ComputeStrides(v->Dims())); + + check_fe_status(graph->build(handle, {fe::HeurMode_t::A}), "graph->build (bwd cache build)"); + + int64_t workspace_size = 0; + check_fe_status(graph->get_workspace_size(workspace_size), "graph->get_workspace_size (bwd cache build)"); + + CachedPlan plan; + plan.graph = graph; + plan.workspace_size = workspace_size; + auto inserted = cache.emplace(std::move(key), std::move(plan)); + return inserted.first->second; +} + +std::tuple, std::shared_ptr> ScaledDotProductAttentionForward( + const std::shared_ptr &q, + const std::shared_ptr &k, + const std::shared_ptr &v, + const std::shared_ptr &attn_mask, + double dropout_p, + bool is_causal, + std::optional scale, + bool enable_gqa) { + return ExecuteSdpaForwardWithLse(q, k, v, attn_mask, dropout_p, is_causal, scale, enable_gqa); +} + +std::tuple, std::shared_ptr, std::shared_ptr> +ScaledDotProductAttentionBackward( + const std::shared_ptr &grad_out, + const std::shared_ptr &q, + const std::shared_ptr &k, + const std::shared_ptr &v, + const std::shared_ptr &attn_mask, + const std::shared_ptr &out, + const std::shared_ptr &lse, + double dropout_p, + bool is_causal, + std::optional scale, + bool enable_gqa) { + + auto dq = std::make_shared(q->Dims(), q->Dtype(), q->GetDevice()); + auto dk = std::make_shared(k->Dims(), k->Dtype(), k->GetDevice()); + auto dv = std::make_shared(v->Dims(), v->Dtype(), v->GetDevice()); + + if (dropout_p > 0.0) { + throw std::runtime_error("cuDNN frontend SDPA path currently does not support dropout in this minimal kernel"); + } + (void)enable_gqa; + + + // ---------- cuDNN frontend implementation ---------- + cudnnHandle_t handle = get_cudnn_handle(grad_out->GetDevice()); + + float attn_scale = scale.has_value() ? static_cast(scale.value()) + : 1.0f / std::sqrt(static_cast(q->Dims().back())); + + auto const &plan = get_or_create_bwd_plan(grad_out, q, k, v, attn_mask, out, lse, is_causal, attn_scale, handle); + void *workspace = acquire_workspace(backward_workspace_cache(), static_cast(plan.workspace_size)); + + std::unordered_map variant_pack = { + {Q_UID, q->DataPtr()}, + {K_UID, k->DataPtr()}, + {V_UID, v->DataPtr()}, + {O_UID, out->DataPtr()}, + {dO_UID, grad_out->DataPtr()}, + {STATS_UID, lse->DataPtr()}, + {dQ_UID, dq->DataPtr()}, + {dK_UID, dk->DataPtr()}, + {dV_UID, dv->DataPtr()}, + }; + if (attn_mask) { + variant_pack[MASK_UID] = attn_mask->DataPtr(); + } + + auto exec_status = plan.graph->execute(handle, variant_pack, workspace); + check_fe_status(exec_status, "graph->execute (backward)"); + + return {dq, dk, dv}; +} + +} + + +#define REGISTER_CUDA_LINEAR_KERNEL(kernel_name) \ + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + +REGISTER_CUDA_LINEAR_KERNEL(ScaledDotProductAttentionBackward) +REGISTER_CUDA_LINEAR_KERNEL(ScaledDotProductAttentionForward) + +#undef REGISTER_CUDA_LINEAR_KERNEL diff --git a/infini_train/src/kernels/cuda/ysyx.code-workspace b/infini_train/src/kernels/cuda/ysyx.code-workspace new file mode 100644 index 00000000..afc35437 --- /dev/null +++ b/infini_train/src/kernels/cuda/ysyx.code-workspace @@ -0,0 +1,8 @@ +{ + "folders": [ + { + "path": "../../../../.." + } + ], + "settings": {} +} \ No newline at end of file diff --git a/third_party/cudnn-frontend b/third_party/cudnn-frontend new file mode 160000 index 00000000..d33027a4 --- /dev/null +++ b/third_party/cudnn-frontend @@ -0,0 +1 @@ +Subproject commit d33027a41a93af9c85f089c6364ab415fce98982