Skip to content

Commit 928b826

Browse files
committed
creating unit test cases
1 parent 82aab67 commit 928b826

File tree

9 files changed

+611
-10
lines changed

9 files changed

+611
-10
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
logs/
12
ipynb_checkpoints/
23
mlruns
34
mlartifacts

src/data/utils.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,26 +30,44 @@ def load_feature(
3030
@logger.catch
3131
def download_dataset(
3232
name: str,
33+
new_name: str,
34+
path: pathlib.Path,
35+
send_to_aws: bool,
3336
) -> None:
3437
"""Dowload the dataset using Kaggle's API.
3538
3639
Args:
3740
name (str): the dataset's name.
41+
new_name (str): the dataset file's new name.
42+
path (pathlib.Path): the path where the dataset will be stored locally.
43+
send_to_aws (bool): whether the dataset will be send to an AWS S3 bucket or not.
3844
"""
39-
kaggle_user = kaggle_credentials.KAGGLE_USERNAME
40-
kaggle_key = kaggle_credentials.KAGGLE_KEY
41-
path = '../data/'
45+
os.environ["KAGGLE_USERNAME"] = kaggle_credentials.KAGGLE_USERNAME
46+
os.environ["KAGGLE_KEY"] = kaggle_credentials.KAGGLE_KEY
47+
4248
logger.info(f"Downloading dataset {name} and saving into the folder {path}.")
4349

4450
# Downloading data using the Kaggle API through the terminal
45-
os.system(f'export KAGGLE_USERNAME={kaggle_user}; export KAGGLE_KEY={kaggle_key};')
46-
os.system(f'kaggle datasets download -d {name} -p {path} --unzip')
51+
# os.system(f'export KAGGLE_USERNAME={kaggle_user}; export KAGGLE_KEY={kaggle_key};')
52+
os.system(f'kaggle datasets download -d {name} --unzip')
53+
os.system(
54+
f'mv ObesityDataSet.csv {pathlib.Path.joinpath(path, new_name)}'
55+
)
4756

4857
# Sending the dataset to the AWS S3 bucket
49-
if aws_credentials.S3 != "YOUR_S3_BUCKET_URL":
50-
send_dataset_to_s3()
51-
58+
if send_to_aws:
59+
if aws_credentials.S3 != "YOUR_S3_BUCKET_URL":
60+
send_dataset_to_s3(
61+
file_path=path,
62+
file_name=new_name,
63+
)
64+
else:
65+
logger.warning(
66+
"The S3 Bucket url was not specified in the 'credentials.yaml' file. " +
67+
"Therefore, the dataset will not be send to S3 and it will be kept saved locally."
68+
)
5269

70+
@logger.catch
5371
def send_dataset_to_s3(
5472
file_path: pathlib.Path,
5573
file_name: str,
@@ -71,3 +89,5 @@ def send_dataset_to_s3(
7189
aws_credentials.S3,
7290
file_name,
7391
)
92+
93+
os.remove(pathlib.Path.joinpath(file_path, file_name))

src/model/inference.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,21 @@ def load(self) -> None:
4949
logger.critical(f"Couldn't load the model using the flavor {model_settings.MODEL_FLAVOR}.")
5050
raise NotImplementedError()
5151

52-
def predict(self, x: np.ndarray) -> np.ndarray:
52+
def predict(self, x: np.ndarray, transform_to_str: bool = True) -> np.ndarray:
5353
"""Uses the trained model to make a prediction on a given feature array.
5454
5555
Args:
5656
x (np.ndarray): the features array.
57+
transform_to_str (bool): whether to transform the prediction integer to
58+
string or not. Defaults to True.
5759
5860
Returns:
5961
np.ndarray: the predictions array.
6062
"""
6163
prediction = self.model.predict(x)
62-
prediction = label_encoder.inverse_transform(prediction)
64+
65+
if transform_to_str:
66+
prediction = label_encoder.inverse_transform(prediction)
67+
6368
logger.info(f"Prediction: {prediction}.")
6469
return prediction

tests/__init__.py

Whitespace-only changes.

tests/unit/__init__.py

Whitespace-only changes.

tests/unit/test_api.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import json
2+
from pathlib import Path
3+
from typing import Dict
4+
5+
import requests
6+
7+
from src.config.model import model_settings
8+
from src.config.settings import general_settings
9+
10+
with open(
11+
f"{Path.joinpath(general_settings.RESEARCH_ENVIRONMENT_PATH, 'VERSION')}",
12+
"r",
13+
encoding="utf-8"
14+
) as f:
15+
CODE_VERSION = f.readline().strip()
16+
17+
def test_version_endpoint() -> None:
18+
"""
19+
Unit case to test the API's version endpoint.
20+
"""
21+
desired_keys = ["model_version", "code_version"]
22+
23+
response = requests.get("http://127.0.0.1:8000/version", timeout=100)
24+
content = json.loads(response.text)
25+
26+
assert response.status_code == 200
27+
assert isinstance(content, Dict)
28+
assert all(dk in content.keys() for dk in desired_keys)
29+
assert model_settings.VERSION == content[desired_keys[0]]
30+
assert CODE_VERSION == content[desired_keys[1]]
31+
32+
def test_inference_endpoint() -> None:
33+
"""
34+
Unit case to test the API's inference endpoint.
35+
"""
36+
desired_classes = [["Normal_Weight"]]
37+
desired_keys = ["predictions"]
38+
39+
data = {
40+
"Age": 21,
41+
"CAEC": "Sometimes",
42+
"CALC": "no",
43+
"FAF": 0,
44+
"FCVC": 2,
45+
"Gender": "Female",
46+
"Height": 1.62,
47+
"MTRANS": "Public_Transportation",
48+
"SCC": "no",
49+
"SMOKE": "False",
50+
"TUE": 1,
51+
"Weight": 64
52+
}
53+
54+
response = requests.get("http://127.0.0.1:8000/predict", json=data, timeout=100)
55+
content = json.loads(response.text)
56+
57+
assert response.status_code == 200
58+
assert isinstance(content, Dict)
59+
assert all(dk in content.keys() for dk in desired_keys)
60+
assert content[desired_keys[0]] == desired_classes

0 commit comments

Comments
 (0)