6161log = logging .getLogger (__name__ )
6262JSON_HIGHLIGHTER = JSONHighlighter ()
6363
64- MAX_GENERATOR_TOKENS = 8092
6564
6665class OutlinesTransformersADM (ActionBasedADM ):
6766 def __init__ (self ,
@@ -193,7 +192,8 @@ def kdma_value_to_system_prompt(kdma, value):
193192 else :
194193 return None
195194
196- def _state_to_top_level_prompt (self , scenario_state , actions ):
195+ @staticmethod
196+ def _static_state_to_top_level_prompt (action_selection_prompt_template , scenario_description , scenario_state , actions ):
197197 """
198198 Generate prompt dialog based on given state and actions
199199 """
@@ -203,11 +203,23 @@ def _state_to_top_level_prompt(self, scenario_state, actions):
203203 scenario_state
204204 )
205205
206- scenario_description = self .scenario_description_template (scenario_state )
207- prompt = self .action_selection_prompt_template (scenario_description , choices )
206+ prompt = action_selection_prompt_template (scenario_description , choices )
208207
209208 return prompt , choices
210209
210+ def _state_to_top_level_prompt (self , scenario_state , actions ):
211+ """
212+ Generate prompt dialog based on given state and actions
213+ """
214+ scenario_description = self .scenario_description_template (scenario_state )
215+ return OutlinesTransformersADM ._static_state_to_top_level_prompt (
216+ self .action_selection_prompt_template ,
217+ scenario_description ,
218+ scenario_state ,
219+ actions
220+ )
221+
222+
211223 # Function borrowed from
212224 # https://docs.python.org/3/library/itertools.html#itertools.batched
213225 # (since itertools.batched is only available in Python 3.12 or newer):
@@ -231,24 +243,25 @@ def run_in_batches(cls, inference_function, inputs, batch_size):
231243 outputs .extend (output )
232244 return outputs
233245
234- def top_level_choose_action (self ,
235- scenario_state ,
236- available_actions ,
237- alignment_target ,
238- num_positive_samples = 1 ,
239- num_negative_samples = 0 ,
240- generator_batch_size = 5 ,
241- kdma_descriptions_map = 'align_system/prompt_engineering/kdma_descriptions.yml' ,
242- reasoning_max_length = 512 ,
243- generator_seed = - 1 ,
244- shuffle_choices = True ,
245- ** kwargs ):
246- if self .baseline and num_negative_samples > 0 :
246+ @staticmethod
247+ def get_dialogs (scenario_state ,
248+ available_actions ,
249+ alignment_target ,
250+ num_positive_samples = 1 ,
251+ num_negative_samples = 0 ,
252+ kdma_descriptions_map = 'align_system/prompt_engineering/kdma_descriptions.yml' ,
253+ shuffle_choices = True ,
254+ baseline = False ,
255+ scenario_description_template = scenario_state_description_1 ,
256+ action_selection_prompt_template = action_selection_prompt ,
257+ baseline_system_prompt = baseline_system_prompt ,
258+ ** kwargs ):
259+ if baseline and num_negative_samples > 0 :
247260 raise RuntimeError ("No notion of negative samples for baseline run" )
248- if self . baseline and "incontext" in kwargs and kwargs ["incontext" ]["number" ] > 0 :
261+ if baseline and "incontext" in kwargs and kwargs ["incontext" ]["number" ] > 0 :
249262 raise RuntimeError ("No notion of incontext examples for baseline run" )
250263
251- scenario_description = self . scenario_description_template (scenario_state )
264+ scenario_description = scenario_description_template (scenario_state )
252265 # Important that the choices stay in the same order as the
253266 # available actions as we'll use the selected index later to
254267 # map to the corresponding action
@@ -261,12 +274,11 @@ def top_level_choose_action(self,
261274 positive_icl_examples = []
262275 negative_icl_examples = []
263276 incontext_settings = kwargs .get ("incontext" , {})
264- if not self .baseline and alignment_target is not None :
265- kdma_values = alignment_target .kdma_values
266277
278+ if not baseline and alignment_target is not None :
279+ kdma_values = alignment_target .kdma_values
267280 if len (kdma_values ) != 1 :
268281 raise RuntimeError ("This ADM assumes a single KDMA target, aborting!" )
269-
270282 kdma_value = kdma_values [0 ]
271283 if isinstance (kdma_value , KDMAValue ):
272284 kdma_value = kdma_value .to_dict ()
@@ -280,8 +292,8 @@ def top_level_choose_action(self,
280292 kdma_descriptions = yaml .load (f , Loader = yaml .FullLoader )
281293 name = kdma_descriptions [kdma ]['name' ]
282294
283- positive_system_prompt = self . __class__ .kdma_value_to_system_prompt (kdma , value )
284- negative_system_prompt = self . __class__ .kdma_value_to_system_prompt (kdma , negative_value )
295+ positive_system_prompt = OutlinesTransformersADM .kdma_value_to_system_prompt (kdma , value )
296+ negative_system_prompt = OutlinesTransformersADM .kdma_value_to_system_prompt (kdma , negative_value )
285297
286298 if positive_system_prompt is None :
287299 raise RuntimeError ("Couldn't find system prompt for kdma: {}, and "
@@ -291,8 +303,7 @@ def top_level_choose_action(self,
291303 "value: {}." .format (kdma , negative_value ))
292304
293305 if "incontext" in kwargs and "number" in incontext_settings and incontext_settings ["number" ] > 0 :
294- scenario_to_match = self .scenario_description_template (scenario_state )
295- prompt_to_match , _ = self ._state_to_top_level_prompt (scenario_state , available_actions )
306+ prompt_to_match , _ = OutlinesTransformersADM ._state_to_top_level_prompt (action_selection_prompt_template , scenario_state , available_actions )
296307
297308 # Create positive ICL example generators
298309 positive_target = {'kdma' : kdma , 'name' : name , 'value' : value }
@@ -301,7 +312,7 @@ def top_level_choose_action(self,
301312 # Get subset of relevant of examples
302313 positive_selected_icl_examples = positive_icl_example_generator .select_icl_examples (
303314 sys_kdma_name = kdma ,
304- scenario_description_to_match = scenario_to_match ,
315+ scenario_description_to_match = scenario_description ,
305316 prompt_to_match = prompt_to_match ,
306317 state_comparison = scenario_state
307318 )
@@ -321,7 +332,7 @@ def top_level_choose_action(self,
321332 # Get subset of relevant of examples
322333 negative_selected_icl_examples = negative_icl_example_generator .select_icl_examples (
323334 sys_kdma_name = kdma ,
324- scenario_description_to_match = scenario_to_match ,
335+ scenario_description_to_match = scenario_description ,
325336 prompt_to_match = prompt_to_match ,
326337 state_comparison = scenario_state
327338 )
@@ -331,17 +342,17 @@ def top_level_choose_action(self,
331342 {"role" : "assistant" , "content" : f'{ icl_sample ["response" ]} ' }
332343 ])
333344 else :
334- positive_system_prompt = self . baseline_system_prompt ()
345+ positive_system_prompt = baseline_system_prompt ()
335346 if num_negative_samples > 0 :
336347 raise RuntimeError ("No notion of negative samples for baseline run" )
337348 if "incontext" in kwargs and kwargs ["incontext" ]["number" ] > 0 :
338349 raise RuntimeError ("No notion of incontext examples for baseline run" )
350+ negative_system_prompt = None # Not used in baseline
339351
340352 positive_dialogs = []
341353 for _ in range (num_positive_samples ):
342- shuffled_choices = random .sample (choices , len (choices )) if shuffle_choices else choices
343-
344- prompt = self .action_selection_prompt_template (scenario_description , shuffled_choices )
354+ shuf = random .sample (choices , len (choices )) if shuffle_choices else choices
355+ prompt = action_selection_prompt (scenario_description , shuf )
345356 dialog = [{'role' : 'system' , 'content' : positive_system_prompt }]
346357 dialog .extend (positive_icl_examples )
347358 dialog .append ({'role' : 'user' , 'content' : prompt })
@@ -350,15 +361,55 @@ def top_level_choose_action(self,
350361
351362 negative_dialogs = []
352363 for _ in range (num_negative_samples ):
353- shuffled_choices = random .sample (choices , len (choices )) if shuffle_choices else choices
354-
355- prompt = self .action_selection_prompt_template (scenario_description , shuffled_choices )
364+ shuf = random .sample (choices , len (choices )) if shuffle_choices else choices
365+ prompt = action_selection_prompt (scenario_description , shuf )
356366 dialog = [{'role' : 'system' , 'content' : negative_system_prompt }]
357367 dialog .extend (negative_icl_examples )
358368 dialog .append ({'role' : 'user' , 'content' : prompt })
359-
360369 negative_dialogs .append (dialog )
361370
371+ return {"scenario_description" : scenario_description ,
372+ "choices" : choices ,
373+ "positive_system_prompt" : positive_system_prompt ,
374+ "negative_system_prompt" : negative_system_prompt ,
375+ "positive_dialogs" : positive_dialogs ,
376+ "negative_dialogs" : negative_dialogs }
377+
378+ def top_level_choose_action (self ,
379+ scenario_state ,
380+ available_actions ,
381+ alignment_target ,
382+ num_positive_samples = 1 ,
383+ num_negative_samples = 0 ,
384+ generator_batch_size = 5 ,
385+ kdma_descriptions_map = 'align_system/prompt_engineering/kdma_descriptions.yml' ,
386+ reasoning_max_length = 512 ,
387+ generator_seed = - 1 ,
388+ max_generator_tokens = - 1 ,
389+ shuffle_choices = True ,
390+ ** kwargs ):
391+ if self .baseline and num_negative_samples > 0 :
392+ raise RuntimeError ("No notion of negative samples for baseline run" )
393+ if self .baseline and "incontext" in kwargs and kwargs ["incontext" ]["number" ] > 0 :
394+ raise RuntimeError ("No notion of incontext examples for baseline run" )
395+
396+ dialogs_data = OutlinesTransformersADM .get_dialogs (
397+ scenario_state ,
398+ available_actions ,
399+ alignment_target ,
400+ num_positive_samples ,
401+ num_negative_samples ,
402+ kdma_descriptions_map ,
403+ shuffle_choices ,
404+ baseline = self .baseline ,
405+ scenario_description_template = self .scenario_description_template ,
406+ action_selection_prompt_template = self .action_selection_prompt_template ,
407+ baseline_system_prompt = self .baseline_system_prompt ,
408+ )
409+ choices = dialogs_data ["choices" ]
410+ positive_dialogs = dialogs_data ["positive_dialogs" ]
411+ negative_dialogs = dialogs_data ["negative_dialogs" ]
412+
362413 # Need to set the whitespace_pattern to prevent the state
363414 # machine from looping indefinitely in some cases, see:
364415 # https://github.com/outlines-dev/outlines/issues/690#issuecomment-2102291934
@@ -367,12 +418,14 @@ def top_level_choose_action(self,
367418 action_choice_json_schema (json .dumps (choices ), reasoning_max_length ),
368419 sampler = self .sampler ,
369420 whitespace_pattern = r"[ ]?" )
421+
422+ if max_generator_tokens >= 0 :
423+ generator = partial (generator , max_tokens = max_generator_tokens )
370424
371425 if generator_seed >= 0 :
372426 torch .manual_seed (generator_seed )
373427 if torch .cuda .is_available ():
374428 torch .cuda .manual_seed (generator_seed )
375- generator = partial (generator , max_tokens = MAX_GENERATOR_TOKENS )
376429
377430
378431 dialog_texts = [self .dialog_to_prompt (d ) for d in
0 commit comments