@@ -47,26 +47,24 @@ def _create_response(generations: List[dict], model_name: Optional[str]):
4747
4848 logprob_response = None
4949
50- tokens = unwrap (generation .get ("tokens" ), [])
5150 token_probs = unwrap (generation .get ("token_probs" ), [])
52- logprobs = unwrap (generation .get ("logprobs" ), [])
5351 if token_probs :
52+ tokens = unwrap (generation .get ("tokens" ), [])
53+ logprobs = unwrap (generation .get ("logprobs" ), [])
5454 collected_token_probs = []
55- for output_token , token_logprob , top_logprobs in zip (
55+ for generated_token , generated_token_logprob , top_logprobs in zip (
5656 tokens , token_probs , logprobs , strict = True
5757 ):
5858 completion_logprobs = [
5959 ChatCompletionLogprob (token = token , logprob = token_logprob )
6060 for token , token_logprob in top_logprobs .items ()
6161 ]
6262
63- collected_token_probs .append (
64- ChatCompletionLogprobChoice (
65- token = output_token ,
66- logprob = token_logprob ,
67- top_logprobs = completion_logprobs ,
68- )
69- )
63+ collected_token_probs .append (ChatCompletionLogprobChoice (
64+ token = generated_token ,
65+ logprob = generated_token_logprob ,
66+ top_logprobs = completion_logprobs ,
67+ ))
7068
7169 logprob_response = ChatCompletionLogprobs (content = collected_token_probs )
7270
@@ -112,22 +110,28 @@ def _create_stream_chunk(
112110 role = "assistant" , content = unwrap (generation .get ("text" ), "" )
113111 )
114112
113+ logprob_response = None
114+
115115 token_probs = unwrap (generation .get ("token_probs" ), {})
116116 if token_probs :
117- logprobs = unwrap (generation .get ("logprobs" ), {})
118- top_logprobs = [
119- ChatCompletionLogprob (token = token , logprob = logprob )
120- for token , logprob in logprobs .items ()
121- ]
122-
123- generated_token = next (iter (token_probs ))
124- token_prob_response = ChatCompletionLogprob (
125- token = generated_token ,
126- logprob = token_probs [generated_token ],
127- top_logprobs = top_logprobs ,
128- )
117+ tokens = unwrap (generation .get ("tokens" ), [])
118+ logprobs = unwrap (generation .get ("logprobs" ), [])
119+ collected_token_probs = []
120+ for generated_token , generated_token_logprob , top_logprobs in zip (
121+ tokens , token_probs , logprobs , strict = True
122+ ):
123+ completion_logprobs = [
124+ ChatCompletionLogprob (token = token , logprob = token_logprob )
125+ for token , token_logprob in top_logprobs .items ()
126+ ]
127+
128+ collected_token_probs .append (ChatCompletionLogprobChoice (
129+ token = generated_token ,
130+ logprob = generated_token_logprob ,
131+ top_logprobs = completion_logprobs ,
132+ ))
129133
130- logprob_response = ChatCompletionLogprobs (content = [ token_prob_response ] )
134+ logprob_response = ChatCompletionLogprobs (content = collected_token_probs )
131135
132136 choice = ChatCompletionStreamChoice (
133137 index = index ,
0 commit comments