Skip to content

Commit 74f9a44

Browse files
committed
clean up vetiver_post
1 parent 9bec20c commit 74f9a44

File tree

3 files changed

+82
-45
lines changed

3 files changed

+82
-45
lines changed

vetiver/server.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import pandas as pd
1212
import requests
1313
import uvicorn
14-
from fastapi import FastAPI
14+
from fastapi import FastAPI, Request
1515
from fastapi.exceptions import RequestValidationError
1616
from fastapi.openapi.utils import get_openapi
1717
from fastapi.responses import HTMLResponse, PlainTextResponse, RedirectResponse
@@ -210,11 +210,13 @@ def vetiver_post(
210210
211211
Parameters
212212
----------
213-
endpoint_fx : Union[typing.Callable, Literal["predict", "predict_proba", "predict_log_proba"]]
214-
A callable function that specifies the custom logic to execute when the endpoint is called.
215-
This function should take input data (e.g., a DataFrame or dictionary) and return the desired output
216-
(e.g., predictions or transformed data). For scikit-learn models, endpoint_fx can also be one of
217-
"predict", "predict_proba", or "predict_log_proba" if the model supports these methods.
213+
endpoint_fx
214+
: Union[typing.Callable, Literal["predict", "predict_proba", "predict_log_proba"]]
215+
A callable function that specifies the custom logic to execute when the
216+
endpoint is called. This function should take input data (e.g., a DataFrame
217+
or dictionary) and return the desired output(e.g., predictions or transformed
218+
data). For scikit-learn models, endpoint_fx can also be one of "predict",
219+
"predict_proba", or "predict_log_proba" if the model supports these methods.
218220
219221
endpoint_name : str
220222
The name of the endpoint to be created.
@@ -236,10 +238,20 @@ def sum_values(x):
236238
```
237239
"""
238240

239-
if isinstance(endpoint_fx, SklearnPredictionTypes):
241+
if not isinstance(endpoint_fx, Callable):
242+
if endpoint_fx not in SklearnPredictionTypes:
243+
raise ValueError(
244+
f"""
245+
Prediction type {endpoint_fx} not available.
246+
Available prediction types: {SklearnPredictionTypes}
247+
"""
248+
)
240249
if not isinstance(self.model, SKLearnHandler):
241250
raise ValueError(
242-
"The 'endpoint_fx' parameter can only be a string when using scikit-learn models."
251+
"""
252+
The 'endpoint_fx' parameter can only be a
253+
string when using scikit-learn models.
254+
"""
243255
)
244256
self.vetiver_post(
245257
self.model.handler_predict,
@@ -252,17 +264,24 @@ def sum_values(x):
252264
endpoint_name = endpoint_name or endpoint_fx.__name__
253265
endpoint_doc = dedent(endpoint_fx.__doc__) if endpoint_fx.__doc__ else None
254266

267+
# this must be split up this way to preserve the correct type hints for
268+
# the input_data schema validation via Pydantic + FastAPI
269+
input_data_type = (
270+
List[self.model.prototype] if self.check_prototype else Request
271+
)
272+
255273
@self.app.post(
256274
urljoin("/", endpoint_name),
257275
name=endpoint_name,
258276
description=endpoint_doc,
259277
)
260-
async def custom_endpoint(input_data: List[self.model.prototype]):
261-
if self.check_prototype:
262-
served_data = api_data_to_frame(input_data)
263-
else:
264-
served_data = await input_data.json()
278+
async def custom_endpoint(input_data: input_data_type):
265279

280+
served_data = (
281+
api_data_to_frame(input_data)
282+
if self.check_prototype
283+
else await input_data.json()
284+
)
266285
predictions = endpoint_fx(served_data, **kw)
267286

268287
if isinstance(predictions, List):

vetiver/tests/test_add_endpoint.py

Lines changed: 0 additions & 32 deletions
This file was deleted.

vetiver/tests/test_server.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import numpy as np
1313
import pytest
1414
import sys
15+
import pandas as pd
16+
from vetiver.handlers.sklearn import SKLearnHandler
1517

1618

1719
@pytest.fixture
@@ -125,3 +127,51 @@ def test_vetiver_endpoint():
125127
url = vetiver_endpoint(url_raw)
126128

127129
assert url == "http://127.0.0.1:8000/predict"
130+
131+
132+
@pytest.fixture
133+
def data() -> pd.DataFrame:
134+
return pd.DataFrame({"B": [1, 1, 1], "C": [2, 2, 2], "D": [3, 3, 3]})
135+
136+
137+
def test_endpoint_adds(client, data):
138+
response = client.post("/sum/", data=data.to_json(orient="records"))
139+
140+
assert response.status_code == 200
141+
assert response.json() == {"sum": [3, 6, 9]}
142+
143+
144+
def test_endpoint_adds_no_prototype(client_no_prototype, data):
145+
146+
data = pd.DataFrame({"B": [1, 1, 1], "C": [2, 2, 2], "D": [3, 3, 3]})
147+
response = client_no_prototype.post("/sum/", data=data.to_json(orient="records"))
148+
149+
assert response.status_code == 200
150+
assert response.json() == {"sum": [3, 6, 9]}
151+
152+
153+
def test_vetiver_post_sklearn_predict(model):
154+
vetiver_api = VetiverAPI(model=model)
155+
if not isinstance(vetiver_api.model, SKLearnHandler):
156+
pytest.skip("Test only applicable for SKLearnHandler models")
157+
158+
vetiver_api.vetiver_post("predict_proba")
159+
160+
client = TestClient(vetiver_api.app)
161+
response = client.post(
162+
"/predict_proba", json=vetiver_api.model.prototype.construct().dict()
163+
)
164+
assert response.status_code == 200
165+
166+
167+
def test_vetiver_post_invalid_sklearn_type(model):
168+
vetiver_api = VetiverAPI(model=model)
169+
if not isinstance(vetiver_api.model, SKLearnHandler):
170+
pytest.skip("Test only applicable for SKLearnHandler models")
171+
172+
with pytest.raises(
173+
ValueError,
174+
match="The 'endpoint_fx' parameter can only be a string \
175+
when using scikit-learn models.",
176+
):
177+
vetiver_api.vetiver_post("invalid_type")

0 commit comments

Comments
 (0)