Skip to content

Commit 72582f8

Browse files
zmingleixwu-intel
authored andcommitted
optimize pad operations in fa3 to accelarate 100+us (sgl-project#6077)
1 parent 463c87f commit 72582f8

File tree

1 file changed

+17
-39
lines changed

1 file changed

+17
-39
lines changed

python/sglang/srt/layers/attention/flashattention_backend.py

Lines changed: 17 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1525,12 +1525,9 @@ def init_forward_metadata_replay_cuda_graph(
15251525
metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
15261526
self.speculative_step_id + 1
15271527
)
1528-
metadata.cu_seqlens_k.copy_(
1529-
torch.nn.functional.pad(
1530-
torch.cumsum(
1531-
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1532-
),
1533-
(1, 0),
1528+
metadata.cu_seqlens_k[1:].copy_(
1529+
torch.cumsum(
1530+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
15341531
)
15351532
)
15361533

@@ -1554,12 +1551,9 @@ def init_forward_metadata_replay_cuda_graph(
15541551
# metadata.max_seq_len_q = self.topk, already set in capture
15551552
metadata.max_seq_len_k = seq_lens_cpu.max().item()
15561553
# metadata.cu_seqlens_q already set in capture
1557-
metadata.cu_seqlens_k.copy_(
1558-
torch.nn.functional.pad(
1559-
torch.cumsum(
1560-
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1561-
),
1562-
(1, 0),
1554+
metadata.cu_seqlens_k[1:].copy_(
1555+
torch.cumsum(
1556+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
15631557
)
15641558
)
15651559

@@ -1616,13 +1610,8 @@ def init_forward_metadata_replay_cuda_graph(
16161610
metadata.max_seq_len_k = (
16171611
seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
16181612
)
1619-
metadata.cu_seqlens_k.copy_(
1620-
torch.nn.functional.pad(
1621-
torch.cumsum(
1622-
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1623-
),
1624-
(1, 0),
1625-
)
1613+
metadata.cu_seqlens_k[1:].copy_(
1614+
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
16261615
)
16271616
max_seq_pages = (
16281617
metadata.max_seq_len_k + self.page_size - 1
@@ -1641,13 +1630,8 @@ def init_forward_metadata_replay_cuda_graph(
16411630
# metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
16421631
metadata.max_seq_len_k = seq_lens_cpu.max().item()
16431632
# metadata.cu_seqlens_q already set in capture
1644-
metadata.cu_seqlens_k.copy_(
1645-
torch.nn.functional.pad(
1646-
torch.cumsum(
1647-
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1648-
),
1649-
(1, 0),
1650-
)
1633+
metadata.cu_seqlens_k[1:].copy_(
1634+
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
16511635
)
16521636
page_table = self.req_to_token[
16531637
req_pool_indices, : metadata.max_seq_len_k
@@ -1705,14 +1689,11 @@ def init_forward_metadata_replay_cuda_graph(
17051689
metadata_expand.cache_seqlens_int32.copy_(
17061690
mask.sum(dim=1).to(torch.int32)
17071691
)
1708-
metadata_expand.cu_seqlens_k.copy_(
1709-
torch.nn.functional.pad(
1710-
torch.cumsum(
1711-
metadata_expand.cache_seqlens_int32,
1712-
dim=0,
1713-
dtype=torch.int32,
1714-
),
1715-
(1, 0),
1692+
metadata_expand.cu_seqlens_k[1:].copy_(
1693+
torch.cumsum(
1694+
metadata_expand.cache_seqlens_int32,
1695+
dim=0,
1696+
dtype=torch.int32,
17161697
)
17171698
)
17181699
metadata_expand.max_seq_len_k = (
@@ -1723,11 +1704,8 @@ def init_forward_metadata_replay_cuda_graph(
17231704
# Only support encoder size 1 for now
17241705
metadata.encoder_max_seq_len_k = encoder_lens[0]
17251706
metadata.encoder_lens_int32.copy_(encoder_lens[:1])
1726-
metadata.encoder_cu_seqlens_k.copy_(
1727-
torch.nn.functional.pad(
1728-
torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
1729-
(1, 0),
1730-
)
1707+
metadata.encoder_cu_seqlens_k[1:].copy_(
1708+
torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32)
17311709
)
17321710

17331711
metadata.encoder_page_table[:, : metadata.encoder_max_seq_len_k].copy_(

0 commit comments

Comments
 (0)