-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathseq.py
More file actions
98 lines (78 loc) · 3.7 KB
/
seq.py
File metadata and controls
98 lines (78 loc) · 3.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import json
from rnaglib.dataset_transforms.cd_hit import CDHitComputer
from rnaglib.dataset_transforms.structure_distance_computer import StructureDistanceComputer
from rnaglib.tasks import get_task
from rnaglib.transforms import SequenceRepresentation, RNAFMTransform
from rnaglib.encoders import ListEncoder
from rnaglib.dataset_transforms import ClusterSplitter, RandomSplitter
from rnaglib.tasks import RNAGo
from exp import RNATrainer
from model_seq import SequenceModel
TASKS_TODO = ['rna_cm',
#'rna_prot',
#'rna_site'
]
# Use this if you are submitting one job per task
# TASKS_TODO = [os.environ.get('TASK')]
#TASKS_TODO = ['rna_site_redundant']
RNA_FM = [True, False]
MODEL_ARGS = {"rna_cm": {"num_layers": 2, "use_bilstm": True, "hidden_channels":
32},
"rna_go": {"num_layers": 2},
"rna_if": {"num_layers": 2,
"hidden_channels": 128},
"rna_ligand": {"num_layers": 4},
"rna_prot": {"num_layers": 2,
"hidden_channels": 64,
"dropout_rate": 0.2},
"rna_site": {"num_layers": 2,
"hidden_channels": 256},
"rna_go_struc_0.6": {"num_layers": 3},
"rna_site_redundant": {"num_layers": 3,
"hidden_channels": 256},
}
TRAINER_ARGS = {"rna_cm": {'epochs': 50,
"batch_size": 8},
"rna_go": {"epochs": 20,
"learning_rate":0.0001}, #0.001 (original)
"rna_if": {"epochs": 40, # There are only marginal improvements running a hundred epochs, so we leave it at 40 for the splitting analysis
"learning_rate": 0.0001},
"rna_ligand": {"epochs": 40,
"learning_rate": 1e-5},
"rna_prot": {"epochs": 100,
"learning_rate": 0.001}, #0.01 (original)
"rna_site": {"batch_size": 8,
"epochs": 50,
"learning_rate": 1e-5},
"rna_go_struc_0.6": {"epochs": 100,
"learning_rate": 0.0001},
"rna_site_redundant": {"epochs": 100,
"learning_rate": 0.001}
}
recompute = True
for tid in TASKS_TODO:
root = f"roots/{tid}_seq"
task = get_task(task_id=tid, root=root)
rnafm = RNAFMTransform()
[rnafm(rna) for rna in task.dataset]
task.dataset.features_computer.add_feature(
feature_names=["rnafm"], custom_encoders={"rnafm": ListEncoder(640)})
# Representation needs to be added here as the loaders are not updated when the rep is added later.
task.add_representation(SequenceRepresentation(framework="pyg"))
task.get_split_loaders(recompute=False)
for seed in [0, 1, 2, 3, 4, 5, 6]:
model = SequenceModel.from_task(task, **MODEL_ARGS[tid], num_node_features=644)
rep = SequenceRepresentation(framework="pyg")
model_string = '_'.join(f'{k}-{v}' for k, v in MODEL_ARGS[tid].items())
result_file = f"results/workshop_{tid}_seq_{seed}_{model_string}.json"
if os.path.exists(result_file) and not recompute:
continue
exp_name = f"{tid}_seq_{seed}_{model_string}"
trainer = RNATrainer(task, model, rep, seed=seed,\
wandb_project="rnaglib-seq", exp_name=exp_name, **TRAINER_ARGS[tid])
trainer.train()
metrics = model.evaluate(task, split="test")
with open(result_file, "w") as j:
json.dump(metrics, j)
pass