Skip to content

Commit 15778d2

Browse files
committed
update
【BugFix】completion接口echo回显支持 (PaddlePaddle#3245) * wenxin-tools-511,修复v1/completion无法回显的问题。 * 支持多prompt的回显 * 支持多prompt情况下的流式回显 * 补充了 completion 接口支持 echo 的单元测试 * pre-commit * 移除了多余的test文件 * 修复了completion接口echo支持的单测方法 * 补充了单元测试文件 * 补充单测 * unittest * 补充单测 * 修复单测 * 删除不必要的assert. * 重新提交 * 更新测试方法 * ut * 验证是否是正确思路单测 * 验证是否是正确思路单测 * 验证是否是正确思路单测3 * 优化单测代码,有针对性地缩小单测范围。 * 优化单测代码2,有针对性地缩小单测范围。 * 优化单测代码3,有针对性地缩小单测范围。 * support 'echo' in chat/completion. * update * update * update * update * update * update * 补充了关于tokenid的单元测试 * update * 修正index错误 * 修正index错误
1 parent d07338f commit 15778d2

File tree

3 files changed

+198
-10
lines changed

3 files changed

+198
-10
lines changed

fastdeploy/entrypoints/openai/serving_completion.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,15 @@ async def completion_full_generator(
240240
dealer.close()
241241
self.engine_client.semaphore.release()
242242

243-
def calc_finish_reason(self, max_tokens, token_num, output):
243+
async def _echo_back_prompt(self, request, res, idx):
244+
if res["outputs"].get("send_idx", -1) == 0 and request.echo:
245+
if isinstance(request.prompt, list):
246+
prompt_text = request.prompt[idx]
247+
else:
248+
prompt_text = request.prompt
249+
res["outputs"]["text"] = prompt_text + (res["outputs"]["text"] or "")
250+
251+
def calc_finish_reason(self, max_tokens, token_num, output, tool_called):
244252
if max_tokens is None or token_num != max_tokens:
245253
if self.engine_client.reasoning_parser == "ernie_x1" and output.get("finish_reason", "") == "tool_calls":
246254
return "tool_calls"
@@ -336,6 +344,7 @@ async def completion_stream_generator(
336344
else:
337345
arrival_time = res["metrics"]["arrival_time"] - inference_start_time[idx]
338346

347+
await self._echo_back_prompt(request, res, idx)
339348
output = res["outputs"]
340349
output_top_logprobs = output["top_logprobs"]
341350
logprobs_res: Optional[CompletionLogprobs] = None
@@ -430,7 +439,7 @@ def request_output_to_completion_response(
430439
final_res = final_res_batch[idx]
431440
prompt_token_ids = prompt_batched_token_ids[idx]
432441
assert prompt_token_ids is not None
433-
prompt_text = final_res["prompt"]
442+
prompt_text = request.prompt
434443
completion_token_ids = completion_batched_token_ids[idx]
435444

436445
output = final_res["outputs"]
@@ -448,17 +457,20 @@ def request_output_to_completion_response(
448457

449458
if request.echo:
450459
assert prompt_text is not None
451-
if request.max_tokens == 0:
452-
token_ids = prompt_token_ids
453-
output_text = prompt_text
460+
token_ids = [*prompt_token_ids, *output["token_ids"]]
461+
if isinstance(prompt_text, list):
462+
output_text = prompt_text[idx] + output["text"]
454463
else:
455-
token_ids = [*prompt_token_ids, *output["token_ids"]]
456-
output_text = prompt_text + output["text"]
464+
output_text = str(prompt_text) + output["text"]
457465
else:
458466
token_ids = output["token_ids"]
459467
output_text = output["text"]
468+
<<<<<<< HEAD
460469

461470
finish_reason = self.calc_finish_reason(request.max_tokens, final_res["output_token_ids"], output)
471+
=======
472+
finish_reason = self.calc_finish_reason(request.max_tokens, final_res["output_token_ids"], output, False)
473+
>>>>>>> c95b3395 (【BugFixcompletion接口echo回显支持 (#3245))
462474

463475
choice_data = CompletionResponseChoice(
464476
token_ids=token_ids,
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
import unittest
2+
from unittest.mock import MagicMock, patch
3+
4+
from fastdeploy.entrypoints.openai.serving_completion import (
5+
CompletionRequest,
6+
OpenAIServingCompletion,
7+
)
8+
9+
10+
class YourClass:
11+
async def _1(self, a, b, c):
12+
if b["outputs"].get("send_idx", -1) == 0 and a.echo:
13+
if isinstance(a.prompt, list):
14+
text = a.prompt[c]
15+
else:
16+
text = a.prompt
17+
b["outputs"]["text"] = text + (b["outputs"]["text"] or "")
18+
19+
20+
class TestCompletionEcho(unittest.IsolatedAsyncioTestCase):
21+
def setUp(self):
22+
self.mock_engine = MagicMock()
23+
self.completion_handler = None
24+
25+
def test_single_prompt_non_streaming(self):
26+
"""测试单prompt非流式响应"""
27+
self.completion_handler = OpenAIServingCompletion(self.mock_engine, pid=123, ips=None, max_waiting_time=30)
28+
29+
request = CompletionRequest(prompt="test prompt", max_tokens=10, echo=True, logprobs=1)
30+
31+
mock_output = {
32+
"outputs": {
33+
"text": " generated text",
34+
"token_ids": [1, 2, 3],
35+
"top_logprobs": {"token1": -0.1, "token2": -0.2},
36+
"finished": True,
37+
},
38+
"output_token_ids": 3,
39+
}
40+
self.mock_engine.generate.return_value = [mock_output]
41+
42+
response = self.completion_handler.request_output_to_completion_response(
43+
final_res_batch=[mock_output],
44+
request=request,
45+
request_id="test_id",
46+
created_time=12345,
47+
model_name="test_model",
48+
prompt_batched_token_ids=[[1, 2]],
49+
completion_batched_token_ids=[[3, 4, 5]],
50+
text_after_process_list=["test prompt"],
51+
)
52+
53+
self.assertEqual(response.choices[0].text, "test prompt generated text")
54+
55+
async def test_echo_back_prompt_and_streaming(self):
56+
"""测试_echo_back_prompt方法和流式响应的prompt拼接逻辑"""
57+
self.completion_handler = OpenAIServingCompletion(self.mock_engine, pid=123, ips=None, max_waiting_time=30)
58+
59+
request = CompletionRequest(prompt="test prompt", max_tokens=10, stream=True, echo=True)
60+
61+
mock_response = {"outputs": {"text": "test output", "token_ids": [1, 2, 3], "finished": True}}
62+
63+
with patch.object(self.completion_handler, "_echo_back_prompt") as mock_echo:
64+
65+
def mock_echo_side_effect(req, res, idx):
66+
res["outputs"]["text"] = req.prompt + res["outputs"]["text"]
67+
68+
mock_echo.side_effect = mock_echo_side_effect
69+
70+
await self.completion_handler._echo_back_prompt(request, mock_response, 0)
71+
72+
mock_echo.assert_called_once_with(request, mock_response, 0)
73+
74+
self.assertEqual(mock_response["outputs"]["text"], "test prompttest output")
75+
self.assertEqual(request.prompt, "test prompt")
76+
77+
def test_multi_prompt_non_streaming(self):
78+
"""测试多prompt非流式响应"""
79+
self.completion_handler = OpenAIServingCompletion(self.mock_engine, pid=123, ips=None, max_waiting_time=30)
80+
81+
request = CompletionRequest(prompt=["prompt1", "prompt2"], max_tokens=10, echo=True)
82+
83+
mock_outputs = [
84+
{
85+
"outputs": {"text": " response1", "token_ids": [1, 2], "top_logprobs": None, "finished": True},
86+
"output_token_ids": 2,
87+
},
88+
{
89+
"outputs": {"text": " response2", "token_ids": [3, 4], "top_logprobs": None, "finished": True},
90+
"output_token_ids": 2,
91+
},
92+
]
93+
self.mock_engine.generate.return_value = mock_outputs
94+
95+
response = self.completion_handler.request_output_to_completion_response(
96+
final_res_batch=mock_outputs,
97+
request=request,
98+
request_id="test_id",
99+
created_time=12345,
100+
model_name="test_model",
101+
prompt_batched_token_ids=[[1], [2]],
102+
completion_batched_token_ids=[[1, 2], [3, 4]],
103+
text_after_process_list=["prompt1", "prompt2"],
104+
)
105+
106+
self.assertEqual(len(response.choices), 2)
107+
self.assertEqual(response.choices[0].text, "prompt1 response1")
108+
self.assertEqual(response.choices[1].text, "prompt2 response2")
109+
110+
async def test_multi_prompt_streaming(self):
111+
self.completion_handler = OpenAIServingCompletion(self.mock_engine, pid=123, ips=None, max_waiting_time=30)
112+
113+
request = CompletionRequest(prompt=["prompt1", "prompt2"], max_tokens=10, stream=True, echo=True)
114+
115+
mock_responses = [
116+
{"outputs": {"text": " response1", "token_ids": [1, 2], "finished": True}},
117+
{"outputs": {"text": " response2", "token_ids": [3, 4], "finished": True}},
118+
]
119+
120+
with patch.object(self.completion_handler, "_echo_back_prompt") as mock_echo:
121+
122+
def mock_echo_side_effect(req, res, idx):
123+
res["outputs"]["text"] = req.prompt[idx] + res["outputs"]["text"]
124+
125+
mock_echo.side_effect = mock_echo_side_effect
126+
127+
await self.completion_handler._echo_back_prompt(request, mock_responses[0], 0)
128+
await self.completion_handler._echo_back_prompt(request, mock_responses[1], 1)
129+
130+
self.assertEqual(mock_echo.call_count, 2)
131+
mock_echo.assert_any_call(request, mock_responses[0], 0)
132+
mock_echo.assert_any_call(request, mock_responses[1], 1)
133+
134+
self.assertEqual(mock_responses[0]["outputs"]["text"], "prompt1 response1")
135+
self.assertEqual(mock_responses[1]["outputs"]["text"], "prompt2 response2")
136+
self.assertEqual(request.prompt, ["prompt1", "prompt2"])
137+
138+
async def test_echo_back_prompt_and_streaming1(self):
139+
request = CompletionRequest(echo=True, prompt=["Hello", "World"])
140+
res = {"outputs": {"send_idx": 0, "text": "!"}}
141+
idx = 0
142+
143+
instance = OpenAIServingCompletion(self.mock_engine, pid=123, ips=None, max_waiting_time=30)
144+
await instance._echo_back_prompt(request, res, idx)
145+
self.assertEqual(res["outputs"]["text"], "Hello!")
146+
147+
async def test_1_prompt_is_string_and_send_idx_is_0(self):
148+
request = CompletionRequest(echo=True, prompt="Hello")
149+
res = {"outputs": {"send_idx": 0, "text": "!"}}
150+
idx = 0
151+
152+
instance = OpenAIServingCompletion(self.mock_engine, pid=123, ips=None, max_waiting_time=30)
153+
await instance._echo_back_prompt(request, res, idx)
154+
self.assertEqual(res["outputs"]["text"], "Hello!")
155+
156+
async def test_1_send_idx_is_not_0(self):
157+
request = CompletionRequest(echo=True, prompt="Hello")
158+
res = {"outputs": {"send_idx": 1, "text": "!"}}
159+
idx = 0
160+
161+
instance = OpenAIServingCompletion(self.mock_engine, pid=123, ips=None, max_waiting_time=30)
162+
await instance._echo_back_prompt(request, res, idx)
163+
self.assertEqual(res["outputs"]["text"], "!")
164+
165+
async def test_1_echo_is_false(self):
166+
"""测试echo为False时,_echo_back_prompt不拼接prompt"""
167+
request = CompletionRequest(echo=False, prompt="Hello")
168+
res = {"outputs": {"send_idx": 0, "text": "!"}}
169+
idx = 0
170+
171+
instance = OpenAIServingCompletion(self.mock_engine, pid=123, ips=None, max_waiting_time=30)
172+
await instance._echo_back_prompt(request, res, idx)
173+
self.assertEqual(res["outputs"]["text"], "!")
174+
175+
176+
if __name__ == "__main__":
177+
unittest.main()

test/entrypoints/openai/test_serving_completion.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ def test_request_output_to_completion_response(self):
5555
openai_serving_completion = OpenAIServingCompletion(engine_client, "pid", "ips", 360)
5656
final_res_batch: List[RequestOutput] = [
5757
{
58-
"prompt": "Hello, world!",
5958
"outputs": {
6059
"token_ids": [1, 2, 3],
6160
"text": " world!",
@@ -67,7 +66,6 @@ def test_request_output_to_completion_response(self):
6766
"output_token_ids": 3,
6867
},
6968
{
70-
"prompt": "Hello, world!",
7169
"outputs": {
7270
"token_ids": [4, 5, 6],
7371
"text": " world!",
@@ -81,12 +79,13 @@ def test_request_output_to_completion_response(self):
8179
]
8280

8381
request: CompletionRequest = Mock()
82+
request.prompt = "Hello, world!"
83+
request.echo = True
8484
request_id = "test_request_id"
8585
created_time = 1655136000
8686
model_name = "test_model"
8787
prompt_batched_token_ids = [[1, 2, 3], [4, 5, 6]]
8888
completion_batched_token_ids = [[7, 8, 9], [10, 11, 12]]
89-
9089
completion_response = openai_serving_completion.request_output_to_completion_response(
9190
final_res_batch=final_res_batch,
9291
request=request,

0 commit comments

Comments
 (0)