Skip to content

Commit e6ed50d

Browse files
committed
Factor get_dialogs to static method in outlines_adm
1 parent 35b496c commit e6ed50d

1 file changed

Lines changed: 90 additions & 37 deletions

File tree

align_system/algorithms/outlines_adm.py

Lines changed: 90 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@
6161
log = logging.getLogger(__name__)
6262
JSON_HIGHLIGHTER = JSONHighlighter()
6363

64-
MAX_GENERATOR_TOKENS = 8092
6564

6665
class OutlinesTransformersADM(ActionBasedADM):
6766
def __init__(self,
@@ -193,7 +192,8 @@ def kdma_value_to_system_prompt(kdma, value):
193192
else:
194193
return None
195194

196-
def _state_to_top_level_prompt(self, scenario_state, actions):
195+
@staticmethod
196+
def _static_state_to_top_level_prompt(action_selection_prompt_template, scenario_description, scenario_state, actions):
197197
"""
198198
Generate prompt dialog based on given state and actions
199199
"""
@@ -203,11 +203,23 @@ def _state_to_top_level_prompt(self, scenario_state, actions):
203203
scenario_state
204204
)
205205

206-
scenario_description = self.scenario_description_template(scenario_state)
207-
prompt = self.action_selection_prompt_template(scenario_description, choices)
206+
prompt = action_selection_prompt_template(scenario_description, choices)
208207

209208
return prompt, choices
210209

210+
def _state_to_top_level_prompt(self, scenario_state, actions):
211+
"""
212+
Generate prompt dialog based on given state and actions
213+
"""
214+
scenario_description = self.scenario_description_template(scenario_state)
215+
return OutlinesTransformersADM._static_state_to_top_level_prompt(
216+
self.action_selection_prompt_template,
217+
scenario_description,
218+
scenario_state,
219+
actions
220+
)
221+
222+
211223
# Function borrowed from
212224
# https://docs.python.org/3/library/itertools.html#itertools.batched
213225
# (since itertools.batched is only available in Python 3.12 or newer):
@@ -231,24 +243,25 @@ def run_in_batches(cls, inference_function, inputs, batch_size):
231243
outputs.extend(output)
232244
return outputs
233245

234-
def top_level_choose_action(self,
235-
scenario_state,
236-
available_actions,
237-
alignment_target,
238-
num_positive_samples=1,
239-
num_negative_samples=0,
240-
generator_batch_size=5,
241-
kdma_descriptions_map='align_system/prompt_engineering/kdma_descriptions.yml',
242-
reasoning_max_length=512,
243-
generator_seed = -1,
244-
shuffle_choices=True,
245-
**kwargs):
246-
if self.baseline and num_negative_samples > 0:
246+
@staticmethod
247+
def get_dialogs(scenario_state,
248+
available_actions,
249+
alignment_target,
250+
num_positive_samples=1,
251+
num_negative_samples=0,
252+
kdma_descriptions_map='align_system/prompt_engineering/kdma_descriptions.yml',
253+
shuffle_choices=True,
254+
baseline=False,
255+
scenario_description_template=scenario_state_description_1,
256+
action_selection_prompt_template=action_selection_prompt,
257+
baseline_system_prompt=baseline_system_prompt,
258+
**kwargs):
259+
if baseline and num_negative_samples > 0:
247260
raise RuntimeError("No notion of negative samples for baseline run")
248-
if self.baseline and "incontext" in kwargs and kwargs["incontext"]["number"] > 0:
261+
if baseline and "incontext" in kwargs and kwargs["incontext"]["number"] > 0:
249262
raise RuntimeError("No notion of incontext examples for baseline run")
250263

251-
scenario_description = self.scenario_description_template(scenario_state)
264+
scenario_description = scenario_description_template(scenario_state)
252265
# Important that the choices stay in the same order as the
253266
# available actions as we'll use the selected index later to
254267
# map to the corresponding action
@@ -261,12 +274,11 @@ def top_level_choose_action(self,
261274
positive_icl_examples = []
262275
negative_icl_examples = []
263276
incontext_settings=kwargs.get("incontext", {})
264-
if not self.baseline and alignment_target is not None:
265-
kdma_values = alignment_target.kdma_values
266277

278+
if not baseline and alignment_target is not None:
279+
kdma_values = alignment_target.kdma_values
267280
if len(kdma_values) != 1:
268281
raise RuntimeError("This ADM assumes a single KDMA target, aborting!")
269-
270282
kdma_value = kdma_values[0]
271283
if isinstance(kdma_value, KDMAValue):
272284
kdma_value = kdma_value.to_dict()
@@ -280,8 +292,8 @@ def top_level_choose_action(self,
280292
kdma_descriptions = yaml.load(f, Loader=yaml.FullLoader)
281293
name = kdma_descriptions[kdma]['name']
282294

283-
positive_system_prompt = self.__class__.kdma_value_to_system_prompt(kdma, value)
284-
negative_system_prompt = self.__class__.kdma_value_to_system_prompt(kdma, negative_value)
295+
positive_system_prompt = OutlinesTransformersADM.kdma_value_to_system_prompt(kdma, value)
296+
negative_system_prompt = OutlinesTransformersADM.kdma_value_to_system_prompt(kdma, negative_value)
285297

286298
if positive_system_prompt is None:
287299
raise RuntimeError("Couldn't find system prompt for kdma: {}, and "
@@ -291,8 +303,7 @@ def top_level_choose_action(self,
291303
"value: {}.".format(kdma, negative_value))
292304

293305
if "incontext" in kwargs and "number" in incontext_settings and incontext_settings["number"] > 0:
294-
scenario_to_match = self.scenario_description_template(scenario_state)
295-
prompt_to_match, _ = self._state_to_top_level_prompt(scenario_state, available_actions)
306+
prompt_to_match, _ = OutlinesTransformersADM._state_to_top_level_prompt(action_selection_prompt_template, scenario_state, available_actions)
296307

297308
# Create positive ICL example generators
298309
positive_target = {'kdma': kdma, 'name': name, 'value': value}
@@ -301,7 +312,7 @@ def top_level_choose_action(self,
301312
# Get subset of relevant of examples
302313
positive_selected_icl_examples = positive_icl_example_generator.select_icl_examples(
303314
sys_kdma_name=kdma,
304-
scenario_description_to_match=scenario_to_match,
315+
scenario_description_to_match=scenario_description,
305316
prompt_to_match=prompt_to_match,
306317
state_comparison=scenario_state
307318
)
@@ -321,7 +332,7 @@ def top_level_choose_action(self,
321332
# Get subset of relevant of examples
322333
negative_selected_icl_examples = negative_icl_example_generator.select_icl_examples(
323334
sys_kdma_name=kdma,
324-
scenario_description_to_match=scenario_to_match,
335+
scenario_description_to_match=scenario_description,
325336
prompt_to_match=prompt_to_match,
326337
state_comparison=scenario_state
327338
)
@@ -331,17 +342,17 @@ def top_level_choose_action(self,
331342
{"role": "assistant", "content": f'{icl_sample["response"]}'}
332343
])
333344
else:
334-
positive_system_prompt = self.baseline_system_prompt()
345+
positive_system_prompt = baseline_system_prompt()
335346
if num_negative_samples > 0:
336347
raise RuntimeError("No notion of negative samples for baseline run")
337348
if "incontext" in kwargs and kwargs["incontext"]["number"] > 0:
338349
raise RuntimeError("No notion of incontext examples for baseline run")
350+
negative_system_prompt = None # Not used in baseline
339351

340352
positive_dialogs = []
341353
for _ in range(num_positive_samples):
342-
shuffled_choices = random.sample(choices, len(choices)) if shuffle_choices else choices
343-
344-
prompt = self.action_selection_prompt_template(scenario_description, shuffled_choices)
354+
shuf = random.sample(choices, len(choices)) if shuffle_choices else choices
355+
prompt = action_selection_prompt(scenario_description, shuf)
345356
dialog = [{'role': 'system', 'content': positive_system_prompt}]
346357
dialog.extend(positive_icl_examples)
347358
dialog.append({'role': 'user', 'content': prompt})
@@ -350,15 +361,55 @@ def top_level_choose_action(self,
350361

351362
negative_dialogs = []
352363
for _ in range(num_negative_samples):
353-
shuffled_choices = random.sample(choices, len(choices)) if shuffle_choices else choices
354-
355-
prompt = self.action_selection_prompt_template(scenario_description, shuffled_choices)
364+
shuf = random.sample(choices, len(choices)) if shuffle_choices else choices
365+
prompt = action_selection_prompt(scenario_description, shuf)
356366
dialog = [{'role': 'system', 'content': negative_system_prompt}]
357367
dialog.extend(negative_icl_examples)
358368
dialog.append({'role': 'user', 'content': prompt})
359-
360369
negative_dialogs.append(dialog)
361370

371+
return {"scenario_description": scenario_description,
372+
"choices": choices,
373+
"positive_system_prompt": positive_system_prompt,
374+
"negative_system_prompt": negative_system_prompt,
375+
"positive_dialogs": positive_dialogs,
376+
"negative_dialogs": negative_dialogs}
377+
378+
def top_level_choose_action(self,
379+
scenario_state,
380+
available_actions,
381+
alignment_target,
382+
num_positive_samples=1,
383+
num_negative_samples=0,
384+
generator_batch_size=5,
385+
kdma_descriptions_map='align_system/prompt_engineering/kdma_descriptions.yml',
386+
reasoning_max_length=512,
387+
generator_seed=-1,
388+
max_generator_tokens=-1,
389+
shuffle_choices=True,
390+
**kwargs):
391+
if self.baseline and num_negative_samples > 0:
392+
raise RuntimeError("No notion of negative samples for baseline run")
393+
if self.baseline and "incontext" in kwargs and kwargs["incontext"]["number"] > 0:
394+
raise RuntimeError("No notion of incontext examples for baseline run")
395+
396+
dialogs_data = OutlinesTransformersADM.get_dialogs(
397+
scenario_state,
398+
available_actions,
399+
alignment_target,
400+
num_positive_samples,
401+
num_negative_samples,
402+
kdma_descriptions_map,
403+
shuffle_choices,
404+
baseline=self.baseline,
405+
scenario_description_template=self.scenario_description_template,
406+
action_selection_prompt_template=self.action_selection_prompt_template,
407+
baseline_system_prompt=self.baseline_system_prompt,
408+
)
409+
choices = dialogs_data["choices"]
410+
positive_dialogs = dialogs_data["positive_dialogs"]
411+
negative_dialogs = dialogs_data["negative_dialogs"]
412+
362413
# Need to set the whitespace_pattern to prevent the state
363414
# machine from looping indefinitely in some cases, see:
364415
# https://github.com/outlines-dev/outlines/issues/690#issuecomment-2102291934
@@ -367,12 +418,14 @@ def top_level_choose_action(self,
367418
action_choice_json_schema(json.dumps(choices), reasoning_max_length),
368419
sampler=self.sampler,
369420
whitespace_pattern=r"[ ]?")
421+
422+
if max_generator_tokens >= 0:
423+
generator = partial(generator, max_tokens=max_generator_tokens)
370424

371425
if generator_seed >= 0:
372426
torch.manual_seed(generator_seed)
373427
if torch.cuda.is_available():
374428
torch.cuda.manual_seed(generator_seed)
375-
generator = partial(generator, max_tokens=MAX_GENERATOR_TOKENS)
376429

377430

378431
dialog_texts = [self.dialog_to_prompt(d) for d in

0 commit comments

Comments
 (0)