Skip to content

Commit 73545e5

Browse files
committed
fix: Handle duplicate texts correctly in embed_stream
Addresses Copilot review comment: Duplicate texts cause incorrect embedding index assignment. Previously, when batch_texts contained duplicate texts, all embeddings for those duplicates would be assigned the same index (the index of the first occurrence) because list.index() always returns the first match. Now tracks used indices and assigns each embedding to the next unused occurrence of its text in the batch, ensuring correct index assignment even with duplicate texts. Example: texts = ['hello', 'world', 'hello'] Before: indices would be [0, 1, 0] - WRONG After: indices are [0, 1, 2] - CORRECT
1 parent 7c198ea commit 73545e5

1 file changed

Lines changed: 16 additions & 4 deletions

File tree

src/cohere/base_client.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,7 +1209,7 @@ def embed_stream(
12091209
for batch_start in range(0, len(texts_list), batch_size):
12101210
batch_end = min(batch_start + batch_size, len(texts_list))
12111211
batch_texts = texts_list[batch_start:batch_end]
1212-
1212+
12131213
# Get response for this batch
12141214
response = self._raw_client.embed(
12151215
texts=batch_texts,
@@ -1219,15 +1219,27 @@ def embed_stream(
12191219
truncate=truncate,
12201220
request_options=request_options,
12211221
)
1222-
1222+
12231223
# Parse embeddings from response incrementally
12241224
parser = StreamingEmbedParser(response._response, batch_texts)
1225+
# Track used indices to handle duplicate texts correctly
1226+
used_batch_indices = set()
1227+
12251228
for embedding in parser.iter_embeddings():
12261229
# The parser sets embedding.text correctly for multiple embedding types
12271230
# Adjust the global index based on text position in batch
12281231
if embedding.text and embedding.text in batch_texts:
1229-
text_idx_in_batch = batch_texts.index(embedding.text)
1230-
embedding.index = batch_start + text_idx_in_batch
1232+
# Find the next unused occurrence of this text in the batch
1233+
# This handles duplicate texts correctly
1234+
text_idx_in_batch = None
1235+
for idx, text in enumerate(batch_texts):
1236+
if text == embedding.text and idx not in used_batch_indices:
1237+
text_idx_in_batch = idx
1238+
used_batch_indices.add(idx)
1239+
break
1240+
1241+
if text_idx_in_batch is not None:
1242+
embedding.index = batch_start + text_idx_in_batch
12311243
yield embedding
12321244

12331245
def rerank(

0 commit comments

Comments
 (0)