Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions sgl-kernel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,12 @@ set(SGL_KERNEL_CUDA_FLAGS
# "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage"
)

option(SGL_KERNEL_ENABLE_SM100A "Enable SM100A" OFF)
option(SGL_KERNEL_ENABLE_SM90A "Enable SM90A" OFF)
option(SGL_KERNEL_ENABLE_BF16 "Enable BF16" ON)
option(SGL_KERNEL_ENABLE_FP8 "Enable FP8" ON)
option(SGL_KERNEL_ENABLE_FP4 "Enable FP4" OFF)
option(SGL_KERNEL_ENABLE_FA3 "Enable FA3" OFF)
option(SGL_KERNEL_ENABLE_SM100A "Enable SM100A" OFF)
option(SGL_KERNEL_ENABLE_SM90A "Enable SM90A" OFF)
option(SGL_KERNEL_ENABLE_BF16 "Enable BF16" ON)
option(SGL_KERNEL_ENABLE_FP8 "Enable FP8" ON)
option(SGL_KERNEL_ENABLE_FP4 "Enable FP4" OFF)
option(SGL_KERNEL_ENABLE_FA3 "Enable FA3" OFF)

if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A)
list(APPEND SGL_KERNEL_CUDA_FLAGS
Expand Down Expand Up @@ -233,14 +233,15 @@ install(TARGETS common_ops LIBRARY DESTINATION sgl_kernel)

# ============================ Optional Install ============================= #
# set flash-attention sources file
# BF16 source files
# Now FA3 support sm80/sm86/sm90
if (SGL_KERNEL_ENABLE_FA3)
set(SGL_FLASH_KERNEL_CUDA_FLAGS
"-DNDEBUG"
"-DOPERATOR_NAMESPACE=sgl-kernel"
"-O3"
"-Xcompiler"
"-fPIC"
"-gencode=arch=compute_80,code=sm_80"
"-gencode=arch=compute_90a,code=sm_90a"
"-std=c++17"
"-DCUTE_USE_PACKED_TUPLE=1"
Expand All @@ -256,6 +257,10 @@ if (SGL_KERNEL_ENABLE_FA3)
"-Xcompiler=-fno-strict-aliasing"
)

# SM8X Logic
file(GLOB FA3_SM8X_GEN_SRCS
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdim*_sm80.cu")

file(GLOB FA3_BF16_GEN_SRCS
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu")
file(GLOB FA3_BF16_GEN_SRCS_
Expand All @@ -276,7 +281,7 @@ if (SGL_KERNEL_ENABLE_FA3)
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_e4m3*_sm90.cu")
list(APPEND FA3_FP8_GEN_SRCS ${FA3_FP8_GEN_SRCS_})

set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} ${FA3_FP8_GEN_SRCS})
set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} ${FA3_FP8_GEN_SRCS} ${FA3_SM8X_GEN_SRCS})

set(FLASH_SOURCES
"csrc/flash_extension.cc"
Expand All @@ -297,7 +302,7 @@ if (SGL_KERNEL_ENABLE_FA3)
install(TARGETS flash_ops LIBRARY DESTINATION "sgl_kernel")

target_compile_definitions(flash_ops PRIVATE
FLASHATTENTION_DISABLE_SM8x
# FLASHATTENTION_DISABLE_SM8x
FLASHATTENTION_DISABLE_BACKWARD
FLASHATTENTION_DISABLE_DROPOUT
FLASHATTENTION_DISABLE_UNEVEN_K
Expand All @@ -318,3 +323,4 @@ install(DIRECTORY "${repo-cutlass_SOURCE_DIR}/include/cute/"

install(DIRECTORY "${repo-cutlass_SOURCE_DIR}/include/cutlass/"
DESTINATION "deep_gemm/include/cutlass")

8 changes: 8 additions & 0 deletions sgl-kernel/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ Third-party libraries:
- [DeepGEMM](https://github.com/deepseek-ai/DeepGEMM)
- [FlashAttention](https://github.com/Dao-AILab/flash-attention)

### FlashAttention FYI

FA3 can fail without a enough shared memory for a some shapes, such as higher hidden_dim or some special cases. Right now, fa3 is supported for sm80/sm87 and sm86/sm89.

The main different Between sm80/sm87 and sm86/sm89 is the shared memory size. you can follow the link below for more information https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x.

And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a. Thats mean if you use **A100(tested)**/A*0/**L20(tested)**/L40/L40s/**3090(tested)** you can use fa3.

### Kernel Development

Steps to add a new kernel:
Expand Down
15 changes: 12 additions & 3 deletions sgl-kernel/python/sgl_kernel/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,18 @@


def is_fa3_supported(device=None) -> bool:
# now sgl-kernel only build fa3 for sm90a && cuda >= 12.3
return (torch.cuda.get_device_capability(device)[0] == 9) and (
torch.version.cuda >= "12.3"
# There some fa3 FYI
# FA3 can fail without a enough shared memory for a some shapes, such as higher
# hidden_dim or some special cases.
# Right now, fa3 is supported for sm80/sm87 and sm86/sm89. The main different
# Between sm80/sm87 and sm86/sm89 is the shared memory size. you can follow the link below for more information
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
# And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a.
# Thats mean if you use A100/A*0/L20/L40/L40s/4090 you can use fa3.
return (
(torch.cuda.get_device_capability(device)[0] == 9
or torch.cuda.get_device_capability(device)[0] == 8)
and (torch.version.cuda >= "12.3")
)


Expand Down
25 changes: 17 additions & 8 deletions sgl-kernel/tests/test_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,24 @@
apply_rotary_emb = None


def is_hopper():
# Only Hopper supports different V headdim
return torch.cuda.get_device_properties(0).major >= 9


def is_fa3_supported(device=None) -> bool:
# FA3 can fail without a enough shared memory for a some shapes, currently
# only 8.0 and 8.7 have enough shared memory for all shapes
# There some fa3 FYI
# FA3 can fail without a enough shared memory for a some shapes, such as higher
# hidden_dim or some special cases.
# Right now, fa3 is supported for sm80/sm87 and sm86/sm89. The main different
# Between sm80/sm87 and sm86/sm89 is the shared memory size. you can follow the link below for more information
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
# now sgl-kernel only build fa3 for sm90a && cuda >= 12.4
# And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a.
# Thats mean if you use A100/A*0/L20/L40/L40s/4090 you can use fa3.
return (
(torch.cuda.get_device_capability(device)[0] == 9)
and (torch.version.cuda >= "12.4")
# or torch.cuda.get_device_capability(device) == (8, 0)
# or torch.cuda.get_device_capability(device) == (8, 7)
(torch.cuda.get_device_capability(device)[0] == 9
or torch.cuda.get_device_capability(device)[0] == 8)
and (torch.version.cuda >= "12.3")
)


Expand Down Expand Up @@ -558,7 +566,8 @@ def test_flash_attn_kvcache(
assert nheads % nheads_k == 0
dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])
if dtype == torch.float8_e4m3fn:
if dtype == torch.float8_e4m3fn or not is_hopper():
# for fp8 and ampere arch, we not support v head dim != qk head dim
dv_vals = [d]
for dv in dv_vals:
has_qv = d == 64 and dv >= 256
Expand Down
Loading