diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index b666b6b1..401c5ab6 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -4,8 +4,10 @@ #include #include #include +#include #include #include +#include #include "gflags/gflags.h" #include "glog/logging.h" @@ -19,6 +21,7 @@ #include "infini_train/include/nn/modules/loss.h" #include "infini_train/include/nn/modules/module.h" #include "infini_train/include/nn/modules/transformer/transformer.h" +#include "infini_train/include/nn/parallel/context_parallel.h" #include "infini_train/include/nn/parallel/ddp/distributed_data_parallel.h" #include "infini_train/include/nn/parallel/ddp/distributed_optimizer.h" #include "infini_train/include/nn/parallel/global.h" @@ -76,6 +79,8 @@ DEFINE_int32( "When set > 1, enables data parallelism with device=cuda on the specified number of visible CUDA devices."); DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size"); DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel"); +DEFINE_uint32(context_parallel, 1, "Context Parallel world size"); +DEFINE_string(cp_comm_type, "p2p", "Context Parallel communication type (all_gather|p2p|a2a)"); DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the number of PP stages."); DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage."); @@ -124,6 +129,9 @@ DEFINE_validator(model, [](const char *, const std::string &value) { return kSup DEFINE_validator(device, [](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; }); DEFINE_validator(zero_stage, [](const char *, int32_t value) { return value >= 0 && value <= 3; }); +DEFINE_validator(cp_comm_type, [](const char *, const std::string &value) { + return value == "all_gather" || value == "p2p" || value == "a2a"; +}); void Train(const nn::parallel::Rank &rank) { using namespace nn::parallel; @@ -150,15 +158,27 @@ void Train(const nn::parallel::Rank &rank) { int ddp_world_size = global::GetDataParallelSize(); int tp_world_size = global::GetTensorParallelSize(); int sp_world_size = global::GetSequenceParallelEnabled() ? tp_world_size : 1; + int cp_world_size = global::GetContextParallelSize(); + int dp_cp_world_size = ddp_world_size * cp_world_size; int pp_world_size = global::GetPipelineParallelSize(); + const auto cp_local_sequence_length = FLAGS_sequence_length / cp_world_size; + if (cp_world_size > 1) { + CHECK_EQ(FLAGS_sequence_length % cp_world_size, 0) << "sequence_length must be divisible by context_parallel."; + } if (FLAGS_sequence_parallel) { - CHECK_EQ(FLAGS_sequence_length % tp_world_size, 0) - << "sequence_length must be divisible by tp_world_size when SP is enabled (pad later if needed)."; + CHECK_EQ(cp_local_sequence_length % tp_world_size, 0) + << "CP-local sequence_length must be divisible by tp_world_size when SP is enabled."; + } + if (FLAGS_zero_stage >= 1) { + CHECK_GT(dp_cp_world_size, 1) << "ZeRO requires DP or CP world size greater than 1."; + CHECK_LE(FLAGS_zero_stage, 2) << "ZeRO-3 is not supported yet."; } int ddp_rank = 0; + int dp_cp_rank = 0; int tp_rank = 0; + int cp_rank = 0; int pp_rank = 0; // Set thread-local global rank @@ -166,7 +186,9 @@ void Train(const nn::parallel::Rank &rank) { nn::parallel::global::thread_global_rank = rank.GlobalRank(); const ProcessGroup *ddp_pg = nullptr; + const ProcessGroup *dp_cp_pg = nullptr; const ProcessGroup *tp_pg = nullptr; + const ProcessGroup *cp_pg = nullptr; const ProcessGroup *pp_pg = nullptr; if (rank.IsParallel()) { @@ -179,6 +201,12 @@ void Train(const nn::parallel::Rank &rank) { ddp_rank = ddp_pg->GetGroupRank(rank.GlobalRank()); } + if (dp_cp_world_size > 1) { + dp_cp_pg = pg_factory->GetOrCreate(GetDataParallelWithContextProcessGroupName(rank.GlobalRank()), + GetDataParallelWithContextGroupRanks(rank.GlobalRank())); + dp_cp_rank = dp_cp_pg->GetGroupRank(rank.GlobalRank()); + } + if (tp_world_size > 1) { tp_pg = pg_factory->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()), GetTensorParallelGroupRanks(rank.GlobalRank())); @@ -187,6 +215,13 @@ void Train(const nn::parallel::Rank &rank) { nn::parallel::tp_rank = tp_rank; } + if (cp_world_size > 1) { + cp_pg = pg_factory->GetOrCreate(GetContextParallelProcessGroupName(rank.GlobalRank()), + GetContextParallelGroupRanks(rank.GlobalRank())); + cp_rank = cp_pg->GetGroupRank(rank.GlobalRank()); + nn::parallel::cp_rank = cp_rank; + } + if (pp_world_size > 1) { pp_pg = pg_factory->GetOrCreate(GetPipelineParallelProcessGroupName(rank.GlobalRank()), GetPipelineParallelGroupRanks(rank.GlobalRank())); @@ -272,13 +307,13 @@ void Train(const nn::parallel::Rank &rank) { if (pp_world_size > 1) { // NOTE(dcj): To ensure that the tensor shapes at the pipeline stage boundaries remain correct - // when sequence parallelism (SP) is enabled, we need to divide by sp_world_size. + // when context/sequence parallelism is enabled, use the CP-local and SP-local sequence length. auto shapes = std::vector>{ - {FLAGS_batch_size, FLAGS_sequence_length / sp_world_size, model_config.n_embd}}; + {FLAGS_batch_size, cp_local_sequence_length / sp_world_size, model_config.n_embd}}; model = std::make_shared(model, pp_world_size, num_micro_batches, shapes, pp_rank, device, model_config.GetChunkSize()); - if (ddp_world_size > 1) { + if (dp_cp_world_size > 1) { auto ddp_config = DistributedDataParallelConfig{.zero_stage = FLAGS_zero_stage}; auto *mutable_chunks = dynamic_cast(model.get())->mutable_chunks(); for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) { @@ -286,7 +321,7 @@ void Train(const nn::parallel::Rank &rank) { = std::make_shared(mutable_chunks->at(chunk_id), rank, ddp_config); } } - } else if (ddp_world_size > 1) { + } else if (dp_cp_world_size > 1) { // NOTE(dcj): Complete all device (.to(device)) and dtype (.to(dtype)) conversions // before wrapping the model with DistributedDataParallel (DDP). // Otherwise, DDP’s gradient hooks may be lost because new parameter tensors @@ -325,7 +360,7 @@ void Train(const nn::parallel::Rank &rank) { ? *(dynamic_cast(model.get())->mutable_chunks()) : std::vector>{model}; optimizer = std::make_shared(optimizer_creator, params_to_optimize, - model_chunks, ddp_world_size, ddp_rank); + model_chunks, dp_cp_world_size, dp_cp_rank); } else { optimizer = optimizer_creator(params_to_optimize); } @@ -376,6 +411,7 @@ void Train(const nn::parallel::Rank &rank) { .ddp_size = ddp_world_size, .tp_size = tp_world_size, .sp_size = sp_world_size, + .cp_size = cp_world_size, .pp_size = pp_world_size, .save_optimizer_state = FLAGS_save_optimizer_state, .checkpoint_root_dir = FLAGS_save, @@ -441,6 +477,10 @@ void Train(const nn::parallel::Rank &rank) { consumed_batches = train_iter.BatchIndex(); x = std::make_shared(x->To(device)); y = std::make_shared(y->To(device)); + if (cp_world_size > 1) { + x = nn::parallel::SliceAlongCPRegionFunc(x, 1); + y = nn::parallel::SliceAlongCPRegionFunc(y, 1); + } LOG(INFO) << "Rank " << rank.GlobalRank() << ": start forward"; @@ -472,13 +512,17 @@ void Train(const nn::parallel::Rank &rank) { consumed_batches = train_iter.BatchIndex(); x = std::make_shared(x->To(device)); y = std::make_shared(y->To(device)); + if (cp_world_size > 1) { + x = nn::parallel::SliceAlongCPRegionFunc(x, 1); + y = nn::parallel::SliceAlongCPRegionFunc(y, 1); + } lossf = model->TrainStep({x}, {y}, optimizer, loss_fn, dtype); } - if (ddp_world_size > 1) { + if (dp_cp_world_size > 1) { auto lossf_tensor = std::make_shared(&lossf, std::vector{}, DataType::kFLOAT32, device); - function::AllReduce(lossf_tensor, function::ReduceOpType::kAvg, ddp_pg); + function::AllReduce(lossf_tensor, function::ReduceOpType::kAvg, dp_cp_pg); lossf = static_cast(lossf_tensor->To(Device()).DataPtr())[0]; } @@ -491,10 +535,11 @@ void Train(const nn::parallel::Rank &rank) { std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device); LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | " - "peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})", + "peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, CP={}, " + "PP={})", step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f, tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size, - pp_world_size); + cp_world_size, pp_world_size); if ((step + 1) % FLAGS_freq_generate_txt == 0) { if (tokenizer) { @@ -535,7 +580,8 @@ int main(int argc, char *argv[]) { auto precision_config = utils::PrecisionCheckConfig::Parse(FLAGS_precision_check); nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel, - FLAGS_pipeline_parallel, FLAGS_virtual_pipeline_parallel); + FLAGS_context_parallel, FLAGS_cp_comm_type, FLAGS_pipeline_parallel, + FLAGS_virtual_pipeline_parallel); utils::PrecisionCheckEnv::Instance().Init(precision_config); LOG(INFO) << nn::parallel::global::ProcessGroupOverview(); diff --git a/example/llama3/main.cc b/example/llama3/main.cc index 5a191865..824bf7ab 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -3,7 +3,9 @@ #include #include #include +#include #include +#include #include "gflags/gflags.h" #include "glog/logging.h" @@ -18,6 +20,7 @@ #include "infini_train/include/nn/modules/loss.h" #include "infini_train/include/nn/modules/module.h" #include "infini_train/include/nn/modules/transformer/transformer.h" +#include "infini_train/include/nn/parallel/context_parallel.h" #include "infini_train/include/nn/parallel/ddp/distributed_data_parallel.h" #include "infini_train/include/nn/parallel/ddp/distributed_optimizer.h" #include "infini_train/include/nn/parallel/global.h" @@ -75,6 +78,8 @@ DEFINE_int32( "When set > 1, enables data parallelism with device=cuda on the specified number of visible CUDA devices."); DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size"); DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel"); +DEFINE_uint32(context_parallel, 1, "Context Parallel world size"); +DEFINE_string(cp_comm_type, "p2p", "Context Parallel communication type (all_gather|p2p|a2a)"); DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the number of PP stages."); DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage."); // precision @@ -105,12 +110,16 @@ constexpr char kDeviceCPU[] = "cpu"; constexpr char kDeviceCUDA[] = "cuda"; constexpr char kDtypeFP32[] = "float32"; constexpr char kDtypeBF16[] = "bfloat16"; + } // namespace DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); }); DEFINE_validator(device, [](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; }); DEFINE_validator(zero_stage, [](const char *, int32_t value) { return value >= 0 && value <= 3; }); +DEFINE_validator(cp_comm_type, [](const char *, const std::string &value) { + return value == "all_gather" || value == "p2p" || value == "a2a"; +}); void Train(const nn::parallel::Rank &rank) { using namespace nn::parallel; @@ -137,22 +146,36 @@ void Train(const nn::parallel::Rank &rank) { int ddp_world_size = global::GetDataParallelSize(); int tp_world_size = global::GetTensorParallelSize(); int sp_world_size = global::GetSequenceParallelEnabled() ? tp_world_size : 1; + int cp_world_size = global::GetContextParallelSize(); + int dp_cp_world_size = ddp_world_size * cp_world_size; int pp_world_size = global::GetPipelineParallelSize(); + const auto cp_local_sequence_length = FLAGS_sequence_length / cp_world_size; + if (cp_world_size > 1) { + CHECK_EQ(FLAGS_sequence_length % cp_world_size, 0) << "sequence_length must be divisible by context_parallel."; + } if (FLAGS_sequence_parallel) { - CHECK_EQ(FLAGS_sequence_length % tp_world_size, 0) - << "sequence_length must be divisible by tp_world_size when SP is enabled (pad later if needed)."; + CHECK_EQ(cp_local_sequence_length % tp_world_size, 0) + << "CP-local sequence_length must be divisible by tp_world_size when SP is enabled."; + } + if (FLAGS_zero_stage >= 1) { + CHECK_GT(dp_cp_world_size, 1) << "ZeRO requires DP or CP world size greater than 1."; + CHECK_LE(FLAGS_zero_stage, 2) << "ZeRO-3 is not supported yet."; } int ddp_rank = 0; + int dp_cp_rank = 0; int tp_rank = 0; + int cp_rank = 0; int pp_rank = 0; // Set thread-local global rank nn::parallel::global::thread_global_rank = rank.GlobalRank(); const ProcessGroup *ddp_pg = nullptr; + const ProcessGroup *dp_cp_pg = nullptr; const ProcessGroup *tp_pg = nullptr; + const ProcessGroup *cp_pg = nullptr; const ProcessGroup *pp_pg = nullptr; if (rank.IsParallel()) { @@ -165,6 +188,12 @@ void Train(const nn::parallel::Rank &rank) { ddp_rank = ddp_pg->GetGroupRank(rank.GlobalRank()); } + if (dp_cp_world_size > 1) { + dp_cp_pg = pg_factory->GetOrCreate(GetDataParallelWithContextProcessGroupName(rank.GlobalRank()), + GetDataParallelWithContextGroupRanks(rank.GlobalRank())); + dp_cp_rank = dp_cp_pg->GetGroupRank(rank.GlobalRank()); + } + if (tp_world_size > 1) { tp_pg = pg_factory->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()), GetTensorParallelGroupRanks(rank.GlobalRank())); @@ -173,6 +202,13 @@ void Train(const nn::parallel::Rank &rank) { nn::parallel::tp_rank = tp_rank; } + if (cp_world_size > 1) { + cp_pg = pg_factory->GetOrCreate(GetContextParallelProcessGroupName(rank.GlobalRank()), + GetContextParallelGroupRanks(rank.GlobalRank())); + cp_rank = cp_pg->GetGroupRank(rank.GlobalRank()); + nn::parallel::cp_rank = cp_rank; + } + if (pp_world_size > 1) { pp_pg = pg_factory->GetOrCreate(GetPipelineParallelProcessGroupName(rank.GlobalRank()), GetPipelineParallelGroupRanks(rank.GlobalRank())); @@ -243,13 +279,13 @@ void Train(const nn::parallel::Rank &rank) { if (pp_world_size > 1) { // NOTE(dcj): To ensure that the tensor shapes at the pipeline stage boundaries remain correct - // when sequence parallelism (SP) is enabled, we need to divide by sp_world_size. + // when context/sequence parallelism is enabled, use the CP-local and SP-local sequence length. auto shapes = std::vector>{ - {FLAGS_batch_size, FLAGS_sequence_length / sp_world_size, model_config.n_embd}}; + {FLAGS_batch_size, cp_local_sequence_length / sp_world_size, model_config.n_embd}}; model = std::make_shared(model, pp_world_size, num_micro_batches, shapes, pp_rank, device, model_config.GetChunkSize()); - if (ddp_world_size > 1) { + if (dp_cp_world_size > 1) { auto ddp_config = DistributedDataParallelConfig{.zero_stage = FLAGS_zero_stage}; auto *mutable_chunks = dynamic_cast(model.get())->mutable_chunks(); for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) { @@ -257,7 +293,7 @@ void Train(const nn::parallel::Rank &rank) { = std::make_shared(mutable_chunks->at(chunk_id), rank, ddp_config); } } - } else if (ddp_world_size > 1) { + } else if (dp_cp_world_size > 1) { // NOTE(dcj): Complete all device (.to(device)) and dtype (.to(dtype)) conversions // before wrapping the model with DistributedDataParallel (DDP). // Otherwise, DDP’s gradient hooks may be lost because new parameter tensors @@ -305,7 +341,7 @@ void Train(const nn::parallel::Rank &rank) { ? *(dynamic_cast(model.get())->mutable_chunks()) : std::vector>{model}; optimizer = std::make_shared(optimizer_creator, params_to_optimize, - model_chunks, ddp_world_size, ddp_rank); + model_chunks, dp_cp_world_size, dp_cp_rank); } else { optimizer = optimizer_creator(params_to_optimize); } @@ -356,6 +392,7 @@ void Train(const nn::parallel::Rank &rank) { .ddp_size = ddp_world_size, .tp_size = tp_world_size, .sp_size = sp_world_size, + .cp_size = cp_world_size, .pp_size = pp_world_size, .save_optimizer_state = FLAGS_save_optimizer_state, .checkpoint_root_dir = FLAGS_save, @@ -419,6 +456,10 @@ void Train(const nn::parallel::Rank &rank) { consumed_batches = train_iter.BatchIndex(); x = std::make_shared(x->To(device)); y = std::make_shared(y->To(device)); + if (cp_world_size > 1) { + x = nn::parallel::SliceAlongCPRegionFunc(x, 1); + y = nn::parallel::SliceAlongCPRegionFunc(y, 1); + } LOG(INFO) << "Rank " << rank.GlobalRank() << ": start forward"; // (bs, seq_len, vocab_size) @@ -449,13 +490,17 @@ void Train(const nn::parallel::Rank &rank) { consumed_batches = train_iter.BatchIndex(); x = std::make_shared(x->To(device)); y = std::make_shared(y->To(device)); + if (cp_world_size > 1) { + x = nn::parallel::SliceAlongCPRegionFunc(x, 1); + y = nn::parallel::SliceAlongCPRegionFunc(y, 1); + } lossf = model->TrainStep({x}, {y}, optimizer, loss_fn, dtype); } - if (ddp_world_size > 1) { + if (dp_cp_world_size > 1) { auto lossf_tensor = std::make_shared(&lossf, std::vector{}, DataType::kFLOAT32, device); - function::AllReduce(lossf_tensor, function::ReduceOpType::kAvg, ddp_pg); + function::AllReduce(lossf_tensor, function::ReduceOpType::kAvg, dp_cp_pg); lossf = static_cast(lossf_tensor->To(Device()).DataPtr())[0]; } @@ -468,10 +513,11 @@ void Train(const nn::parallel::Rank &rank) { std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device); LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | " - "peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})", + "peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, CP={}, " + "PP={})", step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f, tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size, - pp_world_size); + cp_world_size, pp_world_size); if ((step + 1) % FLAGS_freq_generate_txt == 0) { // FIXME(jym): to support PP @@ -512,7 +558,8 @@ int main(int argc, char *argv[]) { auto precision_config = utils::PrecisionCheckConfig::Parse(FLAGS_precision_check); nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel, - FLAGS_pipeline_parallel, FLAGS_virtual_pipeline_parallel); + FLAGS_context_parallel, FLAGS_cp_comm_type, FLAGS_pipeline_parallel, + FLAGS_virtual_pipeline_parallel); utils::PrecisionCheckEnv::Instance().Init(precision_config); LOG(INFO) << nn::parallel::global::ProcessGroupOverview(); diff --git a/infini_train/include/checkpoint/checkpoint.h b/infini_train/include/checkpoint/checkpoint.h index b122a17a..d8872b4e 100644 --- a/infini_train/include/checkpoint/checkpoint.h +++ b/infini_train/include/checkpoint/checkpoint.h @@ -28,6 +28,7 @@ struct TrainerState { int ddp_size = 1; int tp_size = 1; int sp_size = 1; + int cp_size = 1; int pp_size = 1; }; diff --git a/infini_train/include/checkpoint/checkpoint_manager.h b/infini_train/include/checkpoint/checkpoint_manager.h index cce14107..e50bbc31 100644 --- a/infini_train/include/checkpoint/checkpoint_manager.h +++ b/infini_train/include/checkpoint/checkpoint_manager.h @@ -45,6 +45,7 @@ struct SaveCheckpointArgs { int ddp_size = 1; int tp_size = 1; int sp_size = 1; + int cp_size = 1; int pp_size = 1; bool save_optimizer_state = true; std::filesystem::path checkpoint_root_dir; diff --git a/infini_train/include/core/ccl/ccl.h b/infini_train/include/core/ccl/ccl.h index 626cb078..7d6b967b 100644 --- a/infini_train/include/core/ccl/ccl.h +++ b/infini_train/include/core/ccl/ccl.h @@ -53,6 +53,9 @@ class CclImpl { nn::parallel::function::ReduceOpType reduce_op, const CclComm *comm, Stream *stream) const; + virtual void AllToAll(const void *sendbuff, void *recvbuff, size_t count, DataType dtype, const CclComm *comm, + Stream *stream) const; + virtual void Send(const void *buff, size_t count, DataType dtype, int peer, const CclComm *comm, Stream *stream) const; diff --git a/infini_train/include/nn/parallel/context_parallel.h b/infini_train/include/nn/parallel/context_parallel.h new file mode 100644 index 00000000..309d8802 --- /dev/null +++ b/infini_train/include/nn/parallel/context_parallel.h @@ -0,0 +1,36 @@ +#pragma once + +#include +#include +#include + +namespace infini_train { +class Tensor; +} // namespace infini_train + +namespace infini_train::nn::parallel { + +extern thread_local int cp_rank; + +int GetContextParallelRank(); + +int64_t GetContextParallelSequenceStart(int64_t local_sequence_length); + +std::shared_ptr SliceAlongCPRegionFunc(const std::shared_ptr &input, int64_t dim); + +std::shared_ptr GatherFromCPRegionFunc(const std::shared_ptr &input); + +std::shared_ptr AttnForwardFuncWithCP(const std::shared_ptr &q, const std::shared_ptr &k, + const std::shared_ptr &v, const std::shared_ptr &mask); + +std::shared_ptr AttnFuncWithCPAndKVP2P(const std::shared_ptr &q, const std::shared_ptr &k, + const std::shared_ptr &v, const std::shared_ptr &mask); + +std::shared_ptr AttnFuncWithCPAndKVAllGather(const std::shared_ptr &q, const std::shared_ptr &k, + const std::shared_ptr &v, + const std::shared_ptr &mask); + +std::shared_ptr AttnFuncWithCPAndQKVOA2A(const std::shared_ptr &q, const std::shared_ptr &k, + const std::shared_ptr &v, const std::shared_ptr &mask); + +} // namespace infini_train::nn::parallel diff --git a/infini_train/include/nn/parallel/global.h b/infini_train/include/nn/parallel/global.h index 9373100f..0ec929c0 100644 --- a/infini_train/include/nn/parallel/global.h +++ b/infini_train/include/nn/parallel/global.h @@ -8,20 +8,24 @@ namespace infini_train::nn::parallel::global { extern thread_local int thread_global_rank; -enum Axis : uint8_t { DP = 0, TP = 1, PP = 2, AXIS_COUNT = 3 }; +enum Axis : uint8_t { DP = 0, TP = 1, CP = 2, PP = 3, AXIS_COUNT = 4 }; struct Layout { - int sizes[AXIS_COUNT]{1, 1, 1}; - // Default order according to Megatron-LM is TP-DP-PP. Ref: + int sizes[AXIS_COUNT]{1, 1, 1, 1}; + // Default order follows Megatron-LM's TP-CP-DP-PP layout for context parallelism. Ref: // https://github.com/NVIDIA/Megatron-LM/blob/e07c4a4450b6faa187a1ef4ec082a35ad7d2f085/megatron/core/parallel_state.py#L618 - Axis order[AXIS_COUNT]{TP, DP, PP}; - int strides[AXIS_COUNT]{1, 1, 1}; + Axis order[AXIS_COUNT]{TP, CP, DP, PP}; + int strides[AXIS_COUNT]{1, 1, 1, 1}; void InitStrides(); int RankOf(int dp, int tp, int pp) const; + int RankOf(int dp, int tp, int cp, int pp) const; void CoordOf(int rank, int &dp, int &tp, int &pp) const; + void CoordOf(int rank, int &dp, int &tp, int &cp, int &pp) const; int GroupId(Axis target, int dp, int tp, int pp) const; + int GroupId(Axis target, int dp, int tp, int cp, int pp) const; std::vector GroupRanks(Axis target, int fixed_dp, int fixed_tp, int fixed_pp) const; + std::vector GroupRanks(Axis target, int fixed_dp, int fixed_tp, int fixed_cp, int fixed_pp) const; }; class GlobalEnv { @@ -29,7 +33,8 @@ class GlobalEnv { static GlobalEnv &Instance(); void Init(int threads_per_process, int tensor_parallel_size, bool sequence_parallel_enabled, - int pipeline_parallel_size, int virtual_pipeline_parallel_size); + int context_parallel_size, const std::string &context_parallel_comm_type, int pipeline_parallel_size, + int virtual_pipeline_parallel_size); int nnodes() const; @@ -49,6 +54,10 @@ class GlobalEnv { bool sequence_parallel_enabled() const; + int context_parallel_size() const; + + const std::string &context_parallel_comm_type() const; + int data_parallel_size() const; int pipeline_parallel_size() const; @@ -75,6 +84,8 @@ class GlobalEnv { int tensor_parallel_size_ = 1; bool sequence_parallel_enabled_ = false; + int context_parallel_size_ = 1; + std::string context_parallel_comm_type_ = "p2p"; int data_parallel_size_ = 1; @@ -90,8 +101,16 @@ class GlobalEnv { inline void InitAllEnv(int nthread_per_process, int tensor_parallel_size, bool sequence_parallel_enabled, int pipeline_parallel_size, int virtual_pipeline_parallel) { GlobalEnv::Instance().Init(nthread_per_process, tensor_parallel_size, sequence_parallel_enabled, + /*context_parallel_size=*/1, /*context_parallel_comm_type=*/"p2p", pipeline_parallel_size, virtual_pipeline_parallel); } +inline void InitAllEnv(int nthread_per_process, int tensor_parallel_size, bool sequence_parallel_enabled, + int context_parallel_size, const std::string &context_parallel_comm_type, + int pipeline_parallel_size, int virtual_pipeline_parallel) { + GlobalEnv::Instance().Init(nthread_per_process, tensor_parallel_size, sequence_parallel_enabled, + context_parallel_size, context_parallel_comm_type, pipeline_parallel_size, + virtual_pipeline_parallel); +} inline int GetNnodes() { return GlobalEnv::Instance().nnodes(); } inline int GetWorldSize() { return GlobalEnv::Instance().world_size(); } inline int GetNprocPerNode() { return GlobalEnv::Instance().nproc_per_node(); } @@ -102,6 +121,8 @@ inline int GetLocalProcRank() { return GlobalEnv::Instance().local_proc_rank(); inline int GetTensorParallelSize() { return GlobalEnv::Instance().tensor_parallel_size(); } inline int GetSequenceParallelSize() { return GlobalEnv::Instance().sequence_parallel_size(); } inline bool GetSequenceParallelEnabled() { return GlobalEnv::Instance().sequence_parallel_enabled(); } +inline int GetContextParallelSize() { return GlobalEnv::Instance().context_parallel_size(); } +inline const std::string &GetContextParallelCommType() { return GlobalEnv::Instance().context_parallel_comm_type(); } inline int GetDataParallelSize() { return GlobalEnv::Instance().data_parallel_size(); } inline int GetPipelineParallelSize() { return GlobalEnv::Instance().pipeline_parallel_size(); } inline int GetVirtualPipelineParallelSize() { return GlobalEnv::Instance().virtual_pipeline_parallel_size(); } @@ -114,12 +135,16 @@ inline int GetVirtualPipelineParallelSize() { return GlobalEnv::Instance().virtu * @brief Get the global rank corresponding to the given (dp, tp, pp) coordinate. */ inline int GetRankOf(int dp, int tp, int pp) { return GlobalEnv::Instance().layout().RankOf(dp, tp, pp); } +inline int GetRankOf(int dp, int tp, int cp, int pp) { return GlobalEnv::Instance().layout().RankOf(dp, tp, cp, pp); } /** * @brief Get the (dp, tp, pp) coordinate corresponding to the given global rank. */ inline void GetCoordOf(int rank, int &dp, int &tp, int &pp) { return GlobalEnv::Instance().layout().CoordOf(rank, dp, tp, pp); } +inline void GetCoordOf(int rank, int &dp, int &tp, int &cp, int &pp) { + return GlobalEnv::Instance().layout().CoordOf(rank, dp, tp, cp, pp); +} /** * @brief Get the group ID that the (dp, tp, pp) coordinate belongs to along a given parallel axis. @@ -127,13 +152,16 @@ inline void GetCoordOf(int rank, int &dp, int &tp, int &pp) { inline int GetGroupId(Axis target, int dp, int tp, int pp) { return GlobalEnv::Instance().layout().GroupId(target, dp, tp, pp); } +inline int GetGroupId(Axis target, int dp, int tp, int cp, int pp) { + return GlobalEnv::Instance().layout().GroupId(target, dp, tp, cp, pp); +} /** * @brief Get the group ID that a given rank belongs to along a specific parallel axis. */ inline int GetGroupId(Axis target, int rank) { - int dp, tp, pp; - GetCoordOf(rank, dp, tp, pp); - return GlobalEnv::Instance().layout().GroupId(target, dp, tp, pp); + int dp, tp, cp, pp; + GetCoordOf(rank, dp, tp, cp, pp); + return GlobalEnv::Instance().layout().GroupId(target, dp, tp, cp, pp); } /** @@ -143,15 +171,18 @@ inline int GetGroupId(Axis target, int rank) { inline std::vector GetGroupRanks(Axis target, int dp, int tp, int pp) { return GlobalEnv::Instance().layout().GroupRanks(target, dp, tp, pp); } +inline std::vector GetGroupRanks(Axis target, int dp, int tp, int cp, int pp) { + return GlobalEnv::Instance().layout().GroupRanks(target, dp, tp, cp, pp); +} /** * @brief Get all ranks that belong to the same group as the given rank * along a specified parallel axis (e.g., all ranks in the same DP group). */ inline std::vector GetGroupRanks(Axis target, int rank) { - int dp, tp, pp; - GetCoordOf(rank, dp, tp, pp); - return GlobalEnv::Instance().layout().GroupRanks(target, dp, tp, pp); + int dp, tp, cp, pp; + GetCoordOf(rank, dp, tp, cp, pp); + return GlobalEnv::Instance().layout().GroupRanks(target, dp, tp, cp, pp); } /** @@ -160,7 +191,7 @@ inline std::vector GetGroupRanks(Axis target, int rank) { * The output is intended for debugging, logging, and runtime verification of * distributed parallelism configuration. * - * @param L The Layout describing DP / TP / PP sizes and axis ordering. + * @param L The Layout describing DP / TP / CP / PP sizes and axis ordering. * @param skip_trivial_axes * If true, axes whose size <= 1(i.e. parallel strategy that is not enabled) * will be marked as "unenabled" and their detailed group listing will be skipped. diff --git a/infini_train/include/nn/parallel/parallel_functional.h b/infini_train/include/nn/parallel/parallel_functional.h index 2eed56f4..1d39c38b 100644 --- a/infini_train/include/nn/parallel/parallel_functional.h +++ b/infini_train/include/nn/parallel/parallel_functional.h @@ -25,6 +25,9 @@ std::shared_ptr AllGather(const std::shared_ptr &output, const std std::shared_ptr ReduceScatter(const std::shared_ptr &output, const std::shared_ptr &input, ReduceOpType reduce_op, const ProcessGroup *pg = nullptr, bool async_op = false); +std::shared_ptr AllToAll(const std::shared_ptr &output, const std::shared_ptr &input, + const ProcessGroup *pg = nullptr, bool async_op = false); + std::vector>> Scatter(const std::vector> &input_tensors, const std::vector &device_ids, int dim); diff --git a/infini_train/include/nn/parallel/process_group.h b/infini_train/include/nn/parallel/process_group.h index 74bf80c6..d5c4f3ac 100644 --- a/infini_train/include/nn/parallel/process_group.h +++ b/infini_train/include/nn/parallel/process_group.h @@ -30,6 +30,17 @@ class Work; namespace infini_train::nn::parallel { +enum class P2POpType { + kSend, + kRecv, +}; + +struct P2POp { + P2POpType type; + std::shared_ptr tensor; + int peer_rank; +}; + class ProcessGroup { public: explicit ProcessGroup(Device::DeviceType backend, const std::string &process_group_name, @@ -39,6 +50,8 @@ class ProcessGroup { virtual int GetGroupRank(int global_rank) const; + int WorldSize() const { return world_size_; } + // Asynchronous communication APIs (Compute / Communication stream decoupled) virtual std::shared_ptr AllReduce(const std::shared_ptr &tensor, function::ReduceOpType reduce_op = function::ReduceOpType::kSum, @@ -52,12 +65,17 @@ class ProcessGroup { function::ReduceOpType reduce_op = function::ReduceOpType::kSum, bool async_op = false) const; + virtual std::shared_ptr AllToAll(const std::shared_ptr &output, const std::shared_ptr &input, + bool async_op = false) const; + virtual std::shared_ptr Send(std::vector> tensors, int dest_rank, bool async_op = false) const; virtual std::shared_ptr Recv(std::vector> tensors, int src_rank, bool async_op = false) const; + virtual std::shared_ptr BatchSendRecv(const std::vector &ops, bool async_op = false) const; + // Legacy communication APIs (Single-stream) virtual std::vector> BroadCast(const std::vector> &input_tensors) const; diff --git a/infini_train/include/nn/parallel/utils.h b/infini_train/include/nn/parallel/utils.h index 4dc737e7..d23671e5 100644 --- a/infini_train/include/nn/parallel/utils.h +++ b/infini_train/include/nn/parallel/utils.h @@ -11,14 +11,22 @@ class Tensor; namespace infini_train::nn::parallel { std::string GetDataParallelProcessGroupName(int global_rank); +std::string GetDataParallelWithContextProcessGroupName(int global_rank); + std::string GetTensorParallelProcessGroupName(int global_rank); +std::string GetContextParallelProcessGroupName(int global_rank); + std::string GetPipelineParallelProcessGroupName(int global_rank); std::vector GetDataParallelGroupRanks(int global_rank); +std::vector GetDataParallelWithContextGroupRanks(int global_rank); + std::vector GetTensorParallelGroupRanks(int global_rank); +std::vector GetContextParallelGroupRanks(int global_rank); + std::vector GetPipelineParallelGroupRanks(int global_rank); // TP/SP Communication Helper Functions diff --git a/infini_train/src/checkpoint/checkpoint.cc b/infini_train/src/checkpoint/checkpoint.cc index 892ec497..468d217e 100644 --- a/infini_train/src/checkpoint/checkpoint.cc +++ b/infini_train/src/checkpoint/checkpoint.cc @@ -105,8 +105,8 @@ void Checkpoint::Load(const std::filesystem::path &checkpoint_dir, nn::Module &m state = LoadTrainerState(checkpoint_dir / "trainer_state.json"); LOG(ERROR) << "[CKPT] Load done: global_step=" << state.global_step << ", consumed_batches =" << state.consumed_batches << ", last_lr=" << state.last_lr - << ", topology(ddp,tp,sp,pp)=(" << state.ddp_size << "," << state.tp_size << "," << state.sp_size << "," - << state.pp_size << ")"; + << ", topology(ddp,tp,sp,cp,pp)=(" << state.ddp_size << "," << state.tp_size << "," << state.sp_size + << "," << state.cp_size << "," << state.pp_size << ")"; } void Checkpoint::SaveStateDict(const std::filesystem::path &path, @@ -186,13 +186,14 @@ void Checkpoint::SaveTrainerState(const std::filesystem::path &path, const Train ofs << " \"n_head\": " << state.n_head << ",\n"; ofs << " \"n_kv_head\": " << state.n_kv_head << ",\n"; ofs << " \"n_embd\": " << state.n_embd << ",\n"; - ofs << " \"vocab_size\": " << state.vocab_size << "\n"; + ofs << " \"vocab_size\": " << state.vocab_size << ",\n"; ofs << " \"global_step\": " << state.global_step << ",\n"; ofs << " \"consumed_batches\": " << state.consumed_batches << ",\n"; ofs << " \"last_lr\": " << state.last_lr << ",\n"; ofs << " \"ddp_size\": " << state.ddp_size << ",\n"; ofs << " \"tp_size\": " << state.tp_size << ",\n"; ofs << " \"sp_size\": " << state.sp_size << ",\n"; + ofs << " \"cp_size\": " << state.cp_size << ",\n"; ofs << " \"pp_size\": " << state.pp_size << "\n"; ofs << "}\n"; } @@ -215,6 +216,7 @@ TrainerState Checkpoint::LoadTrainerState(const std::filesystem::path &path) { state.ddp_size = ExtractNumberField(content, "ddp_size", 1); state.tp_size = ExtractNumberField(content, "tp_size", 1); state.sp_size = ExtractNumberField(content, "sp_size", 1); + state.cp_size = ExtractNumberField(content, "cp_size", 1); state.pp_size = ExtractNumberField(content, "pp_size", 1); return state; } diff --git a/infini_train/src/checkpoint/checkpoint_manager.cc b/infini_train/src/checkpoint/checkpoint_manager.cc index 71e15c08..8730c588 100644 --- a/infini_train/src/checkpoint/checkpoint_manager.cc +++ b/infini_train/src/checkpoint/checkpoint_manager.cc @@ -28,6 +28,7 @@ ResumeFromCheckpointResult ResumeFromCheckpoint(const ResumeFromCheckpointArgs & int ddp_world_size = nn::parallel::global::GetDataParallelSize(); int tp_world_size = nn::parallel::global::GetTensorParallelSize(); int sp_world_size = nn::parallel::global::GetSequenceParallelEnabled() ? tp_world_size : 1; + int cp_world_size = nn::parallel::global::GetContextParallelSize(); int pp_world_size = nn::parallel::global::GetPipelineParallelSize(); std::filesystem::path resume_dir = args.resume_root; @@ -59,6 +60,8 @@ ResumeFromCheckpointResult ResumeFromCheckpoint(const ResumeFromCheckpointArgs & << "TP size mismatch: checkpoint has TP=" << args.state.tp_size << ", but current run has TP=" << tp_world_size; CHECK_EQ(args.state.sp_size, sp_world_size) << "SP size mismatch: checkpoint has SP=" << args.state.sp_size << ", but current run has SP=" << sp_world_size; + CHECK_EQ(args.state.cp_size, cp_world_size) + << "CP size mismatch: checkpoint has CP=" << args.state.cp_size << ", but current run has CP=" << cp_world_size; CHECK_EQ(args.state.pp_size, pp_world_size) << "PP size mismatch: checkpoint has PP=" << args.state.pp_size << ", but current run has PP=" << pp_world_size; @@ -86,6 +89,7 @@ void SaveCheckpoint(const SaveCheckpointArgs &args) { state.ddp_size = args.ddp_size; state.tp_size = args.tp_size; state.sp_size = args.sp_size; + state.cp_size = args.cp_size; state.pp_size = args.pp_size; Checkpoint::Save(args.save_dir, args.model, &args.optimizer, state, args.save_optimizer_state); diff --git a/infini_train/src/core/ccl/ccl.cc b/infini_train/src/core/ccl/ccl.cc index 92c14cc6..d981e521 100644 --- a/infini_train/src/core/ccl/ccl.cc +++ b/infini_train/src/core/ccl/ccl.cc @@ -54,6 +54,11 @@ void CclImpl::ReduceScatter(const void *sendbuff, void *recvbuff, size_t recv_co LOG(FATAL) << "CclImpl::ReduceScatter is not implemented."; } +void CclImpl::AllToAll(const void *sendbuff, void *recvbuff, size_t count, DataType dtype, const CclComm *comm, + Stream *stream) const { + LOG(FATAL) << "CclImpl::AllToAll is not implemented."; +} + void CclImpl::Send(const void *buff, size_t count, DataType dtype, int peer, const CclComm *comm, Stream *stream) const { LOG(FATAL) << "CclImpl::Send is not implemented."; diff --git a/infini_train/src/core/ccl/ccl_utils.cc b/infini_train/src/core/ccl/ccl_utils.cc index f8968c6b..a7868d66 100644 --- a/infini_train/src/core/ccl/ccl_utils.cc +++ b/infini_train/src/core/ccl/ccl_utils.cc @@ -1,55 +1,158 @@ #include "infini_train/include/core/ccl/ccl_utils.h" +#include +#include #include -#include +#include #include #include #include +#include +#include #include #include "glog/logging.h" namespace infini_train::core { namespace { -std::string UniqueIdFileName(const std::string &name, bool tmp = false) { - return "cclUniqueId_" + name + (tmp ? ".tmp" : ".bin"); + +constexpr int64_t kDefaultUniqueIdTimeoutSec = 600; + +std::string GetEnvString(const char *name, const std::string &default_value) { + const char *value = std::getenv(name); + return value == nullptr ? default_value : std::string(value); +} + +int64_t GetEnvInt64(const char *name, int64_t default_value) { + const char *value = std::getenv(name); + if (value == nullptr) { + return default_value; + } + try { + return std::stoll(value); + } catch (...) { + LOG(WARNING) << "Invalid integer environment variable " << name << "=" << value << ", fallback to " + << default_value; + return default_value; + } +} + +std::string SanitizeFileComponent(std::string value) { + std::replace_if( + value.begin(), value.end(), + [](unsigned char c) { return !(std::isalnum(c) || c == '_' || c == '-' || c == '.'); }, '_'); + return value; +} + +std::filesystem::path UniqueIdFilePath(const std::string &pg_name) { + std::string file_name = "cclUniqueId_"; + const std::string name_space = GetEnvString("INFINITRAIN_CCL_ID_NAMESPACE", ""); + if (!name_space.empty()) { + file_name += SanitizeFileComponent(name_space) + "_"; + } + file_name += SanitizeFileComponent(pg_name) + ".bin"; + return std::filesystem::path(GetEnvString("INFINITRAIN_CCL_ID_DIR", ".")) / file_name; +} + +std::filesystem::path UniqueIdTmpFilePath(const std::filesystem::path &file_path) { + const auto now = std::chrono::steady_clock::now().time_since_epoch().count(); + const std::string global_proc_rank = GetEnvString("GLOBAL_PROC_RANK", "0"); + return std::filesystem::path(file_path.string() + ".tmp." + global_proc_rank + "." + std::to_string(now)); +} + +void RemoveTmpFilesFor(const std::filesystem::path &file_path) { + const auto parent = file_path.parent_path(); + std::error_code ec; + if (!std::filesystem::exists(parent, ec)) { + return; + } + + const std::string prefix = file_path.filename().string() + ".tmp."; + for (const auto &entry : std::filesystem::directory_iterator(parent, ec)) { + if (ec) { + return; + } + const std::string name = entry.path().filename().string(); + if (name.rfind(prefix, 0) == 0) { + std::error_code remove_ec; + std::filesystem::remove(entry.path(), remove_ec); + } + } } } // namespace void WriteUniqueIdFile(const CclUniqueId &unique_id, const std::string &pg_name) { - const std::string tmp_path = UniqueIdFileName(pg_name, true); + const auto file_path = UniqueIdFilePath(pg_name); + const auto tmp_path = UniqueIdTmpFilePath(file_path); + + std::error_code ec; + std::filesystem::create_directories(file_path.parent_path(), ec); + CHECK(!ec) << "Failed to create CCL unique_id directory: " << file_path.parent_path() << ", error=" << ec.message(); + RemoveTmpFilesFor(file_path); std::ofstream ofs(tmp_path, std::ios::binary); CHECK(ofs.good()) << "Failed to open unique_id tmp file for write: " << tmp_path; const size_t size = unique_id.Size(); ofs.write(reinterpret_cast(unique_id.Data()), static_cast(size)); + CHECK(ofs.good()) << "Failed to write unique_id tmp file: " << tmp_path; ofs.close(); + CHECK(!ofs.fail()) << "Failed to close unique_id tmp file: " << tmp_path; - std::rename(tmp_path.c_str(), UniqueIdFileName(pg_name).c_str()); + const auto tmp_size = std::filesystem::file_size(tmp_path, ec); + CHECK(!ec && tmp_size == size) << "Invalid unique_id tmp file size. file=" << tmp_path << ", expected=" << size + << ", got=" << (ec ? 0 : tmp_size) << ", error=" << ec.message(); + + std::filesystem::rename(tmp_path, file_path, ec); + if (ec) { + std::error_code remove_ec; + std::filesystem::remove(tmp_path, remove_ec); + LOG(FATAL) << "Failed to publish unique_id file. tmp=" << tmp_path << ", final=" << file_path + << ", error=" << ec.message(); + } } void ReadUniqueIdFile(CclUniqueId *unique_id, const std::string &pg_name) { CHECK_NOTNULL(unique_id); - const std::string file_path = UniqueIdFileName(pg_name); - - while (!std::filesystem::exists(file_path)) { std::this_thread::sleep_for(std::chrono::microseconds(1000)); } + const auto file_path = UniqueIdFilePath(pg_name); + const auto timeout_sec + = std::max(1, GetEnvInt64("INFINITRAIN_CCL_ID_TIMEOUT_SEC", kDefaultUniqueIdTimeoutSec)); + const auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(timeout_sec); + const size_t expected_size = unique_id->Size(); + uintmax_t last_observed_size = 0; - std::ifstream ifs(file_path, std::ios::binary); - CHECK(ifs.good()) << "Failed to open unique_id file for read: " << file_path; - - std::string bytes((std::istreambuf_iterator(ifs)), std::istreambuf_iterator()); - ifs.close(); + while (std::chrono::steady_clock::now() < deadline) { + std::error_code ec; + if (std::filesystem::exists(file_path, ec)) { + const auto file_size = std::filesystem::file_size(file_path, ec); + if (!ec) { + last_observed_size = file_size; + if (file_size == expected_size) { + std::ifstream ifs(file_path, std::ios::binary); + if (ifs.good()) { + std::string bytes((std::istreambuf_iterator(ifs)), std::istreambuf_iterator()); + ifs.close(); + if (bytes.size() == expected_size) { + unique_id->Load(bytes.data(), bytes.size()); + return; + } + last_observed_size = bytes.size(); + } + } + } + } + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } - CHECK_EQ(bytes.size(), unique_id->Size()) - << "Mismatched unique_id size in file. expected=" << unique_id->Size() << ", got=" << bytes.size(); - unique_id->Load(bytes.data(), bytes.size()); + LOG(FATAL) << "Timed out waiting for CCL unique_id file. file=" << file_path << ", expected_size=" << expected_size + << ", last_observed_size=" << last_observed_size << ", timeout_sec=" << timeout_sec + << ". Set INFINITRAIN_CCL_ID_NAMESPACE for concurrent jobs sharing a directory."; } void CleanupUniqueIdFile(const std::string &pg_name) { - const std::string file_path = UniqueIdFileName(pg_name); - if (std::filesystem::exists(file_path)) { - std::filesystem::remove(file_path); - } + const auto file_path = UniqueIdFilePath(pg_name); + std::error_code ec; + std::filesystem::remove(file_path, ec); + RemoveTmpFilesFor(file_path); } } // namespace infini_train::core diff --git a/infini_train/src/core/ccl/cuda/nccl_impl.cc b/infini_train/src/core/ccl/cuda/nccl_impl.cc index 9e4b1a0d..98840e01 100644 --- a/infini_train/src/core/ccl/cuda/nccl_impl.cc +++ b/infini_train/src/core/ccl/cuda/nccl_impl.cc @@ -146,6 +146,40 @@ void NcclImpl::ReduceScatter(const void *sendbuff, void *recvbuff, size_t recv_c kNcclReduceOpMap.at(reduce_op), GetNcclComm(comm), GetCudaStream(stream))); } +void NcclImpl::AllToAll(const void *sendbuff, void *recvbuff, size_t count, DataType dtype, const CclComm *comm, + Stream *stream) const { + auto nccl_comm = GetNcclComm(comm); + auto cuda_stream = GetCudaStream(stream); + int nranks = 0; + int rank = 0; + NCCL_CHECK(ncclCommCount(nccl_comm, &nranks)); + NCCL_CHECK(ncclCommUserRank(nccl_comm, &rank)); + CHECK_GT(nranks, 0); + CHECK_GE(rank, 0); + CHECK_LT(rank, nranks); + + const size_t chunk_bytes = count * kDataTypeToSize.at(dtype); + auto send_ptr = static_cast(sendbuff); + auto recv_ptr = static_cast(recvbuff); + + if (chunk_bytes > 0) { + CUDA_CHECK(cudaMemcpyAsync(recv_ptr + static_cast(rank) * chunk_bytes, + send_ptr + static_cast(rank) * chunk_bytes, chunk_bytes, + cudaMemcpyDeviceToDevice, cuda_stream)); + } + + NCCL_CHECK(ncclGroupStart()); + for (int peer = 0; peer < nranks; ++peer) { + if (peer == rank) { + continue; + } + const auto offset = static_cast(peer) * chunk_bytes; + NCCL_CHECK(ncclSend(send_ptr + offset, count, kNcclDtypeMap.at(dtype), peer, nccl_comm, cuda_stream)); + NCCL_CHECK(ncclRecv(recv_ptr + offset, count, kNcclDtypeMap.at(dtype), peer, nccl_comm, cuda_stream)); + } + NCCL_CHECK(ncclGroupEnd()); +} + void NcclImpl::Send(const void *buff, size_t count, DataType dtype, int peer, const CclComm *comm, Stream *stream) const { NCCL_CHECK(ncclSend(buff, count, kNcclDtypeMap.at(dtype), peer, GetNcclComm(comm), GetCudaStream(stream))); diff --git a/infini_train/src/core/ccl/cuda/nccl_impl.h b/infini_train/src/core/ccl/cuda/nccl_impl.h index fca177fd..4fa68a87 100644 --- a/infini_train/src/core/ccl/cuda/nccl_impl.h +++ b/infini_train/src/core/ccl/cuda/nccl_impl.h @@ -42,6 +42,9 @@ class NcclImpl final : public CclImpl { nn::parallel::function::ReduceOpType reduce_op, const CclComm *comm, Stream *stream) const override; + void AllToAll(const void *sendbuff, void *recvbuff, size_t count, DataType dtype, const CclComm *comm, + Stream *stream) const override; + void Send(const void *buff, size_t count, DataType dtype, int peer, const CclComm *comm, Stream *stream) const override; diff --git a/infini_train/src/nn/modules/transformer/causal_self_attention.cc b/infini_train/src/nn/modules/transformer/causal_self_attention.cc index 8bed8193..e7425a58 100644 --- a/infini_train/src/nn/modules/transformer/causal_self_attention.cc +++ b/infini_train/src/nn/modules/transformer/causal_self_attention.cc @@ -1,6 +1,7 @@ #include "infini_train/include/nn/modules/transformer/causal_self_attention.h" #include +#include #include #include #include @@ -12,6 +13,7 @@ #include "infini_train/include/nn/modules/normalization.h" #include "infini_train/include/nn/modules/sparse.h" #include "infini_train/include/nn/modules/transformer/transformer_config.h" +#include "infini_train/include/nn/parallel/context_parallel.h" #include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/tensor_parallel.h" #include "infini_train/include/tensor.h" @@ -109,16 +111,24 @@ CausalSelfAttention::ForwardStandard(const std::vectorView({B, T, local_n_head_, head_dim})->Transpose(1, 2); v = v->View({B, T, local_n_head_, head_dim})->Transpose(1, 2); - // (B, h_l, T, T) - auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(head_dim)); - // (1, 1, T, T) - auto mask = buffers_[kParamBiasName]->Slice({0, 0, 0, 0}, {1, 1, T, T}, {1, 1, 1, 1}); - // (1, 1, T, T) -> eq 0 -> (1, 1, T, T) -> masked_fill -> (B, h_l, T, T) - att = att->MaskedFill(mask == 0, -std::numeric_limits::infinity()); - // (B, h_l, T, T) - att = nn::function::Softmax(att, -1); - // (B, h_l, T, Dh) - auto y = att->Matmul(v); + const auto q_start = parallel::GetContextParallelSequenceStart(T); + const auto T_kv = T * parallel::global::GetContextParallelSize(); + + // (1, 1, T_local, T_global) + auto mask = buffers_[kParamBiasName]->Slice({0, 0, q_start, 0}, {1, 1, q_start + T, T_kv}, {1, 1, 1, 1}); + std::shared_ptr y; + if (parallel::global::GetContextParallelSize() > 1) { + y = parallel::AttnForwardFuncWithCP(q, k, v, mask == 0); + } else { + // (B, h_l, T_local, T_global) + auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(head_dim)); + // (1, 1, T_local, T_global) -> eq 0 -> masked_fill -> (B, h_l, T_local, T_global) + att = att->MaskedFill(mask == 0, -std::numeric_limits::infinity()); + // (B, h_l, T_local, T_global) + att = nn::function::Softmax(att, -1); + // (B, h_l, T, Dh) + y = att->Matmul(v); + } // (B, h_l, T, Dh) -> (B, T, h_l, Dh) -> (B, T, local_C) y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_C}); @@ -219,10 +229,13 @@ CausalSelfAttention::ForwardWithRoPE(const std::vector (B, T, H_local, D) via RepeatKV - k = RepeatKV(k, n_rep_); - v = RepeatKV(v, n_rep_); + const bool use_context_parallel = parallel::global::GetContextParallelSize() > 1; + if (!use_context_parallel) { + // align n_head in GQA + // (B, T, KV_local, D) -> (B, T, H_local, D) via RepeatKV + k = RepeatKV(k, n_rep_); + v = RepeatKV(v, n_rep_); + } // (B, T, H_local, D) -> (B, H_local, T, D) q = q->Transpose(1, 2); @@ -232,21 +245,26 @@ CausalSelfAttention::ForwardWithRoPE(const std::vector (B, H_local, D, T) - // q @ k.T: (B, H_local, T, T) -> mul 1.0 / sqrt(D) -> (B, H_local, T, T) - auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(static_cast(D))); - if (mask) { - // mask: (1, 1, T, T) - att = att->MaskedFill(mask, std::numeric_limits::lowest()); + std::shared_ptr y; + if (use_context_parallel) { + y = parallel::AttnForwardFuncWithCP(q, k, v, mask); + } else { + // manual implementation of attention + // this materializes the large (T,T) matrix for all the queries and keys + + // q: (B, H_local, T, D) + // k: (B, H_local, T, D) -> (B, H_local, D, T) + // q @ k.T: (B, H_local, T, T) -> mul 1.0 / sqrt(D) -> (B, H_local, T, T) + auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(static_cast(D))); + if (mask) { + // mask: (1, 1, T, T) + att = att->MaskedFill(mask, std::numeric_limits::lowest()); + } + // (B, H_local, T, T) + att = nn::function::Softmax(att, -1); + // att: (B, H_local, T, T) @ v: (B, H_local, T, D) -> y: (B, H_local, T, D) + y = att->Matmul(v); } - // (B, H_local, T, T) - att = nn::function::Softmax(att, -1); - // att: (B, H_local, T, T) @ v: (B, H_local, T, D) -> y: (B, H_local, T, D) - auto y = att->Matmul(v); // (B, H_local, T, D) -> Transpose(1, 2) -> (B, T, H_local, D) -> (B, T, C_local) y = y->Transpose(1, 2)->Contiguous()->View({B, T, C_local}); // output projection diff --git a/infini_train/src/nn/modules/transformer/transformer.cc b/infini_train/src/nn/modules/transformer/transformer.cc index c7e0f28c..d397eab7 100644 --- a/infini_train/src/nn/modules/transformer/transformer.cc +++ b/infini_train/src/nn/modules/transformer/transformer.cc @@ -16,6 +16,7 @@ #include "infini_train/include/nn/modules/transformer/causal_self_attention.h" #include "infini_train/include/nn/modules/transformer/mlp.h" #include "infini_train/include/nn/modules/transformer/utils.h" +#include "infini_train/include/nn/parallel/context_parallel.h" #include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/tensor_parallel.h" #include "infini_train/include/nn/parallel/utils.h" @@ -56,8 +57,9 @@ std::vector> TransformerFirstStage::Forward(const std::v nn::parallel::GetTensorParallelProcessGroupName(device.Rank().GlobalRank())); tp_rank = tp_group->GetGroupRank(device.Rank().GlobalRank()); } + const int64_t cp_start = nn::parallel::GetContextParallelSequenceStart(x1->Dims()[1]); int64_t t_local = sequence_parallel_enabled ? x1->Dims()[1] / tp_world_size : x1->Dims()[1]; - int64_t start = sequence_parallel_enabled ? tp_rank * t_local : 0; + int64_t start = cp_start + (sequence_parallel_enabled ? tp_rank * t_local : 0); auto pos = nn::init::Arange(start, start + t_local, infini_train::DataType::kINT64, device); // (T) -> Embedding(T_max, C) -> (T, C) @@ -140,15 +142,17 @@ std::vector> TransformerChunk::Forward(const std::vector config_.use_scaled_rope, device); } - const auto t = x1->Dims()[1] * nn::parallel::global::GetSequenceParallelSize(); // full_seq_len + const auto local_t = x1->Dims()[1] * nn::parallel::global::GetSequenceParallelSize(); + const auto full_t = local_t * nn::parallel::global::GetContextParallelSize(); - // Dynamic start_pos (set to 0 for now) - int64_t start_pos = 0; - auto freqs_view = buffers_[kFreqsCisName]->Slice(0, start_pos, start_pos + t, 1); + int64_t start_pos = nn::parallel::GetContextParallelSequenceStart(local_t); + auto freqs_view = buffers_[kFreqsCisName]->Slice(0, start_pos, start_pos + local_t, 1); // Create causal mask - std::shared_ptr ones = std::make_shared(nn::function::Ones({t, t})->To(device)); - std::shared_ptr mask = nn::function::Triu(ones, 1)->View({1, 1, t, t}); + std::shared_ptr ones = std::make_shared(nn::function::Ones({full_t, full_t})->To(device)); + std::shared_ptr mask = nn::function::Triu(ones, 1) + ->Slice({start_pos, 0}, {start_pos + local_t, full_t}, {1, 1}) + ->View({1, 1, local_t, full_t}); std::shared_ptr start_pos_ptr = nullptr; diff --git a/infini_train/src/nn/parallel/context_parallel.cc b/infini_train/src/nn/parallel/context_parallel.cc new file mode 100644 index 00000000..eee4d69e --- /dev/null +++ b/infini_train/src/nn/parallel/context_parallel.cc @@ -0,0 +1,640 @@ +#include "infini_train/include/nn/parallel/context_parallel.h" + +#include +#include +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/autograd/function.h" +#include "infini_train/include/nn/functional.h" +#include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/nn/parallel/process_group.h" +#include "infini_train/include/nn/parallel/reduce_op_type.h" +#include "infini_train/include/nn/parallel/utils.h" +#include "infini_train/include/nn/parallel/work.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::parallel { + +thread_local int cp_rank = 0; + +namespace { + +const ProcessGroup *GetCPGroup(const std::shared_ptr &tensor) { + auto cp_size = global::GetContextParallelSize(); + CHECK_GT(cp_size, 0); + return ProcessGroupFactory::Instance(tensor->GetDevice().type()) + ->Get(GetContextParallelProcessGroupName(tensor->GetDevice().Rank().GlobalRank())); +} + +// Comm Kernel Call Functions +std::shared_ptr GatherAlongFirstDim(const std::shared_ptr &tensor) { + const int cp_size = global::GetContextParallelSize(); + CHECK_GT(cp_size, 0) << "Context Parallel group not initialized"; + if (cp_size == 1) { + return tensor; + } + + auto cp_group = GetCPGroup(tensor); + auto output_shape = tensor->Dims(); + output_shape[0] *= cp_size; + auto output = std::make_shared(output_shape, tensor->Dtype(), tensor->GetDevice()); + cp_group->AllGather(output, tensor, false); + return output; +} + +std::shared_ptr ReduceScatterAlongFirstDim(const std::shared_ptr &tensor) { + const int cp_size = global::GetContextParallelSize(); + CHECK_GT(cp_size, 0) << "Context Parallel group not initialized"; + if (cp_size == 1) { + return tensor; + } + + auto cp_group = GetCPGroup(tensor); + auto output_shape = tensor->Dims(); + CHECK_EQ(output_shape[0] % cp_size, 0) << "First dimension must be divisible by CP size"; + output_shape[0] /= cp_size; + auto output = std::make_shared(output_shape, tensor->Dtype(), tensor->GetDevice()); + cp_group->ReduceScatter(output, tensor, function::ReduceOpType::kSum, false); + return output; +} + +std::shared_ptr AllToAllAlongFirstDim(const std::shared_ptr &tensor) { + // Tensor P is split along first dim in [P0 | P1 | ... | Pn] + // Each rank j sends Pj to every other rank and receive the rest of P from every other rank + const int cp_size = global::GetContextParallelSize(); + CHECK_GT(cp_size, 0) << "Context Parallel group not initialized"; + if (cp_size == 1) { + return tensor; + } + + auto cp_group = GetCPGroup(tensor); + auto output_shape = tensor->Dims(); + CHECK_EQ(output_shape[0] % cp_size, 0) << "First dimension must be divisible by CP size"; + auto output = std::make_shared(output_shape, tensor->Dtype(), tensor->GetDevice()); + cp_group->AllToAll(output, tensor, false); + return output; +} + +std::shared_ptr AllToAllSeqToHead(const std::shared_ptr &input) { + if (global::GetContextParallelSize() == 1) { + return input; + } + + const int cp_size = global::GetContextParallelSize(); + const auto &shape = input->Dims(); + CHECK_EQ(shape.size(), 4); + const int64_t B = shape[0], H = shape[1], T_l = shape[2], D = shape[3]; + CHECK_EQ(H % cp_size, 0) << "A2A CP requires head dimension divisible by CP size"; + const int64_t H_per_cp = H / cp_size; + + // input: (B, H, T_l, D) + // + // send_input: (H, B, T_l, D), split dim 0 into CP chunks of H_per_cp heads. + auto send_input = input->Transpose(0, 1)->Contiguous(); + // exchanged: (H, B, T_l, D), dim 0 chunks are ordered by source sequence-owner rank. + auto exchanged = AllToAllAlongFirstDim(send_input); + // output: (B, H_per_cp, T_g, D) + return exchanged->View({cp_size, H_per_cp, B, T_l, D}) + ->Transpose(0, 2) + ->Contiguous() + ->View({B, H_per_cp, static_cast(cp_size) * T_l, D}); +} + +std::shared_ptr AllToAllHeadToSeq(const std::shared_ptr &input) { + if (global::GetContextParallelSize() == 1) { + return input; + } + + const int cp_size = global::GetContextParallelSize(); + const auto &shape = input->Dims(); + CHECK_EQ(shape.size(), 4); + const int64_t B = shape[0], H_per_cp = shape[1], T_g = shape[2], D = shape[3]; + CHECK_EQ(T_g % cp_size, 0) << "A2A CP requires sequence dimension divisible by CP size"; + const int64_t T_l = T_g / cp_size; + + // input: (B, H_per_cp, T_g, D) + // + // send_input: (CP * H_per_cp, B, T_l, D), split dim 0 into CP sequence-owner chunks. + auto send_input = input->View({B, H_per_cp, cp_size, T_l, D}) + ->Transpose(0, 2) + ->Contiguous() + ->View({static_cast(cp_size) * H_per_cp, B, T_l, D}); + // exchanged: (CP * H_per_cp, B, T_l, D), dim 0 chunks are ordered by source head-owner rank. + auto exchanged = AllToAllAlongFirstDim(send_input); + // output: (B, CP * H_per_cp, T_l, D) + return exchanged->View({cp_size, H_per_cp, B, T_l, D}) + ->Transpose(1, 2) + ->Transpose(0, 1) + ->Contiguous() + ->View({B, static_cast(cp_size) * H_per_cp, T_l, D}); +} + +// Attention Helper Functions +std::shared_ptr NewZeroTensorLike(const std::shared_ptr &tensor) { + auto output = std::make_shared(tensor->Dims(), tensor->Dtype(), tensor->GetDevice()); + output->Fill(0.0f); + return output; +} + +std::vector> P2PCommunicate(int rank, const std::vector> &send_tensors, + int send_dst, + const std::vector> &recv_tensors, + int recv_src, const ProcessGroup *cp_group, bool batch_p2p_comm) { + std::vector ops; + ops.reserve(send_tensors.size() + recv_tensors.size()); + std::vector> works; + + if (rank % 2 == 0) { + if (batch_p2p_comm) { + for (const auto &tensor : send_tensors) { ops.push_back({P2POpType::kSend, tensor, send_dst}); } + for (const auto &tensor : recv_tensors) { ops.push_back({P2POpType::kRecv, tensor, recv_src}); } + } else { + works.push_back(cp_group->Send(send_tensors, send_dst, true)); + works.push_back(cp_group->Recv(recv_tensors, recv_src, true)); + } + } else { + if (batch_p2p_comm) { + for (const auto &tensor : recv_tensors) { ops.push_back({P2POpType::kRecv, tensor, recv_src}); } + for (const auto &tensor : send_tensors) { ops.push_back({P2POpType::kSend, tensor, send_dst}); } + } else { + works.push_back(cp_group->Recv(recv_tensors, recv_src, true)); + works.push_back(cp_group->Send(send_tensors, send_dst, true)); + } + } + if (batch_p2p_comm) { + works.push_back(cp_group->BatchSendRecv(ops, true)); + } + return works; +} + +std::shared_ptr RepeatKVHeads(const std::shared_ptr &x, int64_t n_rep) { + if (n_rep == 1) { + return x; + } + + const auto &shape = x->Dims(); + const int64_t B = shape[0], H = shape[1], T = shape[2], D = shape[3]; + return x->View({B, H, 1, T, D})->RepeatInterleave(n_rep, 2)->Contiguous()->View({B, H * n_rep, T, D}); +} + +std::shared_ptr SumRepeatedKVHeads(const std::shared_ptr &x, int64_t n_rep) { + if (n_rep == 1) { + return x; + } + + const auto &shape = x->Dims(); + const int64_t B = shape[0], H = shape[1], T = shape[2], D = shape[3]; + CHECK_EQ(H % n_rep, 0); + return x->View({B, H / n_rep, n_rep, T, D})->Sum(2); +} + +std::shared_ptr ApplyCoreAttention(const std::shared_ptr &q, const std::shared_ptr &k, + const std::shared_ptr &v, const std::shared_ptr &mask) { + const float scale = static_cast(1.0 / std::sqrt(static_cast(q->Dims().back()))); + // scores: (B, H, T_q, T_k) + auto scores = q->Matmul(k->Transpose(-2, -1)) * scale; + if (mask) { + scores = scores->MaskedFill(mask, std::numeric_limits::lowest()); + } + // probs: (B, H, T_q, T_k) + auto probs = nn::function::Softmax(scores, -1); + // output: (B, H, T_q, D) + return probs->Matmul(v); +} + +// Autograd Function Definitions +class GatherFromCPRegion : public autograd::Function { +public: + static constexpr char kType[] = "GatherFromCPRegionFunction"; + + explicit GatherFromCPRegion() : autograd::Function(kType) {} + + std::vector> Forward(const std::vector> &input_tensors) override { + auto input = input_tensors[0]; + // FIXME(zbl): Megatron keeps sequence as dim 0. We uses [B, H, S, D], so move only + // the sequence dimension to dim 0 before the CP gather. + return {GatherAlongFirstDim(input->Transpose(0, 2))->Transpose(0, 2)->Contiguous()}; + } + + std::vector> Backward(const std::vector> &grad_outputs) override { + // FIXME(zbl): Megatron keeps sequence as dim 0. We uses [B, H, S, D], so move only + // the sequence dimension to dim 0 before the CP gather. + return {ReduceScatterAlongFirstDim(grad_outputs[0]->Transpose(0, 2))->Transpose(0, 2)->Contiguous()}; + } +}; + +class AllToAllSeqToHeadCPRegion : public autograd::Function { +public: + static constexpr char kType[] = "AllToAllSeqToHeadCPRegionFunction"; + + explicit AllToAllSeqToHeadCPRegion() : autograd::Function(kType) {} + + std::vector> Forward(const std::vector> &input_tensors) override { + return {AllToAllSeqToHead(input_tensors[0])}; + } + + std::vector> Backward(const std::vector> &grad_outputs) override { + return {AllToAllHeadToSeq(grad_outputs[0])}; + } +}; + +class AllToAllHeadToSeqCPRegion : public autograd::Function { +public: + static constexpr char kType[] = "AllToAllHeadToSeqCPRegionFunction"; + + explicit AllToAllHeadToSeqCPRegion() : autograd::Function(kType) {} + + std::vector> Forward(const std::vector> &input_tensors) override { + return {AllToAllHeadToSeq(input_tensors[0])}; + } + + std::vector> Backward(const std::vector> &grad_outputs) override { + return {AllToAllSeqToHead(grad_outputs[0])}; + } +}; + +class AttnWithCPAndKVP2P : public autograd::Function { +public: + static constexpr char kType[] = "AttnWithCPAndKVP2PFunction"; + + AttnWithCPAndKVP2P() : autograd::Function(kType) {} + + std::vector> Forward(const std::vector> &input_tensors) override { + CHECK_EQ(input_tensors.size(), 4); + // Shape notation: + // B: batch size, H_q: local query heads after TP, H_kv: local KV heads before GQA repeat, + // T_l: CP-local sequence length, T_g: global sequence length, D: head dimension. + // q: (B, H_q, T_l, D) + const auto &q = input_tensors[0]; + // k_local: (B, H_kv, T_l, D) + const auto &k_local = input_tensors[1]; + // v_local: (B, H_kv, T_l, D) + const auto &v_local = input_tensors[2]; + // mask: (1, 1, T_l, T_g), true values are invalid attention locations. + const auto &mask = input_tensors[3]; + const int cp_size = global::GetContextParallelSize(); + CHECK_GT(cp_size, 1); + CHECK(mask) << "CP ring attention expects a causal mask."; + + auto cp_group = GetCPGroup(q); + CHECK_NOTNULL(cp_group); + const int rank = cp_group->GetGroupRank(q->GetDevice().Rank().GlobalRank()); + const int send_to = (rank + 1) % cp_size; + const int recv_from = (rank - 1 + cp_size) % cp_size; + // NOTE(zbl): Megatron-LM enables batched P2P by default for CP=2 on pre-Blackwell GPUs. + const bool batch_p2p_comm = (cp_size == 2); + + const int64_t local_t = q->Dims()[2]; + CHECK_EQ(k_local->Dims()[2], local_t); + CHECK_EQ(v_local->Dims()[2], local_t); + CHECK_EQ(k_local->Dims()[1], v_local->Dims()[1]); + CHECK_EQ(q->Dims()[1] % k_local->Dims()[1], 0); + const int64_t n_rep = q->Dims()[1] / k_local->Dims()[1]; + const float scale = static_cast(1.0 / std::sqrt(static_cast(q->Dims().back()))); + + // current_k: (B, H_kv, T_l, D), owned by rank `(rank - step + cp_size) % cp_size`. + auto current_k = k_local; + // current_v: (B, H_kv, T_l, D), owned by rank `(rank - step + cp_size) % cp_size`. + auto current_v = v_local; + // running_max: (B, H_q, T_l, 1) + std::shared_ptr running_max; + // running_sum: (B, H_q, T_l, 1) + std::shared_ptr running_sum; + // running_out: (B, H_q, T_l, D) + std::shared_ptr running_out; + + for (int step = 0; step < cp_size; ++step) { + std::shared_ptr next_k; + std::shared_ptr next_v; + std::vector> p2p_works; + if (step + 1 < cp_size) { + // next_k: (B, H_kv, T_l, D) + next_k = std::make_shared(k_local->Dims(), k_local->Dtype(), k_local->GetDevice()); + // next_v: (B, H_kv, T_l, D) + next_v = std::make_shared(v_local->Dims(), v_local->Dtype(), v_local->GetDevice()); + p2p_works = P2PCommunicate(rank, {current_k, current_v}, send_to, {next_k, next_v}, recv_from, cp_group, + batch_p2p_comm); + } + + const int owner = (rank - step + cp_size) % cp_size; + const int64_t kv_start = static_cast(owner) * local_t; + const int64_t kv_end = kv_start + local_t; + + // k_for_attn: (B, H_q, T_l, D) + auto k_for_attn = RepeatKVHeads(current_k, n_rep); + // v_for_attn: (B, H_q, T_l, D) + auto v_for_attn = RepeatKVHeads(current_v, n_rep); + // scores: (B, H_q, T_l, T_l) + auto scores = q->Matmul(k_for_attn->Transpose(-2, -1)) * scale; + // invalid_mask: (1, 1, T_l, T_l) + auto invalid_mask = mask->Slice(3, kv_start, kv_end); + scores = scores->MaskedFill(invalid_mask, std::numeric_limits::lowest()); + + // chunk_max: (B, H_q, T_l, 1) + auto chunk_max = scores->Max(-1, true); + // probs: (B, H_q, T_l, T_l) + auto probs = (scores - chunk_max)->Exp()->MaskedFill(invalid_mask, 0.0f); + // chunk_sum: (B, H_q, T_l, 1) + auto chunk_sum = probs->Sum(-1, true); + // chunk_out: (B, H_q, T_l, D) + auto chunk_out = probs->Matmul(v_for_attn); + + if (!running_out) { + running_max = chunk_max; + running_sum = chunk_sum; + running_out = chunk_out; + } else { + // new_max: (B, H_q, T_l, 1) + auto new_max + = nn::function::Stack(std::vector>{running_max, chunk_max}, -1)->Max(-1); + // old_scale: (B, H_q, T_l, 1) + auto old_scale = (running_max - new_max)->Exp(); + // new_scale: (B, H_q, T_l, 1) + auto new_scale = (chunk_max - new_max)->Exp(); + running_sum = running_sum * old_scale + chunk_sum * new_scale; + running_out = running_out * old_scale + chunk_out * new_scale; + running_max = new_max; + } + + if (!p2p_works.empty()) { + for (const auto &work : p2p_works) { work->WaitNonBlocking(); } + current_k = next_k; + current_v = next_v; + } + } + + // output: (B, H_q, T_l, D) + // running_max: (B, H_q, T_l, 1) + // running_sum: (B, H_q, T_l, 1) + return {running_out / running_sum, running_max, running_sum}; + } + + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override { + CHECK_EQ(output_tensors.size(), 3); + const auto &output = output_tensors[0]; + const auto &softmax_max = output_tensors[1]; + const auto &softmax_sum = output_tensors[2]; + ctx_.MarkNonDifferentiable({softmax_max, softmax_sum}); + ctx_.SaveForBackward( + {input_tensors[0], input_tensors[1], input_tensors[2], input_tensors[3], output, softmax_max, softmax_sum}); + } + + std::vector> Backward(const std::vector> &grad_outputs) override { + // Shape notation: + // - B: batch size + // - H_q: local query heads after TP + // - H_kv: local KV heads before GQA repeat + // - T_l: CP-local sequence length + // - T_g: global sequence length + // - D: head dimension. + + CHECK_GE(grad_outputs.size(), 1); + auto saved_tensors = ctx_.GetSavedTensors(); + CHECK_EQ(saved_tensors.size(), 7); + + // q: (B, H_q, T_l, D) + const auto &q = saved_tensors[0]; + // k_local: (B, H_kv, T_l, D) + const auto &k_local = saved_tensors[1]; + // v_local: (B, H_kv, T_l, D) + const auto &v_local = saved_tensors[2]; + // mask: (1, 1, T_l, T_g), true values are invalid attention locations. + const auto &mask = saved_tensors[3]; + // output: (B, H_q, T_l, D) + const auto &output = saved_tensors[4]; + // softmax_max: (B, H_q, T_l, 1) + const auto &softmax_max = saved_tensors[5]; + // softmax_sum: (B, H_q, T_l, 1) + const auto &softmax_sum = saved_tensors[6]; + // grad_output: (B, H_q, T_l, D) + const auto &grad_output = grad_outputs[0]; + + const int cp_size = global::GetContextParallelSize(); + CHECK_GT(cp_size, 1); + + auto cp_group = GetCPGroup(q); + CHECK_NOTNULL(cp_group); + const int rank = cp_group->GetGroupRank(q->GetDevice().Rank().GlobalRank()); + const int send_to = (rank + 1) % cp_size; + const int recv_from = (rank - 1 + cp_size) % cp_size; + // NOTE(zbl): Megatron-LM enables batched P2P by default for CP=2 on pre-Blackwell GPUs. + const bool batch_p2p_comm = (cp_size == 2); + + const int64_t local_t = q->Dims()[2]; + CHECK_EQ(k_local->Dims()[2], local_t); + CHECK_EQ(v_local->Dims()[2], local_t); + CHECK_EQ(k_local->Dims()[1], v_local->Dims()[1]); + CHECK_EQ(q->Dims()[1] % k_local->Dims()[1], 0); + const int64_t n_rep = q->Dims()[1] / k_local->Dims()[1]; + const float scale = static_cast(1.0 / std::sqrt(static_cast(q->Dims().back()))); + + // current_k: (B, H_kv, T_l, D) + auto current_k = k_local; + // current_v: (B, H_kv, T_l, D) + auto current_v = v_local; + // current_grad_k: (B, H_kv, T_l, D) + auto current_grad_k = NewZeroTensorLike(k_local); + // current_grad_v: (B, H_kv, T_l, D) + auto current_grad_v = NewZeroTensorLike(v_local); + // grad_q: (B, H_q, T_l, D) + std::shared_ptr grad_q; + // softmax_delta: (B, H_q, T_l, 1) + const auto softmax_delta = (grad_output * output)->Sum(-1, true); + + for (int step = 0; step < cp_size; ++step) { + const int owner = (rank - step + cp_size) % cp_size; + const int64_t kv_start = static_cast(owner) * local_t; + const int64_t kv_end = kv_start + local_t; + + // k_for_attn: (B, H_q, T_l, D) + auto k_for_attn = RepeatKVHeads(current_k, n_rep); + // v_for_attn: (B, H_q, T_l, D) + auto v_for_attn = RepeatKVHeads(current_v, n_rep); + // scores: (B, H_q, T_l, T_l) + auto scores = q->Matmul(k_for_attn->Transpose(-2, -1)) * scale; + // invalid_mask: (1, 1, T_l, T_l) + auto invalid_mask = mask->Slice(3, kv_start, kv_end); + scores = scores->MaskedFill(invalid_mask, std::numeric_limits::lowest()); + + // probs: (B, H_q, T_l, T_l) + auto probs = (scores - softmax_max)->Exp()->MaskedFill(invalid_mask, 0.0f) / softmax_sum; + // grad_v_repeated: (B, H_q, T_l, D) + auto grad_v_repeated = probs->Transpose(-2, -1)->Matmul(grad_output); + // grad_probs: (B, H_q, T_l, T_l) + auto grad_probs = grad_output->Matmul(v_for_attn->Transpose(-2, -1)); + // grad_scores: (B, H_q, T_l, T_l) + auto grad_scores = probs * (grad_probs - softmax_delta); + // grad_k_repeated: (B, H_q, T_l, D) + auto grad_k_repeated = grad_scores->Transpose(-2, -1)->Matmul(q) * scale; + + // SumRepeatedKVHeads maps repeated GQA gradients from (B, H_q, T_l, D) to (B, H_kv, T_l, D). + current_grad_k = current_grad_k + SumRepeatedKVHeads(grad_k_repeated, n_rep); + current_grad_v = current_grad_v + SumRepeatedKVHeads(grad_v_repeated, n_rep); + + std::vector> p2p_works; + std::shared_ptr next_k; + std::shared_ptr next_v; + std::shared_ptr next_grad_k; + std::shared_ptr next_grad_v; + + if (step + 1 < cp_size) { + // Send current K/V plus accumulated K/V grads to the next rank; receive the previous owner chunk. + // next_k: (B, H_kv, T_l, D) + next_k = std::make_shared(k_local->Dims(), k_local->Dtype(), k_local->GetDevice()); + // next_v: (B, H_kv, T_l, D) + next_v = std::make_shared(v_local->Dims(), v_local->Dtype(), v_local->GetDevice()); + // next_grad_k: (B, H_kv, T_l, D) + next_grad_k = NewZeroTensorLike(k_local); + // next_grad_v: (B, H_kv, T_l, D) + next_grad_v = NewZeroTensorLike(v_local); + p2p_works + = P2PCommunicate(rank, {current_k, current_v, current_grad_k, current_grad_v}, send_to, + {next_k, next_v, next_grad_k, next_grad_v}, recv_from, cp_group, batch_p2p_comm); + } else { + // Last step only needs to rotate accumulated K/V grads back to the local owner rank. + // next_grad_k: (B, H_kv, T_l, D) + next_grad_k = NewZeroTensorLike(k_local); + // next_grad_v: (B, H_kv, T_l, D) + next_grad_v = NewZeroTensorLike(v_local); + p2p_works = P2PCommunicate(rank, {current_grad_k, current_grad_v}, send_to, {next_grad_k, next_grad_v}, + recv_from, cp_group, batch_p2p_comm); + } + + // grad_q_chunk: (B, H_q, T_l, D) + auto grad_q_chunk = grad_scores->Matmul(k_for_attn) * scale; + grad_q = grad_q ? grad_q + grad_q_chunk : grad_q_chunk; + + for (const auto &work : p2p_works) { work->WaitNonBlocking(); } + + if (step + 1 < cp_size) { + current_k = next_k; + current_v = next_v; + current_grad_k = next_grad_k; + current_grad_v = next_grad_v; + } else { + current_grad_k = next_grad_k; + current_grad_v = next_grad_v; + } + } + + return {grad_q, current_grad_k, current_grad_v, nullptr}; + } +}; + +std::shared_ptr AllToAllSeqToHeadCPRegionFunc(const std::shared_ptr &input) { + return std::make_shared()->Apply({input})[0]; +} + +std::shared_ptr AllToAllHeadToSeqCPRegionFunc(const std::shared_ptr &input) { + return std::make_shared()->Apply({input})[0]; +} + +} // namespace + +// CP State Helper Functions +int GetContextParallelRank() { return cp_rank; } + +int64_t GetContextParallelSequenceStart(int64_t local_sequence_length) { + return static_cast(GetContextParallelRank()) * local_sequence_length; +} + +// CP Communication Helper Functions +std::shared_ptr SliceAlongCPRegionFunc(const std::shared_ptr &input, int64_t dim) { + const int cp_size = global::GetContextParallelSize(); + if (cp_size == 1) { + return input; + } + + int64_t normalized_dim = dim; + if (normalized_dim < 0) { + normalized_dim += static_cast(input->Dims().size()); + } + CHECK_GE(normalized_dim, 0); + CHECK_LT(normalized_dim, static_cast(input->Dims().size())); + const auto dim_size = input->Dims()[normalized_dim]; + CHECK_EQ(dim_size % cp_size, 0) << "Sequence dimension must be divisible by CP size"; + const int64_t local_size = dim_size / cp_size; + const int64_t start = GetContextParallelSequenceStart(local_size); + return input->Slice(normalized_dim, start, start + local_size, 1)->Contiguous(); +} + +std::shared_ptr GatherFromCPRegionFunc(const std::shared_ptr &input) { + return std::make_shared()->Apply({input})[0]; +} + +// CP Attention Backend Functions +std::shared_ptr AttnFuncWithCPAndKVP2P(const std::shared_ptr &q, const std::shared_ptr &k, + const std::shared_ptr &v, const std::shared_ptr &mask) { + return std::make_shared()->Apply({q, k, v, mask})[0]; +} + +std::shared_ptr AttnFuncWithCPAndKVAllGather(const std::shared_ptr &q, const std::shared_ptr &k, + const std::shared_ptr &v, + const std::shared_ptr &mask) { + CHECK_EQ(k->Dims()[1], v->Dims()[1]); + CHECK_EQ(q->Dims()[1] % k->Dims()[1], 0); + const int64_t n_rep = q->Dims()[1] / k->Dims()[1]; + // gathered_k: (B, H_kv, T_g, D) + auto gathered_k = GatherFromCPRegionFunc(k); + // gathered_v: (B, H_kv, T_g, D) + auto gathered_v = GatherFromCPRegionFunc(v); + // k_for_attn: (B, H_q, T_g, D) + auto k_for_attn = RepeatKVHeads(gathered_k, n_rep); + // v_for_attn: (B, H_q, T_g, D) + auto v_for_attn = RepeatKVHeads(gathered_v, n_rep); + return ApplyCoreAttention(q, k_for_attn, v_for_attn, mask); +} + +std::shared_ptr AttnFuncWithCPAndQKVOA2A(const std::shared_ptr &q, const std::shared_ptr &k, + const std::shared_ptr &v, + const std::shared_ptr &mask) { + const int cp_size = global::GetContextParallelSize(); + const int64_t q_heads = q->Dims()[1]; + const int64_t kv_heads = k->Dims()[1]; + CHECK_EQ(kv_heads, v->Dims()[1]); + CHECK_EQ(q_heads % cp_size, 0) << "A2A CP requires local query heads divisible by CP size"; + CHECK_EQ(kv_heads % cp_size, 0) << "A2A CP requires local KV heads divisible by CP size"; + CHECK_EQ(q_heads % kv_heads, 0); + + // q_shard: (B, H_q/CP, T_g, D) + auto q_shard = AllToAllSeqToHeadCPRegionFunc(q); + // k_shard: (B, H_kv/CP, T_g, D) + auto k_shard = AllToAllSeqToHeadCPRegionFunc(k); + // v_shard: (B, H_kv/CP, T_g, D) + auto v_shard = AllToAllSeqToHeadCPRegionFunc(v); + // full_mask: (1, 1, T_g, T_g) + auto full_mask = mask ? GatherFromCPRegionFunc(mask) : nullptr; + + const int64_t n_rep = q_shard->Dims()[1] / k_shard->Dims()[1]; + // k_for_attn: (B, H_q/CP, T_g, D) + auto k_for_attn = RepeatKVHeads(k_shard, n_rep); + // v_for_attn: (B, H_q/CP, T_g, D) + auto v_for_attn = RepeatKVHeads(v_shard, n_rep); + // output_shard: (B, H_q/CP, T_g, D) + auto output_shard = ApplyCoreAttention(q_shard, k_for_attn, v_for_attn, full_mask); + + // output: (B, H_q, T_l, D) + return AllToAllHeadToSeqCPRegionFunc(output_shard); +} + +std::shared_ptr AttnForwardFuncWithCP(const std::shared_ptr &q, const std::shared_ptr &k, + const std::shared_ptr &v, const std::shared_ptr &mask) { + CHECK_GT(global::GetContextParallelSize(), 1); + const auto comm_type = global::GetContextParallelCommType(); + if (comm_type == "p2p") { + return AttnFuncWithCPAndKVP2P(q, k, v, mask); + } else if (comm_type == "a2a") { + return AttnFuncWithCPAndQKVOA2A(q, k, v, mask); + } else if (comm_type == "all_gather") { + return AttnFuncWithCPAndKVAllGather(q, k, v, mask); + } else { + LOG(FATAL) << "AttnForwardFuncWithCP: Unsupported communication type " << comm_type << "."; + } +} + +} // namespace infini_train::nn::parallel diff --git a/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc b/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc index a3bfe008..bcade972 100644 --- a/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc +++ b/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc @@ -25,7 +25,7 @@ constexpr char kModuleName[] = "module"; DistributedDataParallel::DistributedDataParallel(std::shared_ptr module, const Rank &rank, const DistributedDataParallelConfig ddp_config) : ddp_config_(ddp_config), - ddp_pg_(ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(rank.GlobalRank()))) { + ddp_pg_(ProcessGroupFactory::Instance()->Get(GetDataParallelWithContextProcessGroupName(rank.GlobalRank()))) { CHECK(ddp_config_.zero_stage >= 0 && ddp_config_.zero_stage <= 3) << "DistributedDataParallel: zero_stage must be in 0/1/2/3."; if (ddp_config_.zero_stage == 3) { diff --git a/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc b/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc index ab3a8002..bdfcde91 100644 --- a/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc +++ b/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc @@ -413,7 +413,7 @@ ParamAndGradBuffer::ParamAndGradBuffer(const std::vector DistributedDataParallelConfig ddp_config) : params_(std::move(params)), ddp_pg_(std::move(ddp_pg)), ddp_config_(ddp_config) { if (ddp_pg_) { - ddp_world_size_ = global::GetDataParallelSize(); + ddp_world_size_ = static_cast(ddp_pg_->WorldSize()); } grads_.clear(); diff --git a/infini_train/src/nn/parallel/ddp/reducer.cc b/infini_train/src/nn/parallel/ddp/reducer.cc index 1bdd29e1..16d1a3e5 100644 --- a/infini_train/src/nn/parallel/ddp/reducer.cc +++ b/infini_train/src/nn/parallel/ddp/reducer.cc @@ -387,7 +387,7 @@ void Reducer::MarkBucketReady(size_t bucket_index) { void Reducer::FinalizeBucketDense(size_t bucket_index) { // NOTE(zbl): Assume mutex is on when entering this function auto &bucket = buckets_.at(bucket_index); - auto ddp_pg = ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(bucket.device_rank)); + auto ddp_pg = ProcessGroupFactory::Instance()->Get(GetDataParallelWithContextProcessGroupName(bucket.device_rank)); if (comm_hook_) { std::vector> bucket_view{bucket.contents}; diff --git a/infini_train/src/nn/parallel/global.cc b/infini_train/src/nn/parallel/global.cc index 65a3208e..75d6532d 100644 --- a/infini_train/src/nn/parallel/global.cc +++ b/infini_train/src/nn/parallel/global.cc @@ -2,7 +2,9 @@ #include #include +#include #include +#include #include "glog/logging.h" @@ -20,7 +22,6 @@ namespace infini_train::nn::parallel::global { thread_local int thread_global_rank = 0; void Layout::InitStrides() { - // Calculate strides int stride = 1; for (int i = 0; i < AXIS_COUNT; ++i) { const Axis ax = order[i]; @@ -29,9 +30,10 @@ void Layout::InitStrides() { } } -int Layout::RankOf(int dp, int tp, int pp) const { - // Return the thread rank given layout coords - const int coord[AXIS_COUNT] = {dp, tp, pp}; +int Layout::RankOf(int dp, int tp, int pp) const { return RankOf(dp, tp, /*cp=*/0, pp); } + +int Layout::RankOf(int dp, int tp, int cp, int pp) const { + const int coord[AXIS_COUNT] = {dp, tp, cp, pp}; int r = 0; for (int i = 0; i < AXIS_COUNT; ++i) { const Axis ax = static_cast(i); @@ -41,14 +43,20 @@ int Layout::RankOf(int dp, int tp, int pp) const { } void Layout::CoordOf(int rank, int &dp, int &tp, int &pp) const { - // Return the layout coords given thread rank + int cp = 0; + CoordOf(rank, dp, tp, cp, pp); +} + +void Layout::CoordOf(int rank, int &dp, int &tp, int &cp, int &pp) const { dp = (rank / strides[DP]) % sizes[DP]; tp = (rank / strides[TP]) % sizes[TP]; + cp = (rank / strides[CP]) % sizes[CP]; pp = (rank / strides[PP]) % sizes[PP]; } -int Layout::GroupId(Axis target, int dp, int tp, int pp) const { - // Return the parallel ProcessGroup ID where the rank is in +int Layout::GroupId(Axis target, int dp, int tp, int pp) const { return GroupId(target, dp, tp, /*cp=*/0, pp); } + +int Layout::GroupId(Axis target, int dp, int tp, int cp, int pp) const { int id = 0; int mult = 1; for (int i = AXIS_COUNT - 1; i >= 0; --i) { @@ -56,7 +64,7 @@ int Layout::GroupId(Axis target, int dp, int tp, int pp) const { if (ax == target) { continue; } - int c = (ax == DP ? dp : (ax == TP ? tp : pp)); + int c = (ax == DP ? dp : (ax == TP ? tp : (ax == CP ? cp : pp))); id += c * mult; mult *= sizes[ax]; } @@ -64,19 +72,24 @@ int Layout::GroupId(Axis target, int dp, int tp, int pp) const { } std::vector Layout::GroupRanks(Axis target, int fixed_dp, int fixed_tp, int fixed_pp) const { - // Return all the ranks within the same parallel ProcessGroup + return GroupRanks(target, fixed_dp, fixed_tp, /*fixed_cp=*/0, fixed_pp); +} + +std::vector Layout::GroupRanks(Axis target, int fixed_dp, int fixed_tp, int fixed_cp, int fixed_pp) const { std::vector ranks; ranks.reserve(sizes[target]); - int dp = fixed_dp, tp = fixed_tp, pp = fixed_pp; + int dp = fixed_dp, tp = fixed_tp, cp = fixed_cp, pp = fixed_pp; for (int v = 0; v < sizes[target]; ++v) { if (target == DP) { dp = v; } else if (target == TP) { tp = v; + } else if (target == CP) { + cp = v; } else { pp = v; } - ranks.push_back(RankOf(dp, tp, pp)); + ranks.push_back(RankOf(dp, tp, cp, pp)); } return ranks; } @@ -87,6 +100,7 @@ GlobalEnv &GlobalEnv::Instance() { } void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool sequence_parallel_enabled, + int context_parallel_size, const std::string &context_parallel_comm_type, int pipeline_parallel_size, int virtual_pipeline_parallel_size) { std::lock_guard lock(mutex_); @@ -102,12 +116,20 @@ void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool seq CHECK_GE(tensor_parallel_size, 1) << "Tensor Parallel size must be >= 1"; tensor_parallel_size_ = tensor_parallel_size; sequence_parallel_enabled_ = sequence_parallel_enabled; + CHECK_GE(context_parallel_size, 1) << "Context Parallel size must be >= 1"; + context_parallel_size_ = context_parallel_size; + context_parallel_comm_type_ = context_parallel_comm_type; + CHECK_GE(pipeline_parallel_size, 1) << "Pipeline Parallel size must be >= 1"; pipeline_parallel_size_ = pipeline_parallel_size; virtual_pipeline_parallel_size_ = virtual_pipeline_parallel_size; - data_parallel_size_ = world_size_ / tensor_parallel_size_ / pipeline_parallel_size_; + + CHECK_EQ(world_size_ % (tensor_parallel_size_ * context_parallel_size_ * pipeline_parallel_size_), 0) + << "world_size must be divisible by TP * CP * PP"; + data_parallel_size_ = world_size_ / tensor_parallel_size_ / context_parallel_size_ / pipeline_parallel_size_; layout_.sizes[DP] = data_parallel_size_; layout_.sizes[TP] = tensor_parallel_size_; + layout_.sizes[CP] = context_parallel_size_; layout_.sizes[PP] = pipeline_parallel_size_; layout_.InitStrides(); @@ -159,6 +181,16 @@ bool GlobalEnv::sequence_parallel_enabled() const { return sequence_parallel_enabled_; } +int GlobalEnv::context_parallel_size() const { + CHECK(initialized_) << "GlobalEnv is not initialized!"; + return context_parallel_size_; +} + +const std::string &GlobalEnv::context_parallel_comm_type() const { + CHECK(initialized_) << "GlobalEnv is not initialized!"; + return context_parallel_comm_type_; +} + int GlobalEnv::data_parallel_size() const { CHECK(initialized_) << "GlobalEnv is not initialized!"; return data_parallel_size_; @@ -180,7 +212,7 @@ Layout GlobalEnv::layout() const { } namespace { -inline const char *AxisName(Axis a) { return a == DP ? "DP" : (a == TP ? "TP" : "PP"); } +inline const char *AxisName(Axis a) { return a == DP ? "DP" : (a == TP ? "TP" : (a == CP ? "CP" : "PP")); } inline int NumGroups(const Layout &L, Axis target) { int n = 1; @@ -196,8 +228,8 @@ inline int NumGroups(const Layout &L, Axis target) { std::string ProcessGroupOverview(const Layout &L, bool skip_trivial_axes) { std::ostringstream oss; oss << std::format("\n=== Parallel Communication Groups ===\n" - "world_size = {}, config: {{DP={}, TP={}, PP={}}}, order: {{", - GetWorldSize(), L.sizes[DP], L.sizes[TP], L.sizes[PP]); + "world_size = {}, config: {{DP={}, TP={}, CP={}, PP={}}}, order: {{", + GetWorldSize(), L.sizes[DP], L.sizes[TP], L.sizes[CP], L.sizes[PP]); for (int i = 0; i < AXIS_COUNT; ++i) { oss << AxisName(L.order[i]) << (i + 1 == AXIS_COUNT ? "" : " -> "); } oss << "}\n"; @@ -208,51 +240,38 @@ std::string ProcessGroupOverview(const Layout &L, bool skip_trivial_axes) { oss << std::format("[{}] size={}, unenabled\n", AxisName(ax), L.sizes[ax]); continue; } - // Build > mapping - std::vector>> groups; + + std::vector>> groups; for (int dp = 0; dp < (ax == DP ? 1 : L.sizes[DP]); ++dp) { for (int tp = 0; tp < (ax == TP ? 1 : L.sizes[TP]); ++tp) { - for (int pp = 0; pp < (ax == PP ? 1 : L.sizes[PP]); ++pp) { - int gid = L.GroupId(ax, dp, tp, pp); - groups.emplace_back(gid, std::make_tuple(dp, tp, pp)); + for (int cp = 0; cp < (ax == CP ? 1 : L.sizes[CP]); ++cp) { + for (int pp = 0; pp < (ax == PP ? 1 : L.sizes[PP]); ++pp) { + int gid = L.GroupId(ax, dp, tp, cp, pp); + groups.emplace_back(gid, std::make_tuple(dp, tp, cp, pp)); + } } } } - // Sort by the order of Group ID std::sort(groups.begin(), groups.end(), [](const auto &a, const auto &b) { return a.first < b.first; }); const int num_groups = NumGroups(L, ax); const auto name = AxisName(ax); oss << std::format("[{}] size={}, num_groups={}\n", name, L.sizes[ax], num_groups); - // Iterate and print in the order of Group ID for (const auto &pair : groups) { int gid = pair.first; - int dp, tp, pp; - std::tie(dp, tp, pp) = pair.second; - auto ranks = L.GroupRanks(ax, dp, tp, pp); - std::sort(ranks.begin(), ranks.end()); - - auto dp_size_str = (ax == DP) ? "-" : std::to_string(dp); - auto tp_size_str = (ax == TP) ? "-" : std::to_string(tp); - auto pp_size_str = (ax == PP) ? "-" : std::to_string(pp); - - std::string ranks_str; - ranks_str.reserve(ranks.size() * 4); - for (size_t i = 0; i < ranks.size(); ++i) { - if (i > 0) { - ranks_str += ", "; - } - ranks_str += std::to_string(ranks[i]); - } - oss << std::format(" - {} {} (dp={}, tp={}, pp={}): [{}]\n", name, gid, dp_size_str, tp_size_str, - pp_size_str, ranks_str); - } - if (a + 1 < AXIS_COUNT) { - oss << "\n"; + auto [dp, tp, cp, pp] = pair.second; + auto coord = std::format("dp={}, tp={}, cp={}, pp={}", ax == DP ? "-" : std::to_string(dp), + ax == TP ? "-" : std::to_string(tp), ax == CP ? "-" : std::to_string(cp), + ax == PP ? "-" : std::to_string(pp)); + + oss << std::format("- {} {} ({}): [", name, gid, coord); + auto ranks = L.GroupRanks(ax, dp, tp, cp, pp); + for (size_t i = 0; i < ranks.size(); ++i) { oss << ranks[i] << (i + 1 == ranks.size() ? "" : ", "); } + oss << "]\n"; } } - oss << "\n"; + oss << "=====================================\n"; return oss.str(); } diff --git a/infini_train/src/nn/parallel/parallel_functional.cc b/infini_train/src/nn/parallel/parallel_functional.cc index ffd218d7..3aa46731 100644 --- a/infini_train/src/nn/parallel/parallel_functional.cc +++ b/infini_train/src/nn/parallel/parallel_functional.cc @@ -40,6 +40,15 @@ std::shared_ptr ReduceScatter(const std::shared_ptr &output, const return pg->ReduceScatter(output, input, reduce_op, async_op); } +std::shared_ptr AllToAll(const std::shared_ptr &output, const std::shared_ptr &input, + const ProcessGroup *pg, bool async_op) { + auto device = output->GetDevice().type(); + if (pg == nullptr) { + pg = ProcessGroupFactory::Instance(device)->GetDefaultProcessGroup(); + } + return pg->AllToAll(output, input, async_op); +} + std::vector>> Scatter(const std::vector> &input_tensors, const std::vector &devices, int dim) { std::vector>> output_tensors; diff --git a/infini_train/src/nn/parallel/process_group.cc b/infini_train/src/nn/parallel/process_group.cc index 3c4c4910..9cddcc0c 100644 --- a/infini_train/src/nn/parallel/process_group.cc +++ b/infini_train/src/nn/parallel/process_group.cc @@ -194,6 +194,33 @@ std::shared_ptr ProcessGroup::ReduceScatter(const std::shared_ptr } } +std::shared_ptr ProcessGroup::AllToAll(const std::shared_ptr &output, + const std::shared_ptr &input, bool async_op) const { + auto device = input->GetDevice(); + CHECK_EQ(device, output->GetDevice()); + CHECK(input->Dtype() == output->Dtype()); + CHECK_EQ(input->NumElements(), output->NumElements()); + CHECK_EQ(input->NumElements() % world_size_, 0) << "AllToAll input must be evenly divisible by world size"; + core::DeviceGuard guard(device); + auto *compute_stream = runtime_impl_->GetStream(device); + auto *comm_stream = device_stream_map_.at(device.index()); + auto comm = device_comm_map_.at(device.index()); + + auto work = std::make_shared(device, comm); + runtime_impl_->EventRecord(work->ready_event(), compute_stream); + runtime_impl_->StreamWaitEvent(comm_stream, work->ready_event(), 0); + ccl_impl_->AllToAll(input->DataPtr(), output->DataPtr(), input->NumElements() / world_size_, input->Dtype(), comm, + comm_stream); + runtime_impl_->EventRecord(work->done_event(), comm_stream); + + if (async_op) { + return work; + } else { + work->WaitNonBlocking(); + return nullptr; + } +} + std::shared_ptr ProcessGroup::Send(std::vector> tensors, int dest_rank, bool async_op) const { CHECK_GT(tensors.size(), 0); @@ -248,6 +275,46 @@ std::shared_ptr ProcessGroup::Recv(std::vector> te } } +std::shared_ptr ProcessGroup::BatchSendRecv(const std::vector &ops, bool async_op) const { + CHECK_GT(ops.size(), 0); + CHECK_NOTNULL(ops[0].tensor); + auto device = ops[0].tensor->GetDevice(); + core::DeviceGuard guard(device); + auto *compute_stream = runtime_impl_->GetStream(device); + auto *comm_stream = device_stream_map_.at(device.index()); + auto comm = device_comm_map_.at(device.index()); + + auto work = std::make_shared(device, comm); + runtime_impl_->EventRecord(work->ready_event(), compute_stream); + runtime_impl_->StreamWaitEvent(comm_stream, work->ready_event(), 0); + + { + core::CclGroupGuard ccl_group_guard(backend_); + for (const auto &op : ops) { + CHECK_NOTNULL(op.tensor); + CHECK_EQ(device, op.tensor->GetDevice()); + CHECK_GE(op.peer_rank, 0); + CHECK_LT(op.peer_rank, world_size_); + if (op.type == P2POpType::kSend) { + ccl_impl_->Send(op.tensor->DataPtr(), op.tensor->NumElements(), op.tensor->Dtype(), op.peer_rank, comm, + comm_stream); + } else { + ccl_impl_->Recv(op.tensor->DataPtr(), op.tensor->NumElements(), op.tensor->Dtype(), op.peer_rank, comm, + comm_stream); + } + } + } + + runtime_impl_->EventRecord(work->done_event(), comm_stream); + + if (async_op) { + return work; + } else { + work->WaitNonBlocking(); + return nullptr; + } +} + std::vector> ProcessGroup::BroadCast(const std::vector> &input_tensors) const { std::vector> outputs; diff --git a/infini_train/src/nn/parallel/utils.cc b/infini_train/src/nn/parallel/utils.cc index 93c6ae31..0a6d962e 100644 --- a/infini_train/src/nn/parallel/utils.cc +++ b/infini_train/src/nn/parallel/utils.cc @@ -8,18 +8,45 @@ std::string GetDataParallelProcessGroupName(int global_rank) { return "DP" + std::to_string(global::GetGroupId(global::DP, global_rank)); } +std::string GetDataParallelWithContextProcessGroupName(int global_rank) { + int dp, tp, cp, pp; + global::GetCoordOf(global_rank, dp, tp, cp, pp); + return "DP_CP" + std::to_string(tp * global::GetPipelineParallelSize() + pp); +} + std::string GetTensorParallelProcessGroupName(int global_rank) { return "TP" + std::to_string(global::GetGroupId(global::TP, global_rank)); } +std::string GetContextParallelProcessGroupName(int global_rank) { + return "CP" + std::to_string(global::GetGroupId(global::CP, global_rank)); +} + std::string GetPipelineParallelProcessGroupName(int global_rank) { return "PP" + std::to_string(global::GetGroupId(global::PP, global_rank)); } std::vector GetDataParallelGroupRanks(int global_rank) { return global::GetGroupRanks(global::DP, global_rank); } +std::vector GetDataParallelWithContextGroupRanks(int global_rank) { + int dp, tp, cp, pp; + global::GetCoordOf(global_rank, dp, tp, cp, pp); + std::vector ranks; + ranks.reserve(global::GetDataParallelSize() * global::GetContextParallelSize()); + for (int dp_idx = 0; dp_idx < global::GetDataParallelSize(); ++dp_idx) { + for (int cp_idx = 0; cp_idx < global::GetContextParallelSize(); ++cp_idx) { + ranks.push_back(global::GetRankOf(dp_idx, tp, cp_idx, pp)); + } + } + return ranks; +} + std::vector GetTensorParallelGroupRanks(int global_rank) { return global::GetGroupRanks(global::TP, global_rank); } +std::vector GetContextParallelGroupRanks(int global_rank) { + return global::GetGroupRanks(global::CP, global_rank); +} + std::vector GetPipelineParallelGroupRanks(int global_rank) { return global::GetGroupRanks(global::PP, global_rank); } diff --git a/scripts/test_config.json b/scripts/test_config.json index aa7de8d1..be6dad77 100644 --- a/scripts/test_config.json +++ b/scripts/test_config.json @@ -191,6 +191,62 @@ "pipeline_parallel": 2, "virtual_pipeline_parallel": 2 } + }, + { + "id": "9", + "args": { + "dtype": "float32", + "nthread_per_process": 4, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "context_parallel": 2, + "cp_comm_type": "p2p" + } + }, + { + "id": "9_bfloat16", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 4, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "context_parallel": 2, + "cp_comm_type": "p2p" + } + }, + { + "id": "10", + "args": { + "dtype": "float32", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 2, + "sequence_parallel": true, + "pipeline_parallel": 2, + "virtual_pipeline_parallel": 2, + "context_parallel": 2, + "cp_comm_type": "p2p" + } + }, + { + "id": "10_bfloat16", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 2, + "sequence_parallel": true, + "pipeline_parallel": 2, + "virtual_pipeline_parallel": 2, + "context_parallel": 2, + "cp_comm_type": "p2p" + } } ] }, @@ -341,6 +397,58 @@ "zero_stage": 2 } }, + { + "id": "9_distopt", + "args": { + "dtype": "float32", + "nthread_per_process": 4, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "context_parallel": 2, + "cp_comm_type": "p2p", + "zero_stage": 1 + } + }, + { + "id": "9_zero2", + "args": { + "dtype": "float32", + "nthread_per_process": 4, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "context_parallel": 2, + "cp_comm_type": "p2p", + "zero_stage": 2 + } + }, + { + "id": "9_bfloat16_distopt", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 4, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "context_parallel": 2, + "cp_comm_type": "p2p", + "zero_stage": 1 + } + }, + { + "id": "9_bfloat16_zero2", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 4, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "context_parallel": 2, + "cp_comm_type": "p2p", + "zero_stage": 2 + } + }, { "id": "8_distopt", "args": { @@ -673,4 +781,4 @@ ] } ] -} \ No newline at end of file +} diff --git a/tests/autograd/test_autograd.cc b/tests/autograd/test_autograd.cc index 8ed27acf..2972a45d 100644 --- a/tests/autograd/test_autograd.cc +++ b/tests/autograd/test_autograd.cc @@ -8,8 +8,8 @@ #include "infini_train/include/autograd/function.h" #include "infini_train/include/autograd/linear.h" #include "infini_train/include/autograd/matmul.h" -#include "infini_train/include/autograd/normalization.h" #include "infini_train/include/autograd/no_op.h" +#include "infini_train/include/autograd/normalization.h" #include "infini_train/include/autograd/outer.h" #include "infini_train/include/autograd/reduction.h" #include "infini_train/include/autograd/softmax.h" @@ -123,7 +123,8 @@ TEST_P(AutogradForwardTest, SavedOutputIsPackedWithoutAutogradMeta) { } TEST_P(AutogradForwardTest, FunctionCtxNeedsInputGradAndSaveForBackward) { - auto requires_grad_input = std::make_shared(std::vector{2, 2}, DataType::kFLOAT32, GetDevice(), true); + auto requires_grad_input + = std::make_shared(std::vector{2, 2}, DataType::kFLOAT32, GetDevice(), true); auto no_grad_input = std::make_shared(std::vector{2, 2}, DataType::kFLOAT32, GetDevice(), false); requires_grad_input->Fill(1.0f); no_grad_input->Fill(2.0f); diff --git a/tests/common/test_main.cc b/tests/common/test_main.cc index 54ef48dd..521ee4c8 100644 --- a/tests/common/test_main.cc +++ b/tests/common/test_main.cc @@ -4,6 +4,6 @@ int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); - infini_train::nn::parallel::global::GlobalEnv::Instance().Init(1, 1, false, 1, 1); + infini_train::nn::parallel::global::GlobalEnv::Instance().Init(1, 1, false, 1, "p2p", 1, 1); return RUN_ALL_TESTS(); }