@@ -150,11 +150,12 @@ def test_save_model(self):
150150
151151 self .assertEqual ("classification" , model .mode )
152152
153- with self .assertRaises (ValueError ):
154- model .save ("test_model.params.pkl" )
155-
156- with self .assertRaises (ValueError ):
157- model .save ("test_model.params.joblib" )
153+ model .save ("test_model.params.pkl" )
154+ self .assertTrue (os .path .exists ("test_model.params.pkl" ))
155+ shutil .rmtree ("test_model.params.pkl" )
156+ model .save ("test_model.params.joblib" )
157+ self .assertTrue (os .path .exists ("test_model.params.joblib" ))
158+ shutil .rmtree ("test_model.params.joblib" )
158159
159160 rf = RandomForestClassifier ()
160161 model = SklearnModel (model = rf , mode = "classification" , model_dir = "test_model" )
@@ -164,6 +165,34 @@ def test_save_model(self):
164165
165166 shutil .rmtree ("test_model" )
166167
168+ def test_save_model_with_dots (self ):
169+ rf = RandomForestClassifier ()
170+ model = SklearnModel (model = rf , mode = "classification" , model_dir = "test_model" )
171+ model .fit (self .binary_dataset )
172+ model .save ("../test_model" )
173+ self .assertTrue (os .path .exists ("../test_model" ))
174+ shutil .rmtree ("../test_model" )
175+
176+ def test_load_models_with_dots (self ):
177+ rf = RandomForestClassifier ()
178+ model = SklearnModel (model = rf , mode = "classification" , model_dir = "../test_model" )
179+ model .fit (self .binary_dataset )
180+ model .save ("../test_model" )
181+ self .assertTrue (os .path .exists ("../test_model" ))
182+
183+ predictions_1 = model .predict (self .binary_dataset_test )
184+
185+ new_model = SklearnModel .load ("../test_model" )
186+ y_test = new_model .predict (self .binary_dataset_test )
187+ self .assertEqual (len (y_test ), len (self .binary_dataset_test .y ))
188+ self .assertIsInstance (new_model , SklearnModel )
189+ self .assertEqual ("classification" , new_model .mode )
190+
191+ assert np .array_equal (predictions_1 , y_test )
192+
193+ shutil .rmtree ("../test_model" )
194+
195+
167196 def test_load_model (self ):
168197 rf = RandomForestClassifier ()
169198 model = SklearnModel (model = rf , mode = "classification" )
0 commit comments