@@ -139,16 +139,14 @@ def tokenize_trajectory(
139139 # Find the index of the last assistant message
140140 last_assistant_index = - 1
141141 for i , message in enumerate (history .messages_and_choices ):
142- if (
143- isinstance (message , dict )
144- and message ["role" ] == "assistant"
145- and allow_training_without_logprobs
146- ):
147- last_assistant_index = i
148- elif not isinstance (message , dict ) and (
149- message .logprobs or allow_training_without_logprobs
150- ):
151- last_assistant_index = i
142+ if isinstance (message , dict ):
143+ # Message dict
144+ if message ["role" ] == "assistant" and allow_training_without_logprobs :
145+ last_assistant_index = i
146+ else :
147+ # Choice object
148+ if message .logprobs is not None or allow_training_without_logprobs :
149+ last_assistant_index = i
152150 # If there are no trainable assistant messages, return None
153151 if last_assistant_index == - 1 :
154152 return None
@@ -189,7 +187,7 @@ def tokenize_trajectory(
189187 (
190188 message_or_choice
191189 if isinstance (message_or_choice , dict )
192- and not message_or_choice ["role" ] = = "assistant"
190+ and message_or_choice ["role" ] ! = "assistant"
193191 else {
194192 "role" : "assistant" ,
195193 "content" : sentinal_token ,
@@ -205,7 +203,7 @@ def tokenize_trajectory(
205203 assistant_mask : list [int ] = [0 ] * len (token_ids )
206204 logprobs = [float ("nan" )] * len (token_ids )
207205 for message in messages_and_choices :
208- if isinstance (message , dict ) and not message ["role" ] = = "assistant" :
206+ if isinstance (message , dict ) and message ["role" ] ! = "assistant" :
209207 continue
210208 start = token_ids .index (sentinal_token_id )
211209 end = start + 1
@@ -214,6 +212,7 @@ def tokenize_trajectory(
214212 except IndexError :
215213 end_token_id = None
216214 if isinstance (message , dict ):
215+ # Message dict
217216 content = message .get ("content" )
218217 assert isinstance (content , str )
219218 content_token_ids = tokenizer .encode (
@@ -224,6 +223,7 @@ def tokenize_trajectory(
224223 logprobs [start :end ] = [float ("nan" )] * len (content_token_ids )
225224 assistant_mask [start :end ] = [1 ] * len (content_token_ids )
226225 else :
226+ # Choice object
227227 choice = message
228228 assert choice .logprobs or allow_training_without_logprobs , (
229229 "Chat completion choices must have logprobs"
0 commit comments