From abce8d2d552720c58f325a5c001e20d25c71505b Mon Sep 17 00:00:00 2001 From: bolunz Date: Tue, 30 Jun 2026 14:53:45 +0800 Subject: [PATCH 1/3] feat: init cp, support allgather and p2p attn --- example/gpt2/main.cc | 70 ++- example/llama3/main.cc | 71 ++- infini_train/include/checkpoint/checkpoint.h | 1 + .../include/checkpoint/checkpoint_manager.h | 1 + .../include/nn/parallel/context_parallel.h | 40 ++ infini_train/include/nn/parallel/global.h | 57 +- .../include/nn/parallel/process_group.h | 2 + infini_train/include/nn/parallel/utils.h | 8 + infini_train/src/checkpoint/checkpoint.cc | 8 +- .../src/checkpoint/checkpoint_manager.cc | 4 + infini_train/src/core/ccl/ccl_utils.cc | 143 +++++- .../transformer/causal_self_attention.cc | 76 ++- .../src/nn/modules/transformer/transformer.cc | 18 +- .../src/nn/parallel/context_parallel.cc | 486 ++++++++++++++++++ .../parallel/ddp/distributed_data_parallel.cc | 2 +- .../nn/parallel/ddp/param_and_grad_buffer.cc | 2 +- infini_train/src/nn/parallel/ddp/reducer.cc | 2 +- infini_train/src/nn/parallel/global.cc | 109 ++-- infini_train/src/nn/parallel/utils.cc | 27 + scripts/test_config.json | 118 ++++- tests/autograd/test_autograd.cc | 5 +- tests/common/test_main.cc | 2 +- 22 files changed, 1101 insertions(+), 151 deletions(-) create mode 100644 infini_train/include/nn/parallel/context_parallel.h create mode 100644 infini_train/src/nn/parallel/context_parallel.cc diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index b666b6b1..401c5ab6 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -4,8 +4,10 @@ #include #include #include +#include #include #include +#include #include "gflags/gflags.h" #include "glog/logging.h" @@ -19,6 +21,7 @@ #include "infini_train/include/nn/modules/loss.h" #include "infini_train/include/nn/modules/module.h" #include "infini_train/include/nn/modules/transformer/transformer.h" +#include "infini_train/include/nn/parallel/context_parallel.h" #include "infini_train/include/nn/parallel/ddp/distributed_data_parallel.h" #include "infini_train/include/nn/parallel/ddp/distributed_optimizer.h" #include "infini_train/include/nn/parallel/global.h" @@ -76,6 +79,8 @@ DEFINE_int32( "When set > 1, enables data parallelism with device=cuda on the specified number of visible CUDA devices."); DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size"); DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel"); +DEFINE_uint32(context_parallel, 1, "Context Parallel world size"); +DEFINE_string(cp_comm_type, "p2p", "Context Parallel communication type (all_gather|p2p|a2a)"); DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the number of PP stages."); DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage."); @@ -124,6 +129,9 @@ DEFINE_validator(model, [](const char *, const std::string &value) { return kSup DEFINE_validator(device, [](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; }); DEFINE_validator(zero_stage, [](const char *, int32_t value) { return value >= 0 && value <= 3; }); +DEFINE_validator(cp_comm_type, [](const char *, const std::string &value) { + return value == "all_gather" || value == "p2p" || value == "a2a"; +}); void Train(const nn::parallel::Rank &rank) { using namespace nn::parallel; @@ -150,15 +158,27 @@ void Train(const nn::parallel::Rank &rank) { int ddp_world_size = global::GetDataParallelSize(); int tp_world_size = global::GetTensorParallelSize(); int sp_world_size = global::GetSequenceParallelEnabled() ? tp_world_size : 1; + int cp_world_size = global::GetContextParallelSize(); + int dp_cp_world_size = ddp_world_size * cp_world_size; int pp_world_size = global::GetPipelineParallelSize(); + const auto cp_local_sequence_length = FLAGS_sequence_length / cp_world_size; + if (cp_world_size > 1) { + CHECK_EQ(FLAGS_sequence_length % cp_world_size, 0) << "sequence_length must be divisible by context_parallel."; + } if (FLAGS_sequence_parallel) { - CHECK_EQ(FLAGS_sequence_length % tp_world_size, 0) - << "sequence_length must be divisible by tp_world_size when SP is enabled (pad later if needed)."; + CHECK_EQ(cp_local_sequence_length % tp_world_size, 0) + << "CP-local sequence_length must be divisible by tp_world_size when SP is enabled."; + } + if (FLAGS_zero_stage >= 1) { + CHECK_GT(dp_cp_world_size, 1) << "ZeRO requires DP or CP world size greater than 1."; + CHECK_LE(FLAGS_zero_stage, 2) << "ZeRO-3 is not supported yet."; } int ddp_rank = 0; + int dp_cp_rank = 0; int tp_rank = 0; + int cp_rank = 0; int pp_rank = 0; // Set thread-local global rank @@ -166,7 +186,9 @@ void Train(const nn::parallel::Rank &rank) { nn::parallel::global::thread_global_rank = rank.GlobalRank(); const ProcessGroup *ddp_pg = nullptr; + const ProcessGroup *dp_cp_pg = nullptr; const ProcessGroup *tp_pg = nullptr; + const ProcessGroup *cp_pg = nullptr; const ProcessGroup *pp_pg = nullptr; if (rank.IsParallel()) { @@ -179,6 +201,12 @@ void Train(const nn::parallel::Rank &rank) { ddp_rank = ddp_pg->GetGroupRank(rank.GlobalRank()); } + if (dp_cp_world_size > 1) { + dp_cp_pg = pg_factory->GetOrCreate(GetDataParallelWithContextProcessGroupName(rank.GlobalRank()), + GetDataParallelWithContextGroupRanks(rank.GlobalRank())); + dp_cp_rank = dp_cp_pg->GetGroupRank(rank.GlobalRank()); + } + if (tp_world_size > 1) { tp_pg = pg_factory->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()), GetTensorParallelGroupRanks(rank.GlobalRank())); @@ -187,6 +215,13 @@ void Train(const nn::parallel::Rank &rank) { nn::parallel::tp_rank = tp_rank; } + if (cp_world_size > 1) { + cp_pg = pg_factory->GetOrCreate(GetContextParallelProcessGroupName(rank.GlobalRank()), + GetContextParallelGroupRanks(rank.GlobalRank())); + cp_rank = cp_pg->GetGroupRank(rank.GlobalRank()); + nn::parallel::cp_rank = cp_rank; + } + if (pp_world_size > 1) { pp_pg = pg_factory->GetOrCreate(GetPipelineParallelProcessGroupName(rank.GlobalRank()), GetPipelineParallelGroupRanks(rank.GlobalRank())); @@ -272,13 +307,13 @@ void Train(const nn::parallel::Rank &rank) { if (pp_world_size > 1) { // NOTE(dcj): To ensure that the tensor shapes at the pipeline stage boundaries remain correct - // when sequence parallelism (SP) is enabled, we need to divide by sp_world_size. + // when context/sequence parallelism is enabled, use the CP-local and SP-local sequence length. auto shapes = std::vector>{ - {FLAGS_batch_size, FLAGS_sequence_length / sp_world_size, model_config.n_embd}}; + {FLAGS_batch_size, cp_local_sequence_length / sp_world_size, model_config.n_embd}}; model = std::make_shared(model, pp_world_size, num_micro_batches, shapes, pp_rank, device, model_config.GetChunkSize()); - if (ddp_world_size > 1) { + if (dp_cp_world_size > 1) { auto ddp_config = DistributedDataParallelConfig{.zero_stage = FLAGS_zero_stage}; auto *mutable_chunks = dynamic_cast(model.get())->mutable_chunks(); for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) { @@ -286,7 +321,7 @@ void Train(const nn::parallel::Rank &rank) { = std::make_shared(mutable_chunks->at(chunk_id), rank, ddp_config); } } - } else if (ddp_world_size > 1) { + } else if (dp_cp_world_size > 1) { // NOTE(dcj): Complete all device (.to(device)) and dtype (.to(dtype)) conversions // before wrapping the model with DistributedDataParallel (DDP). // Otherwise, DDP’s gradient hooks may be lost because new parameter tensors @@ -325,7 +360,7 @@ void Train(const nn::parallel::Rank &rank) { ? *(dynamic_cast(model.get())->mutable_chunks()) : std::vector>{model}; optimizer = std::make_shared(optimizer_creator, params_to_optimize, - model_chunks, ddp_world_size, ddp_rank); + model_chunks, dp_cp_world_size, dp_cp_rank); } else { optimizer = optimizer_creator(params_to_optimize); } @@ -376,6 +411,7 @@ void Train(const nn::parallel::Rank &rank) { .ddp_size = ddp_world_size, .tp_size = tp_world_size, .sp_size = sp_world_size, + .cp_size = cp_world_size, .pp_size = pp_world_size, .save_optimizer_state = FLAGS_save_optimizer_state, .checkpoint_root_dir = FLAGS_save, @@ -441,6 +477,10 @@ void Train(const nn::parallel::Rank &rank) { consumed_batches = train_iter.BatchIndex(); x = std::make_shared(x->To(device)); y = std::make_shared(y->To(device)); + if (cp_world_size > 1) { + x = nn::parallel::SliceAlongCPRegionFunc(x, 1); + y = nn::parallel::SliceAlongCPRegionFunc(y, 1); + } LOG(INFO) << "Rank " << rank.GlobalRank() << ": start forward"; @@ -472,13 +512,17 @@ void Train(const nn::parallel::Rank &rank) { consumed_batches = train_iter.BatchIndex(); x = std::make_shared(x->To(device)); y = std::make_shared(y->To(device)); + if (cp_world_size > 1) { + x = nn::parallel::SliceAlongCPRegionFunc(x, 1); + y = nn::parallel::SliceAlongCPRegionFunc(y, 1); + } lossf = model->TrainStep({x}, {y}, optimizer, loss_fn, dtype); } - if (ddp_world_size > 1) { + if (dp_cp_world_size > 1) { auto lossf_tensor = std::make_shared(&lossf, std::vector{}, DataType::kFLOAT32, device); - function::AllReduce(lossf_tensor, function::ReduceOpType::kAvg, ddp_pg); + function::AllReduce(lossf_tensor, function::ReduceOpType::kAvg, dp_cp_pg); lossf = static_cast(lossf_tensor->To(Device()).DataPtr())[0]; } @@ -491,10 +535,11 @@ void Train(const nn::parallel::Rank &rank) { std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device); LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | " - "peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})", + "peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, CP={}, " + "PP={})", step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f, tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size, - pp_world_size); + cp_world_size, pp_world_size); if ((step + 1) % FLAGS_freq_generate_txt == 0) { if (tokenizer) { @@ -535,7 +580,8 @@ int main(int argc, char *argv[]) { auto precision_config = utils::PrecisionCheckConfig::Parse(FLAGS_precision_check); nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel, - FLAGS_pipeline_parallel, FLAGS_virtual_pipeline_parallel); + FLAGS_context_parallel, FLAGS_cp_comm_type, FLAGS_pipeline_parallel, + FLAGS_virtual_pipeline_parallel); utils::PrecisionCheckEnv::Instance().Init(precision_config); LOG(INFO) << nn::parallel::global::ProcessGroupOverview(); diff --git a/example/llama3/main.cc b/example/llama3/main.cc index 5a191865..824bf7ab 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -3,7 +3,9 @@ #include #include #include +#include #include +#include #include "gflags/gflags.h" #include "glog/logging.h" @@ -18,6 +20,7 @@ #include "infini_train/include/nn/modules/loss.h" #include "infini_train/include/nn/modules/module.h" #include "infini_train/include/nn/modules/transformer/transformer.h" +#include "infini_train/include/nn/parallel/context_parallel.h" #include "infini_train/include/nn/parallel/ddp/distributed_data_parallel.h" #include "infini_train/include/nn/parallel/ddp/distributed_optimizer.h" #include "infini_train/include/nn/parallel/global.h" @@ -75,6 +78,8 @@ DEFINE_int32( "When set > 1, enables data parallelism with device=cuda on the specified number of visible CUDA devices."); DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size"); DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel"); +DEFINE_uint32(context_parallel, 1, "Context Parallel world size"); +DEFINE_string(cp_comm_type, "p2p", "Context Parallel communication type (all_gather|p2p|a2a)"); DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the number of PP stages."); DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage."); // precision @@ -105,12 +110,16 @@ constexpr char kDeviceCPU[] = "cpu"; constexpr char kDeviceCUDA[] = "cuda"; constexpr char kDtypeFP32[] = "float32"; constexpr char kDtypeBF16[] = "bfloat16"; + } // namespace DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); }); DEFINE_validator(device, [](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; }); DEFINE_validator(zero_stage, [](const char *, int32_t value) { return value >= 0 && value <= 3; }); +DEFINE_validator(cp_comm_type, [](const char *, const std::string &value) { + return value == "all_gather" || value == "p2p" || value == "a2a"; +}); void Train(const nn::parallel::Rank &rank) { using namespace nn::parallel; @@ -137,22 +146,36 @@ void Train(const nn::parallel::Rank &rank) { int ddp_world_size = global::GetDataParallelSize(); int tp_world_size = global::GetTensorParallelSize(); int sp_world_size = global::GetSequenceParallelEnabled() ? tp_world_size : 1; + int cp_world_size = global::GetContextParallelSize(); + int dp_cp_world_size = ddp_world_size * cp_world_size; int pp_world_size = global::GetPipelineParallelSize(); + const auto cp_local_sequence_length = FLAGS_sequence_length / cp_world_size; + if (cp_world_size > 1) { + CHECK_EQ(FLAGS_sequence_length % cp_world_size, 0) << "sequence_length must be divisible by context_parallel."; + } if (FLAGS_sequence_parallel) { - CHECK_EQ(FLAGS_sequence_length % tp_world_size, 0) - << "sequence_length must be divisible by tp_world_size when SP is enabled (pad later if needed)."; + CHECK_EQ(cp_local_sequence_length % tp_world_size, 0) + << "CP-local sequence_length must be divisible by tp_world_size when SP is enabled."; + } + if (FLAGS_zero_stage >= 1) { + CHECK_GT(dp_cp_world_size, 1) << "ZeRO requires DP or CP world size greater than 1."; + CHECK_LE(FLAGS_zero_stage, 2) << "ZeRO-3 is not supported yet."; } int ddp_rank = 0; + int dp_cp_rank = 0; int tp_rank = 0; + int cp_rank = 0; int pp_rank = 0; // Set thread-local global rank nn::parallel::global::thread_global_rank = rank.GlobalRank(); const ProcessGroup *ddp_pg = nullptr; + const ProcessGroup *dp_cp_pg = nullptr; const ProcessGroup *tp_pg = nullptr; + const ProcessGroup *cp_pg = nullptr; const ProcessGroup *pp_pg = nullptr; if (rank.IsParallel()) { @@ -165,6 +188,12 @@ void Train(const nn::parallel::Rank &rank) { ddp_rank = ddp_pg->GetGroupRank(rank.GlobalRank()); } + if (dp_cp_world_size > 1) { + dp_cp_pg = pg_factory->GetOrCreate(GetDataParallelWithContextProcessGroupName(rank.GlobalRank()), + GetDataParallelWithContextGroupRanks(rank.GlobalRank())); + dp_cp_rank = dp_cp_pg->GetGroupRank(rank.GlobalRank()); + } + if (tp_world_size > 1) { tp_pg = pg_factory->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()), GetTensorParallelGroupRanks(rank.GlobalRank())); @@ -173,6 +202,13 @@ void Train(const nn::parallel::Rank &rank) { nn::parallel::tp_rank = tp_rank; } + if (cp_world_size > 1) { + cp_pg = pg_factory->GetOrCreate(GetContextParallelProcessGroupName(rank.GlobalRank()), + GetContextParallelGroupRanks(rank.GlobalRank())); + cp_rank = cp_pg->GetGroupRank(rank.GlobalRank()); + nn::parallel::cp_rank = cp_rank; + } + if (pp_world_size > 1) { pp_pg = pg_factory->GetOrCreate(GetPipelineParallelProcessGroupName(rank.GlobalRank()), GetPipelineParallelGroupRanks(rank.GlobalRank())); @@ -243,13 +279,13 @@ void Train(const nn::parallel::Rank &rank) { if (pp_world_size > 1) { // NOTE(dcj): To ensure that the tensor shapes at the pipeline stage boundaries remain correct - // when sequence parallelism (SP) is enabled, we need to divide by sp_world_size. + // when context/sequence parallelism is enabled, use the CP-local and SP-local sequence length. auto shapes = std::vector>{ - {FLAGS_batch_size, FLAGS_sequence_length / sp_world_size, model_config.n_embd}}; + {FLAGS_batch_size, cp_local_sequence_length / sp_world_size, model_config.n_embd}}; model = std::make_shared(model, pp_world_size, num_micro_batches, shapes, pp_rank, device, model_config.GetChunkSize()); - if (ddp_world_size > 1) { + if (dp_cp_world_size > 1) { auto ddp_config = DistributedDataParallelConfig{.zero_stage = FLAGS_zero_stage}; auto *mutable_chunks = dynamic_cast(model.get())->mutable_chunks(); for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) { @@ -257,7 +293,7 @@ void Train(const nn::parallel::Rank &rank) { = std::make_shared(mutable_chunks->at(chunk_id), rank, ddp_config); } } - } else if (ddp_world_size > 1) { + } else if (dp_cp_world_size > 1) { // NOTE(dcj): Complete all device (.to(device)) and dtype (.to(dtype)) conversions // before wrapping the model with DistributedDataParallel (DDP). // Otherwise, DDP’s gradient hooks may be lost because new parameter tensors @@ -305,7 +341,7 @@ void Train(const nn::parallel::Rank &rank) { ? *(dynamic_cast(model.get())->mutable_chunks()) : std::vector>{model}; optimizer = std::make_shared(optimizer_creator, params_to_optimize, - model_chunks, ddp_world_size, ddp_rank); + model_chunks, dp_cp_world_size, dp_cp_rank); } else { optimizer = optimizer_creator(params_to_optimize); } @@ -356,6 +392,7 @@ void Train(const nn::parallel::Rank &rank) { .ddp_size = ddp_world_size, .tp_size = tp_world_size, .sp_size = sp_world_size, + .cp_size = cp_world_size, .pp_size = pp_world_size, .save_optimizer_state = FLAGS_save_optimizer_state, .checkpoint_root_dir = FLAGS_save, @@ -419,6 +456,10 @@ void Train(const nn::parallel::Rank &rank) { consumed_batches = train_iter.BatchIndex(); x = std::make_shared(x->To(device)); y = std::make_shared(y->To(device)); + if (cp_world_size > 1) { + x = nn::parallel::SliceAlongCPRegionFunc(x, 1); + y = nn::parallel::SliceAlongCPRegionFunc(y, 1); + } LOG(INFO) << "Rank " << rank.GlobalRank() << ": start forward"; // (bs, seq_len, vocab_size) @@ -449,13 +490,17 @@ void Train(const nn::parallel::Rank &rank) { consumed_batches = train_iter.BatchIndex(); x = std::make_shared(x->To(device)); y = std::make_shared(y->To(device)); + if (cp_world_size > 1) { + x = nn::parallel::SliceAlongCPRegionFunc(x, 1); + y = nn::parallel::SliceAlongCPRegionFunc(y, 1); + } lossf = model->TrainStep({x}, {y}, optimizer, loss_fn, dtype); } - if (ddp_world_size > 1) { + if (dp_cp_world_size > 1) { auto lossf_tensor = std::make_shared(&lossf, std::vector{}, DataType::kFLOAT32, device); - function::AllReduce(lossf_tensor, function::ReduceOpType::kAvg, ddp_pg); + function::AllReduce(lossf_tensor, function::ReduceOpType::kAvg, dp_cp_pg); lossf = static_cast(lossf_tensor->To(Device()).DataPtr())[0]; } @@ -468,10 +513,11 @@ void Train(const nn::parallel::Rank &rank) { std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device); LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | " - "peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})", + "peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, CP={}, " + "PP={})", step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f, tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size, - pp_world_size); + cp_world_size, pp_world_size); if ((step + 1) % FLAGS_freq_generate_txt == 0) { // FIXME(jym): to support PP @@ -512,7 +558,8 @@ int main(int argc, char *argv[]) { auto precision_config = utils::PrecisionCheckConfig::Parse(FLAGS_precision_check); nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel, - FLAGS_pipeline_parallel, FLAGS_virtual_pipeline_parallel); + FLAGS_context_parallel, FLAGS_cp_comm_type, FLAGS_pipeline_parallel, + FLAGS_virtual_pipeline_parallel); utils::PrecisionCheckEnv::Instance().Init(precision_config); LOG(INFO) << nn::parallel::global::ProcessGroupOverview(); diff --git a/infini_train/include/checkpoint/checkpoint.h b/infini_train/include/checkpoint/checkpoint.h index b122a17a..d8872b4e 100644 --- a/infini_train/include/checkpoint/checkpoint.h +++ b/infini_train/include/checkpoint/checkpoint.h @@ -28,6 +28,7 @@ struct TrainerState { int ddp_size = 1; int tp_size = 1; int sp_size = 1; + int cp_size = 1; int pp_size = 1; }; diff --git a/infini_train/include/checkpoint/checkpoint_manager.h b/infini_train/include/checkpoint/checkpoint_manager.h index cce14107..e50bbc31 100644 --- a/infini_train/include/checkpoint/checkpoint_manager.h +++ b/infini_train/include/checkpoint/checkpoint_manager.h @@ -45,6 +45,7 @@ struct SaveCheckpointArgs { int ddp_size = 1; int tp_size = 1; int sp_size = 1; + int cp_size = 1; int pp_size = 1; bool save_optimizer_state = true; std::filesystem::path checkpoint_root_dir; diff --git a/infini_train/include/nn/parallel/context_parallel.h b/infini_train/include/nn/parallel/context_parallel.h new file mode 100644 index 00000000..b1ecbb82 --- /dev/null +++ b/infini_train/include/nn/parallel/context_parallel.h @@ -0,0 +1,40 @@ +#pragma once + +#include +#include + +namespace infini_train { +class Tensor; +} // namespace infini_train + +namespace infini_train::nn::parallel { + +extern thread_local int cp_rank; + +int GetContextParallelRank(); + +int64_t GetContextParallelSequenceStart(int64_t local_sequence_length); + +std::shared_ptr SliceAlongCPRegionFunc(const std::shared_ptr &input, int64_t dim); + +std::shared_ptr GatherFromCPRegionFunc(const std::shared_ptr &input); + +std::shared_ptr ContextParallelAttentionFunc(const std::shared_ptr &q, const std::shared_ptr &k, + const std::shared_ptr &v, + const std::shared_ptr &mask, bool mask_true_means_invalid, + float scale, int64_t n_rep); + +std::shared_ptr AttnFuncWithCPAndKVP2P(const std::shared_ptr &q, const std::shared_ptr &k, + const std::shared_ptr &v, const std::shared_ptr &mask, + bool mask_true_means_invalid, float scale, int64_t n_rep); + +std::shared_ptr AttnFuncWithCPAndKVAllGather(const std::shared_ptr &q, const std::shared_ptr &k, + const std::shared_ptr &v, + const std::shared_ptr &mask, bool mask_true_means_invalid, + float scale, int64_t n_rep); + +std::shared_ptr AttnFuncWithCPAndQKVOA2A(const std::shared_ptr &q, const std::shared_ptr &k, + const std::shared_ptr &v, const std::shared_ptr &mask, + bool mask_true_means_invalid, float scale, int64_t n_rep); + +} // namespace infini_train::nn::parallel diff --git a/infini_train/include/nn/parallel/global.h b/infini_train/include/nn/parallel/global.h index 9373100f..0ec929c0 100644 --- a/infini_train/include/nn/parallel/global.h +++ b/infini_train/include/nn/parallel/global.h @@ -8,20 +8,24 @@ namespace infini_train::nn::parallel::global { extern thread_local int thread_global_rank; -enum Axis : uint8_t { DP = 0, TP = 1, PP = 2, AXIS_COUNT = 3 }; +enum Axis : uint8_t { DP = 0, TP = 1, CP = 2, PP = 3, AXIS_COUNT = 4 }; struct Layout { - int sizes[AXIS_COUNT]{1, 1, 1}; - // Default order according to Megatron-LM is TP-DP-PP. Ref: + int sizes[AXIS_COUNT]{1, 1, 1, 1}; + // Default order follows Megatron-LM's TP-CP-DP-PP layout for context parallelism. Ref: // https://github.com/NVIDIA/Megatron-LM/blob/e07c4a4450b6faa187a1ef4ec082a35ad7d2f085/megatron/core/parallel_state.py#L618 - Axis order[AXIS_COUNT]{TP, DP, PP}; - int strides[AXIS_COUNT]{1, 1, 1}; + Axis order[AXIS_COUNT]{TP, CP, DP, PP}; + int strides[AXIS_COUNT]{1, 1, 1, 1}; void InitStrides(); int RankOf(int dp, int tp, int pp) const; + int RankOf(int dp, int tp, int cp, int pp) const; void CoordOf(int rank, int &dp, int &tp, int &pp) const; + void CoordOf(int rank, int &dp, int &tp, int &cp, int &pp) const; int GroupId(Axis target, int dp, int tp, int pp) const; + int GroupId(Axis target, int dp, int tp, int cp, int pp) const; std::vector GroupRanks(Axis target, int fixed_dp, int fixed_tp, int fixed_pp) const; + std::vector GroupRanks(Axis target, int fixed_dp, int fixed_tp, int fixed_cp, int fixed_pp) const; }; class GlobalEnv { @@ -29,7 +33,8 @@ class GlobalEnv { static GlobalEnv &Instance(); void Init(int threads_per_process, int tensor_parallel_size, bool sequence_parallel_enabled, - int pipeline_parallel_size, int virtual_pipeline_parallel_size); + int context_parallel_size, const std::string &context_parallel_comm_type, int pipeline_parallel_size, + int virtual_pipeline_parallel_size); int nnodes() const; @@ -49,6 +54,10 @@ class GlobalEnv { bool sequence_parallel_enabled() const; + int context_parallel_size() const; + + const std::string &context_parallel_comm_type() const; + int data_parallel_size() const; int pipeline_parallel_size() const; @@ -75,6 +84,8 @@ class GlobalEnv { int tensor_parallel_size_ = 1; bool sequence_parallel_enabled_ = false; + int context_parallel_size_ = 1; + std::string context_parallel_comm_type_ = "p2p"; int data_parallel_size_ = 1; @@ -90,8 +101,16 @@ class GlobalEnv { inline void InitAllEnv(int nthread_per_process, int tensor_parallel_size, bool sequence_parallel_enabled, int pipeline_parallel_size, int virtual_pipeline_parallel) { GlobalEnv::Instance().Init(nthread_per_process, tensor_parallel_size, sequence_parallel_enabled, + /*context_parallel_size=*/1, /*context_parallel_comm_type=*/"p2p", pipeline_parallel_size, virtual_pipeline_parallel); } +inline void InitAllEnv(int nthread_per_process, int tensor_parallel_size, bool sequence_parallel_enabled, + int context_parallel_size, const std::string &context_parallel_comm_type, + int pipeline_parallel_size, int virtual_pipeline_parallel) { + GlobalEnv::Instance().Init(nthread_per_process, tensor_parallel_size, sequence_parallel_enabled, + context_parallel_size, context_parallel_comm_type, pipeline_parallel_size, + virtual_pipeline_parallel); +} inline int GetNnodes() { return GlobalEnv::Instance().nnodes(); } inline int GetWorldSize() { return GlobalEnv::Instance().world_size(); } inline int GetNprocPerNode() { return GlobalEnv::Instance().nproc_per_node(); } @@ -102,6 +121,8 @@ inline int GetLocalProcRank() { return GlobalEnv::Instance().local_proc_rank(); inline int GetTensorParallelSize() { return GlobalEnv::Instance().tensor_parallel_size(); } inline int GetSequenceParallelSize() { return GlobalEnv::Instance().sequence_parallel_size(); } inline bool GetSequenceParallelEnabled() { return GlobalEnv::Instance().sequence_parallel_enabled(); } +inline int GetContextParallelSize() { return GlobalEnv::Instance().context_parallel_size(); } +inline const std::string &GetContextParallelCommType() { return GlobalEnv::Instance().context_parallel_comm_type(); } inline int GetDataParallelSize() { return GlobalEnv::Instance().data_parallel_size(); } inline int GetPipelineParallelSize() { return GlobalEnv::Instance().pipeline_parallel_size(); } inline int GetVirtualPipelineParallelSize() { return GlobalEnv::Instance().virtual_pipeline_parallel_size(); } @@ -114,12 +135,16 @@ inline int GetVirtualPipelineParallelSize() { return GlobalEnv::Instance().virtu * @brief Get the global rank corresponding to the given (dp, tp, pp) coordinate. */ inline int GetRankOf(int dp, int tp, int pp) { return GlobalEnv::Instance().layout().RankOf(dp, tp, pp); } +inline int GetRankOf(int dp, int tp, int cp, int pp) { return GlobalEnv::Instance().layout().RankOf(dp, tp, cp, pp); } /** * @brief Get the (dp, tp, pp) coordinate corresponding to the given global rank. */ inline void GetCoordOf(int rank, int &dp, int &tp, int &pp) { return GlobalEnv::Instance().layout().CoordOf(rank, dp, tp, pp); } +inline void GetCoordOf(int rank, int &dp, int &tp, int &cp, int &pp) { + return GlobalEnv::Instance().layout().CoordOf(rank, dp, tp, cp, pp); +} /** * @brief Get the group ID that the (dp, tp, pp) coordinate belongs to along a given parallel axis. @@ -127,13 +152,16 @@ inline void GetCoordOf(int rank, int &dp, int &tp, int &pp) { inline int GetGroupId(Axis target, int dp, int tp, int pp) { return GlobalEnv::Instance().layout().GroupId(target, dp, tp, pp); } +inline int GetGroupId(Axis target, int dp, int tp, int cp, int pp) { + return GlobalEnv::Instance().layout().GroupId(target, dp, tp, cp, pp); +} /** * @brief Get the group ID that a given rank belongs to along a specific parallel axis. */ inline int GetGroupId(Axis target, int rank) { - int dp, tp, pp; - GetCoordOf(rank, dp, tp, pp); - return GlobalEnv::Instance().layout().GroupId(target, dp, tp, pp); + int dp, tp, cp, pp; + GetCoordOf(rank, dp, tp, cp, pp); + return GlobalEnv::Instance().layout().GroupId(target, dp, tp, cp, pp); } /** @@ -143,15 +171,18 @@ inline int GetGroupId(Axis target, int rank) { inline std::vector GetGroupRanks(Axis target, int dp, int tp, int pp) { return GlobalEnv::Instance().layout().GroupRanks(target, dp, tp, pp); } +inline std::vector GetGroupRanks(Axis target, int dp, int tp, int cp, int pp) { + return GlobalEnv::Instance().layout().GroupRanks(target, dp, tp, cp, pp); +} /** * @brief Get all ranks that belong to the same group as the given rank * along a specified parallel axis (e.g., all ranks in the same DP group). */ inline std::vector GetGroupRanks(Axis target, int rank) { - int dp, tp, pp; - GetCoordOf(rank, dp, tp, pp); - return GlobalEnv::Instance().layout().GroupRanks(target, dp, tp, pp); + int dp, tp, cp, pp; + GetCoordOf(rank, dp, tp, cp, pp); + return GlobalEnv::Instance().layout().GroupRanks(target, dp, tp, cp, pp); } /** @@ -160,7 +191,7 @@ inline std::vector GetGroupRanks(Axis target, int rank) { * The output is intended for debugging, logging, and runtime verification of * distributed parallelism configuration. * - * @param L The Layout describing DP / TP / PP sizes and axis ordering. + * @param L The Layout describing DP / TP / CP / PP sizes and axis ordering. * @param skip_trivial_axes * If true, axes whose size <= 1(i.e. parallel strategy that is not enabled) * will be marked as "unenabled" and their detailed group listing will be skipped. diff --git a/infini_train/include/nn/parallel/process_group.h b/infini_train/include/nn/parallel/process_group.h index 74bf80c6..47e9781e 100644 --- a/infini_train/include/nn/parallel/process_group.h +++ b/infini_train/include/nn/parallel/process_group.h @@ -39,6 +39,8 @@ class ProcessGroup { virtual int GetGroupRank(int global_rank) const; + int WorldSize() const { return world_size_; } + // Asynchronous communication APIs (Compute / Communication stream decoupled) virtual std::shared_ptr AllReduce(const std::shared_ptr &tensor, function::ReduceOpType reduce_op = function::ReduceOpType::kSum, diff --git a/infini_train/include/nn/parallel/utils.h b/infini_train/include/nn/parallel/utils.h index 4dc737e7..d23671e5 100644 --- a/infini_train/include/nn/parallel/utils.h +++ b/infini_train/include/nn/parallel/utils.h @@ -11,14 +11,22 @@ class Tensor; namespace infini_train::nn::parallel { std::string GetDataParallelProcessGroupName(int global_rank); +std::string GetDataParallelWithContextProcessGroupName(int global_rank); + std::string GetTensorParallelProcessGroupName(int global_rank); +std::string GetContextParallelProcessGroupName(int global_rank); + std::string GetPipelineParallelProcessGroupName(int global_rank); std::vector GetDataParallelGroupRanks(int global_rank); +std::vector GetDataParallelWithContextGroupRanks(int global_rank); + std::vector GetTensorParallelGroupRanks(int global_rank); +std::vector GetContextParallelGroupRanks(int global_rank); + std::vector GetPipelineParallelGroupRanks(int global_rank); // TP/SP Communication Helper Functions diff --git a/infini_train/src/checkpoint/checkpoint.cc b/infini_train/src/checkpoint/checkpoint.cc index 892ec497..468d217e 100644 --- a/infini_train/src/checkpoint/checkpoint.cc +++ b/infini_train/src/checkpoint/checkpoint.cc @@ -105,8 +105,8 @@ void Checkpoint::Load(const std::filesystem::path &checkpoint_dir, nn::Module &m state = LoadTrainerState(checkpoint_dir / "trainer_state.json"); LOG(ERROR) << "[CKPT] Load done: global_step=" << state.global_step << ", consumed_batches =" << state.consumed_batches << ", last_lr=" << state.last_lr - << ", topology(ddp,tp,sp,pp)=(" << state.ddp_size << "," << state.tp_size << "," << state.sp_size << "," - << state.pp_size << ")"; + << ", topology(ddp,tp,sp,cp,pp)=(" << state.ddp_size << "," << state.tp_size << "," << state.sp_size + << "," << state.cp_size << "," << state.pp_size << ")"; } void Checkpoint::SaveStateDict(const std::filesystem::path &path, @@ -186,13 +186,14 @@ void Checkpoint::SaveTrainerState(const std::filesystem::path &path, const Train ofs << " \"n_head\": " << state.n_head << ",\n"; ofs << " \"n_kv_head\": " << state.n_kv_head << ",\n"; ofs << " \"n_embd\": " << state.n_embd << ",\n"; - ofs << " \"vocab_size\": " << state.vocab_size << "\n"; + ofs << " \"vocab_size\": " << state.vocab_size << ",\n"; ofs << " \"global_step\": " << state.global_step << ",\n"; ofs << " \"consumed_batches\": " << state.consumed_batches << ",\n"; ofs << " \"last_lr\": " << state.last_lr << ",\n"; ofs << " \"ddp_size\": " << state.ddp_size << ",\n"; ofs << " \"tp_size\": " << state.tp_size << ",\n"; ofs << " \"sp_size\": " << state.sp_size << ",\n"; + ofs << " \"cp_size\": " << state.cp_size << ",\n"; ofs << " \"pp_size\": " << state.pp_size << "\n"; ofs << "}\n"; } @@ -215,6 +216,7 @@ TrainerState Checkpoint::LoadTrainerState(const std::filesystem::path &path) { state.ddp_size = ExtractNumberField(content, "ddp_size", 1); state.tp_size = ExtractNumberField(content, "tp_size", 1); state.sp_size = ExtractNumberField(content, "sp_size", 1); + state.cp_size = ExtractNumberField(content, "cp_size", 1); state.pp_size = ExtractNumberField(content, "pp_size", 1); return state; } diff --git a/infini_train/src/checkpoint/checkpoint_manager.cc b/infini_train/src/checkpoint/checkpoint_manager.cc index 71e15c08..8730c588 100644 --- a/infini_train/src/checkpoint/checkpoint_manager.cc +++ b/infini_train/src/checkpoint/checkpoint_manager.cc @@ -28,6 +28,7 @@ ResumeFromCheckpointResult ResumeFromCheckpoint(const ResumeFromCheckpointArgs & int ddp_world_size = nn::parallel::global::GetDataParallelSize(); int tp_world_size = nn::parallel::global::GetTensorParallelSize(); int sp_world_size = nn::parallel::global::GetSequenceParallelEnabled() ? tp_world_size : 1; + int cp_world_size = nn::parallel::global::GetContextParallelSize(); int pp_world_size = nn::parallel::global::GetPipelineParallelSize(); std::filesystem::path resume_dir = args.resume_root; @@ -59,6 +60,8 @@ ResumeFromCheckpointResult ResumeFromCheckpoint(const ResumeFromCheckpointArgs & << "TP size mismatch: checkpoint has TP=" << args.state.tp_size << ", but current run has TP=" << tp_world_size; CHECK_EQ(args.state.sp_size, sp_world_size) << "SP size mismatch: checkpoint has SP=" << args.state.sp_size << ", but current run has SP=" << sp_world_size; + CHECK_EQ(args.state.cp_size, cp_world_size) + << "CP size mismatch: checkpoint has CP=" << args.state.cp_size << ", but current run has CP=" << cp_world_size; CHECK_EQ(args.state.pp_size, pp_world_size) << "PP size mismatch: checkpoint has PP=" << args.state.pp_size << ", but current run has PP=" << pp_world_size; @@ -86,6 +89,7 @@ void SaveCheckpoint(const SaveCheckpointArgs &args) { state.ddp_size = args.ddp_size; state.tp_size = args.tp_size; state.sp_size = args.sp_size; + state.cp_size = args.cp_size; state.pp_size = args.pp_size; Checkpoint::Save(args.save_dir, args.model, &args.optimizer, state, args.save_optimizer_state); diff --git a/infini_train/src/core/ccl/ccl_utils.cc b/infini_train/src/core/ccl/ccl_utils.cc index f8968c6b..a7868d66 100644 --- a/infini_train/src/core/ccl/ccl_utils.cc +++ b/infini_train/src/core/ccl/ccl_utils.cc @@ -1,55 +1,158 @@ #include "infini_train/include/core/ccl/ccl_utils.h" +#include +#include #include -#include +#include #include #include #include +#include +#include #include #include "glog/logging.h" namespace infini_train::core { namespace { -std::string UniqueIdFileName(const std::string &name, bool tmp = false) { - return "cclUniqueId_" + name + (tmp ? ".tmp" : ".bin"); + +constexpr int64_t kDefaultUniqueIdTimeoutSec = 600; + +std::string GetEnvString(const char *name, const std::string &default_value) { + const char *value = std::getenv(name); + return value == nullptr ? default_value : std::string(value); +} + +int64_t GetEnvInt64(const char *name, int64_t default_value) { + const char *value = std::getenv(name); + if (value == nullptr) { + return default_value; + } + try { + return std::stoll(value); + } catch (...) { + LOG(WARNING) << "Invalid integer environment variable " << name << "=" << value << ", fallback to " + << default_value; + return default_value; + } +} + +std::string SanitizeFileComponent(std::string value) { + std::replace_if( + value.begin(), value.end(), + [](unsigned char c) { return !(std::isalnum(c) || c == '_' || c == '-' || c == '.'); }, '_'); + return value; +} + +std::filesystem::path UniqueIdFilePath(const std::string &pg_name) { + std::string file_name = "cclUniqueId_"; + const std::string name_space = GetEnvString("INFINITRAIN_CCL_ID_NAMESPACE", ""); + if (!name_space.empty()) { + file_name += SanitizeFileComponent(name_space) + "_"; + } + file_name += SanitizeFileComponent(pg_name) + ".bin"; + return std::filesystem::path(GetEnvString("INFINITRAIN_CCL_ID_DIR", ".")) / file_name; +} + +std::filesystem::path UniqueIdTmpFilePath(const std::filesystem::path &file_path) { + const auto now = std::chrono::steady_clock::now().time_since_epoch().count(); + const std::string global_proc_rank = GetEnvString("GLOBAL_PROC_RANK", "0"); + return std::filesystem::path(file_path.string() + ".tmp." + global_proc_rank + "." + std::to_string(now)); +} + +void RemoveTmpFilesFor(const std::filesystem::path &file_path) { + const auto parent = file_path.parent_path(); + std::error_code ec; + if (!std::filesystem::exists(parent, ec)) { + return; + } + + const std::string prefix = file_path.filename().string() + ".tmp."; + for (const auto &entry : std::filesystem::directory_iterator(parent, ec)) { + if (ec) { + return; + } + const std::string name = entry.path().filename().string(); + if (name.rfind(prefix, 0) == 0) { + std::error_code remove_ec; + std::filesystem::remove(entry.path(), remove_ec); + } + } } } // namespace void WriteUniqueIdFile(const CclUniqueId &unique_id, const std::string &pg_name) { - const std::string tmp_path = UniqueIdFileName(pg_name, true); + const auto file_path = UniqueIdFilePath(pg_name); + const auto tmp_path = UniqueIdTmpFilePath(file_path); + + std::error_code ec; + std::filesystem::create_directories(file_path.parent_path(), ec); + CHECK(!ec) << "Failed to create CCL unique_id directory: " << file_path.parent_path() << ", error=" << ec.message(); + RemoveTmpFilesFor(file_path); std::ofstream ofs(tmp_path, std::ios::binary); CHECK(ofs.good()) << "Failed to open unique_id tmp file for write: " << tmp_path; const size_t size = unique_id.Size(); ofs.write(reinterpret_cast(unique_id.Data()), static_cast(size)); + CHECK(ofs.good()) << "Failed to write unique_id tmp file: " << tmp_path; ofs.close(); + CHECK(!ofs.fail()) << "Failed to close unique_id tmp file: " << tmp_path; - std::rename(tmp_path.c_str(), UniqueIdFileName(pg_name).c_str()); + const auto tmp_size = std::filesystem::file_size(tmp_path, ec); + CHECK(!ec && tmp_size == size) << "Invalid unique_id tmp file size. file=" << tmp_path << ", expected=" << size + << ", got=" << (ec ? 0 : tmp_size) << ", error=" << ec.message(); + + std::filesystem::rename(tmp_path, file_path, ec); + if (ec) { + std::error_code remove_ec; + std::filesystem::remove(tmp_path, remove_ec); + LOG(FATAL) << "Failed to publish unique_id file. tmp=" << tmp_path << ", final=" << file_path + << ", error=" << ec.message(); + } } void ReadUniqueIdFile(CclUniqueId *unique_id, const std::string &pg_name) { CHECK_NOTNULL(unique_id); - const std::string file_path = UniqueIdFileName(pg_name); - - while (!std::filesystem::exists(file_path)) { std::this_thread::sleep_for(std::chrono::microseconds(1000)); } + const auto file_path = UniqueIdFilePath(pg_name); + const auto timeout_sec + = std::max(1, GetEnvInt64("INFINITRAIN_CCL_ID_TIMEOUT_SEC", kDefaultUniqueIdTimeoutSec)); + const auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(timeout_sec); + const size_t expected_size = unique_id->Size(); + uintmax_t last_observed_size = 0; - std::ifstream ifs(file_path, std::ios::binary); - CHECK(ifs.good()) << "Failed to open unique_id file for read: " << file_path; - - std::string bytes((std::istreambuf_iterator(ifs)), std::istreambuf_iterator()); - ifs.close(); + while (std::chrono::steady_clock::now() < deadline) { + std::error_code ec; + if (std::filesystem::exists(file_path, ec)) { + const auto file_size = std::filesystem::file_size(file_path, ec); + if (!ec) { + last_observed_size = file_size; + if (file_size == expected_size) { + std::ifstream ifs(file_path, std::ios::binary); + if (ifs.good()) { + std::string bytes((std::istreambuf_iterator(ifs)), std::istreambuf_iterator()); + ifs.close(); + if (bytes.size() == expected_size) { + unique_id->Load(bytes.data(), bytes.size()); + return; + } + last_observed_size = bytes.size(); + } + } + } + } + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } - CHECK_EQ(bytes.size(), unique_id->Size()) - << "Mismatched unique_id size in file. expected=" << unique_id->Size() << ", got=" << bytes.size(); - unique_id->Load(bytes.data(), bytes.size()); + LOG(FATAL) << "Timed out waiting for CCL unique_id file. file=" << file_path << ", expected_size=" << expected_size + << ", last_observed_size=" << last_observed_size << ", timeout_sec=" << timeout_sec + << ". Set INFINITRAIN_CCL_ID_NAMESPACE for concurrent jobs sharing a directory."; } void CleanupUniqueIdFile(const std::string &pg_name) { - const std::string file_path = UniqueIdFileName(pg_name); - if (std::filesystem::exists(file_path)) { - std::filesystem::remove(file_path); - } + const auto file_path = UniqueIdFilePath(pg_name); + std::error_code ec; + std::filesystem::remove(file_path, ec); + RemoveTmpFilesFor(file_path); } } // namespace infini_train::core diff --git a/infini_train/src/nn/modules/transformer/causal_self_attention.cc b/infini_train/src/nn/modules/transformer/causal_self_attention.cc index 8bed8193..40ce8d7d 100644 --- a/infini_train/src/nn/modules/transformer/causal_self_attention.cc +++ b/infini_train/src/nn/modules/transformer/causal_self_attention.cc @@ -1,6 +1,7 @@ #include "infini_train/include/nn/modules/transformer/causal_self_attention.h" #include +#include #include #include #include @@ -12,6 +13,7 @@ #include "infini_train/include/nn/modules/normalization.h" #include "infini_train/include/nn/modules/sparse.h" #include "infini_train/include/nn/modules/transformer/transformer_config.h" +#include "infini_train/include/nn/parallel/context_parallel.h" #include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/tensor_parallel.h" #include "infini_train/include/tensor.h" @@ -109,16 +111,25 @@ CausalSelfAttention::ForwardStandard(const std::vectorView({B, T, local_n_head_, head_dim})->Transpose(1, 2); v = v->View({B, T, local_n_head_, head_dim})->Transpose(1, 2); - // (B, h_l, T, T) - auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(head_dim)); - // (1, 1, T, T) - auto mask = buffers_[kParamBiasName]->Slice({0, 0, 0, 0}, {1, 1, T, T}, {1, 1, 1, 1}); - // (1, 1, T, T) -> eq 0 -> (1, 1, T, T) -> masked_fill -> (B, h_l, T, T) - att = att->MaskedFill(mask == 0, -std::numeric_limits::infinity()); - // (B, h_l, T, T) - att = nn::function::Softmax(att, -1); - // (B, h_l, T, Dh) - auto y = att->Matmul(v); + const auto q_start = parallel::GetContextParallelSequenceStart(T); + const auto T_kv = T * parallel::global::GetContextParallelSize(); + + // (1, 1, T_local, T_global) + auto mask = buffers_[kParamBiasName]->Slice({0, 0, q_start, 0}, {1, 1, q_start + T, T_kv}, {1, 1, 1, 1}); + std::shared_ptr y; + if (parallel::global::GetContextParallelSize() > 1) { + y = parallel::ContextParallelAttentionFunc(q, k, v, mask, /*mask_true_means_invalid=*/false, + static_cast(1.0 / std::sqrt(head_dim)), /*n_rep=*/1); + } else { + // (B, h_l, T_local, T_global) + auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(head_dim)); + // (1, 1, T_local, T_global) -> eq 0 -> masked_fill -> (B, h_l, T_local, T_global) + att = att->MaskedFill(mask == 0, -std::numeric_limits::infinity()); + // (B, h_l, T_local, T_global) + att = nn::function::Softmax(att, -1); + // (B, h_l, T, Dh) + y = att->Matmul(v); + } // (B, h_l, T, Dh) -> (B, T, h_l, Dh) -> (B, T, local_C) y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_C}); @@ -219,10 +230,13 @@ CausalSelfAttention::ForwardWithRoPE(const std::vector (B, T, H_local, D) via RepeatKV - k = RepeatKV(k, n_rep_); - v = RepeatKV(v, n_rep_); + const bool use_context_parallel = parallel::global::GetContextParallelSize() > 1; + if (!use_context_parallel) { + // align n_head in GQA + // (B, T, KV_local, D) -> (B, T, H_local, D) via RepeatKV + k = RepeatKV(k, n_rep_); + v = RepeatKV(v, n_rep_); + } // (B, T, H_local, D) -> (B, H_local, T, D) q = q->Transpose(1, 2); @@ -232,21 +246,27 @@ CausalSelfAttention::ForwardWithRoPE(const std::vector (B, H_local, D, T) - // q @ k.T: (B, H_local, T, T) -> mul 1.0 / sqrt(D) -> (B, H_local, T, T) - auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(static_cast(D))); - if (mask) { - // mask: (1, 1, T, T) - att = att->MaskedFill(mask, std::numeric_limits::lowest()); + std::shared_ptr y; + if (use_context_parallel) { + y = parallel::ContextParallelAttentionFunc(q, k, v, mask, /*mask_true_means_invalid=*/true, + 1.0f / std::sqrt(static_cast(D)), n_rep_); + } else { + // manual implementation of attention + // this materializes the large (T,T) matrix for all the queries and keys + + // q: (B, H_local, T, D) + // k: (B, H_local, T, D) -> (B, H_local, D, T) + // q @ k.T: (B, H_local, T, T) -> mul 1.0 / sqrt(D) -> (B, H_local, T, T) + auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(static_cast(D))); + if (mask) { + // mask: (1, 1, T, T) + att = att->MaskedFill(mask, std::numeric_limits::lowest()); + } + // (B, H_local, T, T) + att = nn::function::Softmax(att, -1); + // att: (B, H_local, T, T) @ v: (B, H_local, T, D) -> y: (B, H_local, T, D) + y = att->Matmul(v); } - // (B, H_local, T, T) - att = nn::function::Softmax(att, -1); - // att: (B, H_local, T, T) @ v: (B, H_local, T, D) -> y: (B, H_local, T, D) - auto y = att->Matmul(v); // (B, H_local, T, D) -> Transpose(1, 2) -> (B, T, H_local, D) -> (B, T, C_local) y = y->Transpose(1, 2)->Contiguous()->View({B, T, C_local}); // output projection diff --git a/infini_train/src/nn/modules/transformer/transformer.cc b/infini_train/src/nn/modules/transformer/transformer.cc index c7e0f28c..d397eab7 100644 --- a/infini_train/src/nn/modules/transformer/transformer.cc +++ b/infini_train/src/nn/modules/transformer/transformer.cc @@ -16,6 +16,7 @@ #include "infini_train/include/nn/modules/transformer/causal_self_attention.h" #include "infini_train/include/nn/modules/transformer/mlp.h" #include "infini_train/include/nn/modules/transformer/utils.h" +#include "infini_train/include/nn/parallel/context_parallel.h" #include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/tensor_parallel.h" #include "infini_train/include/nn/parallel/utils.h" @@ -56,8 +57,9 @@ std::vector> TransformerFirstStage::Forward(const std::v nn::parallel::GetTensorParallelProcessGroupName(device.Rank().GlobalRank())); tp_rank = tp_group->GetGroupRank(device.Rank().GlobalRank()); } + const int64_t cp_start = nn::parallel::GetContextParallelSequenceStart(x1->Dims()[1]); int64_t t_local = sequence_parallel_enabled ? x1->Dims()[1] / tp_world_size : x1->Dims()[1]; - int64_t start = sequence_parallel_enabled ? tp_rank * t_local : 0; + int64_t start = cp_start + (sequence_parallel_enabled ? tp_rank * t_local : 0); auto pos = nn::init::Arange(start, start + t_local, infini_train::DataType::kINT64, device); // (T) -> Embedding(T_max, C) -> (T, C) @@ -140,15 +142,17 @@ std::vector> TransformerChunk::Forward(const std::vector config_.use_scaled_rope, device); } - const auto t = x1->Dims()[1] * nn::parallel::global::GetSequenceParallelSize(); // full_seq_len + const auto local_t = x1->Dims()[1] * nn::parallel::global::GetSequenceParallelSize(); + const auto full_t = local_t * nn::parallel::global::GetContextParallelSize(); - // Dynamic start_pos (set to 0 for now) - int64_t start_pos = 0; - auto freqs_view = buffers_[kFreqsCisName]->Slice(0, start_pos, start_pos + t, 1); + int64_t start_pos = nn::parallel::GetContextParallelSequenceStart(local_t); + auto freqs_view = buffers_[kFreqsCisName]->Slice(0, start_pos, start_pos + local_t, 1); // Create causal mask - std::shared_ptr ones = std::make_shared(nn::function::Ones({t, t})->To(device)); - std::shared_ptr mask = nn::function::Triu(ones, 1)->View({1, 1, t, t}); + std::shared_ptr ones = std::make_shared(nn::function::Ones({full_t, full_t})->To(device)); + std::shared_ptr mask = nn::function::Triu(ones, 1) + ->Slice({start_pos, 0}, {start_pos + local_t, full_t}, {1, 1}) + ->View({1, 1, local_t, full_t}); std::shared_ptr start_pos_ptr = nullptr; diff --git a/infini_train/src/nn/parallel/context_parallel.cc b/infini_train/src/nn/parallel/context_parallel.cc new file mode 100644 index 00000000..4ffe881b --- /dev/null +++ b/infini_train/src/nn/parallel/context_parallel.cc @@ -0,0 +1,486 @@ +#include "infini_train/include/nn/parallel/context_parallel.h" + +#include +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/autograd/function.h" +#include "infini_train/include/nn/functional.h" +#include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/nn/parallel/process_group.h" +#include "infini_train/include/nn/parallel/reduce_op_type.h" +#include "infini_train/include/nn/parallel/utils.h" +#include "infini_train/include/nn/parallel/work.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::parallel { + +thread_local int cp_rank = 0; + +namespace { + +const ProcessGroup *GetCPGroup(const std::shared_ptr &tensor) { + auto cp_size = global::GetContextParallelSize(); + CHECK_GT(cp_size, 0); + if (cp_size == 1) { + return nullptr; + } + return ProcessGroupFactory::Instance(tensor->GetDevice().type()) + ->Get(GetContextParallelProcessGroupName(tensor->GetDevice().Rank().GlobalRank())); +} + +// Comm Kernel Call Functions +std::shared_ptr GatherAlongFirstDim(const std::shared_ptr &tensor, const ProcessGroup *cp_group) { + const int cp_size = global::GetContextParallelSize(); + auto output_shape = tensor->Dims(); + output_shape[0] *= cp_size; + auto output = std::make_shared(output_shape, tensor->Dtype(), tensor->GetDevice()); + cp_group->AllGather(output, tensor, false); + return output; +} + +std::shared_ptr ReduceScatterAlongFirstDim(const std::shared_ptr &tensor, + const ProcessGroup *cp_group) { + const int cp_size = global::GetContextParallelSize(); + auto output_shape = tensor->Dims(); + CHECK_EQ(output_shape[0] % cp_size, 0) << "First dimension must be divisible by CP size"; + output_shape[0] /= cp_size; + auto output = std::make_shared(output_shape, tensor->Dtype(), tensor->GetDevice()); + cp_group->ReduceScatter(output, tensor, function::ReduceOpType::kSum, false); + return output; +} + +// Attention Helper Functions +std::shared_ptr NewZeroTensorLike(const std::shared_ptr &tensor) { + auto output = std::make_shared(tensor->Dims(), tensor->Dtype(), tensor->GetDevice()); + output->Fill(0.0f); + return output; +} + +std::shared_ptr RepeatKVHeads(const std::shared_ptr &x, int64_t n_rep) { + if (n_rep == 1) { + return x; + } + + const auto &shape = x->Dims(); + const int64_t B = shape[0], H = shape[1], T = shape[2], D = shape[3]; + return x->View({B, H, 1, T, D})->RepeatInterleave(n_rep, 2)->Contiguous()->View({B, H * n_rep, T, D}); +} + +std::shared_ptr SumRepeatedKVHeads(const std::shared_ptr &x, int64_t n_rep) { + if (n_rep == 1) { + return x; + } + + const auto &shape = x->Dims(); + const int64_t B = shape[0], H = shape[1], T = shape[2], D = shape[3]; + CHECK_EQ(H % n_rep, 0); + return x->View({B, H / n_rep, n_rep, T, D})->Sum(2); +} + +struct RingAttentionForwardResult { + std::shared_ptr output; + std::shared_ptr softmax_max; + std::shared_ptr softmax_sum; +}; + +RingAttentionForwardResult RingOnlineAttentionForward(const std::shared_ptr &q, + const std::shared_ptr &k_local, + const std::shared_ptr &v_local, + const std::shared_ptr &mask, bool mask_true_means_invalid, + float scale, int64_t n_rep) { + const int cp_size = global::GetContextParallelSize(); + CHECK_GT(cp_size, 1); + CHECK(mask) << "CP ring attention expects a causal mask."; + + auto cp_group = GetCPGroup(q); + CHECK_NOTNULL(cp_group); + const int rank = cp_group->GetGroupRank(q->GetDevice().Rank().GlobalRank()); + const int send_to = (rank + 1) % cp_size; + const int recv_from = (rank - 1 + cp_size) % cp_size; + + const int64_t local_t = q->Dims()[2]; + CHECK_EQ(k_local->Dims()[2], local_t); + CHECK_EQ(v_local->Dims()[2], local_t); + CHECK_EQ(q->Dims()[1], k_local->Dims()[1] * n_rep); + CHECK_EQ(q->Dims()[1], v_local->Dims()[1] * n_rep); + + auto current_k = k_local; + auto current_v = v_local; + std::shared_ptr running_max; + std::shared_ptr running_sum; + std::shared_ptr running_out; + + for (int step = 0; step < cp_size; ++step) { + std::shared_ptr next_k; + std::shared_ptr next_v; + std::shared_ptr send_work; + std::shared_ptr recv_work; + if (step + 1 < cp_size) { + next_k = std::make_shared(k_local->Dims(), k_local->Dtype(), k_local->GetDevice()); + next_v = std::make_shared(v_local->Dims(), v_local->Dtype(), v_local->GetDevice()); + if (rank % 2 == 0) { + send_work = cp_group->Send({current_k, current_v}, send_to, true); + recv_work = cp_group->Recv({next_k, next_v}, recv_from, true); + } else { + recv_work = cp_group->Recv({next_k, next_v}, recv_from, true); + send_work = cp_group->Send({current_k, current_v}, send_to, true); + } + } + + const int owner = (rank - step + cp_size) % cp_size; + const int64_t kv_start = static_cast(owner) * local_t; + const int64_t kv_end = kv_start + local_t; + + auto k_for_attn = RepeatKVHeads(current_k, n_rep); + auto v_for_attn = RepeatKVHeads(current_v, n_rep); + auto scores = q->Matmul(k_for_attn->Transpose(-2, -1)) * scale; + auto mask_chunk = mask->Slice(3, kv_start, kv_end); + auto invalid_mask = mask_true_means_invalid ? mask_chunk : (mask_chunk == 0); + scores = scores->MaskedFill(invalid_mask, std::numeric_limits::lowest()); + + auto chunk_max = scores->Max(-1, true); + auto probs = (scores - chunk_max)->Exp()->MaskedFill(invalid_mask, 0.0f); + auto chunk_sum = probs->Sum(-1, true); + auto chunk_out = probs->Matmul(v_for_attn); + + if (!running_out) { + running_max = chunk_max; + running_sum = chunk_sum; + running_out = chunk_out; + } else { + auto new_max + = nn::function::Stack(std::vector>{running_max, chunk_max}, -1)->Max(-1); + auto old_scale = (running_max - new_max)->Exp(); + auto new_scale = (chunk_max - new_max)->Exp(); + running_sum = running_sum * old_scale + chunk_sum * new_scale; + running_out = running_out * old_scale + chunk_out * new_scale; + running_max = new_max; + } + + if (recv_work) { + recv_work->WaitNonBlocking(); + send_work->WaitNonBlocking(); + current_k = next_k; + current_v = next_v; + } + } + + return {.output = running_out / running_sum, .softmax_max = running_max, .softmax_sum = running_sum}; +} + +std::shared_ptr ApplyAttention(const std::shared_ptr &q, const std::shared_ptr &k, + const std::shared_ptr &v, const std::shared_ptr &mask, + bool mask_true_means_invalid, float scale) { + auto scores = q->Matmul(k->Transpose(-2, -1)) * scale; + if (mask) { + auto invalid_mask = mask_true_means_invalid ? mask : (mask == 0); + scores = scores->MaskedFill(invalid_mask, std::numeric_limits::lowest()); + } + auto probs = nn::function::Softmax(scores, -1); + return probs->Matmul(v); +} + +// Autograd Function Definitions +class GatherFromCPRegion : public autograd::Function { +public: + static constexpr char kType[] = "GatherFromCPRegionFunction"; + + explicit GatherFromCPRegion() : autograd::Function(kType) {} + + std::vector> Forward(const std::vector> &input_tensors) override { + auto input = input_tensors[0]; + if (global::GetContextParallelSize() == 1) { + return {std::make_shared(*input)}; + } + auto cp_group = GetCPGroup(input); + // FIXME(zbl): Megatron keeps sequence as dim 0. InfiniTrain uses [B, H, S, D], so move only + // the sequence dimension to dim 0 before the CP gather. + return {GatherAlongFirstDim(input->Transpose(0, 2), cp_group)->Transpose(0, 2)->Contiguous()}; + } + + std::vector> Backward(const std::vector> &grad_outputs) override { + if (global::GetContextParallelSize() == 1) { + return {std::make_shared(*grad_outputs[0])}; + } + auto cp_group = GetCPGroup(grad_outputs[0]); + // FIXME(zbl): See Forward() for the [B, H, S, D] sequence-first bridge. + return {ReduceScatterAlongFirstDim(grad_outputs[0]->Transpose(0, 2), cp_group)->Transpose(0, 2)->Contiguous()}; + } +}; + +class GatherFromCPHeadRegion : public autograd::Function { +public: + static constexpr char kType[] = "GatherFromCPHeadRegionFunction"; + + explicit GatherFromCPHeadRegion() : autograd::Function(kType) {} + + std::vector> Forward(const std::vector> &input_tensors) override { + auto input = input_tensors[0]; + if (global::GetContextParallelSize() == 1) { + return {std::make_shared(*input)}; + } + auto cp_group = GetCPGroup(input); + // A2A CP shards heads across CP ranks. ProcessGroup all-gather works on dim 0, so move heads there. + return {GatherAlongFirstDim(input->Transpose(0, 1), cp_group)->Transpose(0, 1)->Contiguous()}; + } + + std::vector> Backward(const std::vector> &grad_outputs) override { + if (global::GetContextParallelSize() == 1) { + return {std::make_shared(*grad_outputs[0])}; + } + auto cp_group = GetCPGroup(grad_outputs[0]); + return {ReduceScatterAlongFirstDim(grad_outputs[0]->Transpose(0, 1), cp_group)->Transpose(0, 1)->Contiguous()}; + } +}; + +class AttnWithCPAndKVP2P : public autograd::Function { +public: + static constexpr char kType[] = "AttnWithCPAndKVP2PFunction"; + + AttnWithCPAndKVP2P(bool mask_true_means_invalid, float scale, int64_t n_rep) + : autograd::Function(kType), mask_true_means_invalid_(mask_true_means_invalid), scale_(scale), n_rep_(n_rep) {} + + std::vector> Forward(const std::vector> &input_tensors) override { + CHECK_EQ(input_tensors.size(), 4); + auto result = RingOnlineAttentionForward(input_tensors[0], input_tensors[1], input_tensors[2], input_tensors[3], + mask_true_means_invalid_, scale_, n_rep_); + softmax_max_ = result.softmax_max; + softmax_sum_ = result.softmax_sum; + return {result.output}; + } + + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override { + ctx_.SaveForBackward({input_tensors[0], input_tensors[1], input_tensors[2], input_tensors[3], output_tensors[0], + softmax_max_, softmax_sum_}); + } + + std::vector> Backward(const std::vector> &grad_outputs) override { + CHECK_EQ(grad_outputs.size(), 1); + auto saved_tensors = ctx_.GetSavedTensors(); + CHECK_EQ(saved_tensors.size(), 7); + + const auto &q = saved_tensors[0]; + const auto &k_local = saved_tensors[1]; + const auto &v_local = saved_tensors[2]; + const auto &mask = saved_tensors[3]; + const auto &output = saved_tensors[4]; + const auto &softmax_max = saved_tensors[5]; + const auto &softmax_sum = saved_tensors[6]; + const auto &grad_output = grad_outputs[0]; + + const int cp_size = global::GetContextParallelSize(); + CHECK_GT(cp_size, 1); + + auto cp_group = GetCPGroup(q); + CHECK_NOTNULL(cp_group); + const int rank = cp_group->GetGroupRank(q->GetDevice().Rank().GlobalRank()); + const int send_to = (rank + 1) % cp_size; + const int recv_from = (rank - 1 + cp_size) % cp_size; + + const int64_t local_t = q->Dims()[2]; + CHECK_EQ(k_local->Dims()[2], local_t); + CHECK_EQ(v_local->Dims()[2], local_t); + CHECK_EQ(q->Dims()[1], k_local->Dims()[1] * n_rep_); + CHECK_EQ(q->Dims()[1], v_local->Dims()[1] * n_rep_); + + auto current_k = k_local; + auto current_v = v_local; + auto current_grad_k = NewZeroTensorLike(k_local); + auto current_grad_v = NewZeroTensorLike(v_local); + std::shared_ptr grad_q; + const auto softmax_delta = (grad_output * output)->Sum(-1, true); + + for (int step = 0; step < cp_size; ++step) { + const int owner = (rank - step + cp_size) % cp_size; + const int64_t kv_start = static_cast(owner) * local_t; + const int64_t kv_end = kv_start + local_t; + + auto k_for_attn = RepeatKVHeads(current_k, n_rep_); + auto v_for_attn = RepeatKVHeads(current_v, n_rep_); + auto scores = q->Matmul(k_for_attn->Transpose(-2, -1)) * scale_; + auto mask_chunk = mask->Slice(3, kv_start, kv_end); + auto invalid_mask = mask_true_means_invalid_ ? mask_chunk : (mask_chunk == 0); + scores = scores->MaskedFill(invalid_mask, std::numeric_limits::lowest()); + + auto probs = (scores - softmax_max)->Exp()->MaskedFill(invalid_mask, 0.0f) / softmax_sum; + auto grad_v_repeated = probs->Transpose(-2, -1)->Matmul(grad_output); + auto grad_probs = grad_output->Matmul(v_for_attn->Transpose(-2, -1)); + auto grad_scores = probs * (grad_probs - softmax_delta); + auto grad_k_repeated = grad_scores->Transpose(-2, -1)->Matmul(q) * scale_; + + current_grad_k = current_grad_k + SumRepeatedKVHeads(grad_k_repeated, n_rep_); + current_grad_v = current_grad_v + SumRepeatedKVHeads(grad_v_repeated, n_rep_); + + std::shared_ptr send_work; + std::shared_ptr recv_work; + std::shared_ptr next_k; + std::shared_ptr next_v; + std::shared_ptr next_grad_k; + std::shared_ptr next_grad_v; + + if (step + 1 < cp_size) { + next_k = std::make_shared(k_local->Dims(), k_local->Dtype(), k_local->GetDevice()); + next_v = std::make_shared(v_local->Dims(), v_local->Dtype(), v_local->GetDevice()); + next_grad_k = NewZeroTensorLike(k_local); + next_grad_v = NewZeroTensorLike(v_local); + if (rank % 2 == 0) { + send_work = cp_group->Send({current_k, current_v, current_grad_k, current_grad_v}, send_to, true); + recv_work = cp_group->Recv({next_k, next_v, next_grad_k, next_grad_v}, recv_from, true); + } else { + recv_work = cp_group->Recv({next_k, next_v, next_grad_k, next_grad_v}, recv_from, true); + send_work = cp_group->Send({current_k, current_v, current_grad_k, current_grad_v}, send_to, true); + } + } else { + next_grad_k = NewZeroTensorLike(k_local); + next_grad_v = NewZeroTensorLike(v_local); + if (rank % 2 == 0) { + send_work = cp_group->Send({current_grad_k, current_grad_v}, send_to, true); + recv_work = cp_group->Recv({next_grad_k, next_grad_v}, recv_from, true); + } else { + recv_work = cp_group->Recv({next_grad_k, next_grad_v}, recv_from, true); + send_work = cp_group->Send({current_grad_k, current_grad_v}, send_to, true); + } + } + + auto grad_q_chunk = grad_scores->Matmul(k_for_attn) * scale_; + grad_q = grad_q ? grad_q + grad_q_chunk : grad_q_chunk; + + recv_work->WaitNonBlocking(); + send_work->WaitNonBlocking(); + + if (step + 1 < cp_size) { + current_k = next_k; + current_v = next_v; + current_grad_k = next_grad_k; + current_grad_v = next_grad_v; + } else { + current_grad_k = next_grad_k; + current_grad_v = next_grad_v; + } + } + + return {grad_q, current_grad_k, current_grad_v, nullptr}; + } + +private: + bool mask_true_means_invalid_ = false; + float scale_ = 1.0f; + int64_t n_rep_ = 1; + std::shared_ptr softmax_max_; + std::shared_ptr softmax_sum_; +}; + +std::shared_ptr GatherFromCPHeadRegionFunc(const std::shared_ptr &input) { + if (global::GetContextParallelSize() == 1) { + return input; + } + return std::make_shared()->Apply({input})[0]; +} + +} // namespace + +// CP State Helper Functions +int GetContextParallelRank() { return cp_rank; } + +int64_t GetContextParallelSequenceStart(int64_t local_sequence_length) { + return static_cast(GetContextParallelRank()) * local_sequence_length; +} + +// CP Communication Helper Functions +std::shared_ptr SliceAlongCPRegionFunc(const std::shared_ptr &input, int64_t dim) { + const int cp_size = global::GetContextParallelSize(); + if (cp_size == 1) { + return input; + } + + int64_t normalized_dim = dim; + if (normalized_dim < 0) { + normalized_dim += static_cast(input->Dims().size()); + } + CHECK_GE(normalized_dim, 0); + CHECK_LT(normalized_dim, static_cast(input->Dims().size())); + const auto dim_size = input->Dims()[normalized_dim]; + CHECK_EQ(dim_size % cp_size, 0) << "Sequence dimension must be divisible by CP size"; + const int64_t local_size = dim_size / cp_size; + const int64_t start = GetContextParallelSequenceStart(local_size); + return input->Slice(normalized_dim, start, start + local_size, 1)->Contiguous(); +} + +std::shared_ptr GatherFromCPRegionFunc(const std::shared_ptr &input) { + if (global::GetContextParallelSize() == 1) { + return input; + } + return std::make_shared()->Apply({input})[0]; +} + +// CP Attention Backend Functions +std::shared_ptr AttnFuncWithCPAndKVP2P(const std::shared_ptr &q, const std::shared_ptr &k, + const std::shared_ptr &v, const std::shared_ptr &mask, + bool mask_true_means_invalid, float scale, int64_t n_rep) { + return std::make_shared(mask_true_means_invalid, scale, n_rep)->Apply({q, k, v, mask})[0]; +} + +std::shared_ptr AttnFuncWithCPAndKVAllGather(const std::shared_ptr &q, const std::shared_ptr &k, + const std::shared_ptr &v, + const std::shared_ptr &mask, bool mask_true_means_invalid, + float scale, int64_t n_rep) { + auto gathered_k = GatherFromCPRegionFunc(k); + auto gathered_v = GatherFromCPRegionFunc(v); + auto k_for_attn = RepeatKVHeads(gathered_k, n_rep); + auto v_for_attn = RepeatKVHeads(gathered_v, n_rep); + return ApplyAttention(q, k_for_attn, v_for_attn, mask, mask_true_means_invalid, scale); +} + +std::shared_ptr AttnFuncWithCPAndQKVOA2A(const std::shared_ptr &q, const std::shared_ptr &k, + const std::shared_ptr &v, const std::shared_ptr &mask, + bool mask_true_means_invalid, float scale, int64_t n_rep) { + const int cp_size = global::GetContextParallelSize(); + const int rank = GetCPGroup(q)->GetGroupRank(q->GetDevice().Rank().GlobalRank()); + const int64_t q_heads = q->Dims()[1]; + const int64_t kv_heads = k->Dims()[1]; + CHECK_EQ(q_heads % cp_size, 0) << "A2A CP requires local query heads divisible by CP size"; + CHECK_EQ(kv_heads % cp_size, 0) << "A2A CP requires local KV heads divisible by CP size"; + + // TODO(zbl): Replace this semantic all-gather/slice implementation with true QKV/O all-to-all + // once ProcessGroup exposes an async all-to-all primitive. + auto full_q = GatherFromCPRegionFunc(q); + auto full_k = GatherFromCPRegionFunc(k); + auto full_v = GatherFromCPRegionFunc(v); + auto full_mask = mask ? GatherFromCPRegionFunc(mask) : nullptr; + + const int64_t q_heads_per_cp = q_heads / cp_size; + const int64_t kv_heads_per_cp = kv_heads / cp_size; + auto q_shard = full_q->Slice(1, rank * q_heads_per_cp, (rank + 1) * q_heads_per_cp); + auto k_shard = full_k->Slice(1, rank * kv_heads_per_cp, (rank + 1) * kv_heads_per_cp); + auto v_shard = full_v->Slice(1, rank * kv_heads_per_cp, (rank + 1) * kv_heads_per_cp); + auto k_for_attn = RepeatKVHeads(k_shard, n_rep); + auto v_for_attn = RepeatKVHeads(v_shard, n_rep); + auto output_shard = ApplyAttention(q_shard, k_for_attn, v_for_attn, full_mask, mask_true_means_invalid, scale); + + auto gathered_heads = GatherFromCPHeadRegionFunc(output_shard); + const int64_t local_t = q->Dims()[2]; + const int64_t seq_start = static_cast(rank) * local_t; + return gathered_heads->Slice(2, seq_start, seq_start + local_t)->Contiguous(); +} + +std::shared_ptr ContextParallelAttentionFunc(const std::shared_ptr &q, const std::shared_ptr &k, + const std::shared_ptr &v, + const std::shared_ptr &mask, bool mask_true_means_invalid, + float scale, int64_t n_rep) { + CHECK_GT(global::GetContextParallelSize(), 1); + const auto comm_type = global::GetContextParallelCommType(); + if (comm_type == "p2p") { + return AttnFuncWithCPAndKVP2P(q, k, v, mask, mask_true_means_invalid, scale, n_rep); + } + if (comm_type == "a2a") { + return AttnFuncWithCPAndQKVOA2A(q, k, v, mask, mask_true_means_invalid, scale, n_rep); + } + return AttnFuncWithCPAndKVAllGather(q, k, v, mask, mask_true_means_invalid, scale, n_rep); +} + +} // namespace infini_train::nn::parallel diff --git a/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc b/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc index a3bfe008..bcade972 100644 --- a/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc +++ b/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc @@ -25,7 +25,7 @@ constexpr char kModuleName[] = "module"; DistributedDataParallel::DistributedDataParallel(std::shared_ptr module, const Rank &rank, const DistributedDataParallelConfig ddp_config) : ddp_config_(ddp_config), - ddp_pg_(ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(rank.GlobalRank()))) { + ddp_pg_(ProcessGroupFactory::Instance()->Get(GetDataParallelWithContextProcessGroupName(rank.GlobalRank()))) { CHECK(ddp_config_.zero_stage >= 0 && ddp_config_.zero_stage <= 3) << "DistributedDataParallel: zero_stage must be in 0/1/2/3."; if (ddp_config_.zero_stage == 3) { diff --git a/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc b/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc index ab3a8002..bdfcde91 100644 --- a/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc +++ b/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc @@ -413,7 +413,7 @@ ParamAndGradBuffer::ParamAndGradBuffer(const std::vector DistributedDataParallelConfig ddp_config) : params_(std::move(params)), ddp_pg_(std::move(ddp_pg)), ddp_config_(ddp_config) { if (ddp_pg_) { - ddp_world_size_ = global::GetDataParallelSize(); + ddp_world_size_ = static_cast(ddp_pg_->WorldSize()); } grads_.clear(); diff --git a/infini_train/src/nn/parallel/ddp/reducer.cc b/infini_train/src/nn/parallel/ddp/reducer.cc index 1bdd29e1..16d1a3e5 100644 --- a/infini_train/src/nn/parallel/ddp/reducer.cc +++ b/infini_train/src/nn/parallel/ddp/reducer.cc @@ -387,7 +387,7 @@ void Reducer::MarkBucketReady(size_t bucket_index) { void Reducer::FinalizeBucketDense(size_t bucket_index) { // NOTE(zbl): Assume mutex is on when entering this function auto &bucket = buckets_.at(bucket_index); - auto ddp_pg = ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(bucket.device_rank)); + auto ddp_pg = ProcessGroupFactory::Instance()->Get(GetDataParallelWithContextProcessGroupName(bucket.device_rank)); if (comm_hook_) { std::vector> bucket_view{bucket.contents}; diff --git a/infini_train/src/nn/parallel/global.cc b/infini_train/src/nn/parallel/global.cc index 65a3208e..75d6532d 100644 --- a/infini_train/src/nn/parallel/global.cc +++ b/infini_train/src/nn/parallel/global.cc @@ -2,7 +2,9 @@ #include #include +#include #include +#include #include "glog/logging.h" @@ -20,7 +22,6 @@ namespace infini_train::nn::parallel::global { thread_local int thread_global_rank = 0; void Layout::InitStrides() { - // Calculate strides int stride = 1; for (int i = 0; i < AXIS_COUNT; ++i) { const Axis ax = order[i]; @@ -29,9 +30,10 @@ void Layout::InitStrides() { } } -int Layout::RankOf(int dp, int tp, int pp) const { - // Return the thread rank given layout coords - const int coord[AXIS_COUNT] = {dp, tp, pp}; +int Layout::RankOf(int dp, int tp, int pp) const { return RankOf(dp, tp, /*cp=*/0, pp); } + +int Layout::RankOf(int dp, int tp, int cp, int pp) const { + const int coord[AXIS_COUNT] = {dp, tp, cp, pp}; int r = 0; for (int i = 0; i < AXIS_COUNT; ++i) { const Axis ax = static_cast(i); @@ -41,14 +43,20 @@ int Layout::RankOf(int dp, int tp, int pp) const { } void Layout::CoordOf(int rank, int &dp, int &tp, int &pp) const { - // Return the layout coords given thread rank + int cp = 0; + CoordOf(rank, dp, tp, cp, pp); +} + +void Layout::CoordOf(int rank, int &dp, int &tp, int &cp, int &pp) const { dp = (rank / strides[DP]) % sizes[DP]; tp = (rank / strides[TP]) % sizes[TP]; + cp = (rank / strides[CP]) % sizes[CP]; pp = (rank / strides[PP]) % sizes[PP]; } -int Layout::GroupId(Axis target, int dp, int tp, int pp) const { - // Return the parallel ProcessGroup ID where the rank is in +int Layout::GroupId(Axis target, int dp, int tp, int pp) const { return GroupId(target, dp, tp, /*cp=*/0, pp); } + +int Layout::GroupId(Axis target, int dp, int tp, int cp, int pp) const { int id = 0; int mult = 1; for (int i = AXIS_COUNT - 1; i >= 0; --i) { @@ -56,7 +64,7 @@ int Layout::GroupId(Axis target, int dp, int tp, int pp) const { if (ax == target) { continue; } - int c = (ax == DP ? dp : (ax == TP ? tp : pp)); + int c = (ax == DP ? dp : (ax == TP ? tp : (ax == CP ? cp : pp))); id += c * mult; mult *= sizes[ax]; } @@ -64,19 +72,24 @@ int Layout::GroupId(Axis target, int dp, int tp, int pp) const { } std::vector Layout::GroupRanks(Axis target, int fixed_dp, int fixed_tp, int fixed_pp) const { - // Return all the ranks within the same parallel ProcessGroup + return GroupRanks(target, fixed_dp, fixed_tp, /*fixed_cp=*/0, fixed_pp); +} + +std::vector Layout::GroupRanks(Axis target, int fixed_dp, int fixed_tp, int fixed_cp, int fixed_pp) const { std::vector ranks; ranks.reserve(sizes[target]); - int dp = fixed_dp, tp = fixed_tp, pp = fixed_pp; + int dp = fixed_dp, tp = fixed_tp, cp = fixed_cp, pp = fixed_pp; for (int v = 0; v < sizes[target]; ++v) { if (target == DP) { dp = v; } else if (target == TP) { tp = v; + } else if (target == CP) { + cp = v; } else { pp = v; } - ranks.push_back(RankOf(dp, tp, pp)); + ranks.push_back(RankOf(dp, tp, cp, pp)); } return ranks; } @@ -87,6 +100,7 @@ GlobalEnv &GlobalEnv::Instance() { } void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool sequence_parallel_enabled, + int context_parallel_size, const std::string &context_parallel_comm_type, int pipeline_parallel_size, int virtual_pipeline_parallel_size) { std::lock_guard lock(mutex_); @@ -102,12 +116,20 @@ void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool seq CHECK_GE(tensor_parallel_size, 1) << "Tensor Parallel size must be >= 1"; tensor_parallel_size_ = tensor_parallel_size; sequence_parallel_enabled_ = sequence_parallel_enabled; + CHECK_GE(context_parallel_size, 1) << "Context Parallel size must be >= 1"; + context_parallel_size_ = context_parallel_size; + context_parallel_comm_type_ = context_parallel_comm_type; + CHECK_GE(pipeline_parallel_size, 1) << "Pipeline Parallel size must be >= 1"; pipeline_parallel_size_ = pipeline_parallel_size; virtual_pipeline_parallel_size_ = virtual_pipeline_parallel_size; - data_parallel_size_ = world_size_ / tensor_parallel_size_ / pipeline_parallel_size_; + + CHECK_EQ(world_size_ % (tensor_parallel_size_ * context_parallel_size_ * pipeline_parallel_size_), 0) + << "world_size must be divisible by TP * CP * PP"; + data_parallel_size_ = world_size_ / tensor_parallel_size_ / context_parallel_size_ / pipeline_parallel_size_; layout_.sizes[DP] = data_parallel_size_; layout_.sizes[TP] = tensor_parallel_size_; + layout_.sizes[CP] = context_parallel_size_; layout_.sizes[PP] = pipeline_parallel_size_; layout_.InitStrides(); @@ -159,6 +181,16 @@ bool GlobalEnv::sequence_parallel_enabled() const { return sequence_parallel_enabled_; } +int GlobalEnv::context_parallel_size() const { + CHECK(initialized_) << "GlobalEnv is not initialized!"; + return context_parallel_size_; +} + +const std::string &GlobalEnv::context_parallel_comm_type() const { + CHECK(initialized_) << "GlobalEnv is not initialized!"; + return context_parallel_comm_type_; +} + int GlobalEnv::data_parallel_size() const { CHECK(initialized_) << "GlobalEnv is not initialized!"; return data_parallel_size_; @@ -180,7 +212,7 @@ Layout GlobalEnv::layout() const { } namespace { -inline const char *AxisName(Axis a) { return a == DP ? "DP" : (a == TP ? "TP" : "PP"); } +inline const char *AxisName(Axis a) { return a == DP ? "DP" : (a == TP ? "TP" : (a == CP ? "CP" : "PP")); } inline int NumGroups(const Layout &L, Axis target) { int n = 1; @@ -196,8 +228,8 @@ inline int NumGroups(const Layout &L, Axis target) { std::string ProcessGroupOverview(const Layout &L, bool skip_trivial_axes) { std::ostringstream oss; oss << std::format("\n=== Parallel Communication Groups ===\n" - "world_size = {}, config: {{DP={}, TP={}, PP={}}}, order: {{", - GetWorldSize(), L.sizes[DP], L.sizes[TP], L.sizes[PP]); + "world_size = {}, config: {{DP={}, TP={}, CP={}, PP={}}}, order: {{", + GetWorldSize(), L.sizes[DP], L.sizes[TP], L.sizes[CP], L.sizes[PP]); for (int i = 0; i < AXIS_COUNT; ++i) { oss << AxisName(L.order[i]) << (i + 1 == AXIS_COUNT ? "" : " -> "); } oss << "}\n"; @@ -208,51 +240,38 @@ std::string ProcessGroupOverview(const Layout &L, bool skip_trivial_axes) { oss << std::format("[{}] size={}, unenabled\n", AxisName(ax), L.sizes[ax]); continue; } - // Build > mapping - std::vector>> groups; + + std::vector>> groups; for (int dp = 0; dp < (ax == DP ? 1 : L.sizes[DP]); ++dp) { for (int tp = 0; tp < (ax == TP ? 1 : L.sizes[TP]); ++tp) { - for (int pp = 0; pp < (ax == PP ? 1 : L.sizes[PP]); ++pp) { - int gid = L.GroupId(ax, dp, tp, pp); - groups.emplace_back(gid, std::make_tuple(dp, tp, pp)); + for (int cp = 0; cp < (ax == CP ? 1 : L.sizes[CP]); ++cp) { + for (int pp = 0; pp < (ax == PP ? 1 : L.sizes[PP]); ++pp) { + int gid = L.GroupId(ax, dp, tp, cp, pp); + groups.emplace_back(gid, std::make_tuple(dp, tp, cp, pp)); + } } } } - // Sort by the order of Group ID std::sort(groups.begin(), groups.end(), [](const auto &a, const auto &b) { return a.first < b.first; }); const int num_groups = NumGroups(L, ax); const auto name = AxisName(ax); oss << std::format("[{}] size={}, num_groups={}\n", name, L.sizes[ax], num_groups); - // Iterate and print in the order of Group ID for (const auto &pair : groups) { int gid = pair.first; - int dp, tp, pp; - std::tie(dp, tp, pp) = pair.second; - auto ranks = L.GroupRanks(ax, dp, tp, pp); - std::sort(ranks.begin(), ranks.end()); - - auto dp_size_str = (ax == DP) ? "-" : std::to_string(dp); - auto tp_size_str = (ax == TP) ? "-" : std::to_string(tp); - auto pp_size_str = (ax == PP) ? "-" : std::to_string(pp); - - std::string ranks_str; - ranks_str.reserve(ranks.size() * 4); - for (size_t i = 0; i < ranks.size(); ++i) { - if (i > 0) { - ranks_str += ", "; - } - ranks_str += std::to_string(ranks[i]); - } - oss << std::format(" - {} {} (dp={}, tp={}, pp={}): [{}]\n", name, gid, dp_size_str, tp_size_str, - pp_size_str, ranks_str); - } - if (a + 1 < AXIS_COUNT) { - oss << "\n"; + auto [dp, tp, cp, pp] = pair.second; + auto coord = std::format("dp={}, tp={}, cp={}, pp={}", ax == DP ? "-" : std::to_string(dp), + ax == TP ? "-" : std::to_string(tp), ax == CP ? "-" : std::to_string(cp), + ax == PP ? "-" : std::to_string(pp)); + + oss << std::format("- {} {} ({}): [", name, gid, coord); + auto ranks = L.GroupRanks(ax, dp, tp, cp, pp); + for (size_t i = 0; i < ranks.size(); ++i) { oss << ranks[i] << (i + 1 == ranks.size() ? "" : ", "); } + oss << "]\n"; } } - oss << "\n"; + oss << "=====================================\n"; return oss.str(); } diff --git a/infini_train/src/nn/parallel/utils.cc b/infini_train/src/nn/parallel/utils.cc index 93c6ae31..0a6d962e 100644 --- a/infini_train/src/nn/parallel/utils.cc +++ b/infini_train/src/nn/parallel/utils.cc @@ -8,18 +8,45 @@ std::string GetDataParallelProcessGroupName(int global_rank) { return "DP" + std::to_string(global::GetGroupId(global::DP, global_rank)); } +std::string GetDataParallelWithContextProcessGroupName(int global_rank) { + int dp, tp, cp, pp; + global::GetCoordOf(global_rank, dp, tp, cp, pp); + return "DP_CP" + std::to_string(tp * global::GetPipelineParallelSize() + pp); +} + std::string GetTensorParallelProcessGroupName(int global_rank) { return "TP" + std::to_string(global::GetGroupId(global::TP, global_rank)); } +std::string GetContextParallelProcessGroupName(int global_rank) { + return "CP" + std::to_string(global::GetGroupId(global::CP, global_rank)); +} + std::string GetPipelineParallelProcessGroupName(int global_rank) { return "PP" + std::to_string(global::GetGroupId(global::PP, global_rank)); } std::vector GetDataParallelGroupRanks(int global_rank) { return global::GetGroupRanks(global::DP, global_rank); } +std::vector GetDataParallelWithContextGroupRanks(int global_rank) { + int dp, tp, cp, pp; + global::GetCoordOf(global_rank, dp, tp, cp, pp); + std::vector ranks; + ranks.reserve(global::GetDataParallelSize() * global::GetContextParallelSize()); + for (int dp_idx = 0; dp_idx < global::GetDataParallelSize(); ++dp_idx) { + for (int cp_idx = 0; cp_idx < global::GetContextParallelSize(); ++cp_idx) { + ranks.push_back(global::GetRankOf(dp_idx, tp, cp_idx, pp)); + } + } + return ranks; +} + std::vector GetTensorParallelGroupRanks(int global_rank) { return global::GetGroupRanks(global::TP, global_rank); } +std::vector GetContextParallelGroupRanks(int global_rank) { + return global::GetGroupRanks(global::CP, global_rank); +} + std::vector GetPipelineParallelGroupRanks(int global_rank) { return global::GetGroupRanks(global::PP, global_rank); } diff --git a/scripts/test_config.json b/scripts/test_config.json index aa7de8d1..f33b8905 100644 --- a/scripts/test_config.json +++ b/scripts/test_config.json @@ -1,10 +1,10 @@ { "variables": { "BUILD_DIR": "../build", - "GPT2_INPUT_BIN": "/data1/shared/InfiniTrain-dev/data/llmc/gpt2/tinyshakespeare/tiny_shakespeare_train.bin", - "GPT2_LLMC_FILEPATH": "/data1/shared/InfiniTrain-dev/data/llmc/gpt2/gpt2_124M.bin", - "LLAMA3_INPUT_BIN": "/data1/shared/InfiniTrain-dev/data/llmc/llama3/tinyshakespeare/tiny_shakespeare_train.bin", - "LLAMA3_LLMC_FILEPATH": "/data1/shared/InfiniTrain-dev/data/llmc/llama3/llama3.2_1B_fp32.bin", + "GPT2_INPUT_BIN": "/data/shared/InfiniTrain-dev/data/llmc/gpt2/tinyshakespeare/tiny_shakespeare_train.bin", + "GPT2_LLMC_FILEPATH": "/data/shared/InfiniTrain-dev/data/llmc/gpt2/gpt2_124M.bin", + "LLAMA3_INPUT_BIN": "/data/shared/InfiniTrain-dev/data/llmc/llama3/tinyshakespeare/tiny_shakespeare_train.bin", + "LLAMA3_LLMC_FILEPATH": "/data/shared/InfiniTrain-dev/data/llmc/llama3/llama3.2_1B_fp32.bin", "PROFILE_LOG_DIR": "./profile_logs", "LOG_DIR": "./logs", "CKPT_ROOT_DIR": "/data1/ckpt", @@ -191,6 +191,62 @@ "pipeline_parallel": 2, "virtual_pipeline_parallel": 2 } + }, + { + "id": "9", + "args": { + "dtype": "float32", + "nthread_per_process": 4, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "context_parallel": 2, + "cp_comm_type": "p2p" + } + }, + { + "id": "9_bfloat16", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 4, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "context_parallel": 2, + "cp_comm_type": "p2p" + } + }, + { + "id": "10", + "args": { + "dtype": "float32", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 2, + "sequence_parallel": true, + "pipeline_parallel": 2, + "virtual_pipeline_parallel": 2, + "context_parallel": 2, + "cp_comm_type": "p2p" + } + }, + { + "id": "10_bfloat16", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 2, + "sequence_parallel": true, + "pipeline_parallel": 2, + "virtual_pipeline_parallel": 2, + "context_parallel": 2, + "cp_comm_type": "p2p" + } } ] }, @@ -341,6 +397,58 @@ "zero_stage": 2 } }, + { + "id": "9_distopt", + "args": { + "dtype": "float32", + "nthread_per_process": 4, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "context_parallel": 2, + "cp_comm_type": "p2p", + "zero_stage": 1 + } + }, + { + "id": "9_zero2", + "args": { + "dtype": "float32", + "nthread_per_process": 4, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "context_parallel": 2, + "cp_comm_type": "p2p", + "zero_stage": 2 + } + }, + { + "id": "9_bfloat16_distopt", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 4, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "context_parallel": 2, + "cp_comm_type": "p2p", + "zero_stage": 1 + } + }, + { + "id": "9_bfloat16_zero2", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 4, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "context_parallel": 2, + "cp_comm_type": "p2p", + "zero_stage": 2 + } + }, { "id": "8_distopt", "args": { @@ -673,4 +781,4 @@ ] } ] -} \ No newline at end of file +} diff --git a/tests/autograd/test_autograd.cc b/tests/autograd/test_autograd.cc index 8ed27acf..2972a45d 100644 --- a/tests/autograd/test_autograd.cc +++ b/tests/autograd/test_autograd.cc @@ -8,8 +8,8 @@ #include "infini_train/include/autograd/function.h" #include "infini_train/include/autograd/linear.h" #include "infini_train/include/autograd/matmul.h" -#include "infini_train/include/autograd/normalization.h" #include "infini_train/include/autograd/no_op.h" +#include "infini_train/include/autograd/normalization.h" #include "infini_train/include/autograd/outer.h" #include "infini_train/include/autograd/reduction.h" #include "infini_train/include/autograd/softmax.h" @@ -123,7 +123,8 @@ TEST_P(AutogradForwardTest, SavedOutputIsPackedWithoutAutogradMeta) { } TEST_P(AutogradForwardTest, FunctionCtxNeedsInputGradAndSaveForBackward) { - auto requires_grad_input = std::make_shared(std::vector{2, 2}, DataType::kFLOAT32, GetDevice(), true); + auto requires_grad_input + = std::make_shared(std::vector{2, 2}, DataType::kFLOAT32, GetDevice(), true); auto no_grad_input = std::make_shared(std::vector{2, 2}, DataType::kFLOAT32, GetDevice(), false); requires_grad_input->Fill(1.0f); no_grad_input->Fill(2.0f); diff --git a/tests/common/test_main.cc b/tests/common/test_main.cc index 54ef48dd..521ee4c8 100644 --- a/tests/common/test_main.cc +++ b/tests/common/test_main.cc @@ -4,6 +4,6 @@ int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); - infini_train::nn::parallel::global::GlobalEnv::Instance().Init(1, 1, false, 1, 1); + infini_train::nn::parallel::global::GlobalEnv::Instance().Init(1, 1, false, 1, "p2p", 1, 1); return RUN_ALL_TESTS(); } From 8c75b6b962f50c97b91d2a9ba0758c418a718372 Mon Sep 17 00:00:00 2001 From: bolunz Date: Wed, 1 Jul 2026 10:02:04 +0800 Subject: [PATCH 2/3] feat: refactor and support a2a attn --- infini_train/include/core/ccl/ccl.h | 3 + .../include/nn/parallel/context_parallel.h | 14 +- .../include/nn/parallel/parallel_functional.h | 3 + .../include/nn/parallel/process_group.h | 3 + infini_train/src/core/ccl/ccl.cc | 5 + infini_train/src/core/ccl/cuda/nccl_impl.cc | 34 ++ infini_train/src/core/ccl/cuda/nccl_impl.h | 3 + .../transformer/causal_self_attention.cc | 7 +- .../src/nn/parallel/context_parallel.cc | 490 +++++++++++------- .../src/nn/parallel/parallel_functional.cc | 9 + infini_train/src/nn/parallel/process_group.cc | 27 + 11 files changed, 406 insertions(+), 192 deletions(-) diff --git a/infini_train/include/core/ccl/ccl.h b/infini_train/include/core/ccl/ccl.h index 626cb078..7d6b967b 100644 --- a/infini_train/include/core/ccl/ccl.h +++ b/infini_train/include/core/ccl/ccl.h @@ -53,6 +53,9 @@ class CclImpl { nn::parallel::function::ReduceOpType reduce_op, const CclComm *comm, Stream *stream) const; + virtual void AllToAll(const void *sendbuff, void *recvbuff, size_t count, DataType dtype, const CclComm *comm, + Stream *stream) const; + virtual void Send(const void *buff, size_t count, DataType dtype, int peer, const CclComm *comm, Stream *stream) const; diff --git a/infini_train/include/nn/parallel/context_parallel.h b/infini_train/include/nn/parallel/context_parallel.h index b1ecbb82..b73c1772 100644 --- a/infini_train/include/nn/parallel/context_parallel.h +++ b/infini_train/include/nn/parallel/context_parallel.h @@ -19,22 +19,20 @@ std::shared_ptr SliceAlongCPRegionFunc(const std::shared_ptr &in std::shared_ptr GatherFromCPRegionFunc(const std::shared_ptr &input); -std::shared_ptr ContextParallelAttentionFunc(const std::shared_ptr &q, const std::shared_ptr &k, - const std::shared_ptr &v, - const std::shared_ptr &mask, bool mask_true_means_invalid, - float scale, int64_t n_rep); +std::shared_ptr AttnForwardFuncWithCP(const std::shared_ptr &q, const std::shared_ptr &k, + const std::shared_ptr &v, const std::shared_ptr &mask, + float scale, int64_t n_rep); std::shared_ptr AttnFuncWithCPAndKVP2P(const std::shared_ptr &q, const std::shared_ptr &k, const std::shared_ptr &v, const std::shared_ptr &mask, - bool mask_true_means_invalid, float scale, int64_t n_rep); + float scale, int64_t n_rep); std::shared_ptr AttnFuncWithCPAndKVAllGather(const std::shared_ptr &q, const std::shared_ptr &k, const std::shared_ptr &v, - const std::shared_ptr &mask, bool mask_true_means_invalid, - float scale, int64_t n_rep); + const std::shared_ptr &mask, float scale, int64_t n_rep); std::shared_ptr AttnFuncWithCPAndQKVOA2A(const std::shared_ptr &q, const std::shared_ptr &k, const std::shared_ptr &v, const std::shared_ptr &mask, - bool mask_true_means_invalid, float scale, int64_t n_rep); + float scale, int64_t n_rep); } // namespace infini_train::nn::parallel diff --git a/infini_train/include/nn/parallel/parallel_functional.h b/infini_train/include/nn/parallel/parallel_functional.h index 2eed56f4..1d39c38b 100644 --- a/infini_train/include/nn/parallel/parallel_functional.h +++ b/infini_train/include/nn/parallel/parallel_functional.h @@ -25,6 +25,9 @@ std::shared_ptr AllGather(const std::shared_ptr &output, const std std::shared_ptr ReduceScatter(const std::shared_ptr &output, const std::shared_ptr &input, ReduceOpType reduce_op, const ProcessGroup *pg = nullptr, bool async_op = false); +std::shared_ptr AllToAll(const std::shared_ptr &output, const std::shared_ptr &input, + const ProcessGroup *pg = nullptr, bool async_op = false); + std::vector>> Scatter(const std::vector> &input_tensors, const std::vector &device_ids, int dim); diff --git a/infini_train/include/nn/parallel/process_group.h b/infini_train/include/nn/parallel/process_group.h index 47e9781e..c9541321 100644 --- a/infini_train/include/nn/parallel/process_group.h +++ b/infini_train/include/nn/parallel/process_group.h @@ -54,6 +54,9 @@ class ProcessGroup { function::ReduceOpType reduce_op = function::ReduceOpType::kSum, bool async_op = false) const; + virtual std::shared_ptr AllToAll(const std::shared_ptr &output, const std::shared_ptr &input, + bool async_op = false) const; + virtual std::shared_ptr Send(std::vector> tensors, int dest_rank, bool async_op = false) const; diff --git a/infini_train/src/core/ccl/ccl.cc b/infini_train/src/core/ccl/ccl.cc index 92c14cc6..d981e521 100644 --- a/infini_train/src/core/ccl/ccl.cc +++ b/infini_train/src/core/ccl/ccl.cc @@ -54,6 +54,11 @@ void CclImpl::ReduceScatter(const void *sendbuff, void *recvbuff, size_t recv_co LOG(FATAL) << "CclImpl::ReduceScatter is not implemented."; } +void CclImpl::AllToAll(const void *sendbuff, void *recvbuff, size_t count, DataType dtype, const CclComm *comm, + Stream *stream) const { + LOG(FATAL) << "CclImpl::AllToAll is not implemented."; +} + void CclImpl::Send(const void *buff, size_t count, DataType dtype, int peer, const CclComm *comm, Stream *stream) const { LOG(FATAL) << "CclImpl::Send is not implemented."; diff --git a/infini_train/src/core/ccl/cuda/nccl_impl.cc b/infini_train/src/core/ccl/cuda/nccl_impl.cc index 9e4b1a0d..98840e01 100644 --- a/infini_train/src/core/ccl/cuda/nccl_impl.cc +++ b/infini_train/src/core/ccl/cuda/nccl_impl.cc @@ -146,6 +146,40 @@ void NcclImpl::ReduceScatter(const void *sendbuff, void *recvbuff, size_t recv_c kNcclReduceOpMap.at(reduce_op), GetNcclComm(comm), GetCudaStream(stream))); } +void NcclImpl::AllToAll(const void *sendbuff, void *recvbuff, size_t count, DataType dtype, const CclComm *comm, + Stream *stream) const { + auto nccl_comm = GetNcclComm(comm); + auto cuda_stream = GetCudaStream(stream); + int nranks = 0; + int rank = 0; + NCCL_CHECK(ncclCommCount(nccl_comm, &nranks)); + NCCL_CHECK(ncclCommUserRank(nccl_comm, &rank)); + CHECK_GT(nranks, 0); + CHECK_GE(rank, 0); + CHECK_LT(rank, nranks); + + const size_t chunk_bytes = count * kDataTypeToSize.at(dtype); + auto send_ptr = static_cast(sendbuff); + auto recv_ptr = static_cast(recvbuff); + + if (chunk_bytes > 0) { + CUDA_CHECK(cudaMemcpyAsync(recv_ptr + static_cast(rank) * chunk_bytes, + send_ptr + static_cast(rank) * chunk_bytes, chunk_bytes, + cudaMemcpyDeviceToDevice, cuda_stream)); + } + + NCCL_CHECK(ncclGroupStart()); + for (int peer = 0; peer < nranks; ++peer) { + if (peer == rank) { + continue; + } + const auto offset = static_cast(peer) * chunk_bytes; + NCCL_CHECK(ncclSend(send_ptr + offset, count, kNcclDtypeMap.at(dtype), peer, nccl_comm, cuda_stream)); + NCCL_CHECK(ncclRecv(recv_ptr + offset, count, kNcclDtypeMap.at(dtype), peer, nccl_comm, cuda_stream)); + } + NCCL_CHECK(ncclGroupEnd()); +} + void NcclImpl::Send(const void *buff, size_t count, DataType dtype, int peer, const CclComm *comm, Stream *stream) const { NCCL_CHECK(ncclSend(buff, count, kNcclDtypeMap.at(dtype), peer, GetNcclComm(comm), GetCudaStream(stream))); diff --git a/infini_train/src/core/ccl/cuda/nccl_impl.h b/infini_train/src/core/ccl/cuda/nccl_impl.h index fca177fd..4fa68a87 100644 --- a/infini_train/src/core/ccl/cuda/nccl_impl.h +++ b/infini_train/src/core/ccl/cuda/nccl_impl.h @@ -42,6 +42,9 @@ class NcclImpl final : public CclImpl { nn::parallel::function::ReduceOpType reduce_op, const CclComm *comm, Stream *stream) const override; + void AllToAll(const void *sendbuff, void *recvbuff, size_t count, DataType dtype, const CclComm *comm, + Stream *stream) const override; + void Send(const void *buff, size_t count, DataType dtype, int peer, const CclComm *comm, Stream *stream) const override; diff --git a/infini_train/src/nn/modules/transformer/causal_self_attention.cc b/infini_train/src/nn/modules/transformer/causal_self_attention.cc index 40ce8d7d..4f6510ec 100644 --- a/infini_train/src/nn/modules/transformer/causal_self_attention.cc +++ b/infini_train/src/nn/modules/transformer/causal_self_attention.cc @@ -118,8 +118,8 @@ CausalSelfAttention::ForwardStandard(const std::vectorSlice({0, 0, q_start, 0}, {1, 1, q_start + T, T_kv}, {1, 1, 1, 1}); std::shared_ptr y; if (parallel::global::GetContextParallelSize() > 1) { - y = parallel::ContextParallelAttentionFunc(q, k, v, mask, /*mask_true_means_invalid=*/false, - static_cast(1.0 / std::sqrt(head_dim)), /*n_rep=*/1); + y = parallel::AttnForwardFuncWithCP(q, k, v, mask == 0, static_cast(1.0 / std::sqrt(head_dim)), + /*n_rep=*/1); } else { // (B, h_l, T_local, T_global) auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(head_dim)); @@ -248,8 +248,7 @@ CausalSelfAttention::ForwardWithRoPE(const std::vector y; if (use_context_parallel) { - y = parallel::ContextParallelAttentionFunc(q, k, v, mask, /*mask_true_means_invalid=*/true, - 1.0f / std::sqrt(static_cast(D)), n_rep_); + y = parallel::AttnForwardFuncWithCP(q, k, v, mask, 1.0f / std::sqrt(static_cast(D)), n_rep_); } else { // manual implementation of attention // this materializes the large (T,T) matrix for all the queries and keys diff --git a/infini_train/src/nn/parallel/context_parallel.cc b/infini_train/src/nn/parallel/context_parallel.cc index 4ffe881b..f424e366 100644 --- a/infini_train/src/nn/parallel/context_parallel.cc +++ b/infini_train/src/nn/parallel/context_parallel.cc @@ -25,16 +25,19 @@ namespace { const ProcessGroup *GetCPGroup(const std::shared_ptr &tensor) { auto cp_size = global::GetContextParallelSize(); CHECK_GT(cp_size, 0); - if (cp_size == 1) { - return nullptr; - } return ProcessGroupFactory::Instance(tensor->GetDevice().type()) ->Get(GetContextParallelProcessGroupName(tensor->GetDevice().Rank().GlobalRank())); } // Comm Kernel Call Functions -std::shared_ptr GatherAlongFirstDim(const std::shared_ptr &tensor, const ProcessGroup *cp_group) { +std::shared_ptr GatherAlongFirstDim(const std::shared_ptr &tensor) { const int cp_size = global::GetContextParallelSize(); + CHECK_GT(cp_size, 0) << "Context Parallel group not initialized"; + if (cp_size == 1) { + return tensor; + } + + auto cp_group = GetCPGroup(tensor); auto output_shape = tensor->Dims(); output_shape[0] *= cp_size; auto output = std::make_shared(output_shape, tensor->Dtype(), tensor->GetDevice()); @@ -42,9 +45,14 @@ std::shared_ptr GatherAlongFirstDim(const std::shared_ptr &tenso return output; } -std::shared_ptr ReduceScatterAlongFirstDim(const std::shared_ptr &tensor, - const ProcessGroup *cp_group) { +std::shared_ptr ReduceScatterAlongFirstDim(const std::shared_ptr &tensor) { const int cp_size = global::GetContextParallelSize(); + CHECK_GT(cp_size, 0) << "Context Parallel group not initialized"; + if (cp_size == 1) { + return tensor; + } + + auto cp_group = GetCPGroup(tensor); auto output_shape = tensor->Dims(); CHECK_EQ(output_shape[0] % cp_size, 0) << "First dimension must be divisible by CP size"; output_shape[0] /= cp_size; @@ -53,6 +61,77 @@ std::shared_ptr ReduceScatterAlongFirstDim(const std::shared_ptr return output; } +std::shared_ptr AllToAllAlongFirstDim(const std::shared_ptr &tensor) { + // Tensor P is split along first dim in [P0 | P1 | ... | Pn] + // Each rank j sends Pj to every other rank and receive the rest of P from every other rank + const int cp_size = global::GetContextParallelSize(); + CHECK_GT(cp_size, 0) << "Context Parallel group not initialized"; + if (cp_size == 1) { + return tensor; + } + + auto cp_group = GetCPGroup(tensor); + auto output_shape = tensor->Dims(); + CHECK_EQ(output_shape[0] % cp_size, 0) << "First dimension must be divisible by CP size"; + auto output = std::make_shared(output_shape, tensor->Dtype(), tensor->GetDevice()); + cp_group->AllToAll(output, tensor, false); + return output; +} + +std::shared_ptr AllToAllSeqToHead(const std::shared_ptr &input) { + if (global::GetContextParallelSize() == 1) { + return input; + } + + const int cp_size = global::GetContextParallelSize(); + const auto &shape = input->Dims(); + CHECK_EQ(shape.size(), 4); + const int64_t B = shape[0], H = shape[1], T_l = shape[2], D = shape[3]; + CHECK_EQ(H % cp_size, 0) << "A2A CP requires head dimension divisible by CP size"; + const int64_t H_per_cp = H / cp_size; + + // input: (B, H, T_l, D) + // + // send_input: (H, B, T_l, D), split dim 0 into CP chunks of H_per_cp heads. + auto send_input = input->Transpose(0, 1)->Contiguous(); + // exchanged: (H, B, T_l, D), dim 0 chunks are ordered by source sequence-owner rank. + auto exchanged = AllToAllAlongFirstDim(send_input); + // output: (B, H_per_cp, T_g, D) + return exchanged->View({cp_size, H_per_cp, B, T_l, D}) + ->Transpose(0, 2) + ->Contiguous() + ->View({B, H_per_cp, static_cast(cp_size) * T_l, D}); +} + +std::shared_ptr AllToAllHeadToSeq(const std::shared_ptr &input) { + if (global::GetContextParallelSize() == 1) { + return input; + } + + const int cp_size = global::GetContextParallelSize(); + const auto &shape = input->Dims(); + CHECK_EQ(shape.size(), 4); + const int64_t B = shape[0], H_per_cp = shape[1], T_g = shape[2], D = shape[3]; + CHECK_EQ(T_g % cp_size, 0) << "A2A CP requires sequence dimension divisible by CP size"; + const int64_t T_l = T_g / cp_size; + + // input: (B, H_per_cp, T_g, D) + // + // send_input: (CP * H_per_cp, B, T_l, D), split dim 0 into CP sequence-owner chunks. + auto send_input = input->View({B, H_per_cp, cp_size, T_l, D}) + ->Transpose(0, 2) + ->Contiguous() + ->View({static_cast(cp_size) * H_per_cp, B, T_l, D}); + // exchanged: (CP * H_per_cp, B, T_l, D), dim 0 chunks are ordered by source head-owner rank. + auto exchanged = AllToAllAlongFirstDim(send_input); + // output: (B, CP * H_per_cp, T_l, D) + return exchanged->View({cp_size, H_per_cp, B, T_l, D}) + ->Transpose(1, 2) + ->Transpose(0, 1) + ->Contiguous() + ->View({B, static_cast(cp_size) * H_per_cp, T_l, D}); +} + // Attention Helper Functions std::shared_ptr NewZeroTensorLike(const std::shared_ptr &tensor) { auto output = std::make_shared(tensor->Dims(), tensor->Dtype(), tensor->GetDevice()); @@ -81,106 +160,17 @@ std::shared_ptr SumRepeatedKVHeads(const std::shared_ptr &x, int return x->View({B, H / n_rep, n_rep, T, D})->Sum(2); } -struct RingAttentionForwardResult { - std::shared_ptr output; - std::shared_ptr softmax_max; - std::shared_ptr softmax_sum; -}; - -RingAttentionForwardResult RingOnlineAttentionForward(const std::shared_ptr &q, - const std::shared_ptr &k_local, - const std::shared_ptr &v_local, - const std::shared_ptr &mask, bool mask_true_means_invalid, - float scale, int64_t n_rep) { - const int cp_size = global::GetContextParallelSize(); - CHECK_GT(cp_size, 1); - CHECK(mask) << "CP ring attention expects a causal mask."; - - auto cp_group = GetCPGroup(q); - CHECK_NOTNULL(cp_group); - const int rank = cp_group->GetGroupRank(q->GetDevice().Rank().GlobalRank()); - const int send_to = (rank + 1) % cp_size; - const int recv_from = (rank - 1 + cp_size) % cp_size; - - const int64_t local_t = q->Dims()[2]; - CHECK_EQ(k_local->Dims()[2], local_t); - CHECK_EQ(v_local->Dims()[2], local_t); - CHECK_EQ(q->Dims()[1], k_local->Dims()[1] * n_rep); - CHECK_EQ(q->Dims()[1], v_local->Dims()[1] * n_rep); - - auto current_k = k_local; - auto current_v = v_local; - std::shared_ptr running_max; - std::shared_ptr running_sum; - std::shared_ptr running_out; - - for (int step = 0; step < cp_size; ++step) { - std::shared_ptr next_k; - std::shared_ptr next_v; - std::shared_ptr send_work; - std::shared_ptr recv_work; - if (step + 1 < cp_size) { - next_k = std::make_shared(k_local->Dims(), k_local->Dtype(), k_local->GetDevice()); - next_v = std::make_shared(v_local->Dims(), v_local->Dtype(), v_local->GetDevice()); - if (rank % 2 == 0) { - send_work = cp_group->Send({current_k, current_v}, send_to, true); - recv_work = cp_group->Recv({next_k, next_v}, recv_from, true); - } else { - recv_work = cp_group->Recv({next_k, next_v}, recv_from, true); - send_work = cp_group->Send({current_k, current_v}, send_to, true); - } - } - - const int owner = (rank - step + cp_size) % cp_size; - const int64_t kv_start = static_cast(owner) * local_t; - const int64_t kv_end = kv_start + local_t; - - auto k_for_attn = RepeatKVHeads(current_k, n_rep); - auto v_for_attn = RepeatKVHeads(current_v, n_rep); - auto scores = q->Matmul(k_for_attn->Transpose(-2, -1)) * scale; - auto mask_chunk = mask->Slice(3, kv_start, kv_end); - auto invalid_mask = mask_true_means_invalid ? mask_chunk : (mask_chunk == 0); - scores = scores->MaskedFill(invalid_mask, std::numeric_limits::lowest()); - - auto chunk_max = scores->Max(-1, true); - auto probs = (scores - chunk_max)->Exp()->MaskedFill(invalid_mask, 0.0f); - auto chunk_sum = probs->Sum(-1, true); - auto chunk_out = probs->Matmul(v_for_attn); - - if (!running_out) { - running_max = chunk_max; - running_sum = chunk_sum; - running_out = chunk_out; - } else { - auto new_max - = nn::function::Stack(std::vector>{running_max, chunk_max}, -1)->Max(-1); - auto old_scale = (running_max - new_max)->Exp(); - auto new_scale = (chunk_max - new_max)->Exp(); - running_sum = running_sum * old_scale + chunk_sum * new_scale; - running_out = running_out * old_scale + chunk_out * new_scale; - running_max = new_max; - } - - if (recv_work) { - recv_work->WaitNonBlocking(); - send_work->WaitNonBlocking(); - current_k = next_k; - current_v = next_v; - } - } - - return {.output = running_out / running_sum, .softmax_max = running_max, .softmax_sum = running_sum}; -} - -std::shared_ptr ApplyAttention(const std::shared_ptr &q, const std::shared_ptr &k, - const std::shared_ptr &v, const std::shared_ptr &mask, - bool mask_true_means_invalid, float scale) { +std::shared_ptr ApplyCoreAttention(const std::shared_ptr &q, const std::shared_ptr &k, + const std::shared_ptr &v, const std::shared_ptr &mask, + float scale) { + // scores: (B, H, T_q, T_k) auto scores = q->Matmul(k->Transpose(-2, -1)) * scale; if (mask) { - auto invalid_mask = mask_true_means_invalid ? mask : (mask == 0); - scores = scores->MaskedFill(invalid_mask, std::numeric_limits::lowest()); + scores = scores->MaskedFill(mask, std::numeric_limits::lowest()); } + // probs: (B, H, T_q, T_k) auto probs = nn::function::Softmax(scores, -1); + // output: (B, H, T_q, D) return probs->Matmul(v); } @@ -193,47 +183,44 @@ class GatherFromCPRegion : public autograd::Function { std::vector> Forward(const std::vector> &input_tensors) override { auto input = input_tensors[0]; - if (global::GetContextParallelSize() == 1) { - return {std::make_shared(*input)}; - } - auto cp_group = GetCPGroup(input); - // FIXME(zbl): Megatron keeps sequence as dim 0. InfiniTrain uses [B, H, S, D], so move only + // FIXME(zbl): Megatron keeps sequence as dim 0. We uses [B, H, S, D], so move only // the sequence dimension to dim 0 before the CP gather. - return {GatherAlongFirstDim(input->Transpose(0, 2), cp_group)->Transpose(0, 2)->Contiguous()}; + return {GatherAlongFirstDim(input->Transpose(0, 2))->Transpose(0, 2)->Contiguous()}; } std::vector> Backward(const std::vector> &grad_outputs) override { - if (global::GetContextParallelSize() == 1) { - return {std::make_shared(*grad_outputs[0])}; - } - auto cp_group = GetCPGroup(grad_outputs[0]); // FIXME(zbl): See Forward() for the [B, H, S, D] sequence-first bridge. - return {ReduceScatterAlongFirstDim(grad_outputs[0]->Transpose(0, 2), cp_group)->Transpose(0, 2)->Contiguous()}; + return {ReduceScatterAlongFirstDim(grad_outputs[0]->Transpose(0, 2))->Transpose(0, 2)->Contiguous()}; } }; -class GatherFromCPHeadRegion : public autograd::Function { +class AllToAllSeqToHeadCPRegion : public autograd::Function { public: - static constexpr char kType[] = "GatherFromCPHeadRegionFunction"; + static constexpr char kType[] = "AllToAllSeqToHeadCPRegionFunction"; - explicit GatherFromCPHeadRegion() : autograd::Function(kType) {} + explicit AllToAllSeqToHeadCPRegion() : autograd::Function(kType) {} std::vector> Forward(const std::vector> &input_tensors) override { - auto input = input_tensors[0]; - if (global::GetContextParallelSize() == 1) { - return {std::make_shared(*input)}; - } - auto cp_group = GetCPGroup(input); - // A2A CP shards heads across CP ranks. ProcessGroup all-gather works on dim 0, so move heads there. - return {GatherAlongFirstDim(input->Transpose(0, 1), cp_group)->Transpose(0, 1)->Contiguous()}; + return {AllToAllSeqToHead(input_tensors[0])}; } std::vector> Backward(const std::vector> &grad_outputs) override { - if (global::GetContextParallelSize() == 1) { - return {std::make_shared(*grad_outputs[0])}; - } - auto cp_group = GetCPGroup(grad_outputs[0]); - return {ReduceScatterAlongFirstDim(grad_outputs[0]->Transpose(0, 1), cp_group)->Transpose(0, 1)->Contiguous()}; + return {AllToAllHeadToSeq(grad_outputs[0])}; + } +}; + +class AllToAllHeadToSeqCPRegion : public autograd::Function { +public: + static constexpr char kType[] = "AllToAllHeadToSeqCPRegionFunction"; + + explicit AllToAllHeadToSeqCPRegion() : autograd::Function(kType) {} + + std::vector> Forward(const std::vector> &input_tensors) override { + return {AllToAllHeadToSeq(input_tensors[0])}; + } + + std::vector> Backward(const std::vector> &grad_outputs) override { + return {AllToAllSeqToHead(grad_outputs[0])}; } }; @@ -241,36 +228,161 @@ class AttnWithCPAndKVP2P : public autograd::Function { public: static constexpr char kType[] = "AttnWithCPAndKVP2PFunction"; - AttnWithCPAndKVP2P(bool mask_true_means_invalid, float scale, int64_t n_rep) - : autograd::Function(kType), mask_true_means_invalid_(mask_true_means_invalid), scale_(scale), n_rep_(n_rep) {} + AttnWithCPAndKVP2P(float scale, int64_t n_rep) : autograd::Function(kType), scale_(scale), n_rep_(n_rep) {} std::vector> Forward(const std::vector> &input_tensors) override { CHECK_EQ(input_tensors.size(), 4); - auto result = RingOnlineAttentionForward(input_tensors[0], input_tensors[1], input_tensors[2], input_tensors[3], - mask_true_means_invalid_, scale_, n_rep_); - softmax_max_ = result.softmax_max; - softmax_sum_ = result.softmax_sum; - return {result.output}; + // Shape notation: + // B: batch size, H_q: local query heads after TP, H_kv: local KV heads before GQA repeat, + // T_l: CP-local sequence length, T_g: global sequence length, D: head dimension. + // q: (B, H_q, T_l, D) + const auto &q = input_tensors[0]; + // k_local: (B, H_kv, T_l, D) + const auto &k_local = input_tensors[1]; + // v_local: (B, H_kv, T_l, D) + const auto &v_local = input_tensors[2]; + // mask: (1, 1, T_l, T_g), true values are invalid attention locations. + const auto &mask = input_tensors[3]; + const int cp_size = global::GetContextParallelSize(); + CHECK_GT(cp_size, 1); + CHECK(mask) << "CP ring attention expects a causal mask."; + + auto cp_group = GetCPGroup(q); + CHECK_NOTNULL(cp_group); + const int rank = cp_group->GetGroupRank(q->GetDevice().Rank().GlobalRank()); + const int send_to = (rank + 1) % cp_size; + const int recv_from = (rank - 1 + cp_size) % cp_size; + + const int64_t local_t = q->Dims()[2]; + CHECK_EQ(k_local->Dims()[2], local_t); + CHECK_EQ(v_local->Dims()[2], local_t); + CHECK_EQ(q->Dims()[1], k_local->Dims()[1] * n_rep_); + CHECK_EQ(q->Dims()[1], v_local->Dims()[1] * n_rep_); + + // current_k: (B, H_kv, T_l, D), owned by rank `(rank - step + cp_size) % cp_size`. + auto current_k = k_local; + // current_v: (B, H_kv, T_l, D), owned by rank `(rank - step + cp_size) % cp_size`. + auto current_v = v_local; + // running_max: (B, H_q, T_l, 1) + std::shared_ptr running_max; + // running_sum: (B, H_q, T_l, 1) + std::shared_ptr running_sum; + // running_out: (B, H_q, T_l, D) + std::shared_ptr running_out; + + for (int step = 0; step < cp_size; ++step) { + std::shared_ptr next_k; + std::shared_ptr next_v; + std::shared_ptr send_work; + std::shared_ptr recv_work; + if (step + 1 < cp_size) { + // next_k: (B, H_kv, T_l, D) + next_k = std::make_shared(k_local->Dims(), k_local->Dtype(), k_local->GetDevice()); + // next_v: (B, H_kv, T_l, D) + next_v = std::make_shared(v_local->Dims(), v_local->Dtype(), v_local->GetDevice()); + if (rank % 2 == 0) { + send_work = cp_group->Send({current_k, current_v}, send_to, true); + recv_work = cp_group->Recv({next_k, next_v}, recv_from, true); + } else { + recv_work = cp_group->Recv({next_k, next_v}, recv_from, true); + send_work = cp_group->Send({current_k, current_v}, send_to, true); + } + } + + const int owner = (rank - step + cp_size) % cp_size; + const int64_t kv_start = static_cast(owner) * local_t; + const int64_t kv_end = kv_start + local_t; + + // k_for_attn: (B, H_q, T_l, D) + auto k_for_attn = RepeatKVHeads(current_k, n_rep_); + // v_for_attn: (B, H_q, T_l, D) + auto v_for_attn = RepeatKVHeads(current_v, n_rep_); + // scores: (B, H_q, T_l, T_l) + auto scores = q->Matmul(k_for_attn->Transpose(-2, -1)) * scale_; + // invalid_mask: (1, 1, T_l, T_l) + auto invalid_mask = mask->Slice(3, kv_start, kv_end); + scores = scores->MaskedFill(invalid_mask, std::numeric_limits::lowest()); + + // chunk_max: (B, H_q, T_l, 1) + auto chunk_max = scores->Max(-1, true); + // probs: (B, H_q, T_l, T_l) + auto probs = (scores - chunk_max)->Exp()->MaskedFill(invalid_mask, 0.0f); + // chunk_sum: (B, H_q, T_l, 1) + auto chunk_sum = probs->Sum(-1, true); + // chunk_out: (B, H_q, T_l, D) + auto chunk_out = probs->Matmul(v_for_attn); + + if (!running_out) { + running_max = chunk_max; + running_sum = chunk_sum; + running_out = chunk_out; + } else { + // new_max: (B, H_q, T_l, 1) + auto new_max + = nn::function::Stack(std::vector>{running_max, chunk_max}, -1)->Max(-1); + // old_scale: (B, H_q, T_l, 1) + auto old_scale = (running_max - new_max)->Exp(); + // new_scale: (B, H_q, T_l, 1) + auto new_scale = (chunk_max - new_max)->Exp(); + running_sum = running_sum * old_scale + chunk_sum * new_scale; + running_out = running_out * old_scale + chunk_out * new_scale; + running_max = new_max; + } + + if (recv_work) { + recv_work->WaitNonBlocking(); + send_work->WaitNonBlocking(); + current_k = next_k; + current_v = next_v; + } + } + + // output: (B, H_q, T_l, D) + // running_max: (B, H_q, T_l, 1) + // running_sum: (B, H_q, T_l, 1) + return {running_out / running_sum, running_max, running_sum}; } void SetupContext(const std::vector> &input_tensors, const std::vector> &output_tensors) override { - ctx_.SaveForBackward({input_tensors[0], input_tensors[1], input_tensors[2], input_tensors[3], output_tensors[0], - softmax_max_, softmax_sum_}); + CHECK_EQ(output_tensors.size(), 3); + const auto &output = output_tensors[0]; + const auto &softmax_max = output_tensors[1]; + const auto &softmax_sum = output_tensors[2]; + ctx_.MarkNonDifferentiable({softmax_max, softmax_sum}); + ctx_.SaveForBackward( + {input_tensors[0], input_tensors[1], input_tensors[2], input_tensors[3], output, softmax_max, softmax_sum}); } std::vector> Backward(const std::vector> &grad_outputs) override { - CHECK_EQ(grad_outputs.size(), 1); + // This backward is intentionally handwritten. Although the p2p Forward is composed from + // small Tensor ops, InfiniTrain executes autograd::Function::Forward() under NoGradGuard, + // so those ops are not recorded in the autograd graph. Raw CP p2p Send/Recv also has no + // autograd edge; the backward ring must explicitly accumulate every Q shard's contribution + // and return each K/V chunk's gradients to its owner rank. + // + // Shape notation: + // B: batch size, H_q: local query heads after TP, H_kv: local KV heads before GQA repeat, + // T_l: CP-local sequence length, T_g: global sequence length, D: head dimension. + CHECK_GE(grad_outputs.size(), 1); auto saved_tensors = ctx_.GetSavedTensors(); CHECK_EQ(saved_tensors.size(), 7); + // q: (B, H_q, T_l, D) const auto &q = saved_tensors[0]; + // k_local: (B, H_kv, T_l, D) const auto &k_local = saved_tensors[1]; + // v_local: (B, H_kv, T_l, D) const auto &v_local = saved_tensors[2]; + // mask: (1, 1, T_l, T_g), true values are invalid attention locations. const auto &mask = saved_tensors[3]; + // output: (B, H_q, T_l, D) const auto &output = saved_tensors[4]; + // softmax_max: (B, H_q, T_l, 1) const auto &softmax_max = saved_tensors[5]; + // softmax_sum: (B, H_q, T_l, 1) const auto &softmax_sum = saved_tensors[6]; + // grad_output: (B, H_q, T_l, D) const auto &grad_output = grad_outputs[0]; const int cp_size = global::GetContextParallelSize(); @@ -288,11 +400,17 @@ class AttnWithCPAndKVP2P : public autograd::Function { CHECK_EQ(q->Dims()[1], k_local->Dims()[1] * n_rep_); CHECK_EQ(q->Dims()[1], v_local->Dims()[1] * n_rep_); + // current_k: (B, H_kv, T_l, D) auto current_k = k_local; + // current_v: (B, H_kv, T_l, D) auto current_v = v_local; + // current_grad_k: (B, H_kv, T_l, D) auto current_grad_k = NewZeroTensorLike(k_local); + // current_grad_v: (B, H_kv, T_l, D) auto current_grad_v = NewZeroTensorLike(v_local); + // grad_q: (B, H_q, T_l, D) std::shared_ptr grad_q; + // softmax_delta: (B, H_q, T_l, 1) const auto softmax_delta = (grad_output * output)->Sum(-1, true); for (int step = 0; step < cp_size; ++step) { @@ -300,19 +418,28 @@ class AttnWithCPAndKVP2P : public autograd::Function { const int64_t kv_start = static_cast(owner) * local_t; const int64_t kv_end = kv_start + local_t; + // k_for_attn: (B, H_q, T_l, D) auto k_for_attn = RepeatKVHeads(current_k, n_rep_); + // v_for_attn: (B, H_q, T_l, D) auto v_for_attn = RepeatKVHeads(current_v, n_rep_); + // scores: (B, H_q, T_l, T_l) auto scores = q->Matmul(k_for_attn->Transpose(-2, -1)) * scale_; - auto mask_chunk = mask->Slice(3, kv_start, kv_end); - auto invalid_mask = mask_true_means_invalid_ ? mask_chunk : (mask_chunk == 0); + // invalid_mask: (1, 1, T_l, T_l) + auto invalid_mask = mask->Slice(3, kv_start, kv_end); scores = scores->MaskedFill(invalid_mask, std::numeric_limits::lowest()); + // probs: (B, H_q, T_l, T_l) auto probs = (scores - softmax_max)->Exp()->MaskedFill(invalid_mask, 0.0f) / softmax_sum; + // grad_v_repeated: (B, H_q, T_l, D) auto grad_v_repeated = probs->Transpose(-2, -1)->Matmul(grad_output); + // grad_probs: (B, H_q, T_l, T_l) auto grad_probs = grad_output->Matmul(v_for_attn->Transpose(-2, -1)); + // grad_scores: (B, H_q, T_l, T_l) auto grad_scores = probs * (grad_probs - softmax_delta); + // grad_k_repeated: (B, H_q, T_l, D) auto grad_k_repeated = grad_scores->Transpose(-2, -1)->Matmul(q) * scale_; + // SumRepeatedKVHeads maps repeated GQA gradients from (B, H_q, T_l, D) to (B, H_kv, T_l, D). current_grad_k = current_grad_k + SumRepeatedKVHeads(grad_k_repeated, n_rep_); current_grad_v = current_grad_v + SumRepeatedKVHeads(grad_v_repeated, n_rep_); @@ -324,9 +451,14 @@ class AttnWithCPAndKVP2P : public autograd::Function { std::shared_ptr next_grad_v; if (step + 1 < cp_size) { + // Send current K/V plus accumulated K/V grads to the next rank; receive the previous owner chunk. + // next_k: (B, H_kv, T_l, D) next_k = std::make_shared(k_local->Dims(), k_local->Dtype(), k_local->GetDevice()); + // next_v: (B, H_kv, T_l, D) next_v = std::make_shared(v_local->Dims(), v_local->Dtype(), v_local->GetDevice()); + // next_grad_k: (B, H_kv, T_l, D) next_grad_k = NewZeroTensorLike(k_local); + // next_grad_v: (B, H_kv, T_l, D) next_grad_v = NewZeroTensorLike(v_local); if (rank % 2 == 0) { send_work = cp_group->Send({current_k, current_v, current_grad_k, current_grad_v}, send_to, true); @@ -336,7 +468,10 @@ class AttnWithCPAndKVP2P : public autograd::Function { send_work = cp_group->Send({current_k, current_v, current_grad_k, current_grad_v}, send_to, true); } } else { + // Last step only needs to rotate accumulated K/V grads back to the local owner rank. + // next_grad_k: (B, H_kv, T_l, D) next_grad_k = NewZeroTensorLike(k_local); + // next_grad_v: (B, H_kv, T_l, D) next_grad_v = NewZeroTensorLike(v_local); if (rank % 2 == 0) { send_work = cp_group->Send({current_grad_k, current_grad_v}, send_to, true); @@ -347,6 +482,7 @@ class AttnWithCPAndKVP2P : public autograd::Function { } } + // grad_q_chunk: (B, H_q, T_l, D) auto grad_q_chunk = grad_scores->Matmul(k_for_attn) * scale_; grad_q = grad_q ? grad_q + grad_q_chunk : grad_q_chunk; @@ -368,18 +504,16 @@ class AttnWithCPAndKVP2P : public autograd::Function { } private: - bool mask_true_means_invalid_ = false; float scale_ = 1.0f; int64_t n_rep_ = 1; - std::shared_ptr softmax_max_; - std::shared_ptr softmax_sum_; }; -std::shared_ptr GatherFromCPHeadRegionFunc(const std::shared_ptr &input) { - if (global::GetContextParallelSize() == 1) { - return input; - } - return std::make_shared()->Apply({input})[0]; +std::shared_ptr AllToAllSeqToHeadCPRegionFunc(const std::shared_ptr &input) { + return std::make_shared()->Apply({input})[0]; +} + +std::shared_ptr AllToAllHeadToSeqCPRegionFunc(const std::shared_ptr &input) { + return std::make_shared()->Apply({input})[0]; } } // namespace @@ -412,75 +546,71 @@ std::shared_ptr SliceAlongCPRegionFunc(const std::shared_ptr &in } std::shared_ptr GatherFromCPRegionFunc(const std::shared_ptr &input) { - if (global::GetContextParallelSize() == 1) { - return input; - } return std::make_shared()->Apply({input})[0]; } // CP Attention Backend Functions std::shared_ptr AttnFuncWithCPAndKVP2P(const std::shared_ptr &q, const std::shared_ptr &k, const std::shared_ptr &v, const std::shared_ptr &mask, - bool mask_true_means_invalid, float scale, int64_t n_rep) { - return std::make_shared(mask_true_means_invalid, scale, n_rep)->Apply({q, k, v, mask})[0]; + float scale, int64_t n_rep) { + return std::make_shared(scale, n_rep)->Apply({q, k, v, mask})[0]; } std::shared_ptr AttnFuncWithCPAndKVAllGather(const std::shared_ptr &q, const std::shared_ptr &k, const std::shared_ptr &v, - const std::shared_ptr &mask, bool mask_true_means_invalid, - float scale, int64_t n_rep) { + const std::shared_ptr &mask, float scale, int64_t n_rep) { + // gathered_k: (B, H_kv, T_g, D) auto gathered_k = GatherFromCPRegionFunc(k); + // gathered_v: (B, H_kv, T_g, D) auto gathered_v = GatherFromCPRegionFunc(v); + // k_for_attn: (B, H_q, T_g, D) auto k_for_attn = RepeatKVHeads(gathered_k, n_rep); + // v_for_attn: (B, H_q, T_g, D) auto v_for_attn = RepeatKVHeads(gathered_v, n_rep); - return ApplyAttention(q, k_for_attn, v_for_attn, mask, mask_true_means_invalid, scale); + return ApplyCoreAttention(q, k_for_attn, v_for_attn, mask, scale); } std::shared_ptr AttnFuncWithCPAndQKVOA2A(const std::shared_ptr &q, const std::shared_ptr &k, const std::shared_ptr &v, const std::shared_ptr &mask, - bool mask_true_means_invalid, float scale, int64_t n_rep) { + float scale, int64_t n_rep) { const int cp_size = global::GetContextParallelSize(); - const int rank = GetCPGroup(q)->GetGroupRank(q->GetDevice().Rank().GlobalRank()); const int64_t q_heads = q->Dims()[1]; const int64_t kv_heads = k->Dims()[1]; CHECK_EQ(q_heads % cp_size, 0) << "A2A CP requires local query heads divisible by CP size"; CHECK_EQ(kv_heads % cp_size, 0) << "A2A CP requires local KV heads divisible by CP size"; - // TODO(zbl): Replace this semantic all-gather/slice implementation with true QKV/O all-to-all - // once ProcessGroup exposes an async all-to-all primitive. - auto full_q = GatherFromCPRegionFunc(q); - auto full_k = GatherFromCPRegionFunc(k); - auto full_v = GatherFromCPRegionFunc(v); + // q_shard: (B, H_q/CP, T_g, D) + auto q_shard = AllToAllSeqToHeadCPRegionFunc(q); + // k_shard: (B, H_kv/CP, T_g, D) + auto k_shard = AllToAllSeqToHeadCPRegionFunc(k); + // v_shard: (B, H_kv/CP, T_g, D) + auto v_shard = AllToAllSeqToHeadCPRegionFunc(v); + // full_mask: (1, 1, T_g, T_g) auto full_mask = mask ? GatherFromCPRegionFunc(mask) : nullptr; - const int64_t q_heads_per_cp = q_heads / cp_size; - const int64_t kv_heads_per_cp = kv_heads / cp_size; - auto q_shard = full_q->Slice(1, rank * q_heads_per_cp, (rank + 1) * q_heads_per_cp); - auto k_shard = full_k->Slice(1, rank * kv_heads_per_cp, (rank + 1) * kv_heads_per_cp); - auto v_shard = full_v->Slice(1, rank * kv_heads_per_cp, (rank + 1) * kv_heads_per_cp); + // k_for_attn: (B, H_q/CP, T_g, D) auto k_for_attn = RepeatKVHeads(k_shard, n_rep); + // v_for_attn: (B, H_q/CP, T_g, D) auto v_for_attn = RepeatKVHeads(v_shard, n_rep); - auto output_shard = ApplyAttention(q_shard, k_for_attn, v_for_attn, full_mask, mask_true_means_invalid, scale); + // output_shard: (B, H_q/CP, T_g, D) + auto output_shard = ApplyCoreAttention(q_shard, k_for_attn, v_for_attn, full_mask, scale); - auto gathered_heads = GatherFromCPHeadRegionFunc(output_shard); - const int64_t local_t = q->Dims()[2]; - const int64_t seq_start = static_cast(rank) * local_t; - return gathered_heads->Slice(2, seq_start, seq_start + local_t)->Contiguous(); + // output: (B, H_q, T_l, D) + return AllToAllHeadToSeqCPRegionFunc(output_shard); } -std::shared_ptr ContextParallelAttentionFunc(const std::shared_ptr &q, const std::shared_ptr &k, - const std::shared_ptr &v, - const std::shared_ptr &mask, bool mask_true_means_invalid, - float scale, int64_t n_rep) { +std::shared_ptr AttnForwardFuncWithCP(const std::shared_ptr &q, const std::shared_ptr &k, + const std::shared_ptr &v, const std::shared_ptr &mask, + float scale, int64_t n_rep) { CHECK_GT(global::GetContextParallelSize(), 1); const auto comm_type = global::GetContextParallelCommType(); if (comm_type == "p2p") { - return AttnFuncWithCPAndKVP2P(q, k, v, mask, mask_true_means_invalid, scale, n_rep); + return AttnFuncWithCPAndKVP2P(q, k, v, mask, scale, n_rep); } if (comm_type == "a2a") { - return AttnFuncWithCPAndQKVOA2A(q, k, v, mask, mask_true_means_invalid, scale, n_rep); + return AttnFuncWithCPAndQKVOA2A(q, k, v, mask, scale, n_rep); } - return AttnFuncWithCPAndKVAllGather(q, k, v, mask, mask_true_means_invalid, scale, n_rep); + return AttnFuncWithCPAndKVAllGather(q, k, v, mask, scale, n_rep); } } // namespace infini_train::nn::parallel diff --git a/infini_train/src/nn/parallel/parallel_functional.cc b/infini_train/src/nn/parallel/parallel_functional.cc index ffd218d7..3aa46731 100644 --- a/infini_train/src/nn/parallel/parallel_functional.cc +++ b/infini_train/src/nn/parallel/parallel_functional.cc @@ -40,6 +40,15 @@ std::shared_ptr ReduceScatter(const std::shared_ptr &output, const return pg->ReduceScatter(output, input, reduce_op, async_op); } +std::shared_ptr AllToAll(const std::shared_ptr &output, const std::shared_ptr &input, + const ProcessGroup *pg, bool async_op) { + auto device = output->GetDevice().type(); + if (pg == nullptr) { + pg = ProcessGroupFactory::Instance(device)->GetDefaultProcessGroup(); + } + return pg->AllToAll(output, input, async_op); +} + std::vector>> Scatter(const std::vector> &input_tensors, const std::vector &devices, int dim) { std::vector>> output_tensors; diff --git a/infini_train/src/nn/parallel/process_group.cc b/infini_train/src/nn/parallel/process_group.cc index 3c4c4910..2b5f53a1 100644 --- a/infini_train/src/nn/parallel/process_group.cc +++ b/infini_train/src/nn/parallel/process_group.cc @@ -194,6 +194,33 @@ std::shared_ptr ProcessGroup::ReduceScatter(const std::shared_ptr } } +std::shared_ptr ProcessGroup::AllToAll(const std::shared_ptr &output, + const std::shared_ptr &input, bool async_op) const { + auto device = input->GetDevice(); + CHECK_EQ(device, output->GetDevice()); + CHECK(input->Dtype() == output->Dtype()); + CHECK_EQ(input->NumElements(), output->NumElements()); + CHECK_EQ(input->NumElements() % world_size_, 0) << "AllToAll input must be evenly divisible by world size"; + core::DeviceGuard guard(device); + auto *compute_stream = runtime_impl_->GetStream(device); + auto *comm_stream = device_stream_map_.at(device.index()); + auto comm = device_comm_map_.at(device.index()); + + auto work = std::make_shared(device, comm); + runtime_impl_->EventRecord(work->ready_event(), compute_stream); + runtime_impl_->StreamWaitEvent(comm_stream, work->ready_event(), 0); + ccl_impl_->AllToAll(input->DataPtr(), output->DataPtr(), input->NumElements() / world_size_, input->Dtype(), comm, + comm_stream); + runtime_impl_->EventRecord(work->done_event(), comm_stream); + + if (async_op) { + return work; + } else { + work->WaitNonBlocking(); + return nullptr; + } +} + std::shared_ptr ProcessGroup::Send(std::vector> tensors, int dest_rank, bool async_op) const { CHECK_GT(tensors.size(), 0); From d03ba162732e501244b68924e9c4b72efb28c36f Mon Sep 17 00:00:00 2001 From: bolunz Date: Wed, 1 Jul 2026 14:15:41 +0800 Subject: [PATCH 3/3] feat: support batch_send_recv --- .../include/nn/parallel/context_parallel.h | 12 +- .../include/nn/parallel/process_group.h | 13 ++ .../transformer/causal_self_attention.cc | 5 +- .../src/nn/parallel/context_parallel.cc | 174 ++++++++++-------- infini_train/src/nn/parallel/process_group.cc | 40 ++++ scripts/test_config.json | 8 +- 6 files changed, 163 insertions(+), 89 deletions(-) diff --git a/infini_train/include/nn/parallel/context_parallel.h b/infini_train/include/nn/parallel/context_parallel.h index b73c1772..309d8802 100644 --- a/infini_train/include/nn/parallel/context_parallel.h +++ b/infini_train/include/nn/parallel/context_parallel.h @@ -2,6 +2,7 @@ #include #include +#include namespace infini_train { class Tensor; @@ -20,19 +21,16 @@ std::shared_ptr SliceAlongCPRegionFunc(const std::shared_ptr &in std::shared_ptr GatherFromCPRegionFunc(const std::shared_ptr &input); std::shared_ptr AttnForwardFuncWithCP(const std::shared_ptr &q, const std::shared_ptr &k, - const std::shared_ptr &v, const std::shared_ptr &mask, - float scale, int64_t n_rep); + const std::shared_ptr &v, const std::shared_ptr &mask); std::shared_ptr AttnFuncWithCPAndKVP2P(const std::shared_ptr &q, const std::shared_ptr &k, - const std::shared_ptr &v, const std::shared_ptr &mask, - float scale, int64_t n_rep); + const std::shared_ptr &v, const std::shared_ptr &mask); std::shared_ptr AttnFuncWithCPAndKVAllGather(const std::shared_ptr &q, const std::shared_ptr &k, const std::shared_ptr &v, - const std::shared_ptr &mask, float scale, int64_t n_rep); + const std::shared_ptr &mask); std::shared_ptr AttnFuncWithCPAndQKVOA2A(const std::shared_ptr &q, const std::shared_ptr &k, - const std::shared_ptr &v, const std::shared_ptr &mask, - float scale, int64_t n_rep); + const std::shared_ptr &v, const std::shared_ptr &mask); } // namespace infini_train::nn::parallel diff --git a/infini_train/include/nn/parallel/process_group.h b/infini_train/include/nn/parallel/process_group.h index c9541321..d5c4f3ac 100644 --- a/infini_train/include/nn/parallel/process_group.h +++ b/infini_train/include/nn/parallel/process_group.h @@ -30,6 +30,17 @@ class Work; namespace infini_train::nn::parallel { +enum class P2POpType { + kSend, + kRecv, +}; + +struct P2POp { + P2POpType type; + std::shared_ptr tensor; + int peer_rank; +}; + class ProcessGroup { public: explicit ProcessGroup(Device::DeviceType backend, const std::string &process_group_name, @@ -63,6 +74,8 @@ class ProcessGroup { virtual std::shared_ptr Recv(std::vector> tensors, int src_rank, bool async_op = false) const; + virtual std::shared_ptr BatchSendRecv(const std::vector &ops, bool async_op = false) const; + // Legacy communication APIs (Single-stream) virtual std::vector> BroadCast(const std::vector> &input_tensors) const; diff --git a/infini_train/src/nn/modules/transformer/causal_self_attention.cc b/infini_train/src/nn/modules/transformer/causal_self_attention.cc index 4f6510ec..e7425a58 100644 --- a/infini_train/src/nn/modules/transformer/causal_self_attention.cc +++ b/infini_train/src/nn/modules/transformer/causal_self_attention.cc @@ -118,8 +118,7 @@ CausalSelfAttention::ForwardStandard(const std::vectorSlice({0, 0, q_start, 0}, {1, 1, q_start + T, T_kv}, {1, 1, 1, 1}); std::shared_ptr y; if (parallel::global::GetContextParallelSize() > 1) { - y = parallel::AttnForwardFuncWithCP(q, k, v, mask == 0, static_cast(1.0 / std::sqrt(head_dim)), - /*n_rep=*/1); + y = parallel::AttnForwardFuncWithCP(q, k, v, mask == 0); } else { // (B, h_l, T_local, T_global) auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(head_dim)); @@ -248,7 +247,7 @@ CausalSelfAttention::ForwardWithRoPE(const std::vector y; if (use_context_parallel) { - y = parallel::AttnForwardFuncWithCP(q, k, v, mask, 1.0f / std::sqrt(static_cast(D)), n_rep_); + y = parallel::AttnForwardFuncWithCP(q, k, v, mask); } else { // manual implementation of attention // this materializes the large (T,T) matrix for all the queries and keys diff --git a/infini_train/src/nn/parallel/context_parallel.cc b/infini_train/src/nn/parallel/context_parallel.cc index f424e366..eee4d69e 100644 --- a/infini_train/src/nn/parallel/context_parallel.cc +++ b/infini_train/src/nn/parallel/context_parallel.cc @@ -1,5 +1,6 @@ #include "infini_train/include/nn/parallel/context_parallel.h" +#include #include #include #include @@ -139,6 +140,37 @@ std::shared_ptr NewZeroTensorLike(const std::shared_ptr &tensor) return output; } +std::vector> P2PCommunicate(int rank, const std::vector> &send_tensors, + int send_dst, + const std::vector> &recv_tensors, + int recv_src, const ProcessGroup *cp_group, bool batch_p2p_comm) { + std::vector ops; + ops.reserve(send_tensors.size() + recv_tensors.size()); + std::vector> works; + + if (rank % 2 == 0) { + if (batch_p2p_comm) { + for (const auto &tensor : send_tensors) { ops.push_back({P2POpType::kSend, tensor, send_dst}); } + for (const auto &tensor : recv_tensors) { ops.push_back({P2POpType::kRecv, tensor, recv_src}); } + } else { + works.push_back(cp_group->Send(send_tensors, send_dst, true)); + works.push_back(cp_group->Recv(recv_tensors, recv_src, true)); + } + } else { + if (batch_p2p_comm) { + for (const auto &tensor : recv_tensors) { ops.push_back({P2POpType::kRecv, tensor, recv_src}); } + for (const auto &tensor : send_tensors) { ops.push_back({P2POpType::kSend, tensor, send_dst}); } + } else { + works.push_back(cp_group->Recv(recv_tensors, recv_src, true)); + works.push_back(cp_group->Send(send_tensors, send_dst, true)); + } + } + if (batch_p2p_comm) { + works.push_back(cp_group->BatchSendRecv(ops, true)); + } + return works; +} + std::shared_ptr RepeatKVHeads(const std::shared_ptr &x, int64_t n_rep) { if (n_rep == 1) { return x; @@ -161,8 +193,8 @@ std::shared_ptr SumRepeatedKVHeads(const std::shared_ptr &x, int } std::shared_ptr ApplyCoreAttention(const std::shared_ptr &q, const std::shared_ptr &k, - const std::shared_ptr &v, const std::shared_ptr &mask, - float scale) { + const std::shared_ptr &v, const std::shared_ptr &mask) { + const float scale = static_cast(1.0 / std::sqrt(static_cast(q->Dims().back()))); // scores: (B, H, T_q, T_k) auto scores = q->Matmul(k->Transpose(-2, -1)) * scale; if (mask) { @@ -189,7 +221,8 @@ class GatherFromCPRegion : public autograd::Function { } std::vector> Backward(const std::vector> &grad_outputs) override { - // FIXME(zbl): See Forward() for the [B, H, S, D] sequence-first bridge. + // FIXME(zbl): Megatron keeps sequence as dim 0. We uses [B, H, S, D], so move only + // the sequence dimension to dim 0 before the CP gather. return {ReduceScatterAlongFirstDim(grad_outputs[0]->Transpose(0, 2))->Transpose(0, 2)->Contiguous()}; } }; @@ -228,7 +261,7 @@ class AttnWithCPAndKVP2P : public autograd::Function { public: static constexpr char kType[] = "AttnWithCPAndKVP2PFunction"; - AttnWithCPAndKVP2P(float scale, int64_t n_rep) : autograd::Function(kType), scale_(scale), n_rep_(n_rep) {} + AttnWithCPAndKVP2P() : autograd::Function(kType) {} std::vector> Forward(const std::vector> &input_tensors) override { CHECK_EQ(input_tensors.size(), 4); @@ -252,12 +285,16 @@ class AttnWithCPAndKVP2P : public autograd::Function { const int rank = cp_group->GetGroupRank(q->GetDevice().Rank().GlobalRank()); const int send_to = (rank + 1) % cp_size; const int recv_from = (rank - 1 + cp_size) % cp_size; + // NOTE(zbl): Megatron-LM enables batched P2P by default for CP=2 on pre-Blackwell GPUs. + const bool batch_p2p_comm = (cp_size == 2); const int64_t local_t = q->Dims()[2]; CHECK_EQ(k_local->Dims()[2], local_t); CHECK_EQ(v_local->Dims()[2], local_t); - CHECK_EQ(q->Dims()[1], k_local->Dims()[1] * n_rep_); - CHECK_EQ(q->Dims()[1], v_local->Dims()[1] * n_rep_); + CHECK_EQ(k_local->Dims()[1], v_local->Dims()[1]); + CHECK_EQ(q->Dims()[1] % k_local->Dims()[1], 0); + const int64_t n_rep = q->Dims()[1] / k_local->Dims()[1]; + const float scale = static_cast(1.0 / std::sqrt(static_cast(q->Dims().back()))); // current_k: (B, H_kv, T_l, D), owned by rank `(rank - step + cp_size) % cp_size`. auto current_k = k_local; @@ -273,20 +310,14 @@ class AttnWithCPAndKVP2P : public autograd::Function { for (int step = 0; step < cp_size; ++step) { std::shared_ptr next_k; std::shared_ptr next_v; - std::shared_ptr send_work; - std::shared_ptr recv_work; + std::vector> p2p_works; if (step + 1 < cp_size) { // next_k: (B, H_kv, T_l, D) next_k = std::make_shared(k_local->Dims(), k_local->Dtype(), k_local->GetDevice()); // next_v: (B, H_kv, T_l, D) next_v = std::make_shared(v_local->Dims(), v_local->Dtype(), v_local->GetDevice()); - if (rank % 2 == 0) { - send_work = cp_group->Send({current_k, current_v}, send_to, true); - recv_work = cp_group->Recv({next_k, next_v}, recv_from, true); - } else { - recv_work = cp_group->Recv({next_k, next_v}, recv_from, true); - send_work = cp_group->Send({current_k, current_v}, send_to, true); - } + p2p_works = P2PCommunicate(rank, {current_k, current_v}, send_to, {next_k, next_v}, recv_from, cp_group, + batch_p2p_comm); } const int owner = (rank - step + cp_size) % cp_size; @@ -294,11 +325,11 @@ class AttnWithCPAndKVP2P : public autograd::Function { const int64_t kv_end = kv_start + local_t; // k_for_attn: (B, H_q, T_l, D) - auto k_for_attn = RepeatKVHeads(current_k, n_rep_); + auto k_for_attn = RepeatKVHeads(current_k, n_rep); // v_for_attn: (B, H_q, T_l, D) - auto v_for_attn = RepeatKVHeads(current_v, n_rep_); + auto v_for_attn = RepeatKVHeads(current_v, n_rep); // scores: (B, H_q, T_l, T_l) - auto scores = q->Matmul(k_for_attn->Transpose(-2, -1)) * scale_; + auto scores = q->Matmul(k_for_attn->Transpose(-2, -1)) * scale; // invalid_mask: (1, 1, T_l, T_l) auto invalid_mask = mask->Slice(3, kv_start, kv_end); scores = scores->MaskedFill(invalid_mask, std::numeric_limits::lowest()); @@ -329,9 +360,8 @@ class AttnWithCPAndKVP2P : public autograd::Function { running_max = new_max; } - if (recv_work) { - recv_work->WaitNonBlocking(); - send_work->WaitNonBlocking(); + if (!p2p_works.empty()) { + for (const auto &work : p2p_works) { work->WaitNonBlocking(); } current_k = next_k; current_v = next_v; } @@ -355,15 +385,14 @@ class AttnWithCPAndKVP2P : public autograd::Function { } std::vector> Backward(const std::vector> &grad_outputs) override { - // This backward is intentionally handwritten. Although the p2p Forward is composed from - // small Tensor ops, InfiniTrain executes autograd::Function::Forward() under NoGradGuard, - // so those ops are not recorded in the autograd graph. Raw CP p2p Send/Recv also has no - // autograd edge; the backward ring must explicitly accumulate every Q shard's contribution - // and return each K/V chunk's gradients to its owner rank. - // // Shape notation: - // B: batch size, H_q: local query heads after TP, H_kv: local KV heads before GQA repeat, - // T_l: CP-local sequence length, T_g: global sequence length, D: head dimension. + // - B: batch size + // - H_q: local query heads after TP + // - H_kv: local KV heads before GQA repeat + // - T_l: CP-local sequence length + // - T_g: global sequence length + // - D: head dimension. + CHECK_GE(grad_outputs.size(), 1); auto saved_tensors = ctx_.GetSavedTensors(); CHECK_EQ(saved_tensors.size(), 7); @@ -393,12 +422,16 @@ class AttnWithCPAndKVP2P : public autograd::Function { const int rank = cp_group->GetGroupRank(q->GetDevice().Rank().GlobalRank()); const int send_to = (rank + 1) % cp_size; const int recv_from = (rank - 1 + cp_size) % cp_size; + // NOTE(zbl): Megatron-LM enables batched P2P by default for CP=2 on pre-Blackwell GPUs. + const bool batch_p2p_comm = (cp_size == 2); const int64_t local_t = q->Dims()[2]; CHECK_EQ(k_local->Dims()[2], local_t); CHECK_EQ(v_local->Dims()[2], local_t); - CHECK_EQ(q->Dims()[1], k_local->Dims()[1] * n_rep_); - CHECK_EQ(q->Dims()[1], v_local->Dims()[1] * n_rep_); + CHECK_EQ(k_local->Dims()[1], v_local->Dims()[1]); + CHECK_EQ(q->Dims()[1] % k_local->Dims()[1], 0); + const int64_t n_rep = q->Dims()[1] / k_local->Dims()[1]; + const float scale = static_cast(1.0 / std::sqrt(static_cast(q->Dims().back()))); // current_k: (B, H_kv, T_l, D) auto current_k = k_local; @@ -419,11 +452,11 @@ class AttnWithCPAndKVP2P : public autograd::Function { const int64_t kv_end = kv_start + local_t; // k_for_attn: (B, H_q, T_l, D) - auto k_for_attn = RepeatKVHeads(current_k, n_rep_); + auto k_for_attn = RepeatKVHeads(current_k, n_rep); // v_for_attn: (B, H_q, T_l, D) - auto v_for_attn = RepeatKVHeads(current_v, n_rep_); + auto v_for_attn = RepeatKVHeads(current_v, n_rep); // scores: (B, H_q, T_l, T_l) - auto scores = q->Matmul(k_for_attn->Transpose(-2, -1)) * scale_; + auto scores = q->Matmul(k_for_attn->Transpose(-2, -1)) * scale; // invalid_mask: (1, 1, T_l, T_l) auto invalid_mask = mask->Slice(3, kv_start, kv_end); scores = scores->MaskedFill(invalid_mask, std::numeric_limits::lowest()); @@ -437,14 +470,13 @@ class AttnWithCPAndKVP2P : public autograd::Function { // grad_scores: (B, H_q, T_l, T_l) auto grad_scores = probs * (grad_probs - softmax_delta); // grad_k_repeated: (B, H_q, T_l, D) - auto grad_k_repeated = grad_scores->Transpose(-2, -1)->Matmul(q) * scale_; + auto grad_k_repeated = grad_scores->Transpose(-2, -1)->Matmul(q) * scale; // SumRepeatedKVHeads maps repeated GQA gradients from (B, H_q, T_l, D) to (B, H_kv, T_l, D). - current_grad_k = current_grad_k + SumRepeatedKVHeads(grad_k_repeated, n_rep_); - current_grad_v = current_grad_v + SumRepeatedKVHeads(grad_v_repeated, n_rep_); + current_grad_k = current_grad_k + SumRepeatedKVHeads(grad_k_repeated, n_rep); + current_grad_v = current_grad_v + SumRepeatedKVHeads(grad_v_repeated, n_rep); - std::shared_ptr send_work; - std::shared_ptr recv_work; + std::vector> p2p_works; std::shared_ptr next_k; std::shared_ptr next_v; std::shared_ptr next_grad_k; @@ -460,34 +492,24 @@ class AttnWithCPAndKVP2P : public autograd::Function { next_grad_k = NewZeroTensorLike(k_local); // next_grad_v: (B, H_kv, T_l, D) next_grad_v = NewZeroTensorLike(v_local); - if (rank % 2 == 0) { - send_work = cp_group->Send({current_k, current_v, current_grad_k, current_grad_v}, send_to, true); - recv_work = cp_group->Recv({next_k, next_v, next_grad_k, next_grad_v}, recv_from, true); - } else { - recv_work = cp_group->Recv({next_k, next_v, next_grad_k, next_grad_v}, recv_from, true); - send_work = cp_group->Send({current_k, current_v, current_grad_k, current_grad_v}, send_to, true); - } + p2p_works + = P2PCommunicate(rank, {current_k, current_v, current_grad_k, current_grad_v}, send_to, + {next_k, next_v, next_grad_k, next_grad_v}, recv_from, cp_group, batch_p2p_comm); } else { // Last step only needs to rotate accumulated K/V grads back to the local owner rank. // next_grad_k: (B, H_kv, T_l, D) next_grad_k = NewZeroTensorLike(k_local); // next_grad_v: (B, H_kv, T_l, D) next_grad_v = NewZeroTensorLike(v_local); - if (rank % 2 == 0) { - send_work = cp_group->Send({current_grad_k, current_grad_v}, send_to, true); - recv_work = cp_group->Recv({next_grad_k, next_grad_v}, recv_from, true); - } else { - recv_work = cp_group->Recv({next_grad_k, next_grad_v}, recv_from, true); - send_work = cp_group->Send({current_grad_k, current_grad_v}, send_to, true); - } + p2p_works = P2PCommunicate(rank, {current_grad_k, current_grad_v}, send_to, {next_grad_k, next_grad_v}, + recv_from, cp_group, batch_p2p_comm); } // grad_q_chunk: (B, H_q, T_l, D) - auto grad_q_chunk = grad_scores->Matmul(k_for_attn) * scale_; + auto grad_q_chunk = grad_scores->Matmul(k_for_attn) * scale; grad_q = grad_q ? grad_q + grad_q_chunk : grad_q_chunk; - recv_work->WaitNonBlocking(); - send_work->WaitNonBlocking(); + for (const auto &work : p2p_works) { work->WaitNonBlocking(); } if (step + 1 < cp_size) { current_k = next_k; @@ -502,10 +524,6 @@ class AttnWithCPAndKVP2P : public autograd::Function { return {grad_q, current_grad_k, current_grad_v, nullptr}; } - -private: - float scale_ = 1.0f; - int64_t n_rep_ = 1; }; std::shared_ptr AllToAllSeqToHeadCPRegionFunc(const std::shared_ptr &input) { @@ -551,14 +569,16 @@ std::shared_ptr GatherFromCPRegionFunc(const std::shared_ptr &in // CP Attention Backend Functions std::shared_ptr AttnFuncWithCPAndKVP2P(const std::shared_ptr &q, const std::shared_ptr &k, - const std::shared_ptr &v, const std::shared_ptr &mask, - float scale, int64_t n_rep) { - return std::make_shared(scale, n_rep)->Apply({q, k, v, mask})[0]; + const std::shared_ptr &v, const std::shared_ptr &mask) { + return std::make_shared()->Apply({q, k, v, mask})[0]; } std::shared_ptr AttnFuncWithCPAndKVAllGather(const std::shared_ptr &q, const std::shared_ptr &k, const std::shared_ptr &v, - const std::shared_ptr &mask, float scale, int64_t n_rep) { + const std::shared_ptr &mask) { + CHECK_EQ(k->Dims()[1], v->Dims()[1]); + CHECK_EQ(q->Dims()[1] % k->Dims()[1], 0); + const int64_t n_rep = q->Dims()[1] / k->Dims()[1]; // gathered_k: (B, H_kv, T_g, D) auto gathered_k = GatherFromCPRegionFunc(k); // gathered_v: (B, H_kv, T_g, D) @@ -567,17 +587,19 @@ std::shared_ptr AttnFuncWithCPAndKVAllGather(const std::shared_ptr AttnFuncWithCPAndQKVOA2A(const std::shared_ptr &q, const std::shared_ptr &k, - const std::shared_ptr &v, const std::shared_ptr &mask, - float scale, int64_t n_rep) { + const std::shared_ptr &v, + const std::shared_ptr &mask) { const int cp_size = global::GetContextParallelSize(); const int64_t q_heads = q->Dims()[1]; const int64_t kv_heads = k->Dims()[1]; + CHECK_EQ(kv_heads, v->Dims()[1]); CHECK_EQ(q_heads % cp_size, 0) << "A2A CP requires local query heads divisible by CP size"; CHECK_EQ(kv_heads % cp_size, 0) << "A2A CP requires local KV heads divisible by CP size"; + CHECK_EQ(q_heads % kv_heads, 0); // q_shard: (B, H_q/CP, T_g, D) auto q_shard = AllToAllSeqToHeadCPRegionFunc(q); @@ -588,29 +610,31 @@ std::shared_ptr AttnFuncWithCPAndQKVOA2A(const std::shared_ptr & // full_mask: (1, 1, T_g, T_g) auto full_mask = mask ? GatherFromCPRegionFunc(mask) : nullptr; + const int64_t n_rep = q_shard->Dims()[1] / k_shard->Dims()[1]; // k_for_attn: (B, H_q/CP, T_g, D) auto k_for_attn = RepeatKVHeads(k_shard, n_rep); // v_for_attn: (B, H_q/CP, T_g, D) auto v_for_attn = RepeatKVHeads(v_shard, n_rep); // output_shard: (B, H_q/CP, T_g, D) - auto output_shard = ApplyCoreAttention(q_shard, k_for_attn, v_for_attn, full_mask, scale); + auto output_shard = ApplyCoreAttention(q_shard, k_for_attn, v_for_attn, full_mask); // output: (B, H_q, T_l, D) return AllToAllHeadToSeqCPRegionFunc(output_shard); } std::shared_ptr AttnForwardFuncWithCP(const std::shared_ptr &q, const std::shared_ptr &k, - const std::shared_ptr &v, const std::shared_ptr &mask, - float scale, int64_t n_rep) { + const std::shared_ptr &v, const std::shared_ptr &mask) { CHECK_GT(global::GetContextParallelSize(), 1); const auto comm_type = global::GetContextParallelCommType(); if (comm_type == "p2p") { - return AttnFuncWithCPAndKVP2P(q, k, v, mask, scale, n_rep); - } - if (comm_type == "a2a") { - return AttnFuncWithCPAndQKVOA2A(q, k, v, mask, scale, n_rep); + return AttnFuncWithCPAndKVP2P(q, k, v, mask); + } else if (comm_type == "a2a") { + return AttnFuncWithCPAndQKVOA2A(q, k, v, mask); + } else if (comm_type == "all_gather") { + return AttnFuncWithCPAndKVAllGather(q, k, v, mask); + } else { + LOG(FATAL) << "AttnForwardFuncWithCP: Unsupported communication type " << comm_type << "."; } - return AttnFuncWithCPAndKVAllGather(q, k, v, mask, scale, n_rep); } } // namespace infini_train::nn::parallel diff --git a/infini_train/src/nn/parallel/process_group.cc b/infini_train/src/nn/parallel/process_group.cc index 2b5f53a1..9cddcc0c 100644 --- a/infini_train/src/nn/parallel/process_group.cc +++ b/infini_train/src/nn/parallel/process_group.cc @@ -275,6 +275,46 @@ std::shared_ptr ProcessGroup::Recv(std::vector> te } } +std::shared_ptr ProcessGroup::BatchSendRecv(const std::vector &ops, bool async_op) const { + CHECK_GT(ops.size(), 0); + CHECK_NOTNULL(ops[0].tensor); + auto device = ops[0].tensor->GetDevice(); + core::DeviceGuard guard(device); + auto *compute_stream = runtime_impl_->GetStream(device); + auto *comm_stream = device_stream_map_.at(device.index()); + auto comm = device_comm_map_.at(device.index()); + + auto work = std::make_shared(device, comm); + runtime_impl_->EventRecord(work->ready_event(), compute_stream); + runtime_impl_->StreamWaitEvent(comm_stream, work->ready_event(), 0); + + { + core::CclGroupGuard ccl_group_guard(backend_); + for (const auto &op : ops) { + CHECK_NOTNULL(op.tensor); + CHECK_EQ(device, op.tensor->GetDevice()); + CHECK_GE(op.peer_rank, 0); + CHECK_LT(op.peer_rank, world_size_); + if (op.type == P2POpType::kSend) { + ccl_impl_->Send(op.tensor->DataPtr(), op.tensor->NumElements(), op.tensor->Dtype(), op.peer_rank, comm, + comm_stream); + } else { + ccl_impl_->Recv(op.tensor->DataPtr(), op.tensor->NumElements(), op.tensor->Dtype(), op.peer_rank, comm, + comm_stream); + } + } + } + + runtime_impl_->EventRecord(work->done_event(), comm_stream); + + if (async_op) { + return work; + } else { + work->WaitNonBlocking(); + return nullptr; + } +} + std::vector> ProcessGroup::BroadCast(const std::vector> &input_tensors) const { std::vector> outputs; diff --git a/scripts/test_config.json b/scripts/test_config.json index f33b8905..be6dad77 100644 --- a/scripts/test_config.json +++ b/scripts/test_config.json @@ -1,10 +1,10 @@ { "variables": { "BUILD_DIR": "../build", - "GPT2_INPUT_BIN": "/data/shared/InfiniTrain-dev/data/llmc/gpt2/tinyshakespeare/tiny_shakespeare_train.bin", - "GPT2_LLMC_FILEPATH": "/data/shared/InfiniTrain-dev/data/llmc/gpt2/gpt2_124M.bin", - "LLAMA3_INPUT_BIN": "/data/shared/InfiniTrain-dev/data/llmc/llama3/tinyshakespeare/tiny_shakespeare_train.bin", - "LLAMA3_LLMC_FILEPATH": "/data/shared/InfiniTrain-dev/data/llmc/llama3/llama3.2_1B_fp32.bin", + "GPT2_INPUT_BIN": "/data1/shared/InfiniTrain-dev/data/llmc/gpt2/tinyshakespeare/tiny_shakespeare_train.bin", + "GPT2_LLMC_FILEPATH": "/data1/shared/InfiniTrain-dev/data/llmc/gpt2/gpt2_124M.bin", + "LLAMA3_INPUT_BIN": "/data1/shared/InfiniTrain-dev/data/llmc/llama3/tinyshakespeare/tiny_shakespeare_train.bin", + "LLAMA3_LLMC_FILEPATH": "/data1/shared/InfiniTrain-dev/data/llmc/llama3/llama3.2_1B_fp32.bin", "PROFILE_LOG_DIR": "./profile_logs", "LOG_DIR": "./logs", "CKPT_ROOT_DIR": "/data1/ckpt",