From 6fae6b7d60302bb6612a0ca69378962526d412d1 Mon Sep 17 00:00:00 2001 From: cx Date: Fri, 3 Jul 2026 09:11:44 +0000 Subject: [PATCH] Support torchrun-style InfiniTrain multi-process launch --- example/gpt2/main.cc | 7 +- example/llama3/main.cc | 7 +- infini_train/include/nn/parallel/global.h | 1 + infini_train/src/device.cc | 11 +- infini_train/src/nn/parallel/data_parallel.cc | 6 +- .../parallel/ddp/distributed_data_parallel.cc | 6 +- infini_train/src/nn/parallel/global.cc | 29 ++++- infini_train/src/nn/parallel/process_group.cc | 2 +- scripts/run_models_and_profile.bash | 19 ++- scripts/test_config.json | 120 ++++++++++++------ tools/infini_run/infini_run.cc | 100 +++++++++++++-- 11 files changed, 238 insertions(+), 70 deletions(-) diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index b666b6b1..ecb44df5 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -170,7 +170,7 @@ void Train(const nn::parallel::Rank &rank) { const ProcessGroup *pp_pg = nullptr; if (rank.IsParallel()) { - device = Device(Device::DeviceType::kCUDA, rank.thread_rank()); + device = Device(Device::DeviceType::kCUDA, global::GetLocalDeviceIndex(rank.thread_rank())); auto *pg_factory = ProcessGroupFactory::Instance(device.type()); if (ddp_world_size > 1) { @@ -540,7 +540,6 @@ int main(int argc, char *argv[]) { LOG(INFO) << nn::parallel::global::ProcessGroupOverview(); - // NOTE(dcj): currently we only support single process if (FLAGS_nthread_per_process > 1) { std::vector threads; for (int idx = 0; idx < FLAGS_nthread_per_process; ++idx) { @@ -551,7 +550,9 @@ int main(int argc, char *argv[]) { for (auto &thread : threads) { thread.join(); } } else { - Train({0, 0, 1, 1}); + nn::parallel::Rank rank(nn::parallel::global::GetGlobalProcRank(), 0, nn::parallel::global::GetNprocPerNode(), + FLAGS_nthread_per_process); + Train(rank); } gflags::ShutDownCommandLineFlags(); diff --git a/example/llama3/main.cc b/example/llama3/main.cc index 5a191865..ea2d144a 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -156,7 +156,7 @@ void Train(const nn::parallel::Rank &rank) { const ProcessGroup *pp_pg = nullptr; if (rank.IsParallel()) { - device = Device(Device::DeviceType::kCUDA, rank.thread_rank()); + device = Device(Device::DeviceType::kCUDA, global::GetLocalDeviceIndex(rank.thread_rank())); auto *pg_factory = ProcessGroupFactory::Instance(device.type()); if (ddp_world_size > 1) { @@ -517,7 +517,6 @@ int main(int argc, char *argv[]) { LOG(INFO) << nn::parallel::global::ProcessGroupOverview(); - // NOTE(dcj): currently we only support single process if (FLAGS_nthread_per_process > 1) { std::vector threads; for (int idx = 0; idx < FLAGS_nthread_per_process; ++idx) { @@ -528,7 +527,9 @@ int main(int argc, char *argv[]) { for (auto &thread : threads) { thread.join(); } } else { - Train({0, 0, 1, 1}); + nn::parallel::Rank rank(nn::parallel::global::GetGlobalProcRank(), 0, nn::parallel::global::GetNprocPerNode(), + FLAGS_nthread_per_process); + Train(rank); } gflags::ShutDownCommandLineFlags(); diff --git a/infini_train/include/nn/parallel/global.h b/infini_train/include/nn/parallel/global.h index 9373100f..c7c74a99 100644 --- a/infini_train/include/nn/parallel/global.h +++ b/infini_train/include/nn/parallel/global.h @@ -98,6 +98,7 @@ inline int GetNprocPerNode() { return GlobalEnv::Instance().nproc_per_node(); } inline int GetNthreadPerProc() { return GlobalEnv::Instance().nthread_per_process(); } inline int GetGlobalProcRank() { return GlobalEnv::Instance().global_proc_rank(); } inline int GetLocalProcRank() { return GlobalEnv::Instance().local_proc_rank(); } +inline int GetLocalDeviceIndex(int thread_rank = 0) { return GetLocalProcRank() * GetNthreadPerProc() + thread_rank; } inline int GetTensorParallelSize() { return GlobalEnv::Instance().tensor_parallel_size(); } inline int GetSequenceParallelSize() { return GlobalEnv::Instance().sequence_parallel_size(); } diff --git a/infini_train/src/device.cc b/infini_train/src/device.cc index 1bb3aaad..db10cd54 100644 --- a/infini_train/src/device.cc +++ b/infini_train/src/device.cc @@ -33,7 +33,16 @@ std::string Device::ToString() const { } nn::parallel::Rank Device::Rank() const { - return {nn::parallel::global::GetGlobalProcRank(), index_, nn::parallel::global::GetNprocPerNode(), + if (IsCPU()) { + return {nn::parallel::global::GetGlobalProcRank(), 0, nn::parallel::global::GetNprocPerNode(), + nn::parallel::global::GetNthreadPerProc()}; + } + + const int thread_rank = index_ - nn::parallel::global::GetLocalDeviceIndex(); + CHECK_GE(thread_rank, 0) << "CUDA device index is outside the current process rank range"; + CHECK_LT(thread_rank, nn::parallel::global::GetNthreadPerProc()) + << "CUDA device index is outside the current process rank range"; + return {nn::parallel::global::GetGlobalProcRank(), thread_rank, nn::parallel::global::GetNprocPerNode(), nn::parallel::global::GetNthreadPerProc()}; } diff --git a/infini_train/src/nn/parallel/data_parallel.cc b/infini_train/src/nn/parallel/data_parallel.cc index c48761b7..6199d39f 100644 --- a/infini_train/src/nn/parallel/data_parallel.cc +++ b/infini_train/src/nn/parallel/data_parallel.cc @@ -60,7 +60,11 @@ ParallelApply(const std::vector> &modules, DataParallel::DataParallel(const std::shared_ptr &module, int dim, Device::DeviceType device_type) : dim_(dim) { devices_.reserve(global::GetNthreadPerProc()); - for (int index = 0; index < global::GetNthreadPerProc(); ++index) { devices_.emplace_back(device_type, index); } + for (int thread_rank = 0; thread_rank < global::GetNthreadPerProc(); ++thread_rank) { + const int device_index + = device_type == Device::DeviceType::kCUDA ? global::GetLocalDeviceIndex(thread_rank) : thread_rank; + devices_.emplace_back(device_type, device_index); + } CHECK_GT(devices_.size(), 0) << "No available devices found"; output_device_ = devices_.at(0); diff --git a/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc b/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc index a3bfe008..1d4aa467 100644 --- a/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc +++ b/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc @@ -10,6 +10,7 @@ #include "infini_train/include/autograd/function_hook.h" #include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/parallel_functional.h" #include "infini_train/include/nn/parallel/process_group.h" #include "infini_train/include/nn/parallel/rank.h" @@ -36,7 +37,8 @@ DistributedDataParallel::DistributedDataParallel(std::shared_ptr mod continue; } auto device = param->GetDevice(); - CHECK_EQ(device.index(), rank.thread_rank()) << "All parameters must be on the same device as the module"; + CHECK_EQ(device.index(), global::GetLocalDeviceIndex(rank.thread_rank())) + << "All parameters must be on the same device as the module"; if (!ddp_config.gradient_bucketing_enabled && ddp_config.zero_stage < 1) { auto hook = std::make_unique( function::ReduceOpType::kAvg, ddp_pg_); @@ -44,7 +46,7 @@ DistributedDataParallel::DistributedDataParallel(std::shared_ptr mod } } for (auto &buffer : module->Buffers()) { - CHECK_EQ(buffer->GetDevice().index(), rank.thread_rank()) + CHECK_EQ(buffer->GetDevice().index(), global::GetLocalDeviceIndex(rank.thread_rank())) << "All buffers must be on the same device as the module"; } modules_[kModuleName] = std::move(module); diff --git a/infini_train/src/nn/parallel/global.cc b/infini_train/src/nn/parallel/global.cc index 65a3208e..e4399128 100644 --- a/infini_train/src/nn/parallel/global.cc +++ b/infini_train/src/nn/parallel/global.cc @@ -13,6 +13,8 @@ int GetEnvAsInt(const std::string &name, int default_value) { return value ? std::atoi(value) : default_value; } +bool HasEnv(const std::string &name) { return std::getenv(name.c_str()) != nullptr; } + } // namespace namespace infini_train::nn::parallel::global { @@ -92,13 +94,30 @@ void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool seq CHECK(!initialized_) << "Repeated initialization of GlobalEnv!"; - nnodes_ = GetEnvAsInt("NNODES", 1); - nproc_per_node_ = GetEnvAsInt("NPROC_PER_NODE", 1); - world_size_ = GetEnvAsInt("PROC_WORLD_SIZE", 1) * nthread_per_process; - global_proc_rank_ = GetEnvAsInt("GLOBAL_PROC_RANK", 0); - local_proc_rank_ = GetEnvAsInt("LOCAL_PROC_RANK", 0); + const int proc_world_size = GetEnvAsInt("PROC_WORLD_SIZE", GetEnvAsInt("WORLD_SIZE", 1)); + nproc_per_node_ = GetEnvAsInt("NPROC_PER_NODE", GetEnvAsInt("LOCAL_WORLD_SIZE", 1)); + CHECK_GT(nproc_per_node_, 0) << "NPROC_PER_NODE/LOCAL_WORLD_SIZE must be positive"; + CHECK_GT(proc_world_size, 0) << "PROC_WORLD_SIZE/WORLD_SIZE must be positive"; + CHECK_EQ(proc_world_size % nproc_per_node_, 0) + << "PROC_WORLD_SIZE/WORLD_SIZE must be divisible by NPROC_PER_NODE/LOCAL_WORLD_SIZE"; + const bool nnodes_env_set = HasEnv("NNODES"); + nnodes_ = GetEnvAsInt("NNODES", proc_world_size / nproc_per_node_); + CHECK_GT(nnodes_, 0) << "NNODES must be positive"; + if (nnodes_env_set) { + CHECK_EQ(nnodes_ * nproc_per_node_, proc_world_size) + << "NNODES * NPROC_PER_NODE/LOCAL_WORLD_SIZE must equal PROC_WORLD_SIZE/WORLD_SIZE"; + } + global_proc_rank_ = GetEnvAsInt("GLOBAL_PROC_RANK", GetEnvAsInt("RANK", 0)); + local_proc_rank_ = GetEnvAsInt("LOCAL_PROC_RANK", GetEnvAsInt("LOCAL_RANK", 0)); + CHECK_GE(global_proc_rank_, 0) << "GLOBAL_PROC_RANK/RANK must be non-negative"; + CHECK_LT(global_proc_rank_, proc_world_size) + << "GLOBAL_PROC_RANK/RANK must be less than PROC_WORLD_SIZE/WORLD_SIZE"; + CHECK_GE(local_proc_rank_, 0) << "LOCAL_PROC_RANK/LOCAL_RANK must be non-negative"; + CHECK_LT(local_proc_rank_, nproc_per_node_) + << "LOCAL_PROC_RANK/LOCAL_RANK must be less than NPROC_PER_NODE/LOCAL_WORLD_SIZE"; nthread_per_process_ = nthread_per_process; + world_size_ = proc_world_size * nthread_per_process; CHECK_GE(tensor_parallel_size, 1) << "Tensor Parallel size must be >= 1"; tensor_parallel_size_ = tensor_parallel_size; sequence_parallel_enabled_ = sequence_parallel_enabled; diff --git a/infini_train/src/nn/parallel/process_group.cc b/infini_train/src/nn/parallel/process_group.cc index 3c4c4910..9dab7848 100644 --- a/infini_train/src/nn/parallel/process_group.cc +++ b/infini_train/src/nn/parallel/process_group.cc @@ -99,7 +99,7 @@ void ProcessGroup::InitMultiProcess(const std::vector &ranks) { int global_thread_rank = lower_rank + i; auto it = std::ranges::find(ranks, global_thread_rank); if (it != ranks.end()) { - auto device = Device(backend_, i); + auto device = Device(backend_, global::GetLocalDeviceIndex(i)); core::DeviceGuard guard(device); core::CclComm *comm_raw = nullptr; diff --git a/scripts/run_models_and_profile.bash b/scripts/run_models_and_profile.bash index 0e450bf1..234097f5 100755 --- a/scripts/run_models_and_profile.bash +++ b/scripts/run_models_and_profile.bash @@ -335,12 +335,25 @@ args_string_for_test() { | $args | (if has("save") then .save = namespaced_path(.save; $model; $run_mode) else . end) | (if has("load") then .load = namespaced_path(.load; $model; $resume_src_mode) else . end) + | del(.nproc_per_node) | to_entries[] | "--\(.key) \(.value|tostring)" ' "$CONFIG_FILE" | paste -sd' ' - | \ sed "s|@CKPT_ROOT_DIR@|${CKPT_ROOT_DIR}|g" } +model_cmd_for_test() { + local model_bin="$1" + local input_bin="$2" + local llmc_filepath="$3" + local arg_str="$4" + local nproc_per_node="$5" + + printf './infini_run --nproc_per_node=%s %s --input_bin %q --llmc_filepath %q --device cuda %s' \ + "$nproc_per_node" "$model_bin" \ + "$input_bin" "$llmc_filepath" "$arg_str" +} + # Run tests num_basic_compile_commands=$(jq '.basic_compile_commands | length' "$CONFIG_FILE") num_groups=$(jq '.test_groups | length' "$CONFIG_FILE") @@ -410,15 +423,17 @@ for ((id=0; id BuildLauncherArgv(int train_program_index, char **argv) { + std::vector launcher_argv; + launcher_argv.push_back(argv[0]); + for (int i = 1; i < train_program_index; ++i) { launcher_argv.push_back(argv[i]); } + launcher_argv.push_back(nullptr); + return launcher_argv; +} + +void SetEnvInt(const char *name, int value) { + const auto value_str = std::to_string(value); + setenv(name, value_str.c_str(), 1); +} + +} // namespace + int main(int argc, char **argv) { - gflags::ParseCommandLineFlags(&argc, &argv, true); + const int train_program_index = FindTrainProgramIndex(argc, argv); + std::vector launcher_argv = BuildLauncherArgv(train_program_index, argv); + int launcher_argc = static_cast(launcher_argv.size()) - 1; + char **launcher_argv_ptr = launcher_argv.data(); + gflags::ParseCommandLineFlags(&launcher_argc, &launcher_argv_ptr, true); google::InitGoogleLogging(argv[0]); - CHECK_GE(argc, 2) << "No training prgram specified!"; + CHECK_GT(FLAGS_nnodes, 0) << "nnodes must be positive"; + CHECK_GT(FLAGS_nproc_per_node, 0) << "nproc_per_node must be positive"; + CHECK_NE(FLAGS_rdzv_endpoint.find(':'), std::string::npos) << "rdzv_endpoint must be host:port"; - std::string train_program = argv[1]; + CHECK_LT(train_program_index, argc) << "No training program specified!"; + + std::string train_program = argv[train_program_index]; + CHECK_NE(train_program, "--") << "Explicit '--' separator is not supported; pass the training program directly " + "after infini_run launcher flags"; std::vector train_argv; - for (int i = 1; i < argc; ++i) { train_argv.push_back(argv[i]); } + for (int i = train_program_index; i < argc; ++i) { train_argv.push_back(argv[i]); } train_argv.push_back(nullptr); - int world_size = FLAGS_nnodes * FLAGS_nproc_per_node; + int proc_world_size = FLAGS_nnodes * FLAGS_nproc_per_node; std::string master_addr = FLAGS_rdzv_endpoint.substr(0, FLAGS_rdzv_endpoint.find(':')); std::string master_port = FLAGS_rdzv_endpoint.substr(FLAGS_rdzv_endpoint.find(':') + 1); @@ -32,16 +82,23 @@ int main(int argc, char **argv) { pid_t pid = fork(); if (pid == 0) { int global_proc_rank = FLAGS_node_rank * FLAGS_nproc_per_node + local_proc_rank; - setenv("NNODES", std::to_string(FLAGS_nnodes).c_str(), 1); - setenv("NPROC_PER_NODE", std::to_string(FLAGS_nproc_per_node).c_str(), 1); + SetEnvInt("NNODES", FLAGS_nnodes); + SetEnvInt("NPROC_PER_NODE", FLAGS_nproc_per_node); + SetEnvInt("LOCAL_WORLD_SIZE", FLAGS_nproc_per_node); setenv("MASTER_ADDR", master_addr.c_str(), 1); setenv("MASTER_PORT", master_port.c_str(), 1); - setenv("GLOBAL_PROC_RANK", std::to_string(global_proc_rank).c_str(), 1); - setenv("LOCAL_PROC_RANK", std::to_string(local_proc_rank).c_str(), 1); + SetEnvInt("GLOBAL_PROC_RANK", global_proc_rank); + SetEnvInt("LOCAL_PROC_RANK", local_proc_rank); + SetEnvInt("RANK", global_proc_rank); + SetEnvInt("LOCAL_RANK", local_proc_rank); - setenv("PROC_WORLD_SIZE", std::to_string(world_size).c_str(), 1); + SetEnvInt("PROC_WORLD_SIZE", proc_world_size); + SetEnvInt("WORLD_SIZE", proc_world_size); + SetEnvInt("GROUP_RANK", FLAGS_node_rank); + SetEnvInt("ROLE_RANK", global_proc_rank); + SetEnvInt("ROLE_WORLD_SIZE", proc_world_size); execvp(train_program.c_str(), train_argv.data()); perror("exec failed"); @@ -49,10 +106,29 @@ int main(int argc, char **argv) { } } + int exit_code = 0; for (int i = 0; i < FLAGS_nproc_per_node; ++i) { int status; - wait(&status); + pid_t child = wait(&status); + if (child < 0) { + perror("wait failed"); + return 1; + } + + if (WIFEXITED(status)) { + int child_exit_code = WEXITSTATUS(status); + if (child_exit_code != 0 && exit_code == 0) { + exit_code = child_exit_code; + } + } else if (WIFSIGNALED(status)) { + int signal = WTERMSIG(status); + if (exit_code == 0) { + exit_code = 128 + signal; + } + } else if (exit_code == 0) { + exit_code = 1; + } } - return 0; + return exit_code; }