Skip to content

Commit 152dd6f

Browse files
authored
Merge pull request #143 from rstudio/spacy
2 parents ebaddb5 + e41d4a8 commit 152dd6f

File tree

8 files changed

+267
-0
lines changed

8 files changed

+267
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ You can use vetiver with:
2020
- [torch](https://pytorch.org/)
2121
- [statsmodels](https://www.statsmodels.org/stable/index.html)
2222
- [xgboost](https://xgboost.readthedocs.io/en/stable/)
23+
- [spacy](https://spacy.io/)
2324
- or utilize [custom handlers](https://rstudio.github.io/vetiver-python/stable/advancedusage/custom_handler.html) to support your own models!
2425

2526
## Installation

docs/_quarto.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ quartodoc:
8888
- TorchHandler
8989
- StatsmodelsHandler
9090
- XGBoostHandler
91+
- SpacyHandler
9192

9293
metadata-files:
9394
- _sidebar.yml

docs/index.qmd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ You can use vetiver with:
2020
- [torch](https://pytorch.org/)
2121
- [statsmodels](https://www.statsmodels.org/stable/index.html)
2222
- [xgboost](https://xgboost.readthedocs.io/en/stable/)
23+
- [spacy](https://spacy.io/)
2324
- or utilize [custom handlers](https://rstudio.github.io/vetiver-python/stable/advancedusage/custom_handler.html) to support your own models!
2425

2526
## Installation

setup.cfg

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ all_models =
4848
vetiver[torch]
4949
vetiver[statsmodels]
5050
vetiver[xgboost]
51+
vetiver[spacy]
5152

5253
dev =
5354
pytest
@@ -67,6 +68,9 @@ torch =
6768
xgboost =
6869
xgboost
6970

71+
spacy =
72+
spacy
73+
7074
typecheck =
7175
pyright
7276
pandas-stubs

vetiver/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .handlers.torch import TorchHandler # noqa
1717
from .handlers.statsmodels import StatsmodelsHandler # noqa
1818
from .handlers.xgboost import XGBoostHandler # noqa
19+
from .handlers.spacy import SpacyHandler # noqa
1920
from .helpers import api_data_to_frame # noqa
2021
from .rsconnect import deploy_rsconnect # noqa
2122
from .monitor import compute_metrics, pin_metrics, plot_metrics, _rolling_df # noqa

vetiver/handlers/spacy.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from .base import BaseHandler
2+
from ..prototype import vetiver_create_prototype
3+
from ..helpers import api_data_to_frame
4+
5+
import pandas as pd
6+
7+
spacy_exists = True
8+
try:
9+
import spacy
10+
except ImportError:
11+
spacy_exists = False
12+
13+
14+
class SpacyHandler(BaseHandler):
15+
"""Handler class for creating VetiverModels with spacy.
16+
17+
Parameters
18+
----------
19+
model :
20+
a trained and fit spacy model
21+
"""
22+
23+
model_class = staticmethod(lambda: spacy.Language)
24+
25+
if spacy_exists:
26+
pip_name = "spacy"
27+
28+
def construct_prototype(self):
29+
"""Create data prototype for a spacy model, which is one column of string data
30+
31+
Returns
32+
-------
33+
prototype :
34+
Input data prototype for spacy model
35+
"""
36+
if self.prototype_data is not None and not isinstance(
37+
self.prototype_data, (pd.Series, pd.DataFrame, dict)
38+
): # wrong type
39+
raise TypeError(
40+
"Spacy prototype must be a dict, pandas Series, or pandas DataFrame"
41+
)
42+
elif (
43+
isinstance(self.prototype_data, pd.DataFrame)
44+
and len(self.prototype_data.columns) != 1
45+
): # is dataframe, more than one column
46+
raise ValueError("Spacy prototype data must be a 1-column pandas DataFrame")
47+
elif (
48+
isinstance(self.prototype_data, dict) and len(self.prototype_data) != 1
49+
): # is dict, more than one key
50+
raise ValueError("Spacy prototype data must be a dictionary with 1 key")
51+
52+
prototype = vetiver_create_prototype(self.prototype_data)
53+
54+
return prototype
55+
56+
def handler_predict(self, input_data, check_prototype):
57+
"""Generates method for /predict endpoint in VetiverAPI
58+
59+
The `handler_predict` function executes at each API call. Use this
60+
function for calling `predict()` and any other tasks that must be executed
61+
at each API call.
62+
63+
Parameters
64+
----------
65+
input_data:
66+
Test data
67+
68+
Returns
69+
-------
70+
prediction
71+
Prediction from model
72+
"""
73+
if not spacy_exists:
74+
raise ImportError("Cannot import `spacy`")
75+
76+
response_body = []
77+
78+
input_data = api_data_to_frame(input_data)
79+
80+
for doc in self.model.pipe(input_data.iloc[:, 0]):
81+
response_body.append(doc.to_json())
82+
83+
return pd.Series(response_body)

vetiver/helpers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ def _(pred_data):
2828
return pd.DataFrame([dict(s) for s in pred_data])
2929

3030

31+
@api_data_to_frame.register(pd.DataFrame)
32+
def _pd_frame(pred_data):
33+
34+
return pred_data
35+
36+
3137
@api_data_to_frame.register(dict)
3238
def _dict(pred_data):
3339
return api_data_to_frame([pred_data])

vetiver/tests/test_spacy.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
import pytest
2+
3+
spacy = pytest.importorskip("spacy", reason="spacy library not installed")
4+
5+
import numpy as np # noqa
6+
import pandas as pd # noqa
7+
from fastapi.testclient import TestClient # noqa
8+
from numpy import nan # noqa
9+
import vetiver # noqa
10+
11+
12+
@spacy.language.Language.component("animals")
13+
def animal_component_function(doc):
14+
matches = matcher(doc) # noqa
15+
spans = [
16+
spacy.tokens.Span(doc, start, end, label="ANIMAL")
17+
for match_id, start, end in matches
18+
]
19+
doc.ents = spans
20+
return doc
21+
22+
23+
nlp = spacy.blank("en")
24+
animals = list(nlp.pipe(["dog", "cat", "turtle"]))
25+
matcher = spacy.matcher.PhraseMatcher(nlp.vocab)
26+
matcher.add("ANIMAL", animals)
27+
nlp.add_pipe("animals")
28+
29+
30+
@pytest.fixture
31+
def spacy_model():
32+
return nlp
33+
34+
35+
@pytest.fixture()
36+
def vetiver_client_with_prototype(spacy_model): # With check_prototype=True
37+
df = pd.DataFrame({"new_column": ["one", "two", "three"]})
38+
v = vetiver.VetiverModel(spacy_model, "animals", prototype_data=df)
39+
app = vetiver.VetiverAPI(v, check_prototype=True)
40+
app.app.root_path = "/predict"
41+
client = TestClient(app.app)
42+
43+
return client
44+
45+
46+
@pytest.fixture(scope="function")
47+
def vetiver_client_with_prototype_series(spacy_model): # With check_prototype=True
48+
df = pd.Series({"new_column": ["one", "two", "three"]})
49+
v = vetiver.VetiverModel(spacy_model, "animals", prototype_data=df)
50+
app = vetiver.VetiverAPI(v, check_prototype=True)
51+
app.app.root_path = "/predict"
52+
client = TestClient(app.app)
53+
return client
54+
55+
56+
@pytest.fixture
57+
def vetiver_client_no_prototype(spacy_model): # With check_prototype=False
58+
v = vetiver.VetiverModel(spacy_model, "animals")
59+
app = vetiver.VetiverAPI(v, check_prototype=False)
60+
app.app.root_path = "/predict"
61+
client = TestClient(app.app)
62+
63+
return client
64+
65+
66+
@pytest.mark.parametrize("data", ["a", 1, [1, 2, 3]])
67+
def test_bad_prototype_data(data, spacy_model):
68+
with pytest.raises(TypeError):
69+
vetiver.VetiverModel(spacy_model, "animals", prototype_data=data)
70+
71+
72+
@pytest.mark.parametrize(
73+
"data",
74+
[
75+
{"col": ["1", "2"], "col2": [1, 2]},
76+
pd.DataFrame({"col": ["1", "2"], "col2": [1, 2]}),
77+
],
78+
)
79+
def test_bad_prototype_shape(data, spacy_model):
80+
with pytest.raises(ValueError):
81+
vetiver.VetiverModel(spacy_model, "animals", prototype_data=data)
82+
83+
84+
@pytest.mark.parametrize("data", [{"col": "1"}, pd.DataFrame({"col": ["1"]})])
85+
def test_good_prototype_shape(data, spacy_model):
86+
v = vetiver.VetiverModel(spacy_model, "animals", prototype_data=data)
87+
88+
assert v.prototype.construct().dict() == {"col": "1"}
89+
90+
91+
def test_vetiver_predict_with_prototype(vetiver_client_with_prototype):
92+
df = pd.DataFrame({"new_column": ["turtles", "i have a dog"]})
93+
94+
response = vetiver.predict(endpoint=vetiver_client_with_prototype, data=df)
95+
96+
assert isinstance(response, pd.DataFrame), response
97+
assert response.to_dict() == {
98+
"0": {
99+
"text": "turtles",
100+
"ents": [],
101+
"sents": [{"start": 0, "end": 7}],
102+
"tokens": [{"id": 0, "start": 0, "end": 7}],
103+
},
104+
"1": {
105+
"text": "i have a dog",
106+
"ents": [{"start": 9, "end": 12, "label": "ANIMAL"}],
107+
"sents": nan,
108+
"tokens": [
109+
{"id": 0, "start": 0, "end": 1},
110+
{"id": 1, "start": 2, "end": 6},
111+
{"id": 2, "start": 7, "end": 8},
112+
{"id": 3, "start": 9, "end": 12},
113+
],
114+
},
115+
}
116+
117+
118+
def test_vetiver_predict_no_prototype(vetiver_client_no_prototype):
119+
df = pd.DataFrame({"uhhh": ["turtles", "i have a dog"]})
120+
121+
response = vetiver.predict(endpoint=vetiver_client_no_prototype, data=df)
122+
123+
assert isinstance(response, pd.DataFrame), response
124+
assert response.to_dict() == {
125+
"0": {
126+
"text": "turtles",
127+
"ents": [],
128+
"sents": [{"start": 0, "end": 7}],
129+
"tokens": [{"id": 0, "start": 0, "end": 7}],
130+
},
131+
"1": {
132+
"text": "i have a dog",
133+
"ents": [{"start": 9, "end": 12, "label": "ANIMAL"}],
134+
"sents": nan,
135+
"tokens": [
136+
{"id": 0, "start": 0, "end": 1},
137+
{"id": 1, "start": 2, "end": 6},
138+
{"id": 2, "start": 7, "end": 8},
139+
{"id": 3, "start": 9, "end": 12},
140+
],
141+
},
142+
}
143+
144+
145+
def test_serialize_no_prototype(spacy_model):
146+
import pins
147+
148+
board = pins.board_temp(allow_pickle_read=True)
149+
v = vetiver.VetiverModel(spacy_model, "animals")
150+
vetiver.vetiver_pin_write(board=board, model=v)
151+
v2 = vetiver.VetiverModel.from_pin(board, "animals")
152+
assert isinstance(
153+
v2.model,
154+
spacy.lang.en.English,
155+
)
156+
157+
158+
def test_serialize_prototype(spacy_model):
159+
import pins
160+
161+
board = pins.board_temp(allow_pickle_read=True)
162+
v = vetiver.VetiverModel(
163+
spacy_model, "animals", prototype_data=pd.DataFrame({"text": ["text"]})
164+
)
165+
vetiver.vetiver_pin_write(board=board, model=v)
166+
v2 = vetiver.VetiverModel.from_pin(board, "animals")
167+
assert isinstance(
168+
v2.model,
169+
spacy.lang.en.English,
170+
)

0 commit comments

Comments
 (0)