29
29
30
30
@dataclass
31
31
class ConcreteSizeEntry :
32
- """Record the concrete information corresponding to the current batch size """
32
+ """Record the concrete information corresponding to the current shape(num_tokens) """
33
33
34
- # Concrete batch size
34
+ # Concrete shape
35
35
runtime_bs : int
36
36
# The size is in cudagraph_capture_sizes
37
37
use_cudagraph : bool = True
@@ -42,7 +42,7 @@ class ConcreteSizeEntry:
42
42
runnable : Callable = None # type: ignore
43
43
# Number of completed warmups
44
44
num_finished_warmup : int = 0
45
- # Captured cuda graph object corresponding to the current batch size
45
+ # Captured cuda graph object corresponding to the current real shape
46
46
cuda_graph : Optional [graphs .CUDAGraph ] = None
47
47
# Output buffer of cudagraph
48
48
output_buffer : Optional [paddle .Tensor ] = None
@@ -60,33 +60,33 @@ def __init__(
60
60
self .runnable = runnable
61
61
self .cudagraph_capture_sizes = fd_config .graph_opt_config .cudagraph_capture_sizes
62
62
self .warm_up_size = fd_config .graph_opt_config .cudagraph_num_of_warmups
63
- self .batch_size_to_captured_size = fd_config .graph_opt_config .batch_size_to_captured_size
63
+ self .real_shape_to_captured_size = fd_config .graph_opt_config .real_shape_to_captured_size
64
64
65
- # Runtime batch size -> ConcreteSizeEntry
65
+ # Runtime real shape -> ConcreteSizeEntry
66
66
self .concrete_size_entries : Dict [int , ConcreteSizeEntry ] = {}
67
67
68
68
for shape in self .cudagraph_capture_sizes :
69
69
self .concrete_size_entries [shape ] = ConcreteSizeEntry (runtime_bs = shape )
70
70
71
71
logger .info (
72
- f"[CUDA GRAPH] CUDAGraph capture list { self .cudagraph_capture_sizes } , " "Created all batch sizes entry."
72
+ f"[CUDA GRAPH] CUDAGraph capture list { self .cudagraph_capture_sizes } , " "Created all real shape entry."
73
73
)
74
74
75
75
def __call__ (self , ** kwargs ):
76
- # Get batch size
76
+ # Get real shape(all num tokens)
77
77
ids_remove_padding : paddle .Tensor = kwargs ["ids_remove_padding" ]
78
- batch_size = ids_remove_padding .shape [0 ]
79
- padding_batch_size = self .batch_size_to_captured_size [ batch_size ]
78
+ real_shape = ids_remove_padding .shape [0 ]
79
+ padding_real_shape = self .real_shape_to_captured_size [ real_shape ]
80
80
logger .debug (
81
- f"[CUDA GRAPH] The actual batch size obtained by CUDAGraph is :{ batch_size } , "
82
- f"The padded batch size is :{ padding_batch_size } "
81
+ f"[CUDA GRAPH] The actual real shape obtained by CUDAGraph is :{ real_shape } , "
82
+ f"The padded shape is :{ padding_real_shape } "
83
83
)
84
84
85
- entry = self .concrete_size_entries .get (padding_batch_size )
86
- assert entry is not None , f"Batch size: { padding_batch_size } is not in cuda graph capture list."
85
+ entry = self .concrete_size_entries .get (padding_real_shape )
86
+ assert entry is not None , f"real shape: { padding_real_shape } is not in cuda graph capture list."
87
87
if entry .runnable is None :
88
88
entry .runnable = self .runnable
89
- logger .debug (f"[CUDA GRAPH] New entry lazy initialize with batch size { padding_batch_size } " )
89
+ logger .debug (f"[CUDA GRAPH] New entry lazy initialize with real shape { padding_real_shape } " )
90
90
91
91
if not entry .use_cudagraph :
92
92
return entry .runnable (** kwargs )
@@ -98,7 +98,7 @@ def __call__(self, **kwargs):
98
98
entry .num_finished_warmup += 1
99
99
entry .runnable (** kwargs )
100
100
logger .debug (
101
- f"[CUDA GRAPH] Warm up for batch size { padding_batch_size } , "
101
+ f"[CUDA GRAPH] Warm up for real shape { padding_real_shape } , "
102
102
f"finished ({ n + 1 } /{ entry .num_finished_warmup } ) times"
103
103
)
104
104
@@ -122,9 +122,9 @@ def __call__(self, **kwargs):
122
122
output ._clear
123
123
124
124
paddle .device .synchronize ()
125
- logger .debug (f"[CUDA GRAPH] CUDAGraph captured for batch size { padding_batch_size } " )
125
+ logger .debug (f"[CUDA GRAPH] CUDAGraph captured for real shape { padding_real_shape } " )
126
126
127
127
# Replay
128
128
entry .cuda_graph .replay ()
129
- logger .debug (f"[CUDA GRAPH] CUDAGraph replayed for batch size { padding_batch_size } " )
129
+ logger .debug (f"[CUDA GRAPH] CUDAGraph replayed for real shape { padding_real_shape } " )
130
130
return entry .output_buffer
0 commit comments