forked from ChEB-AI/python-chebifier
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcli.py
More file actions
117 lines (103 loc) · 2.87 KB
/
cli.py
File metadata and controls
117 lines (103 loc) · 2.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import click
from chebifier.model_registry import ENSEMBLES
@click.group()
def cli():
"""Command line interface for Chebifier."""
pass
@cli.command()
@click.option(
"--ensemble-config",
"-e",
type=click.Path(exists=True),
default=None,
help="Configuration file for ensemble models",
)
@click.option("--smiles", "-s", multiple=True, help="SMILES strings to predict")
@click.option(
"--smiles-file",
"-f",
type=click.Path(exists=True),
help="File containing SMILES strings (one per line)",
)
@click.option(
"--output",
"-o",
type=click.Path(),
help="Output file to save predictions (optional)",
)
@click.option(
"--ensemble-type",
"-t",
type=click.Choice(ENSEMBLES.keys()),
default="wmv-f1",
help="Type of ensemble to use (default: Weighted Majority Voting)",
)
@click.option(
"--chebi-version",
"-v",
type=int,
default=241,
help="ChEBI version to use for checking consistency (default: 241)",
)
@click.option(
"--use-confidence",
"-c",
is_flag=True,
default=True,
help="Weight predictions based on how 'confident' a model is in its prediction (default: True)",
)
@click.option(
"--resolve-inconsistencies",
"-r",
is_flag=True,
default=True,
help="Resolve inconsistencies in predictions automatically (default: True)",
)
def predict(
ensemble_config,
smiles,
smiles_file,
output,
ensemble_type,
chebi_version,
use_confidence,
resolve_inconsistencies=True,
):
"""Predict ChEBI classes for SMILES strings using an ensemble model."""
# Instantiate ensemble model
ensemble = ENSEMBLES[ensemble_type](
ensemble_config,
chebi_version=chebi_version,
resolve_inconsistencies=resolve_inconsistencies,
)
# Collect SMILES strings from arguments and/or file
smiles_list = list(smiles)
if smiles_file:
with open(smiles_file, "r") as f:
smiles_list.extend([line.strip() for line in f if line.strip()])
if not smiles_list:
click.echo("No SMILES strings provided. Use --smiles or --smiles-file options.")
return
# Make predictions
predictions = ensemble.predict_smiles_list(
smiles_list, use_confidence=use_confidence
)
if output:
# save as json
import json
with open(output, "w") as f:
json.dump(
{smiles: pred for smiles, pred in zip(smiles_list, predictions)},
f,
indent=2,
)
else:
# Print results
for i, (smiles, prediction) in enumerate(zip(smiles_list, predictions)):
click.echo(f"Result for: {smiles}")
if prediction:
click.echo(f" Predicted classes: {', '.join(map(str, prediction))}")
else:
click.echo(" No predictions")
if __name__ == "__main__":
cli()