-
Notifications
You must be signed in to change notification settings - Fork 596
add custom chat template #3251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add custom chat template #3251
Changes from 11 commits
1a569e5
3b4326a
53d8beb
db34a06
23095e5
dc42a72
d44cb51
0b4db9a
8927f4b
4ad201c
f5f2c1f
b77da03
238149e
ca72e35
bc9fb4b
78f5804
7b3f43f
227ddba
8a124ad
43e70c5
c2189c6
1502081
573d7fa
23beb89
f13d985
3f75173
cb3eae6
9f98538
a9f3bc0
24c59e9
38c73ff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,7 @@ | |
retrive_model_from_server, | ||
) | ||
from fastdeploy.worker.output import Logprob, LogprobsLists | ||
from fastdeploy.entrypoints.chat_utils import load_chat_template | ||
|
||
root_logger = logging.getLogger() | ||
for handler in root_logger.handlers[:]: | ||
|
@@ -73,6 +74,7 @@ def __init__( | |
revision: Optional[str] = "master", | ||
tokenizer: Optional[str] = None, | ||
enable_logprob: Optional[bool] = False, | ||
chat_template: Optional[str] = None, | ||
**kwargs, | ||
): | ||
deprecated_kwargs_warning(**kwargs) | ||
|
@@ -98,6 +100,7 @@ def __init__( | |
self.master_node_ip = self.llm_engine.cfg.master_ip | ||
self._receive_output_thread = threading.Thread(target=self._receive_output, daemon=True) | ||
self._receive_output_thread.start() | ||
self.chat_template = load_chat_template(chat_template) | ||
|
||
def _check_master(self): | ||
""" | ||
|
@@ -192,6 +195,7 @@ def chat( | |
sampling_params: Optional[Union[SamplingParams, list[SamplingParams]]] = None, | ||
use_tqdm: bool = True, | ||
chat_template_kwargs: Optional[dict[str, Any]] = None, | ||
chat_template: Optional[str] = None, | ||
): | ||
""" | ||
Args: | ||
|
@@ -224,10 +228,13 @@ def chat( | |
|
||
if sampling_params_len != 1 and len(messages) != sampling_params_len: | ||
raise ValueError("messages and sampling_params must be the same length.") | ||
|
||
if chat_template is None: | ||
chat_template = self.chat_template | ||
|
||
messages_len = len(messages) | ||
for i in range(messages_len): | ||
messages[i] = {"messages": messages[i]} | ||
messages[i] = {"messages": messages[i], "chat_template": chat_template} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个参数,不建议放在 messages 中, 放在 chat_template_kwargs 中吧? 后面的透传行为可以继续传下去。 |
||
req_ids = self._add_request( | ||
prompts=messages, | ||
sampling_params=sampling_params, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -54,6 +54,7 @@ | |
is_port_available, | ||
retrive_model_from_server, | ||
) | ||
from fastdeploy.entrypoints.chat_utils import load_chat_template | ||
|
||
parser = FlexibleArgumentParser() | ||
parser.add_argument("--port", default=8000, type=int, help="port to the http server") | ||
|
@@ -104,6 +105,7 @@ async def lifespan(app: FastAPI): | |
pid = os.getppid() | ||
else: | ||
pid = os.getpid() | ||
chat_template = load_chat_template(args.chat_template) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个适合放在外面进行 load, 这里面会启动多进程。 |
||
api_server_logger.info(f"{pid}") | ||
engine_client = EngineClient( | ||
args.model, | ||
|
@@ -119,7 +121,7 @@ async def lifespan(app: FastAPI): | |
args.enable_logprob, | ||
) | ||
app.state.dynamic_load_weight = args.dynamic_load_weight | ||
chat_handler = OpenAIServingChat(engine_client, pid, args.ips) | ||
chat_handler = OpenAIServingChat(engine_client, pid, args.ips, chat_template) | ||
completion_handler = OpenAIServingCompletion(engine_client, pid, args.ips) | ||
engine_client.create_zmq_client(model=pid, mode=zmq.PUSH) | ||
engine_client.pid = pid | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -49,11 +49,12 @@ class OpenAIServingChat: | |
OpenAI-style chat completions serving | ||
""" | ||
|
||
def __init__(self, engine_client, pid, ips): | ||
def __init__(self, engine_client, pid, ips, chat_template): | ||
self.engine_client = engine_client | ||
self.pid = pid | ||
self.master_ip = ips | ||
self.host_ip = get_host_ip() | ||
self.chat_template = chat_template | ||
if self.master_ip is not None: | ||
if isinstance(self.master_ip, list): | ||
self.master_ip = self.master_ip[0] | ||
|
@@ -84,6 +85,8 @@ async def create_chat_completion(self, request: ChatCompletionRequest): | |
api_server_logger.info(f"create chat completion request: {request_id}") | ||
|
||
try: | ||
if request.chat_template is None: | ||
request.chat_template = self.chat_template | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 要放在 to_dict_for_infer 之后, 因为chat_template 不是 openai 标准参数, 不能直接放在 request 传递。 |
||
current_req_dict = request.to_dict_for_infer(request_id) | ||
current_req_dict["arrival_time"] = time.time() | ||
prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import os | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 单测除了验证模板打开, 还需要验证请求时增加 自定义模板。 |
||
import unittest | ||
from fastdeploy.entrypoints.chat_utils import load_chat_template | ||
|
||
input_chat_template = "unit test \n" | ||
|
||
class TestChatTemplate(unittest.TestCase): | ||
def test_load_chat_template_str(self): | ||
result = load_chat_template(input_chat_template) | ||
self.assertEqual(input_chat_template, result) | ||
|
||
def test_load_chat_template_path(self): | ||
with open("chat_template", 'w', encoding='utf-8') as file: | ||
file.write(input_chat_template) | ||
file_path = os.path.join(os.getcwd(), "chat_template") | ||
result = load_chat_template(file_path) | ||
os.remove(file_path) | ||
self.assertEqual(input_chat_template, result) | ||
|
||
if __name__ == "__main__": | ||
unittest.main() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docs/zh/parameters.md 和 docs/parameters.md 中补充启动参数说明。 中补充启动参数说明。