Skip to content

Commit 403e79a

Browse files
committed
adapt to pure tp
1 parent 08021ee commit 403e79a

File tree

3 files changed

+7
-11
lines changed

3 files changed

+7
-11
lines changed

python/sglang/srt/managers/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1180,7 +1180,7 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
11801180
ret = None
11811181

11821182
# Handle DP attention
1183-
if self.server_args.enable_dp_attention:
1183+
if self.server_args.enable_dp_attention or self.dp_size == 1:
11841184
ret, _ = self.prepare_dp_attn_batch(ret)
11851185

11861186
return ret

python/sglang/srt/model_executor/cuda_graph_runner.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,7 @@ def __init__(self, model_runner: ModelRunner):
245245
)
246246
else:
247247
self.encoder_lens = None
248-
249-
if self.enable_dp_attention:
248+
if self.enable_dp_attention or self.dp_size == 1:
250249
self.gathered_buffer = torch.zeros(
251250
(
252251
self.max_bs * self.dp_size * self.num_tokens_per_bs,
@@ -288,7 +287,7 @@ def model_capture_mode(self):
288287
self.model_runner.token_to_kv_pool.capture_mode = False
289288

290289
def can_run(self, forward_batch: ForwardBatch):
291-
if self.enable_dp_attention:
290+
if self.enable_dp_attention or self.dp_size == 1:
292291
total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
293292

294293
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
@@ -369,7 +368,7 @@ def capture_one_batch_size(self, bs: int, forward: Callable):
369368
encoder_lens = None
370369
mrope_positions = self.mrope_positions[:, :bs]
371370

372-
if self.enable_dp_attention:
371+
if self.enable_dp_attention or self.dp_size == 1:
373372
self.global_num_tokens_gpu.copy_(
374373
torch.tensor(
375374
[
@@ -471,7 +470,7 @@ def replay_prepare(self, forward_batch: ForwardBatch):
471470
raw_num_token = raw_bs * self.num_tokens_per_bs
472471

473472
# Pad
474-
if self.enable_dp_attention:
473+
if self.enable_dp_attention or self.dp_size == 1:
475474
index = bisect.bisect_left(
476475
self.capture_bs, sum(forward_batch.global_num_tokens_cpu)
477476
)
@@ -497,7 +496,7 @@ def replay_prepare(self, forward_batch: ForwardBatch):
497496
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
498497
if forward_batch.mrope_positions is not None:
499498
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
500-
if self.enable_dp_attention:
499+
if self.enable_dp_attention or self.dp_size == 1:
501500
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
502501

503502
if hasattr(forward_batch.spec_info, "hidden_states"):

python/sglang/srt/model_executor/model_runner.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -280,10 +280,7 @@ def model_specific_adjustment(self):
280280

281281
if server_args.enable_deepep_moe:
282282
logger.info("DeepEP is turned on.")
283-
assert (
284-
server_args.enable_dp_attention == True
285-
), "Currently DeepEP is bind to Attention DP. Set '--enable-dp-attention --enable-deepep-moe'"
286-
283+
287284
def init_torch_distributed(self):
288285
logger.info("Init torch distributed begin.")
289286

0 commit comments

Comments
 (0)