-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathCustomTrainer.py
More file actions
188 lines (162 loc) · 7.36 KB
/
CustomTrainer.py
File metadata and controls
188 lines (162 loc) · 7.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import logging
from typing import Any, List, Optional, Tuple
import pandas as pd
import torch
from lightning import LightningModule, Trainer
from lightning.fabric.utilities.data import _set_sampler_epoch
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.loops.fit_loop import _FitLoop
from lightning.pytorch.trainer import call
from torch.nn.utils.rnn import pad_sequence
from chebai.loggers.custom import CustomLogger
from chebai.preprocessing.reader import CLS_TOKEN, ChemDataReader
log = logging.getLogger(__name__)
class CustomTrainer(Trainer):
def __init__(self, *args, **kwargs):
"""
Initializes the CustomTrainer class, logging additional hyperparameters to the custom logger if specified.
Args:
*args: Positional arguments for the Trainer class.
**kwargs: Keyword arguments for the Trainer class.
"""
self.init_args = args
self.init_kwargs = kwargs
super().__init__(*args, **kwargs, deterministic=True)
# instantiation custom logger connector
self._logger_connector.on_trainer_init(self.logger, 1)
# log additional hyperparameters to wandb
if isinstance(self.logger, CustomLogger):
custom_logger = self.logger
assert isinstance(custom_logger, CustomLogger)
if custom_logger.verbose_hyperparameters:
log_kwargs = {}
for key, value in self.init_kwargs.items():
log_key, log_value = self._resolve_logging_argument(key, value)
log_kwargs[log_key] = log_value
self.logger.log_hyperparams(log_kwargs)
# use custom fit loop (https://lightning.ai/docs/pytorch/LTS/extensions/loops.html#overriding-the-default-loops)
self.fit_loop = LoadDataLaterFitLoop(self, self.min_epochs, self.max_epochs)
def _resolve_logging_argument(self, key: str, value: Any) -> Tuple[str, Any]:
"""
Resolves logging arguments, handling nested structures such as lists and complex objects.
Args:
key: The key of the argument.
value: The value of the argument.
Returns:
A tuple containing the resolved key and value.
"""
if isinstance(value, list):
key_value_pairs = [
self._resolve_logging_argument(f"{key}_{i}", v)
for i, v in enumerate(value)
]
return key, {k: v for k, v in key_value_pairs}
if not (
isinstance(value, str)
or isinstance(value, float)
or isinstance(value, int)
or value is None
):
params = {"class": value.__class__}
params.update(value.__dict__)
return key, params
else:
return key, value
def predict_from_file(
self,
model: LightningModule,
checkpoint_path: _PATH,
input_path: _PATH,
save_to: _PATH = "predictions.csv",
classes_path: Optional[_PATH] = None,
**kwargs,
) -> None:
"""
Loads a model from a checkpoint and makes predictions on input data from a file.
Args:
model: The model to use for predictions.
checkpoint_path: Path to the model checkpoint.
input_path: Path to the input file containing SMILES strings.
save_to: Path to save the predictions CSV file.
classes_path: Optional path to a file containing class names (if no class names are provided, columns will be numbered).
"""
loaded_model = model.__class__.load_from_checkpoint(checkpoint_path)
with open(input_path, "r") as input:
smiles_strings = [inp.strip() for inp in input.readlines()]
loaded_model.eval()
predictions = self._predict_smiles(loaded_model, smiles_strings)
predictions_df = pd.DataFrame(predictions.detach().cpu().numpy())
if classes_path is not None:
with open(classes_path, "r") as f:
predictions_df.columns = [cls.strip() for cls in f.readlines()]
predictions_df.index = smiles_strings
predictions_df.to_csv(save_to)
def _predict_smiles(
self, model: LightningModule, smiles: List[str]
) -> torch.Tensor:
"""
Predicts the output for a list of SMILES strings using the model.
Args:
model: The model to use for predictions.
smiles: A list of SMILES strings.
Returns:
A tensor containing the predictions.
"""
reader = ChemDataReader()
parsed_smiles = [reader._read_data(s) for s in smiles]
x = pad_sequence(
[torch.tensor(a, device=model.device) for a in parsed_smiles],
batch_first=True,
)
cls_tokens = (
torch.ones(x.shape[0], dtype=torch.int, device=model.device).unsqueeze(-1)
* CLS_TOKEN
)
features = torch.cat((cls_tokens, x), dim=1)
model_output = model({"features": features})
if model.model_type == "regression":
preds = model_output["logits"]
else:
preds = torch.sigmoid(model_output["logits"])
return preds
@property
def log_dir(self) -> Optional[str]:
"""
Returns the logging directory.
Returns:
The path to the logging directory if available, else the default root directory.
"""
if len(self.loggers) > 0:
logger = self.loggers[0]
if isinstance(logger, WandbLogger):
dirpath = logger.experiment.dir
else:
dirpath = self.loggers[0].log_dir
else:
dirpath = self.default_root_dir
dirpath = self.strategy.broadcast(dirpath)
return dirpath
class LoadDataLaterFitLoop(_FitLoop):
def on_advance_start(self) -> None:
"""Calls the hook ``on_train_epoch_start`` **before** the dataloaders are setup. This is necessary
so that the dataloaders can get information from the model. For example: The on_train_epoch_start
hook sets the curr_epoch attribute of the PubChemBatched dataset. With the Lightning configuration,
the dataloaders would always load batch 0 first, run an epoch, then get the epoch number (usually 0,
unless resuming from a checkpoint), then load batch 0 again (or some other batch). With this
implementation, the dataloaders are setup after the epoch number is set, so that the correct
batch is loaded."""
trainer = self.trainer
# update the epoch value for all samplers
assert self._combined_loader is not None
for i, dl in enumerate(self._combined_loader.flattened):
_set_sampler_epoch(dl, self.epoch_progress.current.processed)
if not self.restarted_mid_epoch and not self.restarted_on_epoch_end:
if not self.restarted_on_epoch_start:
self.epoch_progress.increment_ready()
call._call_callback_hooks(trainer, "on_train_epoch_start")
call._call_lightning_module_hook(trainer, "on_train_epoch_start")
self.epoch_progress.increment_started()
# this is usually at the front of advance_start, but here we need it at the end
# might need to setup data again depending on `trainer.reload_dataloaders_every_n_epochs`
self.setup_data()