-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcompute_contamination_metrics.py
More file actions
347 lines (300 loc) · 14.5 KB
/
compute_contamination_metrics.py
File metadata and controls
347 lines (300 loc) · 14.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
import json
import argparse
import os
import glob
from typing import List, Tuple, Set, Any, Optional
from nltk import ngrams
from typing import Dict
from tqdm import tqdm
from dataclasses import dataclass
from light_scenario import LightInstance, LightScenario, LightScenarioKey
from light_tokenizer import LightTokenizer, DefaultTokenizer
from load_documents import get_document_iterator
from contamination_stats import (
ContaminationStats,
ContaminationStatsKey,
PART_INPUT,
PART_REF,
)
from common.hierarchical_logger import hlog, htrack_block
from common.general import asdict_without_nones, write
from scenarios.scenario import ScenarioSpec
# The n values of the ngrams to be computed
N_VALUES: List[int] = [5, 9, 13] # TODO: Pick the N values
@dataclass(frozen=True)
class EntryContaminationKey:
"""Unique key representing either the input or references of a single instance in a scenario."""
stats_key: ContaminationStatsKey
part: str
"""Either PART_INPUT or PART_REF"""
instance_id: int
# type alias for contamination-related data structures
Ngram = Tuple[str, ...]
NgramIndex = Dict[int, Dict[Ngram, Set[EntryContaminationKey]]]
AllContaminationStats = Dict[ContaminationStatsKey, ContaminationStats]
NgramCounter = Dict[EntryContaminationKey, Dict[Ngram, int]]
def load_light_scenarios_from_jsonl(path: str) -> List[LightScenario]:
"""
Create a list of light scenarios from a jsonl file, where each json represents a LightScenario object.
Input file format:
Instance JSON 1
Instance JSON 2
Instance JSON 3
...
Each line is a json and each json looks like:
{
"light_scenario_key": {
"metadata":{
"split": "SPLIT",
"scenario_attribute_1": "ATTRIBUTE1",
"scenario_attribute_2": "ATTRIBUTE2",
}
},
"light_instances": [
{
"input": "INPUT_TEXT1",
"references": [
"REFERENCE_TEXT_1",
"REFERENCE_TEXT_2"
]
},
{
"input": "INPUT_TEXT2",
"references": [
"REFERENCE_TEXT_3",
"REFERENCE_TEXT_4"
]
}
]
}
Note that the values of light_scenario_key.metadata need to be hashable.
"""
def create_light_instance_from_dict(instance_dict: dict) -> LightInstance:
return LightInstance(input=instance_dict["input"], references=instance_dict["references"])
light_scenarios: List[LightScenario] = []
light_scenario_jsons = open(path, "r").readlines()
for light_scenario_json in light_scenario_jsons:
light_scenario_dict: dict = json.loads(light_scenario_json)
light_scenario_metadata: dict = light_scenario_dict["light_scenario_key"]["metadata"]
# if the light_scenarios are exported from helm, they will have a scenario_spec field
if "scenario_spec" in light_scenario_metadata:
light_scenario_metadata["scenario_spec"] = ScenarioSpec(**light_scenario_metadata["scenario_spec"])
light_scenario_key = LightScenarioKey(metadata=light_scenario_metadata)
light_instances: List[LightInstance] = [
create_light_instance_from_dict(instance_dict) for instance_dict in light_scenario_dict["light_instances"]
]
light_scenarios.append(LightScenario(light_scenario_key=light_scenario_key, light_instances=light_instances))
return light_scenarios
def create_ngram_index(
light_scenarios: List[LightScenario], n_values: List[int], tokenizer: LightTokenizer
) -> NgramIndex:
"""Given a list of scenarios and n values, initialize ngram_index"""
ngram_index: NgramIndex = {n: {} for n in n_values}
for scenario in light_scenarios:
hlog(f"Building ngram indexes for {scenario.light_scenario_key}")
for n in n_values:
stats_key = ContaminationStatsKey(metadata={"light_scenario_key": scenario.light_scenario_key, "N": n})
for i in range(len(scenario.light_instances)):
instance = scenario.light_instances[i]
input_tokens = tokenizer.tokenize(instance.input)
for input_ngram in ngrams(input_tokens, n):
if input_ngram not in ngram_index[n]:
ngram_index[n][input_ngram] = set()
ngram_index[n][input_ngram].add(
EntryContaminationKey(stats_key=stats_key, instance_id=i, part=PART_INPUT)
)
# compute reference ngrams
for reference in instance.references:
reference_unigrams = tokenizer.tokenize(reference)
for reference_ngram in ngrams(reference_unigrams, n):
if reference_ngram not in ngram_index[n]:
ngram_index[n][reference_ngram] = set()
ngram_index[n][reference_ngram].add(
EntryContaminationKey(stats_key=stats_key, instance_id=i, part=PART_REF)
)
return ngram_index
def create_all_contamination_stats(light_scenarios: List[LightScenario], n_values: List[int]) -> AllContaminationStats:
"""Given a list of scenarios and n values, initialize all_contamination_stats"""
hlog("Initializing all contamination stats")
all_contamination_stats: AllContaminationStats = {}
for scenario in light_scenarios:
for n in n_values:
# Initialize a stats instance for every pair of <scenario, n>
stats: ContaminationStats = ContaminationStats.from_scenario(scenario, stats_tags={"N": n})
if stats.stats_key in all_contamination_stats:
raise ValueError("Duplicated settings detected.")
all_contamination_stats[stats.stats_key] = stats
return all_contamination_stats
def compute_scenario_file_contamination(
training_file_path: str,
file_format: str,
ngram_index: NgramIndex,
all_contamination_stats: AllContaminationStats,
tokenizer: LightTokenizer,
ngram_counter: Optional[NgramCounter] = None,
max_contaminated_ngrams: int = 0,
):
"""
Given an input file, compute a contamination stats for each n and each scenario by calling
`compute_scenario_document_contamination()` for each document in the file. The function writes
to the contamination stats directly and does not return anything.
ngram_index: The ngram index that maps from ngrams to contamination stats
all_contamination_stats: The contamination stats for each scenario and n. The variable to write to.
tokenizer: The tokenizer used to break documents in the file into tokens
ngram_counter: The ngrams that are overlapped between the training file and the scenario data
and their counts.
The outer dict maps from n to the inner dict, which maps from ngram to count.
"""
document_iterator = get_document_iterator(file_path=training_file_path, file_format=file_format)
document_index: int = 0
for document in document_iterator:
document_index += 1
compute_scenario_document_contamination(
document=document,
ngram_index=ngram_index,
all_contamination_stats=all_contamination_stats,
tokenizer=tokenizer,
ngram_counter=ngram_counter,
max_contaminated_ngrams=max_contaminated_ngrams,
)
def compute_scenario_document_contamination(
document: str,
ngram_index: NgramIndex,
all_contamination_stats: AllContaminationStats,
tokenizer: LightTokenizer,
ngram_counter: Optional[NgramCounter] = None,
max_contaminated_ngrams: int = 0,
):
"""
Given a document, compute a contamination stats for each n and each scenario. The function
writes to the contamination stats directly and does not return anything.
ngram_index: The ngram index that maps from ngrams to contamination stats
tokenizer: The tokenizer used to break the document into tokens
all_contamination_stats: The contamination stats for each scenario and n. The variable to write to.
ngram_counter: The ngrams that are overlapped between the training file and the scenario data
and their counts.
The outer dict maps from n to the inner dict, which maps from ngram to count.
"""
document_tokens = tokenizer.tokenize(document)
for n in ngram_index.keys():
for document_ngram in ngrams(document_tokens, n):
if document_ngram in ngram_index[n]:
for entry_contamination_key in ngram_index[n][document_ngram]:
# update contamination_stats
stats: ContaminationStats = all_contamination_stats[entry_contamination_key.stats_key]
stats.write_one_to_bit(entry_contamination_key.instance_id, entry_contamination_key.part)
# skip the rest if max_contaminated_ngrams is 0
if max_contaminated_ngrams != 0:
if ngram_counter is None:
raise ValueError("ngram_counter must be not none when max_contaminated_ngrams != 0")
# update ngram_counter
if entry_contamination_key not in ngram_counter:
ngram_counter[entry_contamination_key] = {}
if document_ngram in ngram_counter[entry_contamination_key]:
ngram_counter[entry_contamination_key][document_ngram] += 1
elif (
max_contaminated_ngrams == -1
or len(ngram_counter[entry_contamination_key]) < max_contaminated_ngrams
):
ngram_counter[entry_contamination_key][document_ngram] = 1
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input-data", type=str, required=True, help="Path to your training data")
parser.add_argument("--scenario-data", type=str, required=True, help="Path to scenario data (benchmarking data)")
parser.add_argument("--output-stats", type=str, required=True, help="Path to the output file")
parser.add_argument(
"--input-format",
type=str,
required=True,
help="The format of your input file for your training data, e.g. raw, custom, the_pile",
)
parser.add_argument(
"--tags",
type=str,
nargs="*",
help="Other tags, such as whether the input data is for pretraining or instruction tuning",
)
parser.add_argument(
"--normalization", type=str, default="default", help="What normalization and tokenization strategy to apply"
)
parser.add_argument(
"--output-ngrams",
type=str,
default=None,
help="Path to the file of contaminated ngrams. To output the ngrams, you must also specify --max-output-ngrams",
)
parser.add_argument(
"--max-output-ngrams",
type=int,
default=0,
help=(
"The max number of contaminated ngrams to be stored for each (n, light_instance, part)."
"Set to -1 to store all"
),
)
args = parser.parse_args()
tokenizer: LightTokenizer
if args.normalization == "none":
tokenizer = LightTokenizer()
elif args.normalization == "default":
tokenizer = DefaultTokenizer()
else:
raise ValueError(f"Normalization strategy {args.normalization} is not defined.")
if args.max_output_ngrams != 0 and args.output_ngrams is None:
raise ValueError("You must specify --output-ngrams if you want to output ngrams.")
if args.max_output_ngrams == 0 and args.output_ngrams is not None:
raise ValueError("You must specify --max-output-ngrams != 0 if you want to output ngrams.")
input_file_paths: List[str]
if os.path.isdir(args.input_data):
input_file_paths = []
for file_path in glob.iglob(os.path.join(args.input_data, "**/*"), recursive=True):
if os.path.isfile(file_path):
input_file_paths.append(file_path)
else:
input_file_paths = [args.input_data]
hlog(f"The input data will be loaded from {input_file_paths}")
hlog(f"Loading scenario data from {args.scenario_data}")
light_scenarios = load_light_scenarios_from_jsonl(args.scenario_data)
with htrack_block("Initializing the stats, ngram_index, and ngram_counter"):
all_contamination_stats: AllContaminationStats
ngram_index: NgramIndex
all_contamination_stats = create_all_contamination_stats(light_scenarios=light_scenarios, n_values=N_VALUES)
ngram_index = create_ngram_index(light_scenarios=light_scenarios, n_values=N_VALUES, tokenizer=tokenizer)
ngram_counter: NgramCounter = {}
# commpute the stats
with htrack_block("Computing contamination stats"):
for input_file_index in tqdm(
range(len(input_file_paths)), desc="Computing contamination stats for input files", disable=None
):
input_file_path: str = input_file_paths[input_file_index]
compute_scenario_file_contamination(
training_file_path=input_file_path,
file_format=args.input_format,
ngram_index=ngram_index,
all_contamination_stats=all_contamination_stats,
tokenizer=tokenizer,
ngram_counter=ngram_counter,
max_contaminated_ngrams=args.max_output_ngrams,
)
stats_summaries: List[Dict[str, Any]] = []
for contamination_stats in all_contamination_stats.values():
stats_summaries.append(contamination_stats.generate_summary({"tags:": args.tags}))
with open(args.output_stats, "w") as f:
f.writelines(f"{json.dumps(stats_summary)}\n" for stats_summary in stats_summaries)
hlog(f"Written {len(stats_summaries)} results to {args.output_stats}")
if args.output_ngrams is not None:
# convert the ngram counter to json format
ngram_entries = []
for entry_contamination_key, contaminated_ngrams in ngram_counter.items():
ngram_entries.append(
{
"entry_contamination_key": asdict_without_nones(entry_contamination_key),
"contaminated_ngrams": {" ".join(ngram): count for ngram, count in contaminated_ngrams.items()},
}
)
write(args.output_ngrams, json.dumps(ngram_entries))
else:
hlog(
"Contaminated ngrams are not written to disk. "
"Set --output-ngrams and --max-output-ngrams if you want output the data."
)