-
Notifications
You must be signed in to change notification settings - Fork 43
Expand file tree
/
Copy pathdistributed_optimizer.h
More file actions
59 lines (43 loc) · 1.58 KB
/
distributed_optimizer.h
File metadata and controls
59 lines (43 loc) · 1.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#pragma once
#include <cstdint>
#include <memory>
#include <unordered_map>
#include <vector>
#include "infini_train/include/optimizer.h"
namespace infini_train::nn {
class Module;
namespace parallel {
class ParamAndGradBuffer;
class ParamAndGradBucketGroup;
} // namespace parallel
} // namespace infini_train::nn
namespace infini_train::nn::parallel {
class DistributedOptimizer final : public infini_train::Optimizer {
public:
DistributedOptimizer(OptimizerCreator base_optimizer_creator,
const std::vector<std::shared_ptr<Tensor>> &full_params,
const std::vector<std::shared_ptr<Module>> &model_chunks, size_t ddp_world_size,
size_t ddp_rank);
void Step() override;
void ZeroGrad(bool set_to_none = true) override;
void StartGradSync();
void FinishGradSync();
// Forward microbatch boundary info to bucket groups.
void SetIsLastMicrobatch(bool is_last_microbatch);
void StartParamSync(bool force_sync = false);
void FinishParamSync(bool skip_next_bucket_dispatch = false);
private:
void BuildShardParamsAndBindGrads();
private:
// Inherit from DDP model
std::vector<std::shared_ptr<ParamAndGradBuffer>> param_grad_buffers_;
std::vector<std::shared_ptr<ParamAndGradBucketGroup>> bucket_groups_;
// DP info
size_t ddp_world_size_;
size_t ddp_rank_;
// shard params
std::vector<std::shared_ptr<Tensor>> shard_params_;
// Base optimizer (SGD, Adam and etc.)
std::shared_ptr<Optimizer> base_optimizer_;
};
} // namespace infini_train::nn::parallel