Skip to content

Commit bab0b2a

Browse files
committed
better errors and handling of data types
1 parent 067235c commit bab0b2a

File tree

2 files changed

+73
-9
lines changed

2 files changed

+73
-9
lines changed

vetiver/handlers/spacy.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,24 @@ def construct_prototype(self):
3333
prototype :
3434
Input data prototype for spacy model
3535
"""
36-
if self.prototype_data is None:
36+
if self.prototype_data is not None and not isinstance(
37+
self.prototype_data, (pd.Series, pd.DataFrame, dict)
38+
): # wrong type
39+
raise TypeError(
40+
"Spacy prototype must be a dict, pandas Series, or pandas DataFrame"
41+
)
42+
elif (
43+
isinstance(self.prototype_data, pd.DataFrame)
44+
and len(self.prototype_data.columns) != 1
45+
): # is dataframe, more than one column
46+
raise ValueError("Spacy prototype data must be a 1-column pandas DataFrame")
47+
elif (
48+
isinstance(self.prototype_data, dict) and len(self.prototype_data) != 1
49+
): # is dict, more than one key
50+
raise ValueError("Spacy prototype data must dictionary with 1 key")
51+
elif self.prototype_data is None:
3752
text_column_name = "text"
38-
3953
else:
40-
if (
41-
isinstance(self.prototype_data, pd.DataFrame)
42-
and len(self.prototype_data.columns) != 1
43-
):
44-
raise TypeError("Expected 1 column of text data")
45-
4654
text_column_name = (
4755
self.prototype_data.columns[0]
4856
if isinstance(self.prototype_data, pd.DataFrame)

vetiver/tests/test_spacy.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,26 @@ def spacy_model():
3232
return nlp
3333

3434

35-
@pytest.fixture
35+
@pytest.fixture(scope="function")
36+
@pytest.mark.parametrize("data", ["a", 1, [1, 2, 3]])
37+
def test_bad_prototype_data(data, spacy_model):
38+
with pytest.raises(TypeError):
39+
vetiver.VetiverModel(spacy_model, "animals", prototype_data=data)
40+
41+
42+
@pytest.fixture(scope="function")
43+
@pytest.mark.parametrize(
44+
pd.DataFrame(
45+
{"col": ["1", "2"], "col2": [1, 2]},
46+
pd.DataFrame({"col": ["1", "2"], "col2": [1, 2]}),
47+
)
48+
)
49+
def test_bad_prototype_shape(data, spacy_model):
50+
with pytest.raises(ValueError):
51+
vetiver.VetiverModel(spacy_model, "animals", prototype_data=data)
52+
53+
54+
@pytest.fixture()
3655
def vetiver_client_with_prototype(spacy_model): # With check_prototype=True
3756
df = pd.DataFrame({"new_column": ["one", "two", "three"]})
3857
v = vetiver.VetiverModel(spacy_model, "animals", prototype_data=df)
@@ -43,6 +62,16 @@ def vetiver_client_with_prototype(spacy_model): # With check_prototype=True
4362
return client
4463

4564

65+
@pytest.fixture(scope="function")
66+
def vetiver_client_with_prototype_series(spacy_model): # With check_prototype=True
67+
df = pd.Series({"new_column": ["one", "two", "three"]})
68+
v = vetiver.VetiverModel(spacy_model, "animals", prototype_data=df)
69+
app = vetiver.VetiverAPI(v, check_prototype=True)
70+
app.app.root_path = "/predict"
71+
client = TestClient(app.app)
72+
return client
73+
74+
4675
@pytest.fixture
4776
def vetiver_client_no_prototype(spacy_model): # With check_prototype=False
4877
v = vetiver.VetiverModel(spacy_model, "animals")
@@ -80,6 +109,33 @@ def test_vetiver_predict_with_prototype(vetiver_client_with_prototype):
80109
}
81110

82111

112+
def test_vetiver_predict_with_prototype_series(vetiver_client_with_prototype_series):
113+
df = pd.DataFrame({"new_column": ["turtles", "i have a dog"]})
114+
115+
response = vetiver.predict(endpoint=vetiver_client_with_prototype_series, data=df)
116+
117+
assert isinstance(response, pd.DataFrame), response
118+
assert response.to_dict() == {
119+
"0": {
120+
"text": "turtles",
121+
"ents": [],
122+
"sents": [{"start": 0, "end": 7}],
123+
"tokens": [{"id": 0, "start": 0, "end": 7}],
124+
},
125+
"1": {
126+
"text": "i have a dog",
127+
"ents": [{"start": 9, "end": 12, "label": "ANIMAL"}],
128+
"sents": nan,
129+
"tokens": [
130+
{"id": 0, "start": 0, "end": 1},
131+
{"id": 1, "start": 2, "end": 6},
132+
{"id": 2, "start": 7, "end": 8},
133+
{"id": 3, "start": 9, "end": 12},
134+
],
135+
},
136+
}
137+
138+
83139
def test_vetiver_predict_no_prototype(vetiver_client_no_prototype):
84140
df = pd.DataFrame({"uhhh": ["turtles", "i have a dog"]})
85141

0 commit comments

Comments
 (0)