diff --git a/CMakeLists.txt b/CMakeLists.txt index e25de71d..ff4f0286 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -194,8 +194,8 @@ add_executable(gpt2 example/gpt2/main.cc example/common/tiny_shakespeare_dataset.cc example/common/utils.cc - example/gpt2/checkpoint_loader.cc example/common/tokenizer.cc + example/gpt2/checkpoint_loader.cc ) link_infini_train_exe(gpt2) @@ -203,8 +203,8 @@ add_executable(llama3 example/llama3/main.cc example/common/tiny_shakespeare_dataset.cc example/common/utils.cc - example/llama3/checkpoint_loader.cc example/common/tokenizer.cc + example/llama3/checkpoint_loader.cc ) link_infini_train_exe(llama3) diff --git a/example/gpt2/checkpoint_loader.cc b/example/gpt2/checkpoint_loader.cc index 4a7789e9..c0266be1 100644 --- a/example/gpt2/checkpoint_loader.cc +++ b/example/gpt2/checkpoint_loader.cc @@ -12,8 +12,6 @@ #include "glog/logging.h" -#include "example/common/utils.h" -#include "example/gpt2/config.h" #include "infini_train/include/nn/modules/normalization.h" #include "infini_train/include/nn/modules/sparse.h" #include "infini_train/include/nn/modules/transformer/causal_self_attention.h" @@ -24,6 +22,9 @@ #include "infini_train/include/nn/parallel/tensor_parallel.h" #include "infini_train/include/tensor.h" +#include "example/common/utils.h" +#include "example/gpt2/config.h" + using namespace infini_train; namespace nn = infini_train::nn; diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index 67738e14..d27421c8 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -10,6 +11,7 @@ #include "glog/logging.h" #include "infini_train/include/autocast.h" +#include "infini_train/include/checkpoint/checkpoint.h" #include "infini_train/include/core/runtime/device_guard.h" #include "infini_train/include/dataloader.h" #include "infini_train/include/device.h" @@ -29,6 +31,7 @@ #ifdef PROFILE_MODE #include "infini_train/include/profiler.h" #endif +#include "infini_train/include/checkpoint/checkpoint_manager.h" #include "infini_train/include/nn/parallel/utils.h" #include "infini_train/include/utils/global_module_hook_registry.h" #include "infini_train/include/utils/precision_check_config.h" @@ -39,6 +42,7 @@ #include "example/gpt2/checkpoint_loader.h" #include "example/gpt2/config.h" +// TODO(jym): Reorganize CLI flags into categories for better readability and maintainability. // I/O DEFINE_string(input_bin, "", "input .bin to train on"); DEFINE_string(input_val_bin, "", "input .bin to eval validation loss on"); @@ -77,6 +81,11 @@ DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage."); // precision DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)"); +DEFINE_uint32(save_interval, 0, "save checkpoint every N steps; 0 disables saving"); +DEFINE_string(load, "", "checkpoint directory to resume from"); +DEFINE_string(save, "./checkpoints", "root directory used to store checkpoints"); +DEFINE_uint32(max_checkpoint_keep, 3, "max number of checkpoint steps to keep"); +DEFINE_bool(save_optimizer_state, true, "whether optimizer state is persisted in checkpoints"); // precision check DEFINE_string( precision_check, "", @@ -315,9 +324,55 @@ void Train(const nn::parallel::Rank &rank) { auto impl = core::GetDeviceGuardImpl(device.type()); + int start_step = 0; + TrainerState state; + const auto resume_result = ResumeFromCheckpoint({.resume_root = FLAGS_load, + .rank = rank, + .model = model, + .optimizer = optimizer, + .model_config = model_config, + .state = state, + .load_optimizer_state = false}); + start_step = resume_result.global_step; + size_t consumed_batches = resume_result.consumed_batches; + + // TODO(jym): Replace with Sampler abstraction when available. + // Skip dataloader to resume from the correct batch position. + if (consumed_batches > 0) { + size_t start = train_iter.BatchIndex(); + // Each rank processes every ddp_world_size-th batch starting from its own rank. + // num_skips calculates how many ++ iterations to reach the saved batch position. + size_t num_skips = (consumed_batches - start) / ddp_world_size; + for (size_t i = 0; i < num_skips; ++i) { ++train_iter; } + } + + auto save_checkpoint = [&](const std::filesystem::path &save_dir, int64_t global_step) { + SaveCheckpoint({ + .save_dir = save_dir, + .global_step = global_step, + .consumed_batches = consumed_batches, + .last_lr = FLAGS_learning_rate, + .n_layer = model_config.n_layer, + .n_head = model_config.n_head, + .n_kv_head = model_config.n_kv_head, + .n_embd = model_config.n_embd, + .vocab_size = model_config.vocab_size, + .ddp_size = ddp_world_size, + .tp_size = tp_world_size, + .sp_size = sp_world_size, + .pp_size = pp_world_size, + .save_optimizer_state = FLAGS_save_optimizer_state, + .checkpoint_root_dir = FLAGS_save, + .max_checkpoint_keep = FLAGS_max_checkpoint_keep, + .rank = rank, + .model = *model, + .optimizer = *optimizer, + }); + }; + LOG(INFO) << "start training"; - for (int step = 0; step < FLAGS_num_iteration + 1; ++step) { + for (int step = start_step; step < FLAGS_num_iteration + 1; ++step) { // Reset precision check counters at start of each iteration for file overwrite utils::PrecisionChecker::ResetCounters(); @@ -367,6 +422,7 @@ void Train(const nn::parallel::Rank &rank) { // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below // TODO(dcj): support dataloader.reset() later ++train_iter; + consumed_batches = train_iter.BatchIndex(); x = std::make_shared(x->To(device)); y = std::make_shared(y->To(device)); @@ -397,6 +453,7 @@ void Train(const nn::parallel::Rank &rank) { // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below // TODO(dcj): support dataloader.reset() later ++train_iter; + consumed_batches = train_iter.BatchIndex(); x = std::make_shared(x->To(device)); y = std::make_shared(y->To(device)); @@ -431,6 +488,15 @@ void Train(const nn::parallel::Rank &rank) { } } } + + if (FLAGS_save_interval > 0 && (step + 1) % FLAGS_save_interval == 0) { + std::filesystem::path step_dir + = std::filesystem::path(FLAGS_save) / std::format("checkpoint_step_{:06d}", step + 1); + if (rank.IsParallel()) { + step_dir /= std::format("rank_{:06d}", rank.GlobalRank()); + } + save_checkpoint(step_dir, step + 1); + } } // Save LoRA weights if enabled and path specified @@ -439,6 +505,12 @@ void Train(const nn::parallel::Rank &rank) { nn::lora::SaveLoRAWeights(model, FLAGS_lora_save_path); } + std::filesystem::path final_dir = std::filesystem::path(FLAGS_save) / "checkpoint_final"; + if (rank.IsParallel()) { + final_dir /= std::format("rank_{:06d}", rank.GlobalRank()); + } + save_checkpoint(final_dir, FLAGS_num_iteration); + #ifdef PROFILE_MODE Profiler::Instance().Report("gpt2.report", Profiler::SortBy::DeviceTimePercentage); Profiler::Instance().PrintRecords("gpt2.records.log"); diff --git a/example/llama3/checkpoint_loader.cc b/example/llama3/checkpoint_loader.cc index f29bc540..2d3b25f9 100644 --- a/example/llama3/checkpoint_loader.cc +++ b/example/llama3/checkpoint_loader.cc @@ -12,8 +12,6 @@ #include "glog/logging.h" -#include "example/common/utils.h" -#include "example/llama3/config.h" #include "infini_train/include/nn/modules/normalization.h" #include "infini_train/include/nn/modules/transformer/causal_self_attention.h" #include "infini_train/include/nn/modules/transformer/mlp.h" @@ -22,6 +20,9 @@ #include "infini_train/include/nn/parallel/tensor_parallel.h" #include "infini_train/include/tensor.h" +#include "example/common/utils.h" +#include "example/llama3/config.h" + using namespace infini_train; namespace nn = infini_train::nn; diff --git a/example/llama3/main.cc b/example/llama3/main.cc index fadf205e..72538a73 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -8,6 +9,8 @@ #include "glog/logging.h" #include "infini_train/include/autocast.h" +#include "infini_train/include/checkpoint/checkpoint.h" +#include "infini_train/include/checkpoint/checkpoint_manager.h" #include "infini_train/include/core/runtime/device_guard.h" #include "infini_train/include/dataloader.h" #include "infini_train/include/device.h" @@ -38,6 +41,7 @@ #include "example/llama3/checkpoint_loader.h" #include "example/llama3/config.h" +// TODO(jym): Reorganize CLI flags into categories for better readability and maintainability. // I/O DEFINE_string(input_bin, "", "input .bin to train on"); DEFINE_string(input_val_bin, "", "input .bin to eval validation loss on"); @@ -75,6 +79,12 @@ DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage."); // precision DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)"); +DEFINE_uint32(save_interval, 0, "save checkpoint every N steps; 0 disables saving"); +DEFINE_string(load, "", "checkpoint directory to resume from"); +DEFINE_string(save, "./checkpoints", "root directory used to store checkpoints"); +DEFINE_uint32(max_checkpoint_keep, 3, "max number of checkpoint steps to keep"); +DEFINE_bool(save_optimizer_state, true, "whether optimizer state is persisted in checkpoints"); + // precision check DEFINE_string( precision_check, "", @@ -293,7 +303,54 @@ void Train(const nn::parallel::Rank &rank) { auto impl = core::GetDeviceGuardImpl(device.type()); - for (int step = 0; step < FLAGS_num_iteration + 1; ++step) { + int start_step = 0; + TrainerState state; + const auto resume_result = ResumeFromCheckpoint({.resume_root = FLAGS_load, + .rank = rank, + .model = model, + .optimizer = optimizer, + .model_config = model_config, + .state = state, + .load_optimizer_state = true}); + + start_step = resume_result.global_step; + size_t consumed_batches = resume_result.consumed_batches; + + // TODO(jym): Replace with Sampler abstraction when available. + // Skip dataloader to resume from the correct batch position. + if (consumed_batches > 0) { + size_t start = train_iter.BatchIndex(); + // Each rank processes every ddp_world_size-th batch starting from its own rank. + // num_skips calculates how many ++ iterations to reach the saved batch position. + size_t num_skips = (consumed_batches - start) / ddp_world_size; + for (size_t i = 0; i < num_skips; ++i) { ++train_iter; } + } + + auto save_checkpoint = [&](const std::filesystem::path &save_dir, int64_t global_step) { + SaveCheckpoint({ + .save_dir = save_dir, + .global_step = global_step, + .consumed_batches = consumed_batches, + .last_lr = FLAGS_learning_rate, + .n_layer = model_config.n_layer, + .n_head = model_config.n_head, + .n_kv_head = model_config.n_kv_head, + .n_embd = model_config.n_embd, + .vocab_size = model_config.vocab_size, + .ddp_size = ddp_world_size, + .tp_size = tp_world_size, + .sp_size = sp_world_size, + .pp_size = pp_world_size, + .save_optimizer_state = FLAGS_save_optimizer_state, + .checkpoint_root_dir = FLAGS_save, + .max_checkpoint_keep = FLAGS_max_checkpoint_keep, + .rank = rank, + .model = *model, + .optimizer = *optimizer, + }); + }; + + for (int step = start_step; step < FLAGS_num_iteration + 1; ++step) { // Reset precision check counters at start of each iteration for file overwrite utils::PrecisionChecker::ResetCounters(); @@ -343,6 +400,7 @@ void Train(const nn::parallel::Rank &rank) { // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below // TODO(dcj): support dataloader.reset() later ++train_iter; + consumed_batches = train_iter.BatchIndex(); x = std::make_shared(x->To(device)); y = std::make_shared(y->To(device)); @@ -372,6 +430,7 @@ void Train(const nn::parallel::Rank &rank) { // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below // TODO(dcj): support dataloader.reset() later ++train_iter; + consumed_batches = train_iter.BatchIndex(); x = std::make_shared(x->To(device)); y = std::make_shared(y->To(device)); @@ -406,6 +465,15 @@ void Train(const nn::parallel::Rank &rank) { } } } + + if (FLAGS_save_interval > 0 && (step + 1) % FLAGS_save_interval == 0) { + std::filesystem::path step_dir + = std::filesystem::path(FLAGS_save) / std::format("checkpoint_step_{:06d}", step + 1); + if (rank.IsParallel()) { + step_dir /= std::format("rank_{:06d}", rank.GlobalRank()); + } + save_checkpoint(step_dir, step + 1); + } } // Save LoRA weights if enabled and path specified @@ -414,6 +482,12 @@ void Train(const nn::parallel::Rank &rank) { nn::lora::SaveLoRAWeights(model, FLAGS_lora_save_path); } + std::filesystem::path final_dir = std::filesystem::path(FLAGS_save) / "checkpoint_final"; + if (rank.IsParallel()) { + final_dir /= std::format("rank_{:06d}", rank.GlobalRank()); + } + save_checkpoint(final_dir, FLAGS_num_iteration); + #ifdef PROFILE_MODE Profiler::Instance().Report("llama3.report", Profiler::SortBy::DeviceTimePercentage); Profiler::Instance().PrintRecords("llama3.records.log"); diff --git a/infini_train/include/checkpoint/checkpoint.h b/infini_train/include/checkpoint/checkpoint.h new file mode 100644 index 00000000..b122a17a --- /dev/null +++ b/infini_train/include/checkpoint/checkpoint.h @@ -0,0 +1,52 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace infini_train { +class Optimizer; +class Tensor; +namespace nn { +class Module; +} + +struct TrainerState { + int64_t global_step = 0; + int64_t consumed_batches = 0; + // FIXME(jym): learning_rate should be restored from scheduler state, move `last_lr` from TrainerState to + // SchedulerState later + double last_lr = 0.0; + int64_t n_layer = 0; + int64_t n_head = 0; + int64_t n_kv_head = 0; + int64_t n_embd = 0; + int64_t vocab_size = 0; + int ddp_size = 1; + int tp_size = 1; + int sp_size = 1; + int pp_size = 1; +}; + +class Checkpoint { +public: + static void Save(const std::filesystem::path &checkpoint_dir, const nn::Module &model, const Optimizer *optimizer, + const TrainerState &state, bool save_optimizer_state); + + static void Load(const std::filesystem::path &checkpoint_dir, nn::Module &model, Optimizer *optimizer, + TrainerState &state, bool load_optimizer_state); + +private: + static void SaveStateDict(const std::filesystem::path &path, + const std::unordered_map> &state_dict); + + static std::unordered_map> LoadStateDict(const std::filesystem::path &path); + + static void SaveTrainerState(const std::filesystem::path &path, const TrainerState &state); + static TrainerState LoadTrainerState(const std::filesystem::path &path); +}; + +} // namespace infini_train diff --git a/infini_train/include/checkpoint/checkpoint_manager.h b/infini_train/include/checkpoint/checkpoint_manager.h new file mode 100644 index 00000000..cce14107 --- /dev/null +++ b/infini_train/include/checkpoint/checkpoint_manager.h @@ -0,0 +1,59 @@ +#pragma once + +#include +#include +#include + +#include "infini_train/include/checkpoint/checkpoint.h" +#include "infini_train/include/dataloader.h" +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/parallel/rank.h" +#include "infini_train/include/optimizer.h" + +using namespace infini_train; +namespace nn = infini_train::nn; + +namespace infini_train::nn { +class TransformerConfig; +} + +struct ResumeFromCheckpointArgs { + std::filesystem::path resume_root; + const nn::parallel::Rank &rank; + std::shared_ptr model; + std::shared_ptr optimizer; + const nn::TransformerConfig &model_config; + TrainerState &state; + bool load_optimizer_state; +}; + +struct ResumeFromCheckpointResult { + int global_step = 0; + size_t consumed_batches = 0; +}; + +struct SaveCheckpointArgs { + std::filesystem::path save_dir; + int64_t global_step = 0; + size_t consumed_batches = 0; + double last_lr = 0.0; + int64_t n_layer = 0; + int64_t n_head = 0; + int64_t n_kv_head = 0; + int64_t n_embd = 0; + int64_t vocab_size = 0; + int ddp_size = 1; + int tp_size = 1; + int sp_size = 1; + int pp_size = 1; + bool save_optimizer_state = true; + std::filesystem::path checkpoint_root_dir; + size_t max_checkpoint_keep = 0; + const nn::parallel::Rank &rank; + const nn::Module &model; + const Optimizer &optimizer; +}; + +ResumeFromCheckpointResult ResumeFromCheckpoint(const ResumeFromCheckpointArgs &args); + +void SaveCheckpoint(const SaveCheckpointArgs &args); diff --git a/infini_train/include/dataloader.h b/infini_train/include/dataloader.h index ad7fbcda..38fa02a4 100644 --- a/infini_train/include/dataloader.h +++ b/infini_train/include/dataloader.h @@ -24,6 +24,8 @@ class DataLoaderIterator { friend bool operator!=(const DataLoaderIterator &lhs, const DataLoaderIterator &rhs); friend bool operator==(const DataLoaderIterator &lhs, const DataLoaderIterator &rhs); + size_t BatchIndex() const; + private: const Dataset *dataset_ = nullptr; // not owned size_t batch_size_ = 0; diff --git a/infini_train/include/nn/modules/module.h b/infini_train/include/nn/modules/module.h index f366661b..2b8cd6dc 100644 --- a/infini_train/include/nn/modules/module.h +++ b/infini_train/include/nn/modules/module.h @@ -61,6 +61,10 @@ class Module : public std::enable_shared_from_this { std::unordered_map> StateDict() const; + // Current behavior: missing keys / shape / dtype mismatches are FATAL errors; unexpected keys in state_dict are + // WARNING-only and silently ignored. + void LoadStateDict(const std::unordered_map> &state_dict); + // operator() calls hooks and Forward std::vector> operator()(const std::vector> &input_tensors); diff --git a/infini_train/include/nn/parallel/ddp/distributed_optimizer.h b/infini_train/include/nn/parallel/ddp/distributed_optimizer.h index bc31442e..559c4312 100644 --- a/infini_train/include/nn/parallel/ddp/distributed_optimizer.h +++ b/infini_train/include/nn/parallel/ddp/distributed_optimizer.h @@ -28,6 +28,10 @@ class DistributedOptimizer final : public infini_train::Optimizer { void ZeroGrad(bool set_to_none = true) override; + std::unordered_map> StateDict() const override; + + void LoadStateDict(const std::unordered_map> &state_dict) override; + void StartGradSync(); void FinishGradSync(); diff --git a/infini_train/include/optimizer.h b/infini_train/include/optimizer.h index fb0ae2d5..01051053 100644 --- a/infini_train/include/optimizer.h +++ b/infini_train/include/optimizer.h @@ -3,6 +3,8 @@ #include #include #include +#include +#include #include namespace infini_train { @@ -21,6 +23,10 @@ class Optimizer { virtual void Step() = 0; + virtual std::unordered_map> StateDict() const { return {}; }; + + virtual void LoadStateDict(const std::unordered_map> &state_dict) {} + protected: std::vector> params_; }; @@ -32,11 +38,7 @@ class SGD : public Optimizer { void Step() override; - static OptimizerCreator Create(float learning_rate) { - return [learning_rate](const std::vector> ¶ms) { - return std::make_shared(params, learning_rate); - }; - } + static OptimizerCreator Create(float learning_rate); private: const float learning_rate_ = 0.0; @@ -49,12 +51,11 @@ class Adam : public Optimizer { void Step() override; + std::unordered_map> StateDict() const override; + + void LoadStateDict(const std::unordered_map> &state_dict) override; static OptimizerCreator Create(float learning_rate = 1e-3, float beta1 = 0.9, float beta2 = 0.999, - float eps = 1e-8) { - return [=](const std::vector> ¶ms) { - return std::make_shared(params, learning_rate, beta1, beta2, eps); - }; - } + float eps = 1e-8); private: int64_t t_; diff --git a/infini_train/include/utils/string_utils.h b/infini_train/include/utils/string_utils.h new file mode 100644 index 00000000..070bb992 --- /dev/null +++ b/infini_train/include/utils/string_utils.h @@ -0,0 +1,8 @@ +#pragma once + +#include +#include + +namespace infini_train::utils { +std::string DimsToString(const std::vector &dims); +} diff --git a/infini_train/src/checkpoint/checkpoint.cc b/infini_train/src/checkpoint/checkpoint.cc new file mode 100644 index 00000000..892ec497 --- /dev/null +++ b/infini_train/src/checkpoint/checkpoint.cc @@ -0,0 +1,221 @@ +#include "infini_train/include/checkpoint/checkpoint.h" + +#include +#include +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/optimizer.h" +#include "infini_train/include/tensor.h" + +namespace infini_train { +namespace { +constexpr uint32_t kCkptMagic = 0x54504B43; // CKPT +constexpr uint32_t kCkptVersion = 1; + +void WriteString(std::ofstream *ofs, const std::string &value) { + uint32_t len = static_cast(value.size()); + ofs->write(reinterpret_cast(&len), sizeof(len)); + ofs->write(value.data(), len); +} + +std::string ReadString(std::ifstream *ifs) { + uint32_t len = 0; + ifs->read(reinterpret_cast(&len), sizeof(len)); + std::string s(len, '\0'); + ifs->read(s.data(), len); + return s; +} + +// TODO: This is a hand-rolled JSON field extractor. Replace with a proper JSON library (e.g., nlohmann/json) once +// available in the project dependencies. +template T ExtractNumberField(const std::string &content, const std::string &key, T fallback) { + const auto token = std::string("\"") + key + "\""; + const auto key_pos = content.find(token); + if (key_pos == std::string::npos) { + return fallback; + } + const auto colon_pos = content.find(':', key_pos); + if (colon_pos == std::string::npos) { + return fallback; + } + size_t value_start = colon_pos + 1; + while (value_start < content.size() && (content[value_start] == ' ' || content[value_start] == '\n')) { + ++value_start; + } + size_t value_end = value_start; + while (value_end < content.size() && content[value_end] != ',' && content[value_end] != '\n' + && content[value_end] != '}') { + ++value_end; + } + std::stringstream ss(content.substr(value_start, value_end - value_start)); + T value = fallback; + ss >> value; + if (ss.fail()) { + return fallback; + } + return value; +} +} // namespace + +void Checkpoint::Save(const std::filesystem::path &checkpoint_dir, const nn::Module &model, const Optimizer *optimizer, + const TrainerState &state, bool save_optimizer_state) { + std::filesystem::create_directories(checkpoint_dir); + LOG(INFO) << "[CKPT] Save begin: dir=" << checkpoint_dir << ", global_step=" << state.global_step; + + const auto model_path = checkpoint_dir / ("model.ckpt"); + + SaveStateDict(model_path, model.StateDict()); + + if (save_optimizer_state) { + CHECK(optimizer != nullptr) << "Optimizer pointer is null, cannot save optimizer state."; + auto opt_state = optimizer->StateDict(); + if (!opt_state.empty()) { + const auto opt_path = checkpoint_dir / "optimizer.ckpt"; + SaveStateDict(opt_path, opt_state); + } + } + + SaveTrainerState(checkpoint_dir / "trainer_state.json", state); + LOG(ERROR) << "[CKPT] Save done: dir=" << checkpoint_dir; +} + +void Checkpoint::Load(const std::filesystem::path &checkpoint_dir, nn::Module &model, Optimizer *optimizer, + TrainerState &state, bool load_optimizer_state) { + const auto model_path = checkpoint_dir / "model.ckpt"; + LOG(INFO) << "[CKPT] Loading model: " << model_path; + + model.LoadStateDict(LoadStateDict(model_path)); + + if (load_optimizer_state) { + CHECK(optimizer != nullptr) << "Optimizer pointer is null, cannot load optimizer state."; + const auto opt_path = checkpoint_dir / "optimizer.ckpt"; + if (std::filesystem::exists(opt_path)) { + LOG(INFO) << "[CKPT] Loading optimizer: " << opt_path; + optimizer->LoadStateDict(LoadStateDict(opt_path)); + } else { + LOG(FATAL) << "Optimizer checkpoint not found at: " << opt_path; + } + } + + state = LoadTrainerState(checkpoint_dir / "trainer_state.json"); + LOG(ERROR) << "[CKPT] Load done: global_step=" << state.global_step + << ", consumed_batches =" << state.consumed_batches << ", last_lr=" << state.last_lr + << ", topology(ddp,tp,sp,pp)=(" << state.ddp_size << "," << state.tp_size << "," << state.sp_size << "," + << state.pp_size << ")"; +} + +void Checkpoint::SaveStateDict(const std::filesystem::path &path, + const std::unordered_map> &state_dict) { + std::ofstream ofs(path, std::ios::binary); + CHECK(ofs.is_open()) << "Failed to open checkpoint file: " << path; + + uint32_t magic = kCkptMagic; + uint32_t version = kCkptVersion; + uint32_t count = static_cast(state_dict.size()); + ofs.write(reinterpret_cast(&magic), sizeof(magic)); + ofs.write(reinterpret_cast(&version), sizeof(version)); + ofs.write(reinterpret_cast(&count), sizeof(count)); + + for (const auto &[name, tensor] : state_dict) { + WriteString(&ofs, name); + + const int8_t dtype = static_cast(tensor->Dtype()); + ofs.write(reinterpret_cast(&dtype), sizeof(dtype)); + + const auto &dims = tensor->Dims(); + uint32_t ndim = static_cast(dims.size()); + ofs.write(reinterpret_cast(&ndim), sizeof(ndim)); + for (const auto dim : dims) { ofs.write(reinterpret_cast(&dim), sizeof(dim)); } + + Tensor cpu_tensor = tensor->To(Device()); + uint64_t bytes = static_cast(cpu_tensor.SizeInBytes()); + ofs.write(reinterpret_cast(&bytes), sizeof(bytes)); + ofs.write(reinterpret_cast(cpu_tensor.DataPtr()), static_cast(bytes)); + } +} + +std::unordered_map> Checkpoint::LoadStateDict(const std::filesystem::path &path) { + std::ifstream ifs(path, std::ios::binary); + CHECK(ifs.is_open()) << "Failed to open checkpoint file: " << path; + + uint32_t magic = 0; + uint32_t version = 0; + uint32_t count = 0; + ifs.read(reinterpret_cast(&magic), sizeof(magic)); + ifs.read(reinterpret_cast(&version), sizeof(version)); + ifs.read(reinterpret_cast(&count), sizeof(count)); + + CHECK_EQ(magic, kCkptMagic) << "Invalid checkpoint magic: " << path; + CHECK_EQ(version, kCkptVersion) << "Unsupported checkpoint version: " << path; + + std::unordered_map> state; + for (uint32_t i = 0; i < count; ++i) { + const std::string name = ReadString(&ifs); + + int8_t dtype_raw = 0; + ifs.read(reinterpret_cast(&dtype_raw), sizeof(dtype_raw)); + DataType dtype = static_cast(dtype_raw); + + uint32_t ndim = 0; + ifs.read(reinterpret_cast(&ndim), sizeof(ndim)); + std::vector dims(ndim); + for (uint32_t d = 0; d < ndim; ++d) { ifs.read(reinterpret_cast(&dims[d]), sizeof(dims[d])); } + + uint64_t bytes = 0; + ifs.read(reinterpret_cast(&bytes), sizeof(bytes)); + + auto tensor = std::make_shared(dims, dtype, Device()); + CHECK_EQ(bytes, tensor->SizeInBytes()) << "Tensor bytes mismatch for key: " << name; + ifs.read(reinterpret_cast(tensor->DataPtr()), static_cast(bytes)); + state.emplace(name, tensor); + } + + return state; +} + +void Checkpoint::SaveTrainerState(const std::filesystem::path &path, const TrainerState &state) { + std::ofstream ofs(path); + CHECK(ofs.is_open()) << "Failed to open trainer state file: " << path; + ofs << "{\n"; + ofs << " \"n_layer\": " << state.n_layer << ",\n"; + ofs << " \"n_head\": " << state.n_head << ",\n"; + ofs << " \"n_kv_head\": " << state.n_kv_head << ",\n"; + ofs << " \"n_embd\": " << state.n_embd << ",\n"; + ofs << " \"vocab_size\": " << state.vocab_size << "\n"; + ofs << " \"global_step\": " << state.global_step << ",\n"; + ofs << " \"consumed_batches\": " << state.consumed_batches << ",\n"; + ofs << " \"last_lr\": " << state.last_lr << ",\n"; + ofs << " \"ddp_size\": " << state.ddp_size << ",\n"; + ofs << " \"tp_size\": " << state.tp_size << ",\n"; + ofs << " \"sp_size\": " << state.sp_size << ",\n"; + ofs << " \"pp_size\": " << state.pp_size << "\n"; + ofs << "}\n"; +} + +// TODO(jym): Add TrainerState JSON version compatibility, referencing PyTorch's checkpoint versioning. +TrainerState Checkpoint::LoadTrainerState(const std::filesystem::path &path) { + std::ifstream ifs(path); + CHECK(ifs.is_open()) << "Failed to open trainer state file: " << path; + const std::string content((std::istreambuf_iterator(ifs)), std::istreambuf_iterator()); + + TrainerState state; + state.n_layer = ExtractNumberField(content, "n_layer", 0); + state.n_head = ExtractNumberField(content, "n_head", 0); + state.n_kv_head = ExtractNumberField(content, "n_kv_head", 0); + state.n_embd = ExtractNumberField(content, "n_embd", 0); + state.vocab_size = ExtractNumberField(content, "vocab_size", 0); + state.global_step = ExtractNumberField(content, "global_step", 0); + state.consumed_batches = ExtractNumberField(content, "consumed_batches", 0); + state.last_lr = ExtractNumberField(content, "last_lr", 0.0); + state.ddp_size = ExtractNumberField(content, "ddp_size", 1); + state.tp_size = ExtractNumberField(content, "tp_size", 1); + state.sp_size = ExtractNumberField(content, "sp_size", 1); + state.pp_size = ExtractNumberField(content, "pp_size", 1); + return state; +} +} // namespace infini_train diff --git a/infini_train/src/checkpoint/checkpoint_manager.cc b/infini_train/src/checkpoint/checkpoint_manager.cc new file mode 100644 index 00000000..71e15c08 --- /dev/null +++ b/infini_train/src/checkpoint/checkpoint_manager.cc @@ -0,0 +1,120 @@ +#include "infini_train/include/checkpoint/checkpoint_manager.h" + +#include +#include +#include +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/nn/modules/transformer/transformer_config.h" +#include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/tensor.h" + +using namespace infini_train; +namespace nn = infini_train::nn; + +// TODO(jym): ckpt is a new checkpoint format; bin is the legacy format. Keeping both as an interim solution; plan to +// consolidate into one later. +ResumeFromCheckpointResult ResumeFromCheckpoint(const ResumeFromCheckpointArgs &args) { + ResumeFromCheckpointResult result; + if (args.resume_root.empty()) { + LOG(INFO) << "No checkpoint specified for resume. Starting training from scratch."; + return result; + } + + int ddp_world_size = nn::parallel::global::GetDataParallelSize(); + int tp_world_size = nn::parallel::global::GetTensorParallelSize(); + int sp_world_size = nn::parallel::global::GetSequenceParallelEnabled() ? tp_world_size : 1; + int pp_world_size = nn::parallel::global::GetPipelineParallelSize(); + + std::filesystem::path resume_dir = args.resume_root; + if (args.rank.IsParallel()) { + const auto rank_dir = resume_dir / std::format("rank_{:06d}", args.rank.GlobalRank()); + if (std::filesystem::exists(rank_dir)) { + resume_dir = rank_dir; + } + } + + Checkpoint::Load(resume_dir, *args.model, args.optimizer.get(), args.state, args.load_optimizer_state); + + result.global_step = static_cast(args.state.global_step); + + CHECK_EQ(args.state.n_layer, args.model_config.n_layer) + << "n_layer mismatch: ckpt=" << args.state.n_layer << ", config=" << args.model_config.n_layer; + CHECK_EQ(args.state.n_head, args.model_config.n_head) + << "n_head mismatch: ckpt=" << args.state.n_head << ", config=" << args.model_config.n_head; + CHECK_EQ(args.state.n_kv_head, args.model_config.n_kv_head) + << "n_kv_head mismatch: ckpt=" << args.state.n_kv_head << ", config=" << args.model_config.n_kv_head; + CHECK_EQ(args.state.n_embd, args.model_config.n_embd) + << "n_embd mismatch: ckpt=" << args.state.n_embd << ", config=" << args.model_config.n_embd; + CHECK_EQ(args.state.vocab_size, args.model_config.vocab_size) + << "vocab_size mismatch: ckpt=" << args.state.vocab_size << ", config=" << args.model_config.vocab_size; + + CHECK_EQ(args.state.ddp_size, ddp_world_size) << "DDP size mismatch: checkpoint has DDP=" << args.state.ddp_size + << ", but current run has DDP=" << ddp_world_size; + CHECK_EQ(args.state.tp_size, tp_world_size) + << "TP size mismatch: checkpoint has TP=" << args.state.tp_size << ", but current run has TP=" << tp_world_size; + CHECK_EQ(args.state.sp_size, sp_world_size) + << "SP size mismatch: checkpoint has SP=" << args.state.sp_size << ", but current run has SP=" << sp_world_size; + CHECK_EQ(args.state.pp_size, pp_world_size) + << "PP size mismatch: checkpoint has PP=" << args.state.pp_size << ", but current run has PP=" << pp_world_size; + + result.consumed_batches = static_cast(std::max(args.state.consumed_batches, 0)); + if (args.rank.IsMainRank()) { + LOG(INFO) << std::format("Resume training from step {}, last_lr {:.3e}, consumed_batches {}", + args.state.global_step, args.state.last_lr, args.state.consumed_batches); + } + + return result; +} + +void SaveCheckpoint(const SaveCheckpointArgs &args) { + const auto ckpt_start = std::chrono::high_resolution_clock::now(); + + TrainerState state; + state.global_step = args.global_step; + state.consumed_batches = static_cast(args.consumed_batches); + state.last_lr = args.last_lr; + state.n_layer = args.n_layer; + state.n_head = args.n_head; + state.n_kv_head = args.n_kv_head; + state.n_embd = args.n_embd; + state.vocab_size = args.vocab_size; + state.ddp_size = args.ddp_size; + state.tp_size = args.tp_size; + state.sp_size = args.sp_size; + state.pp_size = args.pp_size; + + Checkpoint::Save(args.save_dir, args.model, &args.optimizer, state, args.save_optimizer_state); + + const auto ckpt_end = std::chrono::high_resolution_clock::now(); + const double ckpt_ms = std::chrono::duration(ckpt_end - ckpt_start).count(); + + if (!args.rank.IsMainRank()) { + return; + } + + LOG(INFO) << std::format("Checkpoint saved at: {} ({:.2f} ms)", args.save_dir.string(), ckpt_ms); + + // FIXME(jym): Pruning currently relies on lexicographic sorting of directory names. + // This only works when step directories use zero-padded names (e.g. checkpoint_step_000042). + // If a future change introduces unpadded names, the prune order will be incorrect. + // Consider extracting the step number from the directory name and sorting numerically + // instead, once the checkpoint naming convention is finalized. + if (args.max_checkpoint_keep > 0 && std::filesystem::exists(args.checkpoint_root_dir)) { + std::vector ckpts; + for (const auto &entry : std::filesystem::directory_iterator(args.checkpoint_root_dir)) { + if (entry.is_directory() && entry.path().filename().string().starts_with("checkpoint_step_")) { + ckpts.push_back(entry.path()); + } + } + std::sort(ckpts.begin(), ckpts.end()); + while (ckpts.size() > args.max_checkpoint_keep) { + std::filesystem::remove_all(ckpts.front()); + ckpts.erase(ckpts.begin()); + } + } +} diff --git a/infini_train/src/dataloader.cc b/infini_train/src/dataloader.cc index 322df553..b7cc94f2 100644 --- a/infini_train/src/dataloader.cc +++ b/infini_train/src/dataloader.cc @@ -78,6 +78,8 @@ bool operator==(const DataLoaderIterator &lhs, const DataLoaderIterator &rhs) { return lhs.batch_idx_ == rhs.batch_idx_; } +size_t DataLoaderIterator::BatchIndex() const { return batch_idx_; } + DataLoader::DataLoader(const std::shared_ptr &dataset, size_t batch_size) : dataset_(dataset), batch_size_(batch_size), max_batch_idx_((dataset_->Size() + batch_size_ - 1) / batch_size_) {} diff --git a/infini_train/src/nn/modules/module.cc b/infini_train/src/nn/modules/module.cc index 6d48dcab..81068fe8 100644 --- a/infini_train/src/nn/modules/module.cc +++ b/infini_train/src/nn/modules/module.cc @@ -11,9 +11,9 @@ #include "infini_train/include/autograd/function.h" #include "infini_train/include/common/hook.h" #include "infini_train/include/device.h" -#include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/tensor.h" #include "infini_train/include/utils/global_module_hook_registry.h" +#include "infini_train/include/utils/string_utils.h" #ifndef UNLIKELY #define UNLIKELY(x) __builtin_expect(!!(x), 0) @@ -147,6 +147,49 @@ std::unordered_map> Module::StateDict() con return state; } +void Module::LoadStateDict(const std::unordered_map> &state_dict) { + // Stage 1: Validate all keys, shapes, and dtypes without copying + std::vector error_msgs; + std::unordered_set visited_keys; + auto expected = StateDict(); + + for (const auto &[name, dst] : expected) { + visited_keys.insert(name); + if (!state_dict.contains(name)) { + error_msgs.push_back(std::format("Missing key: {}", name)); + continue; + } + const auto &src = state_dict.at(name); + if (dst->Dims() != src->Dims()) { + error_msgs.push_back(std::format("Shape mismatch for '{}': expected={}, got={}", name, + infini_train::utils::DimsToString(dst->Dims()), + infini_train::utils::DimsToString(src->Dims()))); + } + if (dst->Dtype() != src->Dtype()) { + error_msgs.push_back(std::format("Dtype mismatch for '{}': expected={}, got={}", name, + kDataTypeToDesc.at(dst->Dtype()), kDataTypeToDesc.at(src->Dtype()))); + } + } + + for (const auto &[name, src] : state_dict) { + if (!visited_keys.contains(name)) { + LOG(WARNING) << std::format("Unexpected key in state_dict: {}", name); + } + } + + if (!error_msgs.empty()) { + std::string msg = "LoadStateDict failed:"; + for (const auto &err : error_msgs) { msg += "\n " + err; } + LOG(FATAL) << msg; + } + + // Stage 2: All checks passed, now copy data + for (const auto &[name, dst] : expected) { + const auto &src = state_dict.at(name); + dst->CopyFrom(*src); + } +} + std::vector> Module::Forward(const std::vector> &input_tensors) { LOG(FATAL) << "Forward function not implemented for this module"; return {}; diff --git a/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc b/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc index 3b86106b..d1831c0f 100644 --- a/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc +++ b/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc @@ -136,4 +136,13 @@ void DistributedOptimizer::Step() { FinishParamSync(/*skip_next_bucket_dispatch=*/true); } +std::unordered_map> DistributedOptimizer::StateDict() const { + CHECK(base_optimizer_) << "DistributedOptimizer: base optimizer is null."; + return base_optimizer_->StateDict(); +} + +void DistributedOptimizer::LoadStateDict(const std::unordered_map> &state_dict) { + CHECK(base_optimizer_) << "DistributedOptimizer: base optimizer is null."; + base_optimizer_->LoadStateDict(state_dict); +} } // namespace infini_train::nn::parallel diff --git a/infini_train/src/optimizer.cc b/infini_train/src/optimizer.cc index d5589b01..15925b2f 100644 --- a/infini_train/src/optimizer.cc +++ b/infini_train/src/optimizer.cc @@ -1,5 +1,6 @@ #include "infini_train/include/optimizer.h" +#include #include #include "infini_train/include/core/runtime/device_guard.h" @@ -32,6 +33,12 @@ void SGD::Step() { } } +OptimizerCreator SGD::Create(float learning_rate) { + return [learning_rate](const std::vector> ¶ms) { + return std::make_shared(params, learning_rate); + }; +} + Adam::Adam(const std::vector> ¶ms, float learning_rate, float beta1, float beta2, float eps) : Optimizer(params), t_(0), learning_rate_(learning_rate), beta1_(beta1), beta2_(beta2), eps_(eps) { @@ -62,5 +69,39 @@ void Adam::Step() { kernel.Call(grad, param, m, v, learning_rate_, beta1_, beta2_, eps_, t_); } } + +OptimizerCreator Adam::Create(float learning_rate, float beta1, float beta2, float eps) { + return [=](const std::vector> ¶ms) { + return std::make_shared(params, learning_rate, beta1, beta2, eps); + }; +} + +std::unordered_map> Adam::StateDict() const { + std::unordered_map> state; + for (size_t i = 0; i < m_.size(); ++i) { + state.emplace(std::format("adam.m.{}", i), m_[i]); + state.emplace(std::format("adam.v.{}", i), v_[i]); + } + + auto t_tensor = std::make_shared(std::vector{}, DataType::kINT64, Device()); + *static_cast(t_tensor->DataPtr()) = t_; + state.emplace("adam.t", t_tensor); + return state; +} + +void Adam::LoadStateDict(const std::unordered_map> &state_dict) { + for (size_t i = 0; i < m_.size(); ++i) { + const auto m_key = std::format("adam.m.{}", i); + const auto v_key = std::format("adam.v.{}", i); + CHECK(state_dict.contains(m_key)) << "Missing optimizer state: " << m_key; + CHECK(state_dict.contains(v_key)) << "Missing optimizer state: " << v_key; + m_[i]->CopyFrom(state_dict.at(m_key)); + v_[i]->CopyFrom(state_dict.at(v_key)); + } + + CHECK(state_dict.contains("adam.t")) << "Missing optimizer state: adam.t"; + const Tensor t_cpu = state_dict.at("adam.t")->To(Device()); + t_ = *static_cast(t_cpu.DataPtr()); +} } // namespace optimizers } // namespace infini_train diff --git a/infini_train/src/utils/string_utils.cc b/infini_train/src/utils/string_utils.cc new file mode 100644 index 00000000..e3d9c4ce --- /dev/null +++ b/infini_train/src/utils/string_utils.cc @@ -0,0 +1,18 @@ +#include "infini_train/include/utils/string_utils.h" + +#include + +namespace infini_train::utils { +std::string DimsToString(const std::vector &dims) { + std::ostringstream oss; + oss << "["; + for (size_t i = 0; i < dims.size(); ++i) { + if (i > 0) { + oss << ", "; + } + oss << dims[i]; + } + oss << "]"; + return oss.str(); +} +} // namespace infini_train::utils diff --git a/scripts/run_models_and_profile.bash b/scripts/run_models_and_profile.bash index e3c67293..c9080148 100755 --- a/scripts/run_models_and_profile.bash +++ b/scripts/run_models_and_profile.bash @@ -72,6 +72,7 @@ PROFILE_LOG_DIR="$(read_var PROFILE_LOG_DIR)"; : "${PROFILE_LOG_DIR:=./profile_ COMPARE_LOG_DIR="$(read_var COMPARE_LOG_DIR)"; : "${COMPARE_LOG_DIR:=}" RUN_CTEST="$(read_var RUN_CTEST)"; : "${RUN_CTEST:=true}" CTEST_CMD="$(read_var CTEST_CMD)"; : "${CTEST_CMD:=ctest --output-on-failure -LE cuda -j$(nproc) && ctest --output-on-failure -L cuda -j1}" +CKPT_ROOT_DIR="$(read_var CKPT_ROOT_DIR)"; : "${CKPT_ROOT_DIR:=/data1/ckpt}" mkdir -p "$BUILD_DIR" "$LOG_DIR" "$PROFILE_LOG_DIR" @@ -114,6 +115,15 @@ clean_build_dir() { rm -rf "${BUILD_DIR:?}/"* } +# Clean checkpoint directories (called once at start of script) +clean_checkpoints() { + echo -e "\033[1;31m[CLEAN] Removing checkpoint directories from previous run\033[0m" + if [[ -d "$CKPT_ROOT_DIR" ]]; then + echo -e "\033[1;31m[CLEAN] Removing: ${CKPT_ROOT_DIR}\033[0m" + rm -rf "${CKPT_ROOT_DIR:?}" + fi +} + # Run a command and log output run_and_log() { local cmd="$1" @@ -208,15 +218,34 @@ move_profile_logs() { done } -# Build "--key value" arg string from test_groups[gi].tests[ti].args (shell-escaped) +# Build "--key value" arg string from tests[i].args. +# For checkpoint-related args, automatically isolate by model and run mode +# (resume/no_resume) to avoid cross-test overwrites in one-click runs. args_string_for_test() { local group_idx="$1" local test_idx="$2" - jq -r --argjson g "$group_idx" --argjson t "$test_idx" ' - .test_groups[$g].tests[$t].args - | to_entries[] - | "--\(.key)=\(.value|tostring)" - ' "$CONFIG_FILE" | paste -sd' ' - + local model_name="$3" + local test_id="$4" + + jq -r --argjson g "$group_idx" --argjson t "$test_idx" --arg model "$model_name" --arg test_id "$test_id" ' + def namespaced_path($p; $model; $mode): + if ($p | test("/checkpoint_step_[0-9]+($|/)")) then + ($p | capture("^(?.*)/(?checkpoint_step_[0-9]+(?:/.*)?)$")) as $m + | ($m.prefix + "/" + $model + "/" + $mode + "/" + $m.step) + else + ($p + "/" + $model + "/" + $mode) + end; + + .test_groups[$g].tests[$t].args as $args + | (if ($args | has("load")) then "resume" else "no_resume" end) as $run_mode + | (if (($args.load // "") | test("no_resume")) then "no_resume" else "resume" end) as $resume_src_mode + | $args + | (if has("save") then .save = namespaced_path(.save; $model; $run_mode) else . end) + | (if has("load") then .load = namespaced_path(.load; $model; $resume_src_mode) else . end) + | to_entries[] + | "--\(.key) \(.value|tostring)" + ' "$CONFIG_FILE" | paste -sd' ' - | \ + sed "s|@CKPT_ROOT_DIR@|${CKPT_ROOT_DIR}|g" } # Run tests @@ -269,16 +298,20 @@ for ((id=0; id +#include + +#include "gtest/gtest.h" + +#include "infini_train/include/checkpoint/checkpoint.h" +#include "infini_train/include/nn/modules/linear.h" +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/optimizer.h" +#include "infini_train/include/tensor.h" + +#include "tests/common/test_utils.h" + +using namespace infini_train; +namespace nn = infini_train::nn; + +class CheckpointSerializationTest : public test::InfiniTrainTest {}; + +TEST_P(CheckpointSerializationTest, SaveAndLoadModelFP32) { + auto dir = std::filesystem::temp_directory_path() / "test_ckpt_fp32"; + std::filesystem::remove_all(dir); + + auto model1 = std::make_shared(3, 2, true, GetDevice()); + auto p1 = std::make_shared(std::vector{2, 3}, DataType::kFLOAT32, GetDevice()); + p1->Fill(0.42f); + *model1->mutable_parameter("weight") = p1; + auto p2 = std::make_shared(std::vector{4}, DataType::kFLOAT32, GetDevice()); + p2->Fill(-1.5f); + *model1->mutable_parameter("bias") = p2; + + auto opt1 = std::make_shared(model1->Parameters(), 0.01); + TrainerState saved{.global_step = 42, .consumed_batches = 100}; + Checkpoint::Save(dir, *model1, opt1.get(), saved, true); + + auto model2 = std::make_shared(3, 2, true, GetDevice()); + auto q1 = std::make_shared(std::vector{2, 3}, DataType::kFLOAT32, GetDevice()); + q1->Fill(0.0f); + *model2->mutable_parameter("weight") = q1; + auto q2 = std::make_shared(std::vector{4}, DataType::kFLOAT32, GetDevice()); + q2->Fill(0.0f); + *model2->mutable_parameter("bias") = q2; + auto opt2 = std::make_shared(model2->Parameters(), 0.01); + + TrainerState loaded; + Checkpoint::Load(dir, *model2, opt2.get(), loaded, true); + + EXPECT_EQ(loaded.global_step, 42); + EXPECT_EQ(loaded.consumed_batches, 100); + + auto w1_cpu = model2->parameter("weight")->To(Device()); + const float *data = static_cast(w1_cpu.DataPtr()); + for (int i = 0; i < 6; ++i) { EXPECT_NEAR(data[i], 0.42f, 1e-6); } + + std::filesystem::remove_all(dir); +} + +INFINI_TRAIN_REGISTER_TEST(CheckpointSerializationTest); diff --git a/tests/checkpoint/test_optimizer_state.cc b/tests/checkpoint/test_optimizer_state.cc new file mode 100644 index 00000000..1cbb8b9f --- /dev/null +++ b/tests/checkpoint/test_optimizer_state.cc @@ -0,0 +1,89 @@ +#include + +#include "gtest/gtest.h" + +#include "infini_train/include/optimizer.h" +#include "infini_train/include/tensor.h" + +#include "tests/common/test_utils.h" + +using namespace infini_train; + +class OptimizerStateTest : public test::InfiniTrainTest {}; + +// ---------- Adam StateDict ---------- +TEST_P(OptimizerStateTest, AdamStateDictKeys) { + auto param = std::make_shared(std::vector{3, 4}, DataType::kFLOAT32, GetDevice()); + param->set_requires_grad(true); + param->Fill(1.0f); + + auto adam = std::make_shared(std::vector>{{param}}, 0.001); + + adam->ZeroGrad(); + adam->Step(); // t=1 + adam->Step(); // t=2 + + auto state = adam->StateDict(); + EXPECT_GT(state.size(), 0); + EXPECT_TRUE(state.count("adam.m.0")); + EXPECT_TRUE(state.count("adam.v.0")); + EXPECT_TRUE(state.count("adam.t")); + + auto t_cpu = state["adam.t"]->To(Device()); + int64_t t_val = *static_cast(t_cpu.DataPtr()); + EXPECT_EQ(t_val, 2); +} + +// ---------- Adam LoadStateDict roundtrip ---------- +TEST_P(OptimizerStateTest, AdamStateDictRoundTrip) { + auto param1 = std::make_shared(std::vector{2, 2}, DataType::kFLOAT32, GetDevice()); + param1->set_requires_grad(true); + param1->Fill(1.0f); + + auto adam1 = std::make_shared(std::vector>{{param1}}, 0.001); + adam1->ZeroGrad(); + adam1->Step(); + adam1->Step(); + adam1->Step(); + + auto saved = adam1->StateDict(); + + auto param2 = std::make_shared(std::vector{2, 2}, DataType::kFLOAT32, GetDevice()); + param2->set_requires_grad(true); + param2->Fill(1.0f); + + auto adam2 = std::make_shared(std::vector>{{param2}}, 0.001); + adam2->LoadStateDict(saved); + + adam2->ZeroGrad(); + adam2->Step(); + auto restored = adam2->StateDict(); + auto t_cpu = restored["adam.t"]->To(Device()); + EXPECT_EQ(*static_cast(t_cpu.DataPtr()), 4); // 3 + 1 + + auto saved_state = adam1->StateDict(); + for (const auto &[key, tensor] : saved_state) { + if (key == "adam.t") { + continue; + } + ASSERT_TRUE(restored.count(key)) << "Missing optimizer state key: " << key; + auto s_cpu = tensor->To(Device()); + auto r_cpu = restored.at(key)->To(Device()); + EXPECT_EQ(s_cpu.Dims(), r_cpu.Dims()) << "Shape mismatch for " << key; + const float *s = static_cast(s_cpu.DataPtr()); + const float *r = static_cast(r_cpu.DataPtr()); + for (size_t i = 0; i < s_cpu.NumElements(); ++i) { + EXPECT_NEAR(s[i], r[i], 1e-6) << "Value mismatch for " << key << " at index " << i; + } + } +} + +// ---------- SGD ---------- +TEST_P(OptimizerStateTest, SGDStateDictEmpty) { + auto param = std::make_shared(std::vector{2, 2}, DataType::kFLOAT32, GetDevice()); + param->set_requires_grad(true); + auto sgd = std::make_shared(std::vector>{{param}}, 0.01); + EXPECT_TRUE(sgd->StateDict().empty()); +} + +INFINI_TRAIN_REGISTER_TEST(OptimizerStateTest); diff --git a/tests/checkpoint/test_trainer_state.cc b/tests/checkpoint/test_trainer_state.cc new file mode 100644 index 00000000..532f2eff --- /dev/null +++ b/tests/checkpoint/test_trainer_state.cc @@ -0,0 +1,110 @@ +#include +#include +#include + +#include "gtest/gtest.h" + +#include "infini_train/include/checkpoint/checkpoint.h" +#include "infini_train/include/nn/modules/linear.h" +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/optimizer.h" +#include "infini_train/include/tensor.h" + +#include "tests/common/test_utils.h" + +using namespace infini_train; +namespace nn = infini_train::nn; + +class TrainerStateTest : public test::InfiniTrainTest {}; + +TEST_P(TrainerStateTest, DefaultValues) { + TrainerState state; + EXPECT_EQ(state.global_step, 0); + EXPECT_EQ(state.consumed_batches, 0); + EXPECT_EQ(state.n_layer, 0); + EXPECT_EQ(state.n_head, 0); + EXPECT_EQ(state.n_kv_head, 0); + EXPECT_EQ(state.n_embd, 0); + EXPECT_EQ(state.vocab_size, 0); + EXPECT_EQ(state.ddp_size, 1); + EXPECT_EQ(state.tp_size, 1); + EXPECT_EQ(state.sp_size, 1); + EXPECT_EQ(state.pp_size, 1); + EXPECT_EQ(state.last_lr, 0.0); +} + +TEST_P(TrainerStateTest, TrainerStateFileCreated) { + auto dir = std::filesystem::temp_directory_path() / "test_trainer_json"; + std::filesystem::remove_all(dir); + + TrainerState saved{.global_step = 30, .consumed_batches = 1200, .last_lr = 0.001}; + + auto model = std::make_shared(1, 2, true, GetDevice()); + auto p = std::make_shared(std::vector{2}, DataType::kFLOAT32, GetDevice()); + p->Fill(1.0f); + *model->mutable_parameter("weight") = p; + auto opt = std::make_shared(model->Parameters(), 0.01); + + Checkpoint::Save(dir, *model, opt.get(), saved, true); + + EXPECT_TRUE(std::filesystem::exists(dir / "trainer_state.json")); + + std::ifstream ifs(dir / "trainer_state.json"); + std::string content((std::istreambuf_iterator(ifs)), std::istreambuf_iterator()); + EXPECT_NE(content.find("\"global_step\""), std::string::npos); + EXPECT_NE(content.find("\"consumed_batches\""), std::string::npos); + + std::filesystem::remove_all(dir); +} + +TEST_P(TrainerStateTest, RoundTrip) { + auto dir = std::filesystem::temp_directory_path() / "test_trainer_rt"; + std::filesystem::remove_all(dir); + + TrainerState saved{ + .global_step = 99, + .consumed_batches = 5000, + .last_lr = 3e-4, + .n_layer = 24, + .n_head = 16, + .n_kv_head = 8, + .n_embd = 1024, + .vocab_size = 128256, + .ddp_size = 2, + .tp_size = 1, + .sp_size = 1, + .pp_size = 2, + }; + + auto model1 = std::make_shared(1, 3, true, GetDevice()); + auto p1 = std::make_shared(std::vector{3}, DataType::kFLOAT32, GetDevice()); + p1->Fill(0.5f); + *model1->mutable_parameter("weight") = p1; + auto opt1 = std::make_shared(model1->Parameters(), 0.01); + + Checkpoint::Save(dir, *model1, opt1.get(), saved, false); + + auto model2 = std::make_shared(1, 3, true, GetDevice()); + auto p2 = std::make_shared(std::vector{3}, DataType::kFLOAT32, GetDevice()); + p2->Fill(0.0f); + *model2->mutable_parameter("weight") = p2; + auto opt2 = std::make_shared(model2->Parameters(), 0.01); + + TrainerState loaded; + Checkpoint::Load(dir, *model2, opt2.get(), loaded, false); + + EXPECT_EQ(loaded.global_step, 99); + EXPECT_EQ(loaded.consumed_batches, 5000); + EXPECT_NEAR(loaded.last_lr, 3e-4, 1e-10); + EXPECT_EQ(loaded.n_layer, 24); + EXPECT_EQ(loaded.n_head, 16); + EXPECT_EQ(loaded.n_kv_head, 8); + EXPECT_EQ(loaded.n_embd, 1024); + EXPECT_EQ(loaded.vocab_size, 128256); + EXPECT_EQ(loaded.ddp_size, 2); + EXPECT_EQ(loaded.pp_size, 2); + + std::filesystem::remove_all(dir); +} + +INFINI_TRAIN_REGISTER_TEST(TrainerStateTest); diff --git a/tests/optimizer/test_optimizer_creation.cc b/tests/optimizer/test_optimizer_creation.cc index fbaa61e8..eac6d0b8 100644 --- a/tests/optimizer/test_optimizer_creation.cc +++ b/tests/optimizer/test_optimizer_creation.cc @@ -2,7 +2,6 @@ #include "gtest/gtest.h" -#include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/optimizer.h" #include "infini_train/include/tensor.h"