-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathbenchmark.py
More file actions
194 lines (156 loc) · 6.17 KB
/
benchmark.py
File metadata and controls
194 lines (156 loc) · 6.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
"""
Entrypoint for running all tasks in `biobench`.
Most of this script is self documenting.
Run `python benchmark.py --help` to see all the options.
Note that you will have to download all the datasets, but each dataset includes its own download script with instructions.
For example, see `biobench.newt.download` for an example.
"""
import collections
import importlib
import logging
import os
import beartype
import submitit
import tyro
from biobench import config, helpers, jobkit, reporting
@beartype.beartype
def main(cfgs: list[str], dry_run: bool = True, n_parallel: int = 1):
"""
Launch all jobs, using either a local GPU or a Slurm cluster. Then report results and save to disk.
Args:
cfgs: List of paths to TOML config files.
dry_run: If --no-dry-run, actually run experiment.
n_parallel: Number of jobs that can be claimed by any one launcher process.
"""
# Load all configs from the provided paths.
cfgs = [cfg for path in cfgs for cfg in config.load(path)]
if not cfgs:
print("No configurations loaded.")
return
# ------------------------------------------------------
# Verify all configs have consistent execution settings.
# ------------------------------------------------------
first = cfgs[0]
for cfg in cfgs[1:]:
if cfg.slurm_acct != first.slurm_acct:
raise ValueError("All configs must have the same slurm_acct")
if cfg.log_to != first.log_to:
raise ValueError("All configs must have the same log_to directory")
if cfg.ssl != first.ssl:
raise ValueError("All configs must have the same ssl setting")
# --------------
# Setup logging.
# --------------
log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
level = logging.DEBUG if first.debug else logging.INFO
logging.basicConfig(level=level, format=log_format)
logger = logging.getLogger("benchmark.py")
logging.getLogger("PIL.TiffImagePlugin").setLevel(logging.INFO)
# ---------------
# Setup executor.
# ---------------
if first.slurm_acct:
executor = submitit.SlurmExecutor(folder=first.log_to)
executor.update_parameters(
time=30,
gpus_per_node=1,
cpus_per_task=8,
stderr_to_stdout=True,
partition="debug",
account=first.slurm_acct,
)
# See biobench.third_party_models.get_ssl() for a discussion of this variable.
if not first.ssl:
executor.update_parameters(setup=["export BIOBENCH_DISABLE_SSL=1"])
elif first.debug:
executor = submitit.DebugExecutor(folder=first.log_to)
# See biobench.third_party_models.get_ssl() for a discussion of this variable.
if not first.ssl:
os.environ["BIOBENCH_DISABLE_SSL"] = "1"
else:
executor = jobkit.SerialExecutor(folder=first.log_to)
# See biobench.third_party_models.get_ssl() for a discussion of this variable.
if not first.ssl:
os.environ["BIOBENCH_DISABLE_SSL"] = "1"
db = reporting.get_db(first)
# Clear old (5 days+) runs.
cleared = reporting.clear_stale_claims(db, max_age_hours=24 * 5)
logger.info("Cleared %d stale jobs from 'runs' table.", cleared)
job_stats = collections.defaultdict(int)
model_stats = collections.defaultdict(int)
fq = jobkit.FutureQueue(max_size=n_parallel)
exit_hook = jobkit.ExitHook(
lambda args: reporting.release_run(db, *args)
).register()
def flush_one():
"""
Get the next finished job from queue, blocking if necessary, write the report and relinquish the claim.
"""
job, cfg, task = fq.pop()
try:
report: reporting.Report = job.result()
report.write()
logger.info("%s+%s/%s done", task, cfg.model.org, cfg.model.ckpt)
except Exception:
logger.exception("%s+%s/%s failed", task, cfg.model.org, cfg.model.ckpt)
finally:
exit_hook.discard((cfg, task))
for cfg in cfgs:
for task, data_root in cfg.data.to_dict().items():
reason = get_skip_reason(db, cfg, task, data_root, dry_run)
if reason:
job_stats[reason] += 1
continue
if dry_run:
job_stats["todo"] += 1
model_stats[cfg.model.ckpt] += 1
continue # no side-effect
if not reporting.claim_run(db, cfg, task):
job_stats["queued"] += 1 # someone else just grabbed it
continue
exit_hook.add((cfg, task)) # for signal/atexit handler
job = executor.submit(worker, task, cfg)
fq.submit((job, cfg, task))
job_stats["submitted"] += 1
while fq.full():
flush_one()
if dry_run:
logger.info("Job Summary:")
logger.info("%-20s | %-5s", "Reason", "Count")
logger.info("-" * 31)
for reason, count in sorted(job_stats.items()):
logger.info("%-20s | %5d", reason, count)
logger.info("-" * 31)
logger.info("Model Summary:")
logger.info("%-70s | %-5s", "Model", "Count")
logger.info("-" * 79)
for model, count in sorted(model_stats.items()):
logger.info("%-70s | %5d", model, count)
logger.info("-" * 79)
return
while fq:
flush_one()
logger.info("Finished.")
@beartype.beartype
def worker(task_name: str, cfg: config.Experiment) -> reporting.Report:
helpers.bump_nofile(512)
module = importlib.import_module(f"biobench.{task_name}")
return module.benchmark(cfg)
@beartype.beartype
def get_skip_reason(
db, cfg: config.Experiment, task: str, data_root: str, dry_run: bool
) -> str | None:
"""Return a short reason string if we should skip (None -> keep)."""
try:
importlib.import_module(f"biobench.{task}")
except ModuleNotFoundError:
return "no code"
if not data_root:
return "no data"
if reporting.already_ran(db, cfg, task):
return "done"
if reporting.is_claimed(db, cfg, task):
return "queued"
return None
if __name__ == "__main__":
tyro.cli(main)