1
+ import pytest
2
+ import sys
3
+ import pandas as pd
4
+ import numpy as np
5
+ from fastapi .testclient import TestClient
6
+ from pydantic import BaseModel , conint
7
+
8
+ from vetiver .data import mtcars
1
9
from vetiver import (
2
10
mock ,
3
11
VetiverModel ,
7
15
vetiver_endpoint ,
8
16
predict ,
9
17
)
10
- from pydantic import BaseModel , conint
11
- from fastapi .testclient import TestClient
12
- import numpy as np
13
- import pytest
14
- import sys
15
- import pandas as pd
16
- from vetiver .handlers .sklearn import SKLearnHandler
17
18
18
19
19
20
@pytest .fixture
20
21
def model ():
21
22
np .random .seed (500 )
22
- X , y = mock .get_mock_data ()
23
- model = mock .get_mock_model ().fit (X , y )
23
+ model = mock .get_mtcars_model ()
24
24
v = VetiverModel (
25
25
model = model ,
26
- prototype_data = X ,
26
+ prototype_data = mtcars . drop ( columns = "cyl" ) ,
27
27
model_name = "my_model" ,
28
28
versioned = None ,
29
- description = "A regression model for testing purposes" ,
29
+ description = "A logistic regression model for testing purposes" ,
30
30
)
31
31
return v
32
32
@@ -84,11 +84,29 @@ def test_get_prototype(client, model):
84
84
assert response .status_code == 200 , response .text
85
85
assert response .json () == {
86
86
"properties" : {
87
- "B" : {"example" : 55 , "type" : "integer" },
88
- "C" : {"example" : 65 , "type" : "integer" },
89
- "D" : {"example" : 17 , "type" : "integer" },
87
+ "mpg" : {"example" : 21.0 , "type" : "number" },
88
+ "disp" : {"example" : 160.0 , "type" : "number" },
89
+ "hp" : {"example" : 110.0 , "type" : "number" },
90
+ "drat" : {"example" : 3.9 , "type" : "number" },
91
+ "wt" : {"example" : 2.62 , "type" : "number" },
92
+ "qsec" : {"example" : 16.46 , "type" : "number" },
93
+ "vs" : {"example" : 0.0 , "type" : "number" },
94
+ "am" : {"example" : 1.0 , "type" : "number" },
95
+ "gear" : {"example" : 4.0 , "type" : "number" },
96
+ "carb" : {"example" : 4.0 , "type" : "number" },
90
97
},
91
- "required" : ["B" , "C" , "D" ],
98
+ "required" : [
99
+ "mpg" ,
100
+ "disp" ,
101
+ "hp" ,
102
+ "drat" ,
103
+ "wt" ,
104
+ "qsec" ,
105
+ "vs" ,
106
+ "am" ,
107
+ "gear" ,
108
+ "carb" ,
109
+ ],
92
110
"title" : "prototype" ,
93
111
"type" : "object" ,
94
112
}
@@ -131,14 +149,28 @@ def test_vetiver_endpoint():
131
149
132
150
@pytest .fixture
133
151
def data () -> pd .DataFrame :
134
- return pd .DataFrame ({"B" : [1 , 1 , 1 ], "C" : [2 , 2 , 2 ], "D" : [3 , 3 , 3 ]})
152
+ return pd .DataFrame (
153
+ {
154
+ "mpg" : [20 , 20 ],
155
+ "disp" : [160 , 160 ],
156
+ "hp" : [110 , 110 ],
157
+ "drat" : [3.9 , 3.9 ],
158
+ "wt" : [2.62 , 2.62 ],
159
+ "qsec" : [16.00 , 16.00 ],
160
+ "vs" : [0 , 0 ],
161
+ "am" : [1 , 1 ],
162
+ "gear" : [4 , 4 ],
163
+ "carb" : [4 , 4 ],
164
+ }
165
+ )
135
166
136
167
137
168
def test_endpoint_adds (client , data ):
169
+
138
170
response = client .post ("/sum/" , data = data .to_json (orient = "records" ))
139
171
140
172
assert response .status_code == 200
141
- assert response .json () == {"sum" : [3 , 6 , 9 ]}
173
+ assert response .json () == {"sum" : [40 , 320 , 220 , 7.8 , 5.24 , 32.00 , 0 , 2 , 8 , 8 ]}
142
174
143
175
144
176
def test_endpoint_adds_no_prototype (client_no_prototype , data ):
@@ -150,28 +182,36 @@ def test_endpoint_adds_no_prototype(client_no_prototype, data):
150
182
assert response .json () == {"sum" : [3 , 6 , 9 ]}
151
183
152
184
153
- def test_vetiver_post_sklearn_predict (model ):
154
- vetiver_api = VetiverAPI (model = model )
155
- if not isinstance (vetiver_api .model , SKLearnHandler ):
156
- pytest .skip ("Test only applicable for SKLearnHandler models" )
157
-
158
- vetiver_api .vetiver_post ("predict_proba" )
159
-
160
- client = TestClient (vetiver_api .app )
161
- response = client .post (
162
- "/predict_proba" , json = vetiver_api .model .prototype .construct ().dict ()
163
- )
164
- assert response .status_code == 200
185
+ def test_vetiver_post_sklearn_predict (model , data ):
186
+ api = VetiverAPI (model = model )
187
+ api .vetiver_post ("predict_proba" )
188
+
189
+ client = TestClient (api .app )
190
+ response = predict (endpoint = "/predict_proba/" , data = data , test_client = client )
191
+
192
+ assert isinstance (response , pd .DataFrame )
193
+ assert len (response ) == 2
194
+ assert response .to_dict () == {
195
+ "predict_proba" : {
196
+ 0 : [
197
+ 0.00627480416153554 ,
198
+ 0.9937251958346092 ,
199
+ 3.855256735904704e-12 ,
200
+ ],
201
+ 1 : [
202
+ 0.00627480416153554 ,
203
+ 0.9937251958346092 ,
204
+ 3.855256735904704e-12 ,
205
+ ],
206
+ },
207
+ }
165
208
166
209
167
210
def test_vetiver_post_invalid_sklearn_type (model ):
168
211
vetiver_api = VetiverAPI (model = model )
169
- if not isinstance (vetiver_api .model , SKLearnHandler ):
170
- pytest .skip ("Test only applicable for SKLearnHandler models" )
171
212
172
213
with pytest .raises (
173
214
ValueError ,
174
- match = "The 'endpoint_fx' parameter can only be a string \
175
- when using scikit-learn models." ,
215
+ match = "Prediction type invalid_type not available" ,
176
216
):
177
217
vetiver_api .vetiver_post ("invalid_type" )
0 commit comments