Skip to content

[Excutor] Experiment-Support Prefill in cudagraph #3459

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,13 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
paddle::Tensor &encoder_batch_ids, // Inplace
paddle::Tensor &encoder_tile_ids_per_batch // Inplace
paddle::Tensor &encoder_num_blocks_x_cpu // Inplace, Pinned Memory
paddle::Tensor &kv_batch_ids // Inplace
paddle::Tensor &kv_tile_ids_per_batch // Inplace
paddle::Tensor &kv_num_blocks_x_cpu // Inplace, Pinned Memory
paddle::Tensor &max_len_kv_cpu // Inplace, Pinned Memory
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int group_size,
Expand All @@ -223,14 +230,8 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
int max_system_len = max_len_cpu_ptr[6];
int max_just_dec_len_without_system = max_len_cpu_ptr[7];

paddle::Tensor encoder_batch_ids;
paddle::Tensor encoder_tile_ids_per_batch;
paddle::Tensor encoder_num_blocks_x_cpu; /*cpu*/
paddle::Tensor kv_batch_ids;
paddle::Tensor kv_tile_ids_per_batch;
paddle::Tensor kv_num_blocks_x_cpu; /*cpu*/
paddle::Tensor max_len_kv_cpu; /*cpu*/

PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(max_len_kv_cpu.data<int>(), 0, sizeof(int32_t), stream));
auto max_len_kv =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place());
get_max_len_kv_ernel<128><<<1, 128, 0, stream>>>(
Expand All @@ -240,14 +241,11 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
max_len_kv_cpu = max_len_kv.copy_to(paddle::CPUPlace(), false);

if (max_enc_len_this_time > 0) {
const uint32_t max_tile_size_per_bs_kv =
div_up(max_enc_dec_len_this_time, block_size);
kv_batch_ids =
GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32,
seq_lens_encoder.place());
kv_tile_ids_per_batch =
GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32,
seq_lens_encoder.place());
const uint32_t max_tile_size_per_bs_kv = div_up(max_enc_dec_len_this_time, block_size);
const uint32_t kv_batch_shape = bsz * max_tile_size_per_bs_kv;
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_batch_ids.data<int>(), 0, kv_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_tile_ids_per_batch.data<int>(), 0, kv_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_num_blocks_x_cpu.data<int>(), 0, sizeof(int32_t), stream));
auto kv_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());

Expand All @@ -259,15 +257,12 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
block_size, block_size);

kv_num_blocks_x_cpu = kv_num_blocks_x.copy_to(paddle::CPUPlace(), false);

const uint32_t encoder_max_tile_size_per_bs_q =
div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
encoder_batch_ids =
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
paddle::DataType::INT32, seq_lens_encoder.place());
encoder_tile_ids_per_batch =
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
paddle::DataType::INT32, seq_lens_encoder.place());
// Clear buffer
const uint32_t encoder_max_tile_size_per_bs_q = div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
const uint32_t encoder_batch_shape = bsz * encoder_max_tile_size_per_bs_q;
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_batch_ids.data<int>(), 0, encoder_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_tile_ids_per_batch.data<int>(), 0, encoder_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_num_blocks_x_cpu.data<int>(), 0, sizeof(int32_t), stream));
auto encoder_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data<int>(), nullptr,
Expand All @@ -277,19 +272,6 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
encoder_block_shape_q, group_size);
encoder_num_blocks_x_cpu =
encoder_num_blocks_x.copy_to(paddle::CPUPlace(), false);
} else {
encoder_batch_ids =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
encoder_tile_ids_per_batch =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
encoder_num_blocks_x_cpu =
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
kv_batch_ids =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
kv_tile_ids_per_batch =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
kv_num_blocks_x_cpu =
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
}

if (max_just_dec_len_this_time > 0) {
Expand All @@ -314,15 +296,6 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
decoder_num_blocks_x_cpu.copy_(decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false);
}

return {
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks_x_cpu, /*cpu*/
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks_x_cpu, /*cpu*/
max_len_kv_cpu, /*cpu*/
};
}

PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
Expand All @@ -333,16 +306,17 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
"decoder_batch_ids",
"decoder_tile_ids_per_batch",
"decoder_num_blocks_x_cpu",
"max_len_tensor_cpu"
"max_len_tensor_cpu",
"encoder_batch_ids",
"encoder_tile_ids_per_batch",
"encoder_num_blocks_x_cpu",
"kv_batch_ids",
"kv_tile_ids_per_batch",
"kv_num_blocks_x_cpu",
"max_len_kv_cpu"
})
.Outputs({
paddle::Optional("encoder_batch_ids"),
paddle::Optional("encoder_tile_ids_per_batch"),
paddle::Optional("encoder_num_blocks_x_cpu"),
paddle::Optional("kv_batch_ids"),
paddle::Optional("kv_tile_ids_per_batch"),
paddle::Optional("kv_num_blocks_x_cpu"),
"max_len_kv_cpu"

})
.Attrs({
"encoder_block_shape_q: int",
Expand Down
14 changes: 11 additions & 3 deletions fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,10 @@ def __init__(
""" Whether to use a full cuda graph for the entire forward pass rather than
splitting certain operations such as attention into subgraphs.
Thus this flag cannot be used together with splitting_ops."""
self.cudagraph_only_prefill: bool = False
"""When cudagraph_only_prefill is False, only capture decode-only.
When cudagraph_only_prefill is True, only capture prefill-only.
Now don't support capture both decode-only and prefill-only"""
self.full_cuda_graph: bool = True

self.max_capture_size: int = None
Expand All @@ -496,13 +500,13 @@ def __init__(

self.check_legality_parameters()

def init_with_cudagrpah_size(self, max_num_seqs: int = 0) -> None:
def init_with_cudagrpah_size(self, max_capture_size: int = 0) -> None:
"""
Initialize cuda graph capture sizes and
pre-compute the mapping from batch size to padded graph size
"""
# Regular capture sizes
self.cudagraph_capture_sizes = [size for size in self.cudagraph_capture_sizes if size <= max_num_seqs]
self.cudagraph_capture_sizes = [size for size in self.cudagraph_capture_sizes if size <= max_capture_size]
dedup_sizes = list(set(self.cudagraph_capture_sizes))
if len(dedup_sizes) < len(self.cudagraph_capture_sizes):
logger.info(
Expand Down Expand Up @@ -950,7 +954,11 @@ def __post_init__(self):
# Initialize cuda graph capture list
if self.graph_opt_config.cudagraph_capture_sizes is None:
self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.parallel_config.max_num_seqs)
self.graph_opt_config.init_with_cudagrpah_size(max_num_seqs=self.parallel_config.max_num_seqs)

if self.graph_opt_config.cudagraph_only_prefill:
self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=512)
else:
self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=self.parallel_config.max_num_seqs)

# TODO(wangmingkai02): change graph_opt_level=2 when using static mode with cinn
if self.graph_opt_config.graph_opt_level == 2:
Expand Down
8 changes: 8 additions & 0 deletions fastdeploy/model_executor/forward_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@ class ForwardMeta:
# Recorded multiple lengths related to prefill or decode
max_len_tensor_cpu: Optional[paddle.Tensor] = None

encoder_batch_ids: Optional[paddle.Tensor] = None
encoder_tile_ids_per_batch: Optional[paddle.Tensor] = None
encoder_num_blocks_x_cpu: Optional[paddle.Tensor] = None
kv_batch_ids: Optional[paddle.Tensor] = None
kv_tile_ids_per_batch: Optional[paddle.Tensor] = None
kv_num_blocks_x_cpu: Optional[paddle.Tensor] = None
max_len_kv_cpu: Optional[paddle.Tensor] = None

# Sequence length of encoder for ever batch
seq_lens_encoder: Optional[paddle.Tensor] = None
# Sequence length of Encoder for ever batch
Expand Down
40 changes: 23 additions & 17 deletions fastdeploy/model_executor/layers/attention/append_attn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,6 @@ class AppendAttentionMetadata(AttentionMetadata):
AppendAttentionMetadata
"""

encoder_batch_ids: paddle.Tensor = None
encoder_tile_ids_per_batch: paddle.Tensor = None
encoder_num_blocks: paddle.Tensor = None
kv_batch_ids: paddle.Tensor = None
kv_tile_ids_per_batch: paddle.Tensor = None
kv_num_blocks: paddle.Tensor = None
max_len_kv: paddle.Tensor = None

_dtype: paddle.dtype = paddle.bfloat16
encoder_max_partition_size: int = 32768
max_partition_size: int = 32768
Expand Down Expand Up @@ -123,6 +115,21 @@ def __init__(

self.rank, self.device_id = init_rank_and_device_id(fd_config)

self.share_inputs = {}
self.share_inputs["encoder_batch_ids"] = paddle.full(
shape=[self.max_seq_len], fill_value=0, dtype="int32"
) # gpu
self.share_inputs["encoder_tile_ids_per_batch"] = paddle.full(
shape=[self.max_seq_len], fill_value=0, dtype="int32"
) # gpu
self.share_inputs["encoder_num_blocks"] = paddle.full(shape=[1], fill_value=0, dtype="int32").cpu() # cpu
Comment on lines +119 to +125
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些 buffer 改为在 gpu_model_runner 里面管理,同时改造下 get_block_shape_and_split_kv 这个Kernel,把 encoder 相关的tensor 改为 Inplace 的实现,不然前处理的 copy_ 开销太高了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

暂时提交了一版还没验证过的草稿版,怕服务器挂代码没了

self.share_inputs["kv_batch_ids"] = paddle.full(shape=[self.max_seq_len], fill_value=0, dtype="int32") # gpu
self.share_inputs["kv_tile_ids_per_batch"] = paddle.full(
shape=[self.max_seq_len], fill_value=0, dtype="int32"
) # gpu
self.share_inputs["kv_num_blocks"] = paddle.full(shape=[1], fill_value=0, dtype="int32").cpu() # cpu
self.share_inputs["max_len_kv"] = paddle.full(shape=[1], fill_value=0, dtype="int32").cpu() # cpu

def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
metadata = AppendAttentionMetadata()
Expand All @@ -139,22 +146,21 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
metadata.rotary_embs = forward_meta.rotary_embs
metadata.attn_mask = forward_meta.attn_mask
metadata.pre_caches_length = forward_meta.pre_caches_length
(
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
metadata.max_len_kv,
) = get_block_shape_and_split_kv_block(
get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.max_len_tensor_cpu,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
forward_meta.max_len_kv_cpu,
self.encoder_block_shape_q,
self.decoder_block_shape_q,
self.group_size,
Expand Down
24 changes: 8 additions & 16 deletions fastdeploy/model_executor/layers/attention/flash_attn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,6 @@ class FlashAttentionMetadata(AttentionMetadata):

rotary_embs: Optional[paddle.Tensor] = None
block_tables: Optional[paddle.Tensor] = None
encoder_batch_ids: paddle.Tensor = None
encoder_tile_ids_per_batch: paddle.Tensor = None
encoder_num_blocks: paddle.Tensor = None
kv_batch_ids: paddle.Tensor = None
kv_tile_ids_per_batch: paddle.Tensor = None
kv_num_blocks: paddle.Tensor = None
max_len_kv: paddle.Tensor = None

cu_seqlens_q: paddle.Tensor = None
cu_seqlens_k: paddle.Tensor = None
Expand Down Expand Up @@ -198,22 +191,21 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
metadata.rotary_embs = forward_meta.rotary_embs
metadata.block_tables = forward_meta.block_tables
(
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
metadata.max_len_kv,
) = get_block_shape_and_split_kv_block(
get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.max_len_tensor_cpu,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
forward_meta.max_len_kv_cpu,
self.encoder_block_shape_q,
self.decoder_block_shape_q,
self.group_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,6 @@ class MLAAttentionMetadata(AttentionMetadata):
MLAAttentionMetadata for Multi-Layer Attention
"""

encoder_batch_ids: paddle.Tensor = None
encoder_tile_ids_per_batch: paddle.Tensor = None
encoder_num_blocks: paddle.Tensor = None
kv_batch_ids: paddle.Tensor = None
kv_tile_ids_per_batch: paddle.Tensor = None
kv_num_blocks: paddle.Tensor = None
max_len_kv: paddle.Tensor = None

_dtype: paddle.dtype = paddle.bfloat16
encoder_max_partition_size: int = 32768
max_partition_size: int = 32768
Expand Down Expand Up @@ -166,22 +158,21 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
metadata.attn_mask = forward_meta.attn_mask
metadata.pre_caches_length = forward_meta.pre_caches_length

(
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
metadata.max_len_kv,
) = get_block_shape_and_split_kv_block(
get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.max_len_tensor_cpu,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
forward_meta.max_len_kv_cpu,
self.encoder_block_shape_q,
self.decoder_block_shape_q,
self.group_size,
Expand Down
Loading
Loading