fix: accumulate domain loss predictions instead of overwriting#523
fix: accumulate domain loss predictions instead of overwriting#523zhpjunfei wants to merge 1 commit into
Conversation
|
张峻飞 seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account. You have signed the CLA already but the status is still pending? Let us recheck it. |
| new_predictions[tower_loss_name].append( | ||
| (domain_index, tower_domain_loss_predict_value) | ||
| ] | ||
| ) |
There was a problem hiding this comment.
Good catch on the overwrite. One thing worth handling in the same PR: now that the list actually accumulates across domains, the sorted(..., key=lambda x: x[0]) a few lines below (line 196) sorts on domain_index as a string (it comes from rsplit("_", 1)[1]). For ≥10 domains this gives lexicographic order — "10" sorts before "2" — which then misaligns the stacked tensor with the integer index from batch.labels[self._domain_input_name] passed to torch.gather. Before this fix the list always had one element so the ordering never mattered; this fix exposes it.
Suggest casting the index to int when appending so the subsequent sort is numeric:
| new_predictions[tower_loss_name].append( | |
| (domain_index, tower_domain_loss_predict_value) | |
| ] | |
| ) | |
| new_predictions[tower_loss_name].append( | |
| (int(domain_index), tower_domain_loss_predict_value) | |
| ) |
|
Nice, targeted fix — the A couple of notes worth considering:
|
What
修复
PEPNet._select_domain_task_output中 domain 预测值累积的 bug。Why
new_predictions初始化为defaultdict(list),但旧代码使用=赋值单元素列表,导致每次循环覆盖之前的值。最终只有最后一个(domain_index, value)参与排序和 stack,使得多 domain 场景下的torch.gather选择结果错误。How
将
new_predictions[tower_loss_name] = [...]改为.append(...),正确累积所有 domain 的预测值。