Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@
[submodule "third_party/eigen"]
path = third_party/eigen
url = git@github.com:InfiniTensor/eigen-mirror.git
[submodule "third_party/cudnn-frontend"]
path = third_party/cudnn-frontend
url = https://github.com/NVIDIA/cudnn-frontend.git
11 changes: 11 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,22 @@ if(USE_CUDA)
find_package(CUDAToolkit REQUIRED)
include_directories(${CUDAToolkit_INCLUDE_DIRS})

# ========== cuDNN 库 ==========
find_library(CUDNN_LIBRARY cudnn REQUIRED)
message(STATUS "Found cuDNN at: ${CUDNN_LIBRARY}")
# ========================================


# CUDA compilation options
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda --expt-relaxed-constexpr")

# Only compile CUDA kernels / cuda sources here (your original used src/*.cu)
file(GLOB_RECURSE CUDA_KERNELS ${PROJECT_SOURCE_DIR}/infini_train/src/*.cu)

add_library(infini_train_cuda_kernels STATIC ${CUDA_KERNELS})
target_include_directories(infini_train_cuda_kernels PUBLIC
${PROJECT_SOURCE_DIR}/third_party/cudnn-frontend/include)

set_target_properties(infini_train_cuda_kernels PROPERTIES CUDA_ARCHITECTURES "75;80;90")

target_link_libraries(infini_train_cuda_kernels
Expand All @@ -92,6 +101,8 @@ if(USE_CUDA)
CUDA::cudart
CUDA::cublas
CUDA::cuda_driver
CUDA::nvrtc
${CUDNN_LIBRARY}
)

if(USE_NCCL)
Expand Down
10 changes: 7 additions & 3 deletions example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ DEFINE_string(
precision_check, "",
"precision check config: level=N,format=simple|table,output_md5=true|false,output_path=PATH,baseline=PATH");

//Flash attention parameters
DEFINE_bool(flash, false, "Whether to enable flash attention");

// LoRA parameters
DEFINE_int32(lora_rank, 0, "LoRA rank (0 = disabled)");
DEFINE_double(lora_alpha, 16.0, "LoRA alpha scaling factor");
Expand Down Expand Up @@ -189,16 +192,17 @@ void Train(const nn::parallel::Rank &rank) {
// init the model, either from scratch or from OpenAI pretrained checkpoint
GPT2Config model_config;
std::shared_ptr<nn::Module> model = nullptr;

if (!FLAGS_llmc_filepath.empty()) {
model = GPT2::FromLLMC(FLAGS_llmc_filepath);
model = GPT2::FromLLMC(FLAGS_llmc_filepath, FLAGS_flash);
} else if (kModelToConfigs.count(FLAGS_model)) {
model_config = kModelToConfigs.at(FLAGS_model);
model_config.flash = FLAGS_flash;
model = std::make_shared<GPT2>(model_config);
} else {
model = GPT2::FromPretrained(kStrToModelType.at(FLAGS_model));
model_config.flash = FLAGS_flash;
}

model->To(device);

utils::PrecisionChecker::BuildNameMap(model.get());
Expand Down
81 changes: 61 additions & 20 deletions example/gpt2/net.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <vector>

#include "glog/logging.h"
#include "gflags/gflags.h"

#include "example/common/utils.h"
#include "infini_train/include/device.h"
Expand Down Expand Up @@ -105,24 +106,63 @@ CausalSelfAttention::Forward(const std::vector<std::shared_ptr<infini_train::Ten
q = q->View({B, T, local_n_head_, head_dim})->Transpose(1, 2);
v = v->View({B, T, local_n_head_, head_dim})->Transpose(1, 2);

// (B, h_l, T, T)
auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(head_dim));
// (1, 1, T, T)
auto mask = buffers_[kParamBiasName]->Slice({0, 0, 0, 0}, {1, 1, T, T}, {1, 1, 1, 1});
// (1, 1, T, T) -> eq 0 -> (1, 1, T, T) -> masked_fill -> (B, h_l, T, T)
att = att->MaskedFill(mask == 0, -std::numeric_limits<float>::infinity());
// (B, h_l, T, T)
att = nn::function::Softmax(att, -1);
// (B, h_l, T, Dh)
auto y = att->Matmul(v);
// (B, h_l, T, Dh) -> (B, T, h_l, Dh) -> (B, T, local_C)
y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_C});

// Get full tensor
// (B, T, local_C) -> RowParallelLinear(n_embd, n_embd) -> (B, T, C)
y = (*modules_[kCProjLayerName])({y})[0];
// (B, T, C) == (bs, seq_len, n_embd)
return {y};
auto Flag_flash = config_.flash;
const bool use_flash_sdpa = Flag_flash && q->Dtype() == DataType::kBFLOAT16;
static bool logged_attention_path = false;
if (!logged_attention_path) {
LOG(ERROR) << "[GPT2][AttentionPath] flash_flag=" << Flag_flash
<< ", q_dtype=" << kDataTypeToDesc.at(q->Dtype())
<< ", selected=" << (use_flash_sdpa ? "cuDNN_SDPA_flash" : "matmul_softmax_fallback");
logged_attention_path = true;
}

std::shared_ptr<infini_train::Tensor> y;

if (Flag_flash) {
// cuDNN SDPA path: causal masking should be enabled by `is_causal=true`.
// Do not pass the 0/1 tril mask as additive bias (it is not -inf mask).
auto q_flash = q;
auto k_flash = k;
auto v_flash = v;
if (q->Dtype() != DataType::kBFLOAT16) {
q_flash = std::make_shared<Tensor>(q->To(DataType::kBFLOAT16));
k_flash = std::make_shared<Tensor>(k->To(DataType::kBFLOAT16));
v_flash = std::make_shared<Tensor>(v->To(DataType::kBFLOAT16));
}
y = nn::function::ScaledDotProductAttention(q_flash, k_flash, v_flash, nullptr, 0.0, true, std::nullopt,
false);

if (y->Dtype() != q->Dtype()) {
y = std::make_shared<Tensor>(y->To(q->Dtype()));
}
// ensure expected layout: (B, h_l, T, Dh) -> (B, T, h_l, Dh) -> (B, T, local_C)
y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_C});

// Get full tensor
// (B, T, local_C) -> RowParallelLinear(n_embd, n_embd) -> (B, T, C)
y = (*modules_[kCProjLayerName])({y})[0];
// (B, T, C) == (bs, seq_len, n_embd)
return {y};
} else {
// (B, h_l, T, T)
auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(head_dim));
// (1, 1, T, T)
auto mask = buffers_[kParamBiasName]->Slice({0, 0, 0, 0}, {1, 1, T, T}, {1, 1, 1, 1});
// (1, 1, T, T) -> eq 0 -> (1, 1, T, T) -> masked_fill -> (B, h_l, T, T)
att = att->MaskedFill(mask == 0, -std::numeric_limits<float>::infinity());
// (B, h_l, T, T)
att = nn::function::Softmax(att, -1);
// (B, h_l, T, Dh)
auto y = att->Matmul(v);
// (B, h_l, T, Dh) -> (B, T, h_l, Dh) -> (B, T, local_C)
y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_C});

// Get full tensor
// (B, T, local_C) -> RowParallelLinear(n_embd, n_embd) -> (B, T, C)
y = (*modules_[kCProjLayerName])({y})[0];
// (B, T, C) == (bs, seq_len, n_embd)
return {y};
}
}

MLP::MLP(const GPT2Config &config) : CloneableModule(kType) {
Expand Down Expand Up @@ -356,7 +396,7 @@ std::tuple<int32_t, infini_train::DataType> DetermineAndCheckVersion(const std::
}
} // namespace

std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath, bool flash) {
if (!std::filesystem::exists(filepath)) {
LOG(FATAL) << "File not found: " << filepath;
}
Expand Down Expand Up @@ -384,7 +424,8 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
.original_vocab_size = vocab_size,
.n_layer = n_layer,
.n_head = n_head,
.n_embd = n_embd});
.n_embd = n_embd,
.flash = flash});

LOG(INFO) << "magic: " << magic << " version: " << version << " block_size: " << block_size
<< " vocab_size: " << vocab_size << " n_layer: " << n_layer << " n_head: " << n_head
Expand Down
3 changes: 2 additions & 1 deletion example/gpt2/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ struct GPT2Config {
int64_t n_layer = 12;
int64_t n_head = 12;
int64_t n_embd = 768;
bool flash = false;
};

class NewGELU : public infini_train::nn::CloneableModule<NewGELU> {
Expand Down Expand Up @@ -140,7 +141,7 @@ class GPT2 : public infini_train::nn::CloneableModule<GPT2> {
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;

static std::shared_ptr<GPT2> FromPretrained(ModelType model_type);
static std::shared_ptr<GPT2> FromLLMC(const std::string &filepath);
static std::shared_ptr<GPT2> FromLLMC(const std::string &filepath, bool flash = false);

int GetChunkSize() const;

Expand Down
5 changes: 4 additions & 1 deletion example/llama3/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ DEFINE_string(lora_target_modules, "c_attn,c_proj,c_fc,c_fc2", "LoRA target modu
DEFINE_string(lora_save_path, "", "Path to save LoRA weights after training");
DEFINE_string(lora_load_path, "", "Path to load LoRA weights from");

DEFINE_bool(flash, false, "Whether to enable flash attention");

using namespace infini_train;

namespace {
Expand Down Expand Up @@ -170,8 +172,9 @@ void Train(const nn::parallel::Rank &rank) {
LLaMA3Config model_config = LLaMA3Config();
std::shared_ptr<nn::Module> model = nullptr;
if (!FLAGS_llmc_filepath.empty()) {
model = LLaMA3::FromLLMC(FLAGS_llmc_filepath);
model = LLaMA3::FromLLMC(FLAGS_llmc_filepath, FLAGS_flash);
} else {
model_config.flash = FLAGS_flash;
model = std::make_shared<LLaMA3>(model_config);
}

Expand Down
5 changes: 4 additions & 1 deletion example/llama3/net.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <vector>

#include "glog/logging.h"
#include "gflags/gflags.h"

#include "example/common/utils.h"
#include "infini_train/include/device.h"
Expand Down Expand Up @@ -457,7 +458,7 @@ constexpr int32_t kLLaMA3Magic = 20240803;
constexpr int32_t kLLaMA3FP32Version = 3;
} // namespace

std::shared_ptr<LLaMA3> LLaMA3::FromLLMC(const std::string &filepath) {
std::shared_ptr<LLaMA3> LLaMA3::FromLLMC(const std::string &filepath, bool flash) {
if (!std::filesystem::exists(filepath)) {
LOG(FATAL) << "File not found: " << filepath;
}
Expand Down Expand Up @@ -496,8 +497,10 @@ std::shared_ptr<LLaMA3> LLaMA3::FromLLMC(const std::string &filepath) {
.rope_theta = rope_theta,
.use_scaled_rope = static_cast<bool>(use_scaled_rope),
.norm_eps = norm_eps,
.flash = flash,
.max_gen_batch_size = max_gen_bs});


// ========== pp_size:num_stages; vpp_size: num_chunks_per_stage ==========
int pp_size = nn::parallel::global::GetPipelineParallelSize();
int vpp_size = nn::parallel::global::GetVirtualPipelineParallelSize();
Expand Down
2 changes: 1 addition & 1 deletion example/llama3/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ class LLaMA3 : public infini_train::nn::CloneableModule<LLaMA3> {
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;

static std::shared_ptr<LLaMA3> FromPretrained(ModelType model_type);
static std::shared_ptr<LLaMA3> FromLLMC(const std::string &filepath);
static std::shared_ptr<LLaMA3> FromLLMC(const std::string &filepath, bool flash = false);

int GetChunkSize() const { return stage_info_.layer_ranges_per_chunk.size(); }

Expand Down
44 changes: 44 additions & 0 deletions infini_train/include/autograd/scaled_dot_product_attention.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#pragma once

#include <memory>
#include <optional>
#include <vector>

#include "infini_train/include/autograd/function.h"

namespace infini_train {
class Tensor;
}

namespace infini_train::autograd {

class ScaledDotProductAttention : public Function {
public:
static constexpr char kType[] = "ScaledDotProductAttention";

ScaledDotProductAttention(double dropout_p, bool is_causal,
std::optional<double> scale, bool enable_gqa)
: Function(kType), dropout_p_(dropout_p), is_causal_(is_causal), scale_(scale),
enable_gqa_(enable_gqa) {}

std::vector<std::shared_ptr<Tensor>> Forward(
const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;

void SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
const std::vector<std::shared_ptr<Tensor>> &output_tensors) override;

std::vector<std::shared_ptr<Tensor>> Backward(
const std::vector<std::shared_ptr<Tensor>> &grad_outputs) override;

private:
double dropout_p_ = 0.0;

bool is_causal_ = false;
std::optional<double> scale_ = std::nullopt;
bool enable_gqa_ = false;
bool has_attn_mask_input_ = false;
std::shared_ptr<Tensor> forward_out_ = nullptr;
std::shared_ptr<Tensor> forward_lse_ = nullptr;
// Saved tensors for backward can be managed via Function's SaveForBackward helper
};
} // namespace infini_train::autograd
1 change: 1 addition & 0 deletions infini_train/include/nn/functional.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <cstdint>
#include <memory>
#include <vector>
#include <optional>

namespace infini_train {
class Tensor;
Expand Down
3 changes: 3 additions & 0 deletions infini_train/include/nn/parallel/ddp/distributed_optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ class DistributedOptimizer final : public infini_train::Optimizer {
void StartGradSync();
void FinishGradSync();

// Forward microbatch boundary info to bucket groups.
void SetIsLastMicrobatch(bool is_last_microbatch);

void StartParamSync(bool force_sync = false);
void FinishParamSync(bool skip_next_bucket_dispatch = false);

Expand Down
3 changes: 3 additions & 0 deletions infini_train/include/nn/parallel/ddp/param_and_grad_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ class ParamAndGradBucketGroup {
// When all params in a bucket group are ready, will call StartGradSync()
void RegisterGradReady(const std::shared_ptr<Tensor> &parameter);

// Mark whether current backward corresponds to the last microbatch in a gradient accumulation window.
void SetIsLastMicrobatch(bool is_last_microbatch);

// Start grad reduce
void StartGradSync();

Expand Down
Loading