Skip to content

Commit 8ef5d7f

Browse files
FlamingoPgsleepcoo
authored andcommitted
[feat] add fa3 in sgl-kernel (sgl-project#4902)
Co-authored-by: Sleepcoo <Sleepcoo@gmail.com>
1 parent e56c394 commit 8ef5d7f

File tree

7 files changed

+1300
-0
lines changed

7 files changed

+1300
-0
lines changed

sgl-kernel/CMakeLists.txt

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,27 +25,39 @@ find_package(Torch REQUIRED)
2525

2626
include(FetchContent)
2727

28+
# cutlass
2829
FetchContent_Declare(
2930
repo-cutlass
3031
GIT_REPOSITORY https://github.com/NVIDIA/cutlass
3132
GIT_TAG 62750a2b75c802660e4894434dc55e839f322277
3233
GIT_SHALLOW ON
3334
)
3435
FetchContent_Populate(repo-cutlass)
36+
# DeepGEMM
3537
FetchContent_Declare(
3638
repo-deepgemm
3739
GIT_REPOSITORY https://github.com/deepseek-ai/DeepGEMM
3840
GIT_TAG c57699ac933a93651c34d365797c2d8b41a4765b
3941
GIT_SHALLOW ON
4042
)
4143
FetchContent_Populate(repo-deepgemm)
44+
# flashinfer
4245
FetchContent_Declare(
4346
repo-flashinfer
4447
GIT_REPOSITORY https://github.com/sgl-project/flashinfer
4548
GIT_TAG sgl-kernel
4649
GIT_SHALLOW OFF
4750
)
4851
FetchContent_Populate(repo-flashinfer)
52+
# flash-attention
53+
FetchContent_Declare(
54+
repo-flash-attention
55+
GIT_REPOSITORY https://github.com/sgl-project/sgl-attn
56+
GIT_TAG sgl-kernel
57+
GIT_SHALLOW OFF
58+
)
59+
FetchContent_Populate(repo-flash-attention)
60+
4961

5062
include_directories(
5163
${PROJECT_SOURCE_DIR}/include
@@ -54,6 +66,7 @@ include_directories(
5466
${repo-cutlass_SOURCE_DIR}/tools/util/include
5567
${repo-flashinfer_SOURCE_DIR}/include
5668
${repo-flashinfer_SOURCE_DIR}/csrc
69+
${repo-flash-attention_SOURCE_DIR}/hopper
5770
)
5871

5972
set(CMAKE_CXX_STANDARD 17)
@@ -78,6 +91,7 @@ set(SGL_KERNEL_CUDA_FLAGS
7891
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1"
7992
"-DCUTLASS_DEBUG_TRACE_LEVEL=0"
8093
"--expt-relaxed-constexpr"
94+
"--use_fast_math"
8195
"-Xcompiler=-Wconversion"
8296
"-Xcompiler=-fno-strict-aliasing"
8397
)
@@ -130,6 +144,30 @@ string(REPLACE "-D__CUDA_NO_HALF_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE
130144
string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
131145
string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
132146

147+
# set flash-attention sources file
148+
# BF16 source files
149+
file(GLOB FA3_BF16_GEN_SRCS
150+
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu")
151+
file(GLOB FA3_BF16_GEN_SRCS_
152+
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_bf16*_sm90.cu")
153+
list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_})
154+
155+
# FP16 source files
156+
file(GLOB FA3_FP16_GEN_SRCS
157+
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_fp16*_sm90.cu")
158+
file(GLOB FA3_FP16_GEN_SRCS_
159+
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_fp16*_sm90.cu")
160+
list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_})
161+
162+
# FP8 source files
163+
file(GLOB FA3_FP8_GEN_SRCS
164+
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_e4m3*_sm90.cu")
165+
file(GLOB FA3_FP8_GEN_SRCS_
166+
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_e4m3*_sm90.cu")
167+
list(APPEND FA3_FP8_GEN_SRCS ${FA3_FP8_GEN_SRCS_})
168+
169+
set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} ${FA3_FP8_GEN_SRCS})
170+
133171
set(SOURCES
134172
"csrc/allreduce/trt_reduce_internal.cu"
135173
"csrc/allreduce/trt_reduce_kernel.cu"
@@ -160,6 +198,10 @@ set(SOURCES
160198
"${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
161199
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
162200
"${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
201+
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_prepare_scheduler.cu"
202+
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_api.cpp"
203+
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_fwd_combine.cu"
204+
"${FA3_GEN_SRCS}"
163205
)
164206

165207
# Support abi3 for build
@@ -173,6 +215,18 @@ target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cubl
173215

174216
install(TARGETS common_ops LIBRARY DESTINATION "sgl_kernel")
175217

218+
# Add some flash-attention custom flag for inference
219+
target_compile_definitions(common_ops PRIVATE
220+
FLASHATTENTION_DISABLE_SM8x
221+
FLASHATTENTION_DISABLE_BACKWARD
222+
FLASHATTENTION_DISABLE_DROPOUT
223+
# FLASHATTENTION_DISABLE_ALIBI
224+
# FLASHATTENTION_DISABLE_SOFTCAP
225+
FLASHATTENTION_DISABLE_UNEVEN_K
226+
# FLASHATTENTION_DISABLE_LOCAL
227+
FLASHATTENTION_VARLEN_ONLY
228+
)
229+
176230
# JIT Logic
177231
# DeepGEMM
178232

sgl-kernel/README.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,36 @@ Steps to add a new kernel:
9292
)
9393
```
9494

95+
### Integrating Third-Party Libraries with Data Type Conversion
96+
97+
When integrating new third-party libraries like flash-attention, you may encounter data type compatibility issues between the C++ interface and PyTorch bindings. For example, the third-party code might use `float` or `int` types, while PyTorch requires `double` and `int64_t`.
98+
99+
To address this issue, we provide the `make_pytorch_shim` function in [sgl_kernel_torch_shim](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/include/sgl_kernel_torch_shim.h) that handles data type conversions automatically.
100+
101+
When you need to support new data type conversions, you can easily add conversion functions like this:
102+
103+
```cpp
104+
// Map `int` -> `int64_t`
105+
template <>
106+
struct pytorch_library_compatible_type<int> {
107+
using type = int64_t;
108+
static int convert_from_type(int64_t arg) {
109+
TORCH_CHECK(arg <= std::numeric_limits<int>::max(), "int64_t value is too large to be converted to int");
110+
TORCH_CHECK(arg >= std::numeric_limits<int>::min(), "int64_t value is too small to be converted to int");
111+
return arg;
112+
}
113+
};
114+
```
115+
116+
To use this with your library functions, simply wrap them with make_pytorch_shim:
117+
118+
```cpp
119+
/*
120+
* From flash-attention
121+
*/
122+
m.def("fwd", make_pytorch_shim(mha_fwd));
123+
```
124+
95125
### Build & Install
96126

97127
Development build:

sgl-kernel/csrc/torch_extension.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,11 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
9191
m.def("top_p_renorm_probs", top_p_renorm_probs);
9292
m.def("top_k_top_p_sampling_from_probs", top_k_top_p_sampling_from_probs);
9393
m.def("top_p_sampling_from_probs", top_p_sampling_from_probs);
94+
95+
/*
96+
* From flash-attention
97+
*/
98+
m.def("fwd", make_pytorch_shim(mha_fwd));
9499
}
95100

96101
REGISTER_EXTENSION(common_ops)

sgl-kernel/include/sgl_kernel_ops.h

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ limitations under the License.
2323

2424
#include <vector>
2525

26+
#include "sgl_kernel_torch_shim.h"
27+
2628
#define _CONCAT(A, B) A##B
2729
#define CONCAT(A, B) _CONCAT(A, B)
2830

@@ -291,3 +293,48 @@ void top_p_sampling_from_probs(
291293
double top_p_val,
292294
bool deterministic,
293295
int64_t cuda_stream);
296+
297+
/*
298+
* From flash-attention
299+
*/
300+
std::vector<at::Tensor> mha_fwd(
301+
at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
302+
const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size,
303+
// h_k, d) if there is page_table.
304+
const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages,
305+
// page_size, h_k, dv) if there is page_table.
306+
std::optional<const at::Tensor>&
307+
k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
308+
std::optional<const at::Tensor>&
309+
v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
310+
std::optional<const at::Tensor>& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
311+
std::optional<at::Tensor>& out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
312+
std::optional<const at::Tensor>& cu_seqlens_q_, // b+1
313+
std::optional<const at::Tensor>& cu_seqlens_k_, // b+1
314+
std::optional<const at::Tensor>& cu_seqlens_k_new_, // b+1
315+
std::optional<const at::Tensor>&
316+
seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
317+
std::optional<const at::Tensor>&
318+
seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
319+
std::optional<int> max_seqlen_q_,
320+
// TODO: check if we need max_seqlen_k
321+
std::optional<int> max_seqlen_k_,
322+
std::optional<const at::Tensor>& page_table_, // (b_k, max_num_pages_per_seq)
323+
std::optional<const at::Tensor>& kv_batch_idx_, // b. indices to index into the KV cache
324+
std::optional<const at::Tensor>& leftpad_k_, // b
325+
std::optional<const at::Tensor>& rotary_cos_, // seqlen_ro x (rotary_dim / 2)
326+
std::optional<const at::Tensor>& rotary_sin_, // seqlen_ro x (rotary_dim / 2)
327+
std::optional<const at::Tensor>& seqlens_rotary_, // b
328+
std::optional<at::Tensor>& q_descale_, // (b, h_k), not (b, h)
329+
std::optional<at::Tensor>& k_descale_, // (b, h_k)
330+
std::optional<at::Tensor>& v_descale_, // (b, h_k)
331+
float const softmax_scale,
332+
bool is_causal,
333+
int window_size_left,
334+
int window_size_right,
335+
float const softcap,
336+
bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
337+
std::optional<at::Tensor>& scheduler_metadata_, // (b + 1)
338+
int num_splits,
339+
std::optional<bool> pack_gqa_,
340+
int const sm_margin);
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
/*Adapt from:
2+
https://github.com/neuralmagic/vllm-flash-attention/blob/90eacc1af2a7c3de62ea249e929ed5faccf38954/csrc/common/pytorch_shim.h
3+
Copyright 2025 SGLang Team. All Rights Reserved.
4+
5+
Licensed under the Apache License, Version 2.0 (the "License");
6+
you may not use this file except in compliance with the License.
7+
You may obtain a copy of the License at
8+
9+
http://www.apache.org/licenses/LICENSE-2.0
10+
11+
Unless required by applicable law or agreed to in writing, software
12+
distributed under the License is distributed on an "AS IS" BASIS,
13+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
See the License for the specific language governing permissions and
15+
limitations under the License.
16+
==============================================================================*/
17+
18+
#pragma once
19+
20+
#include <torch/library.h>
21+
22+
/**
23+
* Unforunately, the type signatures of the flash_attn ops are not compatible
24+
* with the PyTorch library bindings. To get around that we use
25+
* `make_pytorch_shim` which creates a lambda that exponses the API using
26+
* PyTorch compatible types to the types, then converts them to the types
27+
* expected by the flash_attn ops. This shims allows us to make minimal changes
28+
* to `flash_api.cpp` making it easier to synchronize with upstream changes.
29+
*
30+
* The `pytorch_library_compatible_type` struct is used to map from the
31+
* flash_attn ops types to a PyTorch library compatible one. The main issues is
32+
* that the following types are not support by PyTorch libary bindings:
33+
* - `int`
34+
* - `float`
35+
* - `std::optional<T> &`
36+
* - `std::optional<const at::Tensor> &`
37+
* So we convert them to (respectively):
38+
* - `int64_t`
39+
* - `double`
40+
* - `const std::optional<T>&`
41+
* - `const std::optional<at::Tensor>&`
42+
*/
43+
44+
template <typename T>
45+
struct pytorch_library_compatible_type {
46+
using type = T;
47+
static T convert_from_type(T arg) {
48+
return arg;
49+
}
50+
};
51+
52+
template <typename T>
53+
using pytorch_library_compatible_type_t = typename pytorch_library_compatible_type<T>::type;
54+
55+
template <typename T>
56+
T convert_from_pytorch_compatible_type(pytorch_library_compatible_type_t<T> arg) {
57+
return pytorch_library_compatible_type<T>::convert_from_type(arg);
58+
}
59+
60+
// Map `c10::optional<T> &` -> `const c10::optional<T>&`
61+
// (NOTE: this is bit unsafe but non of the ops in flash_attn mutate
62+
// the optional container)
63+
template <typename T>
64+
struct pytorch_library_compatible_type<c10::optional<T>&> {
65+
using type = const c10::optional<T>&;
66+
static c10::optional<T>& convert_from_type(const c10::optional<T>& arg) {
67+
return const_cast<c10::optional<T>&>(arg);
68+
}
69+
};
70+
71+
// Map `c10::optional<T>` ->
72+
// `c10::optional<pytorch_library_compatible_type_t<T>>`
73+
// (NOTE: tested for `c10::optional<int>` -> `c10::optional<int64_t>`)
74+
template <typename T>
75+
struct pytorch_library_compatible_type<c10::optional<T>> {
76+
using type = c10::optional<pytorch_library_compatible_type_t<T>>;
77+
static c10::optional<pytorch_library_compatible_type_t<T>> convert_from_type(c10::optional<T> arg) {
78+
return arg;
79+
}
80+
};
81+
82+
// Map `c10::optional<const at::Tensor>&` -> `const c10::optional<at::Tensor>&`
83+
template <>
84+
struct pytorch_library_compatible_type<c10::optional<const at::Tensor>&> {
85+
using type = const c10::optional<at::Tensor>&;
86+
static c10::optional<const at::Tensor>& convert_from_type(const c10::optional<at::Tensor>& arg) {
87+
return const_cast<c10::optional<const at::Tensor>&>(reinterpret_cast<const c10::optional<const at::Tensor>&>(arg));
88+
}
89+
};
90+
91+
// Map `int` -> `int64_t`
92+
template <>
93+
struct pytorch_library_compatible_type<int> {
94+
using type = int64_t;
95+
static int convert_from_type(int64_t arg) {
96+
TORCH_CHECK(arg <= std::numeric_limits<int>::max(), "int64_t value is too large to be converted to int");
97+
TORCH_CHECK(arg >= std::numeric_limits<int>::min(), "int64_t value is too small to be converted to int");
98+
return arg;
99+
}
100+
};
101+
102+
// Map `float` -> `double`
103+
template <>
104+
struct pytorch_library_compatible_type<float> {
105+
using type = double;
106+
static float convert_from_type(double arg) {
107+
TORCH_CHECK(
108+
std::abs(arg) <= std::numeric_limits<float>::max(), "double value is too large to be converted to float");
109+
return arg;
110+
}
111+
};
112+
113+
//
114+
// Shim Utils
115+
//
116+
117+
template <typename Ret, typename... Args>
118+
auto make_pytorch_shim(Ret (*fun)(Args... args)) {
119+
return [fun](pytorch_library_compatible_type_t<Args>... args) {
120+
return fun(convert_from_pytorch_compatible_type<Args>(args)...);
121+
};
122+
}

0 commit comments

Comments
 (0)