-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathcli.py
More file actions
131 lines (110 loc) · 4.73 KB
/
cli.py
File metadata and controls
131 lines (110 loc) · 4.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from typing import Dict, Set, Type
from lightning.pytorch.cli import LightningArgumentParser, LightningCLI
from chebai.preprocessing.datasets.base import XYBaseDataModule
from chebai.trainer.CustomTrainer import CustomTrainer
class ChebaiCLI(LightningCLI):
"""
Custom CLI subclass for Chebai project based on PyTorch Lightning's LightningCLI.
Args:
save_config_kwargs (dict): Keyword arguments for saving configuration.
parser_kwargs (dict): Keyword arguments for parser configuration.
Attributes:
save_config_kwargs (dict): Configuration options for saving.
parser_kwargs (dict): Configuration options for the argument parser.
"""
def __init__(self, *args, **kwargs):
"""
Initialize ChebaiCLI with custom trainer and configure parser settings.
Args:
args (list): List of arguments for LightningCLI.
kwargs (dict): Keyword arguments for LightningCLI.
save_config_kwargs (dict): Keyword arguments for saving configuration.
parser_kwargs (dict): Keyword arguments for parser configuration.
"""
super().__init__(trainer_class=CustomTrainer, *args, **kwargs)
def add_arguments_to_parser(self, parser: LightningArgumentParser):
"""
Link input parameters that are used by different classes (e.g. number of labels)
see https://lightning.ai/docs/pytorch/stable/cli/lightning_cli_expert.html#argument-linking
Args:
parser (LightningArgumentParser): Argument parser instance.
"""
def call_data_methods(data: Type[XYBaseDataModule]):
if data._num_of_labels is None:
data.prepare_data()
data.setup()
return data.num_of_labels
parser.link_arguments(
"data",
"model.init_args.out_dim",
apply_on="instantiate",
compute_fn=call_data_methods,
)
parser.link_arguments(
"data.feature_vector_size",
"model.init_args.input_dim",
apply_on="instantiate",
)
parser.link_arguments(
"data.classes_txt_file_path",
"model.init_args.classes_txt_file_path",
apply_on="instantiate",
)
for kind in ("train", "val", "test"):
for average in (
"micro-f1",
"macro-f1",
"balanced-accuracy",
"roc-auc",
"f1",
"mse",
"rmse",
"r2",
):
# When using lightning > 2.5.1 then need to uncomment all metrics that are not used
# for average in ("mse", "rmse","r2"): # for regression
# for average in ("f1", "roc-auc"): # for binary classification
# for average in ("micro-f1", "macro-f1", "roc-auc"): # for multilabel classification
# for average in ("micro-f1", "macro-f1", "balanced-accuracy", "roc-auc"): # for multilabel classification using balanced-accuracy
parser.link_arguments(
"data.num_of_labels",
f"model.init_args.{kind}_metrics.init_args.metrics.{average}.init_args.num_labels",
apply_on="instantiate",
)
parser.link_arguments(
"data.num_of_labels", "trainer.callbacks.init_args.num_labels"
)
# parser.link_arguments(
# "model.init_args.out_dim", "trainer.callbacks.init_args.num_labels"
# )
# parser.link_arguments(
# "data", "model.init_args.criterion.init_args.data_extractor"
# )
# parser.link_arguments(
# "data.init_args.chebi_version",
# "model.init_args.criterion.init_args.data_extractor.init_args.chebi_version",
# )
parser.link_arguments(
"data", "model.init_args.criterion.init_args.data_extractor"
)
@staticmethod
def subcommands() -> Dict[str, Set[str]]:
"""
Defines the list of available subcommands and the arguments to skip.
Returns:
Dict[str, Set[str]]: Dictionary where keys are subcommands and values are sets of arguments to skip.
"""
return {
"fit": {"model", "train_dataloaders", "val_dataloaders", "datamodule"},
"validate": {"model", "dataloaders", "datamodule"},
"test": {"model", "dataloaders", "datamodule"},
"predict": {"model", "dataloaders", "datamodule"},
}
def cli():
"""
Main function to instantiate and run the ChebaiCLI.
"""
ChebaiCLI(
save_config_kwargs={"config_filename": "lightning_config.yaml"},
parser_kwargs={"parser_mode": "omegaconf"},
)