11
11
import pandas as pd
12
12
import requests
13
13
import uvicorn
14
- from fastapi import FastAPI
14
+ from fastapi import FastAPI , Request
15
15
from fastapi .exceptions import RequestValidationError
16
16
from fastapi .openapi .utils import get_openapi
17
17
from fastapi .responses import HTMLResponse , PlainTextResponse , RedirectResponse
@@ -210,11 +210,13 @@ def vetiver_post(
210
210
211
211
Parameters
212
212
----------
213
- endpoint_fx : Union[typing.Callable, Literal["predict", "predict_proba", "predict_log_proba"]]
214
- A callable function that specifies the custom logic to execute when the endpoint is called.
215
- This function should take input data (e.g., a DataFrame or dictionary) and return the desired output
216
- (e.g., predictions or transformed data). For scikit-learn models, endpoint_fx can also be one of
217
- "predict", "predict_proba", or "predict_log_proba" if the model supports these methods.
213
+ endpoint_fx
214
+ : Union[typing.Callable, Literal["predict", "predict_proba", "predict_log_proba"]]
215
+ A callable function that specifies the custom logic to execute when the
216
+ endpoint is called. This function should take input data (e.g., a DataFrame
217
+ or dictionary) and return the desired output(e.g., predictions or transformed
218
+ data). For scikit-learn models, endpoint_fx can also be one of "predict",
219
+ "predict_proba", or "predict_log_proba" if the model supports these methods.
218
220
219
221
endpoint_name : str
220
222
The name of the endpoint to be created.
@@ -236,10 +238,20 @@ def sum_values(x):
236
238
```
237
239
"""
238
240
239
- if isinstance (endpoint_fx , SklearnPredictionTypes ):
241
+ if not isinstance (endpoint_fx , Callable ):
242
+ if endpoint_fx not in SklearnPredictionTypes :
243
+ raise ValueError (
244
+ f"""
245
+ Prediction type { endpoint_fx } not available.
246
+ Available prediction types: { SklearnPredictionTypes }
247
+ """
248
+ )
240
249
if not isinstance (self .model , SKLearnHandler ):
241
250
raise ValueError (
242
- "The 'endpoint_fx' parameter can only be a string when using scikit-learn models."
251
+ """
252
+ The 'endpoint_fx' parameter can only be a
253
+ string when using scikit-learn models.
254
+ """
243
255
)
244
256
self .vetiver_post (
245
257
self .model .handler_predict ,
@@ -252,17 +264,24 @@ def sum_values(x):
252
264
endpoint_name = endpoint_name or endpoint_fx .__name__
253
265
endpoint_doc = dedent (endpoint_fx .__doc__ ) if endpoint_fx .__doc__ else None
254
266
267
+ # this must be split up this way to preserve the correct type hints for
268
+ # the input_data schema validation via Pydantic + FastAPI
269
+ input_data_type = (
270
+ List [self .model .prototype ] if self .check_prototype else Request
271
+ )
272
+
255
273
@self .app .post (
256
274
urljoin ("/" , endpoint_name ),
257
275
name = endpoint_name ,
258
276
description = endpoint_doc ,
259
277
)
260
- async def custom_endpoint (input_data : List [self .model .prototype ]):
261
- if self .check_prototype :
262
- served_data = api_data_to_frame (input_data )
263
- else :
264
- served_data = await input_data .json ()
278
+ async def custom_endpoint (input_data : input_data_type ):
265
279
280
+ served_data = (
281
+ api_data_to_frame (input_data )
282
+ if self .check_prototype
283
+ else await input_data .json ()
284
+ )
266
285
predictions = endpoint_fx (served_data , ** kw )
267
286
268
287
if isinstance (predictions , List ):
0 commit comments