-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Add moe topk softmax templated from vllm #4302
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add moe topk softmax templated from vllm #4302
Conversation
b70667a
to
6200f4a
Compare
2786ad3
to
acd4fb7
Compare
acd4fb7
to
97f4eb0
Compare
d1b7bb2
to
6b28b88
Compare
6b28b88
to
2fa0db7
Compare
c77f195
to
c15a211
Compare
once this PR #4432 is merged
you can use Does it sound good to you ? |
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 4)); | ||
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 2)); | ||
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 1)); | ||
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 16)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#else | ||
#define SGLANG_SHFL_XOR_SYNC(mask, var, lane_mask) __shfl_xor((var), (lane_mask)) | ||
#define SGLANG_SHFL_XOR_SYNC_WIDTH(mask, var, lane_mask, width) __shfl_xor((var), (lane_mask), (width)) | ||
#endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will keep these lines since you have the ROCM specific macro (since CUDA operation is no longer safe if we employ this approach) in many places. But #4432 is merged. The macro is no longer needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! I'll change those and remove the definition.
|
||
const int thread_row_offset = blockIdx.x * num_cols; | ||
|
||
cub::Sum sum; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hipCUB is experimental one, we can try it. But it introduces new dependencies.
Could just use some simple reduction kernel ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely. We can change to customized reductions for both max and sum. I'll do it together with the macro change in a follow-up pr. How about the following one? I can test the correctness on CUDA and may need your help for AMD machine testing.
__device__ __forceinline__ float warpReduceSum(float sum_value) {
sum_value += __shfl_xor_sync(0xffffffff, sum_value, 16);
sum_value += __shfl_xor_sync(0xffffffff, sum_value, 8);
sum_value += __shfl_xor_sync(0xffffffff, sum_value, 4);
sum_value += __shfl_xor_sync(0xffffffff, sum_value, 2);
sum_value += __shfl_xor_sync(0xffffffff, sum_value, 1);
return sum_value;
}
__device__ __forceinline__ float blockReduceSum(float sum_value) {
static __shared__ float warpLevelSums[WARP_SIZE];
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
sum_value = warpReduceSum(sum_value);
if (laneId == 0) warpLevelSums[warpId] = sum_value;
__syncthreads();
sum_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelSums[laneId] : 0;
if (warpId == 0) sum_value = warpReduceSum(sum_value);
return sum_value;
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resovled in #4448
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But also I recommend to use shlf_xor based implementation. The old solution from fasterTransformer(later incorporated into TRT-LLM) uses heavily shared memory for reduction:
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
WIth shlf_xor based implementation, then you can get better result.
Sounds great! I'll remove the marco definition and change back to use |
Motivation
#2965
Modifications
token_expert_indices
)warpReduceMax
/blockReduceMax
for handle AMD use case as well.Tests
Unit tests + benchmarking aligned with vllm counterpart
Checklist