Skip to content

[Feature] Pass through the chat_template_kwargs to the data processing module #3421

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Aug 19, 2025
Merged
5 changes: 1 addition & 4 deletions fastdeploy/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,10 +465,7 @@ def add_requests(self, task, sampling_params=None, **kwargs):
request.sampling_params = sampling_params
request.preprocess_start_time = time.time()

enable_thinking = None
if kwargs is not None:
enable_thinking = kwargs.get("enable_thinking", None)
request = self.data_processor.process_request(request, self.cfg.max_model_len, enable_thinking=enable_thinking)
request = self.data_processor.process_request(request, self.cfg.max_model_len, **kwargs)
request.prompt_token_ids_len = len(request.prompt_token_ids)
request.need_prefill_tokens = request.prompt_token_ids_len
input_ids_len = request.prompt_token_ids_len
Expand Down
7 changes: 2 additions & 5 deletions fastdeploy/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def _add_request(
self,
prompts,
sampling_params,
chat_template_kwargs: Optional[dict[str, Any]] = None,
**kwargs,
):
"""
添加一个请求到 LLM Engine,并返回该请求的 ID。
Expand Down Expand Up @@ -289,10 +289,7 @@ def _add_request(
current_sampling_params = sampling_params[i]
else:
current_sampling_params = sampling_params
enable_thinking = None
if chat_template_kwargs is not None:
enable_thinking = chat_template_kwargs.get("enable_thinking", None)
self.llm_engine.add_requests(tasks, current_sampling_params, enable_thinking=enable_thinking)
self.llm_engine.add_requests(tasks, current_sampling_params, **kwargs)
return req_ids

def _decode_token(self, token_id: int) -> str:
Expand Down
14 changes: 12 additions & 2 deletions fastdeploy/input/ernie_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,13 @@ def process_request(self, request, max_model_len=None, **kwargs):
request.prompt_token_ids = token_ids
data_processor_logger.info(f"req_id:{request.request_id}, tokens:{tokens}, token_ids: {token_ids}")
else:
request.prompt_token_ids = self.messages2ids(request.to_dict())
task = request.to_dict()
chat_template_kwargs = request.get("chat_template_kwargs")
if chat_template_kwargs:
for k, v in chat_template_kwargs.items():
if k not in task:
task[k] = v
request.prompt_token_ids = self.messages2ids(task)

if len(request.prompt_token_ids) == 0:
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
Expand Down Expand Up @@ -140,7 +146,11 @@ def process_request_dict(self, request, max_model_len=None):
request = self._apply_default_parameters(request)
if not request.get("eos_token_ids"):
request["eos_token_ids"] = self.eos_token_ids

chat_template_kwargs = request.get("chat_template_kwargs")
if chat_template_kwargs:
for k, v in chat_template_kwargs.items():
if k not in request:
request[k] = v
# processing stop_sequences
stop_sequences = request.get("stop", [])
if stop_sequences:
Expand Down
7 changes: 6 additions & 1 deletion fastdeploy/input/ernie_vl_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def set_value(req, key, value):
def process_request(self, request, max_model_len=None, **kwargs):
"""process the input data"""
task = request.to_dict()
task["enable_thinking"] = kwargs.get("enable_thinking", True)
task["chat_template_kwargs"] = kwargs.get("chat_template_kwargs")
self.process_request_dict(task, max_model_len)
request = Request.from_dict(task)
request = self._apply_default_parameters(request)
Expand Down Expand Up @@ -198,6 +198,11 @@ def process_request_dict(self, request, max_model_len=None):
request = self._apply_default_parameters(request)
if not request.get("eos_token_ids"):
request["eos_token_ids"] = self.eos_token_ids
chat_template_kwargs = request.get("chat_template_kwargs")
if chat_template_kwargs:
for k, v in chat_template_kwargs.items():
if k not in request:
request[k] = v

stop_sequences = request.get("stop", [])
if stop_sequences:
Expand Down
12 changes: 10 additions & 2 deletions fastdeploy/input/text_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,6 @@ def process_request(self, request, max_model_len=None, **kwargs):
request = self._apply_default_parameters(request)
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
request.eos_token_ids = self.eos_token_ids

stop_sequences = request.get("stop", [])
if stop_sequences is not None and len(stop_sequences) != 0:
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
Expand All @@ -221,7 +220,11 @@ def process_request(self, request, max_model_len=None, **kwargs):
if self.tokenizer.chat_template is None:
raise ValueError("This model does not support chat_template.")
task = request.to_dict()
task["enable_thinking"] = kwargs.get("enable_thinking", True)
chat_template_kwargs = kwargs.get("chat_template_kwargs")
if chat_template_kwargs:
for k, v in chat_template_kwargs.items():
if k not in task:
task[k] = v
request.prompt_token_ids = self.messages2ids(task)
else:
raise ValueError(f"The request should have `input_ids`, `text` or `messages`: {request}.")
Expand Down Expand Up @@ -271,6 +274,11 @@ def process_request_dict(self, request, max_model_len=None, **kwargs):
elif "messages" in request:
if self.tokenizer.chat_template is None:
raise ValueError("This model does not support chat_template.")
chat_template_kwargs = request.get("chat_template_kwargs")
if chat_template_kwargs:
for k, v in chat_template_kwargs.items():
if k not in request:
request[k] = v
request["prompt_token_ids"] = self.messages2ids(request)
else:
raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}")
Expand Down
Loading