Skip to content

Commit db43636

Browse files
author
emmaamblard
committed
add tests for BO/Prediction.py
1 parent be0cd7f commit db43636

File tree

1 file changed

+57
-1
lines changed

1 file changed

+57
-1
lines changed

QA/py/tests/test_prediction.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
11
import logging
2+
import pytest
3+
import pandas as pd
4+
import numpy as np
25

36
from starlette import status
47

5-
from tests.credentials import ADMIN_AUTH
8+
from tests.credentials import ADMIN_AUTH, CREATOR_AUTH
69
from tests.test_jobs import get_job_and_wait_until_ok
10+
from tests.test_classification import _prj_query
11+
12+
from BO.Prediction import DeepFeatures
13+
14+
from API_operations.CRUD.ObjectParents import SamplesService
715

816
OBJECT_SET_PREDICT_URL = "/object_set/predict"
917

@@ -37,3 +45,51 @@ def no_test_basic_prediction(config, database, fastapi, caplog):
3745
assert rsp.status_code == status.HTTP_200_OK
3846

3947
job_id = get_job_and_wait_until_ok(fastapi, rsp)
48+
49+
50+
def test_prediction_functions(config, database, fastapi, caplog):
51+
caplog.set_level(logging.ERROR)
52+
from tests.test_import import test_import
53+
prj_id = test_import(config, database, caplog, "Test Prediction")
54+
55+
obj_ids = _prj_query(fastapi, CREATOR_AUTH, prj_id)
56+
assert len(obj_ids) == 8
57+
58+
# Prepare fake CNN features to insert
59+
features = list()
60+
for i, oi in enumerate(obj_ids):
61+
features.append([(i+1) * .1] * 50)
62+
features_df = pd.DataFrame(features, index=obj_ids)
63+
64+
# Test features insertion
65+
with SamplesService() as sce:
66+
n_inserts = DeepFeatures.save(sce.session, features_df)
67+
assert n_inserts == 8
68+
sce.session.commit()
69+
70+
# Test features retrieval
71+
with SamplesService() as sce:
72+
ret = DeepFeatures.np_read_for_objects(sce.session, obj_ids)
73+
assert (ret == np.array(features, dtype='float32')).all()
74+
75+
# Test find_missing without any missing features
76+
with SamplesService() as sce:
77+
ret = DeepFeatures.find_missing(sce.session, prj_id)
78+
assert ret == {}
79+
80+
# Test deletion
81+
with SamplesService() as sce:
82+
n_deletes = DeepFeatures.delete_all(sce.session, prj_id)
83+
assert n_deletes == 8
84+
sce.session.commit()
85+
86+
# Test find_missing after deletion
87+
with SamplesService() as sce:
88+
ret = DeepFeatures.find_missing(sce.session, prj_id)
89+
assert len(ret) == 8
90+
91+
# Test features retrieval in empty table, should raise an error
92+
with SamplesService() as sce:
93+
with pytest.raises(AssertionError):
94+
ret = DeepFeatures.np_read_for_objects(sce.session, obj_ids)
95+

0 commit comments

Comments
 (0)