Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions sample-applications/chat-question-and-answer/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
.venv/
__pycache__/
.vscode
144 changes: 80 additions & 64 deletions sample-applications/chat-question-and-answer/app/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
import openlit
from transformers import AutoTokenizer

set_verbose(True)

Expand Down Expand Up @@ -66,24 +65,35 @@
engine = create_async_engine(PG_CONNECTION_STRING)

# Init Embeddings via Intel Edge GenerativeAI Suite
embedder = EGAIEmbeddings(
openai_api_key="EMPTY",
openai_api_base="{}".format(EMBEDDING_ENDPOINT_URL),
model=MODEL_NAME,
tiktoken_enabled=False,
)
try:
embedder = EGAIEmbeddings(
openai_api_key="EMPTY",
openai_api_base="{}".format(EMBEDDING_ENDPOINT_URL),
model=MODEL_NAME,
request_timeout=30, # Add timeout
max_retries=3, # Add retries
)
logging.info(f"Embeddings initialized with endpoint: {EMBEDDING_ENDPOINT_URL}")
except Exception as e:
logging.error(f"Failed to initialize embeddings: {str(e)}")
raise

try:
knowledge_base = EGAIVectorDB(
embeddings=embedder,
collection_name=COLLECTION_NAME,
connection=engine,
)
retriever = EGAIVectorStoreRetriever(
vectorstore=knowledge_base,
search_type="mmr",
search_kwargs={"k": 1, "fetch_k": FETCH_K},
)
except Exception as e:
logging.error(f"Failed to initialize vector database or retriever: {str(e)}")
raise


knowledge_base = EGAIVectorDB(
embeddings=embedder,
collection_name=COLLECTION_NAME,
connection=engine,
)
retriever = EGAIVectorStoreRetriever(
vectorstore=knowledge_base,
search_type="mmr",
search_kwargs={"k": 1, "fetch_k": FETCH_K},
)

# Define our prompt
template = """
Expand Down Expand Up @@ -122,53 +132,59 @@
LLM_MODEL = os.getenv("LLM_MODEL", "Intel/neural-chat-7b-v3-3")
RERANKER_ENDPOINT = os.getenv("RERANKER_ENDPOINT", "http://localhost:9090/rerank")
callbacks = [streaming_stdout.StreamingStdOutCallbackHandler()]
tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL)


async def process_chunks(question_text, max_tokens):
if LLM_BACKEND in ["vllm", "unknown"]:
seed_value = None
model = EGAIModelServing(
openai_api_key="EMPTY",
openai_api_base="{}".format(ENDPOINT_URL),
model_name=LLM_MODEL,
top_p=0.99,
temperature=0.01,
streaming=True,
callbacks=callbacks,
stop=["\n\n"],
)
else:
seed_value = int(os.getenv("SEED", 42))
model = EGAIModelServing(
openai_api_key="EMPTY",
openai_api_base="{}".format(ENDPOINT_URL),
model_name=LLM_MODEL,
top_p=0.99,
temperature=0.01,
streaming=True,
callbacks=callbacks,
seed=seed_value,
max_tokens=max_tokens,
stop=["\n\n"],
try:
# Validate input
if not question_text or not question_text.strip():
raise ValueError("Question text cannot be empty")

if LLM_BACKEND in ["vllm", "unknown"]:
seed_value = None
model = EGAIModelServing(
openai_api_key="EMPTY",
openai_api_base="{}".format(ENDPOINT_URL),
model_name=LLM_MODEL,
top_p=0.99,
temperature=0.01,
streaming=True,
callbacks=callbacks,
stop=["\n\n"],
)
else:
seed_value = int(os.getenv("SEED", 42))
model = EGAIModelServing(
openai_api_key="EMPTY",
openai_api_base="{}".format(ENDPOINT_URL),
model_name=LLM_MODEL,
top_p=0.99,
temperature=0.01,
streaming=True,
callbacks=callbacks,
seed=seed_value,
max_tokens=max_tokens,
stop=["\n\n"],
)

re_ranker = CustomReranker(reranking_endpoint=RERANKER_ENDPOINT)
re_ranker_lambda = RunnableLambda(re_ranker.rerank)

# RAG Chain
chain = (
RunnableParallel({"context": retriever, "question": RunnablePassthrough()})
| re_ranker_lambda
| prompt
| model
| StrOutputParser()
)
tokens = tokenizer.tokenize(str(prompt))
num_tokens = len(tokens)
logging.info(f"Prompt tokens for model {LLM_MODEL}: {num_tokens}")
output_tokens = max_tokens - num_tokens
logging.info(f"Output tokens for model {LLM_MODEL}: {output_tokens}")

re_ranker = CustomReranker(reranking_endpoint=RERANKER_ENDPOINT)
re_ranker_lambda = RunnableLambda(re_ranker.rerank)

# RAG Chain
chain = (
RunnableParallel({"context": retriever, "question": RunnablePassthrough()})
| re_ranker_lambda
| prompt
| model
| StrOutputParser()
)
# Run the chain with the question text
async for log in chain.astream(question_text):
yield f"data: {log}\n\n"

# Run the chain with the question text
async for log in chain.astream(question_text):
yield f"data: {log}\n\n"

except ValueError as ve:
logging.error(f"Validation error in process_chunks: {str(ve)}")
yield f"data: Error: {str(ve)}\n\n"
except Exception as e:
logging.error(f"Error in process_chunks: {str(e)}", exc_info=True)
yield f"data: I apologize, but I encountered an error while processing your request. Please try again.\n\n"
42 changes: 27 additions & 15 deletions sample-applications/chat-question-and-answer/app/custom_reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,31 @@ def rerank_tei(self, retrieved_docs: Dict[str, Any]) -> Dict[str, Any]:
"raw_scores": False,
}

response = requests.post(
url=f"{self.reranking_endpoint}",
json=request_body,
headers={"Content-Type": "application/json"},
)
if response.status_code == 200:
logging.info(response.json())
try:
response = requests.post(
url=f"{self.reranking_endpoint}",
json=request_body,
headers={"Content-Type": "application/json"},
timeout=30 # Add timeout to prevent hanging
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardcoded timeout?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted! Can add this along with the reranker flag

)
if response.status_code == 200:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to check thread safety in concurrent usage of the application.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't see issue during benchmark testing where we did concurrent runs and multiple requests.

logging.info(response.json())

res = response.json()
maxRank = max(res, key=lambda x: x["score"])
return {
"question": retrieved_docs["question"],
"context": [retrieved_docs["context"][maxRank["index"]]],
}
else:
raise Exception(f"Error: {response.status_code}, {response.text}")
res = response.json()
maxRank = max(res, key=lambda x: x["score"])
return {
"question": retrieved_docs["question"],
"context": [retrieved_docs["context"][maxRank["index"]]],
}
else:
logging.error(f"Reranker error: {response.status_code}, {response.text}")
# Return original context if reranking fails
return retrieved_docs
except requests.exceptions.RequestException as e:
logging.error(f"Reranker request failed: {str(e)}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specific exception handling?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted the change for specific exceptions.

# Return original context if reranking fails
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a log / flag or some information to indicate results are without reranker.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This we should add, also to allow users to select whether to enable re-ranker as per the use case. Can we take this as enhancement.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure

return retrieved_docs
except Exception as e:
logging.error(f"Unexpected error in reranker: {str(e)}")
# Return original context if reranking fails
return retrieved_docs
29 changes: 20 additions & 9 deletions sample-applications/chat-question-and-answer/app/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ async def get_llm_model():
raise HTTPException(status_code=503, detail="LLM_MODEL is not set")
return {"status": "success", "llm_model": llm_model}

@app.post("/stream_log", response_class=StreamingResponse)
@app.post("/chat/completions", response_class=StreamingResponse)
async def query_chain(payload: QuestionRequest):
"""
Handles POST requests to the /stream_log endpoint.
Handles POST requests to the /chat/completions endpoint.

This endpoint receives a question in the form of a JSON payload, validates the input,
and returns a streaming response with the processed chunks of the question text.
Expand All @@ -116,13 +116,24 @@ async def query_chain(payload: QuestionRequest):
Raises:
HTTPException: If the input question text is empty or not provided, a 422 status code is returned.
"""
question_text = payload.input
max_tokens = payload.max_tokens if payload.max_tokens else 512
if max_tokens > 1024:
raise HTTPException(status_code=422, detail="max_tokens cannot be greater than 1024")
if not question_text or question_text == "":
raise HTTPException(status_code=422, detail="Question is required")
return StreamingResponse(process_chunks(question_text,max_tokens), media_type="text/event-stream")
try:
question_text = payload.input
max_tokens = payload.max_tokens if payload.max_tokens else 512
if max_tokens > 1024:
raise HTTPException(status_code=422, detail="MAX_TOKENS cannot be greater than 1024")
if not question_text or question_text == "":
raise HTTPException(status_code=422, detail="Question is required")

# Additional validation
if len(question_text.strip()) == 0:
raise HTTPException(status_code=422, detail="Question cannot be empty or whitespace only")

return StreamingResponse(process_chunks(question_text, max_tokens), media_type="text/event-stream")
except HTTPException:
# Re-raise HTTP exceptions
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")

FastAPIInstrumentor.instrument_app(app)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,3 @@ chatqnaui:

nginxService:
annotations: {}

31 changes: 20 additions & 11 deletions sample-applications/chat-question-and-answer/docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ services:
- LLM_MODEL=${LLM_MODEL}
- SEED=${SEED}
- OTLP_ENDPOINT=${OTLP_ENDPOINT}
- OTEL_EXPORTER_OTLP_TRACES_ENDPOINT=${OTLP_ENDPOINT}/v1/traces
- OTEL_SERVICE_NAME=${OTLP_SERVICE_NAME}
- OTEL_SERVICE_ENV=${OTLP_SERVICE_ENV}
- OTEL_SERVICE_VERSION=${OTEL_SERVICE_VERSION}
- REQUESTS_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt
- OTEL_EXPORTER_OTLP_TRACES_ENDPOINT=${OTLP_ENDPOINT:+${OTLP_ENDPOINT}/v1/traces}
- OTEL_SERVICE_NAME=${OTLP_ENDPOINT:+${OTLP_SERVICE_NAME}}
- OTEL_SERVICE_ENV=${OTLP_ENDPOINT:+${OTLP_SERVICE_ENV}}
- OTEL_SERVICE_VERSION=${OTLP_ENDPOINT:+${OTEL_SERVICE_VERSION}}
- REQUESTS_CA_BUNDLE=${REQUESTS_CA_BUNDLE:-}
- RERANKER_ENDPOINT=${RERANKER_ENDPOINT}
networks:
- my_network
Expand Down Expand Up @@ -165,11 +165,11 @@ services:
- OTEL_EXPORTER_OTLP_TRACES_PROTOCOL=http/protobuf
- OTEL_METRICS_EXPORTER=otlp
- OTEL_TRACES_EXPORTER=otlp
- OTEL_SERVICE_NAME=${OTLP_SERVICE_NAME}
- OTEL_SERVICE_ENV=${OTLP_SERVICE_ENV}
- OTEL_SERVICE_NAME=${OTLP_ENDPOINT:+${OTLP_SERVICE_NAME}}
- OTEL_SERVICE_ENV=${OTLP_ENDPOINT:+${OTLP_SERVICE_ENV}}
- OTLP_ENDPOINT=${OTLP_ENDPOINT}
- OTLP_ENDPOINT_TRACE=${OTLP_ENDPOINT_TRACE}
- REQUESTS_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt
- OTLP_ENDPOINT_TRACE=${OTLP_ENDPOINT:+${OTLP_ENDPOINT_TRACE}}
- REQUESTS_CA_BUNDLE=${REQUESTS_CA_BUNDLE:-}
cap_add:
- SYS_NICE
healthcheck:
Expand All @@ -179,7 +179,16 @@ services:
retries: 3
networks:
- my_network
command: /bin/bash -c "echo $OTLP_ENDPOINT && echo $OTLP_ENDPOINT_TRACE && pip install 'opentelemetry-sdk>=1.26.0,<1.27.0' 'opentelemetry-api>=1.26.0,<1.27.0' 'opentelemetry-exporter-otlp>=1.26.0,<1.27.0' 'opentelemetry-semantic-conventions-ai>=0.4.1,<0.5.0' opentelemetry-instrumentation-fastapi && opentelemetry-instrument python3 -m vllm.entrypoints.openai.api_server --enforce-eager --otlp-traces-endpoint=$OTLP_ENDPOINT_TRACE --model $LLM_MODEL --tensor-parallel-size $TENSOR_PARALLEL_SIZE --host 0.0.0.0 --port 80"
entrypoint:
- /bin/bash
- -c
- |
if [ -n "$OTLP_ENDPOINT" ]; then
pip install "opentelemetry-sdk>=1.26.0,<1.27.0" "opentelemetry-api>=1.26.0,<1.27.0" "opentelemetry-exporter-otlp>=1.26.0,<1.27.0" "opentelemetry-semantic-conventions-ai>=0.4.1,<0.5.0" opentelemetry-instrumentation-fastapi && \
opentelemetry-instrument python3 -m vllm.entrypoints.openai.api_server --enforce-eager --model $LLM_MODEL --tensor-parallel-size $TENSOR_PARALLEL_SIZE --host 0.0.0.0 --port 80;
else
python3 -m vllm.entrypoints.openai.api_server --enforce-eager --model $LLM_MODEL --tensor-parallel-size $TENSOR_PARALLEL_SIZE --host 0.0.0.0 --port 80;
fi

text-generation:
image: ghcr.io/huggingface/text-generation-inference:3.0.1-intel-xpu
Expand All @@ -193,7 +202,7 @@ services:
- https_proxy=${https_proxy}
- HUGGINGFACEHUB_API_TOKEN=${HUGGINGFACEHUB_API_TOKEN}
- HF_TOKEN=${HUGGINGFACEHUB_API_TOKEN}
- HUGGINGFACE_HUB_CACHE='/root/.cache/huggingface/hub'
- HUGGINGFACE_HUB_CACHE=/root/.cache/huggingface/hub
ports:
- "8080:80"
volumes:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ Returns:
content:
application/json:
schema:
/stream_log:
/chat/completions:
post:
summary: "Query Chain"
description: "Handles POST requests to the /stream_log endpoint.
description: "Handles POST requests to the /chat/completions endpoint.

This endpoint receives a question in the form of a JSON payload, validates the input,
and returns a streaming response with the processed chunks of the question text.
Expand All @@ -60,7 +60,7 @@ Returns:

Raises:
HTTPException: If the input question text is empty or not provided, a 422 status code is returned."
operationId: "query_chain_stream_log_post"
operationId: "query_chain_chat/completions_post"
requestBody:
content:
application/json:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Before you begin, ensure that you have the following prerequisites:

@task
def ask_query(self):
self.client.post("/v1/chatqna/stream_log", json={"input": "What is the capital of France?"})
self.client.post("/v1/chatqna/chat/completions", json={"input": "What is the capital of France?"})
```

3. **Run the Performance Test**:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
- Image Optimization for ChatQnA Backend and Document Ingestion Microservices. Reducing image sizes, which will lead to faster processing times and reduced bandwidth usage.
- Update to Run ChatQnA-UI and Nginx Container with Non-Root Access Privileges.
- Security Vulnerabilities Fix for Dependency Packages.
- Max Token Parameter Added to /stream_log API.
- Max Token Parameter Added to /chat/completions API.
- EMF deployment is supported.
- Bug fixes.

Expand Down
Loading