Skip to content
Merged
54 changes: 54 additions & 0 deletions sgl-kernel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,39 @@ find_package(Torch REQUIRED)

include(FetchContent)

# cutlass
FetchContent_Declare(
repo-cutlass
GIT_REPOSITORY https://github.com/NVIDIA/cutlass
GIT_TAG 62750a2b75c802660e4894434dc55e839f322277
GIT_SHALLOW ON
)
FetchContent_Populate(repo-cutlass)
# DeepGEMM
FetchContent_Declare(
repo-deepgemm
GIT_REPOSITORY https://github.com/deepseek-ai/DeepGEMM
GIT_TAG c57699ac933a93651c34d365797c2d8b41a4765b
GIT_SHALLOW ON
)
FetchContent_Populate(repo-deepgemm)
# flashinfer
FetchContent_Declare(
repo-flashinfer
GIT_REPOSITORY https://github.com/sgl-project/flashinfer
GIT_TAG sgl-kernel
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-flashinfer)
# flash-attention
FetchContent_Declare(
repo-flash-attention
GIT_REPOSITORY https://github.com/sgl-project/sgl-attn
GIT_TAG sgl-kernel
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-flash-attention)


include_directories(
${PROJECT_SOURCE_DIR}/include
Expand All @@ -54,6 +66,7 @@ include_directories(
${repo-cutlass_SOURCE_DIR}/tools/util/include
${repo-flashinfer_SOURCE_DIR}/include
${repo-flashinfer_SOURCE_DIR}/csrc
${repo-flash-attention_SOURCE_DIR}/hopper
)

set(CMAKE_CXX_STANDARD 17)
Expand All @@ -78,6 +91,7 @@ set(SGL_KERNEL_CUDA_FLAGS
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1"
"-DCUTLASS_DEBUG_TRACE_LEVEL=0"
"--expt-relaxed-constexpr"
"--use_fast_math"
"-Xcompiler=-Wconversion"
"-Xcompiler=-fno-strict-aliasing"
)
Expand Down Expand Up @@ -130,6 +144,30 @@ string(REPLACE "-D__CUDA_NO_HALF_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE
string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")

# set flash-attention sources file
# BF16 source files
file(GLOB FA3_BF16_GEN_SRCS
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu")
file(GLOB FA3_BF16_GEN_SRCS_
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_bf16*_sm90.cu")
list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_})

# FP16 source files
file(GLOB FA3_FP16_GEN_SRCS
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_fp16*_sm90.cu")
file(GLOB FA3_FP16_GEN_SRCS_
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_fp16*_sm90.cu")
list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_})

# FP8 source files
file(GLOB FA3_FP8_GEN_SRCS
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_e4m3*_sm90.cu")
file(GLOB FA3_FP8_GEN_SRCS_
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_e4m3*_sm90.cu")
list(APPEND FA3_FP8_GEN_SRCS ${FA3_FP8_GEN_SRCS_})

set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} ${FA3_FP8_GEN_SRCS})

set(SOURCES
"csrc/allreduce/trt_reduce_internal.cu"
"csrc/allreduce/trt_reduce_kernel.cu"
Expand Down Expand Up @@ -160,6 +198,10 @@ set(SOURCES
"${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_prepare_scheduler.cu"
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_api.cpp"
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_fwd_combine.cu"
"${FA3_GEN_SRCS}"
)

# Support abi3 for build
Expand All @@ -173,6 +215,18 @@ target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cubl

install(TARGETS common_ops LIBRARY DESTINATION "sgl_kernel")

# Add some flash-attention custom flag for inference
target_compile_definitions(common_ops PRIVATE
FLASHATTENTION_DISABLE_SM8x
FLASHATTENTION_DISABLE_BACKWARD
FLASHATTENTION_DISABLE_DROPOUT
# FLASHATTENTION_DISABLE_ALIBI
# FLASHATTENTION_DISABLE_SOFTCAP
FLASHATTENTION_DISABLE_UNEVEN_K
# FLASHATTENTION_DISABLE_LOCAL
FLASHATTENTION_VARLEN_ONLY
)

# JIT Logic
# DeepGEMM

Expand Down
30 changes: 30 additions & 0 deletions sgl-kernel/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,36 @@ Steps to add a new kernel:
)
```

### Integrating Third-Party Libraries with Data Type Conversion

When integrating new third-party libraries like flash-attention, you may encounter data type compatibility issues between the C++ interface and PyTorch bindings. For example, the third-party code might use `float` or `int` types, while PyTorch requires `double` and `int64_t`.

To address this issue, we provide the `make_pytorch_shim` function in [sgl_kernel_torch_shim](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/include/sgl_kernel_torch_shim.h) that handles data type conversions automatically.

When you need to support new data type conversions, you can easily add conversion functions like this:

```cpp
// Map `int` -> `int64_t`
template <>
struct pytorch_library_compatible_type<int> {
using type = int64_t;
static int convert_from_type(int64_t arg) {
TORCH_CHECK(arg <= std::numeric_limits<int>::max(), "int64_t value is too large to be converted to int");
TORCH_CHECK(arg >= std::numeric_limits<int>::min(), "int64_t value is too small to be converted to int");
return arg;
}
};
```

To use this with your library functions, simply wrap them with make_pytorch_shim:

```cpp
/*
* From flash-attention
*/
m.def("fwd", make_pytorch_shim(mha_fwd));
```

### Build & Install

Development build:
Expand Down
5 changes: 5 additions & 0 deletions sgl-kernel/csrc/torch_extension.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
m.def("top_p_renorm_probs", top_p_renorm_probs);
m.def("top_k_top_p_sampling_from_probs", top_k_top_p_sampling_from_probs);
m.def("top_p_sampling_from_probs", top_p_sampling_from_probs);

/*
* From flash-attention
*/
m.def("fwd", make_pytorch_shim(mha_fwd));
}

REGISTER_EXTENSION(common_ops)
47 changes: 47 additions & 0 deletions sgl-kernel/include/sgl_kernel_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ limitations under the License.

#include <vector>

#include "sgl_kernel_torch_shim.h"

#define _CONCAT(A, B) A##B
#define CONCAT(A, B) _CONCAT(A, B)

Expand Down Expand Up @@ -291,3 +293,48 @@ void top_p_sampling_from_probs(
double top_p_val,
bool deterministic,
int64_t cuda_stream);

/*
* From flash-attention
*/
std::vector<at::Tensor> mha_fwd(
at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size,
// h_k, d) if there is page_table.
const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages,
// page_size, h_k, dv) if there is page_table.
std::optional<const at::Tensor>&
k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
std::optional<const at::Tensor>&
v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
std::optional<const at::Tensor>& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
std::optional<at::Tensor>& out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
std::optional<const at::Tensor>& cu_seqlens_q_, // b+1
std::optional<const at::Tensor>& cu_seqlens_k_, // b+1
std::optional<const at::Tensor>& cu_seqlens_k_new_, // b+1
std::optional<const at::Tensor>&
seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
std::optional<const at::Tensor>&
seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
std::optional<int> max_seqlen_q_,
// TODO: check if we need max_seqlen_k
std::optional<int> max_seqlen_k_,
std::optional<const at::Tensor>& page_table_, // (b_k, max_num_pages_per_seq)
std::optional<const at::Tensor>& kv_batch_idx_, // b. indices to index into the KV cache
std::optional<const at::Tensor>& leftpad_k_, // b
std::optional<const at::Tensor>& rotary_cos_, // seqlen_ro x (rotary_dim / 2)
std::optional<const at::Tensor>& rotary_sin_, // seqlen_ro x (rotary_dim / 2)
std::optional<const at::Tensor>& seqlens_rotary_, // b
std::optional<at::Tensor>& q_descale_, // (b, h_k), not (b, h)
std::optional<at::Tensor>& k_descale_, // (b, h_k)
std::optional<at::Tensor>& v_descale_, // (b, h_k)
float const softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
float const softcap,
bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
std::optional<at::Tensor>& scheduler_metadata_, // (b + 1)
int num_splits,
std::optional<bool> pack_gqa_,
int const sm_margin);
122 changes: 122 additions & 0 deletions sgl-kernel/include/sgl_kernel_torch_shim.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*Adapt from:
https://github.com/neuralmagic/vllm-flash-attention/blob/90eacc1af2a7c3de62ea249e929ed5faccf38954/csrc/common/pytorch_shim.h
Copyright 2025 SGLang Team. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#pragma once

#include <torch/library.h>

/**
* Unforunately, the type signatures of the flash_attn ops are not compatible
* with the PyTorch library bindings. To get around that we use
* `make_pytorch_shim` which creates a lambda that exponses the API using
* PyTorch compatible types to the types, then converts them to the types
* expected by the flash_attn ops. This shims allows us to make minimal changes
* to `flash_api.cpp` making it easier to synchronize with upstream changes.
*
* The `pytorch_library_compatible_type` struct is used to map from the
* flash_attn ops types to a PyTorch library compatible one. The main issues is
* that the following types are not support by PyTorch libary bindings:
* - `int`
* - `float`
* - `std::optional<T> &`
* - `std::optional<const at::Tensor> &`
* So we convert them to (respectively):
* - `int64_t`
* - `double`
* - `const std::optional<T>&`
* - `const std::optional<at::Tensor>&`
*/

template <typename T>
struct pytorch_library_compatible_type {
using type = T;
static T convert_from_type(T arg) {
return arg;
}
};

template <typename T>
using pytorch_library_compatible_type_t = typename pytorch_library_compatible_type<T>::type;

template <typename T>
T convert_from_pytorch_compatible_type(pytorch_library_compatible_type_t<T> arg) {
return pytorch_library_compatible_type<T>::convert_from_type(arg);
}

// Map `c10::optional<T> &` -> `const c10::optional<T>&`
// (NOTE: this is bit unsafe but non of the ops in flash_attn mutate
// the optional container)
template <typename T>
struct pytorch_library_compatible_type<c10::optional<T>&> {
using type = const c10::optional<T>&;
static c10::optional<T>& convert_from_type(const c10::optional<T>& arg) {
return const_cast<c10::optional<T>&>(arg);
}
};

// Map `c10::optional<T>` ->
// `c10::optional<pytorch_library_compatible_type_t<T>>`
// (NOTE: tested for `c10::optional<int>` -> `c10::optional<int64_t>`)
template <typename T>
struct pytorch_library_compatible_type<c10::optional<T>> {
using type = c10::optional<pytorch_library_compatible_type_t<T>>;
static c10::optional<pytorch_library_compatible_type_t<T>> convert_from_type(c10::optional<T> arg) {
return arg;
}
};

// Map `c10::optional<const at::Tensor>&` -> `const c10::optional<at::Tensor>&`
template <>
struct pytorch_library_compatible_type<c10::optional<const at::Tensor>&> {
using type = const c10::optional<at::Tensor>&;
static c10::optional<const at::Tensor>& convert_from_type(const c10::optional<at::Tensor>& arg) {
return const_cast<c10::optional<const at::Tensor>&>(reinterpret_cast<const c10::optional<const at::Tensor>&>(arg));
}
};

// Map `int` -> `int64_t`
template <>
struct pytorch_library_compatible_type<int> {
using type = int64_t;
static int convert_from_type(int64_t arg) {
TORCH_CHECK(arg <= std::numeric_limits<int>::max(), "int64_t value is too large to be converted to int");
TORCH_CHECK(arg >= std::numeric_limits<int>::min(), "int64_t value is too small to be converted to int");
return arg;
}
};

// Map `float` -> `double`
template <>
struct pytorch_library_compatible_type<float> {
using type = double;
static float convert_from_type(double arg) {
TORCH_CHECK(
std::abs(arg) <= std::numeric_limits<float>::max(), "double value is too large to be converted to float");
return arg;
}
};

//
// Shim Utils
//

template <typename Ret, typename... Args>
auto make_pytorch_shim(Ret (*fun)(Args... args)) {
return [fun](pytorch_library_compatible_type_t<Args>... args) {
return fun(convert_from_pytorch_compatible_type<Args>(args)...);
};
}
Loading
Loading