Skip to content

Commit ebaddb5

Browse files
authored
Merge pull request #170 from rstudio/get-meta
2 parents a12326a + 4947843 commit ebaddb5

File tree

3 files changed

+52
-31
lines changed

3 files changed

+52
-31
lines changed

vetiver/server.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ def pin_url():
100100
async def ping():
101101
return {"ping": "pong"}
102102

103+
@app.get("/metadata")
104+
async def get_metadata():
105+
return self.model.metadata.to_dict()
106+
103107
self.vetiver_post(
104108
self.model.handler_predict, "predict", check_prototype=self.check_prototype
105109
)

vetiver/tests/test_ping_server.py

Lines changed: 0 additions & 31 deletions
This file was deleted.

vetiver/tests/test_server.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from vetiver import mock, VetiverModel, VetiverAPI
2+
from fastapi.testclient import TestClient
3+
import pytest
4+
import sys
5+
6+
7+
@pytest.fixture
8+
def vetiver_model():
9+
X, y = mock.get_mock_data()
10+
model = mock.get_mock_model().fit(X, y)
11+
v = VetiverModel(
12+
model=model,
13+
prototype_data=X,
14+
model_name="my_model",
15+
versioned=None,
16+
description="A regression model for testing purposes",
17+
)
18+
return v
19+
20+
21+
@pytest.fixture
22+
def client(vetiver_model):
23+
app = VetiverAPI(vetiver_model)
24+
25+
return TestClient(app.app)
26+
27+
28+
def test_get_ping(client):
29+
response = client.get("/ping")
30+
assert response.status_code == 200, response.text
31+
assert response.json() == {"ping": "pong"}
32+
33+
34+
def test_get_docs(client):
35+
response = client.get("/__docs__")
36+
assert response.status_code == 200, response.text
37+
38+
39+
def test_get_metadata(client):
40+
response = client.get("/metadata")
41+
assert response.status_code == 200, response.text
42+
assert response.json() == {
43+
"user": {},
44+
"version": None,
45+
"url": None,
46+
"required_pkgs": ["scikit-learn"],
47+
"python_version": list(sys.version_info), # JSON will return a list
48+
}

0 commit comments

Comments
 (0)