Skip to content

Commit 3389f4a

Browse files
committed
[FIX] for predictions with dataset.n_tasks = 0
1 parent c8a1d78 commit 3389f4a

4 files changed

Lines changed: 8 additions & 6 deletions

File tree

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[metadata]
22
name = deepmol
3-
version = 1.1.16
3+
version = 1.1.17
44
description = DeepMol: a python-based machine and deep learning framework for drug discovery
55
keywords = machine-learning, deep-learning, cheminformatics, drug-discovery
66
author = DeepMol Team

src/deepmol/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11

2-
__version__ = '1.1.16'
2+
__version__ = '1.1.17'

src/deepmol/models/keras_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,11 +214,11 @@ def predict_proba(self, dataset: Dataset, return_invalid: bool = False) -> np.nd
214214
self.logger.info(str(self.model))
215215
self.logger.info(str(type(self.model)))
216216
predictions = self.model.predict(dataset.X.astype('float32'))
217-
217+
218218
predictions = np.array(predictions)
219219
if predictions.shape != (len(dataset.mols), dataset.n_tasks):
220220
predictions = normalize_labels_shape(predictions, dataset.n_tasks)
221-
221+
222222
if len(predictions.shape) > 1:
223223
if predictions.shape[1] == len(dataset.mols) and predictions.shape[0] == dataset.n_tasks:
224224
predictions = predictions.T

src/deepmol/utils/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,10 @@ def normalize_labels_shape(y_pred: Union[List, np.ndarray], n_tasks: int) -> np.
167167
"""
168168
if not isinstance(y_pred, np.ndarray):
169169
y_pred = np.array(y_pred)
170-
171-
if n_tasks == 1:
170+
171+
if n_tasks == 0 and len(y_pred.shape) < 2:
172+
labels = _normalize_singletask_labels_shape(y_pred)
173+
elif n_tasks == 1:
172174
labels = _normalize_singletask_labels_shape(y_pred)
173175
else:
174176
if len(y_pred.shape) == 3:

0 commit comments

Comments
 (0)