Skip to content

fix(pepnet): fix domain output overwrite in _select_domain_task_output#524

Merged
tiankongdeguiji merged 1 commit into
alibaba:masterfrom
zhpjunfei:fix/pepnet-domain-output-bug
May 24, 2026
Merged

fix(pepnet): fix domain output overwrite in _select_domain_task_output#524
tiankongdeguiji merged 1 commit into
alibaba:masterfrom
zhpjunfei:fix/pepnet-domain-output-bug

Conversation

@zhpjunfei
Copy link
Copy Markdown
Contributor

概要

修复 PEPNet 模型 _select_domain_task_output 中的两个 bug:

Bug 1:域输出被覆盖(第 188 行)

new_predictions[tower_loss_name] 使用的是赋值 (=) 而非 append(),导致每个 tower_loss_name 只保留了最后一个域的输出。当 task_domain_num > 1 时,torch.gather 收到的 stacked tensor 形状为 [batch, 1] 而非 [batch, task_domain_num],崩溃报错:
RuntimeError: index 1 is out of bounds for dimension 1 with size 1

Bug 2:≥10 个域时的字符串排序问题(第 189 行)

domain_index 来自 rsplit("_", 1)[1] 是字符串类型,sorted(..., key=lambda x: x[0]) 按字典序排列。当 task_domain_num >= 10 时,"10" 排在 "2" 之前,导致 stacked tensor 的列序与 batch.labels 传入 torch.gather 的整数索引不一致。修复方式是在 append() 时转为 int

测试缺口

现有 test_pepnet 只覆盖了 predict() 的前向传播和 shape 断言。有 bug 的代码路径(loss()_select_domain_task_output)从未被测试覆盖。新增 test_select_domain_task_output,构造 2 个 tower × 10 个域的合成 predictions dict,直接调用 _select_domain_task_output,验证 gathered 的 logits/probs 与每行的 domainf 标签一致。

@CLAassistant
Copy link
Copy Markdown

CLAassistant commented May 22, 2026

CLA assistant check
All committers have signed the CLA.

@tiankongdeguiji
Copy link
Copy Markdown
Collaborator

@zhpjunfei please run pre-commit install and run pre-commit -a to format your code

@zhpjunfei zhpjunfei force-pushed the fix/pepnet-domain-output-bug branch from 30b7c4d to 63b68d4 Compare May 22, 2026 12:18
@zhpjunfei
Copy link
Copy Markdown
Contributor Author

失败测试:test_multi_tower_din_rtp_train_export — 测试的是 DIN + RTP 模型,不涉及 PEPNet
我只改了 tzrec/models/pepnet.py 和 tzrec/models/pepnet_test.py
其他 871 个测试全部通过(包括 PEPNet 单元测试), 说明这个失败是 pre-existing 且无关的

@tiankongdeguiji tiankongdeguiji merged commit 3d6c84c into alibaba:master May 24, 2026
7 of 9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants