Skip to content

Commit a265da0

Browse files
authored
Merge pull request #155 from rstudio/server-refactor
2 parents 53e9000 + 0495c31 commit a265da0

File tree

14 files changed

+221
-259
lines changed

14 files changed

+221
-259
lines changed

Makefile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ lint:
5252
test: clean-test
5353
pytest -m 'not rsc_test and not docker'
5454

55+
test-pdb: clean-test
56+
pytest -m 'not rsc_test and not docker' --pdb
57+
5558
test-rsc: clean-test
5659
pytest
5760

vetiver/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .handlers.torch import TorchHandler # noqa
1717
from .handlers.statsmodels import StatsmodelsHandler # noqa
1818
from .handlers.xgboost import XGBoostHandler # noqa
19+
from .helpers import api_data_to_frame # noqa
1920
from .rsconnect import deploy_rsconnect # noqa
2021
from .monitor import compute_metrics, pin_metrics, plot_metrics, _rolling_df # noqa
2122
from .model_card import model_card # noqa

vetiver/handlers/base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from vetiver.handlers import base
21
from functools import singledispatch
32
from contextlib import suppress
43

@@ -145,7 +144,7 @@ def handler_predict(self, input_data, check_prototype):
145144

146145

147146
@create_handler.register
148-
def _(model: base.BaseHandler, prototype_data):
147+
def _(model: BaseHandler, prototype_data):
149148
if model.prototype_data is None and prototype_data is not None:
150149
model.prototype_data = prototype_data
151150

vetiver/handlers/sklearn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,4 @@ def handler_predict(self, input_data, check_prototype):
3939
else:
4040
prediction = self.model.predict([input_data])
4141

42-
return prediction
42+
return prediction.tolist()

vetiver/handlers/statsmodels.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,4 @@ def handler_predict(self, input_data, check_prototype):
4747
else:
4848
prediction = self.model.predict([input_data])
4949

50-
return prediction
50+
return prediction.tolist()

vetiver/handlers/torch.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,12 @@ def handler_predict(self, input_data, check_prototype):
4141
"""
4242
if not torch_exists:
4343
raise ImportError("Cannot import `torch`.")
44+
4445
if check_prototype:
4546
input_data = np.array(input_data, dtype=np.array(self.prototype_data).dtype)
4647
prediction = self.model(torch.from_numpy(input_data))
47-
48-
# do not check ptype
4948
else:
5049
input_data = torch.tensor(input_data)
5150
prediction = self.model(input_data)
5251

53-
return prediction
52+
return prediction.tolist()

vetiver/handlers/xgboost.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,9 @@ def handler_predict(self, input_data, check_prototype):
4848
input_data = pd.DataFrame(input_data)
4949
except ValueError:
5050
raise (f"Expected a dict or DataFrame, got {type(input_data)}")
51+
5152
input_data = xgboost.DMatrix(input_data)
5253

5354
prediction = self.model.predict(input_data)
5455

55-
return prediction
56+
return prediction.tolist()

vetiver/helpers.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from functools import singledispatch
2+
import pandas as pd
3+
import pydantic
4+
5+
6+
@singledispatch
7+
def api_data_to_frame(pred_data) -> pd.DataFrame:
8+
"""Convert prototype to dataframe data
9+
10+
Parameters
11+
----------
12+
pred_data : pydantic.BaseModel
13+
User data from given to API endpoint
14+
15+
Returns
16+
-------
17+
pd.DataFrame
18+
BaseModel data translated into DataFrame
19+
"""
20+
21+
raise TypeError("Data should be list, pydantic.BaseModel, pd.DataFrame")
22+
23+
24+
@api_data_to_frame.register(pydantic.BaseModel)
25+
@api_data_to_frame.register(list)
26+
def _(pred_data):
27+
28+
return pd.DataFrame([dict(s) for s in pred_data])
29+
30+
31+
@api_data_to_frame.register(dict)
32+
def _dict(pred_data):
33+
return api_data_to_frame([pred_data])
34+
35+
36+
def response_to_frame(response: dict) -> pd.DataFrame:
37+
"""Convert API JSON response to data frame
38+
39+
Parameters
40+
----------
41+
response : dict
42+
Response from API endpoint
43+
44+
Returns
45+
-------
46+
pd.DataFrame
47+
Response translated into DataFrame
48+
"""
49+
response_df = pd.DataFrame.from_dict(response.json())
50+
51+
return response_df

vetiver/rsconnect.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
import typing
55

66
from rsconnect.actions import deploy_python_fastapi
7+
from rsconnect.api import RSConnectServer as ConnectServer
78

89
from .write_fastapi import write_app
910

1011

1112
def deploy_rsconnect(
12-
connect_server,
13+
connect_server: ConnectServer,
1314
board,
1415
pin_name: str,
1516
version: str = None,

vetiver/server.py

Lines changed: 28 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
1-
from typing import Any, Callable, Dict, List, Union
1+
from typing import Callable, List, Union
22
from urllib.parse import urljoin
33

4+
import re
45
import httpx
56
import pandas as pd
67
import requests
78
import uvicorn
89
from fastapi import FastAPI, Request, testclient
10+
from fastapi.exceptions import RequestValidationError
911
from fastapi.openapi.utils import get_openapi
1012
from fastapi.responses import HTMLResponse, RedirectResponse
13+
from fastapi.responses import PlainTextResponse
1114
from warnings import warn
1215

1316
from .utils import _jupyter_nb
1417
from .vetiver_model import VetiverModel
1518
from .meta import VetiverMeta
19+
from .helpers import api_data_to_frame, response_to_frame
1620

1721

1822
class VetiverAPI:
@@ -138,6 +142,10 @@ async def rapidoc():
138142
</html>
139143
"""
140144

145+
@app.exception_handler(RequestValidationError)
146+
async def validation_exception_handler(request, exc):
147+
return PlainTextResponse(str(exc), status_code=422)
148+
141149
return app
142150

143151
def vetiver_post(self, endpoint_fx: Callable, endpoint_name: str = None, **kw):
@@ -167,26 +175,26 @@ def vetiver_post(self, endpoint_fx: Callable, endpoint_name: str = None, **kw):
167175
if self.check_prototype is True:
168176

169177
@self.app.post(urljoin("/", endpoint_name), name=endpoint_name)
170-
async def custom_endpoint(
171-
input_data: Union[self.model.prototype, List[self.model.prototype]]
172-
):
178+
async def custom_endpoint(input_data: List[self.model.prototype]):
179+
_to_frame = api_data_to_frame(input_data)
180+
predictions = endpoint_fx(_to_frame, **kw)
173181

174-
if isinstance(input_data, List):
175-
served_data = _batch_data(input_data)
182+
if isinstance(predictions, List):
183+
return {endpoint_name: predictions}
176184
else:
177-
served_data = _prepare_data(input_data)
178-
179-
new = endpoint_fx(served_data, **kw)
180-
return {endpoint_name: new.tolist()}
185+
return predictions
181186

182187
else:
183188

184189
@self.app.post(urljoin("/", endpoint_name))
185190
async def custom_endpoint(input_data: Request):
186191
served_data = await input_data.json()
187-
new = endpoint_fx(served_data, **kw)
192+
predictions = endpoint_fx(served_data, **kw)
188193

189-
return {endpoint_name: new.tolist()}
194+
if isinstance(predictions, List):
195+
return {endpoint_name: predictions}
196+
else:
197+
return predictions
190198

191199
def run(self, port: int = 8000, host: str = "127.0.0.1", **kw):
192200
"""
@@ -261,46 +269,28 @@ def predict(endpoint, data: Union[dict, pd.DataFrame, pd.Series], **kw) -> pd.Da
261269
# TO DO: dispatch
262270

263271
if isinstance(data, pd.DataFrame):
264-
data_json = data.to_json(orient="records")
265-
response = requester.post(endpoint, data=data_json, **kw)
272+
response = requester.post(
273+
endpoint, data=data.to_json(orient="records"), **kw
274+
) # TO DO: httpx deprecating data in favor of content for TestClient
266275
elif isinstance(data, pd.Series):
267-
data_dict = data.to_json()
268-
response = requester.post(endpoint, data=data_dict, **kw)
276+
response = requester.post(endpoint, json=[data.to_dict()], **kw)
269277
elif isinstance(data, dict):
270-
response = requester.post(endpoint, json=data, **kw)
278+
response = requester.post(endpoint, json=[data], **kw)
271279
else:
272280
response = requester.post(endpoint, json=data, **kw)
273281

274282
try:
275283
response.raise_for_status()
276284
except (requests.exceptions.HTTPError, httpx.HTTPStatusError) as e:
277285
if response.status_code == 422:
278-
raise TypeError(
279-
f"Predict expects DataFrame, Series, or dict. Given type is {type(data)}"
280-
)
286+
raise TypeError(re.sub(r"\n", ": ", response.text))
281287
raise requests.exceptions.HTTPError(
282288
f"Could not obtain data from endpoint with error: {e}"
283289
)
284290

285-
response_df = pd.DataFrame.from_dict(response.json())
286-
287-
return response_df
288-
289-
290-
def _prepare_data(pred_data: Dict[str, Any]) -> List[Any]:
291-
served_data = []
292-
for key, value in pred_data:
293-
served_data.append(value)
294-
return served_data
295-
296-
297-
def _batch_data(pred_data: List[Any]) -> pd.DataFrame:
298-
columns = pred_data[0].dict().keys()
299-
300-
data = [line.dict() for line in pred_data]
291+
response_frame = response_to_frame(response)
301292

302-
served_data = pd.DataFrame(data, columns=columns)
303-
return served_data
293+
return response_frame
304294

305295

306296
def vetiver_endpoint(url: str = "http://127.0.0.1:8000/predict") -> str:

0 commit comments

Comments
 (0)