diff --git a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu index b9c951d391..27c370a98a 100644 --- a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu +++ b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu @@ -199,6 +199,13 @@ std::vector GetBlockShapeAndSplitKVBlock( paddle::Tensor &decoder_tile_ids_per_batch, // Inplace paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory + paddle::Tensor &encoder_batch_ids, // Inplace + paddle::Tensor &encoder_tile_ids_per_batch // Inplace + paddle::Tensor &encoder_num_blocks_x_cpu // Inplace, Pinned Memory + paddle::Tensor &kv_batch_ids // Inplace + paddle::Tensor &kv_tile_ids_per_batch // Inplace + paddle::Tensor &kv_num_blocks_x_cpu // Inplace, Pinned Memory + paddle::Tensor &max_len_kv_cpu // Inplace, Pinned Memory const int encoder_block_shape_q, const int decoder_block_shape_q, const int group_size, @@ -223,14 +230,8 @@ std::vector GetBlockShapeAndSplitKVBlock( int max_system_len = max_len_cpu_ptr[6]; int max_just_dec_len_without_system = max_len_cpu_ptr[7]; - paddle::Tensor encoder_batch_ids; - paddle::Tensor encoder_tile_ids_per_batch; - paddle::Tensor encoder_num_blocks_x_cpu; /*cpu*/ - paddle::Tensor kv_batch_ids; - paddle::Tensor kv_tile_ids_per_batch; - paddle::Tensor kv_num_blocks_x_cpu; /*cpu*/ - paddle::Tensor max_len_kv_cpu; /*cpu*/ + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(max_len_kv_cpu.data(), 0, sizeof(int32_t), stream)); auto max_len_kv = GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place()); get_max_len_kv_ernel<128><<<1, 128, 0, stream>>>( @@ -240,14 +241,11 @@ std::vector GetBlockShapeAndSplitKVBlock( max_len_kv_cpu = max_len_kv.copy_to(paddle::CPUPlace(), false); if (max_enc_len_this_time > 0) { - const uint32_t max_tile_size_per_bs_kv = - div_up(max_enc_dec_len_this_time, block_size); - kv_batch_ids = - GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32, - seq_lens_encoder.place()); - kv_tile_ids_per_batch = - GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32, - seq_lens_encoder.place()); + const uint32_t max_tile_size_per_bs_kv = div_up(max_enc_dec_len_this_time, block_size); + const uint32_t kv_batch_shape = bsz * max_tile_size_per_bs_kv; + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_batch_ids.data(), 0, kv_batch_shape * sizeof(int32_t), stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_tile_ids_per_batch.data(), 0, kv_batch_shape * sizeof(int32_t), stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_num_blocks_x_cpu.data(), 0, sizeof(int32_t), stream)); auto kv_num_blocks_x = GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place()); @@ -259,15 +257,12 @@ std::vector GetBlockShapeAndSplitKVBlock( block_size, block_size); kv_num_blocks_x_cpu = kv_num_blocks_x.copy_to(paddle::CPUPlace(), false); - - const uint32_t encoder_max_tile_size_per_bs_q = - div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q); - encoder_batch_ids = - GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q}, - paddle::DataType::INT32, seq_lens_encoder.place()); - encoder_tile_ids_per_batch = - GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q}, - paddle::DataType::INT32, seq_lens_encoder.place()); + // Clear buffer + const uint32_t encoder_max_tile_size_per_bs_q = div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q); + const uint32_t encoder_batch_shape = bsz * encoder_max_tile_size_per_bs_q; + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_batch_ids.data(), 0, encoder_batch_shape * sizeof(int32_t), stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_tile_ids_per_batch.data(), 0, encoder_batch_shape * sizeof(int32_t), stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_num_blocks_x_cpu.data(), 0, sizeof(int32_t), stream)); auto encoder_num_blocks_x = GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place()); split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data(), nullptr, @@ -277,19 +272,6 @@ std::vector GetBlockShapeAndSplitKVBlock( encoder_block_shape_q, group_size); encoder_num_blocks_x_cpu = encoder_num_blocks_x.copy_to(paddle::CPUPlace(), false); - } else { - encoder_batch_ids = - GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place()); - encoder_tile_ids_per_batch = - GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place()); - encoder_num_blocks_x_cpu = - GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace()); - kv_batch_ids = - GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place()); - kv_tile_ids_per_batch = - GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place()); - kv_num_blocks_x_cpu = - GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace()); } if (max_just_dec_len_this_time > 0) { @@ -314,15 +296,6 @@ std::vector GetBlockShapeAndSplitKVBlock( decoder_num_blocks_x_cpu.copy_(decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false); } - return { - encoder_batch_ids, - encoder_tile_ids_per_batch, - encoder_num_blocks_x_cpu, /*cpu*/ - kv_batch_ids, - kv_tile_ids_per_batch, - kv_num_blocks_x_cpu, /*cpu*/ - max_len_kv_cpu, /*cpu*/ - }; } PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block) @@ -333,16 +306,17 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block) "decoder_batch_ids", "decoder_tile_ids_per_batch", "decoder_num_blocks_x_cpu", - "max_len_tensor_cpu" + "max_len_tensor_cpu", + "encoder_batch_ids", + "encoder_tile_ids_per_batch", + "encoder_num_blocks_x_cpu", + "kv_batch_ids", + "kv_tile_ids_per_batch", + "kv_num_blocks_x_cpu", + "max_len_kv_cpu" }) .Outputs({ - paddle::Optional("encoder_batch_ids"), - paddle::Optional("encoder_tile_ids_per_batch"), - paddle::Optional("encoder_num_blocks_x_cpu"), - paddle::Optional("kv_batch_ids"), - paddle::Optional("kv_tile_ids_per_batch"), - paddle::Optional("kv_num_blocks_x_cpu"), - "max_len_kv_cpu" + }) .Attrs({ "encoder_block_shape_q: int", diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 9c59b8bab9..71a3314d4a 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -484,6 +484,10 @@ def __init__( """ Whether to use a full cuda graph for the entire forward pass rather than splitting certain operations such as attention into subgraphs. Thus this flag cannot be used together with splitting_ops.""" + self.cudagraph_only_prefill: bool = False + """When cudagraph_only_prefill is False, only capture decode-only. + When cudagraph_only_prefill is True, only capture prefill-only. + Now don't support capture both decode-only and prefill-only""" self.full_cuda_graph: bool = True self.max_capture_size: int = None @@ -496,13 +500,13 @@ def __init__( self.check_legality_parameters() - def init_with_cudagrpah_size(self, max_num_seqs: int = 0) -> None: + def init_with_cudagrpah_size(self, max_capture_size: int = 0) -> None: """ Initialize cuda graph capture sizes and pre-compute the mapping from batch size to padded graph size """ # Regular capture sizes - self.cudagraph_capture_sizes = [size for size in self.cudagraph_capture_sizes if size <= max_num_seqs] + self.cudagraph_capture_sizes = [size for size in self.cudagraph_capture_sizes if size <= max_capture_size] dedup_sizes = list(set(self.cudagraph_capture_sizes)) if len(dedup_sizes) < len(self.cudagraph_capture_sizes): logger.info( @@ -950,7 +954,11 @@ def __post_init__(self): # Initialize cuda graph capture list if self.graph_opt_config.cudagraph_capture_sizes is None: self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.parallel_config.max_num_seqs) - self.graph_opt_config.init_with_cudagrpah_size(max_num_seqs=self.parallel_config.max_num_seqs) + + if self.graph_opt_config.cudagraph_only_prefill: + self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=512) + else: + self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=self.parallel_config.max_num_seqs) # TODO(wangmingkai02): change graph_opt_level=2 when using static mode with cinn if self.graph_opt_config.graph_opt_level == 2: diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index ec31c4753e..82bf54e30e 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -88,6 +88,14 @@ class ForwardMeta: # Recorded multiple lengths related to prefill or decode max_len_tensor_cpu: Optional[paddle.Tensor] = None + encoder_batch_ids: Optional[paddle.Tensor] = None + encoder_tile_ids_per_batch: Optional[paddle.Tensor] = None + encoder_num_blocks_x_cpu: Optional[paddle.Tensor] = None + kv_batch_ids: Optional[paddle.Tensor] = None + kv_tile_ids_per_batch: Optional[paddle.Tensor] = None + kv_num_blocks_x_cpu: Optional[paddle.Tensor] = None + max_len_kv_cpu: Optional[paddle.Tensor] = None + # Sequence length of encoder for ever batch seq_lens_encoder: Optional[paddle.Tensor] = None # Sequence length of Encoder for ever batch diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index ea6bdd6ab6..b751689849 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -48,14 +48,6 @@ class AppendAttentionMetadata(AttentionMetadata): AppendAttentionMetadata """ - encoder_batch_ids: paddle.Tensor = None - encoder_tile_ids_per_batch: paddle.Tensor = None - encoder_num_blocks: paddle.Tensor = None - kv_batch_ids: paddle.Tensor = None - kv_tile_ids_per_batch: paddle.Tensor = None - kv_num_blocks: paddle.Tensor = None - max_len_kv: paddle.Tensor = None - _dtype: paddle.dtype = paddle.bfloat16 encoder_max_partition_size: int = 32768 max_partition_size: int = 32768 @@ -123,6 +115,21 @@ def __init__( self.rank, self.device_id = init_rank_and_device_id(fd_config) + self.share_inputs = {} + self.share_inputs["encoder_batch_ids"] = paddle.full( + shape=[self.max_seq_len], fill_value=0, dtype="int32" + ) # gpu + self.share_inputs["encoder_tile_ids_per_batch"] = paddle.full( + shape=[self.max_seq_len], fill_value=0, dtype="int32" + ) # gpu + self.share_inputs["encoder_num_blocks"] = paddle.full(shape=[1], fill_value=0, dtype="int32").cpu() # cpu + self.share_inputs["kv_batch_ids"] = paddle.full(shape=[self.max_seq_len], fill_value=0, dtype="int32") # gpu + self.share_inputs["kv_tile_ids_per_batch"] = paddle.full( + shape=[self.max_seq_len], fill_value=0, dtype="int32" + ) # gpu + self.share_inputs["kv_num_blocks"] = paddle.full(shape=[1], fill_value=0, dtype="int32").cpu() # cpu + self.share_inputs["max_len_kv"] = paddle.full(shape=[1], fill_value=0, dtype="int32").cpu() # cpu + def init_attention_metadata(self, forward_meta: ForwardMeta): """Initialize attntion metadata hence all layers in the forward pass can reuse it.""" metadata = AppendAttentionMetadata() @@ -139,15 +146,7 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): metadata.rotary_embs = forward_meta.rotary_embs metadata.attn_mask = forward_meta.attn_mask metadata.pre_caches_length = forward_meta.pre_caches_length - ( - metadata.encoder_batch_ids, - metadata.encoder_tile_ids_per_batch, - metadata.encoder_num_blocks, - metadata.kv_batch_ids, - metadata.kv_tile_ids_per_batch, - metadata.kv_num_blocks, - metadata.max_len_kv, - ) = get_block_shape_and_split_kv_block( + get_block_shape_and_split_kv_block( forward_meta.seq_lens_encoder, forward_meta.seq_lens_decoder, forward_meta.seq_lens_this_time, @@ -155,6 +154,13 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): forward_meta.decoder_tile_ids_per_batch, forward_meta.decoder_num_blocks_cpu, forward_meta.max_len_tensor_cpu, + forward_meta.encoder_batch_ids, + forward_meta.encoder_tile_ids_per_batch, + forward_meta.encoder_num_blocks_x_cpu, + forward_meta.kv_batch_ids, + forward_meta.kv_tile_ids_per_batch, + forward_meta.kv_num_blocks_x_cpu, + forward_meta.max_len_kv_cpu, self.encoder_block_shape_q, self.decoder_block_shape_q, self.group_size, diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index ed92483932..10801c4879 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -65,13 +65,6 @@ class FlashAttentionMetadata(AttentionMetadata): rotary_embs: Optional[paddle.Tensor] = None block_tables: Optional[paddle.Tensor] = None - encoder_batch_ids: paddle.Tensor = None - encoder_tile_ids_per_batch: paddle.Tensor = None - encoder_num_blocks: paddle.Tensor = None - kv_batch_ids: paddle.Tensor = None - kv_tile_ids_per_batch: paddle.Tensor = None - kv_num_blocks: paddle.Tensor = None - max_len_kv: paddle.Tensor = None cu_seqlens_q: paddle.Tensor = None cu_seqlens_k: paddle.Tensor = None @@ -198,15 +191,7 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): metadata.cu_seqlens_q = forward_meta.cu_seqlens_q metadata.rotary_embs = forward_meta.rotary_embs metadata.block_tables = forward_meta.block_tables - ( - metadata.encoder_batch_ids, - metadata.encoder_tile_ids_per_batch, - metadata.encoder_num_blocks, - metadata.kv_batch_ids, - metadata.kv_tile_ids_per_batch, - metadata.kv_num_blocks, - metadata.max_len_kv, - ) = get_block_shape_and_split_kv_block( + get_block_shape_and_split_kv_block( forward_meta.seq_lens_encoder, forward_meta.seq_lens_decoder, forward_meta.seq_lens_this_time, @@ -214,6 +199,13 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): forward_meta.decoder_tile_ids_per_batch, forward_meta.decoder_num_blocks_cpu, forward_meta.max_len_tensor_cpu, + forward_meta.encoder_batch_ids, + forward_meta.encoder_tile_ids_per_batch, + forward_meta.encoder_num_blocks_x_cpu, + forward_meta.kv_batch_ids, + forward_meta.kv_tile_ids_per_batch, + forward_meta.kv_num_blocks_x_cpu, + forward_meta.max_len_kv_cpu, self.encoder_block_shape_q, self.decoder_block_shape_q, self.group_size, diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index 5279b68f6f..faf0ca1515 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -64,14 +64,6 @@ class MLAAttentionMetadata(AttentionMetadata): MLAAttentionMetadata for Multi-Layer Attention """ - encoder_batch_ids: paddle.Tensor = None - encoder_tile_ids_per_batch: paddle.Tensor = None - encoder_num_blocks: paddle.Tensor = None - kv_batch_ids: paddle.Tensor = None - kv_tile_ids_per_batch: paddle.Tensor = None - kv_num_blocks: paddle.Tensor = None - max_len_kv: paddle.Tensor = None - _dtype: paddle.dtype = paddle.bfloat16 encoder_max_partition_size: int = 32768 max_partition_size: int = 32768 @@ -166,15 +158,7 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): metadata.attn_mask = forward_meta.attn_mask metadata.pre_caches_length = forward_meta.pre_caches_length - ( - metadata.encoder_batch_ids, - metadata.encoder_tile_ids_per_batch, - metadata.encoder_num_blocks, - metadata.kv_batch_ids, - metadata.kv_tile_ids_per_batch, - metadata.kv_num_blocks, - metadata.max_len_kv, - ) = get_block_shape_and_split_kv_block( + get_block_shape_and_split_kv_block( forward_meta.seq_lens_encoder, forward_meta.seq_lens_decoder, forward_meta.seq_lens_this_time, @@ -182,6 +166,13 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): forward_meta.decoder_tile_ids_per_batch, forward_meta.decoder_num_blocks_cpu, forward_meta.max_len_tensor_cpu, + forward_meta.encoder_batch_ids, + forward_meta.encoder_tile_ids_per_batch, + forward_meta.encoder_num_blocks_x_cpu, + forward_meta.kv_batch_ids, + forward_meta.kv_tile_ids_per_batch, + forward_meta.kv_num_blocks_x_cpu, + forward_meta.max_len_kv_cpu, self.encoder_block_shape_q, self.decoder_block_shape_q, self.group_size, diff --git a/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py b/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py index dd57b52593..c3929175de 100644 --- a/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py +++ b/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py @@ -32,6 +32,13 @@ def get_block_shape_and_split_kv_block( decoder_tile_ids_per_batch: paddle.Tensor, decoder_num_blocks_x_cpu: paddle.Tensor, max_len_tensor_cpu: paddle.Tensor, + encoder_batch_ids: paddle.Tensor, + encoder_tile_ids_per_batch: paddle.Tensor, + encoder_num_blocks_x_cpu: paddle.Tensor, + kv_batch_ids: paddle.Tensor, + kv_tile_ids_per_batch: paddle.Tensor, + kv_num_blocks_x_cpu: paddle.Tensor, + max_len_kv_cpu: paddle.Tensor, encoder_block_shape_q: int, decoder_block_shape_q: int, group_size: int, @@ -42,21 +49,20 @@ def get_block_shape_and_split_kv_block( get_block_shape_and_split_kv_block """ if current_platform.is_cuda(): - ( - encoder_batch_ids, - encoder_tile_ids_per_batch, - encoder_num_blocks, - kv_batch_ids, - kv_tile_ids_per_batch, - kv_num_blocks, - max_len_kv_cpu, - ) = get_block_shape_and_split_kv_block_cuda( + get_block_shape_and_split_kv_block_cuda( seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_x_cpu, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks_x_cpu, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks_x_cpu, + max_len_kv_cpu, max_len_tensor_cpu, encoder_block_shape_q, decoder_block_shape_q, @@ -64,14 +70,6 @@ def get_block_shape_and_split_kv_block( block_size, decoder_step_token_num, ) - return ( - encoder_batch_ids, - encoder_tile_ids_per_batch, - encoder_num_blocks, - kv_batch_ids, - kv_tile_ids_per_batch, - kv_num_blocks, - max_len_kv_cpu, - ) + else: raise NotImplementedError diff --git a/fastdeploy/worker/gcu_model_runner.py b/fastdeploy/worker/gcu_model_runner.py index d1f8f2c689..3073b9b5a5 100644 --- a/fastdeploy/worker/gcu_model_runner.py +++ b/fastdeploy/worker/gcu_model_runner.py @@ -432,6 +432,13 @@ def _init_share_inputs(self, max_num_seqs: int): self.share_inputs["decoder_tile_ids_per_batch"] = None self.share_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory self.share_inputs["max_len_tensor_cpu"] = None # CPU + self.share_inputs["encoder_batch_ids"] = None + self.share_inputs["encoder_tile_ids_per_batch"] = None + self.share_inputs["encoder_num_blocks_x_cpu"] = None # CPU + self.share_inputs["kv_batch_ids"] = None + self.share_inputs["kv_tile_ids_per_batch"] = None + self.share_inputs["kv_num_blocks_x_cpu"] = None # CPU + self.share_inputs["max_len_kv_cpu"] = None # CPU # Initialize rotary position embedding tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1)) @@ -603,6 +610,13 @@ def initialize_forward_meta(self): cu_seqlens_k=self.share_inputs["cu_seqlens_k"], block_tables=self.share_inputs["block_tables"], caches=self.share_inputs["caches"], + encoder_batch_ids=self.share_inputs["encoder_batch_ids"], + encoder_tile_ids_per_batch=self.share_inputs["encoder_tile_ids_per_batch"], + encoder_num_blocks_x_cpu=self.share_inputs["encoder_num_blocks_x_cpu"], + kv_batch_ids=self.share_inputs["kv_batch_ids"], + kv_tile_ids_per_batch=self.share_inputs["kv_tile_ids_per_batch"], + kv_num_blocks_x_cpu=self.share_inputs["kv_num_blocks_x_cpu"], + max_len_kv_cpu=self.share_inputs["max_len_kv_cpu"], ) # Update Batch type for cuda graph @@ -675,14 +689,31 @@ def initialize_attn_backend(self) -> None: encoder_block_shape_q = 64 decoder_block_shape_q = 16 decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1 + group_size = np.ceil(num_heads / self.model_config.kv_num_heads) + decode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil( - (decoder_step_token_num * np.ceil(num_heads / self.model_config.kv_num_heads)) / decoder_block_shape_q + (decoder_step_token_num * group_size) / decoder_block_shape_q + ) + encode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil( + (self.model_config.max_model_len * group_size) / encoder_block_shape_q + ) + kv_max_tile_size = self.parallel_config.max_num_seqs * np.ceil( + self.model_config.max_model_len / self.fd_config.cache_config.block_size ) self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").cpu() self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu() + self.share_inputs["encoder_batch_ids"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32") + self.share_inputs["encoder_tile_ids_per_batch"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32") + self.share_inputs["encoder_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu() + + self.share_inputs["kv_batch_ids"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32") + self.share_inputs["kv_tile_ids_per_batch"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32") + self.share_inputs["kv_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu() + self.share_inputs["max_len_kv_cpu"] = paddle.full([1], 0, dtype="int32").cpu() + # Get the attention backend attn_cls = get_attention_backend() attn_backend = attn_cls( diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 2dfe1021c6..1d966b80d4 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -131,6 +131,7 @@ def __init__( self.use_cudagraph = self.graph_opt_config.use_cudagraph self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes)) self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes + self.cudagraph_only_prefill = self.graph_opt_config.cudagraph_only_prefill # Initialize share inputs self._init_share_inputs(self.parallel_config.max_num_seqs) @@ -161,10 +162,53 @@ def exist_prefill(self): """ check whether prefill stage exist """ - if int(paddle.max(self.share_inputs["seq_lens_encoder"])) != 0: - return 1 - else: - return 0 + return int(paddle.max(self.share_inputs["seq_lens_encoder"])) > 0 + + def exist_decode(self): + """ + check whether decode stage exist + """ + return int(paddle.max(self.share_inputs["seq_lens_decoder"])) > 0 + + def only_prefill(self): + """ + check whether prefill only + """ + only_prefill_batch = True + decode_exists = None + if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed": + # 收集所有 worker 的状态 + only_prefill_batch_list = [] + decode_exists = self.exist_decode() + paddle.distributed.all_gather_object(only_prefill_batch_list, not decode_exists) + only_prefill_batch = all(only_prefill_batch_list) + + only_prefill_batch = only_prefill_batch and not ( + decode_exists if decode_exists is not None else self.exist_decode() + ) + + return only_prefill_batch + + def only_decode(self): + """ + check whether decode only + """ + # Update Batch type for cuda graph for only_decode_batch + only_decode_batch = True + prefill_exists = None + # mix ep in single node + if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed": + only_decode_batch_list = [] + prefill_exists = self.exist_prefill() + paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists) + only_decode_batch = all(only_decode_batch_list) + self.fd_config.parallel_config.moe_phase.phase = "decode" if only_decode_batch else "prefill" + + only_decode_batch = only_decode_batch and not ( + prefill_exists if prefill_exists is not None else self.exist_prefill() + ) + + return only_decode_batch def _init_speculative_proposer(self): """ @@ -551,7 +595,7 @@ def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decod """Set dummy prefill inputs to share_inputs""" # NOTE(gongshaotian): The maximum decoding length is equal to the expected decoded tokens plus the eos token max_dec_len = expected_decode_len + 1 - full_length = min( + input_length = min( num_tokens // batch_size, self.parallel_config.max_model_len - max_dec_len, ) @@ -559,9 +603,8 @@ def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decod # NOTE(wanglongzhi): When the full length is too large, DeepEP's buffer size will not be enough to cause the result to appear nan. # TODO(wanglongzhi): Figure out the accurate buffer size of DeepEP. if self.fd_config.parallel_config.enable_expert_parallel: - full_length = min(full_length, 32) + input_length = min(input_length, 32) - input_length = int(full_length * self.cache_config.kv_cache_ratio) block_num = ( input_length + self.cache_config.block_size - 1 ) // self.cache_config.block_size + self.cache_config.enc_dec_block_num @@ -690,6 +733,13 @@ def _init_share_inputs(self, max_num_seqs: int): self.share_inputs["decoder_tile_ids_per_batch"] = None self.share_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory self.share_inputs["max_len_tensor_cpu"] = None # CPU + self.share_inputs["encoder_batch_ids"] = None + self.share_inputs["encoder_tile_ids_per_batch"] = None + self.share_inputs["encoder_num_blocks_x_cpu"] = None # CPU + self.share_inputs["kv_batch_ids"] = None + self.share_inputs["kv_tile_ids_per_batch"] = None + self.share_inputs["kv_num_blocks_x_cpu"] = None # CPU + self.share_inputs["max_len_kv_cpu"] = None # CPU # Initialize rotary position embedding tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1)) @@ -900,23 +950,24 @@ def initialize_forward_meta(self): cu_seqlens_k=self.share_inputs["cu_seqlens_k"], block_tables=self.share_inputs["block_tables"], caches=self.share_inputs["caches"], + encoder_batch_ids=self.share_inputs["encoder_batch_ids"], + encoder_tile_ids_per_batch=self.share_inputs["encoder_tile_ids_per_batch"], + encoder_num_blocks_x_cpu=self.share_inputs["encoder_num_blocks_x_cpu"], + kv_batch_ids=self.share_inputs["kv_batch_ids"], + kv_tile_ids_per_batch=self.share_inputs["kv_tile_ids_per_batch"], + kv_num_blocks_x_cpu=self.share_inputs["kv_num_blocks_x_cpu"], + max_len_kv_cpu=self.share_inputs["max_len_kv_cpu"], ) - # Update Batch type for cuda graph - only_decode_batch = True - prefill_exists = None - # mix ep in single node - if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed": - only_decode_batch_list = [] - prefill_exists = self.exist_prefill() - paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists) - only_decode_batch = all(only_decode_batch_list) - self.fd_config.parallel_config.moe_phase.phase = "decode" if only_decode_batch else "prefill" + # Update Batch type for cuda graph for only_decode_batch + only_decode_use_cudagraph = self.use_cudagraph and self.only_decode() + + # Update Batch type for cuda graph for only_prefill_batch + only_prefill_use_cudagraph = self.use_cudagraph and self.cudagraph_only_prefill and self.only_prefill() + # When support capture both prefill-only and decode-only, this will use [only_prefill_use_cudagraph or only_decode_use_cudagraph] self.forward_meta.step_use_cudagraph = ( - self.use_cudagraph - and only_decode_batch - and not (prefill_exists if prefill_exists is not None else self.exist_prefill()) + only_prefill_use_cudagraph if self.cudagraph_only_prefill else only_decode_use_cudagraph ) # Initialzie attention meta data @@ -995,14 +1046,31 @@ def initialize_attn_backend(self) -> None: encoder_block_shape_q = 64 decoder_block_shape_q = 16 decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1 + group_size = np.ceil(num_heads / self.model_config.kv_num_heads) + decode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil( - (decoder_step_token_num * np.ceil(num_heads / self.model_config.kv_num_heads)) / decoder_block_shape_q + (decoder_step_token_num * group_size) / decoder_block_shape_q + ) + encode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil( + (self.model_config.max_model_len * group_size) / encoder_block_shape_q + ) + kv_max_tile_size = self.parallel_config.max_num_seqs * np.ceil( + self.model_config.max_model_len / self.fd_config.cache_config.block_size ) self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory() self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu() + self.share_inputs["encoder_batch_ids"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32") + self.share_inputs["encoder_tile_ids_per_batch"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32") + self.share_inputs["encoder_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu() + + self.share_inputs["kv_batch_ids"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32") + self.share_inputs["kv_tile_ids_per_batch"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32") + self.share_inputs["kv_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu() + self.share_inputs["max_len_kv_cpu"] = paddle.full([1], 0, dtype="int32").cpu() + # Get the attention backend attn_cls = get_attention_backend() attn_backend = attn_cls( @@ -1230,7 +1298,7 @@ def _update_chunked_prefill(self, tasks): self.proposer.update_task_chunk_prefill(task) task.chunk_idx += 1 - def capture_model(self) -> None: + def capture_model(self, prefill_only: bool = False) -> None: """ Trigger CUDA Graph capture for all shapes in cuda graph capture list """ @@ -1240,14 +1308,29 @@ def capture_model(self) -> None: time_before_capture = time.perf_counter() expected_decode_len = 1 capture_sizes = self.cudagraph_capture_sizes.copy() - for batch_size in sorted(capture_sizes, reverse=True): - self._dummy_run( - num_tokens=self.parallel_config.max_num_batched_tokens, - batch_size=batch_size, - in_capturing=True, - expected_decode_len=expected_decode_len, - ) - logger.info(f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}") + + if prefill_only: + for num_tokens in sorted(capture_sizes, reverse=True): + self._dummy_run( + num_tokens=num_tokens, + batch_size=1, + in_capturing=True, + expected_decode_len=expected_decode_len, + ) + logger.info( + f"Warm up the model with the num_tokens:{num_tokens}, expected_decode_len:{expected_decode_len}" + ) + else: + for batch_size in sorted(capture_sizes, reverse=True): + self._dummy_run( + num_tokens=self.parallel_config.max_num_batched_tokens, + batch_size=batch_size, + in_capturing=True, + expected_decode_len=expected_decode_len, + ) + logger.info( + f"Warm up the model with the num_tokens:{batch_size}, expected_decode_len:{expected_decode_len}" + ) time_after_capture = time.perf_counter() logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds") diff --git a/fastdeploy/worker/gpu_worker.py b/fastdeploy/worker/gpu_worker.py index bfdc92f1dc..a95f5bfea4 100644 --- a/fastdeploy/worker/gpu_worker.py +++ b/fastdeploy/worker/gpu_worker.py @@ -210,7 +210,7 @@ def graph_optimize_and_warm_up_model(self) -> None: if self.model_runner.graph_opt_level >= 1: self.model_runner.sot_warmup() # Triger cuda grpah capture - self.model_runner.capture_model() + self.model_runner.capture_model(prefill_only=self.fd_config.graph_opt_config.cudagraph_only_prefill) def check_health(self) -> bool: """ """