|
15 | 15 | #include "static_switch.h"
|
16 | 16 | #include "tile_size.h"
|
17 | 17 | #include "heuristics.h"
|
| 18 | +#include "cuda_check.h" |
18 | 19 |
|
19 | 20 | // Copied from https://github.com/pytorch/pytorch/commit/7931eee5c5ebcdf468bff4d308510b03355cd909
|
20 | 21 | // This is so that we can pass in torch.dtype as a parameter to the function.
|
@@ -490,6 +491,127 @@ inline int round_up_headdim(int head_size) {
|
490 | 491 | return 256;
|
491 | 492 | }
|
492 | 493 |
|
| 494 | +// Only applicable to the case where seqused_k (i.e. cache_seqlens) is available |
| 495 | +at::Tensor |
| 496 | +mha_fwd_get_scheduler_metadata( |
| 497 | + int batch_size, |
| 498 | + int max_seqlen_q, |
| 499 | + int max_seqlen_k, |
| 500 | + int num_heads, |
| 501 | + int num_heads_k, |
| 502 | + int headdim, |
| 503 | + int headdim_v, |
| 504 | + at::ScalarType qkv_dtype, |
| 505 | + const at::Tensor &seqused_k, // b |
| 506 | + std::optional<const at::Tensor> &cu_seqlens_q_, // b+1 |
| 507 | + std::optional<const at::Tensor> &cu_seqlens_k_, // b+1 |
| 508 | + std::optional<const at::Tensor> &cu_seqlens_k_new_, // b+1 |
| 509 | + std::optional<const at::Tensor> &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. |
| 510 | + std::optional<const at::Tensor> &leftpad_k_, // b |
| 511 | + std::optional<int> page_size, |
| 512 | + int max_seqlen_k_new, // 0 means we're not appending new KV |
| 513 | + bool is_causal, |
| 514 | + int window_size_left, |
| 515 | + int window_size_right, |
| 516 | + bool has_softcap, |
| 517 | + int num_splits, |
| 518 | + std::optional<bool> pack_gqa_, |
| 519 | + int const sm_margin |
| 520 | + ) { |
| 521 | + |
| 522 | + TORCH_CHECK(qkv_dtype == at::ScalarType::Half || qkv_dtype == at::ScalarType::BFloat16 || qkv_dtype == at::ScalarType::Float8_e4m3fn, |
| 523 | + "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type"); |
| 524 | + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); |
| 525 | + |
| 526 | + // Reset the parameters |
| 527 | + Flash_fwd_params params{}; |
| 528 | + params.is_bf16 = qkv_dtype == at::ScalarType::BFloat16; |
| 529 | + params.is_e4m3 = qkv_dtype == at::ScalarType::Float8_e4m3fn; |
| 530 | + params.b = batch_size; |
| 531 | + params.seqlen_q = max_seqlen_q; |
| 532 | + params.seqlen_k = max_seqlen_k; |
| 533 | + params.h = num_heads; |
| 534 | + params.h_k = num_heads_k; |
| 535 | + params.d = headdim; |
| 536 | + params.dv = headdim_v; |
| 537 | + params.d_rounded = round_up_headdim(headdim); |
| 538 | + params.dv_rounded = round_up_headdim(headdim_v); |
| 539 | + params.seqlen_knew = max_seqlen_k_new; |
| 540 | + |
| 541 | + bool const is_varlen_q = cu_seqlens_q_.has_value(); |
| 542 | + params.cu_seqlens_q = is_varlen_q ? cu_seqlens_q_.value().data_ptr<int>() : nullptr; |
| 543 | + bool const is_varlen_k = cu_seqlens_k_.has_value(); |
| 544 | + params.cu_seqlens_k = is_varlen_k ? cu_seqlens_k_.value().data_ptr<int>() : nullptr; |
| 545 | + params.cu_seqlens_knew = cu_seqlens_k_new_.has_value() ? cu_seqlens_k_new_.value().data_ptr<int>() : nullptr; |
| 546 | + params.seqused_q = seqused_q_.has_value() ? seqused_q_.value().data_ptr<int>() : nullptr; |
| 547 | + params.seqused_k = seqused_k.data_ptr<int>(); |
| 548 | + params.leftpad_k = leftpad_k_.has_value() ? leftpad_k_.value().data_ptr<int>() : nullptr; |
| 549 | + params.knew_ptr = params.seqlen_knew > 0 ? reinterpret_cast<int*>(1) : nullptr; |
| 550 | + if (window_size_left >= max_seqlen_k - 1) { window_size_left = -1; } |
| 551 | + if (window_size_right >= max_seqlen_q - 1) { window_size_right = -1; } |
| 552 | + // causal=true is the same as causal=false in this case |
| 553 | + if (max_seqlen_q == 1 && window_size_left == -1 && window_size_right == -1) { |
| 554 | + // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA |
| 555 | + if ((headdim <= 64 || headdim > 128) || !page_size.has_value()) { |
| 556 | + is_causal = false; |
| 557 | + } |
| 558 | + } |
| 559 | + if (is_causal) { window_size_right = 0; } |
| 560 | + |
| 561 | + params.is_causal = window_size_left < 0 && window_size_right == 0; |
| 562 | + params.is_local = (window_size_left >= 0 || window_size_right >= 0) && !params.is_causal; |
| 563 | + if (window_size_left < 0 && window_size_right >= 0) { window_size_left = max_seqlen_k - 1; } |
| 564 | + if (window_size_left >= 0 && window_size_right < 0) { window_size_right = max_seqlen_q - 1; } |
| 565 | + params.window_size_left = window_size_left; |
| 566 | + params.window_size_right = window_size_right; |
| 567 | + params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor; |
| 568 | + params.num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin; |
| 569 | + params.softcap = has_softcap ? 1.0f : 0.0f; |
| 570 | + |
| 571 | + params.page_size = page_size.has_value() ? page_size.value() : 1; |
| 572 | + params.page_table = !page_size.has_value() ? nullptr : reinterpret_cast<int*>(1); |
| 573 | + |
| 574 | + bool const use_dynamic_split = params.b <= 992; |
| 575 | + params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast<int*>(1); |
| 576 | + |
| 577 | + params.pagedkv_tma = get_pagedkv_tma(params); |
| 578 | + params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; |
| 579 | + // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide |
| 580 | + params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); |
| 581 | + |
| 582 | + bool is_varlen = true; |
| 583 | + |
| 584 | + // Otherwise the kernel will be launched from cuda:0 device |
| 585 | + // Cast to char to avoid compiler warning about narrowing |
| 586 | + at::cuda::CUDAGuard device_guard{(char)seqused_k.get_device()}; |
| 587 | + |
| 588 | + auto opts = seqused_k.options(); |
| 589 | + // This needs to be set after get_num_splits |
| 590 | + at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic |
| 591 | + bool const scheduler_needs_semaphore = params.arch >= 90 || params.num_splits > 1; |
| 592 | + if (scheduler_needs_semaphore || use_dynamic_split) { |
| 593 | + tile_count_semaphore = torch::empty({int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b}, opts.dtype(torch::kInt32)); |
| 594 | + if (scheduler_needs_semaphore) { |
| 595 | + if (!use_dynamic_split) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing |
| 596 | + params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>(); |
| 597 | + } else { |
| 598 | + params.tile_count_semaphore = nullptr; |
| 599 | + } |
| 600 | + params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr<int>() + 1 : nullptr; |
| 601 | + } |
| 602 | + |
| 603 | + if (params.num_splits_dynamic_ptr) { |
| 604 | + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); |
| 605 | + auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr); |
| 606 | + int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); |
| 607 | + int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x); |
| 608 | + auto stream = at::cuda::getCurrentCUDAStream().stream(); |
| 609 | + prepare_varlen_num_blocks(params, stream, params.pack_gqa, kBlockM, kBlockN, false /*enable_pdl*/); |
| 610 | + CHECK_CUDA_KERNEL_LAUNCH(); |
| 611 | + } |
| 612 | + return tile_count_semaphore; |
| 613 | +} |
| 614 | + |
493 | 615 | // b: batch_size
|
494 | 616 | // b_k: batch_size_k
|
495 | 617 | // s_q: seqlen_q
|
@@ -528,6 +650,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
|
528 | 650 | int window_size_right,
|
529 | 651 | float const softcap,
|
530 | 652 | bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
|
| 653 | + std::optional<at::Tensor> &scheduler_metadata_, // (b + 1) |
531 | 654 | int num_splits,
|
532 | 655 | std::optional<bool> pack_gqa_,
|
533 | 656 | int const sm_margin
|
@@ -814,21 +937,24 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
|
814 | 937 | bool const scheduler_needs_semaphore = params.arch >= 90
|
815 | 938 | ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen)
|
816 | 939 | : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1));
|
817 |
| - if (scheduler_needs_semaphore || use_dynamic_split) { // This needs to be set before get_num_splits |
818 |
| - tile_count_semaphore = torch::empty({int(scheduler_needs_semaphore) + int(use_dynamic_split) * batch_size}, opts.dtype(torch::kInt32)); |
819 |
| - if (scheduler_needs_semaphore) { |
820 |
| - if (!use_dynamic_split) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing |
821 |
| - params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>(); |
| 940 | + if (scheduler_needs_semaphore || use_dynamic_split) { |
| 941 | + int metadata_size = int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b; |
| 942 | + params.skip_scheduler_metadata_computation = scheduler_metadata_.has_value(); |
| 943 | + if (scheduler_metadata_.has_value()) { |
| 944 | + at::Tensor scheduler_metadata = scheduler_metadata_.value(); |
| 945 | + CHECK_DEVICE(scheduler_metadata); |
| 946 | + CHECK_SHAPE(scheduler_metadata, metadata_size); |
| 947 | + CHECK_CONTIGUOUS(scheduler_metadata); |
| 948 | + TORCH_CHECK(scheduler_metadata.dtype() == torch::kInt32, "scheduler_metadata must have dtype int32"); |
| 949 | + tile_count_semaphore = scheduler_metadata; |
822 | 950 | } else {
|
823 |
| - params.tile_count_semaphore = nullptr; |
| 951 | + tile_count_semaphore = torch::empty({metadata_size}, opts.dtype(torch::kInt32)); |
824 | 952 | }
|
825 |
| - if (use_dynamic_split) { |
826 |
| - // params.num_m_blocks_ptr = num_m_n_blocks_splits.data_ptr<int>(); |
827 |
| - // params.num_n_blocks_ptr = num_m_n_blocks_splits.data_ptr<int>() + batch_size; |
828 |
| - params.num_splits_dynamic_ptr = tile_count_semaphore.data_ptr<int>() + 1; |
829 |
| - } else { |
830 |
| - params.num_splits_dynamic_ptr = nullptr; |
| 953 | + if (scheduler_needs_semaphore && !use_dynamic_split) { |
| 954 | + tile_count_semaphore.zero_(); // If varlen we'll manually do the zero-ing |
831 | 955 | }
|
| 956 | + params.tile_count_semaphore = scheduler_needs_semaphore ? tile_count_semaphore.data_ptr<int>() : nullptr; |
| 957 | + params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr<int>() + 1 : nullptr; |
832 | 958 | }
|
833 | 959 |
|
834 | 960 | if (q_v_.has_value()) {
|
@@ -1449,4 +1575,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
1449 | 1575 | m.def("fwd", &mha_fwd, "Forward pass");
|
1450 | 1576 | m.def("bwd", &mha_bwd, "Backward pass");
|
1451 | 1577 | m.def("fwd_combine", &mha_combine, "Combine partial attention outputs");
|
| 1578 | + m.def("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata, "Get scheduler metadata for varlen forward pass"); |
1452 | 1579 | }
|
0 commit comments