Skip to content

[Excutor] Experiment-Support Prefill in cudagraph #3459

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from

Conversation

littledgg
Copy link
Contributor

目前支持Prefill-Only的batch进cudagraph。在确定graph可以共用之前,只能选择要么capture decode-only的,要么capture prefill-only。
1.想要开启,需要使用以下参数启动,重点是use_cudagraph和cudagraph_only_prefill都设为True

python -m fastdeploy.entrypoints.openai.api_server --model ${model_path} \
    --max-num-seqs 64 --max-model-len 32768 \
    --port 8988 --engine-worker-queue-port 7732 \
    --metrics-port 7733 --tensor-parallel-size 1 \
    --graph-optimization-config ' {"use_cudagraph":true,"cudagraph_only_prefill":true}' \

2.在当前动态插入的背景下,假设发送4个80 tokens的prompt,那么seq_lens_this_time第一轮是[80],第二轮是[1, 80, 80, 80],很明显只有第一轮是纯P,可以进cudagraph,第二轮就是MIX了,进不了cudagraph,可以通过修改fastdeploy/engine/engine.py中的函数_insert_task_to_worker中

tasks = self.scheduler.get_requests(
                    available_blocks=self.resource_manager.available_block_num(),
                    block_size=self.cfg.cache_config.block_size,
                    reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num,
                    max_num_batched_tokens=self.cfg.max_num_batched_tokens,
                    batch=num_prefill_batch,
                )

改为

                tasks = list()
                while (len(tasks) < 8):
                    print("===RyanDebug, Begin to collect tasks ===")
                    print("====The self.resource_manager.available_block_num is:", self.resource_manager.available_block_num())
                    print("====The self.cfg.cache_config.block_size is:", self.cfg.cache_config.block_size)
                    print("====The self.cfg.cache_config.enc_dec_block_num is:", self.cfg.cache_config.enc_dec_block_num)
                    print("====The self.cfg.max_num_batched_tokens is:", self.cfg.max_num_batched_tokens)
                    print("===RyanDebug, num_prefill_batch is: ",8)

                    tmp_task = self.scheduler.get_requests(
                        available_blocks=5000,
                        block_size=self.cfg.cache_config.block_size,
                        reserved_output_blocks=self.cfg.cache_config.
                        enc_dec_block_num,
                        max_num_batched_tokens=self.cfg.max_num_batched_tokens,
                        batch=8)
                    print("===RyanDebug, the tmp_task is :", tmp_task)

                    if isinstance(tmp_task, list):
                        tasks.extend(tmp_task)
                    elif tmp_task is not None:
                        tasks.append(tmp_task)

                print("===RyanDebug, Finish Fix task, the len of tasks is {} ===", len(tasks))
                print("===RyanDebug, Finish Fix task, the tasks is {} ===", tasks)

这样就是不开启动态插入的逻辑,需要等待8个prompt来(数字可更改),这8个prompt才会一起进入prefill(多个prompt纯P加速),一起进入decode。
3.在fastdeploy/config.py的init_with_cudagrpah_size中,512为capture prefill时最大capture size,可以手动更改。

if self.graph_opt_config.cudagraph_only_prefill:
            self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=512)

TODO:buffer_size的大小需要进一步确认。

Copy link

paddle-bot bot commented Aug 18, 2025

Thanks for your contribution!

@paddle-bot paddle-bot bot added the contributor External developers label Aug 18, 2025
Comment on lines +127 to +133
self.share_inputs["encoder_batch_ids"] = paddle.full(
shape=[self.max_seq_len], fill_value=0, dtype="int32"
) # gpu
self.share_inputs["encoder_tile_ids_per_batch"] = paddle.full(
shape=[self.max_seq_len], fill_value=0, dtype="int32"
) # gpu
self.share_inputs["encoder_num_blocks"] = paddle.full(shape=[1], fill_value=0, dtype="int32").cpu() # cpu
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些 buffer 改为在 gpu_model_runner 里面管理,同时改造下 get_block_shape_and_split_kv 这个Kernel,把 encoder 相关的tensor 改为 Inplace 的实现,不然前处理的 copy_ 开销太高了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

暂时提交了一版还没验证过的草稿版,怕服务器挂代码没了

Comment on lines 187 to 207
self.share_inputs["encoder_batch_ids"].copy_(temp_encoder_batch_ids, False)
metadata.encoder_batch_ids = self.share_inputs["encoder_batch_ids"]

self.share_inputs["encoder_tile_ids_per_batch"].copy_(temp_encoder_tile_ids_per_batch, False)
metadata.encoder_tile_ids_per_batch = self.share_inputs["encoder_tile_ids_per_batch"]

self.share_inputs["encoder_num_blocks"].copy_(temp_encoder_num_blocks, False)
metadata.encoder_num_blocks = self.share_inputs["encoder_num_blocks"]

self.share_inputs["kv_batch_ids"].copy_(temp_kv_batch_ids, False)
metadata.kv_batch_ids = self.share_inputs["kv_batch_ids"]

self.share_inputs["kv_tile_ids_per_batch"].copy_(temp_kv_tile_ids_per_batch, False)
metadata.kv_tile_ids_per_batch = self.share_inputs["kv_tile_ids_per_batch"]

self.share_inputs["kv_num_blocks"].copy_(temp_kv_num_blocks, False)
metadata.kv_num_blocks = self.share_inputs["kv_num_blocks"]

self.share_inputs["max_len_kv"].copy_(temp_max_len_kv, False)
metadata.max_len_kv = self.share_inputs["max_len_kv"]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy_ 开销太高了

Comment on lines 174 to 177
if int(paddle.max(self.share_inputs["seq_lens_decoder"])) > 0:
return 1
else:
return 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

写成一行吧,有点丑

@@ -561,7 +571,7 @@ def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decod
if self.fd_config.parallel_config.enable_expert_parallel:
full_length = min(full_length, 32)

input_length = int(full_length * self.cache_config.kv_cache_ratio)
input_length = int(full_length)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删了这行

Comment on lines 925 to 953

self.forward_meta.step_use_cudagraph = (
only_decode_use_cudagraph = (
self.use_cudagraph
and only_decode_batch
and not (prefill_exists if prefill_exists is not None else self.exist_prefill())
)

# Update Batch type for cuda graph for only_prefill_batch
only_prefill_batch = True
decode_exists = None
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
# 收集所有 worker 的状态
only_prefill_batch_list = []
decode_exists = self.exist_decode()
paddle.distributed.all_gather_object(only_prefill_batch_list, not decode_exists)
only_prefill_batch = all(only_prefill_batch_list)

only_prefill_use_cudagraph = (
self.use_cudagraph
and self.cudagraph_only_prefill
and only_prefill_batch
and not (decode_exists if decode_exists is not None else self.exist_decode())
)

# When support capture both prefill-only and decode-only, this will use [only_prefill_use_cudagraph or only_decode_use_cudagraph]
self.forward_meta.step_use_cudagraph = (
only_prefill_use_cudagraph if self.cudagraph_only_prefill else only_decode_use_cudagraph
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

分支太多了,Only Decode 和 Only Prefill 的能不能合并或者封装下

@@ -1230,7 +1262,7 @@ def _update_chunked_prefill(self, tasks):
self.proposer.update_task_chunk_prefill(task)
task.chunk_idx += 1

def capture_model(self) -> None:
def capture_model(self, capture_prefill: bool = False) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

capture_prefill 变量命名有歧义 -> only prefill 或 PurePrefill

Comment on lines 1294 to 1295
logger.info(
f"Warm up the model with the batch size:{batch_size}, expected_decode_len:{expected_decode_len}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logger 没改

Comment on lines +1276 to +1293
for num_tokens in sorted(capture_sizes, reverse=True):
self._dummy_run(
num_tokens=num_tokens,
batch_size=1,
in_capturing=True,
expected_decode_len=expected_decode_len,
)
logger.info(
f"Warm up the model with the num_tokens:{num_tokens}, expected_decode_len:{expected_decode_len}"
)
else:
for batch_size in sorted(capture_sizes, reverse=True):
self._dummy_run(
num_tokens=self.parallel_config.max_num_batched_tokens,
batch_size=batch_size,
in_capturing=True,
expected_decode_len=expected_decode_len,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

俩循环合并了吧,只有batch_size 不一样

Comment on lines 1374 to 1376
print("传递给model的seq_lens_this_time", self.forward_meta.seq_lens_this_time)
print("input_ids", self.forward_meta.input_ids.shape)
print("self.share_inputs[ids_remove_padding].shape:", self.share_inputs["ids_remove_padding"].shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

print 没删

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants