Skip to content

Commit fa60e7c

Browse files
committed
Add option to precompute scheduler metadata
1 parent 90f27a2 commit fa60e7c

9 files changed

+235
-36
lines changed

hopper/benchmark_attn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def time_fwd(func, *args, repeats=30, verbose=True, desc="", **kwargs):
5656
# time_f = benchmark_forward(lambda: graph.replay(), repeats=repeats, verbose=verbose, desc=desc)
5757
# # return time_f[1].mean
5858
# return time_f[1]
59-
return Timing(do_bench(lambda: func(*args, **kwargs), warmup=5, rep=repeats) * 1e-3)
59+
return Timing(do_bench(lambda: func(*args, **kwargs), warmup=3, rep=repeats) * 1e-3)
6060

6161

6262
def flops(batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, window_size=(-1, -1)):
@@ -404,7 +404,8 @@ def run(*args, **kwargs):
404404
# import pickle
405405
# # with open(f'flash3_attn_time_h100_hdim{headdim}_causal.plk', 'wb') as fp:
406406
# # with open(f'flash3_attn_time_h100_cudnn_triton_20241208.plk', 'wb') as fp:
407-
# with open(f'flash3_attn_time_h100_fa3_20241208.plk', 'wb') as fp:
407+
# with open(f'flash3_attn_time_h100_fa3_20250313.plk', 'wb') as fp:
408+
# # with open(f'flash3_attn_time_h100_fa3_fp8_20250313.plk', 'wb') as fp:
408409
# # with open(f'flash3_attn_time_h100_fp8_hdim{headdim}.plk', 'wb') as fp:
409410
# # with open(f'flash3_attn_time_h100_hdim{headdim}_1031.plk', 'wb') as fp:
410411
# pickle.dump((time_f, time_b), fp, protocol=pickle.HIGHEST_PROTOCOL)

hopper/cuda_check.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
/******************************************************************************
2+
* Copyright (c) 2024, Tri Dao.
3+
******************************************************************************/
4+
5+
#pragma once
6+
7+
#include <assert.h>
8+
#include <stdlib.h>
9+
10+
#define CHECK_CUDA(call) \
11+
do { \
12+
cudaError_t status_ = call; \
13+
if (status_ != cudaSuccess) { \
14+
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
15+
exit(1); \
16+
} \
17+
} while(0)
18+
19+
#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError())

hopper/flash.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ struct Flash_fwd_params : public Qkv_params {
153153
// int * __restrict__ num_m_blocks_ptr;
154154
// int * __restrict__ num_n_blocks_ptr;
155155
int * __restrict__ num_splits_dynamic_ptr;
156+
bool skip_scheduler_metadata_computation;
156157

157158
int arch;
158159
int num_sm;
@@ -208,7 +209,7 @@ struct Flash_bwd_params : public Flash_fwd_params {
208209

209210
template <int Arch, typename T, int kHeadDim, int kHeadDimV, bool Split, bool PagedKVNonTMA, bool Has_softcap, bool PackGQA>
210211
void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
211-
void prepare_varlen_num_blocks(Flash_fwd_params &params, cudaStream_t stream, bool packgqa, int blockM, int blockN);
212+
void prepare_varlen_num_blocks(Flash_fwd_params &params, cudaStream_t stream, bool packgqa, int blockM, int blockN, bool enable_pdl);
212213
template <int Arch, typename T, int kHeadDim, bool Has_softcap>
213214
void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
214215
template <typename T, typename Tpartial, int kBlockK>

hopper/flash_api.cpp

Lines changed: 139 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "static_switch.h"
1616
#include "tile_size.h"
1717
#include "heuristics.h"
18+
#include "cuda_check.h"
1819

1920
// Copied from https://github.com/pytorch/pytorch/commit/7931eee5c5ebcdf468bff4d308510b03355cd909
2021
// This is so that we can pass in torch.dtype as a parameter to the function.
@@ -490,6 +491,127 @@ inline int round_up_headdim(int head_size) {
490491
return 256;
491492
}
492493

494+
// Only applicable to the case where seqused_k (i.e. cache_seqlens) is available
495+
at::Tensor
496+
mha_fwd_get_scheduler_metadata(
497+
int batch_size,
498+
int max_seqlen_q,
499+
int max_seqlen_k,
500+
int num_heads,
501+
int num_heads_k,
502+
int headdim,
503+
int headdim_v,
504+
at::ScalarType qkv_dtype,
505+
const at::Tensor &seqused_k, // b
506+
std::optional<const at::Tensor> &cu_seqlens_q_, // b+1
507+
std::optional<const at::Tensor> &cu_seqlens_k_, // b+1
508+
std::optional<const at::Tensor> &cu_seqlens_k_new_, // b+1
509+
std::optional<const at::Tensor> &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
510+
std::optional<const at::Tensor> &leftpad_k_, // b
511+
std::optional<int> page_size,
512+
int max_seqlen_k_new, // 0 means we're not appending new KV
513+
bool is_causal,
514+
int window_size_left,
515+
int window_size_right,
516+
bool has_softcap,
517+
int num_splits,
518+
std::optional<bool> pack_gqa_,
519+
int const sm_margin
520+
) {
521+
522+
TORCH_CHECK(qkv_dtype == at::ScalarType::Half || qkv_dtype == at::ScalarType::BFloat16 || qkv_dtype == at::ScalarType::Float8_e4m3fn,
523+
"FlashAttention only supports fp16, bf16, and fp8_e4m3 data type");
524+
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
525+
526+
// Reset the parameters
527+
Flash_fwd_params params{};
528+
params.is_bf16 = qkv_dtype == at::ScalarType::BFloat16;
529+
params.is_e4m3 = qkv_dtype == at::ScalarType::Float8_e4m3fn;
530+
params.b = batch_size;
531+
params.seqlen_q = max_seqlen_q;
532+
params.seqlen_k = max_seqlen_k;
533+
params.h = num_heads;
534+
params.h_k = num_heads_k;
535+
params.d = headdim;
536+
params.dv = headdim_v;
537+
params.d_rounded = round_up_headdim(headdim);
538+
params.dv_rounded = round_up_headdim(headdim_v);
539+
params.seqlen_knew = max_seqlen_k_new;
540+
541+
bool const is_varlen_q = cu_seqlens_q_.has_value();
542+
params.cu_seqlens_q = is_varlen_q ? cu_seqlens_q_.value().data_ptr<int>() : nullptr;
543+
bool const is_varlen_k = cu_seqlens_k_.has_value();
544+
params.cu_seqlens_k = is_varlen_k ? cu_seqlens_k_.value().data_ptr<int>() : nullptr;
545+
params.cu_seqlens_knew = cu_seqlens_k_new_.has_value() ? cu_seqlens_k_new_.value().data_ptr<int>() : nullptr;
546+
params.seqused_q = seqused_q_.has_value() ? seqused_q_.value().data_ptr<int>() : nullptr;
547+
params.seqused_k = seqused_k.data_ptr<int>();
548+
params.leftpad_k = leftpad_k_.has_value() ? leftpad_k_.value().data_ptr<int>() : nullptr;
549+
params.knew_ptr = params.seqlen_knew > 0 ? reinterpret_cast<int*>(1) : nullptr;
550+
if (window_size_left >= max_seqlen_k - 1) { window_size_left = -1; }
551+
if (window_size_right >= max_seqlen_q - 1) { window_size_right = -1; }
552+
// causal=true is the same as causal=false in this case
553+
if (max_seqlen_q == 1 && window_size_left == -1 && window_size_right == -1) {
554+
// Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA
555+
if ((headdim <= 64 || headdim > 128) || !page_size.has_value()) {
556+
is_causal = false;
557+
}
558+
}
559+
if (is_causal) { window_size_right = 0; }
560+
561+
params.is_causal = window_size_left < 0 && window_size_right == 0;
562+
params.is_local = (window_size_left >= 0 || window_size_right >= 0) && !params.is_causal;
563+
if (window_size_left < 0 && window_size_right >= 0) { window_size_left = max_seqlen_k - 1; }
564+
if (window_size_left >= 0 && window_size_right < 0) { window_size_right = max_seqlen_q - 1; }
565+
params.window_size_left = window_size_left;
566+
params.window_size_right = window_size_right;
567+
params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;
568+
params.num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin;
569+
params.softcap = has_softcap ? 1.0f : 0.0f;
570+
571+
params.page_size = page_size.has_value() ? page_size.value() : 1;
572+
params.page_table = !page_size.has_value() ? nullptr : reinterpret_cast<int*>(1);
573+
574+
bool const use_dynamic_split = params.b <= 992;
575+
params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast<int*>(1);
576+
577+
params.pagedkv_tma = get_pagedkv_tma(params);
578+
params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits;
579+
// Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide
580+
params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params);
581+
582+
bool is_varlen = true;
583+
584+
// Otherwise the kernel will be launched from cuda:0 device
585+
// Cast to char to avoid compiler warning about narrowing
586+
at::cuda::CUDAGuard device_guard{(char)seqused_k.get_device()};
587+
588+
auto opts = seqused_k.options();
589+
// This needs to be set after get_num_splits
590+
at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic
591+
bool const scheduler_needs_semaphore = params.arch >= 90 || params.num_splits > 1;
592+
if (scheduler_needs_semaphore || use_dynamic_split) {
593+
tile_count_semaphore = torch::empty({int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b}, opts.dtype(torch::kInt32));
594+
if (scheduler_needs_semaphore) {
595+
if (!use_dynamic_split) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing
596+
params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
597+
} else {
598+
params.tile_count_semaphore = nullptr;
599+
}
600+
params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr<int>() + 1 : nullptr;
601+
}
602+
603+
if (params.num_splits_dynamic_ptr) {
604+
auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f);
605+
auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr);
606+
int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x);
607+
int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x);
608+
auto stream = at::cuda::getCurrentCUDAStream().stream();
609+
prepare_varlen_num_blocks(params, stream, params.pack_gqa, kBlockM, kBlockN, false /*enable_pdl*/);
610+
CHECK_CUDA_KERNEL_LAUNCH();
611+
}
612+
return tile_count_semaphore;
613+
}
614+
493615
// b: batch_size
494616
// b_k: batch_size_k
495617
// s_q: seqlen_q
@@ -528,6 +650,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
528650
int window_size_right,
529651
float const softcap,
530652
bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
653+
std::optional<at::Tensor> &scheduler_metadata_, // (b + 1)
531654
int num_splits,
532655
std::optional<bool> pack_gqa_,
533656
int const sm_margin
@@ -814,21 +937,24 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
814937
bool const scheduler_needs_semaphore = params.arch >= 90
815938
? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen)
816939
: ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1));
817-
if (scheduler_needs_semaphore || use_dynamic_split) { // This needs to be set before get_num_splits
818-
tile_count_semaphore = torch::empty({int(scheduler_needs_semaphore) + int(use_dynamic_split) * batch_size}, opts.dtype(torch::kInt32));
819-
if (scheduler_needs_semaphore) {
820-
if (!use_dynamic_split) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing
821-
params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
940+
if (scheduler_needs_semaphore || use_dynamic_split) {
941+
int metadata_size = int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b;
942+
params.skip_scheduler_metadata_computation = scheduler_metadata_.has_value();
943+
if (scheduler_metadata_.has_value()) {
944+
at::Tensor scheduler_metadata = scheduler_metadata_.value();
945+
CHECK_DEVICE(scheduler_metadata);
946+
CHECK_SHAPE(scheduler_metadata, metadata_size);
947+
CHECK_CONTIGUOUS(scheduler_metadata);
948+
TORCH_CHECK(scheduler_metadata.dtype() == torch::kInt32, "scheduler_metadata must have dtype int32");
949+
tile_count_semaphore = scheduler_metadata;
822950
} else {
823-
params.tile_count_semaphore = nullptr;
951+
tile_count_semaphore = torch::empty({metadata_size}, opts.dtype(torch::kInt32));
824952
}
825-
if (use_dynamic_split) {
826-
// params.num_m_blocks_ptr = num_m_n_blocks_splits.data_ptr<int>();
827-
// params.num_n_blocks_ptr = num_m_n_blocks_splits.data_ptr<int>() + batch_size;
828-
params.num_splits_dynamic_ptr = tile_count_semaphore.data_ptr<int>() + 1;
829-
} else {
830-
params.num_splits_dynamic_ptr = nullptr;
953+
if (scheduler_needs_semaphore && !use_dynamic_split) {
954+
tile_count_semaphore.zero_(); // If varlen we'll manually do the zero-ing
831955
}
956+
params.tile_count_semaphore = scheduler_needs_semaphore ? tile_count_semaphore.data_ptr<int>() : nullptr;
957+
params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr<int>() + 1 : nullptr;
832958
}
833959

834960
if (q_v_.has_value()) {
@@ -1449,4 +1575,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
14491575
m.def("fwd", &mha_fwd, "Forward pass");
14501576
m.def("bwd", &mha_bwd, "Backward pass");
14511577
m.def("fwd_combine", &mha_combine, "Combine partial attention outputs");
1578+
m.def("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata, "Get scheduler metadata for varlen forward pass");
14521579
}

hopper/flash_attn_interface.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def _flash_attn_forward(
4444
window_size=(-1, -1),
4545
softcap=0.0,
4646
rotary_interleaved=True,
47+
scheduler_metadata=None,
4748
num_splits=1,
4849
pack_gqa=None,
4950
sm_margin=0):
@@ -86,11 +87,12 @@ def _flash_attn_forward(
8687
window_size[1],
8788
softcap,
8889
rotary_interleaved,
90+
scheduler_metadata,
8991
num_splits,
9092
pack_gqa,
9193
sm_margin,
9294
)
93-
return (out, softmax_lse, *rest)
95+
return out, softmax_lse, *rest
9496

9597

9698
def _flash_attn_backward(
@@ -608,6 +610,7 @@ def flash_attn_with_kvcache(
608610
window_size=(-1, -1), # -1 means infinite context window
609611
softcap=0.0, # 0.0 means deactivated
610612
rotary_interleaved=True,
613+
scheduler_metadata=None,
611614
num_splits=0, # Can be tuned for speed
612615
pack_gqa=None, # Can be tuned for speed
613616
sm_margin=0, # Can be tuned if some SMs are used for communication
@@ -733,9 +736,51 @@ def flash_attn_with_kvcache(
733736
window_size=window_size,
734737
softcap=softcap,
735738
rotary_interleaved=rotary_interleaved,
739+
scheduler_metadata=scheduler_metadata,
736740
num_splits=num_splits,
737741
pack_gqa=pack_gqa,
738742
sm_margin=sm_margin,
739743
)
740744
# return (out, softmax_lse) if return_softmax_lse else out
741745
return (out, softmax_lse, *rest) if return_softmax_lse else out
746+
747+
748+
def get_scheduler_metadata(
749+
batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
750+
cache_seqlens: torch.Tensor,
751+
qkv_dtype=torch.bfloat16,
752+
headdim_v=None,
753+
cu_seqlens_q: Optional[torch.Tensor] = None,
754+
cu_seqlens_k_new: Optional[torch.Tensor] = None,
755+
cache_leftpad: Optional[torch.Tensor] = None,
756+
page_size: Optional[int] = None,
757+
max_seqlen_k_new=0,
758+
causal=False,
759+
window_size=(-1, -1), # -1 means infinite context window
760+
has_softcap=False,
761+
num_splits=0, # Can be tuned for speed
762+
pack_gqa=None, # Can be tuned for speed
763+
sm_margin=0, # Can be tuned if some SMs are used for communication
764+
):
765+
cache_seqlens = maybe_contiguous(cache_seqlens)
766+
if headdim_v is None:
767+
headdim_v = headdim
768+
scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
769+
batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
770+
qkv_dtype,
771+
cache_seqlens,
772+
cu_seqlens_q,
773+
None, # cu_seqlens_k
774+
cu_seqlens_k_new,
775+
None, # seqused_q
776+
cache_leftpad,
777+
page_size,
778+
max_seqlen_k_new,
779+
causal,
780+
window_size[0], window_size[1],
781+
has_softcap,
782+
num_splits,
783+
pack_gqa,
784+
sm_margin,
785+
)
786+
return scheduler_metadata

hopper/flash_fwd_launch_template.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,8 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
155155
params.num_splits_dynamic_ptr,
156156
};
157157

158-
if (Varlen && params.num_splits_dynamic_ptr) {
159-
prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN);
158+
if (Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation) {
159+
prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN, Arch >= 90 /*enable_pdl*/);
160160
CHECK_CUDA_KERNEL_LAUNCH();
161161
}
162162

@@ -188,7 +188,8 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
188188
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
189189
}
190190
// kernel<<<grid_dims, block_dims, smem_size, stream>>>(kernel_params);
191-
cutlass::kernel_launch<AttnKernel>(grid_dims, block_dims, smem_size, stream, kernel_params, Arch >= 90 && Varlen /*launch_with_pdl*/);
191+
cutlass::kernel_launch<AttnKernel>(grid_dims, block_dims, smem_size, stream, kernel_params,
192+
Arch >= 90 && Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation /*launch_with_pdl*/);
192193
}
193194
CHECK_CUDA_KERNEL_LAUNCH();
194195
}

hopper/flash_prepare_scheduler.cu

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,16 @@ __global__ void prepare_varlen_num_blocks_kernel(
2020
cutlass::FastDivmod blockm_divmod, cutlass::FastDivmod blockn_divmod,
2121
int* const tile_count_semaphore,
2222
// int* const num_m_blocks_ptr,
23-
int* const num_splits_dynamic_ptr) {
23+
int* const num_splits_dynamic_ptr,
24+
bool enable_pdl) {
2425

2526
static constexpr int kNumBatchPerWarp = cutlass::NumThreadsPerWarp - 1;
2627
static constexpr int kSmemSize = 1;
2728
// Assume that there's only one block in the grid
2829
__shared__ int total_blocks_smem[kSmemSize];
2930

3031
// There's only 1 block in the grid, so might as well start launching the main attn kernel
31-
cutlass::arch::launch_dependent_grids();
32+
if (enable_pdl) { cutlass::arch::launch_dependent_grids(); }
3233

3334
if (threadIdx.x < kSmemSize) { total_blocks_smem[threadIdx.x] = 0; }
3435
__syncthreads();
@@ -108,7 +109,7 @@ __global__ void prepare_varlen_num_blocks_kernel(
108109
} // flash
109110

110111
void prepare_varlen_num_blocks(Flash_fwd_params &params, cudaStream_t stream, bool packgqa,
111-
int blockM, int blockN) {
112+
int blockM, int blockN, bool enable_pdl) {
112113
// Only support batch <= 992 (32 warps, each with 31 batches)
113114
int qhead_per_khead = !packgqa ? 1 : cutlass::ceil_div(params.h, params.h_k);
114115
flash::prepare_varlen_num_blocks_kernel<<<1 /*grid*/, 1024 /*block*/, 0, stream>>>(
@@ -119,5 +120,5 @@ void prepare_varlen_num_blocks(Flash_fwd_params &params, cudaStream_t stream, bo
119120
cutlass::FastDivmod(blockM), cutlass::FastDivmod(blockN),
120121
params.tile_count_semaphore,
121122
// params.num_m_blocks_ptr,
122-
params.num_splits_dynamic_ptr);
123+
params.num_splits_dynamic_ptr, enable_pdl);
123124
}

0 commit comments

Comments
 (0)