forked from ChEB-AI/python-chebifier
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbase_predictor.py
More file actions
41 lines (33 loc) · 1.15 KB
/
base_predictor.py
File metadata and controls
41 lines (33 loc) · 1.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import json
from abc import ABC
from .._custom_cache import modelwise_smiles_lru_cache
class BasePredictor(ABC):
def __init__(
self,
model_name: str,
model_weight: int = 1,
classwise_weights_path: str = None,
**kwargs,
):
self.model_name = model_name
self.model_weight = model_weight
if classwise_weights_path is not None:
self.classwise_weights = json.load(
open(classwise_weights_path, encoding="utf-8")
)
else:
self.classwise_weights = None
self._description = kwargs.get("description", None)
@modelwise_smiles_lru_cache.batch_decorator
def predict_smiles_list(self, smiles_list: list[str]) -> dict:
raise NotImplementedError()
def predict_smiles(self, smiles: str) -> dict:
# by default, use list-based prediction
return self.predict_smiles_list([smiles])[0]
@property
def info_text(self):
if self._description is None:
return "No description is available for this model."
return self._description
def explain_smiles(self, smiles):
return None