Skip to content

Commit 227ddba

Browse files
committed
add unit test
1 parent 7b3f43f commit 227ddba

File tree

1 file changed

+60
-1
lines changed

1 file changed

+60
-1
lines changed

test/utils/test_custom_chat_template.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
import os
22
import unittest
33
from pathlib import Path
4-
from unittest.mock import MagicMock, mock_open, patch, AsyncMock
4+
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
55

6+
from fastdeploy.engine.request import Request
7+
from fastdeploy.engine.sampling_params import SamplingParams
68
from fastdeploy.entrypoints.chat_utils import load_chat_template
9+
from fastdeploy.entrypoints.llm import LLM
710
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest
811
from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat
12+
from fastdeploy.input.ernie_vl_processor import ErnieMoEVLProcessor
913

1014

1115
class TestLodChatTemplate(unittest.IsolatedAsyncioTestCase):
@@ -16,6 +20,7 @@ def setUp(self):
1620
"""
1721
self.input_chat_template = "unit test \n"
1822
self.mock_engine = MagicMock()
23+
self.tokenizer = MagicMock()
1924

2025
def test_load_chat_template_non(self):
2126
result = load_chat_template(None)
@@ -87,6 +92,60 @@ def mock_format_and_add_data(current_req_dict):
8792
chat_completion = await self.chat_completion_handler.create_chat_completion(request)
8893
self.assertEqual("hello", chat_completion["chat_template"])
8994

95+
@patch("fastdeploy.input.ernie_vl_processor.ErnieMoEVLProcessor.__init__")
96+
def test_vl_processor(self, mock_class):
97+
mock_class.return_value = None
98+
vl_processor = ErnieMoEVLProcessor()
99+
mock_request = Request.from_dict({"request_id": "123"})
100+
101+
def mock_apply_default_parameters(request):
102+
return request
103+
104+
def mock_process_request(request, max_model_len):
105+
return request
106+
107+
vl_processor._apply_default_parameters = mock_apply_default_parameters
108+
vl_processor.process_request_dict = mock_process_request
109+
result = vl_processor.process_request(mock_request, chat_template="hello")
110+
self.assertEqual("hello", result.chat_template)
111+
112+
@patch("fastdeploy.entrypoints.llm.LLM.__init__")
113+
def test_llm_load(self, mock_class):
114+
mock_class.return_value = None
115+
llm = LLM()
116+
llm.llm_engine = MagicMock()
117+
llm.default_sampling_params = MagicMock()
118+
llm.chat_template = "hello"
119+
120+
def mock_run_engine(req_ids, **kwargs):
121+
return req_ids
122+
123+
def mock_add_request(**kwargs):
124+
return kwargs.get("chat_template")
125+
126+
llm._run_engine = mock_run_engine
127+
llm._add_request = mock_add_request
128+
result = llm.chat(["hello"], sampling_params=SamplingParams(1))
129+
self.assertEqual("hello", result)
130+
131+
@patch("fastdeploy.entrypoints.llm.LLM.__init__")
132+
def test_llm(self, mock_class):
133+
mock_class.return_value = None
134+
llm = LLM()
135+
llm.llm_engine = MagicMock()
136+
llm.default_sampling_params = MagicMock()
137+
138+
def mock_run_engine(req_ids, **kwargs):
139+
return req_ids
140+
141+
def mock_add_request(**kwargs):
142+
return kwargs.get("chat_template")
143+
144+
llm._run_engine = mock_run_engine
145+
llm._add_request = mock_add_request
146+
result = llm.chat(["hello"], sampling_params=SamplingParams(1), chat_template="hello")
147+
self.assertEqual("hello", result)
148+
90149

91150
if __name__ == "__main__":
92151
unittest.main()

0 commit comments

Comments
 (0)