Skip to content

Commit 1554296

Browse files
committed
update tests
1 parent 25479ae commit 1554296

File tree

5 files changed

+101
-46
lines changed

5 files changed

+101
-46
lines changed

vetiver/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
) # noqa
1111
from .vetiver_model import VetiverModel # noqa
1212
from .server import VetiverAPI, vetiver_endpoint, predict # noqa
13-
from .mock import get_mock_data, get_mock_model # noqa
13+
from .mock import get_mock_data, get_mock_model, get_mtcars_model # noqa
1414
from .pin_read_write import vetiver_pin_write # noqa
1515
from .attach_pkgs import load_pkgs, get_board_pkgs # noqa
1616
from .meta import VetiverMeta # noqa

vetiver/handlers/sklearn.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,14 @@ def handler_predict(self, input_data, check_prototype: bool, **kw):
3939
Prediction from model
4040
"""
4141
prediction_type = kw.get("prediction_type", "predict")
42-
if prediction_type not in ["predict", "predict_proba", "predict_log_proba"]:
43-
raise ValueError(
44-
'prediction_type must be "predict", "predict_proba", \
45-
or "predict_log_proba"'
46-
)
4742

4843
input_data = (
4944
[input_data]
5045
if check_prototype and not isinstance(input_data, pd.DataFrame)
5146
else input_data
5247
)
53-
return getattr(self.model, prediction_type)(input_data).tolist()
48+
49+
if prediction_type in ["predict_proba", "predict_log_proba"]:
50+
return getattr(self.model, prediction_type)(input_data).tolist()
51+
52+
return self.model.predict(input_data).to_list()

vetiver/mock.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1-
from sklearn.dummy import DummyRegressor
21
import pandas as pd
32
import numpy as np
43

4+
from sklearn.dummy import DummyRegressor
5+
from sklearn.linear_model import LogisticRegression
6+
7+
from .data import mtcars
8+
59

610
def get_mock_data():
711
"""Create mock data for testing
@@ -26,5 +30,17 @@ def get_mock_model():
2630
model : sklearn.dummy.DummyRegressor
2731
Arbitrary model for testing purposes
2832
"""
29-
model = DummyRegressor()
30-
return model
33+
return DummyRegressor()
34+
35+
36+
def get_mtcars_model():
37+
"""Create mock model for testing
38+
39+
Returns
40+
-------
41+
model : sklearn.dummy.DummyRegressor
42+
Arbitrary model for testing purposes
43+
"""
44+
return LogisticRegression(max_iter=1000).fit(
45+
mtcars.drop(columns="cyl"), mtcars["cyl"]
46+
)

vetiver/server.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -239,14 +239,14 @@ def sum_values(x):
239239
"""
240240

241241
if not isinstance(endpoint_fx, Callable):
242-
if endpoint_fx not in SklearnPredictionTypes:
242+
if endpoint_fx not in ["predict", "predict_proba", "predict_log_proba"]:
243243
raise ValueError(
244244
f"""
245245
Prediction type {endpoint_fx} not available.
246246
Available prediction types: {SklearnPredictionTypes}
247247
"""
248248
)
249-
if not isinstance(self.model, SKLearnHandler):
249+
if not isinstance(self.model.handler_predict.__self__, SKLearnHandler):
250250
raise ValueError(
251251
"""
252252
The 'endpoint_fx' parameter can only be a
@@ -255,7 +255,7 @@ def sum_values(x):
255255
)
256256
self.vetiver_post(
257257
self.model.handler_predict,
258-
SklearnPredictionTypes,
258+
endpoint_fx,
259259
check_prototype=self.check_prototype,
260260
prediction_type=endpoint_fx,
261261
)

vetiver/tests/test_server.py

Lines changed: 73 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
import pytest
2+
import sys
3+
import pandas as pd
4+
import numpy as np
5+
from fastapi.testclient import TestClient
6+
from pydantic import BaseModel, conint
7+
8+
from vetiver.data import mtcars
19
from vetiver import (
210
mock,
311
VetiverModel,
@@ -7,26 +15,18 @@
715
vetiver_endpoint,
816
predict,
917
)
10-
from pydantic import BaseModel, conint
11-
from fastapi.testclient import TestClient
12-
import numpy as np
13-
import pytest
14-
import sys
15-
import pandas as pd
16-
from vetiver.handlers.sklearn import SKLearnHandler
1718

1819

1920
@pytest.fixture
2021
def model():
2122
np.random.seed(500)
22-
X, y = mock.get_mock_data()
23-
model = mock.get_mock_model().fit(X, y)
23+
model = mock.get_mtcars_model()
2424
v = VetiverModel(
2525
model=model,
26-
prototype_data=X,
26+
prototype_data=mtcars.drop(columns="cyl"),
2727
model_name="my_model",
2828
versioned=None,
29-
description="A regression model for testing purposes",
29+
description="A logistic regression model for testing purposes",
3030
)
3131
return v
3232

@@ -84,11 +84,29 @@ def test_get_prototype(client, model):
8484
assert response.status_code == 200, response.text
8585
assert response.json() == {
8686
"properties": {
87-
"B": {"example": 55, "type": "integer"},
88-
"C": {"example": 65, "type": "integer"},
89-
"D": {"example": 17, "type": "integer"},
87+
"mpg": {"example": 21.0, "type": "number"},
88+
"disp": {"example": 160.0, "type": "number"},
89+
"hp": {"example": 110.0, "type": "number"},
90+
"drat": {"example": 3.9, "type": "number"},
91+
"wt": {"example": 2.62, "type": "number"},
92+
"qsec": {"example": 16.46, "type": "number"},
93+
"vs": {"example": 0.0, "type": "number"},
94+
"am": {"example": 1.0, "type": "number"},
95+
"gear": {"example": 4.0, "type": "number"},
96+
"carb": {"example": 4.0, "type": "number"},
9097
},
91-
"required": ["B", "C", "D"],
98+
"required": [
99+
"mpg",
100+
"disp",
101+
"hp",
102+
"drat",
103+
"wt",
104+
"qsec",
105+
"vs",
106+
"am",
107+
"gear",
108+
"carb",
109+
],
92110
"title": "prototype",
93111
"type": "object",
94112
}
@@ -131,14 +149,28 @@ def test_vetiver_endpoint():
131149

132150
@pytest.fixture
133151
def data() -> pd.DataFrame:
134-
return pd.DataFrame({"B": [1, 1, 1], "C": [2, 2, 2], "D": [3, 3, 3]})
152+
return pd.DataFrame(
153+
{
154+
"mpg": [20, 20],
155+
"disp": [160, 160],
156+
"hp": [110, 110],
157+
"drat": [3.9, 3.9],
158+
"wt": [2.62, 2.62],
159+
"qsec": [16.00, 16.00],
160+
"vs": [0, 0],
161+
"am": [1, 1],
162+
"gear": [4, 4],
163+
"carb": [4, 4],
164+
}
165+
)
135166

136167

137168
def test_endpoint_adds(client, data):
169+
138170
response = client.post("/sum/", data=data.to_json(orient="records"))
139171

140172
assert response.status_code == 200
141-
assert response.json() == {"sum": [3, 6, 9]}
173+
assert response.json() == {"sum": [40, 320, 220, 7.8, 5.24, 32.00, 0, 2, 8, 8]}
142174

143175

144176
def test_endpoint_adds_no_prototype(client_no_prototype, data):
@@ -150,28 +182,36 @@ def test_endpoint_adds_no_prototype(client_no_prototype, data):
150182
assert response.json() == {"sum": [3, 6, 9]}
151183

152184

153-
def test_vetiver_post_sklearn_predict(model):
154-
vetiver_api = VetiverAPI(model=model)
155-
if not isinstance(vetiver_api.model, SKLearnHandler):
156-
pytest.skip("Test only applicable for SKLearnHandler models")
157-
158-
vetiver_api.vetiver_post("predict_proba")
159-
160-
client = TestClient(vetiver_api.app)
161-
response = client.post(
162-
"/predict_proba", json=vetiver_api.model.prototype.construct().dict()
163-
)
164-
assert response.status_code == 200
185+
def test_vetiver_post_sklearn_predict(model, data):
186+
api = VetiverAPI(model=model)
187+
api.vetiver_post("predict_proba")
188+
189+
client = TestClient(api.app)
190+
response = predict(endpoint="/predict_proba/", data=data, test_client=client)
191+
192+
assert isinstance(response, pd.DataFrame)
193+
assert len(response) == 2
194+
assert response.to_dict() == {
195+
"predict_proba": {
196+
0: [
197+
0.00627480416153554,
198+
0.9937251958346092,
199+
3.855256735904704e-12,
200+
],
201+
1: [
202+
0.00627480416153554,
203+
0.9937251958346092,
204+
3.855256735904704e-12,
205+
],
206+
},
207+
}
165208

166209

167210
def test_vetiver_post_invalid_sklearn_type(model):
168211
vetiver_api = VetiverAPI(model=model)
169-
if not isinstance(vetiver_api.model, SKLearnHandler):
170-
pytest.skip("Test only applicable for SKLearnHandler models")
171212

172213
with pytest.raises(
173214
ValueError,
174-
match="The 'endpoint_fx' parameter can only be a string \
175-
when using scikit-learn models.",
215+
match="Prediction type invalid_type not available",
176216
):
177217
vetiver_api.vetiver_post("invalid_type")

0 commit comments

Comments
 (0)