-
Notifications
You must be signed in to change notification settings - Fork 596
[Feat] support mixed ep #2969
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feat] support mixed ep #2969
Changes from all commits
bda8ad2
50b78d3
62db77a
e6a620d
f206ac2
9ede5a2
e414fae
84afd99
7299c1e
c81a3f9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,9 +43,10 @@ def __init__( | |
num_max_dispatch_tokens_per_rank: int, | ||
hidden: int, | ||
num_experts: int, | ||
moe_phase: MoEPhase, | ||
ep_size: int, | ||
ep_rank: int, | ||
splitwise_role: str, | ||
moe_phase: MoEPhase, | ||
async_finish: bool = False, | ||
): | ||
""" | ||
|
@@ -65,26 +66,44 @@ def __init__( | |
self.hidden = hidden | ||
self.num_experts = num_experts | ||
self.num_local_experts = num_experts // ep_size | ||
self.moe_phase = moe_phase | ||
self.async_finish = async_finish | ||
|
||
self.deepep_engine = None | ||
self.prefill_deepep_engine = None | ||
self.decode_deepep_engine = None | ||
|
||
self.ep_config = Config(24, 6, 256) | ||
self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank | ||
|
||
if moe_phase == MoEPhase.DECODER: | ||
# In mixed EP mode on a single node, we dynamically switch between | ||
# high throughput and low latency modes. | ||
if splitwise_role == "mixed": | ||
# decode engine | ||
logger.info("Initializing Low Latency Buffer") | ||
self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank | ||
self.get_low_latency_buffer() | ||
elif moe_phase == MoEPhase.PREFILL: | ||
self.deepep_engine = deep_ep.Buffer( | ||
# prefill engine | ||
self.prefill_deepep_engine = deep_ep.Buffer( | ||
self.group, | ||
int(5e8), | ||
0, | ||
low_latency_mode=False, | ||
num_qps_per_rank=1, | ||
) | ||
self.ep_config = Config(24, 6, 256) | ||
# In disaggregated mode on mutiple nodes, we either use | ||
# high throughput mode or low latency mode. | ||
else: | ||
raise ValueError(f"Unknown generation phase {moe_phase}") | ||
if moe_phase.phase == "decode": | ||
logger.info("Initializing Low Latency Buffer") | ||
self.get_low_latency_buffer() | ||
elif moe_phase.phase == "prefill": | ||
self.prefill_deepep_engine = deep_ep.Buffer( | ||
self.group, | ||
int(5e8), | ||
0, | ||
low_latency_mode=False, | ||
num_qps_per_rank=1, | ||
) | ||
else: | ||
raise ValueError(f"Unknown generation phase {moe_phase}") | ||
|
||
def get_low_latency_buffer(self): | ||
""" | ||
|
@@ -105,14 +124,14 @@ def get_low_latency_buffer(self): | |
) | ||
# Allocate a buffer if not existed or not enough buffer size | ||
if ( | ||
self.deepep_engine is None | ||
or self.deepep_engine.group != self.group | ||
or not self.deepep_engine.low_latency_mode | ||
or self.deepep_engine.num_rdma_bytes < num_rdma_bytes | ||
self.decode_deepep_engine is None | ||
or self.decode_deepep_engine.group != self.group | ||
or not self.decode_deepep_engine.low_latency_mode | ||
or self.decode_deepep_engine.num_rdma_bytes < num_rdma_bytes | ||
): | ||
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts | ||
assert self.num_experts % self.ep_size == 0 | ||
self.deepep_engine = deep_ep.Buffer( | ||
self.decode_deepep_engine = deep_ep.Buffer( | ||
self.group, | ||
0, | ||
num_rdma_bytes, | ||
|
@@ -149,7 +168,7 @@ def low_latency_dispatch( | |
handle, | ||
_, | ||
dispatch_hook, | ||
) = self.deepep_engine.low_latency_dispatch( | ||
) = self.decode_deepep_engine.low_latency_dispatch( | ||
hidden_states, | ||
topk_idx, | ||
expertwise_scale, | ||
|
@@ -174,8 +193,22 @@ def low_latency_combine( | |
Return: | ||
combined_hidden_states: [num_tokens, hidden] | ||
""" | ||
# TODO(@wufeisheng): Delete them when deepep in PaddlePaddle is fixed | ||
( | ||
src_info, | ||
layout_range, | ||
num_max_dispatch_tokens_per_rank, | ||
num_experts, | ||
) = handle | ||
handle = ( | ||
src_info, | ||
layout_range, | ||
num_max_dispatch_tokens_per_rank, | ||
None, | ||
num_experts, | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这段代码之前被删了,但是我发现3.0.1的paddle这里还是会报错,而Fastdeploy的使用文档里推荐用户使用的paddle版本还是3.0.1,所以从用户使用角度考虑先留着 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里也需要适配develop 下一个PR这里都兼容一下吧 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的 |
||
|
||
combined_hidden_states, _, combine_hook = self.deepep_engine.low_latency_combine( | ||
combined_hidden_states, _, combine_hook = self.decode_deepep_engine.low_latency_combine( | ||
hidden_states, | ||
topk_idx, | ||
topk_weights, | ||
|
@@ -189,15 +222,19 @@ def clean_low_latency_buffer(self): | |
""" | ||
clean_low_latency_buffer | ||
""" | ||
self.deepep_engine.clean_low_latency_buffer( | ||
self.decode_deepep_engine.clean_low_latency_buffer( | ||
self.num_max_dispatch_tokens_per_rank, self.hidden, self.num_experts | ||
) | ||
|
||
def barrier_all(self): | ||
""" | ||
barrier_all | ||
""" | ||
self.deepep_engine.barrier_all() | ||
if self.prefill_deepep_engine is not None: | ||
self.prefill_deepep_engine.barrier_all() | ||
|
||
if self.decode_deepep_engine is not None: | ||
self.decode_deepep_engine.barrier_all() | ||
|
||
|
||
class EPRunner: | ||
|
@@ -210,6 +247,7 @@ def __init__( | |
top_k: int, | ||
hidden: int, | ||
num_experts: int, | ||
splitwise_role: str, | ||
moe_phase: MoEPhase, | ||
num_max_dispatch_tokens_per_rank: int = 1, | ||
ep_size: int = 1, | ||
|
@@ -223,9 +261,10 @@ def __init__( | |
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank, | ||
hidden=hidden, | ||
num_experts=num_experts + redundant_experts_num, | ||
moe_phase=moe_phase, | ||
ep_size=ep_size, | ||
ep_rank=ep_rank, | ||
splitwise_role=splitwise_role, | ||
moe_phase=moe_phase, | ||
) | ||
|
||
def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor): | ||
|
@@ -286,15 +325,19 @@ def __init__( | |
top_k: int, | ||
hidden: int, | ||
num_experts: int, | ||
splitwise_role: str, | ||
ep_size: int = 1, | ||
ep_rank: int = 0, | ||
redundant_experts_num: int = 0, | ||
moe_phase: MoEPhase = MoEPhase("prefill"), | ||
): | ||
super().__init__( | ||
top_k, | ||
hidden, | ||
num_experts, | ||
MoEPhase.PREFILL, | ||
splitwise_role, | ||
moe_phase, | ||
num_max_dispatch_tokens_per_rank=256, | ||
ep_size=ep_size, | ||
ep_rank=ep_rank, | ||
redundant_experts_num=redundant_experts_num, | ||
|
@@ -314,7 +357,7 @@ def dispatch( | |
num_tokens_per_expert, | ||
is_token_in_rank, | ||
_, | ||
) = self.ep_engine.deepep_engine.get_dispatch_layout(topk_idx, self.num_experts) | ||
) = self.ep_engine.prefill_deepep_engine.get_dispatch_layout(topk_idx, self.num_experts) | ||
|
||
x_scale_tensor = kwargs.get("x_scale_tensor", None) | ||
dispatch_args = { | ||
|
@@ -327,7 +370,7 @@ def dispatch( | |
"topk_idx": topk_idx, | ||
"topk_weights": topk_weights, | ||
} | ||
return self.ep_engine.deepep_engine.dispatch(**dispatch_args) | ||
return self.ep_engine.prefill_deepep_engine.dispatch(**dispatch_args) | ||
|
||
def combine( | ||
self, | ||
|
@@ -342,7 +385,7 @@ def combine( | |
"async_finish": self.ep_engine.async_finish, | ||
"topk_weights": recv_topk_weights, | ||
} | ||
fused_moe_out, _, _ = self.ep_engine.deepep_engine.combine(**combine_args) | ||
fused_moe_out, _, _ = self.ep_engine.prefill_deepep_engine.combine(**combine_args) | ||
|
||
return fused_moe_out | ||
|
||
|
@@ -357,16 +400,19 @@ def __init__( | |
top_k: int, | ||
hidden: int, | ||
num_experts: int, | ||
splitwise_role: str, | ||
num_max_dispatch_tokens_per_rank: int, | ||
ep_size: int = 1, | ||
ep_rank: int = 0, | ||
redundant_experts_num: int = 0, | ||
moe_phase: MoEPhase = MoEPhase("decode"), | ||
): | ||
super().__init__( | ||
top_k, | ||
hidden, | ||
num_experts, | ||
MoEPhase.DECODER, | ||
splitwise_role, | ||
moe_phase, | ||
num_max_dispatch_tokens_per_rank, | ||
ep_size=ep_size, | ||
ep_rank=ep_rank, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mixed模式下,这个
low_latency_mode
不应该设置成True
吗?