24
24
25
25
from sglang .global_config import global_config
26
26
from sglang .srt .layers .attention .base_attn_backend import AttentionBackend
27
- from sglang .srt .layers .attention .flashinfer_backend import (
28
- create_flashinfer_kv_indices_triton ,
29
- )
27
+ from sglang .srt .layers .attention .utils import create_flashinfer_kv_indices_triton
30
28
from sglang .srt .layers .dp_attention import get_attention_tp_size
31
29
from sglang .srt .layers .utils import is_sm100_supported
32
30
from sglang .srt .managers .schedule_batch import global_server_args_dict
@@ -72,11 +70,11 @@ def __init__(
72
70
q_indptr_decode_buf : Optional [torch .Tensor ] = None ,
73
71
):
74
72
super ().__init__ ()
75
-
76
73
# Parse constants
77
74
self .max_context_len = model_runner .model_config .context_len
78
75
self .device = model_runner .device
79
76
self .skip_prefill = skip_prefill
77
+ self .page_size = model_runner .page_size
80
78
81
79
# Allocate buffers
82
80
global global_workspace_buffer
@@ -97,15 +95,25 @@ def __init__(
97
95
else :
98
96
self .kv_indptr = kv_indptr_buf
99
97
98
+ self .kv_indices = torch .empty (
99
+ (max_bs * (self .max_context_len + self .page_size - 1 ) // self .page_size ,),
100
+ dtype = torch .int32 ,
101
+ device = model_runner .device ,
102
+ )
103
+
100
104
if not self .skip_prefill :
101
105
self .qo_indptr = torch .zeros (
102
106
(max_bs + 1 ,), dtype = torch .int32 , device = model_runner .device
103
107
)
104
108
105
109
if q_indptr_decode_buf is None :
110
+ # A hack to pre-initialize large batch size for dp attention
111
+ if model_runner .server_args .enable_dp_attention :
112
+ max_bs = model_runner .server_args .dp_size * max_bs
106
113
self .q_indptr_decode = torch .arange (
107
114
0 , max_bs + 1 , dtype = torch .int32 , device = model_runner .device
108
115
)
116
+
109
117
else :
110
118
self .q_indptr_decode = q_indptr_decode_buf
111
119
@@ -148,6 +156,7 @@ def __init__(
148
156
self .prefill_cuda_graph_metadata = {} # For verify
149
157
150
158
def init_forward_metadata (self , forward_batch : ForwardBatch ):
159
+
151
160
if forward_batch .forward_mode .is_decode_or_idle ():
152
161
self .indices_updater_decode .update (
153
162
forward_batch .req_pool_indices ,
@@ -205,16 +214,9 @@ def init_cuda_graph_state(
205
214
max_num_tokens : int ,
206
215
kv_indices_buf : Optional [torch .Tensor ] = None ,
207
216
):
208
- if kv_indices_buf is None :
209
- cuda_graph_kv_indices = torch .zeros (
210
- (max_bs * self .max_context_len ,),
211
- dtype = torch .int32 ,
212
- device = "cuda" ,
213
- )
214
- else :
215
- cuda_graph_kv_indices = kv_indices_buf
216
-
217
- self .cuda_graph_kv_indices = cuda_graph_kv_indices
217
+ self .cuda_graph_kv_indices = (
218
+ self .kv_indices .clone () if kv_indices_buf is None else kv_indices_buf
219
+ )
218
220
self .cuda_graph_qo_indptr = self .q_indptr_decode .clone ()
219
221
self .cuda_graph_kv_indptr = self .kv_indptr .clone ()
220
222
self .cuda_graph_kv_lens = torch .ones (
@@ -240,6 +242,7 @@ def init_forward_metadata_capture_cuda_graph(
240
242
forward_mode : ForwardMode ,
241
243
spec_info : Optional [SpecInfo ],
242
244
):
245
+
243
246
if forward_mode .is_decode_or_idle ():
244
247
decode_wrapper = BatchMLAPagedAttentionWrapper (
245
248
self .workspace_buffer ,
@@ -250,7 +253,6 @@ def init_forward_metadata_capture_cuda_graph(
250
253
kv_len_arr = self .cuda_graph_kv_lens [:num_tokens ],
251
254
backend = "auto" ,
252
255
)
253
-
254
256
seq_lens_sum = seq_lens .sum ().item ()
255
257
self .indices_updater_decode .update (
256
258
req_pool_indices ,
@@ -321,11 +323,13 @@ def init_forward_metadata_replay_cuda_graph(
321
323
spec_info : Optional [SpecInfo ],
322
324
seq_lens_cpu : Optional [torch .Tensor ],
323
325
):
326
+
324
327
if forward_mode .is_decode_or_idle ():
325
328
assert seq_lens_cpu is not None
326
329
kv_len_arr_cpu = seq_lens_cpu [:bs ]
330
+ num_pages_per_req = (seq_lens_cpu + self .page_size - 1 ) // self .page_size
327
331
self .cuda_graph_kv_indptr_cpu [1 : bs + 1 ] = torch .cumsum (
328
- kv_len_arr_cpu , dim = 0
332
+ num_pages_per_req , dim = 0
329
333
)
330
334
self .fast_decode_kwargs .update (
331
335
{
@@ -334,7 +338,6 @@ def init_forward_metadata_replay_cuda_graph(
334
338
"kv_len_arr_cpu" : kv_len_arr_cpu ,
335
339
}
336
340
)
337
-
338
341
self .indices_updater_decode .update (
339
342
req_pool_indices [:bs ],
340
343
seq_lens [:bs ],
@@ -381,7 +384,6 @@ def forward_extend(
381
384
q_rope : Optional [torch .Tensor ] = None ,
382
385
k_rope : Optional [torch .Tensor ] = None ,
383
386
):
384
-
385
387
cache_loc = forward_batch .out_cache_loc
386
388
logits_soft_cap = layer .logit_cap
387
389
prefill_wrapper_paged = self .forward_metadata .prefill_wrapper
@@ -401,7 +403,6 @@ def forward_extend(
401
403
q_rope = q_rope .view (
402
404
- 1 , layer .tp_q_head_num , layer .head_dim - layer .v_head_dim
403
405
)
404
-
405
406
if self .forward_metadata .use_ragged :
406
407
# ragged prefill
407
408
if q_rope is not None :
@@ -422,6 +423,8 @@ def forward_extend(
422
423
k_buf = forward_batch .token_to_kv_pool .get_key_buffer (layer .layer_id ).to (
423
424
q .dtype
424
425
)
426
+ k_buf = k_buf .view (- 1 , self .page_size , k_buf .shape [- 1 ])
427
+
425
428
if q_rope is None :
426
429
qall = q .view (- 1 , layer .tp_q_head_num , layer .head_dim )
427
430
q , q_rope = (
@@ -483,17 +486,17 @@ def forward_decode(
483
486
q_nope = reshaped_q [:, :, : layer .v_head_dim ]
484
487
q_rope = reshaped_q [:, :, layer .v_head_dim :]
485
488
486
- k_buffer = forward_batch .token_to_kv_pool .get_key_buffer (layer .layer_id ).to (
489
+ k_buf = forward_batch .token_to_kv_pool .get_key_buffer (layer .layer_id ).to (
487
490
q .dtype
488
491
)
492
+ k_buf = k_buf .view (- 1 , self .page_size , k_buf .shape [- 1 ])
489
493
490
494
o = q_nope .new_empty (q_nope .shape )
491
- # Direct call to run without the wrapper
492
495
o = decode_wrapper .run (
493
496
q_nope ,
494
497
q_rope ,
495
- k_buffer [:, :, : layer .v_head_dim ],
496
- k_buffer [:, :, layer .v_head_dim :],
498
+ k_buf [:, :, : layer .v_head_dim ],
499
+ k_buf [:, :, layer .v_head_dim :],
497
500
out = o ,
498
501
)
499
502
@@ -512,9 +515,10 @@ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
512
515
self .scaling = model_runner .model_config .scaling
513
516
self .data_type = model_runner .dtype
514
517
self .attn_backend = attn_backend
515
-
518
+ self . page_size = model_runner . page_size
516
519
# Buffers and wrappers
517
520
self .kv_indptr = attn_backend .kv_indptr
521
+ self .kv_indices = attn_backend .kv_indices
518
522
self .req_to_token = model_runner .req_to_token_pool .req_to_token
519
523
self .q_indptr = attn_backend .q_indptr_decode
520
524
@@ -558,13 +562,17 @@ def call_begin_forward(
558
562
kv_lens = paged_kernel_lens .to (torch .int32 )
559
563
sm_scale = self .scaling
560
564
if spec_info is None :
561
- kv_indptr [1 : bs + 1 ] = torch .cumsum (paged_kernel_lens , dim = 0 )
565
+ num_pages_per_req = (
566
+ paged_kernel_lens + self .page_size - 1
567
+ ) // self .page_size
568
+ kv_indptr [1 : bs + 1 ] = torch .cumsum (num_pages_per_req , dim = 0 )
562
569
kv_indptr = kv_indptr [: bs + 1 ]
563
570
kv_indices = (
564
- torch . empty ( paged_kernel_lens_sum , dtype = torch . int32 , device = "cuda" )
571
+ self . kv_indices [: kv_indptr [ - 1 ]]
565
572
if not init_metadata_replay
566
573
else fast_decode_kwargs ["kv_indices" ]
567
574
)
575
+
568
576
create_flashinfer_kv_indices_triton [(bs ,)](
569
577
self .req_to_token ,
570
578
req_pool_indices ,
@@ -573,39 +581,40 @@ def call_begin_forward(
573
581
None ,
574
582
kv_indices ,
575
583
self .req_to_token .shape [1 ],
584
+ self .page_size ,
576
585
)
577
586
else :
578
587
kv_indptr , kv_indices = spec_info .kv_indptr , spec_info .kv_indices
579
588
580
589
if not init_metadata_replay :
581
590
wrapper .plan (
582
- q_indptr ,
583
- kv_indptr ,
584
- kv_indices ,
585
- kv_lens ,
586
- self .num_local_heads ,
587
- self .kv_lora_rank ,
588
- self .qk_rope_head_dim ,
589
- 1 ,
590
- False ,
591
- sm_scale ,
592
- self .data_type ,
593
- self .data_type ,
591
+ qo_indptr = q_indptr ,
592
+ kv_indptr = kv_indptr ,
593
+ kv_indices = kv_indices ,
594
+ kv_len_arr = kv_lens ,
595
+ num_heads = self .num_local_heads ,
596
+ head_dim_ckv = self .kv_lora_rank ,
597
+ head_dim_kpe = self .qk_rope_head_dim ,
598
+ page_size = self . page_size ,
599
+ causal = False ,
600
+ sm_scale = sm_scale ,
601
+ q_data_type = self .data_type ,
602
+ kv_data_type = self .data_type ,
594
603
)
595
604
else :
596
605
wrapper .plan (
597
- fast_decode_kwargs ["qo_indptr_cpu" ],
598
- fast_decode_kwargs ["kv_indptr_cpu" ],
599
- kv_indices ,
600
- fast_decode_kwargs ["kv_len_arr_cpu" ],
601
- self .num_local_heads ,
602
- self .kv_lora_rank ,
603
- self .qk_rope_head_dim ,
604
- 1 ,
605
- False ,
606
- sm_scale ,
607
- self .data_type ,
608
- self .data_type ,
606
+ qo_indptr_cpu = fast_decode_kwargs ["qo_indptr_cpu" ],
607
+ kv_indptr_cpu = fast_decode_kwargs ["kv_indptr_cpu" ],
608
+ kv_indices = kv_indices ,
609
+ kv_len_arr_cpu = fast_decode_kwargs ["kv_len_arr_cpu" ],
610
+ num_heads = self .num_local_heads ,
611
+ head_dim_ckv = self .kv_lora_rank ,
612
+ head_dim_kpe = self .qk_rope_head_dim ,
613
+ page_size = self . page_size ,
614
+ causal = False ,
615
+ sm_scale = sm_scale ,
616
+ q_data_type = self .data_type ,
617
+ kv_data_type = self .data_type ,
609
618
)
610
619
611
620
@@ -627,12 +636,14 @@ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
627
636
# Buffers and wrappers
628
637
self .kv_indptr = attn_backend .kv_indptr
629
638
self .qo_indptr = attn_backend .qo_indptr
639
+ self .kv_indices = attn_backend .kv_indices
630
640
self .req_to_token = model_runner .req_to_token_pool .req_to_token
631
641
self .prefill_wrapper_ragged = attn_backend .prefill_wrapper_ragged
642
+ self .page_size = model_runner .page_size
632
643
633
644
def update (
634
645
self ,
635
- req_pool_indices : torch .Tnesor ,
646
+ req_pool_indices : torch .Tensor ,
636
647
seq_lens : torch .Tensor ,
637
648
seq_lens_sum : int ,
638
649
prefix_lens : torch .Tensor ,
@@ -646,7 +657,6 @@ def update(
646
657
else :
647
658
paged_kernel_lens = seq_lens
648
659
paged_kernel_lens_sum = seq_lens_sum
649
-
650
660
self .call_begin_forward (
651
661
self .prefill_wrapper_ragged ,
652
662
prefill_wrapper_paged ,
@@ -680,13 +690,12 @@ def call_begin_forward(
680
690
681
691
if spec_info is None :
682
692
assert len (seq_lens ) == len (req_pool_indices )
683
- kv_indptr [1 : bs + 1 ] = torch .cumsum (paged_kernel_lens , dim = 0 )
693
+ num_pages_per_req = (
694
+ paged_kernel_lens + self .page_size - 1
695
+ ) // self .page_size
696
+ kv_indptr [1 : bs + 1 ] = torch .cumsum (num_pages_per_req , dim = 0 )
684
697
kv_indptr = kv_indptr [: bs + 1 ]
685
- kv_indices = torch .empty (
686
- paged_kernel_lens_sum ,
687
- dtype = torch .int32 ,
688
- device = req_pool_indices .device ,
689
- )
698
+ kv_indices = self .kv_indices [: kv_indptr [- 1 ]]
690
699
create_flashinfer_kv_indices_triton [(bs ,)](
691
700
self .req_to_token ,
692
701
req_pool_indices ,
@@ -695,6 +704,7 @@ def call_begin_forward(
695
704
None ,
696
705
kv_indices ,
697
706
self .req_to_token .shape [1 ],
707
+ self .page_size ,
698
708
)
699
709
qo_indptr [1 : bs + 1 ] = torch .cumsum (seq_lens - prefix_lens , dim = 0 )
700
710
qo_indptr = qo_indptr [: bs + 1 ]
@@ -712,7 +722,6 @@ def call_begin_forward(
712
722
self .req_to_token ,
713
723
)
714
724
)
715
-
716
725
if use_ragged :
717
726
# ragged prefill
718
727
wrapper_ragged .begin_forward (
@@ -726,20 +735,26 @@ def call_begin_forward(
726
735
)
727
736
else :
728
737
# mla paged prefill
729
- kv_len_arr = kv_indptr [1 :] - kv_indptr [:- 1 ]
738
+ if spec_info is not None :
739
+ assert (
740
+ self .page_size == 1
741
+ ), "Only page_size=1 is supported for flashinfer backend with speculative decoding"
742
+ kv_lens = kv_indptr [1 :] - kv_indptr [:- 1 ]
743
+ else :
744
+ kv_lens = paged_kernel_lens .to (torch .int32 )
730
745
wrapper_paged .plan (
731
- qo_indptr ,
732
- kv_indptr ,
733
- kv_indices ,
734
- kv_len_arr ,
735
- self .num_local_heads ,
736
- self .kv_lora_rank ,
737
- self .qk_rope_head_dim ,
738
- 1 ,
739
- True ,
740
- sm_scale ,
741
- self .q_data_type ,
742
- self .data_type ,
746
+ qo_indptr = qo_indptr ,
747
+ kv_indptr = kv_indptr ,
748
+ kv_indices = kv_indices ,
749
+ kv_len_arr = kv_lens ,
750
+ num_heads = self .num_local_heads ,
751
+ head_dim_ckv = self .kv_lora_rank ,
752
+ head_dim_kpe = self .qk_rope_head_dim ,
753
+ page_size = self . page_size ,
754
+ causal = True ,
755
+ sm_scale = sm_scale ,
756
+ q_data_type = self .q_data_type ,
757
+ kv_data_type = self .data_type ,
743
758
)
744
759
745
760
@@ -834,6 +849,7 @@ def common_template(
834
849
call_fn (i , forward_batch )
835
850
836
851
def init_forward_metadata (self , forward_batch : ForwardBatch ):
852
+
837
853
kv_indices = torch .zeros (
838
854
(
839
855
self .speculative_num_steps ,
@@ -869,6 +885,7 @@ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
869
885
)
870
886
871
887
def init_forward_metadata_capture_cuda_graph (self , forward_batch : ForwardBatch ):
888
+
872
889
def call_fn (i , forward_batch ):
873
890
self .attn_backends [i ].init_forward_metadata_capture_cuda_graph (
874
891
forward_batch .batch_size ,
@@ -885,6 +902,7 @@ def call_fn(i, forward_batch):
885
902
def init_forward_metadata_replay_cuda_graph (
886
903
self , forward_batch : ForwardBatch , bs : int
887
904
):
905
+
888
906
def call_fn (i , forward_batch ):
889
907
self .attn_backends [i ].init_forward_metadata_replay_cuda_graph (
890
908
bs ,
0 commit comments