14
14
# limitations under the License.
15
15
"""
16
16
17
+ from contextlib import contextmanager
17
18
from dataclasses import dataclass
18
19
from typing import Callable , Dict , Optional
19
20
20
21
import paddle .nn .layer
21
22
from paddle .device .cuda import graphs
23
+ from paddle .jit .dy2static .utils import CUDAGraphState
22
24
23
25
from fastdeploy .config import FDConfig
24
26
from fastdeploy .distributed .communication import capture_custom_allreduce
@@ -48,6 +50,35 @@ class ConcreteSizeEntry:
48
50
output_buffer : Optional [paddle .Tensor ] = None
49
51
50
52
53
+ class Dy2StCudaGraphManager :
54
+ def __init__ (self ):
55
+ self .state = CUDAGraphState .DISABLE
56
+ self .captrued_batch_size = set ()
57
+ self .batch_size = - 1
58
+
59
+ def run_impl (self , original_run_impl , inputs , parameters , attrs ):
60
+ run_state = self .state
61
+ prog_attrs , cuda_graph_attrs = attrs
62
+ if run_state == CUDAGraphState .REPLAY :
63
+ if self .batch_size not in self .captrued_batch_size :
64
+ run_state = CUDAGraphState .DISABLE
65
+ elif run_state == CUDAGraphState .CAPTURE :
66
+ self .captrued_batch_size .add (self .batch_size )
67
+
68
+ cuda_graph_attrs |= {
69
+ "cuda_graph_state" : run_state ,
70
+ "cuda_graph_dispatch_key" : self .batch_size if run_state != CUDAGraphState .DISABLE else 0 ,
71
+ }
72
+ return original_run_impl (inputs , parameters , (prog_attrs , cuda_graph_attrs ))
73
+
74
+ @contextmanager
75
+ def run_impl_guard (self ):
76
+ with paddle .jit .dy2static .pir_partial_program .replace_run_impl_guard (
77
+ self .run_impl ,
78
+ ):
79
+ yield
80
+
81
+
51
82
class CudaGraphPiecewiseBackend :
52
83
"""Manage the capture and replay of CUDA graphs at the subgraph level."""
53
84
@@ -68,10 +99,41 @@ def __init__(
68
99
for shape in self .cudagraph_capture_sizes :
69
100
self .concrete_size_entries [shape ] = ConcreteSizeEntry (runtime_bs = shape )
70
101
102
+ self .cuda_graph_manager = None
103
+ if self .fd_config .graph_opt_config .graph_opt_level > 0 :
104
+ self .cuda_graph_manager = Dy2StCudaGraphManager ()
105
+
71
106
logger .info (
72
107
f"[CUDA GRAPH] CUDAGraph capture list { self .cudagraph_capture_sizes } , " "Created all batch sizes entry."
73
108
)
74
109
110
+ def run_static_model (self , entry : ConcreteSizeEntry , ** kwargs ):
111
+ if not entry .captured :
112
+ # Warmup the model
113
+ for n in range (entry .num_finished_warmup , self .warm_up_size ):
114
+ entry .num_finished_warmup += 1
115
+ entry .runnable (** kwargs )
116
+ logger .debug (
117
+ f"[CUDA GRAPH] Warm up for batch size { entry .runtime_bs } , "
118
+ f"finished ({ n + 1 } /{ entry .num_finished_warmup } ) times"
119
+ )
120
+
121
+ # Store input addresses for debug
122
+ input_addresses = [x .data_ptr () for (_ , x ) in kwargs .items () if isinstance (x , paddle .Tensor )]
123
+ entry .input_addresses = input_addresses
124
+
125
+ # Capture
126
+ self .cuda_graph_manager .state = CUDAGraphState .CAPTURE
127
+ self .cuda_graph_manager .batch_size = entry .runtime_bs
128
+ with self .cuda_graph_manager .run_impl_guard ():
129
+ entry .runnable (** kwargs )
130
+
131
+ # Replay
132
+ self .cuda_graph_manager .state = CUDAGraphState .REPLAY
133
+ self .cuda_graph_manager .batch_size = entry .runtime_bs
134
+ with self .cuda_graph_manager .run_impl_guard ():
135
+ return entry .runnable (** kwargs )
136
+
75
137
def __call__ (self , ** kwargs ):
76
138
# Get batch size
77
139
ids_remove_padding : paddle .Tensor = kwargs ["ids_remove_padding" ]
@@ -91,6 +153,9 @@ def __call__(self, **kwargs):
91
153
if not entry .use_cudagraph :
92
154
return entry .runnable (** kwargs )
93
155
156
+ if self .fd_config .graph_opt_config .graph_opt_level > 0 :
157
+ return self .run_static_model (entry , ** kwargs )
158
+
94
159
# Capture a new cuda graph
95
160
if entry .cuda_graph is None :
96
161
# Warmup the model
0 commit comments