Skip to content

Commit 5452848

Browse files
authored
Make HFDatasetDataModule a datasets.load_dataset wrapper (NVIDIA-NeMo#11500)
* Make HfDatasetDataModule a datasets.load_dataset wrapper Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * add logging Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * fix Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * Update HFDatasetDataModule Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * refactor Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * refactor fixup Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * refactor fixup #2 Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * do not expand Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * fix Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * fix Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * fix Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * fix Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * fix Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * doc Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * doc Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * add synonym Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * fix Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * fix Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * fix Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * typo Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * Apply isort and black reformatting Signed-off-by: akoumpa <akoumpa@users.noreply.github.com> * Add train/val/test attributes Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * Add test for hf-datamodule Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * Import lazily to avoid breaking with older megatron versions Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * bot happy Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * Apply isort and black reformatting Signed-off-by: akoumpa <akoumpa@users.noreply.github.com> * bot happy2 Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * add doc-strings and collate-fn arg Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * Apply isort and black reformatting Signed-off-by: akoumpa <akoumpa@users.noreply.github.com> --------- Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> Signed-off-by: akoumpa <akoumpa@users.noreply.github.com> Co-authored-by: akoumpa <akoumpa@users.noreply.github.com>
1 parent cb2302f commit 5452848

6 files changed

Lines changed: 276 additions & 20 deletions

File tree

examples/llm/peft/hf.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from nemo.collections import llm
1919

2020

21-
def mk_hf_dataset(tokenizer):
21+
def make_squad_hf_dataset(tokenizer):
2222
EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
2323

2424
def formatting_prompts_func(examples):
@@ -45,11 +45,9 @@ def formatting_prompts_func(examples):
4545
'labels': tokens[1:] + [tokens[-1]],
4646
}
4747

48-
from datasets import load_dataset
49-
50-
dataset = load_dataset("rajpurkar/squad", split="train")
51-
dataset = dataset.map(formatting_prompts_func, batched=False, batch_size=2)
52-
return dataset
48+
datamodule = llm.HFDatasetDataModule("rajpurkar/squad", split="train", pad_token_id=tokenizer.eos_token_id)
49+
datamodule.map(formatting_prompts_func, batched=False, batch_size=2)
50+
return datamodule
5351

5452

5553
if __name__ == '__main__':
@@ -80,9 +78,7 @@ def formatting_prompts_func(examples):
8078

8179
llm.api.finetune(
8280
model=llm.HFAutoModelForCausalLM(args.model),
83-
data=llm.HFDatasetDataModule(
84-
mk_hf_dataset(tokenizer.tokenizer), pad_token_id=tokenizer.tokenizer.eos_token_id
85-
),
81+
data=make_squad_hf_dataset(tokenizer.tokenizer),
8682
trainer=nl.Trainer(
8783
devices=args.devices,
8884
max_steps=args.max_steps,

nemo/collections/llm/gpt/data/api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ def dolly() -> pl.LightningDataModule:
4141

4242
@run.cli.factory
4343
@run.autoconvert
44-
def hf_dataset(dataset: str) -> pl.LightningDataModule:
45-
return HFDatasetDataModule(dataset=dataset, global_batch_size=16, micro_batch_size=2)
44+
def hf_dataset(path: str) -> pl.LightningDataModule:
45+
return HFDatasetDataModule(path=path, global_batch_size=16, micro_batch_size=2)
4646

4747

4848
__all__ = ["mock", "squad", "dolly", "hf_dataset"]

nemo/collections/llm/gpt/data/hf_dataset.py

Lines changed: 149 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,108 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import datasets.dataset_dict
1516
import lightning.pytorch as pl
1617
import torch
18+
from datasets import load_dataset
1719
from torch.utils.data import DataLoader
20+
1821
from nemo.lightning.pytorch.plugins import MegatronDataSampler
22+
from nemo.utils import logging
23+
24+
25+
def make_dataset_splits(path, split, split_aliases, kwargs):
26+
"""
27+
Loads a dataset with datasets.load_dataset and
28+
returns a dictionary containing all dataset splits.
29+
30+
For example:
31+
32+
ans = make_dataset_splits("dataset-id")
33+
$ ds = load_dataset("dataset-id")
34+
$ print(ds)
35+
> DatasetDict({
36+
> train: Dataset({
37+
> features: ['id', 'title', 'context', 'question', 'answers'],
38+
> num_rows: 87599
39+
> })
40+
> validation: Dataset({
41+
> features: ['id', 'title', 'context', 'question', 'answers'],
42+
> num_rows: 10570
43+
> })
44+
> })
45+
46+
In this case the value of `ans` (returned value) will be:
47+
$ print(ans)
48+
> {
49+
> "train": Dataset .. (with 87599 rows),
50+
> "val": Dataset .. (with 10570 rows),
51+
> }
52+
"""
53+
dataset = load_dataset(path, split=split, **kwargs)
54+
55+
split_names = ['train', 'test', 'val']
56+
dataset_splits = {split: None for split in split_names}
57+
58+
alias_to_split = {}
59+
for split_name, _split_aliases in split_aliases.items():
60+
assert split_name in split_names
61+
for alias in _split_aliases:
62+
alias_to_split[alias] = split_name
63+
64+
if isinstance(dataset, datasets.dataset_dict.DatasetDict):
65+
dataset_split_names = dataset.keys()
66+
logging.info(f"HF dataset has the following splits: {dataset_split_names}")
67+
for alias_split_name, split in dataset.items():
68+
split_name = alias_to_split[alias_split_name]
69+
assert dataset_splits[split_name] is None
70+
dataset_splits[split_name] = split
71+
elif isinstance(split, list):
72+
logging.info(f"Loaded HF dataset will use " + str(split) + " splits.")
73+
assert isinstance(dataset, list)
74+
for i, alias_split_name in enumerate(split):
75+
split_name = alias_to_split[alias_split_name]
76+
assert dataset_splits[split_name] is None
77+
dataset_splits[split_name] = dataset[i]
78+
elif isinstance(split, str):
79+
logging.info(f"Loaded HF dataset has a single split.")
80+
assert not isinstance(dataset, list)
81+
alias_split_name = split
82+
if '+' in alias_split_name:
83+
raise ValueError("Split concatenation not supported")
84+
elif '[' in alias_split_name:
85+
alias_split_name = alias_split_name.split('[')[0]
86+
split_name = alias_to_split[alias_split_name]
87+
assert dataset_splits[split_name] is None
88+
dataset_splits[split_name] = dataset
89+
else:
90+
raise ValueError("Expected split name to be None, str or a list")
91+
92+
assert (
93+
sum(map(lambda x: x is not None, dataset_splits.values())) > 0
94+
), "Expected at least one dataset to have been initialized"
95+
return dataset_splits
1996

2097

2198
class HFDatasetDataModule(pl.LightningDataModule):
99+
"""HFDatasetDataModule wraps HF's load_dataset (datasets library)
100+
so that it can be used within NeMo.
101+
Users can select whether to use an mcore-sampler via use_mcore_sampler arg.
102+
103+
Usage examples:
104+
105+
- loading a single split (train) from a dataset
106+
llm.HFDatasetDataModule("rajpurkar/squad", split="train")
107+
108+
- loading multiple splits (train, validation) from a dataset
109+
llm.HFDatasetDataModule("rajpurkar/squad", split=["train", "validation"])
110+
"""
111+
22112
def __init__(
23113
self,
24-
dataset,
114+
path,
115+
collate_fn=None,
116+
split=None,
25117
num_workers=2,
26118
pin_memory=True,
27119
persistent_workers=True,
@@ -31,11 +123,29 @@ def __init__(
31123
pad_token_id=0,
32124
use_mcore_sampler=False,
33125
mcore_dataloader_type='cyclic',
126+
train_aliases=["train", "training"],
127+
test_aliases=["test", "testing"],
128+
val_aliases=["val", "validation", "valid", "eval"],
129+
**kwargs,
34130
) -> None:
35131
super().__init__()
36132
assert pad_token_id is not None
37133

38-
self.dataset = dataset
134+
logging.info(f"Loading HF dataset from {path}")
135+
136+
# A dataset usually will have several splits (e.g. train, val, test, etc).
137+
# We map synonym names to canonical names (train, test, val).
138+
# A synonym can be a prefix/suffixed word e.g. train <> training.
139+
split_aliases = {'train': train_aliases, 'test': test_aliases, 'val': val_aliases}
140+
141+
# self.dataset_splits will hold the actual dataset for each split.
142+
self.dataset_splits = make_dataset_splits(path, split, split_aliases, kwargs)
143+
144+
if collate_fn is None:
145+
self._collate_fn = lambda x: HFDatasetDataModule.collate_fn(x, pad_token_id=self.pad_token_id)
146+
else:
147+
self._collate_fn = collate_fn
148+
39149
self.num_workers = num_workers
40150
self.pin_memory = pin_memory
41151
self.persistent_workers = persistent_workers
@@ -84,17 +194,51 @@ def setup(self, stage: str):
84194
dataloader_type=self.mcore_dataloader_type,
85195
)
86196

87-
def train_dataloader(self, collate_fn=None):
88-
from nemo.lightning.data import add_megatron_sampler
197+
def _make_dataloader(self, dataset, collate_fn=None):
198+
assert dataset is not None
89199

90200
if collate_fn is None:
91201
collate_fn = lambda x: HFDatasetDataModule.collate_fn(x, pad_token_id=self.pad_token_id)
92202

93203
return DataLoader(
94-
self.dataset,
204+
dataset,
95205
num_workers=self.num_workers,
96206
pin_memory=self.pin_memory,
97207
persistent_workers=self.persistent_workers,
98208
collate_fn=collate_fn,
99209
batch_size=self.micro_batch_size,
100210
)
211+
212+
@property
213+
def train(self):
214+
return self.dataset_splits['train']
215+
216+
@property
217+
def val(self):
218+
return self.dataset_splits['val']
219+
220+
@property
221+
def test(self):
222+
return self.dataset_splits['test']
223+
224+
def train_dataloader(self):
225+
return self._make_dataloader(self.train, self._collate_fn)
226+
227+
def val_dataloader(self):
228+
return self._make_dataloader(self.val, self._collate_fn)
229+
230+
def test_dataloader(self):
231+
return self._make_dataloader(self.test, self._collate_fn)
232+
233+
def map(self, function=None, split_names=None, **kwargs):
234+
if isinstance(split_names, str):
235+
dataset_splits = {split_names: self.dataset_splits[split_names]}
236+
elif isinstance(split_names, list):
237+
dataset_splits = {k: self.dataset_splits[k] for k in split_names}
238+
else:
239+
dataset_splits = self.dataset_splits
240+
241+
for split_name, subset in dataset_splits.items():
242+
if subset is None:
243+
continue
244+
dataset_splits[split_name] = subset.map(function, **kwargs)

nemo/collections/llm/inference/base.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,6 @@
2525
from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import (
2626
AbstractModelInferenceWrapper,
2727
)
28-
from megatron.core.inference.text_generation_controllers.encoder_decoder_text_generation_controller import (
29-
EncoderDecoderTextGenerationController,
30-
)
3128
from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import (
3229
SimpleTextGenerationController,
3330
)
@@ -232,6 +229,10 @@ def generate(
232229
Returns:
233230
dict: A dictionary containing the generated results.
234231
"""
232+
from megatron.core.inference.text_generation_controllers.encoder_decoder_text_generation_controller import (
233+
EncoderDecoderTextGenerationController,
234+
)
235+
235236
if encoder_prompts is not None:
236237
text_generation_controller = EncoderDecoderTextGenerationController(
237238
inference_wrapped_model=model, tokenizer=tokenizer

nemo/collections/llm/t5/model/t5.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import torch
2121
import torch.distributed
2222
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig
23-
from megatron.core.inference.model_inference_wrappers.t5.t5_inference_wrapper import T5InferenceWrapper
23+
2424
from megatron.core.models.T5.t5_model import T5Model as MCoreT5Model
2525
from megatron.core.optimizer import OptimizerConfig
2626
from megatron.core.transformer.spec_utils import ModuleSpec
@@ -319,6 +319,7 @@ def get_inference_wrapper(self, params_dtype, inference_batch_times_seqlen_thres
319319
inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold,
320320
padded_vocab_size=self.tokenizer.vocab_size,
321321
)
322+
from megatron.core.inference.model_inference_wrappers.t5.t5_inference_wrapper import T5InferenceWrapper
322323

323324
model_inference_wrapper = T5InferenceWrapper(mcore_model, inference_wrapper_config)
324325
return model_inference_wrapper
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from nemo.collections import llm
16+
17+
DATA_PATH = "/home/TestData/lite/hf_cache/squad/"
18+
19+
20+
def test_load_single_split():
21+
ds = llm.HFDatasetDataModule(
22+
path=DATA_PATH,
23+
split='train',
24+
seq_length=512,
25+
micro_batch_size=2,
26+
global_batch_size=2,
27+
)
28+
from datasets.arrow_dataset import Dataset
29+
30+
assert isinstance(ds.dataset_splits, dict)
31+
assert len(ds.dataset_splits) == 3
32+
assert 'train' in ds.dataset_splits
33+
assert ds.dataset_splits['train'] is not None
34+
assert ds.train is not None
35+
assert isinstance(ds.dataset_splits['train'], Dataset)
36+
assert 'val' in ds.dataset_splits
37+
assert ds.dataset_splits['val'] is None
38+
assert ds.val is None
39+
assert 'test' in ds.dataset_splits
40+
assert ds.dataset_splits['test'] is None
41+
assert ds.test is None
42+
43+
44+
def test_load_nonexistent_split():
45+
exception_msg = ''
46+
expected_msg = '''Unknown split "this_split_name_should_not_exist". Should be one of ['train', 'validation'].'''
47+
try:
48+
llm.HFDatasetDataModule(
49+
path=DATA_PATH,
50+
split='this_split_name_should_not_exist',
51+
seq_length=512,
52+
micro_batch_size=2,
53+
global_batch_size=2,
54+
)
55+
except ValueError as e:
56+
exception_msg = str(e)
57+
assert exception_msg == expected_msg, exception_msg
58+
59+
60+
def test_load_multiple_split():
61+
ds = llm.HFDatasetDataModule(
62+
path=DATA_PATH,
63+
split=['train', 'validation'],
64+
seq_length=512,
65+
micro_batch_size=2,
66+
global_batch_size=2,
67+
)
68+
from datasets.arrow_dataset import Dataset
69+
70+
assert isinstance(ds.dataset_splits, dict)
71+
assert len(ds.dataset_splits) == 3
72+
assert 'train' in ds.dataset_splits
73+
assert ds.dataset_splits['train'] is not None
74+
assert ds.train is not None
75+
assert isinstance(ds.dataset_splits['train'], Dataset)
76+
assert isinstance(ds.train, Dataset)
77+
assert 'val' in ds.dataset_splits
78+
assert ds.dataset_splits['val'] is not None
79+
assert ds.val is not None
80+
assert isinstance(ds.dataset_splits['val'], Dataset)
81+
assert isinstance(ds.val, Dataset)
82+
assert 'test' in ds.dataset_splits
83+
assert ds.dataset_splits['test'] is None
84+
assert ds.test is None
85+
86+
87+
def test_validate_dataset_asset_accessibility_file_does_not_exist():
88+
raised_exception = False
89+
try:
90+
llm.HFDatasetDataModule(
91+
path="/this/path/should/not/exist/",
92+
seq_length=512,
93+
micro_batch_size=2,
94+
global_batch_size=2,
95+
)
96+
except FileNotFoundError:
97+
raised_exception = True
98+
99+
assert raised_exception == True, "Expected to raise a FileNotFoundError"
100+
101+
102+
def test_validate_dataset_asset_accessibility_file_is_none(): # tokenizer, trainer):
103+
raised_exception = False
104+
try:
105+
llm.HFDatasetDataModule(
106+
path=None,
107+
seq_length=512,
108+
micro_batch_size=2,
109+
global_batch_size=2,
110+
)
111+
except TypeError:
112+
raised_exception = True
113+
114+
assert raised_exception == True, "Expected to raise a ValueError"

0 commit comments

Comments
 (0)