Skip to content

Commit 7a1f7fc

Browse files
authored
[Feature] Hybrid EP and TP (sgl-project#8590)
1 parent 51c3816 commit 7a1f7fc

File tree

14 files changed

+142
-39
lines changed

14 files changed

+142
-39
lines changed

assets/logo.svg

Lines changed: 1 addition & 1 deletion
Loading

assets/logo_square.svg

Lines changed: 1 addition & 1 deletion
Loading

python/sglang/bench_one_batch.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def from_cli_args(cls, args: argparse.Namespace):
138138
def load_model(server_args, port_args, tp_rank):
139139
suppress_other_loggers()
140140
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
141+
moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
141142

142143
model_config = ModelConfig.from_server_args(server_args)
143144
model_runner = ModelRunner(
@@ -146,6 +147,8 @@ def load_model(server_args, port_args, tp_rank):
146147
gpu_id=tp_rank,
147148
tp_rank=tp_rank,
148149
tp_size=server_args.tp_size,
150+
moe_ep_rank=moe_ep_rank,
151+
moe_ep_size=server_args.ep_size,
149152
pp_rank=0,
150153
pp_size=1,
151154
nccl_port=port_args.nccl_port,

python/sglang/srt/distributed/parallel_state.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,13 @@ def __init__(
354354
self.cpu_group, 1 << 22, 6
355355
)
356356

357+
def __repr__(self):
358+
return (
359+
f"ranks={self.ranks} rank={self.rank} local_rank={self.local_rank} use_pynccl={self.use_pynccl} "
360+
f"device_group={self.device_group} cpu_group={self.cpu_group} unique_name={self.unique_name} "
361+
f"world_size={self.world_size} rank_in_group={self.rank_in_group}"
362+
)
363+
357364
@property
358365
def first_rank(self):
359366
"""Return the global rank of the first process in the group"""
@@ -1141,6 +1148,20 @@ def get_tp_group() -> GroupCoordinator:
11411148
return _TP
11421149

11431150

1151+
_MOE_EP: Optional[GroupCoordinator] = None
1152+
_MOE_TP: Optional[GroupCoordinator] = None
1153+
1154+
1155+
def get_moe_ep_group() -> GroupCoordinator:
1156+
assert _MOE_EP is not None, "expert model parallel group is not initialized"
1157+
return _MOE_EP
1158+
1159+
1160+
def get_moe_tp_group() -> GroupCoordinator:
1161+
assert _MOE_TP is not None, "expert model parallel group is not initialized"
1162+
return _MOE_TP
1163+
1164+
11441165
# kept for backward compatibility
11451166
get_tensor_model_parallel_group = get_tp_group
11461167

@@ -1250,6 +1271,7 @@ def init_distributed_environment(
12501271

12511272
def initialize_model_parallel(
12521273
tensor_model_parallel_size: int = 1,
1274+
expert_model_parallel_size: int = 1,
12531275
pipeline_model_parallel_size: int = 1,
12541276
backend: Optional[str] = None,
12551277
duplicate_tp_group: bool = False,
@@ -1327,6 +1349,45 @@ def initialize_model_parallel(
13271349
_TP.pynccl_comm.disabled = False
13281350
_PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
13291351

1352+
moe_ep_size = expert_model_parallel_size
1353+
1354+
moe_tp_size = tensor_model_parallel_size // moe_ep_size
1355+
global _MOE_EP
1356+
assert _MOE_EP is None, "expert model parallel group is already initialized"
1357+
group_ranks = []
1358+
for i in range(num_tensor_model_parallel_groups):
1359+
for j in range(moe_tp_size):
1360+
st = i * tensor_model_parallel_size + j
1361+
en = (i + 1) * tensor_model_parallel_size + j
1362+
ranks = list(range(st, en, moe_tp_size))
1363+
group_ranks.append(ranks)
1364+
1365+
_MOE_EP = init_model_parallel_group(
1366+
group_ranks,
1367+
get_world_group().local_rank,
1368+
backend,
1369+
use_custom_allreduce=False,
1370+
group_name="moe_ep",
1371+
)
1372+
1373+
global _MOE_TP
1374+
assert _MOE_TP is None, "expert model parallel group is already initialized"
1375+
group_ranks = []
1376+
for i in range(num_tensor_model_parallel_groups):
1377+
for j in range(moe_ep_size):
1378+
st = i * tensor_model_parallel_size + j * moe_tp_size
1379+
en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size
1380+
ranks = list(range(st, en))
1381+
group_ranks.append(ranks)
1382+
1383+
_MOE_TP = init_model_parallel_group(
1384+
group_ranks,
1385+
get_world_group().local_rank,
1386+
backend,
1387+
use_custom_allreduce=False,
1388+
group_name="moe_tp",
1389+
)
1390+
13301391
# Build the pipeline model-parallel groups.
13311392
num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
13321393
global _PP
@@ -1347,6 +1408,7 @@ def initialize_model_parallel(
13471408

13481409
def ensure_model_parallel_initialized(
13491410
tensor_model_parallel_size: int,
1411+
expert_model_parallel_size: int,
13501412
pipeline_model_parallel_size: int,
13511413
backend: Optional[str] = None,
13521414
) -> None:
@@ -1357,7 +1419,10 @@ def ensure_model_parallel_initialized(
13571419
backend = backend or torch.distributed.get_backend(get_world_group().device_group)
13581420
if not model_parallel_is_initialized():
13591421
initialize_model_parallel(
1360-
tensor_model_parallel_size, pipeline_model_parallel_size, backend
1422+
tensor_model_parallel_size,
1423+
expert_model_parallel_size,
1424+
pipeline_model_parallel_size,
1425+
backend,
13611426
)
13621427
return
13631428

@@ -1417,6 +1482,26 @@ def get_tensor_model_parallel_rank():
14171482
return get_tp_group().rank_in_group
14181483

14191484

1485+
def get_moe_expert_parallel_world_size():
1486+
"""Return world size for the moe expert parallel group."""
1487+
return get_moe_ep_group().world_size
1488+
1489+
1490+
def get_moe_expert_parallel_rank():
1491+
"""Return my rank for the moe expert parallel group."""
1492+
return get_moe_ep_group().rank_in_group
1493+
1494+
1495+
def get_moe_tensor_parallel_world_size():
1496+
"""Return world size for the moe tensor parallel group."""
1497+
return get_moe_tp_group().world_size
1498+
1499+
1500+
def get_moe_tensor_parallel_rank():
1501+
"""Return my rank for the moe tensor parallel group."""
1502+
return get_moe_tp_group().rank_in_group
1503+
1504+
14201505
def destroy_model_parallel():
14211506
"""Set the groups to none and destroy them."""
14221507
global _TP

python/sglang/srt/entrypoints/engine.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -719,13 +719,15 @@ def _launch_subprocesses(
719719
+ ((pp_rank % pp_size_per_node) * tp_size_per_node)
720720
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
721721
)
722+
moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
722723
proc = mp.Process(
723724
target=run_scheduler_process,
724725
args=(
725726
server_args,
726727
port_args,
727728
gpu_id,
728729
tp_rank,
730+
moe_ep_rank,
729731
pp_rank,
730732
None,
731733
writer,

python/sglang/srt/layers/moe/ep_moe/layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def __init__(
135135
enable_ep_moe=True,
136136
)
137137

138-
self.start_expert_id = self.ep_rank * self.num_local_experts
138+
self.start_expert_id = self.moe_ep_rank * self.num_local_experts
139139
self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
140140

141141
self.intermediate_size = intermediate_size

python/sglang/srt/layers/moe/fused_moe_triton/layer.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
import torch
88

99
from sglang.srt.distributed import (
10+
get_moe_expert_parallel_rank,
11+
get_moe_expert_parallel_world_size,
12+
get_moe_tensor_parallel_rank,
13+
get_moe_tensor_parallel_world_size,
1014
get_tensor_model_parallel_rank,
1115
get_tensor_model_parallel_world_size,
1216
tensor_model_parallel_all_reduce,
@@ -88,10 +92,6 @@ def __init__(
8892
self.layer_id = layer_id
8993
self.top_k = top_k
9094
self.hidden_size = hidden_size
91-
self.tp_size = (
92-
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
93-
)
94-
self.tp_rank = get_tensor_model_parallel_rank()
9595
self.num_experts = num_experts
9696
self.num_fused_shared_experts = num_fused_shared_experts
9797
self.expert_map_cpu = None
@@ -103,30 +103,27 @@ def __init__(
103103
enable_ep_moe = False
104104

105105
self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe
106+
self.moe_ep_size = get_moe_expert_parallel_world_size()
107+
self.moe_ep_rank = get_moe_expert_parallel_rank()
108+
self.moe_tp_size = get_moe_tensor_parallel_world_size()
109+
self.moe_tp_rank = get_moe_tensor_parallel_rank()
110+
assert num_experts % self.moe_ep_size == 0
111+
self.num_local_experts = num_experts // self.moe_ep_size
106112
if enable_ep_moe:
107113
# TODO(ch-wan): support shared experts fusion
108-
self.ep_size = self.tp_size
109-
self.ep_rank = self.tp_rank
110-
self.tp_size = 1
111-
self.tp_rank = 0
112114
# Create a tensor of size num_experts filled with -1
113115
self.expert_map_cpu = torch.full((self.num_experts,), -1, dtype=torch.int32)
114116
# Create a expert map for the local experts
115-
assert num_experts % self.ep_size == 0
116-
self.num_local_experts = num_experts // self.ep_size
117117
self.expert_map_cpu[
118-
self.ep_rank
119-
* self.num_local_experts : (self.ep_rank + 1)
118+
self.moe_ep_rank
119+
* self.num_local_experts : (self.moe_ep_rank + 1)
120120
* self.num_local_experts
121121
] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
122122
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
123-
else:
124-
self.ep_size = 1
125-
self.ep_rank = 0
126-
self.num_local_experts = num_experts
123+
127124
self.routed_scaling_factor = routed_scaling_factor
128-
assert intermediate_size % self.tp_size == 0
129-
self.intermediate_size_per_partition = intermediate_size // self.tp_size
125+
assert intermediate_size % self.moe_tp_size == 0
126+
self.intermediate_size_per_partition = intermediate_size // self.moe_tp_size
130127
self.reduce_results = reduce_results
131128
self.activation = activation
132129
self.apply_router_weight_on_input = apply_router_weight_on_input
@@ -437,8 +434,7 @@ def _weight_loader_impl(
437434
expert_id: int,
438435
) -> None:
439436

440-
# TP rank is set to 0 if EP is enabled
441-
tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank()
437+
tp_rank = self.moe_tp_rank
442438

443439
# compressed-tensors checkpoints with packed weights are stored flipped
444440
# TODO (mgoin): check self.quant_method.quant_config.quant_format
@@ -630,17 +626,17 @@ def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
630626
routed_scaling_factor=self.routed_scaling_factor,
631627
**(
632628
dict(
633-
tp_rank=self.tp_rank,
634-
tp_size=self.tp_size,
635-
ep_rank=self.ep_rank,
636-
ep_size=self.ep_size,
629+
tp_rank=self.moe_tp_rank,
630+
tp_size=self.moe_tp_size,
631+
ep_rank=self.moe_ep_rank,
632+
ep_size=self.moe_ep_size,
637633
)
638634
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod"
639635
else {}
640636
),
641637
)
642638

643-
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
639+
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
644640
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
645641

646642
return final_hidden_states

python/sglang/srt/managers/data_parallel_controller.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,13 +222,15 @@ def launch_tensor_parallel_group(
222222
+ ((pp_rank % pp_size_per_node) * tp_size_per_node)
223223
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
224224
)
225+
moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
225226
proc = mp.Process(
226227
target=run_scheduler_process,
227228
args=(
228229
server_args,
229230
rank_port_args,
230231
gpu_id,
231232
tp_rank,
233+
moe_ep_rank,
232234
pp_rank,
233235
dp_rank,
234236
writer,

python/sglang/srt/managers/scheduler.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,15 +200,18 @@ def __init__(
200200
port_args: PortArgs,
201201
gpu_id: int,
202202
tp_rank: int,
203+
moe_ep_rank: int,
203204
pp_rank: int,
204205
dp_rank: Optional[int],
205206
):
206207
# Parse args
207208
self.server_args = server_args
208209
self.tp_rank = tp_rank
210+
self.moe_ep_rank = moe_ep_rank
209211
self.pp_rank = pp_rank
210212
self.dp_rank = dp_rank
211213
self.tp_size = server_args.tp_size
214+
self.moe_ep_size = server_args.ep_size
212215
self.pp_size = server_args.pp_size
213216
self.dp_size = server_args.dp_size
214217
self.schedule_policy = server_args.schedule_policy
@@ -310,6 +313,7 @@ def __init__(
310313
server_args=server_args,
311314
gpu_id=gpu_id,
312315
tp_rank=tp_rank,
316+
moe_ep_rank=moe_ep_rank,
313317
pp_rank=pp_rank,
314318
dp_rank=dp_rank,
315319
nccl_port=port_args.nccl_port,
@@ -322,6 +326,7 @@ def __init__(
322326
self.draft_worker = EAGLEWorker(
323327
gpu_id=gpu_id,
324328
tp_rank=tp_rank,
329+
moe_ep_rank=moe_ep_rank,
325330
server_args=server_args,
326331
nccl_port=port_args.nccl_port,
327332
target_worker=self.tp_worker,
@@ -2358,6 +2363,7 @@ def run_scheduler_process(
23582363
port_args: PortArgs,
23592364
gpu_id: int,
23602365
tp_rank: int,
2366+
moe_ep_rank: int,
23612367
pp_rank: int,
23622368
dp_rank: Optional[int],
23632369
pipe_writer,
@@ -2368,6 +2374,8 @@ def run_scheduler_process(
23682374
prefix += f" DP{dp_rank}"
23692375
if server_args.tp_size > 1:
23702376
prefix += f" TP{tp_rank}"
2377+
if server_args.ep_size > 1:
2378+
prefix += f" EP{moe_ep_rank}"
23712379
if server_args.pp_size > 1:
23722380
prefix += f" PP{pp_rank}"
23732381

@@ -2391,7 +2399,9 @@ def run_scheduler_process(
23912399

23922400
# Create a scheduler and run the event loop
23932401
try:
2394-
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
2402+
scheduler = Scheduler(
2403+
server_args, port_args, gpu_id, tp_rank, moe_ep_rank, pp_rank, dp_rank
2404+
)
23952405
pipe_writer.send(
23962406
{
23972407
"status": "ready",

python/sglang/srt/managers/tp_worker.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def __init__(
5656
server_args: ServerArgs,
5757
gpu_id: int,
5858
tp_rank: int,
59+
moe_ep_rank: int,
5960
pp_rank: int,
6061
dp_rank: Optional[int],
6162
nccl_port: int,
@@ -66,6 +67,7 @@ def __init__(
6667
# Parse args
6768
self.tp_size = server_args.tp_size
6869
self.tp_rank = tp_rank
70+
self.moe_ep_rank = moe_ep_rank
6971
self.pp_rank = pp_rank
7072

7173
# Init model and tokenizer
@@ -85,6 +87,8 @@ def __init__(
8587
gpu_id=gpu_id,
8688
tp_rank=tp_rank,
8789
tp_size=server_args.tp_size,
90+
moe_ep_rank=moe_ep_rank,
91+
moe_ep_size=server_args.ep_size,
8892
pp_rank=pp_rank,
8993
pp_size=server_args.pp_size,
9094
nccl_port=nccl_port,

0 commit comments

Comments
 (0)