-
Notifications
You must be signed in to change notification settings - Fork 19
Expand file tree
/
Copy pathtest_custom_handler.py
More file actions
63 lines (49 loc) · 1.96 KB
/
test_custom_handler.py
File metadata and controls
63 lines (49 loc) · 1.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import sklearn
import pydantic
import pandas as pd
from vetiver import mock, VetiverModel, BaseHandler
class CustomHandler(BaseHandler):
def __init__(self, model, prototype_data):
super().__init__(model, prototype_data)
model_type = staticmethod(lambda: sklearn.dummy.DummyRegressor)
def handler_predict(self, input_data, check_ptype):
if check_ptype is True:
if isinstance(input_data, pd.DataFrame):
prediction = self.model.predict(input_data)
else:
prediction = self.model.predict([input_data])
else:
if not isinstance(input_data, list):
input_data = [input_data.split(",")] # user delimiter ?
prediction = self.model.predict(input_data)
return prediction
def test_custom_vetiver_model():
X, y = mock.get_mock_data()
model = mock.get_mock_model().fit(X, y)
custom_handler = CustomHandler(model, X)
v = VetiverModel(
model=custom_handler,
prototype_data=X,
model_name="my_model",
versioned=None,
)
assert v.description == "A DummyRegressor model"
assert not v.metadata.required_pkgs
assert isinstance(v.model, sklearn.dummy.DummyRegressor)
# change to model_construct for pydantic v3
assert isinstance(v.prototype.model_construct(), pydantic.BaseModel)
def test_custom_vetiver_model_no_ptype():
X, y = mock.get_mock_data()
model = mock.get_mock_model().fit(X, y)
custom_handler = CustomHandler(model, None)
v = VetiverModel(
model=custom_handler,
prototype_data=X,
model_name="my_model",
versioned=None,
description="A regression model for testing purposes",
)
assert v.description == "A regression model for testing purposes"
assert isinstance(v.model, sklearn.dummy.DummyRegressor)
# change to model_construct for pydantic v3
assert isinstance(v.prototype.model_construct(), pydantic.BaseModel)