Skip to content

Commit 8c2be4f

Browse files
authored
chatqna-core: Fix backend unit test (open-edge-platform#212)
Signed-off-by: Yeoh, Hoong Tee <hoong.tee.yeoh@intel.com>
1 parent ea1824a commit 8c2be4f

File tree

3 files changed

+195
-63
lines changed

3 files changed

+195
-63
lines changed

sample-applications/chat-question-and-answer-core/app/chain.py

Lines changed: 73 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -11,75 +11,85 @@
1111
from langchain_core.output_parsers import StrOutputParser
1212
from langchain.text_splitter import RecursiveCharacterTextSplitter
1313
from langchain_core.prompts import ChatPromptTemplate
14+
import os
1415
import pandas as pd
1516

1617
config = Settings()
1718
vectorstore = None
1819

19-
# login huggingface
20-
login_to_huggingface(config.HF_ACCESS_TOKEN)
21-
22-
# Download convert the model to openvino optimized
23-
download_huggingface_model(config.EMBEDDING_MODEL_ID, config.CACHE_DIR)
24-
download_huggingface_model(config.RERANKER_MODEL_ID, config.CACHE_DIR)
25-
download_huggingface_model(config.LLM_MODEL_ID, config.CACHE_DIR)
26-
27-
# Convert to openvino IR
28-
convert_model(config.EMBEDDING_MODEL_ID, config.CACHE_DIR, "embedding")
29-
convert_model(config.RERANKER_MODEL_ID, config.CACHE_DIR, "reranker")
30-
convert_model(config.LLM_MODEL_ID, config.CACHE_DIR, "llm")
31-
32-
# Define RAG prompt
33-
template = """
34-
Use the following pieces of context from retrieved
35-
dataset to answer the question. Do not make up an answer if there is no
36-
context provided to help answer it.
37-
38-
Context:
39-
---------
40-
{context}
41-
42-
---------
43-
Question: {question}
44-
---------
45-
46-
Answer:
47-
"""
48-
49-
prompt = ChatPromptTemplate.from_template(template)
50-
51-
# Initialize Embedding Model
52-
embedding = OpenVINOBgeEmbeddings(
53-
model_name_or_path=f"{config.CACHE_DIR}/{config.EMBEDDING_MODEL_ID}",
54-
model_kwargs={"device": config.EMBEDDING_DEVICE, "compile": False},
55-
)
56-
embedding.ov_model.compile()
57-
58-
# Initialize Reranker Model
59-
reranker = OpenVINOReranker(
60-
model_name_or_path=f"{config.CACHE_DIR}/{config.RERANKER_MODEL_ID}",
61-
model_kwargs={"device": config.RERANKER_DEVICE},
62-
top_n=2,
63-
)
64-
65-
# Initialize LLM
66-
llm = HuggingFacePipeline.from_model_id(
67-
model_id=f"{config.CACHE_DIR}/{config.LLM_MODEL_ID}",
68-
task="text-generation",
69-
backend="openvino",
70-
model_kwargs={
71-
"device": config.LLM_DEVICE,
72-
"ov_config": {
73-
"PERFORMANCE_HINT": "LATENCY",
74-
"NUM_STREAMS": "1",
75-
"CACHE_DIR": f"{config.CACHE_DIR}/{config.LLM_MODEL_ID}/model_cache",
20+
# The RUN_TEST flag is used to bypass the model download and conversion steps during pytest unit testing.
21+
# By default, the flag is set to 'false', enabling the model download and conversion process in a normal run.
22+
# To skip these steps, set the flag to 'true'.
23+
# Check environment flag
24+
RUN_TEST = os.getenv('RUN_TEST', False)
25+
26+
if not RUN_TEST:
27+
# login huggingface
28+
login_to_huggingface(config.HF_ACCESS_TOKEN)
29+
30+
# Download convert the model to openvino optimized
31+
download_huggingface_model(config.EMBEDDING_MODEL_ID, config.CACHE_DIR)
32+
download_huggingface_model(config.RERANKER_MODEL_ID, config.CACHE_DIR)
33+
download_huggingface_model(config.LLM_MODEL_ID, config.CACHE_DIR)
34+
35+
# Convert to openvino IR
36+
convert_model(config.EMBEDDING_MODEL_ID, config.CACHE_DIR, "embedding")
37+
convert_model(config.RERANKER_MODEL_ID, config.CACHE_DIR, "reranker")
38+
convert_model(config.LLM_MODEL_ID, config.CACHE_DIR, "llm")
39+
40+
# Define RAG prompt
41+
template = """
42+
Use the following pieces of context from retrieved
43+
dataset to answer the question. Do not make up an answer if there is no
44+
context provided to help answer it.
45+
46+
Context:
47+
---------
48+
{context}
49+
50+
---------
51+
Question: {question}
52+
---------
53+
54+
Answer:
55+
"""
56+
57+
prompt = ChatPromptTemplate.from_template(template)
58+
59+
# Initialize Embedding Model
60+
embedding = OpenVINOBgeEmbeddings(
61+
model_name_or_path=f"{config.CACHE_DIR}/{config.EMBEDDING_MODEL_ID}",
62+
model_kwargs={"device": config.EMBEDDING_DEVICE, "compile": False},
63+
)
64+
embedding.ov_model.compile()
65+
66+
# Initialize Reranker Model
67+
reranker = OpenVINOReranker(
68+
model_name_or_path=f"{config.CACHE_DIR}/{config.RERANKER_MODEL_ID}",
69+
model_kwargs={"device": config.RERANKER_DEVICE},
70+
top_n=2,
71+
)
72+
73+
# Initialize LLM
74+
llm = HuggingFacePipeline.from_model_id(
75+
model_id=f"{config.CACHE_DIR}/{config.LLM_MODEL_ID}",
76+
task="text-generation",
77+
backend="openvino",
78+
model_kwargs={
79+
"device": config.LLM_DEVICE,
80+
"ov_config": {
81+
"PERFORMANCE_HINT": "LATENCY",
82+
"NUM_STREAMS": "1",
83+
"CACHE_DIR": f"{config.CACHE_DIR}/{config.LLM_MODEL_ID}/model_cache",
84+
},
85+
"trust_remote_code": True,
7686
},
77-
"trust_remote_code": True,
78-
},
79-
pipeline_kwargs={"max_new_tokens": config.MAX_TOKENS},
80-
)
81-
if llm.pipeline.tokenizer.eos_token_id:
82-
llm.pipeline.tokenizer.pad_token_id = llm.pipeline.tokenizer.eos_token_id
87+
pipeline_kwargs={"max_new_tokens": config.MAX_TOKENS},
88+
)
89+
if llm.pipeline.tokenizer.eos_token_id:
90+
llm.pipeline.tokenizer.pad_token_id = llm.pipeline.tokenizer.eos_token_id
91+
else:
92+
logger.info("Bypassing to mock these functions because RUN_TEST is set to 'True' to run pytest unit test.")
8393

8494

8595
def default_context(docs):

sample-applications/chat-question-and-answer-core/tests/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
import os
2+
3+
# Configure the environment variable prior to importing the app
4+
# This ensures the app operates in test mode, bypassing the startup function responsible for model downloading and conversion
5+
os.environ['RUN_TEST'] = "True"
6+
17
import pytest
28
from fastapi.testclient import TestClient
39

sample-applications/chat-question-and-answer-core/tests/test_server.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,24 @@
33

44

55
def test_chain_response(test_client, mocker):
6+
"""
7+
Tests the chain response functionality of the server by simulating a POST
8+
request to the `/stream_log` endpoint and verifying the streamed response.
9+
Args:
10+
test_client: A test client instance used to simulate HTTP requests.
11+
mocker: A mocking library instance used to patch dependencies.
12+
Mocks:
13+
- `app.server.get_retriever`: Mocked to return `True`.
14+
- `app.server.build_chain`: Mocked to return `True`.
15+
- `app.server.process_query`: Mocked to return an iterator with values `["one", "two"]`.
16+
Raises:
17+
AssertionError: If any of the assertions fail.
18+
"""
619

720
payload = {"input": "What is AI?", "stream": True}
821

22+
mocker.patch("app.server.get_retriever", return_value=True)
23+
mocker.patch("app.server.build_chain", return_value=True)
924
mocker.patch("app.server.process_query", return_value=iter(["one", "two"]))
1025

1126
response = test_client.post("/stream_log", json=payload)
@@ -22,6 +37,23 @@ def test_chain_response(test_client, mocker):
2237

2338

2439
def test_success_upload_and_create_embedding(test_client, mocker):
40+
"""
41+
Tests the successful upload of a document and the creation of embeddings.
42+
This test simulates the process of uploading a text file, validating the document,
43+
saving it, and creating embeddings using a mocked FAISS vector database. It verifies
44+
that the API endpoint responds with the correct status code and response JSON.
45+
Args:
46+
test_client: A test client instance used to simulate HTTP requests to the API.
47+
mocker: A mocking library instance used to patch functions and simulate behavior.
48+
Mocks:
49+
- `app.server.validate_document`: Mocked to return `True`.
50+
- `app.server.save_document`: Mocked to return the temporary file name and `None`.
51+
- `app.server.create_faiss_vectordb`: Mocked to return `True`.
52+
Assertions:
53+
- The response status code is 200.
54+
- The response JSON matches the expected success message and metadata.
55+
"""
56+
2557
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as tmp_file:
2658
tmp_file.write(b"This is sample txt file.")
2759
tmp_file.seek(0)
@@ -44,6 +76,21 @@ def test_success_upload_and_create_embedding(test_client, mocker):
4476

4577

4678
def test_success_get_documents(test_client, mocker):
79+
"""
80+
Test the successful retrieval of documents from the server.
81+
This test verifies that the `/documents` endpoint returns a 200 status code
82+
and the expected JSON response containing a list of documents.
83+
Args:
84+
test_client (TestClient): A test client instance for making HTTP requests.
85+
mocker (MockerFixture): A mocker fixture for patching and mocking dependencies.
86+
Mocks:
87+
- `app.server.get_document_from_vectordb`: Mocked to return a list of documents.
88+
Assertions:
89+
- The response status code is 200.
90+
- The response JSON contains a "status" key with the value "Success".
91+
- The response JSON contains a "metadata" key with a "documents" list matching the mocked documents.
92+
"""
93+
4794
mock_documents = ["test1.txt", "test2.pdf"]
4895
mocker.patch('app.server.get_document_from_vectordb', return_value=mock_documents)
4996

@@ -57,6 +104,20 @@ def test_success_get_documents(test_client, mocker):
57104

58105

59106
def test_delete_embedding_success(test_client, mocker):
107+
"""
108+
Test the successful deletion of an embedding from the vector database.
109+
This test verifies that the `delete_embedding_from_vectordb` function is called
110+
and the API endpoint for deleting a document responds with the expected status code.
111+
Args:
112+
test_client: A test client instance for simulating HTTP requests to the server.
113+
mocker: A mocking library instance used to patch and mock dependencies.
114+
Mocks:
115+
- `app.server.delete_embedding_from_vectordb`: Mocked to return `True`.
116+
Assertions:
117+
- Ensures that the HTTP DELETE request to the "/documents" endpoint with
118+
the specified parameters returns a status code of HTTPStatus.NO_CONTENT.
119+
"""
120+
60121
mocker.patch('app.server.delete_embedding_from_vectordb', return_value=True)
61122

62123
response = test_client.delete("/documents", params={"document": "test1.txt"})
@@ -65,6 +126,20 @@ def test_delete_embedding_success(test_client, mocker):
65126

66127

67128
def test_delete_all_embedding_success(test_client, mocker):
129+
"""
130+
Test the successful deletion of all embeddings from the vector database.
131+
This test verifies that the endpoint for deleting all documents functions
132+
correctly by mocking the `delete_embedding_from_vectordb` function to
133+
return `True` and asserting that the response status code is `HTTPStatus.NO_CONTENT`.
134+
Args:
135+
test_client: A test client instance used to simulate HTTP requests to the server.
136+
mocker: A mocking utility used to patch the `delete_embedding_from_vectordb` function.
137+
Mocks:
138+
- `app.server.delete_embedding_from_vectordb`: Mocked to return `True`.
139+
Assertions:
140+
- The response status code is `HTTPStatus.NO_CONTENT` (204).
141+
"""
142+
68143
mocker.patch('app.server.delete_embedding_from_vectordb', return_value=True)
69144

70145
response = test_client.delete("/documents", params={"delete_all": True})
@@ -73,6 +148,17 @@ def test_delete_all_embedding_success(test_client, mocker):
73148

74149

75150
def test_upload_unsupported_file(test_client):
151+
"""
152+
Tests the upload of an unsupported file format to the server.
153+
This test verifies that the server returns a 400 status code and an appropriate
154+
error message when a file with an unsupported format (e.g., .html) is uploaded.
155+
Args:
156+
test_client: A test client instance used to simulate HTTP requests to the server.
157+
Raises:
158+
AssertionError: If the response status code is not 400 or the error message
159+
does not match the expected output.
160+
"""
161+
76162
with tempfile.NamedTemporaryFile(delete=True, suffix=".html") as tmp_file:
77163
tmp_file.write(b"This is sample html file.")
78164
tmp_file.seek(0)
@@ -86,6 +172,21 @@ def test_upload_unsupported_file(test_client):
86172

87173

88174
def test_fail_get_documents(test_client, mocker):
175+
"""
176+
Test case for handling failure when retrieving documents from the vector database.
177+
This test simulates an exception being raised during the retrieval of documents
178+
from the vector database and verifies that the server responds with the appropriate
179+
HTTP status code and error message.
180+
Args:
181+
test_client: A test client instance used to simulate HTTP requests to the server.
182+
mocker: A mocking library instance used to patch and simulate behavior of dependencies.
183+
Mocks:
184+
- `app.server.get_document_from_vectordb`: Mocked to raise an exception with the message "Error getting documents."
185+
Asserts:
186+
- The HTTP response status code is 500 (Internal Server Error).
187+
- The JSON response contains the expected error message.
188+
"""
189+
89190
mocker.patch('app.server.get_document_from_vectordb', side_effect=Exception("Error getting documents."))
90191

91192
response = test_client.get("/documents")
@@ -97,6 +198,21 @@ def test_fail_get_documents(test_client, mocker):
97198

98199

99200
def test_delete_embedding_failure(test_client, mocker):
201+
"""
202+
Test case for handling failure during the deletion of embeddings from the vector database.
203+
This test simulates a failure scenario where the `delete_embedding_from_vectordb` function
204+
raises an exception. It verifies that the server responds with the appropriate HTTP status
205+
code and error message.
206+
Args:
207+
test_client (TestClient): A test client instance for simulating HTTP requests to the server.
208+
mocker (MockerFixture): A fixture for mocking dependencies and functions.
209+
Mocks:
210+
- `app.server.delete_embedding_from_vectordb`: Mocked to raise an exception with the message "Error deleting embeddings."
211+
Asserts:
212+
- The response status code is 500 (Internal Server Error).
213+
- The response JSON contains the expected error detail message.
214+
"""
215+
100216
mocker.patch('app.server.delete_embedding_from_vectordb', side_effect=Exception("Error deleting embeddings."))
101217

102218
response = test_client.delete("/documents", params={"document": "test1.txt"})

0 commit comments

Comments
 (0)