1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import datasets .dataset_dict
1516import lightning .pytorch as pl
1617import torch
18+ from datasets import load_dataset
1719from torch .utils .data import DataLoader
20+
1821from nemo .lightning .pytorch .plugins import MegatronDataSampler
22+ from nemo .utils import logging
23+
24+
25+ def make_dataset_splits (path , split , split_aliases , kwargs ):
26+ """
27+ Loads a dataset with datasets.load_dataset and
28+ returns a dictionary containing all dataset splits.
29+
30+ For example:
31+
32+ ans = make_dataset_splits("dataset-id")
33+ $ ds = load_dataset("dataset-id")
34+ $ print(ds)
35+ > DatasetDict({
36+ > train: Dataset({
37+ > features: ['id', 'title', 'context', 'question', 'answers'],
38+ > num_rows: 87599
39+ > })
40+ > validation: Dataset({
41+ > features: ['id', 'title', 'context', 'question', 'answers'],
42+ > num_rows: 10570
43+ > })
44+ > })
45+
46+ In this case the value of `ans` (returned value) will be:
47+ $ print(ans)
48+ > {
49+ > "train": Dataset .. (with 87599 rows),
50+ > "val": Dataset .. (with 10570 rows),
51+ > }
52+ """
53+ dataset = load_dataset (path , split = split , ** kwargs )
54+
55+ split_names = ['train' , 'test' , 'val' ]
56+ dataset_splits = {split : None for split in split_names }
57+
58+ alias_to_split = {}
59+ for split_name , _split_aliases in split_aliases .items ():
60+ assert split_name in split_names
61+ for alias in _split_aliases :
62+ alias_to_split [alias ] = split_name
63+
64+ if isinstance (dataset , datasets .dataset_dict .DatasetDict ):
65+ dataset_split_names = dataset .keys ()
66+ logging .info (f"HF dataset has the following splits: { dataset_split_names } " )
67+ for alias_split_name , split in dataset .items ():
68+ split_name = alias_to_split [alias_split_name ]
69+ assert dataset_splits [split_name ] is None
70+ dataset_splits [split_name ] = split
71+ elif isinstance (split , list ):
72+ logging .info (f"Loaded HF dataset will use " + str (split ) + " splits." )
73+ assert isinstance (dataset , list )
74+ for i , alias_split_name in enumerate (split ):
75+ split_name = alias_to_split [alias_split_name ]
76+ assert dataset_splits [split_name ] is None
77+ dataset_splits [split_name ] = dataset [i ]
78+ elif isinstance (split , str ):
79+ logging .info (f"Loaded HF dataset has a single split." )
80+ assert not isinstance (dataset , list )
81+ alias_split_name = split
82+ if '+' in alias_split_name :
83+ raise ValueError ("Split concatenation not supported" )
84+ elif '[' in alias_split_name :
85+ alias_split_name = alias_split_name .split ('[' )[0 ]
86+ split_name = alias_to_split [alias_split_name ]
87+ assert dataset_splits [split_name ] is None
88+ dataset_splits [split_name ] = dataset
89+ else :
90+ raise ValueError ("Expected split name to be None, str or a list" )
91+
92+ assert (
93+ sum (map (lambda x : x is not None , dataset_splits .values ())) > 0
94+ ), "Expected at least one dataset to have been initialized"
95+ return dataset_splits
1996
2097
2198class HFDatasetDataModule (pl .LightningDataModule ):
99+ """HFDatasetDataModule wraps HF's load_dataset (datasets library)
100+ so that it can be used within NeMo.
101+ Users can select whether to use an mcore-sampler via use_mcore_sampler arg.
102+
103+ Usage examples:
104+
105+ - loading a single split (train) from a dataset
106+ llm.HFDatasetDataModule("rajpurkar/squad", split="train")
107+
108+ - loading multiple splits (train, validation) from a dataset
109+ llm.HFDatasetDataModule("rajpurkar/squad", split=["train", "validation"])
110+ """
111+
22112 def __init__ (
23113 self ,
24- dataset ,
114+ path ,
115+ collate_fn = None ,
116+ split = None ,
25117 num_workers = 2 ,
26118 pin_memory = True ,
27119 persistent_workers = True ,
@@ -31,11 +123,29 @@ def __init__(
31123 pad_token_id = 0 ,
32124 use_mcore_sampler = False ,
33125 mcore_dataloader_type = 'cyclic' ,
126+ train_aliases = ["train" , "training" ],
127+ test_aliases = ["test" , "testing" ],
128+ val_aliases = ["val" , "validation" , "valid" , "eval" ],
129+ ** kwargs ,
34130 ) -> None :
35131 super ().__init__ ()
36132 assert pad_token_id is not None
37133
38- self .dataset = dataset
134+ logging .info (f"Loading HF dataset from { path } " )
135+
136+ # A dataset usually will have several splits (e.g. train, val, test, etc).
137+ # We map synonym names to canonical names (train, test, val).
138+ # A synonym can be a prefix/suffixed word e.g. train <> training.
139+ split_aliases = {'train' : train_aliases , 'test' : test_aliases , 'val' : val_aliases }
140+
141+ # self.dataset_splits will hold the actual dataset for each split.
142+ self .dataset_splits = make_dataset_splits (path , split , split_aliases , kwargs )
143+
144+ if collate_fn is None :
145+ self ._collate_fn = lambda x : HFDatasetDataModule .collate_fn (x , pad_token_id = self .pad_token_id )
146+ else :
147+ self ._collate_fn = collate_fn
148+
39149 self .num_workers = num_workers
40150 self .pin_memory = pin_memory
41151 self .persistent_workers = persistent_workers
@@ -84,17 +194,51 @@ def setup(self, stage: str):
84194 dataloader_type = self .mcore_dataloader_type ,
85195 )
86196
87- def train_dataloader (self , collate_fn = None ):
88- from nemo . lightning . data import add_megatron_sampler
197+ def _make_dataloader (self , dataset , collate_fn = None ):
198+ assert dataset is not None
89199
90200 if collate_fn is None :
91201 collate_fn = lambda x : HFDatasetDataModule .collate_fn (x , pad_token_id = self .pad_token_id )
92202
93203 return DataLoader (
94- self . dataset ,
204+ dataset ,
95205 num_workers = self .num_workers ,
96206 pin_memory = self .pin_memory ,
97207 persistent_workers = self .persistent_workers ,
98208 collate_fn = collate_fn ,
99209 batch_size = self .micro_batch_size ,
100210 )
211+
212+ @property
213+ def train (self ):
214+ return self .dataset_splits ['train' ]
215+
216+ @property
217+ def val (self ):
218+ return self .dataset_splits ['val' ]
219+
220+ @property
221+ def test (self ):
222+ return self .dataset_splits ['test' ]
223+
224+ def train_dataloader (self ):
225+ return self ._make_dataloader (self .train , self ._collate_fn )
226+
227+ def val_dataloader (self ):
228+ return self ._make_dataloader (self .val , self ._collate_fn )
229+
230+ def test_dataloader (self ):
231+ return self ._make_dataloader (self .test , self ._collate_fn )
232+
233+ def map (self , function = None , split_names = None , ** kwargs ):
234+ if isinstance (split_names , str ):
235+ dataset_splits = {split_names : self .dataset_splits [split_names ]}
236+ elif isinstance (split_names , list ):
237+ dataset_splits = {k : self .dataset_splits [k ] for k in split_names }
238+ else :
239+ dataset_splits = self .dataset_splits
240+
241+ for split_name , subset in dataset_splits .items ():
242+ if subset is None :
243+ continue
244+ dataset_splits [split_name ] = subset .map (function , ** kwargs )
0 commit comments