-
Notifications
You must be signed in to change notification settings - Fork 73
feat(moe): add MoE inference and expert parallel support #444
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -87,14 +87,15 @@ void RankWorker::load_param(const std::string &name, | |
| //------------------------------------------------------ | ||
| // load_params -- synchronous batch load | ||
| //------------------------------------------------------ | ||
| void RankWorker::load_params(const std::unordered_map<std::string, infinicore::Tensor> ¶ms) { | ||
| void RankWorker::load_params(const std::unordered_map<std::string, infinicore::Tensor> ¶ms, bool strict) { | ||
| { | ||
| std::lock_guard<std::mutex> lock(mutex_); | ||
| if (should_exit_) { | ||
| throw std::runtime_error("RankWorker is closing; cannot load_params"); | ||
| } | ||
|
|
||
| pending_params_ = params; | ||
| pending_params_strict_ = strict; | ||
| job_cmd_ = Command::LOAD_BATCH; | ||
| has_job_ = true; | ||
| job_done_ = false; | ||
|
|
@@ -295,6 +296,7 @@ void RankWorker::thread_loop() { | |
| std::string local_param_name; | ||
| infinicore::Tensor local_param; | ||
| std::unordered_map<std::string, infinicore::Tensor> local_params; | ||
| bool local_params_strict = true; | ||
| Input local_args; | ||
| std::unique_ptr<cache::CacheConfig> local_cache_config; | ||
|
|
||
|
|
@@ -314,6 +316,10 @@ void RankWorker::thread_loop() { | |
| local_param = pending_param_; | ||
| } else if (local_cmd == Command::LOAD_BATCH) { | ||
| local_params = std::move(pending_params_); | ||
| // strict is copied with the batch because loading runs on | ||
| // the worker thread after the caller releases the mutex. | ||
| local_params_strict = pending_params_strict_; | ||
| pending_params_strict_ = true; | ||
| pending_params_.clear(); | ||
| } else if (local_cmd == Command::PREPROCESS) { | ||
|
|
||
|
|
@@ -353,7 +359,7 @@ void RankWorker::thread_loop() { | |
|
|
||
| } else if (local_cmd == Command::LOAD_BATCH) { | ||
| try { | ||
| model_->load_parameters_no_sync(local_params); | ||
| model_->load_parameters_no_sync(local_params, local_params_strict); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 等价于这个写法么 model_->load_parameters_no_sync(local_params, strict);
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是的,现在等价于直接调用 model_->load_parameters_no_sync(local_params, local_params_strict)。这里需要把 strict 继续传下去,否则 Python 侧传入的 non-strict load 对 MoE packed weight 不生效。 |
||
| infinicore::context::syncStream(); | ||
| } catch (const std::exception &e) { | ||
| { | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个新增的replay_output变量,以及graph编译时新增和修改的代码。可以注释或解释一下么,不知道啥意思
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已补充注释。这里的 replay_output 是 graph capture 时为输出保留的普通 Output handle;compiled 里保存的是 GraphTensor/graph 对象,replay 后需要通过这个 handle 拿回模型输出。这样 get_compiled 时可以直接返回可复用的 graph replay 结果。
这个不影响static 推理,已测试