Skip to content

Commit d1390ee

Browse files
committed
Merge remote-tracking branch 'upstream/develop' into develop
2 parents eea3877 + 6735626 commit d1390ee

File tree

9 files changed

+143
-69
lines changed

9 files changed

+143
-69
lines changed

fastdeploy/engine/engine.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -734,10 +734,6 @@ def insert_tasks(self, tasks, current_id=-1, allocated=False):
734734
"""
735735
Insert tasks to engine.
736736
"""
737-
for task in tasks:
738-
start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER)
739-
if task.sampling_params.bad_words is not None:
740-
task.sampling_params.update_from_tokenizer(self.data_processor.tokenizer)
741737
# TODO 返回至 scheduler
742738
if allocated:
743739
current_tasks = []
@@ -764,6 +760,11 @@ def insert_tasks(self, tasks, current_id=-1, allocated=False):
764760
self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz))
765761
return True
766762

763+
for task in tasks:
764+
start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER)
765+
if task.sampling_params.bad_words is not None:
766+
task.sampling_params.update_from_tokenizer(self.data_processor.tokenizer)
767+
767768
self.resource_manager.check_and_free_block_tables()
768769

769770
if not isinstance(tasks, list):

fastdeploy/entrypoints/openai/api_server.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -176,10 +176,10 @@ async def connection_manager():
176176
await asyncio.wait_for(connection_semaphore.acquire(), timeout=0.001)
177177
yield
178178
except asyncio.TimeoutError:
179-
api_server_logger.info(f"Reach max request release: {connection_semaphore.status()}")
180-
if connection_semaphore.locked():
181-
connection_semaphore.release()
182-
raise HTTPException(status_code=429, detail="Too many requests")
179+
api_server_logger.info(f"Reach max request concurrency, semaphore status: {connection_semaphore.status()}")
180+
raise HTTPException(
181+
status_code=429, detail=f"Too many requests,current max concurrency is {args.max_concurrency}"
182+
)
183183

184184

185185
# TODO 传递真实引擎值 通过pid 获取状态
@@ -266,9 +266,11 @@ async def create_chat_completion(request: ChatCompletionRequest):
266266
inject_to_metadata(request)
267267
generator = await app.state.chat_handler.create_chat_completion(request)
268268
if isinstance(generator, ErrorResponse):
269+
api_server_logger.debug(f"release: {connection_semaphore.status()}")
269270
connection_semaphore.release()
270271
return JSONResponse(content={"detail": generator.model_dump()}, status_code=generator.code)
271272
elif isinstance(generator, ChatCompletionResponse):
273+
api_server_logger.debug(f"release: {connection_semaphore.status()}")
272274
connection_semaphore.release()
273275
return JSONResponse(content=generator.model_dump())
274276
else:

fastdeploy/entrypoints/openai/serving_chat.py

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -78,34 +78,48 @@ async def create_chat_completion(self, request: ChatCompletionRequest):
7878
err_msg = f"Only master node can accept completion request, please send request to master node: {self.pod_ips[0]}"
7979
api_server_logger.error(err_msg)
8080
return ErrorResponse(message=err_msg, code=400)
81-
82-
if request.user is not None:
83-
request_id = f"chatcmpl-{request.user}-{uuid.uuid4()}"
84-
else:
85-
request_id = f"chatcmpl-{uuid.uuid4()}"
86-
api_server_logger.info(f"create chat completion request: {request_id}")
87-
text_after_process = None
8881
try:
89-
current_req_dict = request.to_dict_for_infer(request_id)
90-
if "chat_template" not in current_req_dict:
91-
current_req_dict["chat_template"] = self.chat_template
92-
current_req_dict["arrival_time"] = time.time()
93-
prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict)
94-
text_after_process = current_req_dict.get("text_after_process")
95-
if isinstance(prompt_token_ids, np.ndarray):
96-
prompt_token_ids = prompt_token_ids.tolist()
97-
except Exception as e:
98-
error_msg = f"request[{request_id}] send to infer error: {str(e)}, {str(traceback.format_exc())}"
99-
api_server_logger.error(error_msg)
100-
return ErrorResponse(code=400, message=error_msg)
101-
102-
del current_req_dict
103-
try:
104-
api_server_logger.debug(f"{self.engine_client.semaphore.status()}")
10582
if self.max_waiting_time < 0:
10683
await self.engine_client.semaphore.acquire()
10784
else:
10885
await asyncio.wait_for(self.engine_client.semaphore.acquire(), timeout=self.max_waiting_time)
86+
api_server_logger.info(f"current {self.engine_client.semaphore.status()}")
87+
88+
if request.user is not None:
89+
request_id = f"chatcmpl-{request.user}-{uuid.uuid4()}"
90+
else:
91+
request_id = f"chatcmpl-{uuid.uuid4()}"
92+
api_server_logger.info(f"create chat completion request: {request_id}")
93+
text_after_process = None
94+
try:
95+
current_req_dict = request.to_dict_for_infer(request_id)
96+
if "chat_template" not in current_req_dict:
97+
current_req_dict["chat_template"] = self.chat_template
98+
current_req_dict["arrival_time"] = time.time()
99+
prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict)
100+
text_after_process = current_req_dict.get("text_after_process")
101+
if isinstance(prompt_token_ids, np.ndarray):
102+
prompt_token_ids = prompt_token_ids.tolist()
103+
except Exception as e:
104+
error_msg = f"request[{request_id}] generator error: {str(e)}, {str(traceback.format_exc())}"
105+
api_server_logger.error(error_msg)
106+
return ErrorResponse(code=400, message=error_msg)
107+
108+
del current_req_dict
109+
110+
if request.stream:
111+
return self.chat_completion_stream_generator(
112+
request, request_id, request.model, prompt_token_ids, text_after_process
113+
)
114+
else:
115+
try:
116+
return await self.chat_completion_full_generator(
117+
request, request_id, request.model, prompt_token_ids, text_after_process
118+
)
119+
except Exception as e:
120+
error_msg = f"request[{request_id}]full generator error: {str(e)}, {str(traceback.format_exc())}"
121+
api_server_logger.error(error_msg)
122+
return ErrorResponse(code=408, message=error_msg)
109123
except Exception as e:
110124
error_msg = (
111125
f"request[{request_id}] waiting error: {str(e)}, {str(traceback.format_exc())}, "
@@ -114,20 +128,6 @@ async def create_chat_completion(self, request: ChatCompletionRequest):
114128
api_server_logger.error(error_msg)
115129
return ErrorResponse(code=408, message=error_msg)
116130

117-
if request.stream:
118-
return self.chat_completion_stream_generator(
119-
request, request_id, request.model, prompt_token_ids, text_after_process
120-
)
121-
else:
122-
try:
123-
return await self.chat_completion_full_generator(
124-
request, request_id, request.model, prompt_token_ids, text_after_process
125-
)
126-
except Exception as e:
127-
error_msg = f"request[{request_id}] generator error: {str(e)}, {str(traceback.format_exc())}"
128-
api_server_logger.error(error_msg)
129-
return ErrorResponse(code=400, message=error_msg)
130-
131131
def _create_streaming_error_response(self, message: str) -> str:
132132
api_server_logger.error(message)
133133
error_response = ErrorResponse(
@@ -264,6 +264,7 @@ async def chat_completion_stream_generator(
264264
logprobs_res = self._create_chat_logprobs(
265265
output_top_logprobs, request.logprobs, request.top_logprobs
266266
)
267+
267268
if self.engine_client.data_processor.tool_parser_obj and not res["finished"]:
268269
tool_delta_message = output["tool_delta_message"]
269270
if tool_delta_message is None:
@@ -287,7 +288,6 @@ async def chat_completion_stream_generator(
287288
logprobs=logprobs_res,
288289
arrival_time=arrival_time,
289290
)
290-
291291
if res["finished"]:
292292
num_choices -= 1
293293
work_process_metrics.e2e_request_latency.observe(
@@ -319,7 +319,6 @@ async def chat_completion_stream_generator(
319319
if len(choices) == max_streaming_response_tokens or res["finished"]:
320320
chunk.choices = choices
321321
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
322-
# 打印尾包
323322
if res["finished"]:
324323
api_server_logger.info(f"Chat Streaming response last send: {chunk.model_dump_json()}")
325324
choices = []
@@ -429,8 +428,9 @@ async def chat_completion_full_generator(
429428
if task_is_finished:
430429
break
431430
finally:
432-
self.engine_client.semaphore.release()
433431
dealer.close()
432+
self.engine_client.semaphore.release()
433+
api_server_logger.info(f"release {self.engine_client.semaphore.status()}")
434434

435435
choices = []
436436
output = final_res["outputs"]

fastdeploy/entrypoints/openai/serving_completion.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,19 @@ async def create_completion(self, request: CompletionRequest):
104104
api_server_logger.info(f"start inference for request {num_choices}")
105105
prompt_batched_token_ids = []
106106
text_after_process_list = []
107+
try:
108+
if self.max_waiting_time < 0:
109+
await self.engine_client.semaphore.acquire()
110+
else:
111+
await asyncio.wait_for(self.engine_client.semaphore.acquire(), timeout=self.max_waiting_time)
112+
except Exception as e:
113+
error_msg = (
114+
f"OpenAIServingCompletion waiting error: {e}, {str(traceback.format_exc())}, "
115+
f"max waiting time: {self.max_waiting_time}"
116+
)
117+
api_server_logger.error(error_msg)
118+
return ErrorResponse(code=408, message=error_msg)
119+
107120
try:
108121
for idx, prompt in enumerate(request_prompts):
109122
request_id_idx = f"{request_id}-{idx}"
@@ -122,19 +135,6 @@ async def create_completion(self, request: CompletionRequest):
122135

123136
del current_req_dict
124137

125-
try:
126-
if self.max_waiting_time < 0:
127-
await self.engine_client.semaphore.acquire()
128-
else:
129-
await asyncio.wait_for(self.engine_client.semaphore.acquire(), timeout=self.max_waiting_time)
130-
except Exception as e:
131-
error_msg = (
132-
f"OpenAIServingCompletion waiting error: {e}, {str(traceback.format_exc())}, "
133-
f"max waiting time: {self.max_waiting_time}"
134-
)
135-
api_server_logger.error(error_msg)
136-
return ErrorResponse(code=408, message=error_msg)
137-
138138
if request.stream:
139139
return self.completion_stream_generator(
140140
request=request,

fastdeploy/inter_communicator/zmq_client.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class ZmqClient:
3232
"""
3333

3434
def __init__(self, name, mode):
35-
self.context = zmq.Context()
35+
self.context = zmq.Context(4)
3636
self.socket = self.context.socket(mode)
3737
self.file_name = f"/dev/shm/{name}.socket"
3838
self.router_path = f"/dev/shm/router_{name}.ipc"
@@ -68,6 +68,7 @@ def create_router(self):
6868
"""
6969
self.router = self.context.socket(zmq.ROUTER)
7070
self.router.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
71+
self.router.setsockopt(zmq.ROUTER_MANDATORY, 1)
7172
self.router.setsockopt(zmq.SNDTIMEO, -1)
7273
self.router.bind(f"ipc://{self.router_path}")
7374

@@ -126,6 +127,11 @@ def send_multipart(self, req_id, data):
126127
else:
127128
break
128129

130+
if self.req_dict[req_id] == -1:
131+
if data[-1].finished:
132+
with self.mutex:
133+
self.req_dict.pop(req_id, None)
134+
return
129135
try:
130136
start_send = time.time()
131137
if self.aggregate_send:
@@ -134,7 +140,9 @@ def send_multipart(self, req_id, data):
134140
result = msgpack.packb([response.to_dict() for response in data])
135141
self.router.send_multipart([self.req_dict[req_id], b"", result])
136142
llm_logger.debug(f"send_multipart result: {req_id} len {len(data)} elapse: {time.time()-start_send}")
137-
143+
except zmq.ZMQError as e:
144+
llm_logger.error(f"[{req_id}] zmq error: {e}")
145+
self.req_dict[req_id] = -1
138146
except Exception as e:
139147
llm_logger.error(f"Send result to zmq client failed: {e}, {str(traceback.format_exc())}")
140148

fastdeploy/model_executor/layers/sample/early_stopper.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,16 +67,17 @@ def process(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags:
6767
def process_normal(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
6868
# Get the probability score corresponding to next_tokens in this step
6969
next_scores = paddle.index_sample(probs, next_tokens)
70+
real_bsz = probs.shape[0]
7071

7172
# Sliding window: Move left one grid and insert new score
72-
self.trunc_scores[:, :-1] = self.trunc_scores[:, 1:]
73-
self.trunc_scores[:, -1:] = next_scores
73+
self.trunc_scores[:real_bsz, :-1] = self.trunc_scores[:real_bsz, 1:]
74+
self.trunc_scores[:real_bsz, -1:] = next_scores
7475

7576
# Determine which samples need to be terminated: all trunc_scores are greater than threshold
7677
need_trunc_all = paddle.all(self.trunc_scores > self.threshold, axis=-1).unsqueeze(-1)
7778

7879
# Add the stop flags
79-
stop_flags[need_trunc_all] = True
80+
stop_flags[need_trunc_all[:real_bsz]] = True
8081

8182
# Reset trunc_scores of truncated samples to 0 to avoid false triggering in the next step
8283
reset_mask = need_trunc_all.tile([1, self.window_size])

scripts/coverage_run.sh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ done
2626
failed_tests_file="failed_tests.log"
2727
> "$failed_tests_file"
2828
disabled_tests=(
29-
layers/test_sampler.py
3029
layers/test_append_attention.py
3130
layers/test_attention.py
3231
operators/test_rejection_top_p_sampling.py
@@ -36,7 +35,6 @@ disabled_tests=(
3635
operators/test_stop_generation.py
3736
operators/test_air_topp_sampling.py
3837
operators/test_fused_moe.py
39-
layers/test_repetition_early_stopper.py
4038
operators/test_stop_generation_multi_ends.py
4139
graph_optimization/test_cuda_graph.py
4240
)

test/layers/test_repetition_early_stopper.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,69 @@ def test_consistency():
170170
actual = triggered_step_triton[i]
171171
assert expected == actual, f"Sample {i} triggered at different steps: {expected} vs {actual}"
172172

173-
print("Triton vs Normal: All tokens, states, and trigger timings match.")
173+
print("[consistency]Triton vs Normal: All tokens, states, and trigger timings match.")
174+
175+
176+
def test_consistency_with_real_batch_size():
177+
batch_size = 20
178+
real_batch_size = 15
179+
vocab_size = 103424
180+
window_size = 3000
181+
threshold = 0.9
182+
eos_token_id = vocab_size
183+
max_steps = 10
184+
185+
fixed_token_id = np.random.randint(0, vocab_size)
186+
early_stop_batch_id = np.random.randint(0, real_batch_size)
187+
188+
trigger_step_flags = [[i, np.random.randint(0, max_steps + 1)] for i in range(batch_size)]
189+
trigger_step_flags = dict(trigger_step_flags)
190+
cfg = EarlyStopConfig({"enable_early_stop": True, "window_size": window_size, "threshold": threshold})
191+
stopper_normal = RepetitionEarlyStopper()
192+
stopper_normal.initialize(batch_size, cfg)
193+
stopper_triton = RepetitionEarlyStopper()
194+
stopper_triton.initialize(batch_size, cfg)
195+
196+
next_tokens_normal = paddle.randint(0, vocab_size, shape=[real_batch_size, 1], dtype="int64")
197+
next_tokens_triton = next_tokens_normal.clone()
198+
199+
next_tokens_normal[early_stop_batch_id, 0] = fixed_token_id
200+
next_tokens_triton[early_stop_batch_id, 0] = fixed_token_id
201+
202+
stop_flags_normal = paddle.zeros_like(next_tokens_normal)
203+
stop_flags_triton = stop_flags_normal.clone()
204+
205+
triggered_step_normal = [None] * batch_size
206+
triggered_step_triton = [None] * batch_size
207+
208+
for step in range(max_steps):
209+
210+
flags = [trigger_step_flags[i] for i in range(real_batch_size)]
211+
probs_np = simulate_step_probs(real_batch_size, early_stop_batch_id, fixed_token_id, vocab_size, step, flags)
212+
probs = paddle.to_tensor(probs_np)
213+
214+
stopper_normal.process_normal(probs, next_tokens_normal, stop_flags_normal)
215+
stopper_triton.process_triton(probs, next_tokens_triton, stop_flags_triton)
216+
217+
assert np.allclose(stop_flags_normal.numpy(), stop_flags_triton.numpy()), f"stop flags mismatch at step {step}"
218+
219+
trunc_scores_diff = paddle.abs(stopper_normal.trunc_scores - stopper_triton.trunc_scores)
220+
assert paddle.all(trunc_scores_diff < 1e-5), f"trunc_scores mismatch at step {step}"
221+
222+
out_normal = stop_flags_normal.numpy()
223+
out_triton = stop_flags_triton.numpy()
224+
for i in range(real_batch_size):
225+
if out_normal[i, 0] == eos_token_id and triggered_step_normal[i] is None:
226+
triggered_step_normal[i] = step
227+
if out_triton[i, 0] == eos_token_id and triggered_step_triton[i] is None:
228+
triggered_step_triton[i] = step
229+
230+
for i in range(batch_size):
231+
expected = triggered_step_normal[i]
232+
actual = triggered_step_triton[i]
233+
assert expected == actual, f"Sample {i} triggered at different steps: {expected} vs {actual}"
234+
235+
print("[consistency_with_real_batch_size]Triton vs Normal: All tokens, states, and trigger timings match.")
174236

175237

176238
def test_performance():
@@ -232,4 +294,5 @@ def test_performance():
232294
if __name__ == "__main__":
233295
test_repetition_early_stopper()
234296
test_consistency()
297+
test_consistency_with_real_batch_size()
235298
test_performance()

test/layers/test_sampler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def _create_default_sampling_metadata(
5757
bad_words_token_ids=paddle.full(shape=[batch_size], fill_value=-1, dtype="int64"),
5858
eos_token_ids=paddle.full(shape=[batch_size], fill_value=-2, dtype="int64"),
5959
min_p=paddle.randn([batch_size]),
60+
seed=paddle.to_tensor([[2025]]),
6061
)
6162
return fake_sampling_metadata
6263

0 commit comments

Comments
 (0)