Skip to content
This repository was archived by the owner on Nov 19, 2025. It is now read-only.

Commit a112c19

Browse files
committed
feat: dpo dataset new openai chat completion format
1 parent eb2db8b commit a112c19

5 files changed

Lines changed: 226 additions & 85 deletions

File tree

examples/nlp/gpt/train_gpt_sft.py

Lines changed: 6 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
resolve_and_create_trainer,
4040
retrieve_custom_trainer_state_dict,
4141
)
42-
from nemo_aligner.utils.utils import load_from_nemo
42+
from nemo_aligner.utils.utils import load_and_override_model_config, load_from_nemo
4343

4444
"""Script to start SFT training"""
4545

@@ -49,75 +49,10 @@
4949
mp.set_start_method("spawn", force=True)
5050

5151

52-
def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False):
53-
"""
54-
This function modifies the original gpt pre-training config (gpt_cfg) with attributes from the finetuning config (cfg).
55-
The `add_cfg_to_tree` arg adds `cfg` to the top of the yaml tree which is needed for all `hparams.yaml` files when passed as an arg to `load_from_checkpoint()`.
56-
"""
57-
OmegaConf.set_struct(gpt_cfg, True)
58-
OmegaConf.resolve(cfg)
59-
with open_dict(gpt_cfg):
60-
gpt_cfg.megatron_amp_O2 = cfg.model.get("megatron_amp_O2", False)
61-
gpt_cfg.micro_batch_size = cfg.model.data.train_ds.micro_batch_size
62-
gpt_cfg.global_batch_size = cfg.model.data.train_ds.global_batch_size
63-
gpt_cfg.sequence_parallel = cfg.model.get("sequence_parallel", False)
64-
gpt_cfg.activations_checkpoint_granularity = cfg.model.get("activations_checkpoint_granularity", None)
65-
gpt_cfg.activations_checkpoint_num_layers = cfg.model.get("activations_checkpoint_num_layers", None)
66-
gpt_cfg.activations_checkpoint_method = cfg.model.get("activations_checkpoint_method", None)
67-
gpt_cfg.activations_checkpoint_layers_per_pipeline = cfg.model.get(
68-
"activations_checkpoint_layers_per_pipeline", None
69-
)
70-
gpt_cfg.peft = cfg.model.peft
71-
gpt_cfg.data = cfg.model.data
72-
gpt_cfg.optim = cfg.model.optim
73-
gpt_cfg.precision = cfg.trainer.precision
74-
gpt_cfg.answer_only_loss = cfg.model.answer_only_loss
75-
gpt_cfg.restore_from_path = cfg.model.restore_from_path
76-
gpt_cfg.resume_from_checkpoint = cfg.model.resume_from_checkpoint
77-
gpt_cfg.save_nemo_on_validation_end = cfg.model.save_nemo_on_validation_end
78-
gpt_cfg.gradient_as_bucket_view = cfg.model.gradient_as_bucket_view
79-
gpt_cfg.hidden_dropout = cfg.model.get("hidden_dropout", 0.0)
80-
gpt_cfg.attention_dropout = cfg.model.get("attention_dropout", 0.0)
81-
gpt_cfg.ffn_dropout = cfg.model.ffn_dropout
82-
gpt_cfg.use_flash_attention = cfg.model.get("use_flash_attention", False)
83-
# if TP/PP size is -1, use default TP/PP size as original model
84-
if cfg.model.get("tensor_model_parallel_size", 1) > 0:
85-
gpt_cfg.tensor_model_parallel_size = cfg.model.get("tensor_model_parallel_size", 1)
86-
if cfg.model.get("pipeline_model_parallel_size", 1) > 0:
87-
gpt_cfg.pipeline_model_parallel_size = cfg.model.get("pipeline_model_parallel_size", 1)
88-
gpt_cfg.pipeline_model_parallel_split_rank = cfg.model.get("pipeline_model_parallel_split_rank", 0)
89-
90-
if cfg.model.data.get("chat", False):
91-
# chat model, overwrite the prompt template
92-
prompt_template = get_prompt_template_example(cfg.model.data.chat_prompt_tokens)
93-
gpt_cfg.data.train_ds.prompt_template = prompt_template
94-
gpt_cfg.data.validation_ds.prompt_template = prompt_template
95-
96-
sft_cls = GPTSFTModel
97-
gpt_cfg.target = f"{sft_cls.__module__}.{sft_cls.__name__}"
98-
99-
if cfg.model.get("use_flash_attention", None) is not None:
100-
gpt_cfg.use_flash_attention = cfg.model.use_flash_attention
101-
102-
if cfg.model.get("seq_len_interpolation_factor", None) is not None:
103-
gpt_cfg.seq_len_interpolation_factor = cfg.model.seq_len_interpolation_factor
104-
105-
if cfg.model.get("dist_ckpt_load_strictness", None) is not None:
106-
gpt_cfg.dist_ckpt_load_strictness = cfg.model.dist_ckpt_load_strictness
107-
108-
gpt_cfg.inference = cfg.model.get("inference", {})
109-
110-
# This is needed when modifying a hparam file directly to load `.ckpt` files.
111-
# This is not needed to modify the cfg in `.nemo` files.
112-
if add_cfg_to_tree:
113-
OmegaConf.resolve(gpt_cfg)
114-
gpt_cfg.cfg = gpt_cfg
115-
116-
return gpt_cfg
117-
118-
11952
@hydra_runner(config_path="conf", config_name="gpt_sft")
12053
def main(cfg) -> None:
54+
cfg.model = load_and_override_model_config(cfg.model.restore_from_path, cfg.model)
55+
12156
logging.info("\n\n************** Experiment configuration ***********")
12257
logging.info(f"\n{OmegaConf.to_yaml(cfg)}")
12358

@@ -129,17 +64,11 @@ def main(cfg) -> None:
12964
with open_dict(cfg):
13065
cfg.model.precision = cfg.trainer.precision
13166

132-
ptl_model, updated_cfg = load_from_nemo(
133-
GPTSFTModel,
134-
cfg,
135-
trainer,
136-
strict=True,
137-
modify_config_fn=_modify_config,
138-
restore_path=cfg.model.restore_from_path,
139-
return_updated_cfg=True,
67+
ptl_model = load_from_nemo(
68+
GPTSFTModel, cfg, trainer, strict=True, restore_path=cfg.model.restore_from_path, return_updated_cfg=False,
14069
)
14170

142-
init_peft(ptl_model, updated_cfg)
71+
init_peft(ptl_model, cfg.model)
14372

14473
with open_dict(cfg):
14574
# overwrite the model config with the config from the checkpoint

nemo_aligner/data/nlp/datasets.py

Lines changed: 95 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,19 @@
1515
"""Custom datasets for RLHF training"""
1616

1717
import os
18+
from typing import Dict, List
1819

1920
import numpy as np
2021
import scipy
2122
import torch
23+
from omegaconf import OmegaConf
2224

2325
from nemo.collections.nlp.data.language_modeling.megatron.gpt_dataset import _create_ltor_masks_and_position_ids
24-
from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_chat_dataset import GPTSFTChatDataset
26+
from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_chat_dataset import (
27+
GPTSFTChatDataset,
28+
_get_header_conversation_type_mask_role,
29+
get_prompt_template_example,
30+
)
2531
from nemo.core import Dataset
2632
from nemo.utils import logging
2733

@@ -344,16 +350,97 @@ def encode(self, text, append_eod=False):
344350

345351
return text_ids, len(text_ids)
346352

353+
@staticmethod
354+
def _convert_messages(
355+
input_list: List[Dict[str, str]]
356+
) -> Dict: # TODO: (@adithyare) this method should live elsewhare..
357+
"""
358+
args:
359+
input_list: is a list of dicts in the openai format
360+
for example:
361+
[{"role": "system", "content": "you are helpful},
362+
{"role": "user", "content": "Why is the sky blue?"},
363+
{"role": "assistant", "content": "Because blablabla"},
364+
...]
365+
returns:
366+
output_dict: a dict in nemo's format {"system": "sytem prompt",
367+
"conversation": [],
368+
...
369+
}
370+
"""
371+
output_dict = {
372+
"system": "",
373+
"conversations": [],
374+
"mask": "User",
375+
"type": "VALUE_TO_TEXT",
376+
}
377+
378+
# Extract the system message
379+
num_system_msg = 0
380+
for msg in input_list:
381+
if msg["role"] == "system":
382+
output_dict["system"] = msg["content"]
383+
num_system_msg += 1
384+
if num_system_msg > 1:
385+
raise RuntimeError("Multiple system messages seen, please consolidate into a single system message.")
386+
387+
# Build the conversations list
388+
for msg in input_list:
389+
if msg["role"] != "system":
390+
conversation_entry = {
391+
"from": msg["role"].capitalize(), # Capitalize 'user' and 'assistant'
392+
"value": msg["content"],
393+
"label": None,
394+
}
395+
output_dict["conversations"].append(conversation_entry)
396+
397+
return output_dict
398+
399+
def convert(self, messages):
400+
"""
401+
args:
402+
messages: is a list of dicts in the openai format
403+
for example:
404+
[{"role": "system", "content": "you are helpful},
405+
{"role": "user", "content": "Why is the sky blue?"},
406+
{"role": "assistant", "content": "Because blablabla"},
407+
...]
408+
returns:
409+
conversation: is a string formatted with the chat template
410+
"""
411+
if OmegaConf.select(self.cfg, "data.chat_prompt_tokens") is None:
412+
raise RuntimeError(
413+
"You don't have a model (model_config.yaml) which has chat_prompt_tokens, are you sure this is a Chat/Instruction model?"
414+
)
415+
special_tokens = self.cfg.data.chat_prompt_tokens
416+
nemo_source = self._convert_messages(messages)
417+
header, conversation, data_type, mask_role = _get_header_conversation_type_mask_role(
418+
nemo_source, special_tokens
419+
)
420+
return conversation
421+
347422
def __getitem__(self, idx):
348423
"""Returns a pair of chosen/rejected pairs, their respective lengths, and labels."""
349424
payload = self.data[idx]
350-
prompt, prompt_len = self.encode(payload["prompt"], append_eod=False)
351-
chosen, chosen_len = self.encode(
352-
payload["prompt"] + payload["chosen_response"], append_eod=self.cfg.data.get("append_eod", False)
353-
)
354-
reject, reject_len = self.encode(
355-
payload["prompt"] + payload["rejected_response"], append_eod=self.cfg.data.get("append_eod", False)
356-
)
425+
426+
if isinstance(payload["prompt"], str):
427+
# (@adithyare) format with hardcoded chat tokens
428+
# will allow this for the time being.
429+
prompt_fmtd = payload["prompt"]
430+
chosen_fmtd = payload["prompt"] + payload["chosen_response"]
431+
rejected_fmtd = payload["prompt"] + payload["rejected_response"]
432+
logging.warning(
433+
"Pre-formatting chat conversation as string with hardcoded chat tokens will be deprecated."
434+
) # (@adithyare) this will spam the console for now.
435+
else:
436+
prompt_fmtd = self.convert(payload["prompt"]) # (@adithyare) read var as "prompt formatted"
437+
chosen_fmtd = self.convert(payload["prompt"] + [payload["chosen_response"]])
438+
rejected_fmtd = self.convert(payload["prompt"] + [payload["rejected_response"]])
439+
440+
prompt, prompt_len = self.encode(prompt_fmtd, append_eod=False)
441+
chosen, chosen_len = self.encode(chosen_fmtd, append_eod=self.cfg.data.get("append_eod", False))
442+
reject, reject_len = self.encode(rejected_fmtd, append_eod=self.cfg.data.get("append_eod", False))
443+
357444
# chosen_response_only, chosen_response_len = self.encode(payload['chosen_response'])
358445
# reject_response_only, reject_response_len = self.encode(payload['rejected_response'])
359446
chosen_labels = ([-100] * prompt_len) + chosen[prompt_len:]
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Script to remove special tokens from dpo datasets
16+
and convert them into list of messages format"""
17+
18+
import argparse
19+
import json
20+
import re
21+
22+
23+
def format_conversation(input_string):
24+
# Define roles and patterns
25+
role_patterns = {"<extra_id_0>System": "system", "<extra_id_1>User": "user", "<extra_id_1>Assistant": "assistant"}
26+
27+
# Initialize an empty output list
28+
conversation = []
29+
30+
# Use regex to find each segment's role and content
31+
segments = re.findall(r"(<extra_id_[0-1]>[^\n]+)\n(.*?)((?=<extra_id_)|$)", input_string, re.DOTALL)
32+
33+
for segment in segments:
34+
role_tag, content, _ = segment
35+
role = role_patterns.get(role_tag.strip(), "unknown")
36+
conversation.append({"role": role, "content": content.strip()})
37+
38+
return conversation
39+
40+
41+
if __name__ == "__main__":
42+
parser = argparse.ArgumentParser(description="Process a JSONL file.")
43+
parser.add_argument("input_jsonl", type=str, help="Path to the input JSONL file.")
44+
# Parse the arguments
45+
args = parser.parse_args()
46+
47+
input_jsonl = args.input_jsonl
48+
output_jsonl = input_jsonl.replace(".jsonl", ".no_special_toks.jsonl")
49+
50+
with open(input_jsonl, "r") as f, open(output_jsonl, "w") as w:
51+
for line in f:
52+
j = json.loads(line)
53+
prompt = j["prompt"]
54+
undo_spl_prompt = format_conversation(prompt)
55+
empty_assistant = undo_spl_prompt.pop()
56+
chosen, rejected = j["chosen_response"], j["rejected_response"]
57+
chosen = chosen.split("\n<extra_id_1>")[0]
58+
rejected = rejected.split("\n<extra_id_1>")[0]
59+
chosen_message = {"role": empty_assistant["role"], "content": chosen}
60+
rejected_message = {"role": empty_assistant["role"], "content": rejected}
61+
j_out = {
62+
"prompt": undo_spl_prompt,
63+
"chosen_response": chosen_message,
64+
"rejected_response": rejected_message,
65+
"chosen_reward": j["chosen_reward"],
66+
"rejected_reward": j["rejected_reward"],
67+
}
68+
w.write(json.dumps(j_out) + "\n")

setup/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
Jinja2~=3.1.4
12
jsonlines
23
megatron_core>=0.8
34
nemo_toolkit[nlp]

tests/test_datasets.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,16 @@
2222
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
2323
from nemo_aligner.algorithms.dpo import dpo_custom_collate
2424
from nemo_aligner.data.nlp.builders import build_dataloader, build_train_valid_test_dpo_datasets
25+
from nemo_aligner.data.nlp.scripts.undo_special_tokens import format_conversation
2526
from nemo_aligner.utils import parallel_state
2627

28+
try:
29+
from jinja2 import Template
30+
31+
HAS_JINJA2 = True
32+
except:
33+
HAS_JINJA2 = False
34+
2735

2836
@pytest.fixture
2937
def llama3_tokenizer():
@@ -136,6 +144,54 @@ def test_dpo_loader(init_model_parallel, make_tmp_jsonl, llama3_tokenizer):
136144
assert num_mini_batches == 2
137145

138146

147+
@pytest.mark.run_only_on("GPU")
148+
def test_dpo_dataset_conversion():
149+
prompt_str = """<extra_id_0>System\n\n<extra_id_1>User\nDoes GPT-4 use RNN or Transformer models, and which one is better for this type of project?\n<extra_id_1>Assistant\nGPT-4 uses a transformer architecture, not a recurrent neural network. Both models are commonly used for natural language processing tasks, and both have advantages and disadvantages, so it depends on the specific application of the model.\n<extra_id_1>User\nCould you explain in detail both the advantages and disadvantages from different perspectives?\n<extra_id_1>Assistant\nYes, here are a few different perspectives on the pros and cons of RNN and transformer models:\n \n\n The RNN model has better sequential memory and can better model sequences of data, whereas the transformer model has better parallel computation and can better model large input datasets. This makes RNNs better for applications like speech recognition and machine translation, where the input is a sequence of sounds or words, and transformers better for applications like natural language processing and text classification, where the input is a large collection of documents or social media posts.\n \n\n RNNs are more straightforward and easier to understand and debug, whereas transformers are more flexible and can be used for more diverse applications. This makes RNNs better for teaching and research, and transformers better for actual applications in industry and the real world.\n \n\n RNNs are more efficient and faster to run, whereas transformers are more accurate and produce better results. This makes RNNs better for fast prototyping and testing, and transformers better for actual deployment in production applications.\n<extra_id_1>User\ncould you add more in a table\n<extra_id_1>Assistant\n"""
150+
151+
expected_oai_messages = [
152+
{"role": "system", "content": ""},
153+
{
154+
"role": "user",
155+
"content": "Does GPT-4 use RNN or Transformer models, and which one is better for this type of project?",
156+
},
157+
{
158+
"role": "assistant",
159+
"content": "GPT-4 uses a transformer architecture, not a recurrent neural network. Both models are commonly used for natural language processing tasks, and both have advantages and disadvantages, so it depends on the specific application of the model.",
160+
},
161+
{
162+
"role": "user",
163+
"content": "Could you explain in detail both the advantages and disadvantages from different perspectives?",
164+
},
165+
{
166+
"role": "assistant",
167+
"content": """Yes, here are a few different perspectives on the pros and cons of RNN and transformer models:\n \n\n The RNN model has better sequential memory and can better model sequences of data, whereas the transformer model has better parallel computation and can better model large input datasets. This makes RNNs better for applications like speech recognition and machine translation, where the input is a sequence of sounds or words, and transformers better for applications like natural language processing and text classification, where the input is a large collection of documents or social media posts.\n \n\n RNNs are more straightforward and easier to understand and debug, whereas transformers are more flexible and can be used for more diverse applications. This makes RNNs better for teaching and research, and transformers better for actual applications in industry and the real world.\n \n\n RNNs are more efficient and faster to run, whereas transformers are more accurate and produce better results. This makes RNNs better for fast prototyping and testing, and transformers better for actual deployment in production applications.""",
168+
},
169+
{"role": "user", "content": "could you add more in a table"},
170+
{"role": "assistant", "content": ""},
171+
]
172+
173+
oai_messages_prompt = format_conversation(prompt_str)
174+
assert expected_oai_messages == oai_messages_prompt
175+
176+
if HAS_JINJA2:
177+
# (@adithyare) bonus test! convert oai style messages back into a string using Jinja
178+
179+
def remove_trailing(s, t):
180+
if s.endswith(t):
181+
s = s[: -len(t)]
182+
return s
183+
184+
jinja_template = """{% for message in conversation %}{%- if message.role == "system" -%}<extra_id_0>System\n{{ message.content }}\n{% elif message.role == "user" -%}<extra_id_1>User\n{{ message.content }}\n{% elif message.role == "assistant" -%}<extra_id_1>Assistant\n{{ message.content }}\n{% endif %}{% endfor %}"""
185+
jinja_template = Template(jinja_template)
186+
prompt_str_jinja_rendered = jinja_template.render(conversation=oai_messages_prompt)
187+
prompt_str_jinja_rendered = remove_trailing(
188+
prompt_str_jinja_rendered, "\n"
189+
) # (@adithyare) jinja will add the ending of message token which we should remove to make a prompt.
190+
assert prompt_str == prompt_str_jinja_rendered
191+
192+
return True
193+
194+
139195
@pytest.mark.run_only_on("GPU")
140196
def test_dpo_loader_original(init_model_parallel, make_tmp_jsonl, llama3_tokenizer):
141197
init_model_parallel(tensor_model_parallel_size=1, pipeline_model_parallel_size=1)

0 commit comments

Comments
 (0)