Skip to content

Commit 2bf7738

Browse files
authored
Merge pull request #7 from cyber-evangelists/dev-branch
PR for Language Issue Fix and Input Validations
2 parents 5c81379 + e503b87 commit 2bf7738

File tree

12 files changed

+123
-60
lines changed

12 files changed

+123
-60
lines changed

client-requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ loguru==0.7.2
33
websockets
44
python-dotenv==1.0.1
55
transformers==4.46.2
6-
torch==2.5.1
6+
torch==2.5.1
7+
python-bidi

client.py

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
import gradio as gr
22
import websockets
3-
import json
4-
import asyncio
5-
import logging
63
from typing import Tuple, List, Optional, Dict, Any
74
from loguru import logger
85

@@ -15,24 +12,41 @@
1512
guardrails_model = GuardRails()
1613

1714

18-
async def search_click(msg, history):
15+
async def search_click(msg: str, history: List[Tuple[str, str]]) -> Tuple[str, List[Tuple[str, str]], gr.Info]:
16+
17+
if not msg.strip():
18+
logger.error(f"No input provided")
19+
return "", history, gr.Warning("Please enter a query.")
1920

2021
response = int(guardrails_model.classify_prompt(msg))
2122

2223
if response == 0:
23-
return await ws_client.handle_request(
24+
result = await ws_client.handle_request(
2425
"search",
2526
{"query": msg, "history": history if history else []}
2627
)
28+
if result[2] == "right":
29+
30+
styled_response = (f"<div style='direction: rtl; text-align: right; direction: right;'>{result[1]}</div>")
31+
else:
32+
styled_response = f"<div style='direction: ltr; text-align: left; direction: left;'>{result[1]}</div>"
33+
34+
# Append the styled response to the chat history
35+
updated_history = history + [(msg, styled_response)]
36+
37+
38+
return result[0], updated_history, gr.Info("Query Processed")
39+
2740
else:
2841
return await return_protection_message(msg, history)
2942

3043

3144
async def return_protection_message(msg, history):
3245

33-
new_message = (msg, "Your query appears a prompt injection. I would prefer Not to answer it.")
46+
new_message = (msg, "Your query appears inappropriate. Do you have any other question?I am here to help.. ")
3447
updated_history = history + [new_message]
35-
return "", updated_history
48+
return "", updated_history, gr.Warning("Query is Inapproprite..")
49+
3650

3751

3852
async def handle_ingest() -> gr.Info:
@@ -74,6 +88,11 @@ async def record_feedback(feedback, msg ) -> gr.Info:
7488
logger.info(feedback)
7589
logger.info(msg)
7690

91+
92+
if not msg.strip():
93+
logger.error(f"No Comments provided")
94+
return gr.Info("Please Enter Some Feed back First"), ""
95+
7796
message, _ = await ws_client.handle_request(feedback, {"comment": msg})
7897
return gr.Info(message) if "success" in message.lower() else gr.Warning(message), ""
7998

@@ -107,7 +126,7 @@ async def record_feedback(feedback, msg ) -> gr.Info:
107126
margin-top: 0.25rem;
108127
flex: 0 0 auto;
109128
}
110-
#chatbot {
129+
#chatbot-left {
111130
border: 1px solid #E5E7EB;
112131
border-radius: 8px;
113132
background-color: #FFFFFF;
@@ -118,6 +137,24 @@ async def record_feedback(feedback, msg ) -> gr.Info:
118137
flex-direction: column;
119138
overflow-y: auto; /* To allow scrolling if content overflows */
120139
min-height: 62vh;
140+
text-direction: left;
141+
direction: left;
142+
text-align: left;
143+
}
144+
#chatbot-right {
145+
border: 1px solid #E5E7EB;
146+
border-radius: 8px;
147+
background-color: #FFFFFF;
148+
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
149+
flex: 1 1 auto;
150+
min-height: 0;
151+
display: flex;
152+
flex-direction: column;
153+
overflow-y: auto; /* To allow scrolling if content overflows */
154+
min-height: 62vh;
155+
text-direction: right;
156+
direction: right;
157+
text-align: right;
121158
}
122159
#feedback-button {
123160
max-width: 0.25vh;
@@ -141,7 +178,7 @@ async def record_feedback(feedback, msg ) -> gr.Info:
141178
chatbot = gr.Chatbot(
142179
show_label=False,
143180
container=True,
144-
elem_id="chatbot"
181+
elem_id="chatbot-left"
145182
)
146183

147184
with gr.Row(elem_id="feedback-container"):
@@ -173,7 +210,7 @@ async def record_feedback(feedback, msg ) -> gr.Info:
173210
send_button.click(
174211
fn=search_click,
175212
inputs=[msg, chatbot],
176-
outputs=[msg, chatbot]
213+
outputs=[msg, chatbot, status_box]
177214
)
178215
clear_button.click(
179216
fn=clear_chat,
@@ -203,4 +240,3 @@ async def record_feedback(feedback, msg ) -> gr.Info:
203240
share=False,
204241
debug=True,
205242
show_error=True,)
206-

docker-compose.yml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,22 @@ services:
1313
- capec-network
1414
hostname: rag-server
1515
volumes:
16+
- ./src:/app/src
17+
- ./capec-dataset:/app/capec-dataset
18+
- ./.env:/app/.env
1619
- ./src/index/index/:/app/src/index/index/
1720
environment:
1821
- TOKENIZERS_PARALLELISM=false
22+
command:
23+
[
24+
"uvicorn",
25+
"server:app",
26+
"--host",
27+
"0.0.0.0",
28+
"--port",
29+
"8000",
30+
"--reload"
31+
]
1932

2033
client:
2134
build:
@@ -28,6 +41,11 @@ services:
2841
environment:
2942
- SERVER_HOST=rag-server
3043
- SERVER_PORT=8000
44+
volumes:
45+
- ./src:/app/src
46+
- ./client.py:/app/client.py
47+
- ./client-requirements.txt:/app/client-requirements.txt
48+
command: ["python", "client.py"]
3149

3250
qdrant:
3351
image: qdrant/qdrant:v0.10.1

server.py

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,38 +6,37 @@
66
from typing import Dict, Any, List, Optional
77

88
from src.config.config import Config
9+
from src.qdrant.qdrant_utils import QdrantWrapper
910
from src.embedder.embedder_llama_index import EmbeddingWrapper
10-
from llama_index.core.retrievers import VectorIndexRetriever
11+
from src.parser.csv_parser import CsvParser
1112
from llama_index.core import Settings
1213
Settings.llm = None
1314

14-
from src.qdrant.qdrant_manager import QdrantManager
1515
from src.utils.connections_manager import ConnectionManager
1616
from src.chatbot.rag_chat_bot import RAGChatBot
1717
from src.reranker.re_ranking import RerankDocuments
1818

19-
import os
20-
2119
app = FastAPI()
2220

2321
chatbot = RAGChatBot()
22+
file_processor = CsvParser(data_dir = Config.DATA_DIRECTORY)
2423

2524
collection_name = Config.COLLECTION_NAME
26-
qdrantManager = QdrantManager(Config.QDRANT_HOST, Config.QDRANT_PORT, collection_name)
27-
25+
qdrant_client = QdrantWrapper()
2826
embedding_client = EmbeddingWrapper()
2927

3028

31-
data_dir = Config.CAPEC_DATA_DIR
29+
try:
3230

33-
reranker = RerankDocuments()
31+
processed_chunks = file_processor.process_directory()
32+
qdrant_client.ingest_embeddings(processed_chunks)
3433

35-
index = qdrantManager.load_index(persist_dir=Config.PERSIST_DIR, embed_model=embedding_client)
34+
logger.info("Successfully ingested Data")
3635

37-
retriever = VectorIndexRetriever(
38-
index=index,
39-
similarity_top_k=5
40-
)
36+
except Exception as e:
37+
logger.error(f"Error in data ingestion: {str(e)}")
38+
39+
reranker = RerankDocuments()
4140

4241
# Manually added file names of the CAPEC daatset. In production, These files will be fetched from database
4342
database_files = ["333.csv", "658.csv", "659.csv", "1000.csv", "3000.csv"]
@@ -66,27 +65,26 @@ async def handle_search(websocket: WebSocket, query: str) -> None:
6665

6766
filename = find_file_names(query, database_files)
6867

69-
if filename:
70-
logger.info("Searching for file names...")
68+
query_embeddings = embedding_client.generate_embeddings(query)
7169

72-
filters = MetadataFilters(filters=[ExactMatchFilter(key="source_file", value=filename)])
73-
relevant_nodes = index.as_retriever(filters=filters).retrieve(query)
74-
if not relevant_nodes:
75-
logger.info("Searching without file name filter....")
76-
relevant_nodes = retriever.retrieve(query)
77-
else:
78-
logger.info("Searching without file names....")
79-
relevant_nodes = retriever.retrieve(query)
70+
top_5_results = qdrant_client.search(query_embeddings, 5)
71+
logger.info("Retrieved top 5 results")
8072

81-
82-
context = [node.text for node in relevant_nodes]
83-
84-
reranked_docs = reranker.rerank_docs(query, context)
73+
if not top_5_results:
74+
logger.warning("No results found in database")
75+
await websocket.send_json({
76+
"result": "The database is empty. Please ingest some data first before searching."
77+
})
78+
return
8579

86-
# only top 2 documents are passing as a context
87-
response, conversation_id = chatbot.chat(query, reranked_docs[:2])
8880

81+
reranked_docs = reranker.rerank_docs(query, top_5_results)
82+
reranked_top_5_list = [item['content'] for item in reranked_docs]
8983

84+
context = reranked_top_5_list[:2]
85+
86+
# only top 2 documents are passing as a context
87+
response, conversation_id = chatbot.chat(query, context)
9088

9189
logger.info("Generating response from Groq")
9290

src/config/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class Config:
1313
GRADIO_SERVER_NAME = "0.0.0.0"
1414
GRADIO_SERVER_PORT = int(7860)
1515
WEBSOCKET_URI = "ws://rag-server:8000/ws"
16-
DATA_DIRECTORY = "data/"
16+
DATA_DIRECTORY = "capec-dataset/"
1717
WEBSOCKET_TIMEOUT = 300 # 5 minutes
1818
HEARTBEAT_INTERVAL = 30 # 30 seconds
1919
MAX_CONNECTIONS = 100

src/docker-files/Dockerfile.client

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@ RUN pip install --upgrade pip && \
1313

1414
# Copy only the required files for the application
1515
COPY client.py .
16-
COPY src/ ./src/
17-
1816

1917
# Run the application
2018
CMD ["python", "client.py"]

src/docker-files/Dockerfile.server

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,6 @@ RUN pip install --upgrade pip && \
1010
pip install -r requirements.txt
1111

1212
COPY server.py .
13-
COPY src/ ./src/
14-
COPY .env .
15-
COPY capec-dataset/ ./capec-dataset/
16-
1713

1814
# Set Python to run in unbuffered mode
1915
ENV PYTHONUNBUFFERED=1

src/parser/csv_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,5 +139,5 @@ def process_directory(self) -> List[Document]:
139139
logger.error(f"Skipping file {file_path} due to error: {str(e)}")
140140
continue
141141

142-
logger.info("All .csv files indexed....")
142+
logger.info("All .csv files processed. Returning chunks...")
143143
return all_documents

src/qdrant/qdrant_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def _connect_with_retry(self) -> None:
5252
self.client.get_collections()
5353
logger.info("Successfully connected to Qdrant")
5454
self._create_collection_if_not_exists()
55+
self.clear_collection()
5556
break
5657
except Exception as e:
5758
logger.error(f"Connection attempt {attempt + 1} failed: {str(e)}")

src/reranker/re_ranking.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,11 @@ def rerank_docs(self,
2727
"""
2828
# Re-ranking using cross-encoder
2929
# Prepare pairs for reranking
30-
pairs = [[query, doc] for doc in top_5_results]
30+
# Prepare pairs for reranking
31+
pairs = [[query, doc["content"]] for doc in top_5_results]
3132

3233
# Get relevance scores
33-
scores = self.reranker.predict(pairs)
34+
scores = self.reranker.predict(pairs)
3435

3536
# Sort by new scores
3637
reranked_results = [

0 commit comments

Comments
 (0)