-
Notifications
You must be signed in to change notification settings - Fork 110
Expand file tree
/
Copy pathmodule.hpp
More file actions
154 lines (123 loc) · 5.98 KB
/
module.hpp
File metadata and controls
154 lines (123 loc) · 5.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#pragma once
#include "../tensor.hpp"
#include "parameter.hpp"
#include <type_traits>
#include <unordered_map>
#include <vector>
namespace infinicore::nn {
class Module {
public:
Module() = default;
const std::unordered_map<std::string, Parameter> &state_dict() const;
void load_state_dict(const std::unordered_map<std::string, Tensor> &_state_dict);
void load_parameter(const std::string &name, const Tensor ¶m);
void load_parameter_(const std::string &name, const Tensor ¶m);
void load_parameter_from_blob(const std::string &name, const void *data);
protected:
Tensor register_parameter(const std::string &name, Parameter param);
Tensor register_buffer(const std::string &name, Parameter buffer);
// Add an existing submodule to this module's hierarchy
// Template parameter M must be a type derived from Module
// Returns the submodule for convenience (allows method chaining)
template <typename M>
std::shared_ptr<M> add_module(const std::string &name, std::shared_ptr<M> submodule) {
// Ensure M is derived from Module (compile-time check)
static_assert(std::is_base_of<Module, M>::value,
"Template parameter M must be derived from infinicore::nn::Module");
// Store in the submodules map (std::shared_ptr<M> automatically converts to std::shared_ptr<Module>)
submodules_[name] = submodule;
return submodule;
}
// Create and register a new submodule by constructing it with the given arguments
// Template parameter M must be a type derived from Module
// Args are forwarded to M's constructor
template <typename M, typename... Args>
std::shared_ptr<M> register_module(const std::string &name, Args &&...args) {
// Ensure M is derived from Module (compile-time check)
static_assert(std::is_base_of<Module, M>::value,
"Template parameter M must be derived from infinicore::nn::Module");
// Construct the submodule
auto submodule = std::make_shared<M>(std::forward<Args>(args)...);
return add_module(name, submodule);
}
// Create and register multiple submodules of the same type
// Each submodule is named as "name.0", "name.1", etc.
// Template parameter M must be a type derived from Module
template <typename M, typename... Args>
std::vector<std::shared_ptr<M>> register_modules(size_t count, const std::string &name, Args &&...args) {
static_assert(std::is_base_of<Module, M>::value,
"Template parameter M must be derived from infinicore::nn::Module");
std::vector<std::shared_ptr<M>> modules;
modules.reserve(count);
for (size_t i = 0; i < count; i++) {
modules.push_back(register_module<M>(name + "." + std::to_string(i), std::forward<Args>(args)...));
}
return modules;
}
protected:
Device device_;
std::unordered_map<std::string, std::shared_ptr<Module>> submodules_;
std::unordered_map<std::string, Parameter> buffers_;
std::unordered_map<std::string, Parameter> parameters_;
private:
void load_state_dict_recursively(const std::unordered_map<std::string, Tensor> &_state_dict, const std::string &prefix = "");
void collect_all_parameters(std::unordered_map<std::string, Parameter> &all_params, const std::string &prefix = "") const;
};
// ============================================================================
// PyTorch-like Macros for Convenient Module Registration
// ============================================================================
/**
* @brief Register submodules with automatic name inference from variable name
*
* Usage:
* @code
* class MyModel : public Module {
* protected:
* INFINICORE_NN_MODULE(Linear, layer1);
* INFINICORE_NN_MODULE(Linear, layer2);
* INFINICORE_NN_MODULE_VEC(Linear, layers);
* INFINICORE_NN_PARAMETER(scaling_factor);
*
* public:
* MyModel() {
* INFINICORE_NN_MODULE_INIT(layer1, 128, 64);
* INFINICORE_NN_MODULE_INIT(layer2, 64, 32);
* INFINICORE_NN_MODULE_VEC_INIT(layers, 3, Linear, 32, 16);
* INFINICORE_NN_PARAMETER_INIT(scaling_factor, ({1}, DataType::F32, Device()));
* }
* };
* @endcode
*/
// Declare a single module member variable
#define INFINICORE_NN_MODULE(ModuleType, name) \
std::shared_ptr<ModuleType> name##_
// Declare a vector of modules member variable
#define INFINICORE_NN_MODULE_VEC(ModuleType, name) \
std::vector<std::shared_ptr<ModuleType>> name##_
// Initialize a module in constructor
#define INFINICORE_NN_MODULE_INIT(name, ...) \
name##_ = this->register_module<std::remove_reference<decltype(*name##_)>::type>(#name, ##__VA_ARGS__)
// Initialize a vector of modules in constructor
// Usage: INFINICORE_NN_MODULE_VEC_INIT(layers, count, ModuleType, ctor_args...)
// Example: INFINICORE_NN_MODULE_VEC_INIT(layers, 3, Linear, 128, 64)
#define INFINICORE_NN_MODULE_VEC_INIT(name, count, ModuleType, ...) \
name##_ = this->register_modules<ModuleType>(count, #name, ##__VA_ARGS__)
// Declare a parameter member variable
#define INFINICORE_NN_PARAMETER(name) \
infinicore::nn::Parameter name##_
// Initialize a parameter in constructor
// Usage: INFINICORE_NN_PARAMETER_INIT(name, (shape, dtype, device))
// Example: INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, DataType::F32, device))
#define INFINICORE_NN_PARAMETER_INIT(name, args) \
name##_ = infinicore::nn::Parameter args; \
this->register_parameter(#name, name##_)
// Declare a buffer member variable
#define INFINICORE_NN_BUFFER(name) \
infinicore::nn::Parameter name##_
// Initialize a buffer in constructor
// Usage: INFINICORE_NN_BUFFER_INIT(name, (shape, dtype, device))
// Example: INFINICORE_NN_BUFFER_INIT(cache, ({max_seq_len, head_dim}, DataType::F32, device))
#define INFINICORE_NN_BUFFER_INIT(name, args) \
name##_ = infinicore::nn::Parameter args; \
this->register_buffer(#name, name##_)
} // namespace infinicore::nn