Skip to content

Commit e5a71ef

Browse files
authored
Merge pull request #116 from BioSystemsUM/issue_105
[FIX] saving models with dots in the path to save the model + [BUMP] …
2 parents 16f7b6d + ea4ea9c commit e5a71ef

6 files changed

Lines changed: 94 additions & 10 deletions

File tree

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[metadata]
22
name = deepmol
3-
version = 1.1.14
3+
version = 1.1.15
44
description = DeepMol: a python-based machine and deep learning framework for drug discovery
55
keywords = machine-learning, deep-learning, cheminformatics, drug-discovery
66
author = DeepMol Team

src/deepmol/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11

2-
__version__ = '1.1.14'
2+
__version__ = '1.1.15'

src/deepmol/models/sklearn_models.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,8 @@ def save(self, folder_path: str = None):
190190

191191
save_to_disk(self.model, model_path)
192192

193-
# change file path to keep the extension but add _params
194-
parameters_file_path = model_path.split('.')[0] + '_params.' + model_path.split('.')[1]
193+
base, ext = os.path.splitext(model_path)
194+
parameters_file_path = f"{base}_params{ext}"
195195
save_to_disk(self.parameters_to_save, parameters_file_path)
196196

197197
@classmethod
@@ -213,7 +213,8 @@ def load(cls, folder_path: str, **kwargs) -> 'SklearnModel':
213213
model_path = cls.get_model_filename(folder_path)
214214
model = load_from_disk(model_path)
215215
# change file path to keep the extension but add _params
216-
parameters_file_path = ".".join(model_path.split('.')[:-1]) + '_params.' + model_path.split('.')[-1]
216+
base, ext = os.path.splitext(model_path)
217+
parameters_file_path = f"{base}_params{ext}"
217218
params = load_from_disk(parameters_file_path)
218219
instance = cls(model=model, model_dir=model_path, **params)
219220
return instance

tests/unit_tests/models/test_deepchem_model.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import shutil
23
from unittest import TestCase
34
from unittest.mock import MagicMock
45

@@ -109,6 +110,25 @@ def test_save(self):
109110
new_predictions = model_graph_loaded.predict(ds_test)
110111
self.assertTrue(np.array_equal(test_preds, new_predictions))
111112

113+
def test_save_load_with_dots(self):
114+
ds_train = self.multitask_dataset
115+
ds_train.X = ConvMolFeaturizer().featurize([MolFromSmiles('CCC')] * 10)
116+
ds_test = self.multitask_dataset_test
117+
ds_test.X = ConvMolFeaturizer().featurize([MolFromSmiles('CCC')] * 10)
118+
119+
model_graph = DeepChemModel(GraphConvModel, n_tasks=ds_train.n_tasks, mode='classification', epochs=3)
120+
121+
model_graph.fit(ds_train)
122+
test_preds = model_graph.predict(ds_test)
123+
124+
model_graph.save("../test_model")
125+
model_graph_loaded = DeepChemModel.load("../test_model")
126+
self.assertEqual(model_graph.n_tasks, ds_train.n_tasks)
127+
self.assertEqual(model_graph.epochs, 3)
128+
new_predictions = model_graph_loaded.predict(ds_test)
129+
self.assertTrue(np.array_equal(test_preds, new_predictions))
130+
shutil.rmtree("../test_model")
131+
112132
def test_cross_validate(self):
113133
ds_train = self.binary_dataset
114134
ds_train.X = ConvMolFeaturizer().featurize([MolFromSmiles('CCC')] * 100)

tests/unit_tests/models/test_keras_model.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,40 @@ def test_save_model(self):
9696

9797
shutil.rmtree("test_model")
9898

99+
def test_save_model_with_dots(self):
100+
model = KerasModel(model_builder=make_cnn_model,
101+
epochs=2, input_dim=self.binary_dataset.X.shape[1])
102+
model.fit(self.binary_dataset)
103+
104+
first_predictions = model.predict(self.binary_dataset_test)
105+
106+
model.save("../test_model")
107+
loaded_model = KerasModel.load("../test_model")
108+
self.assertEqual(2, loaded_model.epochs)
109+
self.assertEqual(50, loaded_model.parameters_to_save["input_dim"])
110+
loaded_model_predictions = loaded_model.predict(self.binary_dataset_test)
111+
112+
assert np.array_equal(first_predictions, loaded_model_predictions)
113+
114+
shutil.rmtree("../test_model")
115+
116+
def test_load_models_with_dots(self):
117+
model = KerasModel(model_builder=make_cnn_model,
118+
epochs=2, input_dim=self.binary_dataset.X.shape[1])
119+
model.fit(self.binary_dataset)
120+
121+
first_predictions = model.predict(self.binary_dataset_test)
122+
123+
model.save("../test_model")
124+
loaded_model = KerasModel.load("../test_model")
125+
self.assertEqual(2, loaded_model.epochs)
126+
self.assertEqual(50, loaded_model.parameters_to_save["input_dim"])
127+
loaded_model_predictions = loaded_model.predict(self.binary_dataset_test)
128+
129+
assert np.array_equal(first_predictions, loaded_model_predictions)
130+
131+
shutil.rmtree("../test_model")
132+
99133
def test_baseline_models(self):
100134
model_kwargs = {'input_dim': 50}
101135
keras_kwargs = {}

tests/unit_tests/models/test_sklearn_models.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -150,11 +150,12 @@ def test_save_model(self):
150150

151151
self.assertEqual("classification", model.mode)
152152

153-
with self.assertRaises(ValueError):
154-
model.save("test_model.params.pkl")
155-
156-
with self.assertRaises(ValueError):
157-
model.save("test_model.params.joblib")
153+
model.save("test_model.params.pkl")
154+
self.assertTrue(os.path.exists("test_model.params.pkl"))
155+
shutil.rmtree("test_model.params.pkl")
156+
model.save("test_model.params.joblib")
157+
self.assertTrue(os.path.exists("test_model.params.joblib"))
158+
shutil.rmtree("test_model.params.joblib")
158159

159160
rf = RandomForestClassifier()
160161
model = SklearnModel(model=rf, mode="classification", model_dir="test_model")
@@ -164,6 +165,34 @@ def test_save_model(self):
164165

165166
shutil.rmtree("test_model")
166167

168+
def test_save_model_with_dots(self):
169+
rf = RandomForestClassifier()
170+
model = SklearnModel(model=rf, mode="classification", model_dir="test_model")
171+
model.fit(self.binary_dataset)
172+
model.save("../test_model")
173+
self.assertTrue(os.path.exists("../test_model"))
174+
shutil.rmtree("../test_model")
175+
176+
def test_load_models_with_dots(self):
177+
rf = RandomForestClassifier()
178+
model = SklearnModel(model=rf, mode="classification", model_dir="../test_model")
179+
model.fit(self.binary_dataset)
180+
model.save("../test_model")
181+
self.assertTrue(os.path.exists("../test_model"))
182+
183+
predictions_1 = model.predict(self.binary_dataset_test)
184+
185+
new_model = SklearnModel.load("../test_model")
186+
y_test = new_model.predict(self.binary_dataset_test)
187+
self.assertEqual(len(y_test), len(self.binary_dataset_test.y))
188+
self.assertIsInstance(new_model, SklearnModel)
189+
self.assertEqual("classification", new_model.mode)
190+
191+
assert np.array_equal(predictions_1, y_test)
192+
193+
shutil.rmtree("../test_model")
194+
195+
167196
def test_load_model(self):
168197
rf = RandomForestClassifier()
169198
model = SklearnModel(model=rf, mode="classification")

0 commit comments

Comments
 (0)