@@ -47,12 +47,12 @@ def _create_response(generations: List[dict], model_name: Optional[str]):
4747
4848 logprob_response = None
4949
50- tokens = unwrap (generation .get ("tokens" ), [])
5150 token_probs = unwrap (generation .get ("token_probs" ), [])
52- logprobs = unwrap (generation .get ("logprobs" ), [])
5351 if token_probs :
52+ tokens = unwrap (generation .get ("tokens" ), [])
53+ logprobs = unwrap (generation .get ("logprobs" ), [])
5454 collected_token_probs = []
55- for output_token , token_logprob , top_logprobs in zip (
55+ for generated_token , generated_token_logprob , top_logprobs in zip (
5656 tokens , token_probs , logprobs , strict = True
5757 ):
5858 completion_logprobs = [
@@ -62,8 +62,8 @@ def _create_response(generations: List[dict], model_name: Optional[str]):
6262
6363 collected_token_probs .append (
6464 ChatCompletionLogprobChoice (
65- token = output_token ,
66- logprob = token_logprob ,
65+ token = generated_token ,
66+ logprob = generated_token_logprob ,
6767 top_logprobs = completion_logprobs ,
6868 )
6969 )
@@ -112,22 +112,30 @@ def _create_stream_chunk(
112112 role = "assistant" , content = unwrap (generation .get ("text" ), "" )
113113 )
114114
115+ logprob_response = None
116+
115117 token_probs = unwrap (generation .get ("token_probs" ), {})
116118 if token_probs :
117- logprobs = unwrap (generation .get ("logprobs" ), {})
118- top_logprobs = [
119- ChatCompletionLogprob (token = token , logprob = logprob )
120- for token , logprob in logprobs .items ()
121- ]
122-
123- generated_token = next (iter (token_probs ))
124- token_prob_response = ChatCompletionLogprob (
125- token = generated_token ,
126- logprob = token_probs [generated_token ],
127- top_logprobs = top_logprobs ,
128- )
119+ tokens = unwrap (generation .get ("tokens" ), [])
120+ logprobs = unwrap (generation .get ("logprobs" ), [])
121+ collected_token_probs = []
122+ for generated_token , generated_token_logprob , top_logprobs in zip (
123+ tokens , token_probs , logprobs , strict = True
124+ ):
125+ completion_logprobs = [
126+ ChatCompletionLogprob (token = token , logprob = token_logprob )
127+ for token , token_logprob in top_logprobs .items ()
128+ ]
129+
130+ collected_token_probs .append (
131+ ChatCompletionLogprobChoice (
132+ token = generated_token ,
133+ logprob = generated_token_logprob ,
134+ top_logprobs = completion_logprobs ,
135+ )
136+ )
129137
130- logprob_response = ChatCompletionLogprobs (content = [ token_prob_response ] )
138+ logprob_response = ChatCompletionLogprobs (content = collected_token_probs )
131139
132140 choice = ChatCompletionStreamChoice (
133141 index = index ,
0 commit comments