-
Notifications
You must be signed in to change notification settings - Fork 2.7k
[sgl-kernel] feat: Support sm120 cutlass fp8 gemm kernel #9403
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @kaln27, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces support for the NVIDIA SM120 (Blackwell) architecture's FP8 GEMM kernel using the Cutlass library. This enhancement aims to significantly improve the performance of FP8 matrix multiplications, addressing previous issues where FP8-Dynamic models failed to load or performed suboptimally with the existing Triton FP8 GEMM kernel. The new Cutlass-based implementation is expected to be substantially faster, providing a more efficient and reliable solution for FP8 computations on modern NVIDIA GPUs.
Highlights
- SM120 Cutlass FP8 GEMM Support: Implemented a new FP8 General Matrix Multiply (GEMM) kernel specifically optimized for NVIDIA SM120 architecture using the Cutlass library.
- Performance Improvement: The newly integrated Cutlass FP8 GEMM kernel is designed to deliver significantly improved performance compared to the previously used Triton FP8 GEMM kernel.
- FP8 Model Compatibility: This enhancement provides a more robust solution for FP8-Dynamic models, addressing previous issues where these models failed to load or exhibited suboptimal performance.
- Dynamic Kernel Dispatch: The system now includes logic to dynamically dispatch to the SM120-optimized kernel when the application is running on compatible hardware (SM version >= 120 and CUDA version >= 12.8).
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for sm120
(Blackwell) architecture for FP8 GEMM kernels using CUTLASS. The implementation follows the pattern of existing kernels for other architectures like sm100
.
My review focuses on improving performance and maintainability. I've pointed out that the ClusterShape
s used for the new sm120
kernels are likely suboptimal and will negatively impact performance. I've also suggested improving the kernel dispatch logic for small batch sizes. Additionally, there is a significant amount of code duplication between the sm120
and sm100
implementations, and I've provided suggestions on how to refactor this using templates to improve maintainability.
using CTAShapeDefault = Shape<_256, _128, _64>; | ||
using ClusterShapeDefault = Shape<_1, _1, _1>; | ||
|
||
using CTAShape256 = Shape<_128, _128, _128>; | ||
using ClusterShape256 = Shape<_1, _1, _1>; | ||
|
||
using CTAShape64 = Shape<_128, _64, _128>; | ||
using ClusterShape64 = Shape<_1, _1, _1>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ClusterShape
for all configurations is set to Shape<_1, _1, _1>
. This effectively disables thread block clustering, which is a key performance feature on Hopper and newer architectures. This is likely to lead to suboptimal performance.
For sm100
, cluster shapes like Shape<_2, _2, _1>
and Shape<_2, _1, _1>
are used. For sm90
, shapes like Shape<_1, _8, _1>
are used.
I recommend using more appropriate cluster shapes for sm120
and tuning them for optimal performance. You could start with values similar to those for sm100
as a baseline.
using CTAShapeDefault = Shape<_256, _128, _64>;
using ClusterShapeDefault = Shape<_2, _2, _1>;
using CTAShape256 = Shape<_128, _128, _128>;
using ClusterShape256 = Shape<_2, _1, _1>;
using CTAShape64 = Shape<_128, _64, _128>;
using ClusterShape64 = Shape<_1, _1, _1>;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no programmatic multicast on this arch sm120. only support cluster shape <_1, _1, _1>
typename EpilogueScheduleType, | ||
typename TileSchedulerType = void, | ||
bool WithBias = false> | ||
struct DeviceGemmFp8RowwiseSm120 { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This struct DeviceGemmFp8RowwiseSm120
and its helper functions prepare_sm120_fp8_args
and launch_sm120_fp8_scaled_mm
are almost identical to their sm100
counterparts. This introduces significant code duplication.
To improve maintainability, consider refactoring this into a generic template that is parameterized by the architecture tag (e.g., cutlass::arch::Sm120
).
For example, for the struct:
template <
typename Arch,
typename ElementType,
// ... other template params
>
struct DeviceGemmFp8Rowwise {
// ... generic implementation using Arch in CollectiveEpilogue and CollectiveMainloop
};
// Then define type aliases for each architecture
using DeviceGemmFp8RowwiseSm120 = DeviceGemmFp8Rowwise<cutlass::arch::Sm120, ...>;
The helper functions prepare_*_fp8_args
and launch_*_fp8_scaled_mm
are already generic enough and don't need to be duplicated for each architecture. You could rename them to remove the sm120
prefix and use them for all compatible architectures.
if (bias) { | ||
if (mp2 <= 64) { | ||
return launch_sm120_fp8_scaled_mm<BiasGemm64, true>(out, a, b, scales_a, scales_b, bias); | ||
} else if (mp2 <= 256) { | ||
return launch_sm120_fp8_scaled_mm<BiasGemm256, true>(out, a, b, scales_a, scales_b, bias); | ||
} else { | ||
return launch_sm120_fp8_scaled_mm<BiasGemmDefault, true>(out, a, b, scales_a, scales_b, bias); | ||
} | ||
} else { | ||
if (mp2 <= 64) { | ||
return launch_sm120_fp8_scaled_mm<Gemm64, false>(out, a, b, scales_a, scales_b, bias); | ||
} else if (mp2 <= 256) { | ||
return launch_sm120_fp8_scaled_mm<Gemm256, false>(out, a, b, scales_a, scales_b, bias); | ||
} else { | ||
return launch_sm120_fp8_scaled_mm<GemmDefault, false>(out, a, b, scales_a, scales_b, bias); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dispatch logic here is less granular than for other architectures like sm100
, which has a special case for mp2 <= 16
. This implementation starts with mp2 <= 64
. For small batch sizes (e.g., M=1
), this could lead to using a suboptimal kernel configuration. Consider adding a more fine-grained dispatch for smaller M
values to improve performance, similar to what's done for sm100
and sm90
.
Please merge this guys |
Does this support block and wise scale? And also compressed wise scale? In other workds - does this support zai-org/GLM-4.5-Air-FP8 which is "quant_method": "compressed-tensors" (https://huggingface.co/zai-org/GLM-4.5-Air-FP8/blob/main/config.json) Does this support block wise scale also? which is for example https://huggingface.co/Qwen/Qwen3-235B-A22B-Instruct-2507-FP8 which is "quant_method": "fp8","weight_block_size": [128,128] |
@kaln27 I have tested this PR on 2x 6000 PRO python -m sglang.launch_server --model /mnt/GLM-4.5-Air-FP8/ --tp 2 --host 0.0.0.0 --port 8001 --mem-fraction-static 0.95 --context-length 128000 79 tokens/sec for single query with your PR 102 tokens/sec for the triton kernel my triton version is 3.4.0 I have also backported the fp8 from the vllm (I did not created PR for this) and I also have only 80 tokens/sec. Why is it slower than triton in my case? What model / usecase / settings provides you with better results of the cutlass fp8 vs triton ? p.s.: I tried your wheel also to rule out any compilation issues. same results - 79tokens/sec vs triton 102 tonens/sec |
@voipmonitor I use single RTX 5070Ti with model Qwen2.5VL-7B-FP8-Dynamic. I use vllm Result for triton kernelUSE_TRITON_W8A8_FP8_KERNEL=1 python -m sglang.launch_server --model-path /data/models/quant/Qwen2.5-VL-7B-Instruct-FP8-Dynamic --context-length 4096 --mem-fraction-static 0.75 --max-running-requests 64 --port 8000
============ Serving Benchmark Result ============
Successful requests: 100
Benchmark duration (s): 24.06
Total input tokens: 2058
Total generated tokens: 9087
Request throughput (req/s): 4.16
Output token throughput (tok/s): 377.67
Total Token throughput (tok/s): 463.20
---------------Time to First Token----------------
Mean TTFT (ms): 11845.80
Median TTFT (ms): 10688.53
P99 TTFT (ms): 20884.12
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 88.14
Median TPOT (ms): 82.84
P99 TPOT (ms): 157.37
---------------Inter-token Latency----------------
Mean ITL (ms): 86.93
Median ITL (ms): 33.73
P99 ITL (ms): 1418.76
================================================== Result for cutlass kernelpython -m sglang.launch_server --model-path /data/models/quant/Qwen2.5-VL-7B-Instruct-FP8-Dynamic --context-length 4096 --mem-fraction-static 0.75 --max-running-requests 64 --port 8000
============ Serving Benchmark Result ============
Successful requests: 100
Benchmark duration (s): 18.34
Total input tokens: 2058
Total generated tokens: 9675
Request throughput (req/s): 5.45
Output token throughput (tok/s): 527.49
Total Token throughput (tok/s): 639.70
---------------Time to First Token----------------
Mean TTFT (ms): 9514.83
Median TTFT (ms): 8476.88
P99 TTFT (ms): 15731.09
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 52.84
Median TPOT (ms): 49.88
P99 TPOT (ms): 93.08
---------------Inter-token Latency----------------
Mean ITL (ms): 51.86
Median ITL (ms): 26.61
P99 ITL (ms): 138.08
================================================== |
I have tested this PR for Qwen 2.5 FP8 w8a8 quantization model (quantized by llm-compressor) on 5070Ti, and it works well. |
Motivation
In this PR I support cutlass fp8 gemm kernel.
Issue #7482 says that some fp8 model are failed to load. Those model are FP8-Dynamic which may quantize use
llm-compressor
.PR #9251 solve above Issue's problem by fallback to triton fp8 gemm kernel. But through to my experiment, cutlass fp8 gemm kernel (this PR support) way faster then triton fp8 gemm kernel.
In Issue #5338 a lot of people mention this error.
I also publish the pre-builded wheel to google dirve sgl-kernel wheel (CUDA 12.8, sm89 & sm120)
Modifications
support cutlass fp8 kernel in file
sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu
and set a if condition to route to correct kernel.Accuracy Tests
WIP
Benchmarking and Profiling
WIP
Checklist