Skip to content

Commit 9effeb5

Browse files
authored
Support EPLB in FusedMoE (sgl-project#8448)
1 parent 1992ef9 commit 9effeb5

File tree

15 files changed

+107
-11
lines changed

15 files changed

+107
-11
lines changed

python/sglang/srt/eplb/expert_distribution.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ def init_new(
4747
rank: int,
4848
):
4949
if server_args.expert_distribution_recorder_mode is not None:
50+
assert (
51+
expert_location_metadata is not None
52+
), "ExpertLocationMetadata is required for expert distribution recording. One possible"
53+
"reason is that you are using a model that does not support expert distribution"
54+
"recording. Try setting `get_model_config_for_expert_location` in your model."
5055
return _ExpertDistributionRecorderReal(
5156
server_args, expert_location_metadata, rank
5257
)

python/sglang/srt/eplb/expert_location.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ def __post_init__(self):
8282
def init_trivial(server_args: ServerArgs, model_config: ModelConfig):
8383
"""Trivial location - logical expert i corresponds to physical expert i"""
8484
common = ExpertLocationMetadata._init_common(server_args, model_config)
85+
86+
if common is None:
87+
return None
88+
8589
num_physical_experts = common["num_physical_experts"]
8690
model_config_for_expert_location = common["model_config_for_expert_location"]
8791
num_layers = model_config_for_expert_location.num_layers
@@ -109,6 +113,10 @@ def init_by_mapping(
109113
physical_to_logical_map = physical_to_logical_map.to(server_args.device)
110114

111115
common = ExpertLocationMetadata._init_common(server_args, model_config)
116+
117+
if common is None:
118+
return None
119+
112120
model_config_for_expert_location = common["model_config_for_expert_location"]
113121
logical_to_all_physical_map = _compute_logical_to_all_physical_map(
114122
physical_to_logical_map,
@@ -133,6 +141,10 @@ def init_by_eplb(
133141
logical_count = logical_count.to(server_args.device)
134142

135143
common = ExpertLocationMetadata._init_common(server_args, model_config)
144+
145+
if common is None:
146+
return None
147+
136148
model_config_for_expert_location = common["model_config_for_expert_location"]
137149
num_physical_experts = common["num_physical_experts"]
138150
num_groups = model_config_for_expert_location.num_groups
@@ -168,6 +180,9 @@ def _init_common(server_args: ServerArgs, model_config: ModelConfig):
168180
ModelConfigForExpertLocation.from_model_config(model_config)
169181
)
170182

183+
if model_config_for_expert_location is None:
184+
return None
185+
171186
num_physical_experts = (
172187
model_config_for_expert_location.num_logical_experts
173188
+ server_args.ep_num_redundant_experts
@@ -398,10 +413,6 @@ class ModelConfigForExpertLocation:
398413
num_logical_experts: int
399414
num_groups: Optional[int] = None
400415

401-
@staticmethod
402-
def init_dummy():
403-
return ModelConfigForExpertLocation(num_layers=1, num_logical_experts=1)
404-
405416
@staticmethod
406417
def from_model_config(model_config: ModelConfig):
407418
model_class, _ = get_model_architecture(model_config)
@@ -410,12 +421,12 @@ def from_model_config(model_config: ModelConfig):
410421
model_config.hf_config
411422
)
412423
else:
413-
return ModelConfigForExpertLocation.init_dummy()
424+
return None
414425

415426

416427
def compute_initial_expert_location_metadata(
417428
server_args: ServerArgs, model_config: ModelConfig
418-
) -> ExpertLocationMetadata:
429+
) -> Optional[ExpertLocationMetadata]:
419430
data = server_args.init_expert_location
420431
if data == "trivial":
421432
return ExpertLocationMetadata.init_trivial(server_args, model_config)

python/sglang/srt/eplb/expert_location_dispatch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class ExpertLocationDispatchInfo:
3636
def init_new(cls, layer_id: int):
3737
ep_dispatch_algorithm = global_server_args_dict["ep_dispatch_algorithm"]
3838
expert_location_metadata = get_global_expert_location_metadata()
39+
assert expert_location_metadata is not None
3940

4041
if ep_dispatch_algorithm is None:
4142
return None

python/sglang/srt/eplb/expert_location_updater.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ def update(
5050
torch.cuda.empty_cache()
5151

5252
old_expert_location_metadata = get_global_expert_location_metadata()
53+
assert old_expert_location_metadata is not None
54+
5355
_update_expert_weights(
5456
routed_experts_weights_of_layer=routed_experts_weights_of_layer,
5557
old_expert_location_metadata=old_expert_location_metadata,

python/sglang/srt/layers/moe/ep_moe/layer.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def __init__(
183183
hidden_size: int,
184184
intermediate_size: int,
185185
layer_id: int,
186+
num_fused_shared_experts: int = 0,
186187
params_dtype: Optional[torch.dtype] = None,
187188
quant_config: Optional[QuantizationConfig] = None,
188189
tp_size: Optional[int] = None,
@@ -196,6 +197,7 @@ def __init__(
196197
hidden_size=hidden_size,
197198
intermediate_size=intermediate_size,
198199
top_k=top_k,
200+
num_fused_shared_experts=num_fused_shared_experts,
199201
layer_id=layer_id,
200202
params_dtype=params_dtype,
201203
quant_config=quant_config,
@@ -728,10 +730,19 @@ def weight_loader(
728730
shard_id: str,
729731
expert_id: int,
730732
) -> None:
731-
physical_expert_ids = (
732-
get_global_expert_location_metadata().logical_to_all_physical(
733-
self.layer_id, expert_id
733+
global_expert_location_metadata = get_global_expert_location_metadata()
734+
if global_expert_location_metadata is None:
735+
self._weight_loader_impl(
736+
param=param,
737+
loaded_weight=loaded_weight,
738+
weight_name=weight_name,
739+
shard_id=shard_id,
740+
expert_id=expert_id,
734741
)
742+
return
743+
744+
physical_expert_ids = global_expert_location_metadata.logical_to_all_physical(
745+
self.layer_id, expert_id
735746
)
736747
for physical_expert_id in physical_expert_ids:
737748
self._weight_loader_physical(
@@ -778,6 +789,7 @@ def __init__(
778789
hidden_size: int,
779790
intermediate_size: int,
780791
layer_id: int,
792+
num_fused_shared_experts: int = 0,
781793
params_dtype: Optional[torch.dtype] = None,
782794
quant_config: Optional[QuantizationConfig] = None,
783795
tp_size: Optional[int] = None,
@@ -792,6 +804,7 @@ def __init__(
792804
hidden_size=hidden_size,
793805
intermediate_size=intermediate_size,
794806
layer_id=layer_id,
807+
num_fused_shared_experts=num_fused_shared_experts,
795808
params_dtype=params_dtype,
796809
quant_config=quant_config,
797810
tp_size=tp_size,

python/sglang/srt/layers/moe/fused_moe_triton/layer.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
get_tensor_model_parallel_world_size,
1212
tensor_model_parallel_all_reduce,
1313
)
14+
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
1415
from sglang.srt.layers.moe.topk import TopKOutput
1516
from sglang.srt.layers.quantization.base_config import (
1617
QuantizationConfig,
@@ -62,8 +63,9 @@ def __init__(
6263
num_experts: int,
6364
hidden_size: int,
6465
intermediate_size: int,
66+
layer_id: int,
6567
top_k: Optional[int] = None,
66-
layer_id: Optional[int] = None,
68+
num_fused_shared_experts: int = 0,
6769
params_dtype: Optional[torch.dtype] = None,
6870
reduce_results: bool = False,
6971
quant_config: Optional[QuantizationConfig] = None,
@@ -84,13 +86,15 @@ def __init__(
8486
if params_dtype is None:
8587
params_dtype = torch.get_default_dtype()
8688

89+
self.layer_id = layer_id
8790
self.top_k = top_k
8891
self.hidden_size = hidden_size
8992
self.tp_size = (
9093
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
9194
)
9295
self.tp_rank = get_tensor_model_parallel_rank()
9396
self.num_experts = num_experts
97+
self.num_fused_shared_experts = num_fused_shared_experts
9498
self.expert_map = None
9599

96100
if enable_flashinfer_cutlass_moe and quant_config is None:
@@ -375,6 +379,45 @@ def weight_loader(
375379
shard_id: str,
376380
expert_id: int,
377381
) -> None:
382+
383+
global_expert_location_metadata = get_global_expert_location_metadata()
384+
if global_expert_location_metadata is None:
385+
self._weight_loader_impl(
386+
param=param,
387+
loaded_weight=loaded_weight,
388+
weight_name=weight_name,
389+
shard_id=shard_id,
390+
expert_id=expert_id,
391+
)
392+
return
393+
394+
if expert_id >= self.num_experts - self.num_fused_shared_experts:
395+
# This is a shared expert.
396+
physical_expert_ids = [expert_id]
397+
else:
398+
physical_expert_ids = (
399+
global_expert_location_metadata.logical_to_all_physical(
400+
self.layer_id, expert_id
401+
)
402+
)
403+
404+
for physical_expert_id in physical_expert_ids:
405+
self._weight_loader_physical(
406+
param=param,
407+
loaded_weight=loaded_weight,
408+
weight_name=weight_name,
409+
shard_id=shard_id,
410+
expert_id=physical_expert_id,
411+
)
412+
413+
def _weight_loader_physical(
414+
self,
415+
param: torch.nn.Parameter,
416+
loaded_weight: torch.Tensor,
417+
weight_name: str,
418+
shard_id: str,
419+
expert_id: int,
420+
) -> None:
378421
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
379422
if expert_id == -1:
380423
return

python/sglang/srt/models/deepseek_v2.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,7 @@ def __init__(
325325
num_experts=config.n_routed_experts
326326
+ self.num_fused_shared_experts
327327
+ global_server_args_dict["ep_num_redundant_experts"],
328+
num_fused_shared_experts=self.num_fused_shared_experts,
328329
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
329330
hidden_size=config.hidden_size,
330331
intermediate_size=config.moe_intermediate_size,
@@ -2112,6 +2113,7 @@ def determine_num_fused_shared_experts(
21122113

21132114
if disable_reason is not None:
21142115
global_server_args_dict["disable_shared_experts_fusion"] = True
2116+
self.num_fused_shared_experts = 0
21152117
log_info_on_rank0(
21162118
logger,
21172119
f"{disable_reason} Shared experts fusion optimization is disabled.",

python/sglang/srt/models/glm4_moe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,7 @@ def __init__(
434434
num_experts=config.n_routed_experts
435435
+ self.num_fused_shared_experts
436436
+ global_server_args_dict["ep_num_redundant_experts"],
437+
num_fused_shared_experts=self.num_fused_shared_experts,
437438
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
438439
hidden_size=config.hidden_size,
439440
intermediate_size=config.moe_intermediate_size,
@@ -740,10 +741,11 @@ def determine_num_fused_shared_experts(
740741
global_server_args_dict["enable_deepep_moe"]
741742
or global_server_args_dict["enable_ep_moe"]
742743
):
743-
disable_reason = "Deepseek GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
744+
disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
744745

745746
if disable_reason is not None:
746747
global_server_args_dict["disable_shared_experts_fusion"] = True
748+
self.num_fused_shared_experts = 0
747749
log_info_on_rank0(
748750
logger,
749751
f"{disable_reason} Shared experts fusion optimization is disabled.",

python/sglang/srt/models/granitemoe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def __init__(
4343
top_k: int,
4444
hidden_size: int,
4545
intermediate_size: int,
46+
layer_id: int,
4647
params_dtype: Optional[torch.dtype] = None,
4748
quant_config: Optional[QuantizationConfig] = None,
4849
tp_size: Optional[int] = None,
@@ -71,6 +72,7 @@ def __init__(
7172
top_k=top_k,
7273
hidden_size=hidden_size,
7374
intermediate_size=intermediate_size,
75+
layer_id=layer_id,
7476
params_dtype=params_dtype,
7577
reduce_results=True,
7678
quant_config=quant_config,
@@ -203,6 +205,7 @@ def __init__(
203205
top_k=config.num_experts_per_tok,
204206
hidden_size=config.hidden_size,
205207
intermediate_size=config.intermediate_size,
208+
layer_id=layer_id,
206209
quant_config=quant_config,
207210
prefix=f"{prefix}.block_sparse_moe",
208211
)

python/sglang/srt/models/grok.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ class Grok1MoE(nn.Module):
7878
def __init__(
7979
self,
8080
config: PretrainedConfig,
81+
layer_id: int,
8182
num_experts: int,
8283
top_k: int,
8384
hidden_size: int,
@@ -128,6 +129,7 @@ def __init__(
128129
self.experts = MoEImpl(
129130
num_experts=num_experts,
130131
top_k=top_k,
132+
layer_id=layer_id,
131133
hidden_size=hidden_size,
132134
intermediate_size=intermediate_size,
133135
params_dtype=params_dtype,
@@ -331,6 +333,7 @@ def __init__(
331333
)
332334
self.block_sparse_moe = Grok1MoE(
333335
config=config,
336+
layer_id=layer_id,
334337
num_experts=config.num_local_experts,
335338
top_k=config.num_experts_per_tok,
336339
hidden_size=config.hidden_size,

0 commit comments

Comments
 (0)