-
Notifications
You must be signed in to change notification settings - Fork 64
Expand file tree
/
Copy pathengine.hpp
More file actions
210 lines (194 loc) · 10 KB
/
engine.hpp
File metadata and controls
210 lines (194 loc) · 10 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
#include "../../engine/infer_engine.hpp"
#include "infinicore/tensor.hpp"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
namespace py = pybind11;
namespace infinilm::engine::distributed {
inline void bind_dist_config(py::module &m) {
py::class_<DistConfig>(m, "DistConfig")
.def(py::init<>(), "Default constructor, empty device list")
.def(py::init<int>(), py::arg("tp_size"),
"Constructor with tensor parallel size, auto-assigns device IDs 0..tp_size-1")
.def(py::init<const std::vector<int> &>(), py::arg("tp_device_ids"),
"Constructor with explicit device IDs")
.def_readwrite("tp_device_ids", &DistConfig::tp_device_ids,
"List of device IDs used in tensor parallelism")
.def("__repr__", [](const DistConfig &cfg) {
return std::string(cfg);
})
.def("__str__", [](const DistConfig &cfg) {
return std::string(cfg);
});
}
} // namespace infinilm::engine::distributed
namespace infinilm::engine {
inline void bind_infer_engine(py::module &m) {
py::class_<InferEngine, std::shared_ptr<InferEngine>> infer_engine(m, "InferEngine");
infer_engine
.def(py::init([](
const InfinilmModel::Config &cfg,
const distributed::DistConfig &dist,
infinicore::Device::Type dev,
std::shared_ptr<const infinilm::cache::CacheConfig> cache_cfg,
bool enable_graph_compiling,
const std::string &attention_backend) {
return std::make_shared<InferEngine>(
cfg,
dist,
dev,
cache_cfg ? cache_cfg.get() : nullptr,
enable_graph_compiling,
infinilm::backends::parse_attention_backend(attention_backend));
}),
py::arg("config"),
py::arg("distributed_config") = distributed::DistConfig(),
py::arg("device_type") = infinicore::context::getDevice().getType(),
py::arg("cache_config") = py::none(),
py::arg("enable_graph_compiling") = false,
py::arg("attention_backend") = "default")
.def("load_param", &InferEngine::load_param,
py::arg("name"), py::arg("param"),
"Load a parameter tensor into all workers (each worker picks its shard)")
.def("state_dict", [](InferEngine &self) {
py::list state_dict_tp_all;
for (const auto &state_dict_tp : self.state_dict()) {
py::dict result;
for (const auto &[name, param] : state_dict_tp) {
result[py::cast(name)] = infinicore::Tensor(param);
}
state_dict_tp_all.append(result);
}
return state_dict_tp_all;
})
.def(
"forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output {
py::gil_scoped_release release;
return self.forward(input);
},
"Run inference on all ranks with arbitrary arguments")
.def(
"reset_cache", [](InferEngine &self, std::shared_ptr<const cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
.def("get_cache_config", [](const InferEngine &self) -> std::shared_ptr<cache::CacheConfig> {
auto cfg = self.get_cache_config();
return cfg ? std::shared_ptr<cache::CacheConfig>(cfg->unique_copy()) : nullptr;
})
.def("__repr__", [](const InferEngine &self) { return "<InferEngine: " + std::string(self.get_dist_config()) + ">"; });
infer_engine
.def(py::init([](
const std::string &model_path,
const distributed::DistConfig &dist,
infinicore::Device::Type dev,
std::shared_ptr<const infinilm::cache::CacheConfig> cache_cfg,
bool enable_graph_compiling,
const std::string &attention_backend) {
return std::make_shared<InferEngine>(
model_path,
dist,
dev,
cache_cfg ? cache_cfg.get() : nullptr,
enable_graph_compiling,
infinilm::backends::parse_attention_backend(attention_backend));
}),
py::arg("model_path") = "",
py::arg("distributed_config") = distributed::DistConfig(),
py::arg("device_type") = infinicore::context::getDevice().getType(),
py::arg("cache_config") = py::none(),
py::arg("enable_graph_compiling") = false,
py::arg("attention_backend") = "default")
.def("load_param", &InferEngine::load_param,
py::arg("name"), py::arg("param"),
"Load a parameter tensor into all workers (each worker picks its shard)")
.def("state_dict", [](InferEngine &self) {
py::list state_dict_tp_all;
for (const auto &state_dict_tp : self.state_dict()) {
py::dict result;
for (const auto &[name, param] : state_dict_tp) {
result[py::cast(name)] = infinicore::Tensor(param);
}
state_dict_tp_all.append(result);
}
return state_dict_tp_all;
})
.def(
"forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output {
py::gil_scoped_release release;
return self.forward(input);
},
"Run inference on all ranks with arbitrary arguments")
.def(
"reset_cache", [](InferEngine &self, std::shared_ptr<const cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
.def("get_cache_config", [](const InferEngine &self) {
auto cfg = self.get_cache_config();
return std::shared_ptr<cache::CacheConfig>(std::move(cfg->unique_copy())); })
.def("__repr__", [](const InferEngine &self) { return "<InferEngine: " + std::string(self.get_dist_config()) + ">"; });
py::class_<InferEngine::Input>(infer_engine, "Input")
.def(
py::init([](
std::optional<infinicore::Tensor> input_ids,
std::optional<infinicore::Tensor> position_ids,
std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping,
py::kwargs kwargs) {
InferEngine::Input input{
std::move(input_ids),
std::move(position_ids),
std::move(past_sequence_lengths),
std::move(total_sequence_lengths),
std::move(input_offsets),
std::move(cu_seqlens),
std::move(block_tables),
std::move(slot_mapping),
};
// Explicit defaults
input.temperature = 1.0f;
input.top_p = 1.0f;
input.top_k = 1;
// Allowed keyword arguments
static const std::unordered_set<std::string> allowed_kwargs = {
"temperature",
"top_p",
"top_k",
};
for (auto &item : kwargs) {
const std::string key = py::cast<std::string>(item.first);
if (allowed_kwargs.find(key) == allowed_kwargs.end()) {
throw py::value_error(
"InferEngine.Input got an unexpected keyword argument '" + key + "'");
}
if (key == "temperature") {
input.temperature = py::cast<float>(item.second);
} else if (key == "top_p") {
input.top_p = py::cast<float>(item.second);
} else if (key == "top_k") {
input.top_k = py::cast<int>(item.second);
}
}
return input;
}),
py::arg("input_ids") = std::nullopt,
py::arg("position_ids") = std::nullopt,
py::arg("past_sequence_lengths") = std::nullopt,
py::arg("total_sequence_lengths") = std::nullopt,
py::arg("input_offsets") = std::nullopt,
py::arg("cu_seqlens") = std::nullopt,
py::arg("block_tables") = std::nullopt,
py::arg("slot_mapping") = std::nullopt)
.def_readwrite("input_ids", &InferEngine::Input::input_ids)
.def_readwrite("position_ids", &InferEngine::Input::position_ids)
.def_readwrite("past_sequence_lengths", &InferEngine::Input::past_sequence_lengths)
.def_readwrite("total_sequence_lengths", &InferEngine::Input::total_sequence_lengths)
.def_readwrite("input_offsets", &InferEngine::Input::input_offsets)
.def_readwrite("cu_seqlens", &InferEngine::Input::cu_seqlens)
.def_readwrite("block_tables", &InferEngine::Input::block_tables)
.def_readwrite("slot_mapping", &InferEngine::Input::slot_mapping)
.def_readwrite("temperature", &InferEngine::Input::temperature)
.def_readwrite("top_k", &InferEngine::Input::top_k)
.def_readwrite("top_p", &InferEngine::Input::top_p);
py::class_<InferEngine::Output>(infer_engine, "Output")
.def_readwrite("output_ids", &InferEngine::Output::output_ids, "Output tensor");
}
} // namespace infinilm::engine