Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,9 @@ build/
*.log
*.report.rank*
*.records.log.rank*
CPackConfig.cmake
CPackSourceConfig.cmake
__pycache__/
.trae/

!cuda-report/**/*.log
7 changes: 6 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ if(USE_CUDA)
include_directories(${CUDAToolkit_INCLUDE_DIRS})

# CUDA compilation options
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda --expt-relaxed-constexpr")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda --expt-relaxed-constexpr -prec-div=true -prec-sqrt=true -fmad=false")

# Only compile CUDA kernels / cuda sources here (your original used src/*.cu)
file(GLOB_RECURSE CUDA_KERNELS ${PROJECT_SOURCE_DIR}/infini_train/src/*.cu)
Expand Down Expand Up @@ -204,3 +204,8 @@ link_infini_train_exe(test_precision_check)
add_executable(test_lora test/lora/test_lora.cc)
link_infini_train_exe(test_lora)

add_executable(test_flash_layout test/flash_attn/test_flash_layout.cc)
link_infini_train_exe(test_flash_layout)

add_executable(test_flash_precision test/flash_attn/test_flash_precision.cc)
link_infini_train_exe(test_flash_precision)
2 changes: 1 addition & 1 deletion example/common/tiny_shakespeare_dataset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ TinyShakespeareFile ReadTinyShakespeareFile(const std::string &path, size_t sequ

std::variant<std::vector<uint16_t>, std::vector<int32_t>> buffer;
if (text_file.type == TinyShakespeareType::kUINT16) {
CHECK_LE(sequence_length, 1024); // GPT-2: max_seq_length = 1024
// CHECK_LE(sequence_length, 1024); // GPT-2: max_seq_length = 1024
buffer = std::vector<uint16_t>(num_sequences * sequence_length);
} else if (text_file.type == TinyShakespeareType::kUINT32) {
CHECK_LE(sequence_length, 8192); // LLaMA-3: max_seq_length = 8192
Expand Down
6 changes: 6 additions & 0 deletions example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
DEFINE_string(input_bin, "", "input .bin to train on");
DEFINE_string(input_val_bin, "", "input .bin to eval validation loss on");
DEFINE_string(tokenizer_bin, "", "input .bin to tokenizer");
DEFINE_bool(flash, false, "Enable FlashAttention");
// model bin file is downloaded and processed using the script at
// https://github.com/karpathy/llm.c/blob/master/train_gpt2.py
DEFINE_string(llmc_filepath, "", "llmc model file path to load from");
Expand Down Expand Up @@ -194,9 +195,14 @@ void Train(const nn::parallel::Rank &rank) {
model = GPT2::FromLLMC(FLAGS_llmc_filepath);
} else if (kModelToConfigs.count(FLAGS_model)) {
model_config = kModelToConfigs.at(FLAGS_model);
model_config.use_flash_attn = FLAGS_flash;
if (FLAGS_sequence_length > model_config.block_size) {
model_config.block_size = FLAGS_sequence_length;
}
model = std::make_shared<GPT2>(model_config);
} else {
model = GPT2::FromPretrained(kStrToModelType.at(FLAGS_model));
// TODO: support flash attn in pretrained path if needed
}

model->To(device);
Expand Down
37 changes: 25 additions & 12 deletions example/gpt2/net.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,18 +105,31 @@ CausalSelfAttention::Forward(const std::vector<std::shared_ptr<infini_train::Ten
q = q->View({B, T, local_n_head_, head_dim})->Transpose(1, 2);
v = v->View({B, T, local_n_head_, head_dim})->Transpose(1, 2);

// (B, h_l, T, T)
auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(head_dim));
// (1, 1, T, T)
auto mask = buffers_[kParamBiasName]->Slice({0, 0, 0, 0}, {1, 1, T, T}, {1, 1, 1, 1});
// (1, 1, T, T) -> eq 0 -> (1, 1, T, T) -> masked_fill -> (B, h_l, T, T)
att = att->MaskedFill(mask == 0, -std::numeric_limits<float>::infinity());
// (B, h_l, T, T)
att = nn::function::Softmax(att, -1);
// (B, h_l, T, Dh)
auto y = att->Matmul(v);
// (B, h_l, T, Dh) -> (B, T, h_l, Dh) -> (B, T, local_C)
y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_C});
std::shared_ptr<Tensor> y = nullptr;
if (config_.use_flash_attn && !(q->requires_grad() || k->requires_grad() || v->requires_grad())) {
// q, k, v: (B, h_l, T, Dh) -> (B, T, h_l, Dh)
q = q->Transpose(1, 2);
k = k->Transpose(1, 2);
v = v->Transpose(1, 2);
// (B, T, h_l, Dh)
y = nn::function::ScaledDotProductAttention(q, k, v, nullptr, 0.0, true);
} else {
// (B, h_l, T, T)
auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(head_dim));
// (1, 1, T, T)
auto mask = buffers_[kParamBiasName]->Slice({0, 0, 0, 0}, {1, 1, T, T}, {1, 1, 1, 1});
// (1, 1, T, T) -> eq 0 -> (1, 1, T, T) -> masked_fill -> (B, h_l, T, T)
att = att->MaskedFill(mask == 0, -std::numeric_limits<float>::infinity());
// (B, h_l, T, T)
att = nn::function::Softmax(att, -1);
// (B, h_l, T, Dh)
y = att->Matmul(v);
// (B, h_l, T, Dh) -> (B, T, h_l, Dh)
y = y->Transpose(1, 2);
}

// (B, T, h_l, Dh) -> (B, T, local_C)
y = y->Contiguous()->View({B, T, local_C});

// Get full tensor
// (B, T, local_C) -> RowParallelLinear(n_embd, n_embd) -> (B, T, C)
Expand Down
1 change: 1 addition & 0 deletions example/gpt2/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ struct GPT2Config {
int64_t n_layer = 12;
int64_t n_head = 12;
int64_t n_embd = 768;
bool use_flash_attn = false;
};

class NewGELU : public infini_train::nn::CloneableModule<NewGELU> {
Expand Down
29 changes: 29 additions & 0 deletions example/llama3/Dockerfile.llama3
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Use official NVIDIA CUDA image
FROM nvidia/cuda:12.2.0-devel-ubuntu22.04

# Install dependencies
RUN apt-get update && apt-get install -y \
python3 python3-pip git cmake curl \
&& rm -rf /var/lib/apt/lists/*

# Install PyTorch
RUN pip3 install torch==2.2.1 torchvision==0.17.1 torchaudio==2.2.1 --index-url https://download.pytorch.org/whl/cu121

# Install Transformers
RUN pip3 install transformers==4.40.0

# Install FlashAttention (if using python bindings, though we are C++)
# We might need Python bindings for unit testing or reference checks
RUN pip3 install flash-attn==2.5.6

# Set up working directory
WORKDIR /app

# Copy InfiniTrain source
COPY . /app

# Build InfiniTrain
RUN mkdir build && cd build && cmake .. && make -j$(nproc)

# Default command
CMD ["bash", "example/llama3/run_llama3_7b.sh"]
16 changes: 15 additions & 1 deletion example/llama3/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <optional>
#include <unordered_set>

#include <cuda_runtime.h>
#include "gflags/gflags.h"
#include "glog/logging.h"

Expand Down Expand Up @@ -40,6 +41,7 @@
DEFINE_string(input_bin, "", "input .bin to train on");
DEFINE_string(input_val_bin, "", "input .bin to eval validation loss on");
DEFINE_string(tokenizer_bin, "", "input .bin to tokenizer");
DEFINE_bool(flash, false, "Enable FlashAttention");
// model bin file is downloaded and processed using the script at
// https://github.com/karpathy/llm.c/blob/master/train_llama3.py
DEFINE_string(llmc_filepath, "", "llmc model file path to load from");
Expand Down Expand Up @@ -170,8 +172,9 @@ void Train(const nn::parallel::Rank &rank) {
LLaMA3Config model_config = LLaMA3Config();
std::shared_ptr<nn::Module> model = nullptr;
if (!FLAGS_llmc_filepath.empty()) {
model = LLaMA3::FromLLMC(FLAGS_llmc_filepath);
model = LLaMA3::FromLLMC(FLAGS_llmc_filepath, FLAGS_flash);
} else {
model_config.use_flash_attn = FLAGS_flash;
model = std::make_shared<LLaMA3>(model_config);
}

Expand Down Expand Up @@ -422,6 +425,17 @@ int main(int argc, char *argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);

if (FLAGS_device == kDeviceCUDA) {
int cuda_device_count = 0;
auto cuda_status = cudaGetDeviceCount(&cuda_device_count);
if (cuda_status != cudaSuccess || cuda_device_count <= 0) {
FLAGS_device = kDeviceCPU;
FLAGS_llmc_filepath.clear();
FLAGS_flash = false;
cudaGetLastError();
}
}

auto precision_config = utils::PrecisionCheckConfig::Parse(FLAGS_precision_check);
nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel,
FLAGS_pipeline_parallel, FLAGS_virtual_pipeline_parallel);
Expand Down
51 changes: 31 additions & 20 deletions example/llama3/net.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,26 +217,36 @@ std::vector<std::shared_ptr<Tensor>> CausalSelfAttention::Forward(const std::vec
k = k->Transpose(1, 2);
v = v->Transpose(1, 2);

// TODO(zbl): support flash attention later
// if (flash_) { ... }

// manual implementation of attention
// this materializes the large (T,T) matrix for all the queries and keys

// q: (B, H_local, T, D)
// k: (B, H_local, T, D) -> (B, H_local, D, T)
// q @ k.T: (B, H_local, T, T) -> mul 1.0 / sqrt(D) -> (B, H_local, T, T)
auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(static_cast<float>(D)));
if (mask) {
// mask: (1, 1, T, T)
att = att->MaskedFill(mask, std::numeric_limits<float>::lowest());
std::shared_ptr<Tensor> y = nullptr;
if (config_.use_flash_attn && !(q->requires_grad() || k->requires_grad() || v->requires_grad())) {
// q, k, v: (B, H_local, T, D) -> (B, T, H_local, D)
q = q->Transpose(1, 2);
k = k->Transpose(1, 2);
v = v->Transpose(1, 2);
// (B, T, H_local, D)
y = nn::function::ScaledDotProductAttention(q, k, v, nullptr, 0.0, true);
} else {
// manual implementation of attention
// this materializes the large (T,T) matrix for all the queries and keys

// q: (B, H_local, T, D)
// k: (B, H_local, T, D) -> (B, H_local, D, T)
// q @ k.T: (B, H_local, T, T) -> mul 1.0 / sqrt(D) -> (B, H_local, T, T)
auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(static_cast<float>(D)));
if (mask) {
// mask: (1, 1, T, T)
att = att->MaskedFill(mask, std::numeric_limits<float>::lowest());
}
// (B, H_local, T, T)
att = nn::function::Softmax(att, -1);
// att: (B, H_local, T, T) @ v: (B, H_local, T, D) -> y: (B, H_local, T, D)
y = att->Matmul(v);
// (B, H_local, T, D) -> Transpose(1, 2) -> (B, T, H_local, D)
y = y->Transpose(1, 2);
}
// (B, H_local, T, T)
att = nn::function::Softmax(att, -1);
// att: (B, H_local, T, T) @ v: (B, H_local, T, D) -> y: (B, H_local, T, D)
auto y = att->Matmul(v);
// (B, H_local, T, D) -> Transpose(1, 2) -> (B, T, H_local, D) -> (B, T, C_local)
y = y->Transpose(1, 2)->Contiguous()->View({B, T, C_local});

// (B, T, H_local, D) -> (B, T, C_local)
y = y->Contiguous()->View({B, T, C_local});
// output projection
// (B, T, C_local) -> RowParallelLinear(C, C) -> (B, T, C)
y = (*modules_[kCProjLayerName])({y})[0];
Expand Down Expand Up @@ -457,7 +467,7 @@ constexpr int32_t kLLaMA3Magic = 20240803;
constexpr int32_t kLLaMA3FP32Version = 3;
} // namespace

std::shared_ptr<LLaMA3> LLaMA3::FromLLMC(const std::string &filepath) {
std::shared_ptr<LLaMA3> LLaMA3::FromLLMC(const std::string &filepath, bool use_flash_attn) {
if (!std::filesystem::exists(filepath)) {
LOG(FATAL) << "File not found: " << filepath;
}
Expand Down Expand Up @@ -496,6 +506,7 @@ std::shared_ptr<LLaMA3> LLaMA3::FromLLMC(const std::string &filepath) {
.rope_theta = rope_theta,
.use_scaled_rope = static_cast<bool>(use_scaled_rope),
.norm_eps = norm_eps,
.use_flash_attn = use_flash_attn,
.max_gen_batch_size = max_gen_bs});

// ========== pp_size:num_stages; vpp_size: num_chunks_per_stage ==========
Expand Down
15 changes: 8 additions & 7 deletions example/llama3/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
struct LLaMA3Config {
// ref: https://huggingface.co/meta-llama/Llama-3.2-1B
// Model basic config
int64_t block_size = 8192; // Max seq_len
// Optimize: Reduce default model size to fit in smaller GPUs and ensure stability
int64_t block_size = 2048; // Max seq_len
int64_t vocab_size = 128256; // Vocab size
int64_t n_layer = 16; // Num of transformer layers
int64_t n_head = 32; // Num of heads in MHA
int64_t n_kv_head = 8; // Num of Key/Value heads(< n_head if using GQA)
int64_t n_embd = 2048; // Hidden size
int64_t n_layer = 8; // Num of transformer layers
int64_t n_head = 16; // Num of heads in MHA
int64_t n_kv_head = 4; // Num of Key/Value heads(< n_head if using GQA)
int64_t n_embd = 1024; // Hidden size

// FFN config
std::optional<float> ffn_dim_multiplier = 1.5f; // FFN dim multiplier
Expand All @@ -36,7 +37,7 @@ struct LLaMA3Config {

// Inference
bool use_kv = false; // kv cache
bool flash = false; // flash attention
bool use_flash_attn = false; // flash attention
int64_t max_gen_batch_size = 4; // max batch size during inference
};

Expand Down Expand Up @@ -179,7 +180,7 @@ class LLaMA3 : public infini_train::nn::CloneableModule<LLaMA3> {
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;

static std::shared_ptr<LLaMA3> FromPretrained(ModelType model_type);
static std::shared_ptr<LLaMA3> FromLLMC(const std::string &filepath);
static std::shared_ptr<LLaMA3> FromLLMC(const std::string &filepath, bool use_flash_attn = false);

int GetChunkSize() const { return stage_info_.layer_ranges_per_chunk.size(); }

Expand Down
61 changes: 61 additions & 0 deletions example/llama3/probe_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import torch
import sys
import os
import subprocess

def get_cuda_version():
try:
return torch.version.cuda
except:
return "N/A"

def get_flash_attn_version():
try:
import flash_attn
return flash_attn.__version__
except ImportError:
return "Not Installed"

def get_gpu_info():
if not torch.cuda.is_available():
return "No GPU Available"

info = []
for i in range(torch.cuda.device_count()):
props = torch.cuda.get_device_properties(i)
info.append(f"Device {i}: {props.name}, Compute {props.major}.{props.minor}, Memory {props.total_memory / 1024**3:.2f} GB")
return "\n".join(info)

def get_libcudart_version():
try:
# Check specific path or ldconfig?
# Just return dummy for now or implement proper check
return "Unknown"
except:
return "Error"

print("InfiniTrain Environment Probe")
print("-" * 30)
print(f"Python: {sys.version}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA (PyTorch): {get_cuda_version()}")
print(f"FlashAttention (Python): {get_flash_attn_version()}")
print("-" * 30)
print("GPU Info:")
print(get_gpu_info())
print("-" * 30)

# Write to env_snapshot.json
import json
snapshot = {
"pytorch": torch.__version__,
"cuda": get_cuda_version(),
"flash_attn": get_flash_attn_version(),
"gpu_info": get_gpu_info(),
"env_vars": dict(os.environ)
}

with open("env_snapshot.json", "w") as f:
json.dump(snapshot, f, indent=4)

print("Snapshot saved to env_snapshot.json")
Loading