Skip to content

Fix batch embedding averaging for batch_size > 1#3839

Open
Chessing234 wants to merge 2 commits intolm-sys:mainfrom
Chessing234:fix/issue-3785-batch-embedding-averaging
Open

Fix batch embedding averaging for batch_size > 1#3839
Chessing234 wants to merge 2 commits intolm-sys:mainfrom
Chessing234:fix/issue-3785-batch-embedding-averaging

Conversation

@Chessing234
Copy link
Copy Markdown

Summary

  • Fix incorrect embedding computation when batch_size > 1 in ModelWorker.get_embeddings()
  • Previously, token_num was computed as a single scalar summing tokens across the entire batch (torch.sum(attention_mask).item()), but sum_embeddings is per-sequence (shape [batch_size, hidden_dim]). Dividing a per-sequence tensor by a batch-wide scalar produces wrong averages for every sequence except when batch_size == 1.
  • Changed to per-sequence token counts using attention_mask.sum(dim=1, keepdim=True) so each sequence's embedding is divided by its own token count. The ret["token_num"] return value remains a scalar (total tokens) for API compatibility.

Fixes #3785

Test plan

  • Verify embedding output is identical for batch_size=1 (no regression)
  • Compare embeddings computed with batch_size=1 vs batch_size>1 for the same inputs -- they should now match
  • Test both embed_in_truncate and chunked (non-truncate) code paths
  • Test with use_cls_pooling enabled and disabled

🤖 Generated with Claude Code

Chessing234 and others added 2 commits April 6, 2026 19:02
Initialize x before the loop to prevent UnboundLocalError if
generate_stream_gate yields no items.

Fixes lm-sys#3786

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Compute per-sequence token counts instead of a single scalar across
the entire batch. This fixes incorrect embeddings when batch_size > 1.

Fixes lm-sys#3785

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Batch embedding averaging is incorrect for batch_size > 1

1 participant