-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy pathwithout_metadata.py
More file actions
213 lines (194 loc) · 9.3 KB
/
without_metadata.py
File metadata and controls
213 lines (194 loc) · 9.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import logging
from datasets import config, load_dataset
from torch.utils.data import DataLoader
from transformers import default_data_collator
logger = logging.getLogger(__name__)
def get_dataloaders(tokenizer, args):
"""
Args:
tokenizer: a huggingface/transformers tokenizer
args: a DataConfig
Returns:
a training dataloader and one or more validation dataloaders
validation dataloaders should be in a dictionary
each dataloader should yield {str: torch.Tensor(cpu) }
dictionary keys may have 'metadata_mask'
other fields will be passed to model
note: metadata_mask should be padded
Example:
train_dataloader, val_dataloaders = get_dataloaders(...)
for batch in train_dataloader:
metadata_mask = batch.get('metadata_mask', None)
outputs = model(**batch)
metrics = loss_fn(batch, outputs, metadata_mask)
"""
# Mostly copy/paste from https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_clm_no_trainer.py
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub).
#
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
# 'text' is found. You can easily tweak this behavior (see below).
#
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
# download the dataset.
data_files = {}
if args.train_file is not None:
data_files["train"] = args.train_file
if args.validation_file is not None:
data_files["validation"] = args.validation_file
if not data_files:
data_files = None
logger.info(f"Start to load dataset, the result will be cached at {config.HF_DATASETS_CACHE}")
if args.dataset_name is not None:
logger.info(
"Downloading with arguments: "
f"dataset_name={args.dataset_name}, "
f"dataset_config_name={args.dataset_config_name}, "
f"data_files={data_files}, "
f"cache_dir={args.cache_dir},"
)
# Downloading and loading a dataset from the hub.
datasets = load_dataset(
args.dataset_name,
args.dataset_config_name,
data_files=data_files,
cache_dir=args.cache_dir,
keep_in_memory=False,
)
if "validation" not in datasets.keys():
datasets["validation"] = load_dataset(
args.dataset_name,
args.dataset_config_name,
split=f"train[:{args.validation_split_percentage}%]",
cache_dir=args.cache_dir,
)
datasets["train"] = load_dataset(
args.dataset_name,
args.dataset_config_name,
split=f"train[{args.validation_split_percentage}%:]",
cache_dir=args.cache_dir,
)
else:
logger.info("Loading dataset from extension script")
extension = args.train_file.split(".")[-1] if not args.extension else args.extension
if extension == "txt":
extension = "text"
if extension == "jsonl":
extension = "json"
datasets = load_dataset(extension, data_files=data_files, cache_dir=args.cache_dir)
if "validation" not in datasets.keys():
datasets["validation"] = load_dataset(
extension,
data_files=data_files,
split=f"train[:{args.validation_split_percentage}%]",
cache_dir=args.cache_dir,
)
datasets["train"] = load_dataset(
extension,
data_files=data_files,
split=f"train[{args.validation_split_percentage}%:]",
cache_dir=args.cache_dir,
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# Preprocessing the datasets.
# First we tokenize all the texts.
column_names = datasets["train"].column_names if "train" in datasets else datasets["validation"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]
def tokenize_function(examples):
return tokenizer(examples[text_column_name])
logger.info("Tokenize dataset")
tokenized_datasets = datasets.map(
tokenize_function,
batched=True,
num_proc=args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not args.overwrite_cache,
desc="Running tokenizer on dataset",
batch_size=args.map_batch_size,
)
logger.info("Tokenize dataset finished")
if args.block_size is None:
block_size = tokenizer.model_max_length
if block_size > 1024:
logger.warning(
f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
"Picking 1024 instead. You can change that default value by passing --block_size xxx."
)
block_size = 1024
else:
if args.block_size > tokenizer.model_max_length:
logger.warning(
f"The block_size passed ({args.block_size}) is larger than the maximum length for the model"
f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
)
block_size = min(args.block_size, tokenizer.model_max_length)
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
def group_texts(examples):
# Concatenate all texts.
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
# customize this part to your needs.
if total_length >= block_size:
total_length = (total_length // block_size) * block_size
# Split by chunks of max_len.
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
result["labels"] = result["input_ids"].copy()
return result
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
# for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
# to preprocess.
#
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
logger.info("Group texts")
datasets = tokenized_datasets.map(
group_texts,
batched=True,
num_proc=args.preprocessing_num_workers,
load_from_cache_file=not args.overwrite_cache,
desc=f"Grouping texts in chunks of {block_size}",
batch_size=args.map_batch_size,
)
logger.info("Group texts finished")
train_dataset = datasets["train"] if "train" in datasets else None
val_dataset = datasets["validation"]
if "train" in datasets:
logger.info(f" Num train examples = {len(train_dataset)}")
logger.info(f" Num validation examples = {len(val_dataset)}")
if "train" in datasets:
logger.info(" Train sample without metadata")
for idx in range(3):
logger.info(f" Train sample n°{idx} attention_mask:\n{train_dataset[idx]['attention_mask']}")
logger.info(f" Train sample n°{idx} input_ids:\n{train_dataset[idx]['input_ids']}")
logger.info(f" Train sample n°{idx} input_ids decoded:\n{tokenizer.decode(train_dataset[idx]['input_ids'])}")
logger.info(
f" Train sample n°{idx} tokens:\n{tokenizer.convert_ids_to_tokens(train_dataset[idx]['input_ids'])}"
)
else:
logger.info(" Validation sample without metadata")
for idx in range(3):
logger.info(f" Validation sample n°{idx} attention_mask:\n{val_dataset[idx]['attention_mask']}")
logger.info(f" Validation sample n°{idx} input_ids:\n{val_dataset[idx]['input_ids']}")
logger.info(f" Validation sample n°{idx} input_ids decoded:\n{tokenizer.decode(val_dataset[idx]['input_ids'])}")
logger.info(
f" Validation sample n°{idx} tokens:\n{tokenizer.convert_ids_to_tokens(val_dataset[idx]['input_ids'])}"
)
# DataLoaders creation:
train_dataloader = DataLoader(
train_dataset,
shuffle=True,
collate_fn=default_data_collator,
batch_size=args.per_device_train_batch_size,
) if "train" in datasets else None
val_dataloader1 = DataLoader(
val_dataset,
collate_fn=default_data_collator,
batch_size=args.per_device_eval_batch_size,
)
return train_dataloader, {"val1": val_dataloader1}