|
21 | 21 | )
|
22 | 22 |
|
23 | 23 |
|
| 24 | +@paddle_use_triton_v2() |
| 25 | +def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token, |
| 26 | + token_mask, BLOCK_SIZE_M, BLOCK_SIZE_N, |
| 27 | + compute_type): |
| 28 | + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type) |
| 29 | + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) |
| 30 | + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ |
| 31 | + None, :] |
| 32 | + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) |
| 33 | + tl.store(c_ptrs, accumulator, mask=c_mask) |
| 34 | + |
| 35 | + |
24 | 36 | @paddle_use_triton_v2()
|
25 | 37 | def fused_moe_kernel_paddle(
|
26 | 38 | a_ptr,
|
@@ -108,11 +120,20 @@ def fused_moe_kernel_paddle(
|
108 | 120 | offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
109 | 121 | token_mask = offs_token < num_valid_tokens
|
110 | 122 |
|
| 123 | + off_experts = tl.load(expert_ids_ptr + pid_m) |
| 124 | + if off_experts == -1: |
| 125 | + # ----------------------------------------------------------- |
| 126 | + # Write back zeros to the output when the expert is not |
| 127 | + # in the current expert parallel rank. |
| 128 | + write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, |
| 129 | + offs_token, token_mask, BLOCK_SIZE_M, |
| 130 | + BLOCK_SIZE_N, compute_type) |
| 131 | + return |
| 132 | + |
111 | 133 | offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
112 | 134 | offs_k = tl.arange(0, BLOCK_SIZE_K)
|
113 |
| - a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak) |
114 | 135 |
|
115 |
| - off_experts = tl.load(expert_ids_ptr + pid_m) |
| 136 | + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak) |
116 | 137 | b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
117 | 138 |
|
118 | 139 | if use_int8_w8a16:
|
|
0 commit comments