From 3802f59ba88f01f460780d51046eb1cea77a3bfd Mon Sep 17 00:00:00 2001 From: luukunn <83932082+luukunn@users.noreply.github.com> Date: Tue, 19 Aug 2025 10:50:01 +0800 Subject: [PATCH] [Feature] Pass through the `chat_template_kwargs` to the data processing module (#3421) * fix chat_template_args * fix args * add offline * add offline * fix * fix * fix default enable_thinking value * fix default enable_thinking value * modify condition * Revert "modify condition" This reverts commit 26430bdeb1c86963b6fbeefc376a6d30d93262db. * fix unit test --- fastdeploy/engine/engine.py | 5 +---- fastdeploy/entrypoints/llm.py | 7 ++----- fastdeploy/input/ernie_processor.py | 19 +++++++++++++++++- fastdeploy/input/ernie_vl_processor.py | 11 +++++++++- fastdeploy/input/text_processor.py | 20 +++++++++++++++++-- .../EB_VL_Lite/test_EB_VL_Lite_serving.py | 1 + 6 files changed, 50 insertions(+), 13 deletions(-) diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index fa9fa61750..1ae1af568f 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -497,10 +497,7 @@ def add_requests(self, task, sampling_params=None, **kwargs): request.sampling_params = sampling_params request.preprocess_start_time = time.time() - enable_thinking = None - if kwargs is not None: - enable_thinking = kwargs.get("enable_thinking", None) - request = self.data_processor.process_request(request, self.cfg.max_model_len, enable_thinking=enable_thinking) + request = self.data_processor.process_request(request, self.cfg.max_model_len, **kwargs) request.prompt_token_ids_len = len(request.prompt_token_ids) request.need_prefill_tokens = request.prompt_token_ids_len input_ids_len = request.prompt_token_ids_len diff --git a/fastdeploy/entrypoints/llm.py b/fastdeploy/entrypoints/llm.py index 3e150abf2d..3f6d77872e 100644 --- a/fastdeploy/entrypoints/llm.py +++ b/fastdeploy/entrypoints/llm.py @@ -238,7 +238,7 @@ def _add_request( self, prompts, sampling_params, - chat_template_kwargs: Optional[dict[str, Any]] = None, + **kwargs, ): """ 添加一个请求到 LLM Engine,并返回该请求的 ID。 @@ -279,10 +279,7 @@ def _add_request( current_sampling_params = sampling_params[i] else: current_sampling_params = sampling_params - enable_thinking = None - if chat_template_kwargs is not None: - enable_thinking = chat_template_kwargs.get("enable_thinking", None) - self.llm_engine.add_requests(tasks, current_sampling_params, enable_thinking=enable_thinking) + self.llm_engine.add_requests(tasks, current_sampling_params, **kwargs) return req_ids def _decode_token(self, token_id: int) -> str: diff --git a/fastdeploy/input/ernie_processor.py b/fastdeploy/input/ernie_processor.py index 7cbb847f79..0077abf252 100644 --- a/fastdeploy/input/ernie_processor.py +++ b/fastdeploy/input/ernie_processor.py @@ -108,7 +108,16 @@ def process_request(self, request, max_model_len=None, **kwargs): request.prompt_token_ids = token_ids data_processor_logger.info(f"req_id:{request.request_id}, tokens:{tokens}, token_ids: {token_ids}") else: - request.prompt_token_ids = self.messages2ids(request.to_dict()) + task = request.to_dict() + chat_template_kwargs = kwargs.get("chat_template_kwargs") + if chat_template_kwargs: + if isinstance(chat_template_kwargs, dict): + for k, v in chat_template_kwargs.items(): + if k not in task: + task[k] = v + else: + raise ValueError("Invalid input: chat_template_kwargs must be a dict") + request.prompt_token_ids = self.messages2ids(task) if len(request.prompt_token_ids) == 0: raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs") @@ -163,6 +172,14 @@ def process_request_dict(self, request, max_model_len=None): req_id = request.get("request_id", None) data_processor_logger.info(f"req_id:{req_id}, tokens:{tokens}, token_ids: {token_ids}") else: + chat_template_kwargs = request.get("chat_template_kwargs") + if chat_template_kwargs: + if isinstance(chat_template_kwargs, dict): + for k, v in chat_template_kwargs.items(): + if k not in request: + request[k] = v + else: + raise ValueError("Invalid input: chat_template_kwargs must be a dict") request["prompt_token_ids"] = self.messages2ids(request) if len(request["prompt_token_ids"]) == 0: raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs") diff --git a/fastdeploy/input/ernie_vl_processor.py b/fastdeploy/input/ernie_vl_processor.py index d2975c6971..d032a186f7 100644 --- a/fastdeploy/input/ernie_vl_processor.py +++ b/fastdeploy/input/ernie_vl_processor.py @@ -109,7 +109,7 @@ def set_value(req, key, value): def process_request(self, request, max_model_len=None, **kwargs): """process the input data""" task = request.to_dict() - task["enable_thinking"] = kwargs.get("enable_thinking", True) + task["chat_template_kwargs"] = kwargs.get("chat_template_kwargs") self.process_request_dict(task, max_model_len) request = Request.from_dict(task) request = self._apply_default_parameters(request) @@ -216,6 +216,15 @@ def process_request_dict(self, request, max_model_len=None): elif request.get("messages"): messages = request["messages"] self._check_mm_limits(messages) + chat_template_kwargs = request.get("chat_template_kwargs") + if chat_template_kwargs: + if isinstance(chat_template_kwargs, dict): + for k, v in chat_template_kwargs.items(): + if k not in request: + request[k] = v + else: + raise ValueError("Invalid input: chat_template_kwargs must be a dict") + request.setdefault("enable_thinking", True) outputs = self.ernie_processor.request2ids(request) else: raise ValueError(f"Request must contain 'prompt', or 'messages': {request}") diff --git a/fastdeploy/input/text_processor.py b/fastdeploy/input/text_processor.py index eec346341a..91def2f2fb 100644 --- a/fastdeploy/input/text_processor.py +++ b/fastdeploy/input/text_processor.py @@ -222,7 +222,6 @@ def process_request(self, request, max_model_len=None, **kwargs): request = self._apply_default_parameters(request) if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0: request.eos_token_ids = self.eos_token_ids - stop_sequences = request.get("stop", []) if stop_sequences is not None and len(stop_sequences) != 0: stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences) @@ -236,7 +235,15 @@ def process_request(self, request, max_model_len=None, **kwargs): if self.tokenizer.chat_template is None: raise ValueError("This model does not support chat_template.") task = request.to_dict() - task["enable_thinking"] = kwargs.get("enable_thinking", True) + chat_template_kwargs = kwargs.get("chat_template_kwargs") + if chat_template_kwargs: + if isinstance(chat_template_kwargs, dict): + for k, v in chat_template_kwargs.items(): + if k not in task: + task[k] = v + else: + raise ValueError("Invalid input: chat_template_kwargs must be a dict") + task.setdefault("enable_thinking", True) request.prompt_token_ids = self.messages2ids(task) else: raise ValueError(f"The request should have `input_ids`, `text` or `messages`: {request}.") @@ -286,6 +293,15 @@ def process_request_dict(self, request, max_model_len=None, **kwargs): elif "messages" in request: if self.tokenizer.chat_template is None: raise ValueError("This model does not support chat_template.") + chat_template_kwargs = request.get("chat_template_kwargs") + if chat_template_kwargs: + if isinstance(chat_template_kwargs, dict): + for k, v in chat_template_kwargs.items(): + if k not in request: + request[k] = v + else: + raise ValueError("Invalid input: chat_template_kwargs must be a dict") + request.setdefault("enable_thinking", True) request["prompt_token_ids"] = self.messages2ids(request) else: raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}") diff --git a/test/ci_use/EB_VL_Lite/test_EB_VL_Lite_serving.py b/test/ci_use/EB_VL_Lite/test_EB_VL_Lite_serving.py index fb31a655f8..28c0c2c684 100644 --- a/test/ci_use/EB_VL_Lite/test_EB_VL_Lite_serving.py +++ b/test/ci_use/EB_VL_Lite/test_EB_VL_Lite_serving.py @@ -507,6 +507,7 @@ def test_chat_with_thinking(openai_client, capsys): extra_body={"chat_template_kwargs": {"enable_thinking": False}}, ) assert response.choices[0].message.reasoning_content is None + assert "" not in response.choices[0].message.content # enable thinking, streaming reasoning_max_tokens = 3