Skip to content

Commit c59fb3a

Browse files
authored
A new post-processing type for verifying hit questions using LLM has been added (#669)
I think that using LLM for verification can yield better results in some cases where quality requirements are higher. Especially in the customer service field and the like, other users' questions may lead to privacy leaks or mention other brands, thus causing interference and so on. In terms of specific details, I maintained the operation of openai==0.28.0 on the interface and was compatible with the operation of openai>1.0.0 and above. I have added the test case, the example test file, and updated the example/readme.md. I upgraded the version of onnxruntime to 1.21.1 because the previous version 1.14.0 is no longer in use.
1 parent fcb7f2b commit c59fb3a

File tree

6 files changed

+291
-57
lines changed

6 files changed

+291
-57
lines changed

examples/README.md

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
11
# Example
22

3-
- [How to run Visual Question Answering with MiniGPT-4](#How-to-run-Visual-Question-Answering-with-MiniGPT-4)
4-
- [How to set the **embedding** function](#How-to-set-the-embedding-function)
5-
- [How to set the **data manager** class](#How-to-set-the-data-manager-class)
6-
- [How to set the **similarity evaluation** interface](#How-to-set-the-similarity-evaluation-interface)
7-
- [Other cache init params](#Other-cache-init-params)
8-
- [How to run with session](#How-to-run-with-session)
9-
- [How to use GPTCache server](#How-to-use-GPTCache-server)
10-
- [Benchmark](#Benchmark)
3+
- [Example](#example)
4+
- [How to run Visual Question Answering with MiniGPT-4](#how-to-run-visual-question-answering-with-minigpt-4)
5+
- [How to set the `embedding` function](#how-to-set-the-embedding-function)
6+
- [Default embedding function](#default-embedding-function)
7+
- [Suitable for embedding methods consisting of a cached storage and vector store](#suitable-for-embedding-methods-consisting-of-a-cached-storage-and-vector-store)
8+
- [Custom embedding](#custom-embedding)
9+
- [How to set the `data manager` class](#how-to-set-the-data-manager-class)
10+
- [How to set the `similarity evaluation` interface](#how-to-set-the-similarity-evaluation-interface)
11+
- [Request cache parameter customization](#request-cache-parameter-customization)
12+
- [How to run with session](#how-to-run-with-session)
13+
- [Run in `with` method](#run-in-with-method)
14+
- [Custom Session](#custom-session)
15+
- [How to use GPTCache server](#how-to-use-gptcache-server)
16+
- [Start server](#start-server)
17+
- [Benchmark](#benchmark)
18+
- [How to use post-process function](#how-to-use-post-process-function)
1119

1220
## How to run Visual Question Answering with MiniGPT-4
1321

@@ -686,3 +694,24 @@ similarity evaluation func: pair_evaluation (search distance)
686694
| 0.95 | 0.12s | 425 | 25 | 549 |
687695
| 0.9 | 0.23s | 804 | 77 | 118 |
688696
| 0.8 | 0.26s | 904 | 92 | 3 |
697+
## How to use post-process function
698+
699+
You can use the LlmVerifier() function to process the cached answer list after recall. This is similar to `first` or `random_one`, but it will call a LLM to verify whether the recalled question is truly similar to the user's question. You can define your own system prompt to decide under what circumstances the LLM should actively reject. You can also choose a small model to perform the verification step, so only a small additional cost is required.
700+
Example usage:
701+
702+
```python
703+
from gptcache.processor.post import post
704+
705+
# ... (init cache, embedding, data_manager, etc.)
706+
707+
cache.init(
708+
embedding_func=onnx.to_embeddings,
709+
data_manager=data_manager,
710+
similarity_evaluation=SearchDistanceEvaluation(),
711+
post_process_messages_func=LlmVerifier(client=None,
712+
system_prompt=custom_prompt,
713+
model="gpt-3.5-turbo")
714+
)
715+
```
716+
717+
See [processor/post_example.py](./processor/post_example.py) for a runnable example.
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import time
2+
import os
3+
4+
from gptcache import cache
5+
from gptcache.adapter import openai
6+
from gptcache.embedding import Onnx
7+
from gptcache.manager import manager_factory
8+
from gptcache.processor.post import LlmVerifier
9+
from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation
10+
11+
print("This example demonstrates how to use LLM verification with OpenAI's GPT-3.5 Turbo model.")
12+
cache.set_openai_key()
13+
14+
onnx = Onnx()
15+
data_manager = manager_factory("sqlite,faiss", vector_params={"dimension": onnx.dimension})
16+
17+
18+
19+
20+
custom_prompt = """You are a helpful assistant. Your task is to verify whether the answer is semantically consistent with the question.
21+
If the answer is consistent, respond with "yes". If it is not consistent, respond with "no".
22+
You must only respond in "yes" or "no". """
23+
24+
verifier = LlmVerifier(client=None,
25+
system_prompt=custom_prompt,
26+
model="gpt-3.5-turbo")
27+
28+
cache.init(
29+
embedding_func=onnx.to_embeddings,
30+
data_manager=data_manager,
31+
similarity_evaluation=SearchDistanceEvaluation(),
32+
post_process_messages_func=verifier
33+
)
34+
35+
question = 'what is github'
36+
37+
for _ in range(3):
38+
start = time.time()
39+
response = openai.ChatCompletion.create(
40+
model='gpt-3.5-turbo',
41+
messages=[{
42+
'role': 'user',
43+
'content': question
44+
}],
45+
)
46+
print(f"Response: {response['choices'][0]['message']['content']}")
47+
print(f"Time: {round(time.time() - start, 2)}s\n")

gptcache/adapter/adapter.py

Lines changed: 63 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44

55
from gptcache import cache
6-
from gptcache.processor.post import temperature_softmax
6+
from gptcache.processor.post import temperature_softmax, LlmVerifier
77
from gptcache.utils.error import NotInitError
88
from gptcache.utils.log import gptcache_log
99
from gptcache.utils.time import time_cal
@@ -189,6 +189,12 @@ def post_process():
189189
scores=[t[0] for t in cache_answers],
190190
temperature=temperature,
191191
)
192+
elif chat_cache.post_process_messages_func is LlmVerifier:
193+
return_message = chat_cache.post_process_messages_func(
194+
messages=[t[1] for t in cache_answers],
195+
scores=[t[0] for t in cache_answers],
196+
original_question=pre_embedding_data
197+
)
192198
else:
193199
return_message = chat_cache.post_process_messages_func(
194200
[t[1] for t in cache_answers]
@@ -200,29 +206,30 @@ def post_process():
200206
func_name="post_process",
201207
report_func=chat_cache.report.post,
202208
)()
203-
chat_cache.report.hint_cache()
204-
cache_whole_data = answers_dict.get(str(return_message))
205-
if session and cache_whole_data:
206-
chat_cache.data_manager.add_session(
207-
cache_whole_data[2], session.name, pre_embedding_data
208-
)
209-
if cache_whole_data and not chat_cache.config.disable_report:
210-
# user_question / cache_question / cache_question_id / cache_answer / similarity / consume time/ time
211-
report_cache_data = cache_whole_data[3]
212-
report_search_data = cache_whole_data[2]
213-
chat_cache.data_manager.report_cache(
214-
pre_store_data if isinstance(pre_store_data, str) else "",
215-
report_cache_data.question
216-
if isinstance(report_cache_data.question, str)
217-
else "",
218-
report_search_data[1],
219-
report_cache_data.answers[0].answer
220-
if isinstance(report_cache_data.answers[0].answer, str)
221-
else "",
222-
cache_whole_data[0],
223-
round(time.time() - start_time, 6),
224-
)
225-
return cache_data_convert(return_message)
209+
if return_message is not None:
210+
chat_cache.report.hint_cache()
211+
cache_whole_data = answers_dict.get(str(return_message))
212+
if session and cache_whole_data:
213+
chat_cache.data_manager.add_session(
214+
cache_whole_data[2], session.name, pre_embedding_data
215+
)
216+
if cache_whole_data and not chat_cache.config.disable_report:
217+
# user_question / cache_question / cache_question_id / cache_answer / similarity / consume time/ time
218+
report_cache_data = cache_whole_data[3]
219+
report_search_data = cache_whole_data[2]
220+
chat_cache.data_manager.report_cache(
221+
pre_store_data if isinstance(pre_store_data, str) else "",
222+
report_cache_data.question
223+
if isinstance(report_cache_data.question, str)
224+
else "",
225+
report_search_data[1],
226+
report_cache_data.answers[0].answer
227+
if isinstance(report_cache_data.answers[0].answer, str)
228+
else "",
229+
cache_whole_data[0],
230+
round(time.time() - start_time, 6),
231+
)
232+
return cache_data_convert(return_message)
226233

227234
next_cache = chat_cache.next_cache
228235
if next_cache:
@@ -444,6 +451,13 @@ def post_process():
444451
scores=[t[0] for t in cache_answers],
445452
temperature=temperature,
446453
)
454+
elif chat_cache.post_process_messages_func is LlmVerifier:
455+
return_message = chat_cache.post_process_messages_func(
456+
messages=[t[1] for t in cache_answers],
457+
scores=[t[0] for t in cache_answers],
458+
original_question=pre_embedding_data,
459+
temperature=temperature,
460+
)
447461
else:
448462
return_message = chat_cache.post_process_messages_func(
449463
[t[1] for t in cache_answers]
@@ -455,36 +469,38 @@ def post_process():
455469
func_name="post_process",
456470
report_func=chat_cache.report.post,
457471
)()
458-
chat_cache.report.hint_cache()
459-
cache_whole_data = answers_dict.get(str(return_message))
460-
if session and cache_whole_data:
461-
chat_cache.data_manager.add_session(
462-
cache_whole_data[2], session.name, pre_embedding_data
463-
)
464-
if cache_whole_data:
465-
# user_question / cache_question / cache_question_id / cache_answer / similarity / consume time/ time
466-
report_cache_data = cache_whole_data[3]
467-
report_search_data = cache_whole_data[2]
468-
chat_cache.data_manager.report_cache(
469-
pre_store_data if isinstance(pre_store_data, str) else "",
470-
report_cache_data.question
471-
if isinstance(report_cache_data.question, str)
472-
else "",
473-
report_search_data[1],
474-
report_cache_data.answers[0].answer
475-
if isinstance(report_cache_data.answers[0].answer, str)
476-
else "",
477-
cache_whole_data[0],
478-
round(time.time() - start_time, 6),
479-
)
480-
return cache_data_convert(return_message)
472+
if return_message is not None:
473+
chat_cache.report.hint_cache()
474+
cache_whole_data = answers_dict.get(str(return_message))
475+
if session and cache_whole_data:
476+
chat_cache.data_manager.add_session(
477+
cache_whole_data[2], session.name, pre_embedding_data
478+
)
479+
if cache_whole_data:
480+
# user_question / cache_question / cache_question_id / cache_answer / similarity / consume time/ time
481+
report_cache_data = cache_whole_data[3]
482+
report_search_data = cache_whole_data[2]
483+
chat_cache.data_manager.report_cache(
484+
pre_store_data if isinstance(pre_store_data, str) else "",
485+
report_cache_data.question
486+
if isinstance(report_cache_data.question, str)
487+
else "",
488+
report_search_data[1],
489+
report_cache_data.answers[0].answer
490+
if isinstance(report_cache_data.answers[0].answer, str)
491+
else "",
492+
cache_whole_data[0],
493+
round(time.time() - start_time, 6),
494+
)
495+
return cache_data_convert(return_message)
481496

482497
next_cache = chat_cache.next_cache
483498
if next_cache:
484499
kwargs["cache_obj"] = next_cache
485500
kwargs["cache_context"] = context
486501
kwargs["cache_skip"] = cache_skip
487502
kwargs["cache_factor"] = cache_factor
503+
kwargs["search_only"] = search_only_flag
488504
llm_data = adapt(
489505
llm_handler, cache_data_convert, update_cache_callback, *args, **kwargs
490506
)

gptcache/processor/post.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,119 @@ def temperature_softmax(messages: List[Any], scores: List[float], temperature: f
8787
else:
8888
m_s = list(zip(messages, scores))
8989
return sorted(m_s, key=lambda x: x[1], reverse=True)[0][0]
90+
91+
92+
93+
def llm_semantic_verification(
94+
messages: List[Any],
95+
scores: List[float] = None,
96+
original_question: str = None,
97+
*,
98+
client=None,
99+
system_prompt: str = None,
100+
model: str = "gpt-3.5-turbo",
101+
**kwargs
102+
) -> Any:
103+
"""
104+
Use LLM to verify whether the answer is semantically consistent with the question.
105+
If the answer passes verification, return it; otherwise, return None (to trigger a real LLM call).
106+
107+
:param messages: A list of candidate outputs.
108+
:type messages: List[Any]
109+
:param scores: A list of evaluation scores corresponding to messages.
110+
:type scores: List[float], optional
111+
:param original_question: The original question string.
112+
:type original_question: str, optional
113+
:param client: LLM client object, defaults to None.
114+
:type client: Any, optional
115+
:param system_prompt: System prompt, defaults to None.
116+
:type system_prompt: str, optional
117+
:param model: LLM model name, defaults to "gpt-3.5-turbo".
118+
:type model: str, optional
119+
:param temperature: Sampling temperature, defaults to 0.0.
120+
:type temperature: float, optional
121+
:param kwargs: Other keyword arguments.
122+
:return: The answer if it passes semantic verification, otherwise None.
123+
:rtype: Any
124+
125+
Example:
126+
.. code-block:: python
127+
128+
from gptcache.processor.post import llm_semantic_verification
129+
130+
messages = ["answer1", "answer2"]
131+
scores = [0.9, 0.5]
132+
question = "original question"
133+
answer = llm_semantic_verification(messages, scores, original_question=question)
134+
"""
135+
if not messages or not original_question:
136+
return None
137+
import openai
138+
139+
# Select the answer with the highest score
140+
best_answer = messages[0] if not scores else messages[scores.index(max(scores))]
141+
if client is None:
142+
client = openai
143+
else:
144+
client = client if hasattr(client, 'completions') else client.chat # Ensure client has the correct method for completions
145+
if system_prompt is None:
146+
system_prompt = ("You are a strict semantic verification assistant. "
147+
"… Only answer 'yes' or 'no'. If unsure, answer 'no'.")
148+
149+
try:
150+
resp = client.completions.create(
151+
model=model,
152+
messages=[
153+
{"role": "system", "content": system_prompt},
154+
{"role": "user",
155+
"content": f"Question: {original_question}\n"
156+
f"Answer: {best_answer}\n"
157+
f"Does this answer fully match the question? yes/no"}
158+
],
159+
temperature=0,
160+
max_tokens=10
161+
)
162+
verdict = resp.choices[0].message.content.strip().lower()
163+
if verdict in {"yes"}:
164+
return best_answer
165+
except Exception as e:
166+
print("LLM verification failed:", e)
167+
168+
169+
170+
return None
171+
172+
173+
class LlmVerifier:
174+
"""
175+
LlmVerifier is a callable class that wraps the llm_semantic_verification function.
176+
It stores the LLM client, system prompt, and model name for repeated semantic verification tasks.
177+
178+
:param client: LLM client object.
179+
:type client: Any
180+
:param system_prompt: System prompt for the LLM.
181+
:type system_prompt: str
182+
:param model: LLM model name, defaults to "gpt-3.5-turbo".
183+
:type model: str, optional
184+
"""
185+
def __init__(self, client=None, system_prompt=None, model="gpt-3.5-turbo"):
186+
self.client = client
187+
self.system_prompt = system_prompt
188+
self.model = model
189+
190+
def __call__(self, messages, scores=None, original_question=None, **kwargs):
191+
"""
192+
Call the verifier to perform semantic verification using the stored client, prompt, and model.
193+
194+
:param messages: A list of candidate outputs.
195+
:param scores: A list of evaluation scores corresponding to messages.
196+
:param original_question: The original question string.
197+
:param temperature: Sampling temperature.
198+
:param kwargs: Other keyword arguments.
199+
:return: The answer if it passes semantic verification, otherwise None.
200+
"""
201+
return llm_semantic_verification(
202+
messages, scores=scores, original_question=original_question,
203+
client=self.client, system_prompt=self.system_prompt,
204+
model=self.model, **kwargs
205+
)

gptcache/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def import_huggingface_hub():
105105

106106

107107
def import_onnxruntime():
108-
_check_library("onnxruntime", package="onnxruntime==1.14.1")
108+
_check_library("onnxruntime", package="onnxruntime==1.21.1")
109109

110110

111111
def import_faiss():

0 commit comments

Comments
 (0)