Skip to content

Commit be1f709

Browse files
committed
Optimizing kernel performance
1 parent 6fdd83d commit be1f709

File tree

1 file changed

+23
-2
lines changed

1 file changed

+23
-2
lines changed

fastdeploy/model_executor/layers/moe/triton_moe_kernels.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,18 @@
2121
)
2222

2323

24+
@paddle_use_triton_v2()
25+
def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token,
26+
token_mask, BLOCK_SIZE_M, BLOCK_SIZE_N,
27+
compute_type):
28+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type)
29+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
30+
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
31+
None, :]
32+
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
33+
tl.store(c_ptrs, accumulator, mask=c_mask)
34+
35+
2436
@paddle_use_triton_v2()
2537
def fused_moe_kernel_paddle(
2638
a_ptr,
@@ -108,11 +120,20 @@ def fused_moe_kernel_paddle(
108120
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
109121
token_mask = offs_token < num_valid_tokens
110122

123+
off_experts = tl.load(expert_ids_ptr + pid_m)
124+
if off_experts == -1:
125+
# -----------------------------------------------------------
126+
# Write back zeros to the output when the expert is not
127+
# in the current expert parallel rank.
128+
write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N,
129+
offs_token, token_mask, BLOCK_SIZE_M,
130+
BLOCK_SIZE_N, compute_type)
131+
return
132+
111133
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
112134
offs_k = tl.arange(0, BLOCK_SIZE_K)
113-
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak)
114135

115-
off_experts = tl.load(expert_ids_ptr + pid_m)
136+
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak)
116137
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
117138

118139
if use_int8_w8a16:

0 commit comments

Comments
 (0)