Skip to content

[Excutor] Change cudagraph hashkey from batch size to num_tokens #3454

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def __init__(
self.full_cuda_graph: bool = True

self.max_capture_size: int = None
self.batch_size_to_captured_size: dict[int, int] = None
self.real_shape_to_captured_size: dict[int, int] = None
# CINN Config ...
if args is not None:
for key, value in args.items():
Expand Down Expand Up @@ -516,26 +516,26 @@ def init_with_cudagrpah_size(self, max_num_seqs: int = 0) -> None:
self.cudagraph_capture_sizes.sort(reverse=True)
self.max_capture_size = self.cudagraph_capture_sizes[0] if self.cudagraph_capture_sizes else 0

# Pre-compute the mapping from batch size to padded graph size
self.batch_size_to_captured_size = {}
# Pre-compute the mapping from shape to padded graph size
self.real_shape_to_captured_size = {}
for end, start in zip(self.cudagraph_capture_sizes, self.cudagraph_capture_sizes[1:] + [0]):
for bs in range(start, end):
if bs == start:
self.batch_size_to_captured_size[bs] = start
self.real_shape_to_captured_size[bs] = start
else:
self.batch_size_to_captured_size[bs] = end
self.batch_size_to_captured_size[self.max_capture_size] = self.max_capture_size
self.real_shape_to_captured_size[bs] = end
self.real_shape_to_captured_size[self.max_capture_size] = self.max_capture_size

def _set_cudagraph_sizes(self, max_num_seqs: int = 0):
"""
Calculate a series of candidate capture batch sizes,
Calculate a series of candidate capture sizes,
and then extract a portion of them as the capture list for the CUDA graph based on user input.
"""
# Batch Size [1, 2, 4, 8, 16, ... 120, 128]
# Shape [1, 2, 4, 8, 16, ... 120, 128]
draft_capture_sizes = [1, 2, 4] + [8 * i for i in range(1, 17)]
# Batch Size [128, 144, ... 240, 256]
# Shape [128, 144, ... 240, 256]
draft_capture_sizes += [16 * i for i in range(9, 17)]
# Batch Size [256, 288, ... 992, 1024]
# Shape [256, 288, ... 992, 1024]
draft_capture_sizes += [32 * i for i in range(17, 33)]

draft_capture_sizes.append(max_num_seqs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@

@dataclass
class ConcreteSizeEntry:
"""Record the concrete information corresponding to the current batch size"""
"""Record the concrete information corresponding to the current shape(num_tokens)"""

# Concrete batch size
# Concrete shape
runtime_bs: int
Copy link
Collaborator

@gongshaotian gongshaotian Aug 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

runtime_bs -> captured_size 下一个PR改吧

# The size is in cudagraph_capture_sizes
use_cudagraph: bool = True
Expand All @@ -42,7 +42,7 @@ class ConcreteSizeEntry:
runnable: Callable = None # type: ignore
# Number of completed warmups
num_finished_warmup: int = 0
# Captured cuda graph object corresponding to the current batch size
# Captured cuda graph object corresponding to the current real shape
cuda_graph: Optional[graphs.CUDAGraph] = None
# Output buffer of cudagraph
output_buffer: Optional[paddle.Tensor] = None
Expand All @@ -60,33 +60,33 @@ def __init__(
self.runnable = runnable
self.cudagraph_capture_sizes = fd_config.graph_opt_config.cudagraph_capture_sizes
self.warm_up_size = fd_config.graph_opt_config.cudagraph_num_of_warmups
self.batch_size_to_captured_size = fd_config.graph_opt_config.batch_size_to_captured_size
self.real_shape_to_captured_size = fd_config.graph_opt_config.real_shape_to_captured_size

# Runtime batch size -> ConcreteSizeEntry
# Runtime real shape -> ConcreteSizeEntry
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}

for shape in self.cudagraph_capture_sizes:
self.concrete_size_entries[shape] = ConcreteSizeEntry(runtime_bs=shape)

logger.info(
f"[CUDA GRAPH] CUDAGraph capture list {self.cudagraph_capture_sizes}, " "Created all batch sizes entry."
f"[CUDA GRAPH] CUDAGraph capture list {self.cudagraph_capture_sizes}, " "Created all real shape entry."
)

def __call__(self, **kwargs):
# Get batch size
# Get real shape(all num tokens)
ids_remove_padding: paddle.Tensor = kwargs["ids_remove_padding"]
batch_size = ids_remove_padding.shape[0]
padding_batch_size = self.batch_size_to_captured_size[batch_size]
real_shape = ids_remove_padding.shape[0]
padding_real_shape = self.real_shape_to_captured_size[real_shape]
logger.debug(
f"[CUDA GRAPH] The actual batch size obtained by CUDAGraph is :{batch_size}, "
f"The padded batch size is :{padding_batch_size}"
f"[CUDA GRAPH] The actual real shape obtained by CUDAGraph is :{real_shape}, "
f"The padded shape is :{padding_real_shape}"
)

entry = self.concrete_size_entries.get(padding_batch_size)
assert entry is not None, f"Batch size:{padding_batch_size} is not in cuda graph capture list."
entry = self.concrete_size_entries.get(padding_real_shape)
assert entry is not None, f"real shape:{padding_real_shape} is not in cuda graph capture list."
if entry.runnable is None:
entry.runnable = self.runnable
logger.debug(f"[CUDA GRAPH] New entry lazy initialize with batch size {padding_batch_size}")
logger.debug(f"[CUDA GRAPH] New entry lazy initialize with real shape {padding_real_shape}")

if not entry.use_cudagraph:
return entry.runnable(**kwargs)
Expand All @@ -98,7 +98,7 @@ def __call__(self, **kwargs):
entry.num_finished_warmup += 1
entry.runnable(**kwargs)
logger.debug(
f"[CUDA GRAPH] Warm up for batch size {padding_batch_size}, "
f"[CUDA GRAPH] Warm up for real shape {padding_real_shape}, "
f"finished ({n + 1}/{entry.num_finished_warmup}) times"
)

Expand All @@ -122,9 +122,9 @@ def __call__(self, **kwargs):
output._clear

paddle.device.synchronize()
logger.debug(f"[CUDA GRAPH] CUDAGraph captured for batch size {padding_batch_size}")
logger.debug(f"[CUDA GRAPH] CUDAGraph captured for real shape {padding_real_shape}")

# Replay
entry.cuda_graph.replay()
logger.debug(f"[CUDA GRAPH] CUDAGraph replayed for batch size {padding_batch_size}")
logger.debug(f"[CUDA GRAPH] CUDAGraph replayed for real shape {padding_real_shape}")
return entry.output_buffer
Loading