Skip to content

Commit 85e93ac

Browse files
committed
refactor: configure with hydra
1 parent 316f9e0 commit 85e93ac

6 files changed

Lines changed: 246 additions & 490 deletions

File tree

install_env.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ conda install -y -c nvidia/label/cuda-${cuda_version} \
5353
conda install -y -c conda-forge \
5454
biopython \
5555
einops \
56+
hydra-core \
5657
tensorboard \
5758
tqdm
5859

profold2/command/evaluator.py

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
for further help.
66
"""
77
import os
8+
from dataclasses import dataclass, make_dataclass
89
import logging
910
import pickle
1011

@@ -17,18 +18,23 @@
1718
from profold2.model import functional, profiler, snapshot, FeatureBuilder, ReturnValues
1819
from profold2.utils import exists, timing
1920

20-
from profold2.command.worker import main, autocast_ctx, WorkerModel, WorkerXPU
21+
from profold2.command import worker
2122

2223

23-
def evaluate(rank, args): # pylint: disable=redefined-outer-name
24-
worker = WorkerModel(rank, args)
25-
feats, model = worker.load(args.model)
26-
features = FeatureBuilder(feats).to(worker.device())
24+
@dataclass
25+
class Args(worker.Args):
26+
pass
27+
28+
29+
def run(rank, args): # pylint: disable=redefined-outer-name
30+
runner = worker.WorkerModel(rank, args)
31+
feats, model = runner.load(args.model)
32+
features = FeatureBuilder(feats).to(runner.device())
2733
logging.info('feats: %s', feats)
2834

2935
kwargs = {}
30-
if rank.is_available() and WorkerXPU.world_size(args.nnodes) > 1:
31-
kwargs['num_replicas'] = WorkerXPU.world_size(args.nnodes)
36+
if rank.is_available() and worker.world_size(args.nnodes) > 1:
37+
kwargs['num_replicas'] = worker.world_size(args.nnodes)
3238
kwargs['rank'] = rank.rank
3339
test_loader = dataset.load(
3440
data_dir=args.eval_data,
@@ -68,7 +74,7 @@ def data_eval(idx, batch):
6874
# predict - out is (batch, L * 3, 3)
6975
with timing(f'Running model on {fasta_name} {fasta_len}', logging.debug):
7076
with torch.no_grad():
71-
with autocast_ctx(args.amp_enabled):
77+
with worker.autocast_ctx(args.amp_enabled):
7278
r = ReturnValues(
7379
**model(
7480
batch=batch, # pylint: disable=not-callable
@@ -320,34 +326,31 @@ def add_arguments(parser): # pylint: disable=redefined-outer-name
320326

321327
if __name__ == '__main__':
322328
import argparse
329+
import hydra
323330

324-
parser = argparse.ArgumentParser()
325-
326-
# init distributed env
327-
parser.add_argument('--nnodes', type=int, default=None, help='number of nodes.')
328-
parser.add_argument('--node_rank', type=int, default=0, help='rank of the node.')
329-
parser.add_argument(
330-
'--local_rank', type=int, default=None, help='local rank of xpu, default=None'
331-
)
332-
parser.add_argument(
333-
'--init_method',
334-
type=str,
335-
default='file:///tmp/profold2.dist',
336-
help='method to initialize the process group, '
337-
'default=\'file:///tmp/profold2.dist\''
331+
parser = argparse.ArgumentParser(
332+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
338333
)
339334

340-
# output dir
335+
parser.add_argument('-c', '--config', type=str, default=None, help='config file.')
341336
parser.add_argument(
342-
'-o',
343-
'--prefix',
344-
type=str,
345-
default='.',
346-
help='prefix of out directory, default=\'.\''
337+
'overrides',
338+
nargs='*',
339+
metavar='KEY=VAL',
340+
help='override configs, see: https://hydra.cc'
347341
)
348-
add_arguments(parser)
349-
parser.add_argument('-v', '--verbose', action='store_true', help='verbose')
350342

351343
args = parser.parse_args()
352-
353-
main(args, evaluate)
344+
config_dir, config_name = os.path.split(
345+
os.path.abspath(args.config)
346+
) if exists(args.config) else (os.getcwd(), None)
347+
348+
with hydra.initialize_config_dir(
349+
version_base=None, config_dir=config_dir, job_name=__file__
350+
):
351+
worker.main(
352+
make_dataclass('t', [], namespace={
353+
'Args': Args,
354+
'run': run
355+
}), hydra.compose(config_name, args.overrides)
356+
)

profold2/command/main.py

Lines changed: 25 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,75 +4,56 @@
44
```
55
for further help.
66
"""
7+
import os
78
import argparse
89

10+
import hydra
11+
912
from profold2.command import (evaluator, predictor, trainer, worker)
10-
from profold2.utils import env
13+
from profold2.utils import env, exists
1114

12-
_COMMANDS = [
13-
('train', trainer.train, trainer.add_arguments),
14-
('evaluate', evaluator.evaluate, evaluator.add_arguments),
15-
('predict', predictor.predict, predictor.add_arguments),
16-
]
15+
_COMMANDS = [('train', trainer), ('evaluate', evaluator), ('predict', predictor)]
1716

1817

1918
def create_args():
2019
formatter_class = argparse.ArgumentDefaultsHelpFormatter
2120
parser = argparse.ArgumentParser(formatter_class=formatter_class)
2221

23-
# distributed args
24-
parser.add_argument(
25-
'--nnodes',
26-
type=int,
27-
default=env('SLURM_NNODES', defval=None, dtype=int),
28-
help='number of nodes.'
29-
)
30-
parser.add_argument(
31-
'--node_rank',
32-
type=int,
33-
default=env('SLURM_NODEID', defval=0, dtype=int),
34-
help='rank of the node.'
35-
)
36-
parser.add_argument(
37-
'--local_rank',
38-
type=int,
39-
default=int(env('LOCAL_RANK', defval=0, dtype=int)),
40-
help='local rank of xpu.'
41-
)
42-
parser.add_argument(
43-
'--init_method',
44-
type=str,
45-
default=None,
46-
help='method to initialize the process group.'
47-
)
48-
4922
# command args
5023
subparsers = parser.add_subparsers(dest='command', required=True)
51-
for cmd, _, add_arguments in _COMMANDS:
24+
for cmd, _ in _COMMANDS:
5225
cmd_parser = subparsers.add_parser(cmd, formatter_class=formatter_class)
53-
54-
# output dir
5526
cmd_parser.add_argument(
56-
'-o', '--prefix', type=str, default='.', help='prefix of out directory.'
27+
'-c', '--config', type=str, default=None, help='config file.'
28+
)
29+
cmd_parser.add_argument(
30+
'overrides',
31+
nargs='*',
32+
metavar='KEY=VAL',
33+
help='override configs, see: https://hydra.cc'
5734
)
58-
add_arguments(cmd_parser)
59-
# verbose
60-
cmd_parser.add_argument('-v', '--verbose', action='store_true', help='verbose')
6135

6236
return parser.parse_args()
6337

6438

65-
def create_fn(args): # pylint: disable=redefined-outer-name
66-
for cmd, fn, _ in _COMMANDS:
39+
def create_task(args): # pylint: disable=redefined-outer-name
40+
for cmd, task in _COMMANDS:
6741
if cmd == args.command:
68-
return fn
42+
return task
6943
return None
7044

7145

7246
def main():
7347
args = create_args()
74-
work_fn = create_fn(args)
75-
worker.main(args, work_fn)
48+
config_dir, config_name = os.path.split(
49+
os.path.abspath(args.config)
50+
) if exists(args.config) else ('.', None)
51+
52+
with hydra.initialize_config_dir(
53+
version_base=None, config_dir=config_dir, job_name=args.command
54+
):
55+
task = create_task(args)
56+
worker.main(task, hydra.compose(config_name, args.overrides))
7657

7758

7859
if __name__ == '__main__':

profold2/command/predictor.py

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
for further help.
66
"""
77
import os
8+
from dataclasses import dataclass, make_dataclass
89
import functools
910
import glob
1011
import json
@@ -23,6 +24,7 @@
2324
from profold2.model import profiler, snapshot, FeatureBuilder, ReturnValues
2425
from profold2.utils import exists, timing
2526

27+
from profold2.command import worker
2628
from profold2.command.worker import main, autocast_ctx, WorkerModel, WorkerXPU
2729

2830

@@ -128,7 +130,21 @@ def _location_split(model_location):
128130
yield model_name, (features, model)
129131

130132

131-
def predict(rank, args): # pylint: disable=redefined-outer-name
133+
@dataclass
134+
class Args(worker.Args):
135+
models: list[str] # models to be loaded using[model_name=model_location] format
136+
model_recycles: int = 0 # number of recycles
137+
model_shard_size: Optional[int] = 0 # shard size in the evoformer model
138+
map_location: str = 'cpu' # remapped to an alternative set of devices
139+
140+
data_dir: Optional[str] = None # dataset dir
141+
data_idx: Optional[str] = None # dataset idx
142+
add_pseudo_linker: bool = False # enable loading complex data
143+
144+
max_msa_size: int = 1024
145+
146+
147+
def run(rank, args): # pylint: disable=redefined-outer-name
132148
model_runners = dict(_load_models(rank, args))
133149
logging.info('Have %d models: %s', len(model_runners), list(model_runners.keys()))
134150

@@ -360,35 +376,31 @@ def add_arguments(parser): # pylint: disable=redefined-outer-name
360376

361377
if __name__ == '__main__':
362378
import argparse
379+
import hydra
363380

364-
parser = argparse.ArgumentParser()
365-
366-
# init distributed env
367-
parser.add_argument('--nnodes', type=int, default=None, help='number of nodes.')
368-
parser.add_argument('--node_rank', type=int, default=0, help='rank of the node.')
369-
parser.add_argument(
370-
'--local_rank', type=int, default=None, help='local rank of xpu, default=None'
371-
)
372-
parser.add_argument(
373-
'--init_method',
374-
type=str,
375-
default='file:///tmp/profold2.dist',
376-
help='method to initialize the process group, '
377-
'default=\'file:///tmp/profold2.dist\''
381+
parser = argparse.ArgumentParser(
382+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
378383
)
379384

380-
# output dir
385+
parser.add_argument('-c', '--config', type=str, default=None, help='config file.')
381386
parser.add_argument(
382-
'-o',
383-
'--prefix',
384-
type=str,
385-
default='.',
386-
help='prefix of out directory, default=\'.\''
387+
'overrides',
388+
nargs='*',
389+
metavar='KEY=VAL',
390+
help='override configs, see: https://hydra.cc'
387391
)
388-
add_arguments(parser)
389-
# verbose
390-
parser.add_argument('-v', '--verbose', action='store_true', help='verbose')
391392

392393
args = parser.parse_args()
393-
394-
main(args, predict)
394+
config_dir, config_name = os.path.split(
395+
os.path.abspath(args.config)
396+
) if exists(args.config) else (os.getcwd(), None)
397+
398+
with hydra.initialize_config_dir(
399+
version_base=None, config_dir=config_dir, job_name=__file__
400+
):
401+
worker.main(
402+
make_dataclass('t', [], namespace={
403+
'Args': Args,
404+
'run': run
405+
}), hydra.compose(config_name, args.overrides)
406+
)

0 commit comments

Comments
 (0)