|
6 | 6 | from chebai.preprocessing.datasets.chebi import ChEBIOver50 |
7 | 7 | from chebai.result.analyse_sem import PredictionSmoother, get_chebi_graph |
8 | 8 |
|
| 9 | +from chebifier.check_env import check_package_installed |
9 | 10 | from chebifier.prediction_models.base_predictor import BasePredictor |
10 | | -from functools import lru_cache |
| 11 | + |
11 | 12 |
|
12 | 13 | class BaseEnsemble: |
13 | 14 |
|
14 | | - def __init__(self, model_configs: dict, chebi_version: int = 241): |
| 15 | + def __init__(self, model_configs: dict, chebi_version: int = 241, resolve_inconsistencies: bool = True): |
15 | 16 | # Deferred Import: To avoid circular import error |
16 | 17 | from chebifier.model_registry import MODEL_TYPES |
17 | 18 |
|
18 | 19 | self.chebi_dataset = ChEBIOver50(chebi_version=chebi_version) |
19 | 20 | self.chebi_dataset._download_required_data() # download chebi if not already downloaded |
20 | 21 | self.chebi_graph = get_chebi_graph(self.chebi_dataset, None) |
21 | | - self.disjoint_files = [ |
| 22 | + local_disjoint_files = [ |
22 | 23 | os.path.join("data", "disjoint_chebi.csv"), |
23 | 24 | os.path.join("data", "disjoint_additional.csv"), |
24 | 25 | ] |
| 26 | + self.disjoint_files = [] |
| 27 | + for file in local_disjoint_files: |
| 28 | + if os.path.isfile(file): |
| 29 | + self.disjoint_files.append(file) |
| 30 | + else: |
| 31 | + print(f"Disjoint axiom file {file} not found. Loading from huggingface instead...") |
| 32 | + from chebifier.hugging_face import download_model_files |
| 33 | + self.disjoint_files.append(download_model_files({ |
| 34 | + "repo_id": "chebai/chebifier", |
| 35 | + "repo_type": "dataset", |
| 36 | + "files": {"disjoint_file": os.path.basename(file)}, |
| 37 | + })["disjoint_file"]) |
25 | 38 |
|
26 | 39 | self.models = [] |
27 | 40 | self.positive_prediction_threshold = 0.5 |
28 | 41 | for model_name, model_config in model_configs.items(): |
29 | 42 | model_cls = MODEL_TYPES[model_config["type"]] |
30 | 43 | if "hugging_face" in model_config: |
31 | | - from api.hugging_face import download_model_files |
| 44 | + from chebifier.hugging_face import download_model_files |
32 | 45 | hugging_face_kwargs = download_model_files(model_config["hugging_face"]) |
33 | 46 | else: |
34 | 47 | hugging_face_kwargs = {} |
| 48 | + if "package_name" in model_config: |
| 49 | + check_package_installed(model_config["package_name"]) |
| 50 | + |
35 | 51 | model_instance = model_cls( |
36 | 52 | model_name, **model_config, **hugging_face_kwargs, chebi_graph=self.chebi_graph |
37 | 53 | ) |
38 | 54 | assert isinstance(model_instance, BasePredictor) |
39 | 55 | self.models.append(model_instance) |
40 | 56 |
|
41 | 57 |
|
42 | | - |
43 | | - self.smoother = PredictionSmoother( |
44 | | - self.chebi_dataset, |
45 | | - label_names=None, |
46 | | - disjoint_files=self.disjoint_files, |
47 | | - ) |
| 58 | + if resolve_inconsistencies: |
| 59 | + self.smoother = PredictionSmoother( |
| 60 | + self.chebi_dataset, |
| 61 | + label_names=None, |
| 62 | + disjoint_files=self.disjoint_files, |
| 63 | + ) |
| 64 | + else: |
| 65 | + self.smoother = None |
48 | 66 |
|
49 | 67 | def gather_predictions(self, smiles_list): |
50 | 68 | # get predictions from all models for the SMILES list |
@@ -131,15 +149,15 @@ def consolidate_predictions(self, predictions, classwise_weights, predicted_clas |
131 | 149 | # Smooth predictions |
132 | 150 | start_time = time.perf_counter() |
133 | 151 | class_names = list(predicted_classes.keys()) |
134 | | - self.smoother.set_label_names(class_names) |
135 | | - smooth_net_score = self.smoother(net_score) |
| 152 | + if self.smoother is not None: |
| 153 | + self.smoother.set_label_names(class_names) |
| 154 | + smooth_net_score = self.smoother(net_score) |
| 155 | + class_decisions = (smooth_net_score > 0.5) & has_valid_predictions # Shape: (num_smiles, num_classes) |
| 156 | + else: |
| 157 | + class_decisions = (net_score > 0) & has_valid_predictions # Shape: (num_smiles, num_classes) |
136 | 158 | end_time = time.perf_counter() |
137 | 159 | print(f"Prediction smoothing took {end_time - start_time:.2f} seconds") |
138 | 160 |
|
139 | | - class_decisions = ( |
140 | | - smooth_net_score > 0.5 |
141 | | - ) & has_valid_predictions # Shape: (num_smiles, num_classes) |
142 | | - |
143 | 161 | complete_failure = torch.all(~has_valid_predictions, dim=1) |
144 | 162 | return class_decisions, complete_failure |
145 | 163 |
|
|
0 commit comments