Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 58 additions & 12 deletions example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
#include <format>
#include <memory>
#include <optional>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "gflags/gflags.h"
#include "glog/logging.h"
Expand All @@ -19,6 +21,7 @@
#include "infini_train/include/nn/modules/loss.h"
#include "infini_train/include/nn/modules/module.h"
#include "infini_train/include/nn/modules/transformer/transformer.h"
#include "infini_train/include/nn/parallel/context_parallel.h"
#include "infini_train/include/nn/parallel/ddp/distributed_data_parallel.h"
#include "infini_train/include/nn/parallel/ddp/distributed_optimizer.h"
#include "infini_train/include/nn/parallel/global.h"
Expand Down Expand Up @@ -76,6 +79,8 @@ DEFINE_int32(
"When set > 1, enables data parallelism with device=cuda on the specified number of visible CUDA devices.");
DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size");
DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel");
DEFINE_uint32(context_parallel, 1, "Context Parallel world size");
DEFINE_string(cp_comm_type, "p2p", "Context Parallel communication type (all_gather|p2p|a2a)");
DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the number of PP stages.");
DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage.");

Expand Down Expand Up @@ -124,6 +129,9 @@ DEFINE_validator(model, [](const char *, const std::string &value) { return kSup
DEFINE_validator(device,
[](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; });
DEFINE_validator(zero_stage, [](const char *, int32_t value) { return value >= 0 && value <= 3; });
DEFINE_validator(cp_comm_type, [](const char *, const std::string &value) {
return value == "all_gather" || value == "p2p" || value == "a2a";
});

void Train(const nn::parallel::Rank &rank) {
using namespace nn::parallel;
Expand All @@ -150,23 +158,37 @@ void Train(const nn::parallel::Rank &rank) {
int ddp_world_size = global::GetDataParallelSize();
int tp_world_size = global::GetTensorParallelSize();
int sp_world_size = global::GetSequenceParallelEnabled() ? tp_world_size : 1;
int cp_world_size = global::GetContextParallelSize();
int dp_cp_world_size = ddp_world_size * cp_world_size;
int pp_world_size = global::GetPipelineParallelSize();
const auto cp_local_sequence_length = FLAGS_sequence_length / cp_world_size;

if (cp_world_size > 1) {
CHECK_EQ(FLAGS_sequence_length % cp_world_size, 0) << "sequence_length must be divisible by context_parallel.";
}
if (FLAGS_sequence_parallel) {
CHECK_EQ(FLAGS_sequence_length % tp_world_size, 0)
<< "sequence_length must be divisible by tp_world_size when SP is enabled (pad later if needed).";
CHECK_EQ(cp_local_sequence_length % tp_world_size, 0)
<< "CP-local sequence_length must be divisible by tp_world_size when SP is enabled.";
}
if (FLAGS_zero_stage >= 1) {
CHECK_GT(dp_cp_world_size, 1) << "ZeRO requires DP or CP world size greater than 1.";
CHECK_LE(FLAGS_zero_stage, 2) << "ZeRO-3 is not supported yet.";
}

int ddp_rank = 0;
int dp_cp_rank = 0;
int tp_rank = 0;
int cp_rank = 0;
int pp_rank = 0;

// Set thread-local global rank
// TODO(dcj): Use DeviceGuardImpl to get GlobalRank later.
nn::parallel::global::thread_global_rank = rank.GlobalRank();

const ProcessGroup *ddp_pg = nullptr;
const ProcessGroup *dp_cp_pg = nullptr;
const ProcessGroup *tp_pg = nullptr;
const ProcessGroup *cp_pg = nullptr;
const ProcessGroup *pp_pg = nullptr;

if (rank.IsParallel()) {
Expand All @@ -179,6 +201,12 @@ void Train(const nn::parallel::Rank &rank) {
ddp_rank = ddp_pg->GetGroupRank(rank.GlobalRank());
}

if (dp_cp_world_size > 1) {
dp_cp_pg = pg_factory->GetOrCreate(GetDataParallelWithContextProcessGroupName(rank.GlobalRank()),
GetDataParallelWithContextGroupRanks(rank.GlobalRank()));
dp_cp_rank = dp_cp_pg->GetGroupRank(rank.GlobalRank());
}

if (tp_world_size > 1) {
tp_pg = pg_factory->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()),
GetTensorParallelGroupRanks(rank.GlobalRank()));
Expand All @@ -187,6 +215,13 @@ void Train(const nn::parallel::Rank &rank) {
nn::parallel::tp_rank = tp_rank;
}

if (cp_world_size > 1) {
cp_pg = pg_factory->GetOrCreate(GetContextParallelProcessGroupName(rank.GlobalRank()),
GetContextParallelGroupRanks(rank.GlobalRank()));
cp_rank = cp_pg->GetGroupRank(rank.GlobalRank());
nn::parallel::cp_rank = cp_rank;
}

if (pp_world_size > 1) {
pp_pg = pg_factory->GetOrCreate(GetPipelineParallelProcessGroupName(rank.GlobalRank()),
GetPipelineParallelGroupRanks(rank.GlobalRank()));
Expand Down Expand Up @@ -272,21 +307,21 @@ void Train(const nn::parallel::Rank &rank) {

if (pp_world_size > 1) {
// NOTE(dcj): To ensure that the tensor shapes at the pipeline stage boundaries remain correct
// when sequence parallelism (SP) is enabled, we need to divide by sp_world_size.
// when context/sequence parallelism is enabled, use the CP-local and SP-local sequence length.
auto shapes = std::vector<std::vector<int64_t>>{
{FLAGS_batch_size, FLAGS_sequence_length / sp_world_size, model_config.n_embd}};
{FLAGS_batch_size, cp_local_sequence_length / sp_world_size, model_config.n_embd}};

model = std::make_shared<nn::parallel::PipelineParallel>(model, pp_world_size, num_micro_batches, shapes,
pp_rank, device, model_config.GetChunkSize());
if (ddp_world_size > 1) {
if (dp_cp_world_size > 1) {
auto ddp_config = DistributedDataParallelConfig{.zero_stage = FLAGS_zero_stage};
auto *mutable_chunks = dynamic_cast<nn::parallel::PipelineParallel *>(model.get())->mutable_chunks();
for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) {
(*mutable_chunks)[chunk_id]
= std::make_shared<DistributedDataParallel>(mutable_chunks->at(chunk_id), rank, ddp_config);
}
}
} else if (ddp_world_size > 1) {
} else if (dp_cp_world_size > 1) {
// NOTE(dcj): Complete all device (.to(device)) and dtype (.to(dtype)) conversions
// before wrapping the model with DistributedDataParallel (DDP).
// Otherwise, DDP’s gradient hooks may be lost because new parameter tensors
Expand Down Expand Up @@ -325,7 +360,7 @@ void Train(const nn::parallel::Rank &rank) {
? *(dynamic_cast<nn::parallel::PipelineParallel *>(model.get())->mutable_chunks())
: std::vector<std::shared_ptr<nn::Module>>{model};
optimizer = std::make_shared<nn::parallel::DistributedOptimizer>(optimizer_creator, params_to_optimize,
model_chunks, ddp_world_size, ddp_rank);
model_chunks, dp_cp_world_size, dp_cp_rank);
} else {
optimizer = optimizer_creator(params_to_optimize);
}
Expand Down Expand Up @@ -376,6 +411,7 @@ void Train(const nn::parallel::Rank &rank) {
.ddp_size = ddp_world_size,
.tp_size = tp_world_size,
.sp_size = sp_world_size,
.cp_size = cp_world_size,
.pp_size = pp_world_size,
.save_optimizer_state = FLAGS_save_optimizer_state,
.checkpoint_root_dir = FLAGS_save,
Expand Down Expand Up @@ -441,6 +477,10 @@ void Train(const nn::parallel::Rank &rank) {
consumed_batches = train_iter.BatchIndex();
x = std::make_shared<Tensor>(x->To(device));
y = std::make_shared<Tensor>(y->To(device));
if (cp_world_size > 1) {
x = nn::parallel::SliceAlongCPRegionFunc(x, 1);
y = nn::parallel::SliceAlongCPRegionFunc(y, 1);
}

LOG(INFO) << "Rank " << rank.GlobalRank() << ": start forward";

Expand Down Expand Up @@ -472,13 +512,17 @@ void Train(const nn::parallel::Rank &rank) {
consumed_batches = train_iter.BatchIndex();
x = std::make_shared<Tensor>(x->To(device));
y = std::make_shared<Tensor>(y->To(device));
if (cp_world_size > 1) {
x = nn::parallel::SliceAlongCPRegionFunc(x, 1);
y = nn::parallel::SliceAlongCPRegionFunc(y, 1);
}

lossf = model->TrainStep({x}, {y}, optimizer, loss_fn, dtype);
}

if (ddp_world_size > 1) {
if (dp_cp_world_size > 1) {
auto lossf_tensor = std::make_shared<Tensor>(&lossf, std::vector<int64_t>{}, DataType::kFLOAT32, device);
function::AllReduce(lossf_tensor, function::ReduceOpType::kAvg, ddp_pg);
function::AllReduce(lossf_tensor, function::ReduceOpType::kAvg, dp_cp_pg);
lossf = static_cast<const float *>(lossf_tensor->To(Device()).DataPtr())[0];
}

Expand All @@ -491,10 +535,11 @@ void Train(const nn::parallel::Rank &rank) {
std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device);

LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | "
"peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})",
"peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, CP={}, "
"PP={})",
step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f,
tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size,
pp_world_size);
cp_world_size, pp_world_size);

if ((step + 1) % FLAGS_freq_generate_txt == 0) {
if (tokenizer) {
Expand Down Expand Up @@ -535,7 +580,8 @@ int main(int argc, char *argv[]) {

auto precision_config = utils::PrecisionCheckConfig::Parse(FLAGS_precision_check);
nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel,
FLAGS_pipeline_parallel, FLAGS_virtual_pipeline_parallel);
FLAGS_context_parallel, FLAGS_cp_comm_type, FLAGS_pipeline_parallel,
FLAGS_virtual_pipeline_parallel);
utils::PrecisionCheckEnv::Instance().Init(precision_config);

LOG(INFO) << nn::parallel::global::ProcessGroupOverview();
Expand Down
Loading
Loading