Skip to content

Commit d51a0f4

Browse files
committed
support spliting static graph into piecewise graph with cuda_graph
1 parent be94bdd commit d51a0f4

File tree

1 file changed

+65
-0
lines changed

1 file changed

+65
-0
lines changed

fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414
# limitations under the License.
1515
"""
1616

17+
from contextlib import contextmanager
1718
from dataclasses import dataclass
1819
from typing import Callable, Dict, Optional
1920

2021
import paddle.nn.layer
2122
from paddle.device.cuda import graphs
23+
from paddle.jit.dy2static.utils import CUDAGraphState
2224

2325
from fastdeploy.config import FDConfig
2426
from fastdeploy.distributed.communication import capture_custom_allreduce
@@ -48,6 +50,35 @@ class ConcreteSizeEntry:
4850
output_buffer: Optional[paddle.Tensor] = None
4951

5052

53+
class Dy2StCudaGraphManager:
54+
def __init__(self):
55+
self.state = CUDAGraphState.DISABLE
56+
self.captrued_batch_size = set()
57+
self.batch_size = -1
58+
59+
def run_impl(self, original_run_impl, inputs, parameters, attrs):
60+
run_state = self.state
61+
prog_attrs, cuda_graph_attrs = attrs
62+
if run_state == CUDAGraphState.REPLAY:
63+
if self.batch_size not in self.captrued_batch_size:
64+
run_state = CUDAGraphState.DISABLE
65+
elif run_state == CUDAGraphState.CAPTURE:
66+
self.captrued_batch_size.add(self.batch_size)
67+
68+
cuda_graph_attrs |= {
69+
"cuda_graph_state": run_state,
70+
"cuda_graph_dispatch_key": self.batch_size if run_state != CUDAGraphState.DISABLE else 0,
71+
}
72+
return original_run_impl(inputs, parameters, (prog_attrs, cuda_graph_attrs))
73+
74+
@contextmanager
75+
def run_impl_guard(self):
76+
with paddle.jit.dy2static.pir_partial_program.replace_run_impl_guard(
77+
self.run_impl,
78+
):
79+
yield
80+
81+
5182
class CudaGraphPiecewiseBackend:
5283
"""Manage the capture and replay of CUDA graphs at the subgraph level."""
5384

@@ -68,10 +99,41 @@ def __init__(
6899
for shape in self.cudagraph_capture_sizes:
69100
self.concrete_size_entries[shape] = ConcreteSizeEntry(runtime_bs=shape)
70101

102+
self.cuda_graph_manager = None
103+
if self.fd_config.graph_opt_config.graph_opt_level > 0:
104+
self.cuda_graph_manager = Dy2StCudaGraphManager()
105+
71106
logger.info(
72107
f"[CUDA GRAPH] CUDAGraph capture list {self.cudagraph_capture_sizes}, " "Created all batch sizes entry."
73108
)
74109

110+
def run_static_model(self, entry: ConcreteSizeEntry, **kwargs):
111+
if not entry.captured:
112+
# Warmup the model
113+
for n in range(entry.num_finished_warmup, self.warm_up_size):
114+
entry.num_finished_warmup += 1
115+
entry.runnable(**kwargs)
116+
logger.debug(
117+
f"[CUDA GRAPH] Warm up for batch size {entry.runtime_bs}, "
118+
f"finished ({n + 1}/{entry.num_finished_warmup}) times"
119+
)
120+
121+
# Store input addresses for debug
122+
input_addresses = [x.data_ptr() for (_, x) in kwargs.items() if isinstance(x, paddle.Tensor)]
123+
entry.input_addresses = input_addresses
124+
125+
# Capture
126+
self.cuda_graph_manager.state = CUDAGraphState.CAPTURE
127+
self.cuda_graph_manager.batch_size = entry.runtime_bs
128+
with self.cuda_graph_manager.run_impl_guard():
129+
entry.runnable(**kwargs)
130+
131+
# Replay
132+
self.cuda_graph_manager.state = CUDAGraphState.REPLAY
133+
self.cuda_graph_manager.batch_size = entry.runtime_bs
134+
with self.cuda_graph_manager.run_impl_guard():
135+
return entry.runnable(**kwargs)
136+
75137
def __call__(self, **kwargs):
76138
# Get batch size
77139
ids_remove_padding: paddle.Tensor = kwargs["ids_remove_padding"]
@@ -91,6 +153,9 @@ def __call__(self, **kwargs):
91153
if not entry.use_cudagraph:
92154
return entry.runnable(**kwargs)
93155

156+
if self.fd_config.graph_opt_config.graph_opt_level > 0:
157+
return self.run_static_model(entry, **kwargs)
158+
94159
# Capture a new cuda graph
95160
if entry.cuda_graph is None:
96161
# Warmup the model

0 commit comments

Comments
 (0)