Skip to content

fix: accumulate domain loss predictions instead of overwriting#523

Closed
zhpjunfei wants to merge 1 commit into
alibaba:masterfrom
zhpjunfei:master
Closed

fix: accumulate domain loss predictions instead of overwriting#523
zhpjunfei wants to merge 1 commit into
alibaba:masterfrom
zhpjunfei:master

Conversation

@zhpjunfei
Copy link
Copy Markdown
Contributor

What

修复 PEPNet._select_domain_task_output 中 domain 预测值累积的 bug。

Why

new_predictions 初始化为 defaultdict(list),但旧代码使用 = 赋值单元素列表,导致每次循环覆盖之前的值。最终只有最后一个 (domain_index, value) 参与排序和 stack,使得多 domain 场景下的 torch.gather 选择结果错误。

How

new_predictions[tower_loss_name] = [...] 改为 .append(...),正确累积所有 domain 的预测值。

@CLAassistant
Copy link
Copy Markdown

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.


张峻飞 seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account.
You have signed the CLA already but the status is still pending? Let us recheck it.

@tiankongdeguiji tiankongdeguiji added the claude-review Let Claude Review label May 22, 2026
@github-actions github-actions Bot removed the claude-review Let Claude Review label May 22, 2026
Comment thread tzrec/models/pepnet.py
Comment on lines +188 to +190
new_predictions[tower_loss_name].append(
(domain_index, tower_domain_loss_predict_value)
]
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch on the overwrite. One thing worth handling in the same PR: now that the list actually accumulates across domains, the sorted(..., key=lambda x: x[0]) a few lines below (line 196) sorts on domain_index as a string (it comes from rsplit("_", 1)[1]). For ≥10 domains this gives lexicographic order — "10" sorts before "2" — which then misaligns the stacked tensor with the integer index from batch.labels[self._domain_input_name] passed to torch.gather. Before this fix the list always had one element so the ordering never mattered; this fix exposes it.

Suggest casting the index to int when appending so the subsequent sort is numeric:

Suggested change
new_predictions[tower_loss_name].append(
(domain_index, tower_domain_loss_predict_value)
]
)
new_predictions[tower_loss_name].append(
(int(domain_index), tower_domain_loss_predict_value)
)

@github-actions
Copy link
Copy Markdown

Nice, targeted fix — the defaultdict(list) + = assignment was indeed silently dropping all but the last domain.

A couple of notes worth considering:

  • Test gap: _select_domain_task_output is only reached through loss() / update_metric(), and the existing test_pepnet in tzrec/models/pepnet_test.py only calls the forward pass and asserts output shape. The buggy code path was unexercised, and even if it were, a shape-only assertion wouldn't have caught the overwrite. A small unit test that builds a synthetic predictions dict with ≥2 towers × ≥2 domains and verifies the gathered values match the per-row domainf label would lock this in.
  • Latent sort bug exposed by this fix: see inline comment on the changed lines — domain_index is sorted as a string, so it breaks at ≥10 domains. Worth folding into this PR since the fix is what makes the sort path actually matter.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants