Skip to content

Commit 04c2f3c

Browse files
committed
mm support structured output
1 parent b630031 commit 04c2f3c

File tree

18 files changed

+422
-105
lines changed

18 files changed

+422
-105
lines changed

docs/features/structured_outputs.md

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,3 +330,65 @@ ParsedChatCompletionMessage[Info](content='{"addr": "No.1 Century Avenue, Pudong
330330
Address: No.1 Century Avenue, Pudong New Area, Shanghai
331331
Height: 468
332332
```
333+
334+
### Offline Inference
335+
336+
Offline inference allows restricting the model's output format by pre-specified constraints. In `FastDeploy`, constraints can be specified through the `GuidedDecodingParams` class in `SamplingParams`. `GuidedDecodingParams` supports the following constraint types, with usage similar to online inference:
337+
338+
```python
339+
json: Optional[Union[str, dict]] = None
340+
regex: Optional[str] = None
341+
choice: Optional[List[str]] = None
342+
grammar: Optional[str] = None
343+
json_object: Optional[bool] = None
344+
structural_tag: Optional[str] = None
345+
```
346+
347+
The following example demonstrates how to use offline inference to generate a structured json:
348+
349+
```python
350+
from fastdeploy import LLM, SamplingParams
351+
from fastdeploy.engine.sampling_params import GuidedDecodingParams
352+
from pydantic import BaseModel
353+
from enum import Enum
354+
355+
class BookType(str, Enum):
356+
romance = "Romance"
357+
historical = "Historical"
358+
adventure = "Adventure"
359+
mystery = "Mystery"
360+
dystopian = "Dystopian"
361+
362+
class BookDescription(BaseModel):
363+
author: str
364+
title: str
365+
genre: BookType
366+
367+
# Constrained decoding parameters
368+
guided_decoding_params = GuidedDecodingParams(json=BookDescription.model_json_schema())
369+
370+
# Sampling parameters
371+
sampling_params = SamplingParams(
372+
top_p=0.95,
373+
max_tokens=6400,
374+
guided_decoding=guided_decoding_params,
375+
)
376+
377+
# Load model
378+
llm = LLM(model="ERNIE-4.5-0.3B", tensor_parallel_size=1, max_model_len=8192, guided_decoding_backend="auto")
379+
380+
outputs = llm.generate(
381+
prompts="Generate a JSON describing a literary work, including author, title and book type.",
382+
sampling_params=sampling_params,
383+
)
384+
385+
# Output results
386+
for output in outputs:
387+
print(output.outputs.text)
388+
```
389+
390+
Output:
391+
392+
```
393+
{"author": "George Orwell", "title": "1984", "genre": "Dystopian"}
394+
```

docs/zh/features/structured_outputs.md

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,3 +330,67 @@ ParsedChatCompletionMessage[Info](content='{"addr": "上海市浦东新区世纪
330330
地址: 上海市浦东新区世纪大道1号
331331
高度: 468
332332
```
333+
334+
### 离线推理
335+
336+
离线推理允许通过预先指定约束条件,限制模型输出格式。在 `FastDeploy` 中,支持通过 `SamplingParams` 中的 `GuidedDecodingParams` 类指定相关约束条件。`GuidedDecodingParams` 支持以下几种约束条件,使用方式可以参考在线推理:
337+
338+
```python
339+
json: Optional[Union[str, dict]] = None
340+
regex: Optional[str] = None
341+
choice: Optional[List[str]] = None
342+
grammar: Optional[str] = None
343+
json_object: Optional[bool] = None
344+
structural_tag: Optional[str] = None
345+
```
346+
347+
以下示例展示了如何使用离线推理生成一个结构化的 json :
348+
349+
```python
350+
351+
from fastdeploy import LLM, SamplingParams
352+
from fastdeploy.engine.sampling_params import GuidedDecodingParams
353+
from pydantic import BaseModel
354+
from enum import Enum
355+
356+
class BookType(str, Enum):
357+
romance = "Romance"
358+
historical = "Historical"
359+
adventure = "Adventure"
360+
mystery = "Mystery"
361+
dystopian = "Dystopian"
362+
363+
class BookDescription(BaseModel):
364+
author: str
365+
title: str
366+
genre: BookType
367+
368+
# Constrained decoding parameters
369+
guided_decoding_params = GuidedDecodingParams(json=BookDescription.model_json_schema())
370+
371+
# Sampling parameters
372+
sampling_params = SamplingParams(
373+
top_p=0.95,
374+
max_tokens=6400,
375+
guided_decoding=guided_decoding_params,
376+
)
377+
378+
# Load model
379+
llm = LLM(model="ERNIE-4.5-0.3B", tensor_parallel_size=1, max_model_len=8192, guided_decoding_backend="auto")
380+
381+
outputs = llm.generate(
382+
prompts="生成一个JSON,描述一本中国的著作,要包含作者、标题和书籍类型。",
383+
sampling_params=sampling_params,
384+
)
385+
386+
# Output results
387+
for output in outputs:
388+
print(output.outputs.text)
389+
390+
```
391+
392+
输出
393+
394+
```
395+
{"author": "曹雪芹", "title": "红楼梦", "genre": "Historical"}
396+
```

fastdeploy/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def __init__(
8383
self.dtype = ""
8484
self.enable_logprob = False
8585
self.enable_mm = False
86+
self.reasoning_parser = None
8687

8788
for key, value in args.items():
8889
if hasattr(self, key):

fastdeploy/engine/config.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from fastdeploy.platforms import current_platform
2525
from fastdeploy.scheduler import SchedulerConfig
2626
from fastdeploy.utils import (ceil_div, check_unified_ckpt, get_host_ip,
27-
is_port_available, get_random_port, llm_logger)
27+
get_random_port, is_port_available, llm_logger)
2828

2929
TaskOption = Literal["generate"]
3030

@@ -701,7 +701,7 @@ def __init__(
701701
self.max_num_batched_tokens = max_num_batched_tokens
702702
self.tensor_parallel_size = tensor_parallel_size
703703
self.dist_init_ip = dist_init_ip
704-
704+
705705
self.nnode = nnodes
706706
self.node_rank = node_rank
707707
if self.dist_init_ip is None:
@@ -805,7 +805,8 @@ def postprocess(self):
805805
self.max_model_len // self.cache_config.block_size)
806806

807807
if self.guided_decoding_backend == "auto":
808-
if self.enable_mm:
808+
if current_platform.is_xpu() or self.speculative_config.method is not None:
809+
llm_logger.warning("Speculative Decoding and XPU currently do not support Guided decoding, set off.")
809810
self.guided_decoding_backend = "off"
810811
else:
811812
self.guided_decoding_backend = "xgrammar"
@@ -872,10 +873,10 @@ def check(self):
872873
f"Only support xgrammar、auto guided decoding backend, but got {self.guided_decoding_backend}."
873874

874875
if self.guided_decoding_backend != "off":
875-
# TODO: mm support guided_decoding
876-
assert self.enable_mm is False, "Multimodal model currently do not support guided_decoding"
877876

878877
# TODO: speculative decoding support guided_decoding
878+
assert self.speculative_config.method is None, \
879+
"speculative decoding currently do not support guided_decoding"
879880

880881
# TODO: xpu support guided_decoding
881882
assert not current_platform.is_xpu(
@@ -907,7 +908,8 @@ def print(self, file=None):
907908
k == "model_config" or
908909
k == "scheduler_config" or
909910
k == "parallel_config" or
910-
k == "commit_config"):
911+
k == "commit_config" or
912+
k == "speculative_config"):
911913
v.print()
912914
else:
913915
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))

fastdeploy/engine/engine.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -363,10 +363,16 @@ def _insert_zmq_task_to_scheduler(self):
363363
request = Request.from_dict(data)
364364
start_span("ENQUEUE_ZMQ", data, trace.SpanKind.PRODUCER)
365365

366-
367366
llm_logger.debug(f"Receive request: {request}")
368367

369368
err_msg = None
369+
if ((request.guided_json is not None
370+
or request.guided_regex is not None
371+
or request.structural_tag is not None
372+
or request.guided_grammar is not None) and self.guided_decoding_checker is None):
373+
err_msg = "guided_backend is None, use --guided-decoding-backend to " \
374+
"specify the backend at server startup."
375+
370376
if self.guided_decoding_checker is not None:
371377
request, err_msg = self.guided_decoding_checker.schema_format(
372378
request)
@@ -455,6 +461,14 @@ def add_requests(self, task, sampling_params=None, **kwargs):
455461
llm_logger.error(error_msg)
456462
raise EngineError(error_msg, error_code=400)
457463

464+
if ((request.guided_json is not None
465+
or request.guided_regex is not None
466+
or request.structural_tag is not None
467+
or request.guided_grammar is not None) and self.guided_decoding_checker is None):
468+
err_msg = "guided_backend is None, use --guided-decoding-backend to specify the backend at server startup."
469+
llm_logger.error(err_msg)
470+
raise EngineError(err_msg, error_code=400)
471+
458472
if self.guided_decoding_checker is not None:
459473
request, err_msg = self.guided_decoding_checker.schema_format(
460474
request)
@@ -1036,8 +1050,8 @@ def _start_worker_service(self):
10361050
f" --speculative_benchmark_mode {self.cfg.speculative_config.benchmark_mode}"
10371051
f" --graph_optimization_config '{self.cfg.graph_optimization_config.to_json_string()}'"
10381052
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}"
1039-
f" --load_strategy {self.cfg.model_config.load_strategy}")
1040-
1053+
f" --load_strategy {self.cfg.model_config.load_strategy}"
1054+
f" --reasoning_parser {self.cfg.reasoning_parser}")
10411055

10421056
worker_append_flag = {
10431057
"enable_expert_parallel":

fastdeploy/engine/sampling_params.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ class SamplingParams:
9292
min_tokens: int = 1
9393
logprobs: Optional[int] = None
9494
bad_words: Optional[List[str]] = None
95+
guided_decoding: Optional[GuidedDecodingParams] = None
9596

9697
@classmethod
9798
def from_dict(cls, req_dict: dict[str, Any]) -> "SamplingParams":
@@ -121,7 +122,8 @@ def from_optional(cls,
121122
reasoning_max_tokens=None,
122123
min_tokens=1,
123124
logprobs=None,
124-
bad_words=None) -> "SamplingParams":
125+
bad_words=None,
126+
guided_decoding=None) -> "SamplingParams":
125127
"""Create instance from command line arguments"""
126128
return cls(n=1 if n is None else n,
127129
best_of=best_of,
@@ -141,7 +143,8 @@ def from_optional(cls,
141143
reasoning_max_tokens=reasoning_max_tokens,
142144
min_tokens=min_tokens,
143145
logprobs=logprobs,
144-
bad_words=bad_words)
146+
bad_words=bad_words,
147+
guided_decoding=guided_decoding)
145148

146149
def __post_init__(self):
147150
if self.seed is None:
@@ -224,3 +227,45 @@ class BeamSearchParams:
224227
temperature: float = 0.0
225228
length_penalty: float = 1.0
226229
include_stop_str_in_output: bool = False
230+
231+
232+
@dataclass
233+
class GuidedDecodingParams:
234+
"""Guided decoding parameters for text generation."""
235+
json: Optional[Union[str, dict]] = None
236+
regex: Optional[str] = None
237+
choice: Optional[List[str]] = None
238+
grammar: Optional[str] = None
239+
json_object: Optional[bool] = None
240+
structural_tag: Optional[str] = None
241+
242+
def to_dict(self):
243+
"""convert to dict"""
244+
key_dict = {
245+
"guided_json": self.json,
246+
"guided_regex": self.regex,
247+
"guided_choice": self.choice,
248+
"guided_grammar": self.grammar,
249+
"structural_tag": self.structural_tag,
250+
"guided_json_object": self.json_object,
251+
}
252+
253+
guided_dict = {}
254+
for key, value in key_dict.items():
255+
if value is not None:
256+
guided_dict[key] = value
257+
return guided_dict
258+
259+
def __post_init__(self):
260+
"""Verify the arguments."""
261+
guided_count = sum([
262+
self.json is not None, self.regex is not None, self.choice
263+
is not None, self.grammar is not None, self.json_object
264+
is not None, self.structural_tag is not None
265+
])
266+
267+
if guided_count > 1:
268+
raise ValueError(
269+
"You can only use one kind of guided decoding "
270+
"('json', 'json_object', 'regex', 'choice', 'grammar', 'structural_tag')."
271+
)

fastdeploy/entrypoints/llm.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __init__(
8989
self._receive_output_thread = threading.Thread(
9090
target=self._receive_output, daemon=True)
9191
self._receive_output_thread.start()
92-
92+
9393
def _check_master(self):
9494
"""
9595
Check if the current node is the master node.
@@ -198,7 +198,7 @@ def chat(
198198
if not self._check_master():
199199
err_msg = f"Only master node can accept completion request, please send request to master node: {self.master_node_ip}"
200200
raise ValueError(err_msg)
201-
201+
202202
if sampling_params is None:
203203
sampling_params = self.default_sampling_params
204204

@@ -275,6 +275,9 @@ def _add_request(
275275
if chat_template_kwargs is not None:
276276
enable_thinking = chat_template_kwargs.get(
277277
"enable_thinking", None)
278+
if current_sampling_params.guided_decoding is not None:
279+
guided_decoding_dict = current_sampling_params.guided_decoding.to_dict()
280+
tasks.update(guided_decoding_dict)
278281
self.llm_engine.add_requests(tasks,
279282
current_sampling_params,
280283
enable_thinking=enable_thinking)

fastdeploy/input/ernie_processor.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ def __init__(self, model_name_or_path, reasoning_parser_obj=None):
6060
self.eos_token_ids = [self.tokenizer.eos_token_id]
6161
self.eos_token_id_len = len(self.eos_token_ids)
6262
self.pad_token_id = self.get_pad_id()
63-
self.reasoning_parser = None
6463
if reasoning_parser_obj:
6564
self.reasoning_parser = reasoning_parser_obj(self.tokenizer)
6665

@@ -100,7 +99,6 @@ def process_request(self, request, max_model_len=None, **kwargs):
10099

101100
if request.prompt_token_ids is None or len(
102101
request.prompt_token_ids) == 0:
103-
system = request.get("system")
104102
if request.prompt is None and request.messages is None:
105103
raise ValueError(
106104
f"The request should have `input_ids`, `text` or `messages`: {request}.")
@@ -149,7 +147,6 @@ def process_request_dict(self, request, max_model_len=None):
149147
request['stop_token_ids'] = stop_seqs
150148
request['stop_seqs_len'] = stop_seqs_len
151149

152-
system = request.get("system")
153150
# 处理prompt_token_ids
154151
if not request.get('prompt_token_ids'):
155152
if request.get('prompt') is None and request.get(
@@ -213,7 +210,7 @@ def process_response(self, response_dict, **kwargs):
213210
response_dict.outputs.reasoning_content = reasoning_content
214211
else:
215212
response_dict.outputs.text = full_text
216-
data_processor_logger.info(f"req_id:{req_id}, token)ids: {token_ids}")
213+
data_processor_logger.info(f"req_id:{req_id}, token ids: {token_ids}")
217214
if response_dict.outputs.text == "" and \
218215
response_dict.outputs.reasoning_content == "":
219216
return None
@@ -278,7 +275,6 @@ def process_response_dict_streaming(self, response_dict, **kwargs):
278275
Returns:
279276
Dict: response contain text fields
280277
"""
281-
enable_thinking = kwargs.get("enable_thinking")
282278
is_end = response_dict["finished"]
283279
req_id = response_dict["request_id"]
284280
token_ids = response_dict["outputs"]["token_ids"]
@@ -288,6 +284,8 @@ def process_response_dict_streaming(self, response_dict, **kwargs):
288284
token_ids = token_ids[:-1]
289285
delta_text, previous_token_ids, previous_texts = self.ids2tokens(
290286
token_ids, req_id)
287+
288+
enable_thinking = self.get_enable_thinking(kwargs.get("enable_thinking"))
291289
if enable_thinking and self.reasoning_parser:
292290
reasoning_content, text = self.reasoning_parser.extract_reasoning_content_streaming(
293291
previous_texts, previous_texts + delta_text, delta_text,

0 commit comments

Comments
 (0)