Skip to content

Commit 3acab3b

Browse files
committed
Logprobs fixes for streaming chat/completions.
This also brings the two chat/completions code paths back into alignment.
1 parent 7ebb0b2 commit 3acab3b

1 file changed

Lines changed: 27 additions & 23 deletions

File tree

endpoints/OAI/utils/chat_completion.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -47,26 +47,24 @@ def _create_response(generations: List[dict], model_name: Optional[str]):
4747

4848
logprob_response = None
4949

50-
tokens = unwrap(generation.get("tokens"), [])
5150
token_probs = unwrap(generation.get("token_probs"), [])
52-
logprobs = unwrap(generation.get("logprobs"), [])
5351
if token_probs:
52+
tokens = unwrap(generation.get("tokens"), [])
53+
logprobs = unwrap(generation.get("logprobs"), [])
5454
collected_token_probs = []
55-
for output_token, token_logprob, top_logprobs in zip(
55+
for generated_token, generated_token_logprob, top_logprobs in zip(
5656
tokens, token_probs, logprobs, strict=True
5757
):
5858
completion_logprobs = [
5959
ChatCompletionLogprob(token=token, logprob=token_logprob)
6060
for token, token_logprob in top_logprobs.items()
6161
]
6262

63-
collected_token_probs.append(
64-
ChatCompletionLogprobChoice(
65-
token=output_token,
66-
logprob=token_logprob,
67-
top_logprobs=completion_logprobs,
68-
)
69-
)
63+
collected_token_probs.append(ChatCompletionLogprobChoice(
64+
token=generated_token,
65+
logprob=generated_token_logprob,
66+
top_logprobs=completion_logprobs,
67+
))
7068

7169
logprob_response = ChatCompletionLogprobs(content=collected_token_probs)
7270

@@ -112,22 +110,28 @@ def _create_stream_chunk(
112110
role="assistant", content=unwrap(generation.get("text"), "")
113111
)
114112

113+
logprob_response = None
114+
115115
token_probs = unwrap(generation.get("token_probs"), {})
116116
if token_probs:
117-
logprobs = unwrap(generation.get("logprobs"), {})
118-
top_logprobs = [
119-
ChatCompletionLogprob(token=token, logprob=logprob)
120-
for token, logprob in logprobs.items()
121-
]
122-
123-
generated_token = next(iter(token_probs))
124-
token_prob_response = ChatCompletionLogprob(
125-
token=generated_token,
126-
logprob=token_probs[generated_token],
127-
top_logprobs=top_logprobs,
128-
)
117+
tokens = unwrap(generation.get("tokens"), [])
118+
logprobs = unwrap(generation.get("logprobs"), [])
119+
collected_token_probs = []
120+
for generated_token, generated_token_logprob, top_logprobs in zip(
121+
tokens, token_probs, logprobs, strict=True
122+
):
123+
completion_logprobs = [
124+
ChatCompletionLogprob(token=token, logprob=token_logprob)
125+
for token, token_logprob in top_logprobs.items()
126+
]
127+
128+
collected_token_probs.append(ChatCompletionLogprobChoice(
129+
token=generated_token,
130+
logprob=generated_token_logprob,
131+
top_logprobs=completion_logprobs,
132+
))
129133

130-
logprob_response = ChatCompletionLogprobs(content=[token_prob_response])
134+
logprob_response = ChatCompletionLogprobs(content=collected_token_probs)
131135

132136
choice = ChatCompletionStreamChoice(
133137
index=index,

0 commit comments

Comments
 (0)