Skip to content

[CudaGraph] [SOT] Support spliting static graph into piecewise graph with cuda_graph #3478

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
# limitations under the License.
"""

from contextlib import contextmanager
from dataclasses import dataclass
from typing import Callable, Dict, Optional

import paddle.nn.layer
from paddle.device.cuda import graphs
from paddle.jit.dy2static.utils import CUDAGraphState

from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication import capture_custom_allreduce
Expand Down Expand Up @@ -48,6 +50,35 @@ class ConcreteSizeEntry:
output_buffer: Optional[paddle.Tensor] = None


class Dy2StCudaGraphManager:
def __init__(self):
self.state = CUDAGraphState.DISABLE
self.captrued_batch_size = set()
self.batch_size = -1

def run_impl(self, original_run_impl, inputs, parameters, attrs):
run_state = self.state
prog_attrs, cuda_graph_attrs = attrs
if run_state == CUDAGraphState.REPLAY:
if self.batch_size not in self.captrued_batch_size:
run_state = CUDAGraphState.DISABLE
elif run_state == CUDAGraphState.CAPTURE:
self.captrued_batch_size.add(self.batch_size)

cuda_graph_attrs |= {
"cuda_graph_state": run_state,
"cuda_graph_dispatch_key": self.batch_size if run_state != CUDAGraphState.DISABLE else 0,
}
return original_run_impl(inputs, parameters, (prog_attrs, cuda_graph_attrs))

@contextmanager
def run_impl_guard(self):
with paddle.jit.dy2static.pir_partial_program.replace_run_impl_guard(
self.run_impl,
):
yield


class CudaGraphPiecewiseBackend:
"""Manage the capture and replay of CUDA graphs at the subgraph level."""

Expand All @@ -68,10 +99,41 @@ def __init__(
for shape in self.cudagraph_capture_sizes:
self.concrete_size_entries[shape] = ConcreteSizeEntry(runtime_bs=shape)

self.cuda_graph_manager = None
if self.fd_config.graph_opt_config.graph_opt_level > 0:
self.cuda_graph_manager = Dy2StCudaGraphManager()

logger.info(
f"[CUDA GRAPH] CUDAGraph capture list {self.cudagraph_capture_sizes}, " "Created all batch sizes entry."
)

def run_static_model(self, entry: ConcreteSizeEntry, **kwargs):
if not entry.captured:
# Warmup the model
for n in range(entry.num_finished_warmup, self.warm_up_size):
entry.num_finished_warmup += 1
entry.runnable(**kwargs)
logger.debug(
f"[CUDA GRAPH] Warm up for batch size {entry.runtime_bs}, "
f"finished ({n + 1}/{entry.num_finished_warmup}) times"
)

# Store input addresses for debug
input_addresses = [x.data_ptr() for (_, x) in kwargs.items() if isinstance(x, paddle.Tensor)]
entry.input_addresses = input_addresses

# Capture
self.cuda_graph_manager.state = CUDAGraphState.CAPTURE
self.cuda_graph_manager.batch_size = entry.runtime_bs
with self.cuda_graph_manager.run_impl_guard():
entry.runnable(**kwargs)

# Replay
self.cuda_graph_manager.state = CUDAGraphState.REPLAY
self.cuda_graph_manager.batch_size = entry.runtime_bs
with self.cuda_graph_manager.run_impl_guard():
return entry.runnable(**kwargs)

def __call__(self, **kwargs):
# Get batch size
ids_remove_padding: paddle.Tensor = kwargs["ids_remove_padding"]
Expand All @@ -91,6 +153,9 @@ def __call__(self, **kwargs):
if not entry.use_cudagraph:
return entry.runnable(**kwargs)

if self.fd_config.graph_opt_config.graph_opt_level > 0:
return self.run_static_model(entry, **kwargs)

# Capture a new cuda graph
if entry.cuda_graph is None:
# Warmup the model
Expand Down
Loading