Skip to content

[sgl-kernel] feat: Support sm120 cutlass fp8 gemm kernel #9403

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

kaln27
Copy link

@kaln27 kaln27 commented Aug 20, 2025

Motivation

In this PR I support cutlass fp8 gemm kernel.
Issue #7482 says that some fp8 model are failed to load. Those model are FP8-Dynamic which may quantize use llm-compressor.

PR #9251 solve above Issue's problem by fallback to triton fp8 gemm kernel. But through to my experiment, cutlass fp8 gemm kernel (this PR support) way faster then triton fp8 gemm kernel.

In Issue #5338 a lot of people mention this error.

I also publish the pre-builded wheel to google dirve sgl-kernel wheel (CUDA 12.8, sm89 & sm120)

pip install <wheel> --force-reinstall --no-deps

Modifications

support cutlass fp8 kernel in file sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu and set a if condition to route to correct kernel.

Accuracy Tests

WIP

Benchmarking and Profiling

WIP

Checklist

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @kaln27, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces support for the NVIDIA SM120 (Blackwell) architecture's FP8 GEMM kernel using the Cutlass library. This enhancement aims to significantly improve the performance of FP8 matrix multiplications, addressing previous issues where FP8-Dynamic models failed to load or performed suboptimally with the existing Triton FP8 GEMM kernel. The new Cutlass-based implementation is expected to be substantially faster, providing a more efficient and reliable solution for FP8 computations on modern NVIDIA GPUs.

Highlights

  • SM120 Cutlass FP8 GEMM Support: Implemented a new FP8 General Matrix Multiply (GEMM) kernel specifically optimized for NVIDIA SM120 architecture using the Cutlass library.
  • Performance Improvement: The newly integrated Cutlass FP8 GEMM kernel is designed to deliver significantly improved performance compared to the previously used Triton FP8 GEMM kernel.
  • FP8 Model Compatibility: This enhancement provides a more robust solution for FP8-Dynamic models, addressing previous issues where these models failed to load or exhibited suboptimal performance.
  • Dynamic Kernel Dispatch: The system now includes logic to dynamically dispatch to the SM120-optimized kernel when the application is running on compatible hardware (SM version >= 120 and CUDA version >= 12.8).
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for sm120 (Blackwell) architecture for FP8 GEMM kernels using CUTLASS. The implementation follows the pattern of existing kernels for other architectures like sm100.

My review focuses on improving performance and maintainability. I've pointed out that the ClusterShapes used for the new sm120 kernels are likely suboptimal and will negatively impact performance. I've also suggested improving the kernel dispatch logic for small batch sizes. Additionally, there is a significant amount of code duplication between the sm120 and sm100 implementations, and I've provided suggestions on how to refactor this using templates to improve maintainability.

Comment on lines +1398 to +1405
using CTAShapeDefault = Shape<_256, _128, _64>;
using ClusterShapeDefault = Shape<_1, _1, _1>;

using CTAShape256 = Shape<_128, _128, _128>;
using ClusterShape256 = Shape<_1, _1, _1>;

using CTAShape64 = Shape<_128, _64, _128>;
using ClusterShape64 = Shape<_1, _1, _1>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The ClusterShape for all configurations is set to Shape<_1, _1, _1>. This effectively disables thread block clustering, which is a key performance feature on Hopper and newer architectures. This is likely to lead to suboptimal performance.

For sm100, cluster shapes like Shape<_2, _2, _1> and Shape<_2, _1, _1> are used. For sm90, shapes like Shape<_1, _8, _1> are used.

I recommend using more appropriate cluster shapes for sm120 and tuning them for optimal performance. You could start with values similar to those for sm100 as a baseline.

  using CTAShapeDefault = Shape<_256, _128, _64>;
  using ClusterShapeDefault = Shape<_2, _2, _1>;

  using CTAShape256 = Shape<_128, _128, _128>;
  using ClusterShape256 = Shape<_2, _1, _1>;

  using CTAShape64 = Shape<_128, _64, _128>;
  using ClusterShape64 = Shape<_1, _1, _1>;

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no programmatic multicast on this arch sm120. only support cluster shape <_1, _1, _1>

typename EpilogueScheduleType,
typename TileSchedulerType = void,
bool WithBias = false>
struct DeviceGemmFp8RowwiseSm120 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This struct DeviceGemmFp8RowwiseSm120 and its helper functions prepare_sm120_fp8_args and launch_sm120_fp8_scaled_mm are almost identical to their sm100 counterparts. This introduces significant code duplication.

To improve maintainability, consider refactoring this into a generic template that is parameterized by the architecture tag (e.g., cutlass::arch::Sm120).

For example, for the struct:

template <
    typename Arch,
    typename ElementType,
    // ... other template params
>
struct DeviceGemmFp8Rowwise {
  // ... generic implementation using Arch in CollectiveEpilogue and CollectiveMainloop
};

// Then define type aliases for each architecture
using DeviceGemmFp8RowwiseSm120 = DeviceGemmFp8Rowwise<cutlass::arch::Sm120, ...>;

The helper functions prepare_*_fp8_args and launch_*_fp8_scaled_mm are already generic enough and don't need to be duplicated for each architecture. You could rename them to remove the sm120 prefix and use them for all compatible architectures.

Comment on lines +1480 to +1496
if (bias) {
if (mp2 <= 64) {
return launch_sm120_fp8_scaled_mm<BiasGemm64, true>(out, a, b, scales_a, scales_b, bias);
} else if (mp2 <= 256) {
return launch_sm120_fp8_scaled_mm<BiasGemm256, true>(out, a, b, scales_a, scales_b, bias);
} else {
return launch_sm120_fp8_scaled_mm<BiasGemmDefault, true>(out, a, b, scales_a, scales_b, bias);
}
} else {
if (mp2 <= 64) {
return launch_sm120_fp8_scaled_mm<Gemm64, false>(out, a, b, scales_a, scales_b, bias);
} else if (mp2 <= 256) {
return launch_sm120_fp8_scaled_mm<Gemm256, false>(out, a, b, scales_a, scales_b, bias);
} else {
return launch_sm120_fp8_scaled_mm<GemmDefault, false>(out, a, b, scales_a, scales_b, bias);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The dispatch logic here is less granular than for other architectures like sm100, which has a special case for mp2 <= 16. This implementation starts with mp2 <= 64. For small batch sizes (e.g., M=1), this could lead to using a suboptimal kernel configuration. Consider adding a more fine-grained dispatch for smaller M values to improve performance, similar to what's done for sm100 and sm90.

@celsowm
Copy link

celsowm commented Aug 20, 2025

Please merge this guys

@voipmonitor
Copy link
Contributor

Motivation

In this PR I support cutlass fp8 gemm kernel. Issue #7482 says that some fp8 model are failed to load. Those model are FP8-Dynamic which may quantize use llm-compressor.

Does this support block and wise scale? And also compressed wise scale? In other workds - does this support zai-org/GLM-4.5-Air-FP8 which is "quant_method": "compressed-tensors" (https://huggingface.co/zai-org/GLM-4.5-Air-FP8/blob/main/config.json)

Does this support block wise scale also? which is for example https://huggingface.co/Qwen/Qwen3-235B-A22B-Instruct-2507-FP8 which is "quant_method": "fp8","weight_block_size": [128,128]

@voipmonitor
Copy link
Contributor

voipmonitor commented Aug 20, 2025

@kaln27 I have tested this PR on 2x 6000 PRO

python -m sglang.launch_server --model /mnt/GLM-4.5-Air-FP8/ --tp 2 --host 0.0.0.0 --port 8001 --mem-fraction-static 0.95 --context-length 128000

79 tokens/sec for single query with your PR

102 tokens/sec for the triton kernel

my triton version is 3.4.0

I have also backported the fp8 from the vllm (I did not created PR for this) and I also have only 80 tokens/sec.

Why is it slower than triton in my case? What model / usecase / settings provides you with better results of the cutlass fp8 vs triton ?

p.s.: I tried your wheel also to rule out any compilation issues. same results - 79tokens/sec vs triton 102 tonens/sec

@kaln27
Copy link
Author

kaln27 commented Aug 21, 2025

@voipmonitor I use single RTX 5070Ti with model Qwen2.5VL-7B-FP8-Dynamic. I use vllm benchmark_serving.py script to bench the model.

Result for triton kernel

USE_TRITON_W8A8_FP8_KERNEL=1 python -m sglang.launch_server --model-path /data/models/quant/Qwen2.5-VL-7B-Instruct-FP8-Dynamic  --context-length 4096  --mem-fraction-static 0.75 --max-running-requests 64 --port 8000 

============ Serving Benchmark Result ============
Successful requests:                     100       
Benchmark duration (s):                  24.06     
Total input tokens:                      2058      
Total generated tokens:                  9087      
Request throughput (req/s):              4.16      
Output token throughput (tok/s):         377.67    
Total Token throughput (tok/s):          463.20    
---------------Time to First Token----------------
Mean TTFT (ms):                          11845.80  
Median TTFT (ms):                        10688.53  
P99 TTFT (ms):                           20884.12  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          88.14     
Median TPOT (ms):                        82.84     
P99 TPOT (ms):                           157.37    
---------------Inter-token Latency----------------
Mean ITL (ms):                           86.93     
Median ITL (ms):                         33.73     
P99 ITL (ms):                            1418.76   
==================================================

Result for cutlass kernel

python -m sglang.launch_server --model-path /data/models/quant/Qwen2.5-VL-7B-Instruct-FP8-Dynamic  --context-length 4096  --mem-fraction-static 0.75 --max-running-requests 64 --port 8000 

============ Serving Benchmark Result ============
Successful requests:                     100       
Benchmark duration (s):                  18.34     
Total input tokens:                      2058      
Total generated tokens:                  9675      
Request throughput (req/s):              5.45      
Output token throughput (tok/s):         527.49    
Total Token throughput (tok/s):          639.70    
---------------Time to First Token----------------
Mean TTFT (ms):                          9514.83   
Median TTFT (ms):                        8476.88   
P99 TTFT (ms):                           15731.09  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          52.84     
Median TPOT (ms):                        49.88     
P99 TPOT (ms):                           93.08     
---------------Inter-token Latency----------------
Mean ITL (ms):                           51.86     
Median ITL (ms):                         26.61     
P99 ITL (ms):                            138.08    
==================================================

@ziyye
Copy link

ziyye commented Aug 21, 2025

I have tested this PR for Qwen 2.5 FP8 w8a8 quantization model (quantized by llm-compressor) on 5070Ti, and it works well.
Hope the community will review this PR and merge it ASAP.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants