Skip to content

Commit 0e7d931

Browse files
committed
[FIX] base featurizer + [FIX] requirements to include right version of tensorflow probability
1 parent 9c56f44 commit 0e7d931

3 files changed

Lines changed: 8 additions & 3 deletions

File tree

extra-requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ dgllife==0.3.2: deep-learning
2525
deepchem==2.8.0: deep-learning
2626
torch_geometric==2.6.1: deep-learning
2727
tensorflow==2.15.0: deep-learning
28-
tensorflow-probability==0.25.0: deep-learning
28+
tensorflow-probability==0.23.0: deep-learning
2929
boruta==0.4.3: preprocessing, machine-learning
3030
scikit-learn<1.6: machine-learning, deep-learning
3131
optuna==4.1.0: machine-learning, deep-learning

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ numpy
2323
dgllife==0.3.2
2424
dill==0.3.6
2525
optuna==4.1.0
26-
tensorflow-probability==0.25.0
26+
tensorflow-probability==0.23.0
2727
torch_geometric==2.6.1
2828
scikit-multilearn==0.2.0
2929
scikeras==0.12.0

src/deepmol/compound_featurization/base_featurizer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,14 @@ def featurize(self,
9898

9999
remove_mols_list = np.array(remove_mols)
100100
dataset.remove_elements(np.array(dataset.ids)[remove_mols_list], inplace=True)
101-
101+
102102
features = np.array(features, dtype=object)
103103
features = features[~remove_mols_list]
104+
105+
try:
106+
features = features.astype('float64')
107+
except:
108+
pass
104109

105110
if (isinstance(features[0], np.ndarray) and len(features[0].shape) == 2) or not isinstance(features[0],
106111
np.ndarray):

0 commit comments

Comments
 (0)