Skip to content

Commit 958e222

Browse files
committed
Logprobs fixes for streaming chat/completions.
This also brings the two chat/completions code paths back into alignment.
1 parent 7ebb0b2 commit 958e222

1 file changed

Lines changed: 26 additions & 18 deletions

File tree

endpoints/OAI/utils/chat_completion.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,12 @@ def _create_response(generations: List[dict], model_name: Optional[str]):
4747

4848
logprob_response = None
4949

50-
tokens = unwrap(generation.get("tokens"), [])
5150
token_probs = unwrap(generation.get("token_probs"), [])
52-
logprobs = unwrap(generation.get("logprobs"), [])
5351
if token_probs:
52+
tokens = unwrap(generation.get("tokens"), [])
53+
logprobs = unwrap(generation.get("logprobs"), [])
5454
collected_token_probs = []
55-
for output_token, token_logprob, top_logprobs in zip(
55+
for generated_token, generated_token_logprob, top_logprobs in zip(
5656
tokens, token_probs, logprobs, strict=True
5757
):
5858
completion_logprobs = [
@@ -62,8 +62,8 @@ def _create_response(generations: List[dict], model_name: Optional[str]):
6262

6363
collected_token_probs.append(
6464
ChatCompletionLogprobChoice(
65-
token=output_token,
66-
logprob=token_logprob,
65+
token=generated_token,
66+
logprob=generated_token_logprob,
6767
top_logprobs=completion_logprobs,
6868
)
6969
)
@@ -112,22 +112,30 @@ def _create_stream_chunk(
112112
role="assistant", content=unwrap(generation.get("text"), "")
113113
)
114114

115+
logprob_response = None
116+
115117
token_probs = unwrap(generation.get("token_probs"), {})
116118
if token_probs:
117-
logprobs = unwrap(generation.get("logprobs"), {})
118-
top_logprobs = [
119-
ChatCompletionLogprob(token=token, logprob=logprob)
120-
for token, logprob in logprobs.items()
121-
]
122-
123-
generated_token = next(iter(token_probs))
124-
token_prob_response = ChatCompletionLogprob(
125-
token=generated_token,
126-
logprob=token_probs[generated_token],
127-
top_logprobs=top_logprobs,
128-
)
119+
tokens = unwrap(generation.get("tokens"), [])
120+
logprobs = unwrap(generation.get("logprobs"), [])
121+
collected_token_probs = []
122+
for generated_token, generated_token_logprob, top_logprobs in zip(
123+
tokens, token_probs, logprobs, strict=True
124+
):
125+
completion_logprobs = [
126+
ChatCompletionLogprob(token=token, logprob=token_logprob)
127+
for token, token_logprob in top_logprobs.items()
128+
]
129+
130+
collected_token_probs.append(
131+
ChatCompletionLogprobChoice(
132+
token=generated_token,
133+
logprob=generated_token_logprob,
134+
top_logprobs=completion_logprobs,
135+
)
136+
)
129137

130-
logprob_response = ChatCompletionLogprobs(content=[token_prob_response])
138+
logprob_response = ChatCompletionLogprobs(content=collected_token_probs)
131139

132140
choice = ChatCompletionStreamChoice(
133141
index=index,

0 commit comments

Comments
 (0)