-
Notifications
You must be signed in to change notification settings - Fork 2.8k
[ROCm] Enable per token group quant fp8 in amd #3702
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[ROCm] Enable per token group quant fp8 in amd #3702
Conversation
@HaiShaw Hi, can you have a look? Thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Preferably, code refactor is needed.
Also, some correctness to solve.
if is_hip_: | ||
fp8_max = 224 | ||
else: | ||
fp8_max = finfo.max |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you make an FP8_E4M3_MAX global (outside of functions), and refer to it later?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for late reply. I have been working on MLA related function since yesterday.
Sure. I can put it inside "sglang.srt.utils" so that it comes with "_is_hip".
Does it sounds good ?
Also, can I make it in later PR, since this modification may be out of scope this PR? I will fix it as you suggested
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
FYI, #3959
#include <flashinfer/vec_dtypes.cuh> | ||
#else | ||
#include "hip_vec_dtypes.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should not boilerplate sgl-kernel code with flashinfer's.
better to make changes to flashinfer, and then use it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I agree. I have marked it as tempory solution as flashinfer-rocm is not fully supported and ready to use.
As far as I know, SGlang will continuously use flash::vec_t
for vectorization of 128 bit data laoding. With this tempory support, we don't need to modify related CUDA codes.
Will it sound reasonable ?
|
||
// Adapted from flashinfer | ||
|
||
#define FLASHINFER_INLINE inline __attribute__((always_inline)) __device__ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to keep using FLASHINFER_INLINE
here, it is very common macro.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it comes with flashinfer::vec_t tempory device functions support.
c777940
to
e1ec0e8
Compare
please fix the conflicts |
Motivation
This is follow up of PR#3664
Modifications
ROCm test
Checklist