Skip to content

Commit 7202cd8

Browse files
authored
Merge pull request #101 from rstudio/xgboost
FEAT: xgboost handler
2 parents 22405bd + 8c62f89 commit 7202cd8

File tree

10 files changed

+185
-40
lines changed

10 files changed

+185
-40
lines changed

.github/workflows/tests.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
- name: Install dependencies
3333
run: |
3434
python -m pip install --upgrade pip
35-
python -m pip install -e .[dev,torch,statsmodels]
35+
python -m pip install -e .[dev,torch,statsmodels,xgboost]
3636
- name: Run Tests
3737
run: |
3838
pytest -m 'not rsc_test' --cov --cov-report xml
@@ -65,8 +65,8 @@ jobs:
6565
run: |
6666
pytest vetiver -m 'rsc_test'
6767
68-
test-no-torch:
69-
name: "Test no-torch"
68+
test-no-extras:
69+
name: "Test no exra ml frameworks"
7070
runs-on: ubuntu-latest
7171
steps:
7272
- uses: actions/checkout@v2

docs/source/index.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ You can use vetiver with:
55

66
- `scikit-learn <https://scikit-learn.org/stable/>`_
77
- `pytorch <https://pytorch.org/>`_
8+
- `statsmodels <https://www.statsmodels.org/>`_
9+
- `xgboost <https://xgboost.readthedocs.io/>`_
810

911
You can install the released version of vetiver from `PyPI <https://pypi.org/project/vetiver/>`_:
1012

@@ -65,6 +67,18 @@ Monitor
6567
~pin_metrics
6668
~plot_metrics
6769

70+
Model Handlers
71+
==================
72+
.. autosummary::
73+
:toctree: reference/
74+
:caption: Monitor
75+
76+
~BaseHandler
77+
~SKLearnHandler
78+
~TorchHandler
79+
~StatsmodelsHandler
80+
~XGBoostHandler
81+
6882
Advanced Usage
6983
==================
7084
.. toctree::

setup.cfg

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,6 @@ torch =
5151

5252
statsmodels =
5353
statsmodels
54+
55+
xgboost =
56+
xgboost

vetiver/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .handlers.sklearn import SKLearnHandler # noqa
1616
from .handlers.torch import TorchHandler # noqa
1717
from .handlers.statsmodels import StatsmodelsHandler # noqa
18+
from .handlers.xgboost import XGBoostHandler # noqa
1819
from .rsconnect import deploy_rsconnect # noqa
1920
from .monitor import compute_metrics, pin_metrics, plot_metrics, _rolling_df # noqa
2021

vetiver/handlers/sklearn.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,12 @@ class SKLearnHandler(BaseHandler):
1616

1717
model_class = staticmethod(lambda: sklearn.base.BaseEstimator)
1818

19-
def __init__(self, model, ptype_data):
20-
super().__init__(model, ptype_data)
21-
2219
def describe(self):
2320
"""Create description for sklearn model"""
2421
desc = f"Scikit-learn {self.model.__class__} model"
2522
return desc
2623

27-
def construct_meta(
24+
def create_meta(
2825
user: list = None,
2926
version: str = None,
3027
url: str = None,
@@ -54,17 +51,9 @@ def handler_predict(self, input_data, check_ptype):
5451
Prediction from model
5552
"""
5653

57-
if check_ptype:
58-
if isinstance(input_data, pd.DataFrame):
59-
prediction = self.model.predict(input_data)
60-
else:
61-
prediction = self.model.predict([input_data])
62-
63-
# do not check ptype
64-
else:
65-
if not isinstance(input_data, list):
66-
input_data = [input_data.split(",")] # user delimiter ?
67-
54+
if not check_ptype or isinstance(input_data, pd.DataFrame):
6855
prediction = self.model.predict(input_data)
56+
else:
57+
prediction = self.model.predict([input_data])
6958

7059
return prediction

vetiver/handlers/statsmodels.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,12 @@ def handler_predict(self, input_data, check_ptype):
5858
prediction
5959
Prediction from model
6060
"""
61-
if sm_exists:
62-
if isinstance(input_data, (list, pd.DataFrame)):
63-
prediction = self.model.predict(input_data)
64-
else:
65-
prediction = self.model.predict([input_data])
66-
else:
61+
if not sm_exists:
6762
raise ImportError("Cannot import `statsmodels`")
6863

64+
if isinstance(input_data, (list, pd.DataFrame)):
65+
prediction = self.model.predict(input_data)
66+
else:
67+
prediction = self.model.predict([input_data])
68+
6969
return prediction

vetiver/handlers/torch.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,6 @@ class TorchHandler(BaseHandler):
2121

2222
model_class = staticmethod(lambda: torch.nn.Module)
2323

24-
def __init__(self, model, ptype_data):
25-
super().__init__(model, ptype_data)
26-
2724
def describe(self):
2825
"""Create description for torch model"""
2926
desc = f"Pytorch model of type {type(self.model)}"
@@ -58,17 +55,15 @@ def handler_predict(self, input_data, check_ptype):
5855
prediction
5956
Prediction from model
6057
"""
61-
if torch_exists:
62-
if check_ptype:
63-
input_data = np.array(input_data, dtype=np.array(self.ptype_data).dtype)
64-
prediction = self.model(torch.from_numpy(input_data))
65-
66-
# do not check ptype
67-
else:
68-
input_data = torch.tensor(input_data)
69-
prediction = self.model(input_data)
58+
if not torch_exists:
59+
raise ImportError("Cannot import `torch`.")
60+
if check_ptype:
61+
input_data = np.array(input_data, dtype=np.array(self.ptype_data).dtype)
62+
prediction = self.model(torch.from_numpy(input_data))
7063

64+
# do not check ptype
7165
else:
72-
raise ImportError("Cannot import `torch`.")
66+
input_data = torch.tensor(input_data)
67+
prediction = self.model(input_data)
7368

7469
return prediction

vetiver/handlers/xgboost.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import pandas as pd
2+
3+
from ..meta import _model_meta
4+
from .base import BaseHandler
5+
6+
xgb_exists = True
7+
try:
8+
import xgboost
9+
except ImportError:
10+
xgb_exists = False
11+
12+
13+
class XGBoostHandler(BaseHandler):
14+
"""Handler class for creating VetiverModels with xgboost.
15+
16+
Parameters
17+
----------
18+
model :
19+
a trained and fit xgboost model
20+
"""
21+
22+
model_class = staticmethod(lambda: xgboost.Booster)
23+
24+
def describe(self):
25+
"""Create description for xgboost model"""
26+
desc = f"XGBoost {self.model.__class__} model."
27+
return desc
28+
29+
def create_meta(
30+
user: list = None,
31+
version: str = None,
32+
url: str = None,
33+
required_pkgs: list = [],
34+
):
35+
"""Create metadata for xgboost"""
36+
required_pkgs = required_pkgs + ["xgboost"]
37+
meta = _model_meta(user, version, url, required_pkgs)
38+
39+
return meta
40+
41+
def handler_predict(self, input_data, check_ptype):
42+
"""Generates method for /predict endpoint in VetiverAPI
43+
44+
The `handler_predict` function executes at each API call. Use this
45+
function for calling `predict()` and any other tasks that must be executed
46+
at each API call.
47+
48+
Parameters
49+
----------
50+
input_data:
51+
Test data
52+
53+
Returns
54+
-------
55+
prediction
56+
Prediction from model
57+
"""
58+
59+
if not xgb_exists:
60+
raise ImportError("Cannot import `xgboost`")
61+
62+
if not isinstance(input_data, pd.DataFrame):
63+
try:
64+
input_data = pd.DataFrame(input_data)
65+
except ValueError:
66+
raise (f"Expected a dict or DataFrame, got {type(input_data)}")
67+
input_data = xgboost.DMatrix(input_data)
68+
69+
prediction = self.model.predict(input_data)
70+
71+
return prediction

vetiver/tests/test_sklearn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def test_predict_endpoint_ptype_error():
4949
def test_predict_endpoint_no_ptype():
5050
np.random.seed(500)
5151
client = TestClient(_start_application(save_ptype=False).app)
52-
data = "0,0,0"
52+
data = [{"B": 0, "C": 0, "D": 0}]
5353
response = client.post("/predict", json=data)
5454
assert response.status_code == 200, response.text
5555
assert response.json() == {"prediction": [44.47]}, response.json()
@@ -58,7 +58,7 @@ def test_predict_endpoint_no_ptype():
5858
def test_predict_endpoint_no_ptype_batch():
5959
np.random.seed(500)
6060
client = TestClient(_start_application(save_ptype=False).app)
61-
data = [["0,0,0"], ["0,0,0"]]
61+
data = [[0, 0, 0], [0, 0, 0]]
6262
response = client.post("/predict", json=data)
6363
assert response.status_code == 200, response.text
6464
assert response.json() == {"prediction": [44.47, 44.47]}, response.json()
@@ -69,4 +69,4 @@ def test_predict_endpoint_no_ptype_error():
6969
client = TestClient(_start_application(save_ptype=False).app)
7070
data = {"hell0", 9, 32.0}
7171
with pytest.raises(TypeError):
72-
client.post("/predictt", json=data)
72+
client.post("/predict", json=data)

vetiver/tests/test_xgboost.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import pytest
2+
3+
xgb = pytest.importorskip("xgboost", reason="xgboost library not installed")
4+
5+
from vetiver.data import mtcars # noqa
6+
from vetiver.handlers.xgboost import XGBoostHandler # noqa
7+
import numpy as np # noqa
8+
from fastapi.testclient import TestClient # noqa
9+
10+
import vetiver # noqa
11+
12+
13+
@pytest.fixture
14+
def build_xgb():
15+
# read in data
16+
dtrain = xgb.DMatrix(mtcars.drop(columns="mpg"), label=mtcars["mpg"])
17+
# specify parameters via map
18+
param = {
19+
"max_depth": 2,
20+
"eta": 1,
21+
"objective": "reg:squarederror",
22+
"random_state": 0,
23+
}
24+
num_round = 2
25+
fit = xgb.train(param, dtrain, num_round)
26+
27+
return vetiver.VetiverModel(fit, "xgb", mtcars.drop(columns="mpg"))
28+
29+
30+
def test_vetiver_build(build_xgb):
31+
api = vetiver.VetiverAPI(build_xgb)
32+
client = TestClient(api.app)
33+
data = mtcars.head(1).drop(columns="mpg")
34+
35+
response = vetiver.predict(endpoint=client, data=data)
36+
37+
assert response.iloc[0, 0] == 21.064373016357422
38+
assert len(response) == 1
39+
40+
41+
def test_batch(build_xgb):
42+
api = vetiver.VetiverAPI(build_xgb)
43+
client = TestClient(api.app)
44+
data = mtcars.head(3).drop(columns="mpg")
45+
46+
response = vetiver.predict(endpoint=client, data=data)
47+
48+
assert response.iloc[0, 0] == 21.064373016357422
49+
assert len(response) == 3
50+
51+
52+
def test_no_ptype(build_xgb):
53+
api = vetiver.VetiverAPI(build_xgb, check_ptype=False)
54+
client = TestClient(api.app)
55+
data = mtcars.head(1).drop(columns="mpg")
56+
57+
response = vetiver.predict(endpoint=client, data=data)
58+
59+
assert response.iloc[0, 0] == 21.064373016357422
60+
assert len(response) == 1
61+
62+
63+
def test_serialize(build_xgb):
64+
import pins
65+
66+
board = pins.board_temp(allow_pickle_read=True)
67+
vetiver.vetiver_pin_write(board=board, model=build_xgb)
68+
assert isinstance(
69+
board.pin_read("xgb"),
70+
xgb.Booster,
71+
)
72+
board.pin_delete("xgb")

0 commit comments

Comments
 (0)