@@ -245,8 +245,7 @@ def __init__(self, model_runner: ModelRunner):
245
245
)
246
246
else :
247
247
self .encoder_lens = None
248
-
249
- if self .enable_dp_attention :
248
+ if self .enable_dp_attention or self .dp_size == 1 :
250
249
self .gathered_buffer = torch .zeros (
251
250
(
252
251
self .max_bs * self .dp_size * self .num_tokens_per_bs ,
@@ -288,7 +287,7 @@ def model_capture_mode(self):
288
287
self .model_runner .token_to_kv_pool .capture_mode = False
289
288
290
289
def can_run (self , forward_batch : ForwardBatch ):
291
- if self .enable_dp_attention :
290
+ if self .enable_dp_attention or self . dp_size == 1 :
292
291
total_global_tokens = sum (forward_batch .global_num_tokens_cpu )
293
292
294
293
is_bs_supported = forward_batch .can_run_dp_cuda_graph and (
@@ -369,7 +368,7 @@ def capture_one_batch_size(self, bs: int, forward: Callable):
369
368
encoder_lens = None
370
369
mrope_positions = self .mrope_positions [:, :bs ]
371
370
372
- if self .enable_dp_attention :
371
+ if self .enable_dp_attention or self . dp_size == 1 :
373
372
self .global_num_tokens_gpu .copy_ (
374
373
torch .tensor (
375
374
[
@@ -471,7 +470,7 @@ def replay_prepare(self, forward_batch: ForwardBatch):
471
470
raw_num_token = raw_bs * self .num_tokens_per_bs
472
471
473
472
# Pad
474
- if self .enable_dp_attention :
473
+ if self .enable_dp_attention or self . dp_size == 1 :
475
474
index = bisect .bisect_left (
476
475
self .capture_bs , sum (forward_batch .global_num_tokens_cpu )
477
476
)
@@ -497,7 +496,7 @@ def replay_prepare(self, forward_batch: ForwardBatch):
497
496
self .encoder_lens [:raw_bs ].copy_ (forward_batch .encoder_lens )
498
497
if forward_batch .mrope_positions is not None :
499
498
self .mrope_positions [:, :raw_bs ].copy_ (forward_batch .mrope_positions )
500
- if self .enable_dp_attention :
499
+ if self .enable_dp_attention or self . dp_size == 1 :
501
500
self .global_num_tokens_gpu .copy_ (forward_batch .global_num_tokens_gpu )
502
501
503
502
if hasattr (forward_batch .spec_info , "hidden_states" ):
0 commit comments