Skip to content

Commit 6fc1759

Browse files
authored
Optimize a pad operation to accelerate 25us (sgl-project#5945)
1 parent ad506a4 commit 6fc1759

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

python/sglang/srt/layers/attention/flashattention_backend.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1587,8 +1587,9 @@ def init_forward_metadata_replay_cuda_graph(
15871587
metadata.max_seq_len_k = max_len
15881588

15891589
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
1590-
metadata.cu_seqlens_k = torch.nn.functional.pad(
1591-
torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
1590+
# Optimize cumulative sequence length calculation
1591+
metadata.cu_seqlens_k[1:].copy_(
1592+
torch.cumsum(seq_lens, dim=0, dtype=torch.int32)
15921593
)
15931594

15941595
max_seq_pages = (

0 commit comments

Comments
 (0)