This repository was archived by the owner on Jan 26, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlearn.py
More file actions
executable file
·84 lines (56 loc) · 2.4 KB
/
learn.py
File metadata and controls
executable file
·84 lines (56 loc) · 2.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#!/usr/bin/env python
import sys
import os
import logging
import torch.optim as optim
# Assure that python can find the deeprank files:
deeprank_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, deeprank_root)
from deeprank.learn.NeuralNet import NeuralNet
from deeprank.learn.DataSet import DataSet
from deeprank.learn.model3d import cnn_class
from deeprank.models.metrics import OutputExporter, LabelExporter, TensorboardVariantClassificationExporter
logging.basicConfig(filename="learn-%d.log" % os.getpid(), filemode="w", level=logging.INFO)
def interpret_args(args, usage):
""" Convert a list of commandline arguments into a set of positional and keyword arguments.
Args (list of str): the commandline arguments
Returns: (tuple(list of str, dict of str)): the positional and keyword arguments
"""
if len(args) == 0:
print(usage)
sys.exit(1)
if "--help" in args or "-h" in args:
print(usage)
sys.exit(0)
positional_args = []
kwargs = {}
i = 0
while i < len(args):
if args[i].startswith("--"):
key = args[i][2:]
i += 1
kwargs[key] = args[i]
elif args[i].startswith("-"):
key = args[i][1:2]
if len(args[i]) > 2:
kwargs[key] = args[i][2:]
else:
i += 1
kwargs[key] = args[i]
else:
positional_args.append(args[i])
i += 1
return (positional_args, kwargs)
if __name__ == "__main__":
usage = "Usage: %s [-e EPOCH_COUNT] *preprocessed_hdf5_files" % sys.argv[0]
args, kwargs = interpret_args(sys.argv[1:], usage)
dataset = DataSet(args, normalize_features=True)
if len(args) == 0:
raise RuntimeError("No preprocessed HDF5 files given")
epoch_count = int(kwargs.get('e', 50))
run_directory = "run-{}".format(os.getpid())
neural_net = NeuralNet(dataset, cnn_class, model_type='3d',task='class', cuda=False,
metrics_exporters=[OutputExporter(run_directory), LabelExporter(run_directory),
TensorboardVariantClassificationExporter(run_directory)])
neural_net.optimizer = optim.AdamW(neural_net.net.parameters(), lr=0.001, weight_decay=0.005)
neural_net.train(nepoch = epoch_count, divide_trainset=None, train_batch_size = 5, num_workers=0)