@@ -767,21 +767,25 @@ def get_special_tokens(
767767 }
768768
769769 def get_logprobs (self , token_ids : torch .Tensor , token_probs : torch .Tensor ):
770- top_tokens = [
771- self .tokenizer .extended_id_to_piece .get (
772- index , self .tokenizer .id_to_piece [index ]
773- )
774- for index in token_ids .flatten ().tolist ()
775- ]
770+ logprobs = []
771+ for token_idx in range (token_ids .shape [1 ]):
772+ top_tokens = [
773+ self .tokenizer .extended_id_to_piece .get (
774+ index , self .tokenizer .id_to_piece [index ]
775+ )
776+ for index in token_ids [0 , token_idx ].tolist ()
777+ ]
776778
777- top_values = torch .log (token_probs ). flatten ( ).tolist ()
779+ top_values = torch .log (token_probs [ 0 , token_idx ] ).tolist ()
778780
779- # Cannot return -inf in JSON
780- cleaned_values = [
781- - 1000 if value == float ("-inf" ) else value for value in top_values
782- ]
781+ # Cannot return -inf in JSON
782+ cleaned_values = [
783+ - 1000 if value == float ("-inf" ) else value for value in top_values
784+ ]
783785
784- return dict (zip_longest (top_tokens , cleaned_values ))
786+ logprobs .append (dict (zip_longest (top_tokens , cleaned_values )))
787+
788+ return logprobs
785789
786790 async def generate (self , prompt : str , ** kwargs ):
787791 """Generate a response to a prompt"""
@@ -793,8 +797,9 @@ async def generate(self, prompt: str, **kwargs):
793797 "text" : "" ,
794798 "prompt_tokens" : 0 ,
795799 "generation_tokens" : 0 ,
800+ "tokens" : [],
796801 "offset" : [],
797- "token_probs" : {} ,
802+ "token_probs" : [] ,
798803 "logprobs" : [],
799804 }
800805
@@ -811,13 +816,14 @@ async def generate(self, prompt: str, **kwargs):
811816 if len (generations ) > 0 :
812817 for generation in generations :
813818 joined_generation ["text" ] += unwrap (generation .get ("text" ), "" )
814- joined_generation ["offset" ].append (unwrap (generation .get ("offset" ), - 1 ))
815- joined_generation ["token_probs" ].update (
816- unwrap (generation .get ("token_probs" ), {})
819+ joined_generation ["tokens" ].extend (unwrap (generation .get ("tokens" ), []))
820+ joined_generation ["offset" ].extend (unwrap (generation .get ("offset" ), []))
821+ joined_generation ["token_probs" ].extend (
822+ unwrap (generation .get ("token_probs" ), [])
817823 )
818824
819825 # Include empty logprob dicts for index preservation
820- joined_generation ["logprobs" ].append (
826+ joined_generation ["logprobs" ].extend (
821827 unwrap (generation .get ("logprobs" ), {})
822828 )
823829
@@ -1145,7 +1151,6 @@ async def generate_gen(
11451151 "text" : chunk ,
11461152 "prompt_tokens" : context_len ,
11471153 "generated_tokens" : generated_tokens ,
1148- "offset" : len (full_response ),
11491154 }
11501155
11511156 if request_logprobs > 0 :
@@ -1164,11 +1169,41 @@ async def generate_gen(
11641169 logprobs = self .get_logprobs (top_tokens , top_probs )
11651170 generation ["logprobs" ] = logprobs
11661171
1167- # The first logprob is the selected token prob
1168- generation ["token_probs" ] = {
1169- token : logprobs [token ]
1170- for token in list (logprobs .keys ())[:1 ]
1171- }
1172+ token_ids = unwrap (
1173+ result .get ("token_ids" ),
1174+ torch .empty (0 ),
1175+ )
1176+
1177+ token_probs = unwrap (
1178+ result .get ("token_probs" ),
1179+ torch .empty (0 ),
1180+ )
1181+
1182+ if token_ids .numel () > 0 and token_probs .numel () > 0 :
1183+ token_ids = token_ids .flatten ().tolist ()
1184+ token_probs = token_probs .flatten ().tolist ()
1185+
1186+ tokens = [
1187+ self .tokenizer .extended_id_to_piece .get (
1188+ index , self .tokenizer .id_to_piece [index ]
1189+ )
1190+ for index in token_ids
1191+ ]
1192+
1193+ generation ["tokens" ] = tokens
1194+ generation ["token_probs" ] = [
1195+ math .log (prob ) for prob in token_probs
1196+ ]
1197+
1198+ # Calculate the offset of each token in the output,
1199+ # working backwards from the end.
1200+ offsets = []
1201+ token_offset = 0
1202+ for token in tokens :
1203+ token_offset += len (token )
1204+ offsets .append (len (full_response ) - token_offset )
1205+ offsets .reverse ()
1206+ generation ["offset" ] = offsets
11721207
11731208 yield generation
11741209
0 commit comments