Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ void Train(const nn::parallel::Rank &rank) {
const ProcessGroup *pp_pg = nullptr;

if (rank.IsParallel()) {
device = Device(Device::DeviceType::kCUDA, rank.thread_rank());
device = Device(Device::DeviceType::kCUDA, global::GetLocalDeviceIndex(rank.thread_rank()));
auto *pg_factory = ProcessGroupFactory::Instance(device.type());

if (ddp_world_size > 1) {
Expand Down Expand Up @@ -540,7 +540,6 @@ int main(int argc, char *argv[]) {

LOG(INFO) << nn::parallel::global::ProcessGroupOverview();

// NOTE(dcj): currently we only support single process
if (FLAGS_nthread_per_process > 1) {
std::vector<std::thread> threads;
for (int idx = 0; idx < FLAGS_nthread_per_process; ++idx) {
Expand All @@ -551,7 +550,9 @@ int main(int argc, char *argv[]) {

for (auto &thread : threads) { thread.join(); }
} else {
Train({0, 0, 1, 1});
nn::parallel::Rank rank(nn::parallel::global::GetGlobalProcRank(), 0, nn::parallel::global::GetNprocPerNode(),
FLAGS_nthread_per_process);
Train(rank);
}

gflags::ShutDownCommandLineFlags();
Expand Down
7 changes: 4 additions & 3 deletions example/llama3/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ void Train(const nn::parallel::Rank &rank) {
const ProcessGroup *pp_pg = nullptr;

if (rank.IsParallel()) {
device = Device(Device::DeviceType::kCUDA, rank.thread_rank());
device = Device(Device::DeviceType::kCUDA, global::GetLocalDeviceIndex(rank.thread_rank()));
auto *pg_factory = ProcessGroupFactory::Instance(device.type());

if (ddp_world_size > 1) {
Expand Down Expand Up @@ -517,7 +517,6 @@ int main(int argc, char *argv[]) {

LOG(INFO) << nn::parallel::global::ProcessGroupOverview();

// NOTE(dcj): currently we only support single process
if (FLAGS_nthread_per_process > 1) {
std::vector<std::thread> threads;
for (int idx = 0; idx < FLAGS_nthread_per_process; ++idx) {
Expand All @@ -528,7 +527,9 @@ int main(int argc, char *argv[]) {

for (auto &thread : threads) { thread.join(); }
} else {
Train({0, 0, 1, 1});
nn::parallel::Rank rank(nn::parallel::global::GetGlobalProcRank(), 0, nn::parallel::global::GetNprocPerNode(),
FLAGS_nthread_per_process);
Train(rank);
}

gflags::ShutDownCommandLineFlags();
Expand Down
1 change: 1 addition & 0 deletions infini_train/include/nn/parallel/global.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ inline int GetNprocPerNode() { return GlobalEnv::Instance().nproc_per_node(); }
inline int GetNthreadPerProc() { return GlobalEnv::Instance().nthread_per_process(); }
inline int GetGlobalProcRank() { return GlobalEnv::Instance().global_proc_rank(); }
inline int GetLocalProcRank() { return GlobalEnv::Instance().local_proc_rank(); }
inline int GetLocalDeviceIndex(int thread_rank = 0) { return GetLocalProcRank() * GetNthreadPerProc() + thread_rank; }

inline int GetTensorParallelSize() { return GlobalEnv::Instance().tensor_parallel_size(); }
inline int GetSequenceParallelSize() { return GlobalEnv::Instance().sequence_parallel_size(); }
Expand Down
11 changes: 10 additions & 1 deletion infini_train/src/device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,16 @@ std::string Device::ToString() const {
}

nn::parallel::Rank Device::Rank() const {
return {nn::parallel::global::GetGlobalProcRank(), index_, nn::parallel::global::GetNprocPerNode(),
if (IsCPU()) {
return {nn::parallel::global::GetGlobalProcRank(), 0, nn::parallel::global::GetNprocPerNode(),
nn::parallel::global::GetNthreadPerProc()};
}

const int thread_rank = index_ - nn::parallel::global::GetLocalDeviceIndex();
CHECK_GE(thread_rank, 0) << "CUDA device index is outside the current process rank range";
CHECK_LT(thread_rank, nn::parallel::global::GetNthreadPerProc())
<< "CUDA device index is outside the current process rank range";
return {nn::parallel::global::GetGlobalProcRank(), thread_rank, nn::parallel::global::GetNprocPerNode(),
nn::parallel::global::GetNthreadPerProc()};
}

Expand Down
6 changes: 5 additions & 1 deletion infini_train/src/nn/parallel/data_parallel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@ ParallelApply(const std::vector<std::shared_ptr<Module>> &modules,

DataParallel::DataParallel(const std::shared_ptr<Module> &module, int dim, Device::DeviceType device_type) : dim_(dim) {
devices_.reserve(global::GetNthreadPerProc());
for (int index = 0; index < global::GetNthreadPerProc(); ++index) { devices_.emplace_back(device_type, index); }
for (int thread_rank = 0; thread_rank < global::GetNthreadPerProc(); ++thread_rank) {
const int device_index
= device_type == Device::DeviceType::kCUDA ? global::GetLocalDeviceIndex(thread_rank) : thread_rank;
devices_.emplace_back(device_type, device_index);
}

CHECK_GT(devices_.size(), 0) << "No available devices found";
output_device_ = devices_.at(0);
Expand Down
6 changes: 4 additions & 2 deletions infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "infini_train/include/autograd/function_hook.h"
#include "infini_train/include/nn/modules/module.h"
#include "infini_train/include/nn/parallel/global.h"
#include "infini_train/include/nn/parallel/parallel_functional.h"
#include "infini_train/include/nn/parallel/process_group.h"
#include "infini_train/include/nn/parallel/rank.h"
Expand All @@ -36,15 +37,16 @@ DistributedDataParallel::DistributedDataParallel(std::shared_ptr<nn::Module> mod
continue;
}
auto device = param->GetDevice();
CHECK_EQ(device.index(), rank.thread_rank()) << "All parameters must be on the same device as the module";
CHECK_EQ(device.index(), global::GetLocalDeviceIndex(rank.thread_rank()))
<< "All parameters must be on the same device as the module";
if (!ddp_config.gradient_bucketing_enabled && ddp_config.zero_stage < 1) {
auto hook = std::make_unique<infini_train::autograd::AllReducePostAccumulateHook>(
function::ReduceOpType::kAvg, ddp_pg_);
param->RegisterPostAccumulateGradHook(std::move(hook));
}
}
for (auto &buffer : module->Buffers()) {
CHECK_EQ(buffer->GetDevice().index(), rank.thread_rank())
CHECK_EQ(buffer->GetDevice().index(), global::GetLocalDeviceIndex(rank.thread_rank()))
<< "All buffers must be on the same device as the module";
}
modules_[kModuleName] = std::move(module);
Expand Down
29 changes: 24 additions & 5 deletions infini_train/src/nn/parallel/global.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ int GetEnvAsInt(const std::string &name, int default_value) {
return value ? std::atoi(value) : default_value;
}

bool HasEnv(const std::string &name) { return std::getenv(name.c_str()) != nullptr; }

} // namespace

namespace infini_train::nn::parallel::global {
Expand Down Expand Up @@ -92,13 +94,30 @@ void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool seq

CHECK(!initialized_) << "Repeated initialization of GlobalEnv!";

nnodes_ = GetEnvAsInt("NNODES", 1);
nproc_per_node_ = GetEnvAsInt("NPROC_PER_NODE", 1);
world_size_ = GetEnvAsInt("PROC_WORLD_SIZE", 1) * nthread_per_process;
global_proc_rank_ = GetEnvAsInt("GLOBAL_PROC_RANK", 0);
local_proc_rank_ = GetEnvAsInt("LOCAL_PROC_RANK", 0);
const int proc_world_size = GetEnvAsInt("PROC_WORLD_SIZE", GetEnvAsInt("WORLD_SIZE", 1));
nproc_per_node_ = GetEnvAsInt("NPROC_PER_NODE", GetEnvAsInt("LOCAL_WORLD_SIZE", 1));
CHECK_GT(nproc_per_node_, 0) << "NPROC_PER_NODE/LOCAL_WORLD_SIZE must be positive";
CHECK_GT(proc_world_size, 0) << "PROC_WORLD_SIZE/WORLD_SIZE must be positive";
CHECK_EQ(proc_world_size % nproc_per_node_, 0)
<< "PROC_WORLD_SIZE/WORLD_SIZE must be divisible by NPROC_PER_NODE/LOCAL_WORLD_SIZE";
const bool nnodes_env_set = HasEnv("NNODES");
nnodes_ = GetEnvAsInt("NNODES", proc_world_size / nproc_per_node_);
CHECK_GT(nnodes_, 0) << "NNODES must be positive";
if (nnodes_env_set) {
CHECK_EQ(nnodes_ * nproc_per_node_, proc_world_size)
<< "NNODES * NPROC_PER_NODE/LOCAL_WORLD_SIZE must equal PROC_WORLD_SIZE/WORLD_SIZE";
}
global_proc_rank_ = GetEnvAsInt("GLOBAL_PROC_RANK", GetEnvAsInt("RANK", 0));
local_proc_rank_ = GetEnvAsInt("LOCAL_PROC_RANK", GetEnvAsInt("LOCAL_RANK", 0));
CHECK_GE(global_proc_rank_, 0) << "GLOBAL_PROC_RANK/RANK must be non-negative";
CHECK_LT(global_proc_rank_, proc_world_size)
<< "GLOBAL_PROC_RANK/RANK must be less than PROC_WORLD_SIZE/WORLD_SIZE";
CHECK_GE(local_proc_rank_, 0) << "LOCAL_PROC_RANK/LOCAL_RANK must be non-negative";
CHECK_LT(local_proc_rank_, nproc_per_node_)
<< "LOCAL_PROC_RANK/LOCAL_RANK must be less than NPROC_PER_NODE/LOCAL_WORLD_SIZE";

nthread_per_process_ = nthread_per_process;
world_size_ = proc_world_size * nthread_per_process;
CHECK_GE(tensor_parallel_size, 1) << "Tensor Parallel size must be >= 1";
tensor_parallel_size_ = tensor_parallel_size;
sequence_parallel_enabled_ = sequence_parallel_enabled;
Expand Down
2 changes: 1 addition & 1 deletion infini_train/src/nn/parallel/process_group.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ void ProcessGroup::InitMultiProcess(const std::vector<int> &ranks) {
int global_thread_rank = lower_rank + i;
auto it = std::ranges::find(ranks, global_thread_rank);
if (it != ranks.end()) {
auto device = Device(backend_, i);
auto device = Device(backend_, global::GetLocalDeviceIndex(i));
core::DeviceGuard guard(device);

core::CclComm *comm_raw = nullptr;
Expand Down
19 changes: 17 additions & 2 deletions scripts/run_models_and_profile.bash
Original file line number Diff line number Diff line change
Expand Up @@ -335,12 +335,25 @@ args_string_for_test() {
| $args
| (if has("save") then .save = namespaced_path(.save; $model; $run_mode) else . end)
| (if has("load") then .load = namespaced_path(.load; $model; $resume_src_mode) else . end)
| del(.nproc_per_node)
| to_entries[]
| "--\(.key) \(.value|tostring)"
' "$CONFIG_FILE" | paste -sd' ' - | \
sed "s|@CKPT_ROOT_DIR@|${CKPT_ROOT_DIR}|g"
}

model_cmd_for_test() {
local model_bin="$1"
local input_bin="$2"
local llmc_filepath="$3"
local arg_str="$4"
local nproc_per_node="$5"

printf './infini_run --nproc_per_node=%s %s --input_bin %q --llmc_filepath %q --device cuda %s' \
"$nproc_per_node" "$model_bin" \
"$input_bin" "$llmc_filepath" "$arg_str"
}

# Run tests
num_basic_compile_commands=$(jq '.basic_compile_commands | length' "$CONFIG_FILE")
num_groups=$(jq '.test_groups | length' "$CONFIG_FILE")
Expand Down Expand Up @@ -410,15 +423,17 @@ for ((id=0; id<num_basic_compile_commands; ++id)); do

for ((ti=0; ti<num_tests; ++ti)); do
test_id=$(jq -r ".test_groups[$gi].tests[$ti].id" "$CONFIG_FILE")
nproc_per_node="$(jq -r ".test_groups[$gi].tests[$ti].args.nproc_per_node // empty" "$CONFIG_FILE")"
: "${nproc_per_node:=${INFINI_NPROC_PER_NODE:-1}}"
gpt2_arg_str="$(args_string_for_test "$gi" "$ti" "gpt2" "$test_id")"
llama3_arg_str="$(args_string_for_test "$gi" "$ti" "llama3" "$test_id")"

# gpt2
gpt2_cmd="${prefix}./gpt2 --input_bin ${GPT2_INPUT_BIN} --llmc_filepath ${GPT2_LLMC_FILEPATH} --device cuda ${gpt2_arg_str}"
gpt2_cmd="$(model_cmd_for_test "./gpt2" "$GPT2_INPUT_BIN" "$GPT2_LLMC_FILEPATH" "$gpt2_arg_str" "$nproc_per_node")"
run_and_log "$gpt2_cmd" "gpt2_${test_id}${log_suffix}" "$profile_flag" "$group_tag"

# llama3
llama3_cmd="${prefix}./llama3 --input_bin ${LLAMA3_INPUT_BIN} --llmc_filepath ${LLAMA3_LLMC_FILEPATH} --device cuda ${llama3_arg_str}"
llama3_cmd="$(model_cmd_for_test "./llama3" "$LLAMA3_INPUT_BIN" "$LLAMA3_LLMC_FILEPATH" "$llama3_arg_str" "$nproc_per_node")"
run_and_log "$llama3_cmd" "llama3_${test_id}${log_suffix}" "$profile_flag" "$group_tag"
done

Expand Down
Loading
Loading