@@ -1525,12 +1525,9 @@ def init_forward_metadata_replay_cuda_graph(
1525
1525
metadata .max_seq_len_k = seq_lens_cpu .max ().item () + (
1526
1526
self .speculative_step_id + 1
1527
1527
)
1528
- metadata .cu_seqlens_k .copy_ (
1529
- torch .nn .functional .pad (
1530
- torch .cumsum (
1531
- metadata .cache_seqlens_int32 , dim = 0 , dtype = torch .int32
1532
- ),
1533
- (1 , 0 ),
1528
+ metadata .cu_seqlens_k [1 :].copy_ (
1529
+ torch .cumsum (
1530
+ metadata .cache_seqlens_int32 , dim = 0 , dtype = torch .int32
1534
1531
)
1535
1532
)
1536
1533
@@ -1554,12 +1551,9 @@ def init_forward_metadata_replay_cuda_graph(
1554
1551
# metadata.max_seq_len_q = self.topk, already set in capture
1555
1552
metadata .max_seq_len_k = seq_lens_cpu .max ().item ()
1556
1553
# metadata.cu_seqlens_q already set in capture
1557
- metadata .cu_seqlens_k .copy_ (
1558
- torch .nn .functional .pad (
1559
- torch .cumsum (
1560
- metadata .cache_seqlens_int32 , dim = 0 , dtype = torch .int32
1561
- ),
1562
- (1 , 0 ),
1554
+ metadata .cu_seqlens_k [1 :].copy_ (
1555
+ torch .cumsum (
1556
+ metadata .cache_seqlens_int32 , dim = 0 , dtype = torch .int32
1563
1557
)
1564
1558
)
1565
1559
@@ -1616,13 +1610,8 @@ def init_forward_metadata_replay_cuda_graph(
1616
1610
metadata .max_seq_len_k = (
1617
1611
seq_lens_cpu .max ().item () + self .speculative_num_draft_tokens
1618
1612
)
1619
- metadata .cu_seqlens_k .copy_ (
1620
- torch .nn .functional .pad (
1621
- torch .cumsum (
1622
- metadata .cache_seqlens_int32 , dim = 0 , dtype = torch .int32
1623
- ),
1624
- (1 , 0 ),
1625
- )
1613
+ metadata .cu_seqlens_k [1 :].copy_ (
1614
+ torch .cumsum (metadata .cache_seqlens_int32 , dim = 0 , dtype = torch .int32 )
1626
1615
)
1627
1616
max_seq_pages = (
1628
1617
metadata .max_seq_len_k + self .page_size - 1
@@ -1641,13 +1630,8 @@ def init_forward_metadata_replay_cuda_graph(
1641
1630
# metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
1642
1631
metadata .max_seq_len_k = seq_lens_cpu .max ().item ()
1643
1632
# metadata.cu_seqlens_q already set in capture
1644
- metadata .cu_seqlens_k .copy_ (
1645
- torch .nn .functional .pad (
1646
- torch .cumsum (
1647
- metadata .cache_seqlens_int32 , dim = 0 , dtype = torch .int32
1648
- ),
1649
- (1 , 0 ),
1650
- )
1633
+ metadata .cu_seqlens_k [1 :].copy_ (
1634
+ torch .cumsum (metadata .cache_seqlens_int32 , dim = 0 , dtype = torch .int32 )
1651
1635
)
1652
1636
page_table = self .req_to_token [
1653
1637
req_pool_indices , : metadata .max_seq_len_k
@@ -1705,14 +1689,11 @@ def init_forward_metadata_replay_cuda_graph(
1705
1689
metadata_expand .cache_seqlens_int32 .copy_ (
1706
1690
mask .sum (dim = 1 ).to (torch .int32 )
1707
1691
)
1708
- metadata_expand .cu_seqlens_k .copy_ (
1709
- torch .nn .functional .pad (
1710
- torch .cumsum (
1711
- metadata_expand .cache_seqlens_int32 ,
1712
- dim = 0 ,
1713
- dtype = torch .int32 ,
1714
- ),
1715
- (1 , 0 ),
1692
+ metadata_expand .cu_seqlens_k [1 :].copy_ (
1693
+ torch .cumsum (
1694
+ metadata_expand .cache_seqlens_int32 ,
1695
+ dim = 0 ,
1696
+ dtype = torch .int32 ,
1716
1697
)
1717
1698
)
1718
1699
metadata_expand .max_seq_len_k = (
@@ -1723,11 +1704,8 @@ def init_forward_metadata_replay_cuda_graph(
1723
1704
# Only support encoder size 1 for now
1724
1705
metadata .encoder_max_seq_len_k = encoder_lens [0 ]
1725
1706
metadata .encoder_lens_int32 .copy_ (encoder_lens [:1 ])
1726
- metadata .encoder_cu_seqlens_k .copy_ (
1727
- torch .nn .functional .pad (
1728
- torch .cumsum (metadata .encoder_lens_int32 , dim = 0 , dtype = torch .int32 ),
1729
- (1 , 0 ),
1730
- )
1707
+ metadata .encoder_cu_seqlens_k [1 :].copy_ (
1708
+ torch .cumsum (metadata .encoder_lens_int32 , dim = 0 , dtype = torch .int32 )
1731
1709
)
1732
1710
1733
1711
metadata .encoder_page_table [:, : metadata .encoder_max_seq_len_k ].copy_ (
0 commit comments