Skip to content

Commit d07338f

Browse files
authored
[Feature] Pass through the chat_template_kwargs to the data processing module (#3421) (#3469)
* fix chat_template_args * fix args * add offline * add offline * fix * fix * fix default enable_thinking value * fix default enable_thinking value * modify condition * Revert "modify condition" This reverts commit 26430bd. * fix unit test
1 parent 3ffbc98 commit d07338f

File tree

6 files changed

+50
-13
lines changed

6 files changed

+50
-13
lines changed

fastdeploy/engine/engine.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -497,10 +497,7 @@ def add_requests(self, task, sampling_params=None, **kwargs):
497497
request.sampling_params = sampling_params
498498
request.preprocess_start_time = time.time()
499499

500-
enable_thinking = None
501-
if kwargs is not None:
502-
enable_thinking = kwargs.get("enable_thinking", None)
503-
request = self.data_processor.process_request(request, self.cfg.max_model_len, enable_thinking=enable_thinking)
500+
request = self.data_processor.process_request(request, self.cfg.max_model_len, **kwargs)
504501
request.prompt_token_ids_len = len(request.prompt_token_ids)
505502
request.need_prefill_tokens = request.prompt_token_ids_len
506503
input_ids_len = request.prompt_token_ids_len

fastdeploy/entrypoints/llm.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def _add_request(
238238
self,
239239
prompts,
240240
sampling_params,
241-
chat_template_kwargs: Optional[dict[str, Any]] = None,
241+
**kwargs,
242242
):
243243
"""
244244
添加一个请求到 LLM Engine,并返回该请求的 ID。
@@ -279,10 +279,7 @@ def _add_request(
279279
current_sampling_params = sampling_params[i]
280280
else:
281281
current_sampling_params = sampling_params
282-
enable_thinking = None
283-
if chat_template_kwargs is not None:
284-
enable_thinking = chat_template_kwargs.get("enable_thinking", None)
285-
self.llm_engine.add_requests(tasks, current_sampling_params, enable_thinking=enable_thinking)
282+
self.llm_engine.add_requests(tasks, current_sampling_params, **kwargs)
286283
return req_ids
287284

288285
def _decode_token(self, token_id: int) -> str:

fastdeploy/input/ernie_processor.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,16 @@ def process_request(self, request, max_model_len=None, **kwargs):
108108
request.prompt_token_ids = token_ids
109109
data_processor_logger.info(f"req_id:{request.request_id}, tokens:{tokens}, token_ids: {token_ids}")
110110
else:
111-
request.prompt_token_ids = self.messages2ids(request.to_dict())
111+
task = request.to_dict()
112+
chat_template_kwargs = kwargs.get("chat_template_kwargs")
113+
if chat_template_kwargs:
114+
if isinstance(chat_template_kwargs, dict):
115+
for k, v in chat_template_kwargs.items():
116+
if k not in task:
117+
task[k] = v
118+
else:
119+
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
120+
request.prompt_token_ids = self.messages2ids(task)
112121

113122
if len(request.prompt_token_ids) == 0:
114123
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
@@ -163,6 +172,14 @@ def process_request_dict(self, request, max_model_len=None):
163172
req_id = request.get("request_id", None)
164173
data_processor_logger.info(f"req_id:{req_id}, tokens:{tokens}, token_ids: {token_ids}")
165174
else:
175+
chat_template_kwargs = request.get("chat_template_kwargs")
176+
if chat_template_kwargs:
177+
if isinstance(chat_template_kwargs, dict):
178+
for k, v in chat_template_kwargs.items():
179+
if k not in request:
180+
request[k] = v
181+
else:
182+
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
166183
request["prompt_token_ids"] = self.messages2ids(request)
167184
if len(request["prompt_token_ids"]) == 0:
168185
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")

fastdeploy/input/ernie_vl_processor.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def set_value(req, key, value):
109109
def process_request(self, request, max_model_len=None, **kwargs):
110110
"""process the input data"""
111111
task = request.to_dict()
112-
task["enable_thinking"] = kwargs.get("enable_thinking", True)
112+
task["chat_template_kwargs"] = kwargs.get("chat_template_kwargs")
113113
self.process_request_dict(task, max_model_len)
114114
request = Request.from_dict(task)
115115
request = self._apply_default_parameters(request)
@@ -216,6 +216,15 @@ def process_request_dict(self, request, max_model_len=None):
216216
elif request.get("messages"):
217217
messages = request["messages"]
218218
self._check_mm_limits(messages)
219+
chat_template_kwargs = request.get("chat_template_kwargs")
220+
if chat_template_kwargs:
221+
if isinstance(chat_template_kwargs, dict):
222+
for k, v in chat_template_kwargs.items():
223+
if k not in request:
224+
request[k] = v
225+
else:
226+
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
227+
request.setdefault("enable_thinking", True)
219228
outputs = self.ernie_processor.request2ids(request)
220229
else:
221230
raise ValueError(f"Request must contain 'prompt', or 'messages': {request}")

fastdeploy/input/text_processor.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,6 @@ def process_request(self, request, max_model_len=None, **kwargs):
222222
request = self._apply_default_parameters(request)
223223
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
224224
request.eos_token_ids = self.eos_token_ids
225-
226225
stop_sequences = request.get("stop", [])
227226
if stop_sequences is not None and len(stop_sequences) != 0:
228227
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
@@ -236,7 +235,15 @@ def process_request(self, request, max_model_len=None, **kwargs):
236235
if self.tokenizer.chat_template is None:
237236
raise ValueError("This model does not support chat_template.")
238237
task = request.to_dict()
239-
task["enable_thinking"] = kwargs.get("enable_thinking", True)
238+
chat_template_kwargs = kwargs.get("chat_template_kwargs")
239+
if chat_template_kwargs:
240+
if isinstance(chat_template_kwargs, dict):
241+
for k, v in chat_template_kwargs.items():
242+
if k not in task:
243+
task[k] = v
244+
else:
245+
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
246+
task.setdefault("enable_thinking", True)
240247
request.prompt_token_ids = self.messages2ids(task)
241248
else:
242249
raise ValueError(f"The request should have `input_ids`, `text` or `messages`: {request}.")
@@ -286,6 +293,15 @@ def process_request_dict(self, request, max_model_len=None, **kwargs):
286293
elif "messages" in request:
287294
if self.tokenizer.chat_template is None:
288295
raise ValueError("This model does not support chat_template.")
296+
chat_template_kwargs = request.get("chat_template_kwargs")
297+
if chat_template_kwargs:
298+
if isinstance(chat_template_kwargs, dict):
299+
for k, v in chat_template_kwargs.items():
300+
if k not in request:
301+
request[k] = v
302+
else:
303+
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
304+
request.setdefault("enable_thinking", True)
289305
request["prompt_token_ids"] = self.messages2ids(request)
290306
else:
291307
raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}")

test/ci_use/EB_VL_Lite/test_EB_VL_Lite_serving.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,7 @@ def test_chat_with_thinking(openai_client, capsys):
507507
extra_body={"chat_template_kwargs": {"enable_thinking": False}},
508508
)
509509
assert response.choices[0].message.reasoning_content is None
510+
assert "</think>" not in response.choices[0].message.content
510511

511512
# enable thinking, streaming
512513
reasoning_max_tokens = 3

0 commit comments

Comments
 (0)