Skip to content

add custom chat template #3251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Aug 18, 2025
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
1a569e5
add custom chat_template
luukunn Jul 29, 2025
3b4326a
add custom chat_template
luukunn Jul 29, 2025
53d8beb
Merge branch 'develop' of https://github.com/PaddlePaddle/FastDeploy …
luukunn Jul 29, 2025
db34a06
Resolve merge conflicts
luukunn Aug 6, 2025
23095e5
add unittest
luukunn Aug 6, 2025
dc42a72
fix
luukunn Aug 6, 2025
d44cb51
add docs
luukunn Aug 6, 2025
0b4db9a
fix comment
luukunn Aug 6, 2025
8927f4b
add offline chat
luukunn Aug 6, 2025
4ad201c
fix unit test
luukunn Aug 6, 2025
f5f2c1f
fix unit test
luukunn Aug 7, 2025
b77da03
fix
luukunn Aug 11, 2025
238149e
Merge branch 'develop' into develop
luukunn Aug 11, 2025
ca72e35
Merge branch 'develop' into develop
luukunn Aug 11, 2025
bc9fb4b
fix pre commit
luukunn Aug 11, 2025
78f5804
Merge branch 'develop' of https://github.com/luukunn/FastDeploy into …
luukunn Aug 11, 2025
7b3f43f
fix unit test
luukunn Aug 11, 2025
227ddba
add unit test
luukunn Aug 11, 2025
8a124ad
add unit test
luukunn Aug 12, 2025
43e70c5
add unit test
luukunn Aug 12, 2025
c2189c6
Merge branch 'develop' into develop
luukunn Aug 12, 2025
1502081
fix pre_commit
luukunn Aug 12, 2025
573d7fa
Merge branch 'develop' into develop
luukunn Aug 12, 2025
23beb89
fix enable_thinking
luukunn Aug 18, 2025
f13d985
Merge branch 'develop' of https://github.com/luukunn/FastDeploy into …
luukunn Aug 18, 2025
3f75173
Merge branch 'develop' into develop
luukunn Aug 18, 2025
cb3eae6
fix pre commit
luukunn Aug 18, 2025
9f98538
fix pre commit
luukunn Aug 18, 2025
a9f3bc0
fix unit test
luukunn Aug 18, 2025
24c59e9
add requirements
luukunn Aug 18, 2025
38c73ff
Merge branch 'develop' into develop
luukunn Aug 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/online_serving/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ The following extra parameters are supported:
chat_template_kwargs: Optional[dict] = None
# Additional parameters passed to the chat template, used for customizing dialogue formats (default None).

chat_template: Optional[str] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docs/zh/parameters.md 和 docs/parameters.md 中补充启动参数说明。 中补充启动参数说明。

# Custom chat template will override the model's default chat template (default None).

reasoning_max_tokens: Optional[int] = None
# Maximum number of tokens to generate during reasoning (e.g., CoT, chain of thought) (default None means using global max_tokens).

Expand Down
3 changes: 3 additions & 0 deletions docs/zh/online_serving/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ repetition_penalty: Optional[float] = None
chat_template_kwargs: Optional[dict] = None
# 传递给聊天模板(chat template)的额外参数,用于自定义对话格式(默认 None)。

chat_template: Optional[str] = None
# 自定义聊天模板,会覆盖模型默认的聊天模板,(默认 None)。

reasoning_max_tokens: Optional[int] = None
# 推理(如 CoT, 思维链)过程中生成的最大 token 数(默认 None 表示使用全局 max_tokens)。

Expand Down
10 changes: 10 additions & 0 deletions fastdeploy/engine/args_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ class EngineArgs:
"""
specifies the reasoning parser to use for extracting reasoning content from the model output
"""
chat_template: str = None
"""
chat template or chat template file path
"""
enable_mm: bool = False
"""
Flags to enable multi-modal model
Expand Down Expand Up @@ -420,6 +424,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
help="Flag specifies the reasoning parser to use for extracting "
"reasoning content from the model output",
)
model_group.add_argument(
"--chat-template",
type=str,
default=EngineArgs.chat_template,
help="chat template or chat template file path",
)
model_group.add_argument(
"--speculative-config",
type=json.loads,
Expand Down
5 changes: 5 additions & 0 deletions fastdeploy/engine/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(
guided_json_object: Optional[bool] = None,
enable_thinking: Optional[bool] = True,
trace_carrier: dict = dict(),
chat_template: Optional[str] = None,
) -> None:
self.request_id = request_id
self.prompt = prompt
Expand Down Expand Up @@ -110,6 +111,8 @@ def __init__(
self.enable_thinking = enable_thinking
self.trace_carrier = trace_carrier

self.chat_template = chat_template

# token num
self.block_tables = []
self.output_token_ids = []
Expand Down Expand Up @@ -151,6 +154,7 @@ def from_dict(cls, d: dict):
guided_json_object=d.get("guided_json_object", None),
enable_thinking=d.get("enable_thinking", True),
trace_carrier=d.get("trace_carrier", {}),
chat_template=d.get("chat_template", None),
)

@property
Expand Down Expand Up @@ -190,6 +194,7 @@ def to_dict(self) -> dict:
"draft_token_ids": self.draft_token_ids,
"enable_thinking": self.enable_thinking,
"trace_carrier": self.trace_carrier,
"chat_template": self.chat_template,
}
add_params = [
"guided_json",
Expand Down
30 changes: 29 additions & 1 deletion fastdeploy/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""

from copy import deepcopy
from typing import List, Literal, Union
from typing import List, Literal, Union, Optional
from urllib.parse import urlparse

import requests
Expand All @@ -29,6 +29,7 @@

from fastdeploy.multimodal.image import ImageMediaIO
from fastdeploy.multimodal.video import VideoMediaIO
from pathlib import Path


class VideoURL(TypedDict, total=False):
Expand Down Expand Up @@ -156,3 +157,30 @@ def parse_chat_messages(messages):

conversation.append({"role": role, "content": parsed_content})
return conversation

def load_chat_template(chat_template: Union[Path, str], is_literal: bool = False,) -> Optional[str]:
if chat_template is None:
return None
if is_literal:
if isinstance(chat_template, Path):
raise TypeError("chat_template is expected to be read directly "
"from its value")

return chat_template

try:
with open(chat_template) as f:
return f.read()
except OSError as e:
if isinstance(chat_template, Path):
raise
JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS):
msg = (f"The supplied chat template ({chat_template}) "
f"looks like a file path, but it failed to be "
f"opened. Reason: {e}")
raise ValueError(msg) from e

# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
return load_chat_template(chat_template, is_literal=True)
9 changes: 8 additions & 1 deletion fastdeploy/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
retrive_model_from_server,
)
from fastdeploy.worker.output import Logprob, LogprobsLists
from fastdeploy.entrypoints.chat_utils import load_chat_template

root_logger = logging.getLogger()
for handler in root_logger.handlers[:]:
Expand Down Expand Up @@ -73,6 +74,7 @@ def __init__(
revision: Optional[str] = "master",
tokenizer: Optional[str] = None,
enable_logprob: Optional[bool] = False,
chat_template: Optional[str] = None,
**kwargs,
):
deprecated_kwargs_warning(**kwargs)
Expand All @@ -98,6 +100,7 @@ def __init__(
self.master_node_ip = self.llm_engine.cfg.master_ip
self._receive_output_thread = threading.Thread(target=self._receive_output, daemon=True)
self._receive_output_thread.start()
self.chat_template = load_chat_template(chat_template)

def _check_master(self):
"""
Expand Down Expand Up @@ -192,6 +195,7 @@ def chat(
sampling_params: Optional[Union[SamplingParams, list[SamplingParams]]] = None,
use_tqdm: bool = True,
chat_template_kwargs: Optional[dict[str, Any]] = None,
chat_template: Optional[str] = None,
):
"""
Args:
Expand Down Expand Up @@ -224,10 +228,13 @@ def chat(

if sampling_params_len != 1 and len(messages) != sampling_params_len:
raise ValueError("messages and sampling_params must be the same length.")

if chat_template is None:
chat_template = self.chat_template

messages_len = len(messages)
for i in range(messages_len):
messages[i] = {"messages": messages[i]}
messages[i] = {"messages": messages[i], "chat_template": chat_template}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个参数,不建议放在 messages 中, 放在 chat_template_kwargs 中吧? 后面的透传行为可以继续传下去。

req_ids = self._add_request(
prompts=messages,
sampling_params=sampling_params,
Expand Down
4 changes: 3 additions & 1 deletion fastdeploy/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
is_port_available,
retrive_model_from_server,
)
from fastdeploy.entrypoints.chat_utils import load_chat_template

parser = FlexibleArgumentParser()
parser.add_argument("--port", default=8000, type=int, help="port to the http server")
Expand Down Expand Up @@ -104,6 +105,7 @@ async def lifespan(app: FastAPI):
pid = os.getppid()
else:
pid = os.getpid()
chat_template = load_chat_template(args.chat_template)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个适合放在外面进行 load, 这里面会启动多进程。

api_server_logger.info(f"{pid}")
engine_client = EngineClient(
args.model,
Expand All @@ -119,7 +121,7 @@ async def lifespan(app: FastAPI):
args.enable_logprob,
)
app.state.dynamic_load_weight = args.dynamic_load_weight
chat_handler = OpenAIServingChat(engine_client, pid, args.ips)
chat_handler = OpenAIServingChat(engine_client, pid, args.ips, chat_template)
completion_handler = OpenAIServingCompletion(engine_client, pid, args.ips)
engine_client.create_zmq_client(model=pid, mode=zmq.PUSH)
engine_client.pid = pid
Expand Down
1 change: 1 addition & 0 deletions fastdeploy/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,7 @@ class ChatCompletionRequest(BaseModel):

# doc: start-completion-extra-params
chat_template_kwargs: Optional[dict] = None
chat_template: Optional[str] = None
reasoning_max_tokens: Optional[int] = None
structural_tag: Optional[str] = None
guided_json: Optional[Union[str, dict, BaseModel]] = None
Expand Down
5 changes: 4 additions & 1 deletion fastdeploy/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,12 @@ class OpenAIServingChat:
OpenAI-style chat completions serving
"""

def __init__(self, engine_client, pid, ips):
def __init__(self, engine_client, pid, ips, chat_template):
self.engine_client = engine_client
self.pid = pid
self.master_ip = ips
self.host_ip = get_host_ip()
self.chat_template = chat_template
if self.master_ip is not None:
if isinstance(self.master_ip, list):
self.master_ip = self.master_ip[0]
Expand Down Expand Up @@ -84,6 +85,8 @@ async def create_chat_completion(self, request: ChatCompletionRequest):
api_server_logger.info(f"create chat completion request: {request_id}")

try:
if request.chat_template is None:
request.chat_template = self.chat_template
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

要放在 to_dict_for_infer 之后, 因为chat_template 不是 openai 标准参数, 不能直接放在 request 传递。

current_req_dict = request.to_dict_for_infer(request_id)
current_req_dict["arrival_time"] = time.time()
prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict)
Expand Down
1 change: 1 addition & 0 deletions fastdeploy/input/ernie_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def messages2ids(self, request_or_messages):
tokenize=False,
split_special_tokens=False,
add_special_tokens=False,
chat_template=request_or_messages.get("chat_template", None)
)

req_id = None
Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/input/mm_processor/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,8 @@ def apply_chat_template(self, request):
request,
tokenize=False,
add_generation_prompt=request.get("add_generation_prompt", True),
chat_template=request.get("chat_template", None),

)
.replace("<|image@placeholder|>", "")
.replace("<|video@placeholder|>", "")
Expand Down
1 change: 1 addition & 0 deletions fastdeploy/input/text_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ def messages2ids(self, request):
split_special_tokens=False,
add_special_tokens=False,
return_tensors="pd",
chat_template=request.get("chat_template", None),
)
req_id = None
tokens = self.tokenizer.tokenize(spliced_message)
Expand Down
22 changes: 22 additions & 0 deletions test/utils/test_custom_chat_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import os
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单测除了验证模板打开, 还需要验证请求时增加 自定义模板。

import unittest
from fastdeploy.entrypoints.chat_utils import load_chat_template

input_chat_template = "unit test \n"

class TestChatTemplate(unittest.TestCase):
def test_load_chat_template_str(self):
result = load_chat_template(input_chat_template)
self.assertEqual(input_chat_template, result)

def test_load_chat_template_path(self):
with open("chat_template", 'w', encoding='utf-8') as file:
file.write(input_chat_template)
file_path = os.path.join(os.getcwd(), "chat_template")
result = load_chat_template(file_path)
os.remove(file_path)
self.assertEqual(input_chat_template, result)

if __name__ == "__main__":
unittest.main()

Loading