fix(pepnet): fix domain output overwrite in _select_domain_task_output#524
Merged
tiankongdeguiji merged 1 commit intoMay 24, 2026
Merged
Conversation
tiankongdeguiji
previously approved these changes
May 22, 2026
Collaborator
|
@zhpjunfei please run |
3077cce to
30b7c4d
Compare
30b7c4d to
63b68d4
Compare
Contributor
Author
|
失败测试:test_multi_tower_din_rtp_train_export — 测试的是 DIN + RTP 模型,不涉及 PEPNet |
tiankongdeguiji
approved these changes
May 24, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
概要
修复 PEPNet 模型
_select_domain_task_output中的两个 bug:Bug 1:域输出被覆盖(第 188 行)
new_predictions[tower_loss_name]使用的是赋值 (=) 而非append(),导致每个tower_loss_name只保留了最后一个域的输出。当task_domain_num > 1时,torch.gather收到的 stacked tensor 形状为[batch, 1]而非[batch, task_domain_num],崩溃报错:RuntimeError: index 1 is out of bounds for dimension 1 with size 1
Bug 2:≥10 个域时的字符串排序问题(第 189 行)
domain_index来自rsplit("_", 1)[1]是字符串类型,sorted(..., key=lambda x: x[0])按字典序排列。当task_domain_num >= 10时,"10"排在"2"之前,导致 stacked tensor 的列序与batch.labels传入torch.gather的整数索引不一致。修复方式是在append()时转为int。测试缺口
现有
test_pepnet只覆盖了predict()的前向传播和 shape 断言。有 bug 的代码路径(loss()→_select_domain_task_output)从未被测试覆盖。新增test_select_domain_task_output,构造 2 个 tower × 10 个域的合成 predictions dict,直接调用_select_domain_task_output,验证 gathered 的 logits/probs 与每行的domainf标签一致。