Skip to content

Commit 562ffb1

Browse files
pavanimajetyjimpang
authored andcommitted
Add Support for Page Size greater than 1 for Flashinfer MLA Backend (sgl-project#8593)
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
1 parent 40f5fd5 commit 562ffb1

File tree

5 files changed

+291
-104
lines changed

5 files changed

+291
-104
lines changed

python/sglang/srt/layers/attention/flashinfer_mla_backend.py

Lines changed: 90 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,7 @@
2424

2525
from sglang.global_config import global_config
2626
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
27-
from sglang.srt.layers.attention.flashinfer_backend import (
28-
create_flashinfer_kv_indices_triton,
29-
)
27+
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
3028
from sglang.srt.layers.dp_attention import get_attention_tp_size
3129
from sglang.srt.layers.utils import is_sm100_supported
3230
from sglang.srt.managers.schedule_batch import global_server_args_dict
@@ -72,11 +70,11 @@ def __init__(
7270
q_indptr_decode_buf: Optional[torch.Tensor] = None,
7371
):
7472
super().__init__()
75-
7673
# Parse constants
7774
self.max_context_len = model_runner.model_config.context_len
7875
self.device = model_runner.device
7976
self.skip_prefill = skip_prefill
77+
self.page_size = model_runner.page_size
8078

8179
# Allocate buffers
8280
global global_workspace_buffer
@@ -97,15 +95,25 @@ def __init__(
9795
else:
9896
self.kv_indptr = kv_indptr_buf
9997

98+
self.kv_indices = torch.empty(
99+
(max_bs * (self.max_context_len + self.page_size - 1) // self.page_size,),
100+
dtype=torch.int32,
101+
device=model_runner.device,
102+
)
103+
100104
if not self.skip_prefill:
101105
self.qo_indptr = torch.zeros(
102106
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
103107
)
104108

105109
if q_indptr_decode_buf is None:
110+
# A hack to pre-initialize large batch size for dp attention
111+
if model_runner.server_args.enable_dp_attention:
112+
max_bs = model_runner.server_args.dp_size * max_bs
106113
self.q_indptr_decode = torch.arange(
107114
0, max_bs + 1, dtype=torch.int32, device=model_runner.device
108115
)
116+
109117
else:
110118
self.q_indptr_decode = q_indptr_decode_buf
111119

@@ -148,6 +156,7 @@ def __init__(
148156
self.prefill_cuda_graph_metadata = {} # For verify
149157

150158
def init_forward_metadata(self, forward_batch: ForwardBatch):
159+
151160
if forward_batch.forward_mode.is_decode_or_idle():
152161
self.indices_updater_decode.update(
153162
forward_batch.req_pool_indices,
@@ -205,16 +214,9 @@ def init_cuda_graph_state(
205214
max_num_tokens: int,
206215
kv_indices_buf: Optional[torch.Tensor] = None,
207216
):
208-
if kv_indices_buf is None:
209-
cuda_graph_kv_indices = torch.zeros(
210-
(max_bs * self.max_context_len,),
211-
dtype=torch.int32,
212-
device="cuda",
213-
)
214-
else:
215-
cuda_graph_kv_indices = kv_indices_buf
216-
217-
self.cuda_graph_kv_indices = cuda_graph_kv_indices
217+
self.cuda_graph_kv_indices = (
218+
self.kv_indices.clone() if kv_indices_buf is None else kv_indices_buf
219+
)
218220
self.cuda_graph_qo_indptr = self.q_indptr_decode.clone()
219221
self.cuda_graph_kv_indptr = self.kv_indptr.clone()
220222
self.cuda_graph_kv_lens = torch.ones(
@@ -240,6 +242,7 @@ def init_forward_metadata_capture_cuda_graph(
240242
forward_mode: ForwardMode,
241243
spec_info: Optional[SpecInfo],
242244
):
245+
243246
if forward_mode.is_decode_or_idle():
244247
decode_wrapper = BatchMLAPagedAttentionWrapper(
245248
self.workspace_buffer,
@@ -250,7 +253,6 @@ def init_forward_metadata_capture_cuda_graph(
250253
kv_len_arr=self.cuda_graph_kv_lens[:num_tokens],
251254
backend="auto",
252255
)
253-
254256
seq_lens_sum = seq_lens.sum().item()
255257
self.indices_updater_decode.update(
256258
req_pool_indices,
@@ -321,11 +323,13 @@ def init_forward_metadata_replay_cuda_graph(
321323
spec_info: Optional[SpecInfo],
322324
seq_lens_cpu: Optional[torch.Tensor],
323325
):
326+
324327
if forward_mode.is_decode_or_idle():
325328
assert seq_lens_cpu is not None
326329
kv_len_arr_cpu = seq_lens_cpu[:bs]
330+
num_pages_per_req = (seq_lens_cpu + self.page_size - 1) // self.page_size
327331
self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum(
328-
kv_len_arr_cpu, dim=0
332+
num_pages_per_req, dim=0
329333
)
330334
self.fast_decode_kwargs.update(
331335
{
@@ -334,7 +338,6 @@ def init_forward_metadata_replay_cuda_graph(
334338
"kv_len_arr_cpu": kv_len_arr_cpu,
335339
}
336340
)
337-
338341
self.indices_updater_decode.update(
339342
req_pool_indices[:bs],
340343
seq_lens[:bs],
@@ -381,7 +384,6 @@ def forward_extend(
381384
q_rope: Optional[torch.Tensor] = None,
382385
k_rope: Optional[torch.Tensor] = None,
383386
):
384-
385387
cache_loc = forward_batch.out_cache_loc
386388
logits_soft_cap = layer.logit_cap
387389
prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
@@ -401,7 +403,6 @@ def forward_extend(
401403
q_rope = q_rope.view(
402404
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
403405
)
404-
405406
if self.forward_metadata.use_ragged:
406407
# ragged prefill
407408
if q_rope is not None:
@@ -422,6 +423,8 @@ def forward_extend(
422423
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
423424
q.dtype
424425
)
426+
k_buf = k_buf.view(-1, self.page_size, k_buf.shape[-1])
427+
425428
if q_rope is None:
426429
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
427430
q, q_rope = (
@@ -483,17 +486,17 @@ def forward_decode(
483486
q_nope = reshaped_q[:, :, : layer.v_head_dim]
484487
q_rope = reshaped_q[:, :, layer.v_head_dim :]
485488

486-
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
489+
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
487490
q.dtype
488491
)
492+
k_buf = k_buf.view(-1, self.page_size, k_buf.shape[-1])
489493

490494
o = q_nope.new_empty(q_nope.shape)
491-
# Direct call to run without the wrapper
492495
o = decode_wrapper.run(
493496
q_nope,
494497
q_rope,
495-
k_buffer[:, :, : layer.v_head_dim],
496-
k_buffer[:, :, layer.v_head_dim :],
498+
k_buf[:, :, : layer.v_head_dim],
499+
k_buf[:, :, layer.v_head_dim :],
497500
out=o,
498501
)
499502

@@ -512,9 +515,10 @@ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
512515
self.scaling = model_runner.model_config.scaling
513516
self.data_type = model_runner.dtype
514517
self.attn_backend = attn_backend
515-
518+
self.page_size = model_runner.page_size
516519
# Buffers and wrappers
517520
self.kv_indptr = attn_backend.kv_indptr
521+
self.kv_indices = attn_backend.kv_indices
518522
self.req_to_token = model_runner.req_to_token_pool.req_to_token
519523
self.q_indptr = attn_backend.q_indptr_decode
520524

@@ -558,13 +562,17 @@ def call_begin_forward(
558562
kv_lens = paged_kernel_lens.to(torch.int32)
559563
sm_scale = self.scaling
560564
if spec_info is None:
561-
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
565+
num_pages_per_req = (
566+
paged_kernel_lens + self.page_size - 1
567+
) // self.page_size
568+
kv_indptr[1 : bs + 1] = torch.cumsum(num_pages_per_req, dim=0)
562569
kv_indptr = kv_indptr[: bs + 1]
563570
kv_indices = (
564-
torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda")
571+
self.kv_indices[: kv_indptr[-1]]
565572
if not init_metadata_replay
566573
else fast_decode_kwargs["kv_indices"]
567574
)
575+
568576
create_flashinfer_kv_indices_triton[(bs,)](
569577
self.req_to_token,
570578
req_pool_indices,
@@ -573,39 +581,40 @@ def call_begin_forward(
573581
None,
574582
kv_indices,
575583
self.req_to_token.shape[1],
584+
self.page_size,
576585
)
577586
else:
578587
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
579588

580589
if not init_metadata_replay:
581590
wrapper.plan(
582-
q_indptr,
583-
kv_indptr,
584-
kv_indices,
585-
kv_lens,
586-
self.num_local_heads,
587-
self.kv_lora_rank,
588-
self.qk_rope_head_dim,
589-
1,
590-
False,
591-
sm_scale,
592-
self.data_type,
593-
self.data_type,
591+
qo_indptr=q_indptr,
592+
kv_indptr=kv_indptr,
593+
kv_indices=kv_indices,
594+
kv_len_arr=kv_lens,
595+
num_heads=self.num_local_heads,
596+
head_dim_ckv=self.kv_lora_rank,
597+
head_dim_kpe=self.qk_rope_head_dim,
598+
page_size=self.page_size,
599+
causal=False,
600+
sm_scale=sm_scale,
601+
q_data_type=self.data_type,
602+
kv_data_type=self.data_type,
594603
)
595604
else:
596605
wrapper.plan(
597-
fast_decode_kwargs["qo_indptr_cpu"],
598-
fast_decode_kwargs["kv_indptr_cpu"],
599-
kv_indices,
600-
fast_decode_kwargs["kv_len_arr_cpu"],
601-
self.num_local_heads,
602-
self.kv_lora_rank,
603-
self.qk_rope_head_dim,
604-
1,
605-
False,
606-
sm_scale,
607-
self.data_type,
608-
self.data_type,
606+
qo_indptr_cpu=fast_decode_kwargs["qo_indptr_cpu"],
607+
kv_indptr_cpu=fast_decode_kwargs["kv_indptr_cpu"],
608+
kv_indices=kv_indices,
609+
kv_len_arr_cpu=fast_decode_kwargs["kv_len_arr_cpu"],
610+
num_heads=self.num_local_heads,
611+
head_dim_ckv=self.kv_lora_rank,
612+
head_dim_kpe=self.qk_rope_head_dim,
613+
page_size=self.page_size,
614+
causal=False,
615+
sm_scale=sm_scale,
616+
q_data_type=self.data_type,
617+
kv_data_type=self.data_type,
609618
)
610619

611620

@@ -627,12 +636,14 @@ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
627636
# Buffers and wrappers
628637
self.kv_indptr = attn_backend.kv_indptr
629638
self.qo_indptr = attn_backend.qo_indptr
639+
self.kv_indices = attn_backend.kv_indices
630640
self.req_to_token = model_runner.req_to_token_pool.req_to_token
631641
self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
642+
self.page_size = model_runner.page_size
632643

633644
def update(
634645
self,
635-
req_pool_indices: torch.Tnesor,
646+
req_pool_indices: torch.Tensor,
636647
seq_lens: torch.Tensor,
637648
seq_lens_sum: int,
638649
prefix_lens: torch.Tensor,
@@ -646,7 +657,6 @@ def update(
646657
else:
647658
paged_kernel_lens = seq_lens
648659
paged_kernel_lens_sum = seq_lens_sum
649-
650660
self.call_begin_forward(
651661
self.prefill_wrapper_ragged,
652662
prefill_wrapper_paged,
@@ -680,13 +690,12 @@ def call_begin_forward(
680690

681691
if spec_info is None:
682692
assert len(seq_lens) == len(req_pool_indices)
683-
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
693+
num_pages_per_req = (
694+
paged_kernel_lens + self.page_size - 1
695+
) // self.page_size
696+
kv_indptr[1 : bs + 1] = torch.cumsum(num_pages_per_req, dim=0)
684697
kv_indptr = kv_indptr[: bs + 1]
685-
kv_indices = torch.empty(
686-
paged_kernel_lens_sum,
687-
dtype=torch.int32,
688-
device=req_pool_indices.device,
689-
)
698+
kv_indices = self.kv_indices[: kv_indptr[-1]]
690699
create_flashinfer_kv_indices_triton[(bs,)](
691700
self.req_to_token,
692701
req_pool_indices,
@@ -695,6 +704,7 @@ def call_begin_forward(
695704
None,
696705
kv_indices,
697706
self.req_to_token.shape[1],
707+
self.page_size,
698708
)
699709
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
700710
qo_indptr = qo_indptr[: bs + 1]
@@ -712,7 +722,6 @@ def call_begin_forward(
712722
self.req_to_token,
713723
)
714724
)
715-
716725
if use_ragged:
717726
# ragged prefill
718727
wrapper_ragged.begin_forward(
@@ -726,20 +735,26 @@ def call_begin_forward(
726735
)
727736
else:
728737
# mla paged prefill
729-
kv_len_arr = kv_indptr[1:] - kv_indptr[:-1]
738+
if spec_info is not None:
739+
assert (
740+
self.page_size == 1
741+
), "Only page_size=1 is supported for flashinfer backend with speculative decoding"
742+
kv_lens = kv_indptr[1:] - kv_indptr[:-1]
743+
else:
744+
kv_lens = paged_kernel_lens.to(torch.int32)
730745
wrapper_paged.plan(
731-
qo_indptr,
732-
kv_indptr,
733-
kv_indices,
734-
kv_len_arr,
735-
self.num_local_heads,
736-
self.kv_lora_rank,
737-
self.qk_rope_head_dim,
738-
1,
739-
True,
740-
sm_scale,
741-
self.q_data_type,
742-
self.data_type,
746+
qo_indptr=qo_indptr,
747+
kv_indptr=kv_indptr,
748+
kv_indices=kv_indices,
749+
kv_len_arr=kv_lens,
750+
num_heads=self.num_local_heads,
751+
head_dim_ckv=self.kv_lora_rank,
752+
head_dim_kpe=self.qk_rope_head_dim,
753+
page_size=self.page_size,
754+
causal=True,
755+
sm_scale=sm_scale,
756+
q_data_type=self.q_data_type,
757+
kv_data_type=self.data_type,
743758
)
744759

745760

@@ -834,6 +849,7 @@ def common_template(
834849
call_fn(i, forward_batch)
835850

836851
def init_forward_metadata(self, forward_batch: ForwardBatch):
852+
837853
kv_indices = torch.zeros(
838854
(
839855
self.speculative_num_steps,
@@ -869,6 +885,7 @@ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
869885
)
870886

871887
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
888+
872889
def call_fn(i, forward_batch):
873890
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
874891
forward_batch.batch_size,
@@ -885,6 +902,7 @@ def call_fn(i, forward_batch):
885902
def init_forward_metadata_replay_cuda_graph(
886903
self, forward_batch: ForwardBatch, bs: int
887904
):
905+
888906
def call_fn(i, forward_batch):
889907
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
890908
bs,

0 commit comments

Comments
 (0)