Skip to content

Commit 1cacffc

Browse files
committed
can have no prototype
1 parent 5d30dd3 commit 1cacffc

File tree

2 files changed

+24
-53
lines changed

2 files changed

+24
-53
lines changed

vetiver/handlers/spacy.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,8 @@ def construct_prototype(self):
4848
isinstance(self.prototype_data, dict) and len(self.prototype_data) != 1
4949
): # is dict, more than one key
5050
raise ValueError("Spacy prototype data must dictionary with 1 key")
51-
elif self.prototype_data is None:
52-
text_column_name = "text"
53-
else:
54-
text_column_name = (
55-
self.prototype_data.columns[0]
56-
if isinstance(self.prototype_data, pd.DataFrame)
57-
else list(self.prototype_data.keys())[0]
58-
)
5951

60-
prototype = vetiver_create_prototype(pd.DataFrame({text_column_name: ["text"]}))
52+
prototype = vetiver_create_prototype(self.prototype_data)
6153

6254
return prototype
6355

vetiver/tests/test_spacy.py

Lines changed: 23 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -32,25 +32,6 @@ def spacy_model():
3232
return nlp
3333

3434

35-
@pytest.fixture(scope="function")
36-
@pytest.mark.parametrize("data", ["a", 1, [1, 2, 3]])
37-
def test_bad_prototype_data(data, spacy_model):
38-
with pytest.raises(TypeError):
39-
vetiver.VetiverModel(spacy_model, "animals", prototype_data=data)
40-
41-
42-
@pytest.fixture(scope="function")
43-
@pytest.mark.parametrize(
44-
pd.DataFrame(
45-
{"col": ["1", "2"], "col2": [1, 2]},
46-
pd.DataFrame({"col": ["1", "2"], "col2": [1, 2]}),
47-
)
48-
)
49-
def test_bad_prototype_shape(data, spacy_model):
50-
with pytest.raises(ValueError):
51-
vetiver.VetiverModel(spacy_model, "animals", prototype_data=data)
52-
53-
5435
@pytest.fixture()
5536
def vetiver_client_with_prototype(spacy_model): # With check_prototype=True
5637
df = pd.DataFrame({"new_column": ["one", "two", "three"]})
@@ -82,37 +63,35 @@ def vetiver_client_no_prototype(spacy_model): # With check_prototype=False
8263
return client
8364

8465

85-
def test_vetiver_predict_with_prototype(vetiver_client_with_prototype):
86-
df = pd.DataFrame({"new_column": ["turtles", "i have a dog"]})
66+
@pytest.mark.parametrize("data", ["a", 1, [1, 2, 3]])
67+
def test_bad_prototype_data(data, spacy_model):
68+
with pytest.raises(TypeError):
69+
vetiver.VetiverModel(spacy_model, "animals", prototype_data=data)
8770

88-
response = vetiver.predict(endpoint=vetiver_client_with_prototype, data=df)
8971

90-
assert isinstance(response, pd.DataFrame), response
91-
assert response.to_dict() == {
92-
"0": {
93-
"text": "turtles",
94-
"ents": [],
95-
"sents": [{"start": 0, "end": 7}],
96-
"tokens": [{"id": 0, "start": 0, "end": 7}],
97-
},
98-
"1": {
99-
"text": "i have a dog",
100-
"ents": [{"start": 9, "end": 12, "label": "ANIMAL"}],
101-
"sents": nan,
102-
"tokens": [
103-
{"id": 0, "start": 0, "end": 1},
104-
{"id": 1, "start": 2, "end": 6},
105-
{"id": 2, "start": 7, "end": 8},
106-
{"id": 3, "start": 9, "end": 12},
107-
],
108-
},
109-
}
72+
@pytest.mark.parametrize(
73+
"data",
74+
[
75+
{"col": ["1", "2"], "col2": [1, 2]},
76+
pd.DataFrame({"col": ["1", "2"], "col2": [1, 2]}),
77+
],
78+
)
79+
def test_bad_prototype_shape(data, spacy_model):
80+
with pytest.raises(ValueError):
81+
vetiver.VetiverModel(spacy_model, "animals", prototype_data=data)
82+
83+
84+
@pytest.mark.parametrize("data", [{"col": "1"}, pd.DataFrame({"col": ["1"]})])
85+
def test_good_prototype_shape(data, spacy_model):
86+
v = vetiver.VetiverModel(spacy_model, "animals", prototype_data=data)
11087

88+
assert v.prototype.construct().dict() == {"col": "1"}
11189

112-
def test_vetiver_predict_with_prototype_series(vetiver_client_with_prototype_series):
90+
91+
def test_vetiver_predict_with_prototype(vetiver_client_with_prototype):
11392
df = pd.DataFrame({"new_column": ["turtles", "i have a dog"]})
11493

115-
response = vetiver.predict(endpoint=vetiver_client_with_prototype_series, data=df)
94+
response = vetiver.predict(endpoint=vetiver_client_with_prototype, data=df)
11695

11796
assert isinstance(response, pd.DataFrame), response
11897
assert response.to_dict() == {

0 commit comments

Comments
 (0)