diff --git a/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py b/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py index 1e146c306d..d5e373f029 100644 --- a/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py +++ b/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py @@ -21,6 +21,18 @@ ) +@paddle_use_triton_v2() +def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token, + token_mask, BLOCK_SIZE_M, BLOCK_SIZE_N, + compute_type): + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ + None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + @paddle_use_triton_v2() def fused_moe_kernel_paddle( a_ptr, @@ -108,11 +120,20 @@ def fused_moe_kernel_paddle( offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) token_mask = offs_token < num_valid_tokens + off_experts = tl.load(expert_ids_ptr + pid_m) + if off_experts == -1: + # ----------------------------------------------------------- + # Write back zeros to the output when the expert is not + # in the current expert parallel rank. + write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, + offs_token, token_mask, BLOCK_SIZE_M, + BLOCK_SIZE_N, compute_type) + return + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak) - off_experts = tl.load(expert_ids_ptr + pid_m) + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak) b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) if use_int8_w8a16: