Skip to content

Commit 7f19e08

Browse files
RunkaiTaoch-wan
andauthored
Support (1 <= dp < tp) in the dp attention in DeepEP (#4770)
Co-authored-by: Cheng Wan <cwan39@gatech.edu>
1 parent 98a2cfa commit 7f19e08

File tree

10 files changed

+238
-47
lines changed

10 files changed

+238
-47
lines changed

docs/backend/server_arguments.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ Please consult the documentation below to learn more about the parameters you ma
9090
### Expert parallelism
9191
* `enable_ep_moe`: Enables expert parallelism that distributes the experts onto multiple GPUs for MoE models.
9292
* `ep_size`: The size of EP. Please shard the model weights with `tp_size=ep_size`, for detailed benchmarking refer to [this PR](https://github.com/sgl-project/sglang/pull/2203). If not set, `ep_size` will be automatically set to `tp_size`.
93-
* `enable_deepep_moe`: Enables expert parallelism that distributes the experts onto multiple GPUs for DeepSeek-V3 model based on deepseek-ai/DeepEP. Currently DeepEP is bind to DP Attention. Please set `--enable-dp-attention --enable-deepep-moe`, perfer `tp_size=dp_size=ep_size`.
93+
* `enable_deepep_moe`: Enables expert parallelism that distributes the experts onto multiple GPUs for DeepSeek-V3 model based on deepseek-ai/DeepEP.
9494

9595
## Memory and scheduling
9696

@@ -184,7 +184,7 @@ Please consult the documentation below to learn more about the parameters you ma
184184
*Note: Some of these options are still in experimental stage.*
185185

186186
* `enable_mixed_chunk`: Enables mixing prefill and decode, see [this discussion](https://github.com/sgl-project/sglang/discussions/1163).
187-
* `enable_dp_attention`: Enable [Data Parallelism Attention](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models) for Deepseek models. Note that you need to choose `dp_size = tp_size` for this.
187+
* `enable_dp_attention`: Enable [Data Parallelism Attention](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models) for Deepseek models.
188188
* `enable_torch_compile`: Torch compile the model. Note that compiling a model takes a long time but have a great performance boost. The compiled model can also be [cached for future use](https://docs.sglang.ai/backend/hyperparameter_tuning.html#enabling-cache-for-torch-compile).
189189
* `torch_compile_max_bs`: The maximum batch size when using `torch_compile`.
190190
* `cuda_graph_max_bs`: Adjust the maximum batchsize when using cuda graph. By default this is chosen for you based on GPU specifics.

python/sglang/srt/distributed/device_communicators/custom_all_reduce.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import os
66
from contextlib import contextmanager
77
from functools import wraps
8-
from typing import Callable, List, Optional, TypeVar, Union
8+
from typing import Any, Callable, List, Optional, TypeVar, Union
99

1010
import torch
1111
import torch.distributed as dist

python/sglang/srt/distributed/parallel_state.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,15 @@ def _all_reduce_in_place(self, input_: torch.Tensor) -> None:
439439
else:
440440
torch.distributed.all_reduce(input_, group=self.device_group)
441441

442+
def reduce_scatter(
443+
self,
444+
output: torch.Tensor,
445+
input_list: List[torch.Tensor],
446+
) -> None:
447+
# TODO(ch-wan): support other backends
448+
torch.distributed.reduce_scatter(output, input_list, group=self.device_group)
449+
return output
450+
442451
def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
443452
pynccl_comm = self.pynccl_comm
444453
if pynccl_comm is not None and not pynccl_comm.disabled:
@@ -456,11 +465,23 @@ def all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
456465
output, input, group_name=self.unique_name
457466
)
458467

459-
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
468+
def all_gather(
469+
self,
470+
input_: torch.Tensor,
471+
dim: int = -1,
472+
tensor_list: List[torch.Tensor] = None,
473+
) -> torch.Tensor:
460474
world_size = self.world_size
461475
# Bypass the function if we are using only 1 GPU.
462476
if world_size == 1:
463477
return input_
478+
479+
if tensor_list is not None:
480+
# TODO(ch-wan): support other backends
481+
return torch.distributed.all_gather(
482+
tensor_list, input_, group=self.device_group
483+
)
484+
464485
assert (
465486
-input_.dim() <= dim < input_.dim()
466487
), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"

python/sglang/srt/layers/dp_attention.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import functools
44
import logging
55
from contextlib import contextmanager
6-
from typing import TYPE_CHECKING, Union
6+
from typing import TYPE_CHECKING, List
77

88
import torch
99
import triton
@@ -249,3 +249,14 @@ def dp_scatter(
249249
memcpy_triton(
250250
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
251251
)
252+
253+
254+
def tp_reduce_scatter(
255+
output: torch.Tensor,
256+
input_list: List[torch.Tensor],
257+
):
258+
return get_attention_tp_group().reduce_scatter(output, input_list)
259+
260+
261+
def tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
262+
return get_attention_tp_group().all_gather(input_, tensor_list=output_list)

python/sglang/srt/managers/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1186,7 +1186,7 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
11861186
ret = None
11871187

11881188
# Handle DP attention
1189-
if self.server_args.enable_dp_attention:
1189+
if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
11901190
ret, _ = self.prepare_dp_attn_batch(ret)
11911191

11921192
return ret

python/sglang/srt/model_executor/cuda_graph_runner.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def __init__(self, model_runner: ModelRunner):
174174
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
175175
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
176176
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
177+
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
177178
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
178179
self.tp_size = model_runner.server_args.tp_size
179180
self.dp_size = model_runner.server_args.dp_size
@@ -245,8 +246,8 @@ def __init__(self, model_runner: ModelRunner):
245246
)
246247
else:
247248
self.encoder_lens = None
248-
249-
if self.enable_dp_attention:
249+
if self.enable_dp_attention or self.enable_sp_layernorm:
250+
# TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
250251
self.gathered_buffer = torch.zeros(
251252
(
252253
self.max_bs * self.dp_size * self.num_tokens_per_bs,
@@ -288,7 +289,7 @@ def model_capture_mode(self):
288289
self.model_runner.token_to_kv_pool.capture_mode = False
289290

290291
def can_run(self, forward_batch: ForwardBatch):
291-
if self.enable_dp_attention:
292+
if self.enable_dp_attention or self.enable_sp_layernorm:
292293
total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
293294

294295
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
@@ -369,7 +370,7 @@ def capture_one_batch_size(self, bs: int, forward: Callable):
369370
encoder_lens = None
370371
mrope_positions = self.mrope_positions[:, :bs]
371372

372-
if self.enable_dp_attention:
373+
if self.enable_dp_attention or self.enable_sp_layernorm:
373374
self.global_num_tokens_gpu.copy_(
374375
torch.tensor(
375376
[
@@ -471,7 +472,7 @@ def replay_prepare(self, forward_batch: ForwardBatch):
471472
raw_num_token = raw_bs * self.num_tokens_per_bs
472473

473474
# Pad
474-
if self.enable_dp_attention:
475+
if self.enable_dp_attention or self.enable_sp_layernorm:
475476
index = bisect.bisect_left(
476477
self.capture_bs, sum(forward_batch.global_num_tokens_cpu)
477478
)
@@ -497,7 +498,7 @@ def replay_prepare(self, forward_batch: ForwardBatch):
497498
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
498499
if forward_batch.mrope_positions is not None:
499500
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
500-
if self.enable_dp_attention:
501+
if self.enable_dp_attention or self.enable_sp_layernorm:
501502
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
502503

503504
if hasattr(forward_batch.spec_info, "hidden_states"):

python/sglang/srt/model_executor/model_runner.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -281,9 +281,6 @@ def model_specific_adjustment(self):
281281

282282
if server_args.enable_deepep_moe:
283283
logger.info("DeepEP is turned on.")
284-
assert (
285-
server_args.enable_dp_attention == True
286-
), "Currently DeepEP is bind to Attention DP. Set '--enable-dp-attention --enable-deepep-moe'"
287284

288285
def init_torch_distributed(self):
289286
logger.info("Init torch distributed begin.")

0 commit comments

Comments
 (0)