diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 9c59b8bab9..034bf837b6 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -18,17 +18,21 @@ import json import os -from dataclasses import dataclass, field +from dataclasses import dataclass +from datetime import datetime from enum import Enum -from typing import Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union +import paddle from paddleformers.transformers.configuration_utils import PretrainedConfig import fastdeploy from fastdeploy import envs from fastdeploy.model_executor.layers.quantization.quant_base import QuantConfigBase +from fastdeploy.multimodal.registry import MultimodalRegistry from fastdeploy.platforms import current_platform -from fastdeploy.utils import check_unified_ckpt, get_logger +from fastdeploy.scheduler import SchedulerConfig +from fastdeploy.utils import ceil_div, check_unified_ckpt, get_host_ip, get_logger logger = get_logger("config", "config.log") @@ -120,7 +124,6 @@ def __init__( self.max_model_len = 0 self.dtype = "" self.enable_logprob = False - self.enable_mm = False self.enable_redundant_experts = False self.redundant_experts_num = 0 self.seed = 0 @@ -154,6 +157,12 @@ def __init__( if ErnieArchitectures.contains_ernie_arch(self.architectures): self.ori_vocab_size = args.get("ori_vocab_size", self.ori_vocab_size) + architectures = self.architectures[0] + if MultimodalRegistry.contains_model(architectures): + self.enable_mm = True + else: + self.enable_mm = False + self.is_unified_ckpt = check_unified_ckpt(self.model) self.override_name_from_config() @@ -934,19 +943,54 @@ class FDConfig: simplifies passing around the distinct configurations in the codebase. """ - model_config: ModelConfig = field(default=None, init=True) # type: ignore - - parallel_config: ParallelConfig = field(default=None, init=True) - speculative_config: SpeculativeConfig = field(default=None, init=True) # type: ignore - device_config: DeviceConfig = field(default=None, init=True) # type: ignore - load_config: LoadConfig = field(default=None, init=True) - quant_config: Optional[QuantConfigBase] = None - graph_opt_config: Optional[GraphOptimizationConfig] = None - early_stop_config: Optional[EarlyStopConfig] = None - decoding_config: DecodingConfig = field(default=None, init=True) # type: ignore - cache_config: CacheConfig = field(default=None, init=True) # type: ignore - - def __post_init__(self): + def __init__( + self, + model_config: ModelConfig = None, + cache_config: CacheConfig = None, + parallel_config: ParallelConfig = None, + load_config: LoadConfig = None, + commit_config: CommitConfig = CommitConfig(), + scheduler_config: SchedulerConfig = None, + device_config: DeviceConfig = None, + decoding_config: DecodingConfig = None, + quant_config: QuantConfigBase = None, + graph_opt_config: GraphOptimizationConfig = None, + speculative_config: SpeculativeConfig = None, + tokenizer: str = None, + max_model_len: int = 8192, + max_num_seqs: int = 8, + max_num_batched_tokens: Optional[int] = None, + ips: str = None, + use_warmup: bool = False, + engine_worker_queue_port: int = 8002, + limit_mm_per_prompt: Optional[Dict[str, Any]] = None, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, + splitwise_role: str = "mixed", + innode_prefill_ports: Optional[List[int]] = None, + max_num_partial_prefills: int = 1, + max_long_partial_prefills: int = 1, + long_prefill_token_threshold: int = 0, + reasoning_parser: str = None, + guided_decoding_backend: Optional[str] = None, + disable_any_whitespace: bool = False, + early_stop_config: Optional[Dict[str, Any]] = None, + tool_parser: str = None, + test_mode=False, + ): + self.model_config: ModelConfig = model_config # type: ignore + self.cache_config: CacheConfig = cache_config # type: ignore + self.scheduler_config: SchedulerConfig = scheduler_config # type: ignore + self.parallel_config = parallel_config # type: ignore + self.speculative_config: SpeculativeConfig = speculative_config + self.device_config: DeviceConfig = device_config # type: ignore + self.load_config: LoadConfig = load_config + self.quant_config: Optional[QuantConfigBase] = quant_config + self.graph_opt_config: Optional[GraphOptimizationConfig] = graph_opt_config + self.early_stop_config: Optional[EarlyStopConfig] = early_stop_config + self.decoding_config: DecodingConfig = decoding_config # type: ignore + self.cache_config: CacheConfig = cache_config # type: ignore + if test_mode: + return # Initialize cuda graph capture list if self.graph_opt_config.cudagraph_capture_sizes is None: self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.parallel_config.max_num_seqs) @@ -955,3 +999,289 @@ def __post_init__(self): # TODO(wangmingkai02): change graph_opt_level=2 when using static mode with cinn if self.graph_opt_config.graph_opt_level == 2: self.graph_opt_config.graph_opt_level = 1 + + self.tokenizer = tokenizer + self.max_num_batched_tokens = max_num_batched_tokens + self.ips = ips + self.tool_parser = tool_parser + + if self.ips is None: + self.master_ip = "0.0.0.0" + elif isinstance(self.ips, list): + self.master_ip = self.ips[0] + else: + self.ips = self.ips.split(",") + self.master_ip = self.ips[0] + + if self.ips is None: + self.nnode = 1 + self.node_rank = 0 + else: + self.nnode = len(self.ips) + + for idx, ip in enumerate(self.ips): + if ip == self.master_ip: + self.node_rank = idx + + self.max_model_len = max_model_len + self.max_num_seqs = max_num_seqs + self.limit_mm_per_prompt = limit_mm_per_prompt + self.mm_processor_kwargs = mm_processor_kwargs + self.use_warmup = use_warmup + self.splitwise_role = splitwise_role + self.innode_prefill_ports = innode_prefill_ports + self.max_num_partial_prefills = max_num_partial_prefills + self.max_long_partial_prefills = max_long_partial_prefills + self.long_prefill_token_threshold = long_prefill_token_threshold + self.reasoning_parser = reasoning_parser + self.guided_decoding_backend = guided_decoding_backend + self.disable_any_whitespace = disable_any_whitespace + self._str_to_list("innode_prefill_ports", int) + + assert self.splitwise_role in ["mixed", "prefill", "decode"] + import fastdeploy.model_executor.models # noqa: F401 + + # TODO + self.max_prefill_batch = 3 + if current_platform.is_xpu(): + self.max_prefill_batch = 1 + if self.model_config.enable_mm: + self.max_prefill_batch = 1 # TODO:当前多模prefill阶段只支持并行度为1,待优化 + + # TODO(@wufeisheng): TP and EP need to be supported simultaneously. + assert (self.parallel_config.tensor_parallel_size == 1 and self.parallel_config.expert_parallel_size >= 1) or ( + self.parallel_config.tensor_parallel_size >= 1 and self.parallel_config.expert_parallel_size == 1 + ), "TP and EP cannot be enabled at the same time" + + num_ranks = self.parallel_config.tensor_parallel_size * self.parallel_config.expert_parallel_size + self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8 + if num_ranks > self.max_chips_per_node: + self.worker_num_per_node = self.max_chips_per_node + nnode = ceil_div(num_ranks, self.worker_num_per_node) + assert nnode == self.nnode, f"nnode: {nnode}, but got {self.nnode}" + else: + self.worker_num_per_node = num_ranks + + self.engine_worker_queue_port = engine_worker_queue_port + self.device_ids = ",".join([str(i) for i in range(self.worker_num_per_node)]) + self.device_ids = os.getenv("CUDA_VISIBLE_DEVICES", self.device_ids) + if current_platform.is_xpu(): + self.device_ids = os.getenv("XPU_VISIBLE_DEVICES", self.device_ids) + + self.read_from_config() + self.postprocess() + self.check() + self.print() + + def postprocess(self): + """ + calculate some parameters + """ + assert ( + self.device_ids.split(",").__len__() == self.worker_num_per_node + ), f"invalid CUDA_VISIBLE_DEVICES, should be equal to {self.worker_num_per_node}" + + self.local_device_ids = self.device_ids.split(",")[: self.parallel_config.tensor_parallel_size] + + self.host_ip = get_host_ip() + + if self.ips is None or self.host_ip == self.master_ip: + self.is_master = True + else: + self.is_master = False + + if self.parallel_config.tensor_parallel_size <= self.worker_num_per_node: + self.is_master = True + + self.paddle_commit_id = paddle.version.commit + + if self.max_num_batched_tokens is None: + if self.cache_config.enable_chunked_prefill: + self.max_num_batched_tokens = 2048 + else: + if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")): + self.max_num_batched_tokens = self.max_model_len + else: + self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM + + if self.long_prefill_token_threshold == 0: + self.long_prefill_token_threshold = int(self.max_model_len * 0.04) + + self.cache_config.postprocess(self.max_num_batched_tokens, self.max_num_seqs) + self.cache_config.max_block_num_per_seq = int(self.max_model_len // self.cache_config.block_size) + + if self.guided_decoding_backend == "auto": + if self.model_config.enable_mm: + self.guided_decoding_backend = "off" + else: + self.guided_decoding_backend = "xgrammar" + + def check(self): + """ + check the legality of config + """ + assert self.max_num_seqs <= 256, ( + "The parameter `max_num_seqs` is not allowed to exceed 256, " f"but now it's {self.max_num_seqs}." + ) + assert self.nnode >= 1, f"nnode: {self.nnode} should no less than 1" + assert self.max_model_len >= 16, f"max_model_len: {self.max_model_len} should be larger than 16" + assert self.max_num_seqs >= 1, f"max_num_seqs: {self.max_num_seqs} should be larger than 1" + assert self.max_num_batched_tokens >= self.max_num_seqs, ( + f"max_num_batched_tokens: {self.max_num_batched_tokens} " + f"should be larger than or equal to max_num_seqs: {self.max_num_seqs}" + ) + assert self.max_num_batched_tokens <= self.max_model_len * self.max_num_seqs, ( + f"max_num_batched_tokens: {self.max_num_batched_tokens} should be larger" + f"than or equal to max_num_seqs: {self.max_num_seqs} * max_model_len: {self.max_model_len}" + ) + assert ( + self.max_num_partial_prefills >= 1 + ), f"max_num_partial_prefills: {self.max_num_partial_prefills} should be larger than or equal to 1" + + assert ( + self.max_long_partial_prefills >= 1 + ), f"max_long_partial_prefills: {self.max_long_partial_prefills} should be larger than or equal to 1" + assert self.max_long_partial_prefills <= self.max_num_partial_prefills, ( + f"max_long_partial_prefills: {self.max_long_partial_prefills} should " + f"be less than or equal to max_num_partial_prefills: {self.max_num_partial_prefills}" + ) + + if not self.cache_config.enable_chunked_prefill: + if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")): + assert self.max_num_batched_tokens >= self.max_model_len, ( + f"max_num_batched_tokens: {self.max_num_batched_tokens} " + f"should be larger than or equal to max_model_len: {self.max_model_len}" + ) + else: + assert self.max_num_batched_tokens >= self.cache_config.block_size, ( + f"max_num_batched_tokens: {self.max_num_batched_tokens} " + f"should be larger than or equal to block_size: {self.cache_config.block_size}" + ) + + if self.max_num_partial_prefills > 1: + assert ( + self.cache_config.enable_chunked_prefill is True + ), "Chunked prefill must be enabled to set max_num_partial_prefills > 1" + assert self.long_prefill_token_threshold < self.max_model_len, ( + f"long_prefill_token_threshold: {self.long_prefill_token_threshold} should be less than" + f" max_model_len: {self.max_model_len}" + ) + + if self.guided_decoding_backend is not None: + assert self.guided_decoding_backend in [ + "xgrammar", + "XGrammar", + "auto", + "off", + ], f"Only support xgrammar、auto guided decoding backend, but got {self.guided_decoding_backend}." + + if self.guided_decoding_backend != "off": + # TODO: mm support guided_decoding + assert ( + self.model_config.enable_mm is False + ), "Multimodal model currently do not support guided_decoding" + + # TODO: speculative decoding support guided_decoding + + # TODO: xpu support guided_decoding + assert not current_platform.is_xpu(), "XPU currently do not support guided_decoding" + + try: + import xgrammar # noqa + except Exception as e: + raise Exception( + f"import XGrammar failed, please install XGrammar use `pip install xgrammar==0.1.19`. \n\t {e}" + ) + if self.scheduler_config is not None: + self.scheduler_config.check() + + def print(self, file=None): + """ + print all config + + Args: + file (str): the path of file to save config + """ + logger.info("=================== Configuration Information ===============") + for k, v in self.__dict__.items(): + if k == "generation_config" and v is not None: + for gck, gcv in v.to_dict().items(): + logger.info("{:<20}:{:<6}{}".format(gck, "", gcv)) + elif ( + k == "cache_config" + or k == "model_config" + or k == "scheduler_config" + or k == "parallel_config" + or k == "commit_config" + ): + if v is not None: + v.print() + else: + logger.info("{:<20}:{:<6}{}".format(k, "", v)) + logger.info("=============================================================") + if file is not None: + f = open(file, "a") + now_time = datetime.now() + f.write(f"{now_time} configuration information as below,\n") + for k, v in self.__dict__.items(): + f.write("{:<20}:{:<6}{}\n".format(k, "", v)) + f.close() + + def init_cache_info(self): + """ + initialize cache info + """ + disaggregate_info = {} + if self.splitwise_role != "mixed": + disaggregate_info["role"] = self.splitwise_role + disaggregate_info["cache_info"] = dict() + current_protocol = self.cache_config.cache_transfer_protocol.split(",") + disaggregate_info["transfer_protocol"] = current_protocol + for protocol in current_protocol: + if protocol == "ipc": + disaggregate_info["cache_info"][protocol] = { + "ip": self.host_ip, + "port": self.engine_worker_queue_port, + "device_ids": self.local_device_ids, + } + elif protocol == "rdma": + disaggregate_info["cache_info"][protocol] = { + "ip": self.host_ip, + "port": self.cache_config.pd_comm_port[0], + "rdma_port": self.cache_config.rdma_comm_ports, + } + self.disaggregate_info = disaggregate_info + logger.info(f"disaggregate_info: {self.disaggregate_info}") + + def read_from_config(self): + """ + reset model config from json file + """ + + def reset_value(cls, value_name, key): + if hasattr(cls, key): + value = getattr(cls, key) + setattr(cls, value_name, value) + logger.info(f"Reset parameter {value_name} = {value} from configuration.") + + reset_value(self.cache_config, "block_size", "infer_model_block_size") + reset_value( + self.model_config, + "return_full_hidden_states", + "return_full_hidden_states", + ) + reset_value(self.cache_config, "cache_dtype", "infer_model_dtype") + + def _check_master(self): + return self.is_master + + def _str_to_list(self, attr_name, default_type): + if hasattr(self, attr_name): + val = getattr(self, attr_name) + if type(val) is str: + setattr(self, attr_name, [default_type(i) for i in val.split(",")]) + else: + setattr(self, attr_name, val) + + def __str__(self) -> str: + return json.dumps(self.__dict__, indent=4) diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index af7b3ffb08..e3f2d95aa7 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -23,6 +23,7 @@ from fastdeploy.config import ( CacheConfig, EarlyStopConfig, + FDConfig, GraphOptimizationConfig, LoadConfig, ModelConfig, @@ -30,10 +31,13 @@ SpeculativeConfig, TaskOption, ) -from fastdeploy.engine.config import Config from fastdeploy.platforms import current_platform from fastdeploy.scheduler.config import SchedulerConfig -from fastdeploy.utils import DeprecatedOptionWarning, FlexibleArgumentParser +from fastdeploy.utils import ( + DeprecatedOptionWarning, + FlexibleArgumentParser, + is_port_available, +) def nullable_str(x: str) -> Optional[str]: @@ -902,7 +906,7 @@ def create_early_stop_config(self) -> EarlyStopConfig: early_stop_args[k] = v return EarlyStopConfig(early_stop_args) - def create_engine_config(self) -> Config: + def create_engine_config(self) -> FDConfig: """ Create and return a Config object based on the current settings. """ @@ -937,8 +941,11 @@ def create_engine_config(self) -> Config: self.tensor_parallel_size <= 1 and self.enable_custom_all_reduce ), "enable_custom_all_reduce must be used with tensor_parallel_size>1" - return Config( - model_name_or_path=self.model, + assert is_port_available( + "0.0.0.0", self.engine_worker_queue_port + ), f"The parameter `engine_worker_queue_port`:{self.engine_worker_queue_port} is already in use." + + return FDConfig( model_config=model_cfg, scheduler_config=scheduler_cfg, tokenizer=self.tokenizer, @@ -946,7 +953,6 @@ def create_engine_config(self) -> Config: load_config=load_cfg, parallel_config=parallel_cfg, max_model_len=self.max_model_len, - tensor_parallel_size=self.tensor_parallel_size, max_num_seqs=self.max_num_seqs, speculative_config=speculative_cfg, max_num_batched_tokens=self.max_num_batched_tokens, @@ -955,7 +961,6 @@ def create_engine_config(self) -> Config: engine_worker_queue_port=self.engine_worker_queue_port, limit_mm_per_prompt=self.limit_mm_per_prompt, mm_processor_kwargs=self.mm_processor_kwargs, - # enable_mm=self.enable_mm, reasoning_parser=self.reasoning_parser, tool_parser=self.tool_call_parser, splitwise_role=self.splitwise_role, @@ -963,10 +968,8 @@ def create_engine_config(self) -> Config: max_num_partial_prefills=self.max_num_partial_prefills, max_long_partial_prefills=self.max_long_partial_prefills, long_prefill_token_threshold=self.long_prefill_token_threshold, - graph_optimization_config=graph_opt_cfg, + graph_opt_config=graph_opt_cfg, guided_decoding_backend=self.guided_decoding_backend, disable_any_whitespace=self.guided_decoding_disable_any_whitespace, - enable_logprob=self.enable_logprob, early_stop_config=early_stop_cfg, - load_choices=self.load_choices, ) diff --git a/fastdeploy/engine/config.py b/fastdeploy/engine/config.py deleted file mode 100644 index 7b6d1bffb4..0000000000 --- a/fastdeploy/engine/config.py +++ /dev/null @@ -1,435 +0,0 @@ -""" -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License" -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" - -import json -import os -from datetime import datetime -from typing import Any, Dict, List, Optional - -from fastdeploy.config import ( - CacheConfig, - CommitConfig, - LoadConfig, - ModelConfig, - ParallelConfig, -) -from fastdeploy.multimodal.registry import MultimodalRegistry -from fastdeploy.platforms import current_platform -from fastdeploy.scheduler import SchedulerConfig -from fastdeploy.utils import ceil_div, get_host_ip, is_port_available, llm_logger - - -class Config: - """ - Initial configuration class. - - Attributes: - model_config (ModelConfig): Model configuration object. - cache_config (CacheConfig): Cache configuration object. - model_name_or_path (str): Directory path to the model or the model name. - tokenizer (Optional[str]): Default is the model. - max_num_batched_tokens (Optional[int]): Maximum number of batched tokens. - tensor_parallel_size (int): Tensor parallel size. - nnode (int): Number of nodes. - max_model_len (int): Maximum model length. Default is 8192. - max_num_seqs (int): Maximum number of sequences. Default is 8. - mm_processor_kwargs (Optional[Dict[str, Any]]): Additional arguments for multi-modal processor. - speculative_config (Optional[Dict[str, Any]]): Speculative execution configuration. - use_warmup (bool): Flag to use warmup. - engine_worker_queue_port (int): Port for engine worker queue. - enable_mm (bool): Flag to enable multi-modal processing. - reasoning_parser(str): Flag specifies the reasoning parser to use for - extracting reasoning content from the model output - splitwise_role (str): Splitwise role. - innode_prefill_ports (Optional[List[int]]): Innode prefill ports. - Temporary configuration, will be removed in the future. - load_choices(str):The format of the model weights to load. .Default is default - """ - - def __init__( - self, - model_config: ModelConfig, - cache_config: CacheConfig, - scheduler_config: SchedulerConfig, - parallel_config: ParallelConfig, - load_config: LoadConfig, - commit_config: CommitConfig = CommitConfig(), - model_name_or_path: str = None, - tokenizer: str = None, - tensor_parallel_size: int = 8, - max_model_len: int = 8192, - max_num_seqs: int = 8, - max_num_batched_tokens: Optional[int] = None, - ips: str = None, - speculative_config: Optional[Dict[str, Any]] = None, - graph_optimization_config: Optional[Dict[str, Any]] = None, - use_warmup: bool = False, - engine_worker_queue_port: int = 8002, - limit_mm_per_prompt: Optional[Dict[str, Any]] = None, - mm_processor_kwargs: Optional[Dict[str, Any]] = None, - # enable_mm: bool = False, - splitwise_role: str = "mixed", - innode_prefill_ports: Optional[List[int]] = None, - max_num_partial_prefills: int = 1, - max_long_partial_prefills: int = 1, - long_prefill_token_threshold: int = 0, - reasoning_parser: str = None, - tool_parser: str = None, - guided_decoding_backend: Optional[str] = None, - disable_any_whitespace: bool = False, - enable_logprob: bool = False, - early_stop_config: Optional[Dict[str, Any]] = None, - load_choices: str = "default", - ): - """ - Initialize the Config class. - - Args: - model_config (ModelConfig): Model configuration object. - cache_config (CacheConfig): Cache configuration object. - parallel_config (ParallelConfig): Parallel configuration object. - scheduler_config (SchedulerConfig): Scheduler configuration object. - model_name_or_path (str): Model directory path or model name. - tokenizer (str): Default is the model. - tensor_parallel_size (int): Tensor parallel size. Default is 8. - max_model_len (int): Maximum model length. Default is 8192. - max_num_seqs (int): Maximum number of sequences. Default is 8. - max_num_batched_tokens (Optional[int]): Maximum number of batched tokens. Default is None. - mm_processor_kwargs (Optional[Dict[str, Any]]): Additional arguments for multi-modal processor. Default is None. - speculative_config (Optional[Dict[str, Any]]): Speculative execution configuration. Default is None. - graph_optimization_config (Optional[Dict[str, Any]]): Graph optimizaion backend execution configuration. Default is None. - use_warmup (bool): Flag to use warmup. Default is False. - engine_worker_queue_port (int): Engine worker queue port. Default is 8002. - enable_mm (bool): Flag to enable multi-modal processing. Default is False. - splitwise_role (str): Splitwise role. Default is "mixed". - innode_prefill_ports (Optional[List[int]]): Innode prefill ports. Default is None. - reasoning_parser (str): Flag specifies the reasoning parser to use for - extracting reasoning content from the model output. Default is None. - guided_decoding_backend(str): Guided decoding backend. Default is None. - disable_any_whitespace(bool): Disable any whitespace when using guided decoding. - Default is False. - enable_logprob(bool): Enable logprob. Default is False. - early_stop_config (Optional[Dict[str, Any]]): Early stop configuration. Default is None. - load_choices(str):The format of the model weights to load. .Default is default - """ - self.model_config = model_config - self.cache_config = cache_config - self.scheduler_config = scheduler_config - self.parallel_config = parallel_config - self.load_config = load_config - self.commit_config = commit_config - self.model_name_or_path = model_name_or_path - self.tokenizer = tokenizer - self.max_num_batched_tokens = max_num_batched_tokens - self.tensor_parallel_size = tensor_parallel_size - self.ips = ips - - if self.ips is None: - self.master_ip = "0.0.0.0" - elif isinstance(self.ips, list): - self.master_ip = self.ips[0] - else: - self.ips = self.ips.split(",") - self.master_ip = self.ips[0] - - if self.ips is None: - self.nnode = 1 - self.node_rank = 0 - else: - self.nnode = len(self.ips) - - for idx, ip in enumerate(self.ips): - if ip == self.master_ip: - self.node_rank = idx - - self.max_model_len = max_model_len - self.max_num_seqs = max_num_seqs - self.limit_mm_per_prompt = limit_mm_per_prompt - self.mm_processor_kwargs = mm_processor_kwargs - # self.enable_mm = enable_mm - self.speculative_config = speculative_config - self.use_warmup = use_warmup - self.splitwise_role = splitwise_role - self.innode_prefill_ports = innode_prefill_ports - self.max_num_partial_prefills = max_num_partial_prefills - self.max_long_partial_prefills = max_long_partial_prefills - self.long_prefill_token_threshold = long_prefill_token_threshold - self.reasoning_parser = reasoning_parser - self.tool_parser = tool_parser - self.graph_optimization_config = graph_optimization_config - self.early_stop_config = early_stop_config - self.guided_decoding_backend = guided_decoding_backend - self.disable_any_whitespace = disable_any_whitespace - self._str_to_list("innode_prefill_ports", int) - self.load_choices = load_choices - - assert self.splitwise_role in ["mixed", "prefill", "decode"] - - import fastdeploy.model_executor.models # noqa: F401 - - architectures = self.model_config.architectures[0] - if MultimodalRegistry.contains_model(architectures): - self.enable_mm = True - else: - self.enable_mm = False - - # TODO - self.max_prefill_batch = 3 - if current_platform.is_xpu(): - self.max_prefill_batch = 1 - if self.enable_mm: - self.max_prefill_batch = 1 # TODO:当前多模prefill阶段只支持并行度为1,待优化 - - # TODO(@wufeisheng): TP and EP need to be supported simultaneously. - assert (self.tensor_parallel_size == 1 and self.parallel_config.expert_parallel_size >= 1) or ( - self.tensor_parallel_size >= 1 and self.parallel_config.expert_parallel_size == 1 - ), "TP and EP cannot be enabled at the same time" - - num_ranks = self.tensor_parallel_size * self.parallel_config.expert_parallel_size - self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8 - if num_ranks > self.max_chips_per_node: - self.worker_num_per_node = self.max_chips_per_node - nnode = ceil_div(num_ranks, self.worker_num_per_node) - assert nnode == self.nnode, f"nnode: {nnode}, but got {self.nnode}" - else: - self.worker_num_per_node = num_ranks - - self.engine_worker_queue_port = engine_worker_queue_port - self.device_ids = ",".join([str(i) for i in range(self.worker_num_per_node)]) - self.device_ids = os.getenv("CUDA_VISIBLE_DEVICES", self.device_ids) - if current_platform.is_xpu(): - self.device_ids = os.getenv("XPU_VISIBLE_DEVICES", self.device_ids) - - self.enable_logprob = enable_logprob - - self.read_from_config() - self.postprocess() - self.check() - self.print() - - def postprocess(self): - """ - calculate some parameters - """ - assert ( - self.device_ids.split(",").__len__() == self.worker_num_per_node - ), f"invalid CUDA_VISIBLE_DEVICES, should be equal to {self.worker_num_per_node}" - - self.local_device_ids = self.device_ids.split(",")[: self.tensor_parallel_size] - - self.host_ip = get_host_ip() - - if self.ips is None or self.host_ip == self.master_ip: - self.is_master = True - else: - self.is_master = False - - if self.tensor_parallel_size <= self.worker_num_per_node: - self.is_master = True - - import paddle - - self.paddle_commit_id = paddle.version.commit - - if self.max_num_batched_tokens is None: - if self.cache_config.enable_chunked_prefill: - self.max_num_batched_tokens = 2048 - else: - if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")): - self.max_num_batched_tokens = self.max_model_len - else: - self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM - - if self.long_prefill_token_threshold == 0: - self.long_prefill_token_threshold = int(self.max_model_len * 0.04) - - self.cache_config.postprocess(self.max_num_batched_tokens, self.max_num_seqs) - self.cache_config.max_block_num_per_seq = int(self.max_model_len // self.cache_config.block_size) - - if self.guided_decoding_backend == "auto": - if self.enable_mm: - self.guided_decoding_backend = "off" - else: - self.guided_decoding_backend = "xgrammar" - - def check(self): - """ - check the legality of config - """ - assert self.max_num_seqs <= 256, ( - "The parameter `max_num_seqs` is not allowed to exceed 256, " f"but now it's {self.max_num_seqs}." - ) - assert is_port_available( - "0.0.0.0", self.engine_worker_queue_port - ), f"The parameter `engine_worker_queue_port`:{self.engine_worker_queue_port} is already in use." - assert self.nnode >= 1, f"nnode: {self.nnode} should no less than 1" - assert self.max_model_len >= 16, f"max_model_len: {self.max_model_len} should be larger than 16" - assert self.max_num_seqs >= 1, f"max_num_seqs: {self.max_num_seqs} should be larger than 1" - assert self.max_num_batched_tokens >= self.max_num_seqs, ( - f"max_num_batched_tokens: {self.max_num_batched_tokens} " - f"should be larger than or equal to max_num_seqs: {self.max_num_seqs}" - ) - assert self.max_num_batched_tokens <= self.max_model_len * self.max_num_seqs, ( - f"max_num_batched_tokens: {self.max_num_batched_tokens} should be larger" - f"than or equal to max_num_seqs: {self.max_num_seqs} * max_model_len: {self.max_model_len}" - ) - assert ( - self.max_num_partial_prefills >= 1 - ), f"max_num_partial_prefills: {self.max_num_partial_prefills} should be larger than or equal to 1" - - assert ( - self.max_long_partial_prefills >= 1 - ), f"max_long_partial_prefills: {self.max_long_partial_prefills} should be larger than or equal to 1" - assert self.max_long_partial_prefills <= self.max_num_partial_prefills, ( - f"max_long_partial_prefills: {self.max_long_partial_prefills} should " - f"be less than or equal to max_num_partial_prefills: {self.max_num_partial_prefills}" - ) - - if not self.cache_config.enable_chunked_prefill: - if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")): - assert self.max_num_batched_tokens >= self.max_model_len, ( - f"max_num_batched_tokens: {self.max_num_batched_tokens} " - f"should be larger than or equal to max_model_len: {self.max_model_len}" - ) - else: - assert self.max_num_batched_tokens >= self.cache_config.block_size, ( - f"max_num_batched_tokens: {self.max_num_batched_tokens} " - f"should be larger than or equal to block_size: {self.cache_config.block_size}" - ) - - if self.max_num_partial_prefills > 1: - assert ( - self.cache_config.enable_chunked_prefill is True - ), "Chunked prefill must be enabled to set max_num_partial_prefills > 1" - assert self.long_prefill_token_threshold < self.max_model_len, ( - f"long_prefill_token_threshold: {self.long_prefill_token_threshold} should be less than" - f" max_model_len: {self.max_model_len}" - ) - - if self.guided_decoding_backend is not None: - assert self.guided_decoding_backend in [ - "xgrammar", - "XGrammar", - "auto", - "off", - ], f"Only support xgrammar、auto guided decoding backend, but got {self.guided_decoding_backend}." - - if self.guided_decoding_backend != "off": - # TODO: mm support guided_decoding - assert self.enable_mm is False, "Multimodal model currently do not support guided_decoding" - - # TODO: speculative decoding support guided_decoding - - # TODO: xpu support guided_decoding - assert not current_platform.is_xpu(), "XPU currently do not support guided_decoding" - - try: - import xgrammar # noqa - except Exception as e: - raise Exception( - f"import XGrammar failed, please install XGrammar use `pip install xgrammar==0.1.19`. \n\t {e}" - ) - - self.scheduler_config.check() - - def print(self, file=None): - """ - print all config - - Args: - file (str): the path of file to save config - """ - llm_logger.info("=================== Configuration Information ===============") - for k, v in self.__dict__.items(): - if k == "generation_config" and v is not None: - for gck, gcv in v.to_dict().items(): - llm_logger.info("{:<20}:{:<6}{}".format(gck, "", gcv)) - elif ( - k == "cache_config" - or k == "model_config" - or k == "scheduler_config" - or k == "parallel_config" - or k == "commit_config" - ): - v.print() - else: - llm_logger.info("{:<20}:{:<6}{}".format(k, "", v)) - llm_logger.info("=============================================================") - if file is not None: - f = open(file, "a") - now_time = datetime.now() - f.write(f"{now_time} configuration information as below,\n") - for k, v in self.__dict__.items(): - f.write("{:<20}:{:<6}{}\n".format(k, "", v)) - f.close() - - def init_cache_info(self): - """ - initialize cache info - """ - disaggregate_info = {} - if self.splitwise_role != "mixed": - disaggregate_info["role"] = self.splitwise_role - disaggregate_info["cache_info"] = dict() - current_protocol = self.cache_config.cache_transfer_protocol.split(",") - disaggregate_info["transfer_protocol"] = current_protocol - for protocol in current_protocol: - if protocol == "ipc": - disaggregate_info["cache_info"][protocol] = { - "ip": self.host_ip, - "port": self.engine_worker_queue_port, - "device_ids": self.local_device_ids, - } - elif protocol == "rdma": - disaggregate_info["cache_info"][protocol] = { - "ip": self.host_ip, - "port": self.cache_config.pd_comm_port[0], - "rdma_port": self.cache_config.rdma_comm_ports, - } - self.disaggregate_info = disaggregate_info - llm_logger.info(f"disaggregate_info: {self.disaggregate_info}") - - def read_from_config(self): - """ - reset model config from json file - """ - - def reset_value(cls, value_name, key): - if hasattr(cls, key): - value = getattr(cls, key) - setattr(cls, value_name, value) - llm_logger.info(f"Reset parameter {value_name} = {value} from configuration.") - - reset_value(self.cache_config, "block_size", "infer_model_block_size") - reset_value( - self.model_config, - "return_full_hidden_states", - "return_full_hidden_states", - ) - reset_value(self.cache_config, "cache_dtype", "infer_model_dtype") - - def _check_master(self): - return self.is_master - - def _str_to_list(self, attr_name, default_type): - if hasattr(self, attr_name): - val = getattr(self, attr_name) - if type(val) is str: - setattr(self, attr_name, [default_type(i) for i in val.split(",")]) - else: - setattr(self, attr_name, val) - - def __str__(self) -> str: - return json.dumps(self.__dict__, indent=4) diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index db3bdefffe..3426396c3c 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -105,7 +105,7 @@ def __init__(self, cfg): cfg.reasoning_parser, cfg.limit_mm_per_prompt, cfg.mm_processor_kwargs, - cfg.enable_mm, + cfg.model_config.enable_mm, cfg.tool_parser, ) @@ -113,7 +113,7 @@ def __init__(self, cfg): if envs.ENABLE_V1_KVCACHE_SCHEDULER: self.resource_manager = ResourceManagerV1( - cfg.max_num_seqs, cfg, cfg.tensor_parallel_size, cfg.splitwise_role + cfg.max_num_seqs, cfg, cfg.parallel_config.tensor_parallel_size, cfg.splitwise_role ) if cfg.splitwise_role != "mixed": raise NotImplementedError( @@ -121,7 +121,7 @@ def __init__(self, cfg): ) else: self.resource_manager = ResourceManager( - cfg.max_num_seqs, cfg, cfg.tensor_parallel_size, cfg.splitwise_role + cfg.max_num_seqs, cfg, cfg.parallel_config.tensor_parallel_size, cfg.splitwise_role ) os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.cfg.engine_worker_queue_port) @@ -191,7 +191,7 @@ def start(self, api_server_pid=None): device_ids = self.cfg.device_ids.split(",") self.cache_manager_processes = self.resource_manager.cache_manager.launch_cache_manager( cache_config=self.cfg.cache_config, - tensor_parallel_size=self.cfg.tensor_parallel_size, + tensor_parallel_size=self.cfg.parallel_config.tensor_parallel_size, device_ids=device_ids, pod_ip=self.cfg.master_ip, engine_worker_queue_port=self.cfg.engine_worker_queue_port, @@ -387,7 +387,7 @@ def _insert_zmq_task_to_scheduler(self): while self.running: try: block = True if len(added_requests) == 0 else False - if not self.cfg.enable_mm: + if not self.cfg.model_config.enable_mm: err, data = self.zmq_server.receive_json_once(block) else: err, data = self.zmq_server.receive_pyobj_once(block) @@ -809,7 +809,7 @@ def insert_tasks(self, tasks, current_id=-1, allocated=False): for task in tasks: task.inference_start_time = time.time() if not is_prefill: - if not self.cfg.enable_mm: + if not self.cfg.model_config.enable_mm: self.update_requests_chunk_size(tasks) else: self.update_mm_requests_chunk_size(tasks) @@ -1049,7 +1049,7 @@ def _setting_environ_variables(self): if self.cfg.splitwise_role == "prefill": variables["FLAGS_fmt_write_cache_completed_signal"] = 1 - if self.cfg.enable_mm: + if self.cfg.model_config.enable_mm: variables["FLAGS_max_partition_size"] = 1024 command_prefix = "" @@ -1084,9 +1084,9 @@ def _start_worker_service(self): f" --devices {self.cfg.device_ids} {py_script}" f" --max_num_seqs {self.cfg.max_num_seqs} --max_model_len {self.cfg.max_model_len}" f" --gpu_memory_utilization {self.cfg.cache_config.gpu_memory_utilization}" - f" --model {self.cfg.model_name_or_path!s}" + f" --model {self.cfg.model_config.model!s}" f" --device_ids {self.cfg.device_ids}" - f" --tensor_parallel_size {self.cfg.tensor_parallel_size}" + f" --tensor_parallel_size {self.cfg.parallel_config.tensor_parallel_size}" f" --engine_worker_queue_port {self.cfg.engine_worker_queue_port!s}" f" --pod_ip {self.cfg.master_ip}" f" --total_block_num {self.cfg.cache_config.total_block_num}" @@ -1103,11 +1103,11 @@ def _start_worker_service(self): f" --quantization {self.cfg.model_config.quantization}" f" --ori_vocab_size {ori_vocab_size}" f" --speculative_config '{self.cfg.speculative_config.to_json_string()}'" - f" --graph_optimization_config '{self.cfg.graph_optimization_config.to_json_string()}'" + f" --graph_optimization_config '{self.cfg.graph_opt_config.to_json_string()}'" f" --guided_decoding_backend {self.cfg.guided_decoding_backend}" f" --load_strategy {self.cfg.load_config.load_strategy}" f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'" - f" --load_choices {self.cfg.load_choices}" + f" --load_choices {self.cfg.load_config.load_choices}" ) worker_append_flag = { @@ -1118,8 +1118,7 @@ def _start_worker_service(self): "dynamic_load_weight": self.cfg.load_config.dynamic_load_weight, "disable_any_whitespace": self.cfg.disable_any_whitespace, "enable_custom_all_reduce": self.cfg.parallel_config.enable_custom_all_reduce, - "enable_logprob": self.cfg.enable_logprob, - "enable_mm": self.cfg.enable_mm, + "enable_logprob": self.cfg.model_config.enable_logprob, } for worker_flag, value in worker_append_flag.items(): if value: @@ -1216,7 +1215,7 @@ def _stop_profile(self): device_ids = self.cfg.device_ids.split(",") self.cache_manager_processes = self.resource_manager.cache_manager.launch_cache_manager( cache_config=self.cfg.cache_config, - tensor_parallel_size=self.cfg.tensor_parallel_size, + tensor_parallel_size=self.cfg.parallel_config.tensor_parallel_size, device_ids=device_ids, pod_ip=self.cfg.master_ip, engine_worker_queue_port=self.cfg.engine_worker_queue_port, @@ -1370,7 +1369,7 @@ def start_queue_service(self): self.engine_worker_queue_server = EngineWorkerQueue( address=address, is_server=True, - num_client=self.cfg.tensor_parallel_size, + num_client=self.cfg.parallel_config.tensor_parallel_size, local_data_parallel_size=self.cfg.parallel_config.data_parallel_size, ) @@ -1382,7 +1381,7 @@ def start_queue_service(self): ), authkey=b"cache_queue_service", is_server=True, - num_client=self.cfg.tensor_parallel_size, + num_client=self.cfg.parallel_config.tensor_parallel_size, client_id=-1, local_data_parallel_size=self.cfg.parallel_config.data_parallel_size, ) @@ -1390,7 +1389,7 @@ def start_queue_service(self): self.engine_worker_queue = EngineWorkerQueue( address=address, is_server=False, - num_client=self.cfg.tensor_parallel_size, + num_client=self.cfg.parallel_config.tensor_parallel_size, client_id=0, local_data_parallel_size=self.cfg.parallel_config.data_parallel_size, local_data_parallel_id=min( diff --git a/fastdeploy/engine/expert_service.py b/fastdeploy/engine/expert_service.py index 9cf5f97f7f..b6a12b7927 100644 --- a/fastdeploy/engine/expert_service.py +++ b/fastdeploy/engine/expert_service.py @@ -50,8 +50,8 @@ def __init__(self, cfg, local_data_parallel_id): cfg (Config): Config object containing all the configuration parameters. """ self.cfg = cfg - start_pos = (local_data_parallel_id * self.cfg.tensor_parallel_size) % cfg.worker_num_per_node - end_pos = start_pos + self.cfg.tensor_parallel_size + start_pos = (local_data_parallel_id * self.cfg.parallel_config.tensor_parallel_size) % cfg.worker_num_per_node + end_pos = start_pos + self.cfg.parallel_config.tensor_parallel_size if cfg.splitwise_role != "mixed": self.cfg.cache_config.rdma_comm_ports = self.cfg.cache_config.rdma_comm_ports[start_pos:end_pos] self.cfg.local_device_ids = self.cfg.device_ids.split(",")[start_pos:end_pos] @@ -69,13 +69,13 @@ def __init__(self, cfg, local_data_parallel_id): address=address, is_server=False, client_id=0, - num_client=cfg.tensor_parallel_size, + num_client=cfg.parallel_config.tensor_parallel_size, local_data_parallel_id=local_data_parallel_id, ) self.resource_manager = ResourceManager( cfg.max_num_seqs, cfg, - cfg.tensor_parallel_size, + cfg.parallel_config.tensor_parallel_size, cfg.splitwise_role, local_data_parallel_id, ) @@ -125,7 +125,7 @@ def start(self, ipc_signal_suffix, local_data_parallel_id): if self.cfg.splitwise_role != "mixed": self.cache_manager_processes = self.resource_manager.cache_manager.launch_cache_manager( cache_config=self.cfg.cache_config, - tensor_parallel_size=self.cfg.tensor_parallel_size, + tensor_parallel_size=self.cfg.parallel_config.tensor_parallel_size, device_ids=self.cfg.local_device_ids, pod_ip=self.cfg.master_ip, engine_worker_queue_port=self.cfg.engine_worker_queue_port, @@ -343,7 +343,7 @@ def insert_tasks(self, tasks, current_id=-1, allocated=False): if not is_decode: llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}") if not is_prefill and self.cfg.cache_config.enable_chunked_prefill: - if not self.cfg.enable_mm: + if not self.cfg.model_config.enable_mm: self.update_requests_chunk_size(tasks) else: self.update_mm_requests_chunk_size(tasks) diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index daed93b8f9..ccebe5b375 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -20,7 +20,7 @@ import numpy as np from fastdeploy import envs -from fastdeploy.engine.config import ModelConfig +from fastdeploy.config import ModelConfig from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS from fastdeploy.input.preprocess import InputPreprocessor from fastdeploy.inter_communicator import IPCSignal, ZmqClient diff --git a/fastdeploy/input/preprocess.py b/fastdeploy/input/preprocess.py index d702090868..62ed8d62dd 100644 --- a/fastdeploy/input/preprocess.py +++ b/fastdeploy/input/preprocess.py @@ -16,8 +16,7 @@ from typing import Any, Dict, Optional -from fastdeploy.config import ErnieArchitectures -from fastdeploy.engine.config import ModelConfig +from fastdeploy.config import ErnieArchitectures, ModelConfig from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager from fastdeploy.reasoning import ReasoningParserManager diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py index 430ae64ae1..3789320fd9 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py @@ -43,7 +43,6 @@ Ernie4_5_MLP, ) from fastdeploy.model_executor.models.model_base import ModelForCasualLM -from fastdeploy.multimodal.registry import MultimodalRegistry from fastdeploy.platforms import current_platform if current_platform.is_cuda(): @@ -504,7 +503,6 @@ def forward( return out -@MultimodalRegistry.register_model() class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM): """ Ernie4_5_VLMoeForConditionalGeneration diff --git a/fastdeploy/multimodal/registry.py b/fastdeploy/multimodal/registry.py index 402e8d2040..74de853cce 100644 --- a/fastdeploy/multimodal/registry.py +++ b/fastdeploy/multimodal/registry.py @@ -22,7 +22,7 @@ class MultimodalRegistry: A registry for multimodal models """ - mm_models: set[str] = set() + mm_models: set[str] = {"Ernie4_5_VLMoeForConditionalGeneration"} @classmethod def register_model(cls, name: str = "") -> Callable: diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index ebb64cebc7..344c781778 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -57,7 +57,7 @@ def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_conn self.split_connector = split_connector self.speculative_decoding = self.cfg.speculative_config.method is not None - self.use_logprobs = self.cfg.enable_logprob + self.use_logprobs = self.cfg.model_config.enable_logprob if self.speculative_decoding: self.output_tokens = paddle.full( diff --git a/fastdeploy/splitwise/splitwise_connector.py b/fastdeploy/splitwise/splitwise_connector.py index 6b4c8ce04d..4dfc4ca6dc 100644 --- a/fastdeploy/splitwise/splitwise_connector.py +++ b/fastdeploy/splitwise/splitwise_connector.py @@ -319,7 +319,7 @@ def create_connection(self, port): """ self.connect_innode_instances[port] = EngineWorkerQueue( address=("0.0.0.0", int(port)), - num_client=self.cfg.tensor_parallel_size, + num_client=self.cfg.parallel_config.tensor_parallel_size, client_id=0, ) diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 2c80c27107..b9ac096489 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -587,7 +587,6 @@ def parse_args(): "'ipc': real-time IPC streaming with automatic resharding, " "'ipc_snapshot': load from disk snapshot of IPC weights.", ) - parser.add_argument("--enable_mm", action="store_true", help="Whether to enable vl model") parser.add_argument( "--enable_logprob", action="store_true", @@ -708,8 +707,6 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: else: logger.info("No quantization config found and use original weight and act dtype.") - # Set VL tag - model_config.enable_mm = args.enable_mm logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}") logger.info(f"- Load strategy: {load_config.load_strategy}") diff --git a/test/graph_optimization/test_cuda_graph_dynamic_subgraph.py b/test/graph_optimization/test_cuda_graph_dynamic_subgraph.py index 9e28240bf1..029dabd781 100644 --- a/test/graph_optimization/test_cuda_graph_dynamic_subgraph.py +++ b/test/graph_optimization/test_cuda_graph_dynamic_subgraph.py @@ -144,7 +144,10 @@ def run_test_case(): graph_opt_config.use_cudagraph = True parallel_config = ParallelConfig(args={}) parallel_config.max_num_seqs = 1 - fd_config = FDConfig(graph_opt_config=graph_opt_config, parallel_config=parallel_config) + # Initialize cuda graph capture list + graph_opt_config._set_cudagraph_sizes(max_num_seqs=parallel_config.max_num_seqs) + graph_opt_config.init_with_cudagrpah_size(max_num_seqs=parallel_config.max_num_seqs) + fd_config = FDConfig(graph_opt_config=graph_opt_config, parallel_config=parallel_config, test_mode=True) # Run Test Case1 test_model1 = TestModel1(fd_config=fd_config) diff --git a/test/graph_optimization/test_cuda_graph_spec_decode.py b/test/graph_optimization/test_cuda_graph_spec_decode.py index 8e8fcf4886..14b3597d33 100644 --- a/test/graph_optimization/test_cuda_graph_spec_decode.py +++ b/test/graph_optimization/test_cuda_graph_spec_decode.py @@ -90,7 +90,10 @@ def run_test_case(): graph_opt_config.use_cudagraph = True parallel_config = ParallelConfig(args={}) parallel_config.max_num_seqs = 1 - fd_config = FDConfig(graph_opt_config=graph_opt_config, parallel_config=parallel_config) + # Initialize cuda graph capture list + graph_opt_config._set_cudagraph_sizes(max_num_seqs=parallel_config.max_num_seqs) + graph_opt_config.init_with_cudagrpah_size(max_num_seqs=parallel_config.max_num_seqs) + fd_config = FDConfig(graph_opt_config=graph_opt_config, parallel_config=parallel_config, test_mode=True) # Run Test Case1 test_model1 = TestModel1(fd_config=fd_config)