-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.py
More file actions
79 lines (61 loc) · 2.44 KB
/
main.py
File metadata and controls
79 lines (61 loc) · 2.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from torch.utils.data import DataLoader
from util import run
from config import RunConfig, OptimizeConfig
import argparse
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--run_config", type=str, required=True, help="Path to the YAML config file"
)
parser.add_argument(
"--optimize_config", type=str, help="Path to the optimization YAML config file"
)
parser.add_argument(
"--device", type=str, help="Device to run on (e.g. 'cuda:0' or 'cpu')"
)
args = parser.parse_args()
# Run Config
base_config = RunConfig.from_yaml(args.run_config)
# Device
if args.device:
base_config = base_config.with_overrides(device=args.device)
# Load data
ds_train, ds_val = base_config.load_data()
dl_train = DataLoader(ds_train, batch_size=base_config.batch_size, shuffle=True)
dl_val = DataLoader(ds_val, batch_size=base_config.batch_size, shuffle=False)
# Run
if args.optimize_config:
optimize_config = OptimizeConfig.from_yaml(args.optimize_config)
pruner = optimize_config.create_pruner()
def objective(trial, base_config, optimize_config, dl_train, dl_val):
params = optimize_config.suggest_params(trial)
overrides = {"project": f"{base_config.project}_Opt"}
for category, category_params in params.items():
overrides[category] = category_params
run_config = base_config.with_overrides(**overrides)
group_name = run_config.gen_group_name()
group_name += f"[{trial.number}]"
trial.set_user_attr("group_name", group_name)
return run(
run_config, dl_train, dl_val, group_name, trial=trial, pruner=pruner
)
study = optimize_config.create_study(project=f"{base_config.project}_Opt")
study.optimize(
lambda trial: objective(
trial, base_config, optimize_config, dl_train, dl_val
),
n_trials=optimize_config.trials,
)
print("Best trial:")
trial = study.best_trial
print(f" Value: {trial.value}")
print(" Params: ")
for key, value in trial.params.items():
print(f" {key}: {value}")
print(
f" Path: runs/{base_config.project}_Opt/{trial.user_attrs['group_name']}"
)
else:
run(base_config, dl_train, dl_val)
if __name__ == "__main__":
main()