Skip to content

Commit e16f717

Browse files
committed
update TestOpenAIServingCompletion for merge
1 parent 727c7ad commit e16f717

File tree

2 files changed

+70
-7
lines changed

2 files changed

+70
-7
lines changed

fastdeploy/entrypoints/openai/serving_completion.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -436,9 +436,6 @@ def request_output_to_completion_response(
436436
token_ids = output["token_ids"]
437437
output_text = output["text"]
438438

439-
num_generated_tokens += final_res["output_token_ids"]
440-
num_prompt_tokens += len(prompt_token_ids)
441-
442439
finish_reason = self.calc_finish_reason(request.max_tokens, final_res["output_token_ids"], output)
443440

444441
choice_data = CompletionResponseChoice(
@@ -454,6 +451,10 @@ def request_output_to_completion_response(
454451
)
455452
choices.append(choice_data)
456453

454+
num_generated_tokens += final_res["output_token_ids"]
455+
456+
num_prompt_tokens += len(prompt_token_ids)
457+
457458
usage = UsageInfo(
458459
prompt_tokens=num_prompt_tokens,
459460
completion_tokens=num_generated_tokens,

test/entrypoints/openai/test_serving_completion.py

Lines changed: 66 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
import unittest
2+
from typing import List
23
from unittest.mock import Mock
34

4-
from fastdeploy.entrypoints.openai.serving_completion import OpenAIServingCompletion
5+
from fastdeploy.entrypoints.openai.serving_completion import (
6+
CompletionRequest,
7+
OpenAIServingCompletion,
8+
RequestOutput,
9+
)
510

611

712
class TestOpenAIServingCompletion(unittest.TestCase):
@@ -11,7 +16,7 @@ def test_calc_finish_reason_tool_calls(self):
1116
engine_client = Mock()
1217
engine_client.reasoning_parser = "ernie_x1"
1318
# 创建一个OpenAIServingCompletion实例
14-
serving_completion = OpenAIServingCompletion(engine_client, "pid", "ips")
19+
serving_completion = OpenAIServingCompletion(engine_client, "pid", "ips", 360)
1520
# 创建一个模拟的output,并设置finish_reason为"tool_calls"
1621
output = {"finish_reason": "tool_calls"}
1722
# 调用calc_finish_reason方法
@@ -24,7 +29,7 @@ def test_calc_finish_reason_stop(self):
2429
engine_client = Mock()
2530
engine_client.reasoning_parser = "ernie_x1"
2631
# 创建一个OpenAIServingCompletion实例
27-
serving_completion = OpenAIServingCompletion(engine_client, "pid", "ips")
32+
serving_completion = OpenAIServingCompletion(engine_client, "pid", "ips", 360)
2833
# 创建一个模拟的output,并设置finish_reason为其他值
2934
output = {"finish_reason": "other_reason"}
3035
# 调用calc_finish_reason方法
@@ -36,14 +41,71 @@ def test_calc_finish_reason_length(self):
3641
# 创建一个模拟的engine_client
3742
engine_client = Mock()
3843
# 创建一个OpenAIServingCompletion实例
39-
serving_completion = OpenAIServingCompletion(engine_client, "pid", "ips")
44+
serving_completion = OpenAIServingCompletion(engine_client, "pid", "ips", 360)
4045
# 创建一个模拟的output
4146
output = {}
4247
# 调用calc_finish_reason方法
4348
result = serving_completion.calc_finish_reason(100, 100, output)
4449
# 断言结果为"length"
4550
assert result == "length"
4651

52+
def test_request_output_to_completion_response(self):
53+
engine_client = Mock()
54+
# 创建一个OpenAIServingCompletion实例
55+
openai_serving_completion = OpenAIServingCompletion(engine_client, "pid", "ips", 360)
56+
final_res_batch: List[RequestOutput] = [
57+
{
58+
"prompt": "Hello, world!",
59+
"outputs": {
60+
"token_ids": [1, 2, 3],
61+
"text": " world!",
62+
"top_logprobs": {
63+
"a": 0.1,
64+
"b": 0.2,
65+
},
66+
},
67+
"output_token_ids": 3,
68+
},
69+
{
70+
"prompt": "Hello, world!",
71+
"outputs": {
72+
"token_ids": [4, 5, 6],
73+
"text": " world!",
74+
"top_logprobs": {
75+
"a": 0.3,
76+
"b": 0.4,
77+
},
78+
},
79+
"output_token_ids": 3,
80+
},
81+
]
82+
83+
request: CompletionRequest = Mock()
84+
request_id = "test_request_id"
85+
created_time = 1655136000
86+
model_name = "test_model"
87+
prompt_batched_token_ids = [[1, 2, 3], [4, 5, 6]]
88+
completion_batched_token_ids = [[7, 8, 9], [10, 11, 12]]
89+
90+
completion_response = openai_serving_completion.request_output_to_completion_response(
91+
final_res_batch=final_res_batch,
92+
request=request,
93+
request_id=request_id,
94+
created_time=created_time,
95+
model_name=model_name,
96+
prompt_batched_token_ids=prompt_batched_token_ids,
97+
completion_batched_token_ids=completion_batched_token_ids,
98+
)
99+
100+
assert completion_response.id == request_id
101+
assert completion_response.created == created_time
102+
assert completion_response.model == model_name
103+
assert len(completion_response.choices) == 2
104+
105+
# 验证 choices 的 text 属性
106+
assert completion_response.choices[0].text == "Hello, world! world!"
107+
assert completion_response.choices[1].text == "Hello, world! world!"
108+
47109

48110
if __name__ == "__main__":
49111
unittest.main()

0 commit comments

Comments
 (0)