Skip to content

Commit 5589b75

Browse files
Add treemask mode to build_eagle_tree & release sgl-kernel 0.2.3 (sgl-project#7756)
Co-authored-by: Pranjal Shankhdhar <pranjal.ssh@gmail.com>
1 parent c04a8a8 commit 5589b75

File tree

6 files changed

+101
-36
lines changed

6 files changed

+101
-36
lines changed

python/sglang/srt/speculative/build_eagle_tree.py

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
# NOTE: Please run this file to make sure the test cases are correct.
22

3-
from typing import List
3+
import math
4+
from enum import IntEnum
5+
from typing import List, Optional
46

57
import torch
68

7-
from sglang.srt.utils import is_cuda, is_hip, rank0_log
9+
from sglang.srt.utils import is_cuda, is_hip
810

911
if is_cuda() or is_hip():
1012
from sgl_kernel import (
@@ -40,6 +42,12 @@ def build_tree_kernel_efficient_preprocess(
4042
return parent_list, top_scores_index, draft_tokens
4143

4244

45+
class TreeMaskMode(IntEnum):
46+
FULL_MASK = 0
47+
QLEN_ONLY = 1
48+
QLEN_ONLY_BITPACKING = 2
49+
50+
4351
def build_tree_kernel_efficient(
4452
verified_id: torch.Tensor,
4553
score_list: List[torch.Tensor],
@@ -50,6 +58,9 @@ def build_tree_kernel_efficient(
5058
topk: int,
5159
spec_steps: int,
5260
num_verify_tokens: int,
61+
tree_mask_mode: TreeMaskMode = TreeMaskMode.FULL_MASK,
62+
tree_mask_buf: Optional[torch.Tensor] = None,
63+
position_buf: Optional[torch.Tensor] = None,
5364
):
5465
parent_list, top_scores_index, draft_tokens = (
5566
build_tree_kernel_efficient_preprocess(
@@ -66,15 +77,37 @@ def build_tree_kernel_efficient(
6677
device = seq_lens.device
6778
# e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened)
6879
# where each row indicates the attending pattern of each draft token
80+
# if use_partial_packed_tree_mask is True, tree_mask: num_draft_token (flattened, packed)
81+
if tree_mask_buf is not None:
82+
tree_mask = tree_mask_buf
83+
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY:
84+
tree_mask = torch.full(
85+
(num_verify_tokens * bs * num_verify_tokens,),
86+
True,
87+
dtype=torch.bool,
88+
device=device,
89+
)
90+
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING:
91+
packed_dtypes = [torch.uint8, torch.uint16, torch.uint32]
92+
packed_dtype_idx = int(math.ceil(math.log2((num_verify_tokens + 7) // 8)))
93+
tree_mask = torch.zeros(
94+
(num_verify_tokens * bs,),
95+
dtype=packed_dtypes[packed_dtype_idx],
96+
device=device,
97+
)
98+
elif tree_mask_mode == TreeMaskMode.FULL_MASK:
99+
tree_mask = torch.full(
100+
(
101+
seq_lens_sum * num_verify_tokens
102+
+ num_verify_tokens * num_verify_tokens * bs,
103+
),
104+
True,
105+
device=device,
106+
)
107+
else:
108+
raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}")
109+
69110
# TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel`
70-
tree_mask = torch.full(
71-
(
72-
seq_lens_sum * num_verify_tokens
73-
+ num_verify_tokens * num_verify_tokens * bs,
74-
),
75-
True,
76-
device=device,
77-
)
78111
retrive_index = torch.full(
79112
(bs, num_verify_tokens), -1, device=device, dtype=torch.long
80113
)
@@ -87,7 +120,12 @@ def build_tree_kernel_efficient(
87120
# position: where each token belongs to
88121
# e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7
89122
# then, positions = [7, 8, 8, 9]
90-
positions = torch.empty((bs * num_verify_tokens,), device=device, dtype=torch.long)
123+
if position_buf is not None:
124+
positions = position_buf
125+
else:
126+
positions = torch.empty(
127+
(bs * num_verify_tokens,), device=device, dtype=torch.long
128+
)
91129

92130
sgl_build_tree_kernel_efficient(
93131
parent_list,
@@ -101,6 +139,7 @@ def build_tree_kernel_efficient(
101139
topk,
102140
spec_steps,
103141
num_verify_tokens,
142+
tree_mask_mode,
104143
)
105144
return (
106145
tree_mask,
@@ -344,13 +383,13 @@ def test_build_tree_kernel_efficient():
344383
num_verify_tokens=num_draft_token,
345384
)
346385

347-
rank0_log("=========== build tree kernel efficient ==========")
348-
# rank0_log(f"{tree_mask=}")
349-
rank0_log(f"{position=}")
350-
rank0_log(f"{retrive_index=}")
351-
rank0_log(f"{retrive_next_token=}")
352-
rank0_log(f"{retrive_next_sibling=}")
353-
rank0_log(f"{draft_tokens=}")
386+
print("=========== build tree kernel efficient ==========")
387+
print(f"{tree_mask=}")
388+
print(f"{position=}")
389+
print(f"{retrive_index=}")
390+
print(f"{retrive_next_token=}")
391+
print(f"{retrive_next_sibling=}")
392+
print(f"{draft_tokens=}")
354393
assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14]
355394
assert retrive_index.tolist() == [
356395
[0, 1, 2, 3, 4, 5, 6, 7],

sgl-kernel/csrc/common_extension.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
232232
m.def(
233233
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
234234
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, "
235-
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num) -> ()");
235+
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num, int tree_mask_mode) -> "
236+
"()");
236237
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);
237238

238239
m.def(

sgl-kernel/csrc/speculative/eagle_utils.cu

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
#include "pytorch_extension_utils_rocm.h"
2424
#endif
2525

26+
typedef enum { FULL_MASK = 0, QLEN_ONLY = 1, QLEN_ONLY_BITPACKING = 2 } TreeMaskMode;
27+
2628
// parent_list [bs, topk * (depth - 1) + 1)]
2729
// selected_index [bs, draft_token_num - 1]
2830
// verified_seq_len [bs]
@@ -40,7 +42,8 @@ __global__ void build_tree_efficient(
4042
int64_t* retrive_next_sibling,
4143
int topk,
4244
int depth,
43-
int draft_token_num) {
45+
int draft_token_num,
46+
int tree_mask_mode) {
4447
int bid = blockIdx.x;
4548
int tid = threadIdx.x;
4649

@@ -52,7 +55,13 @@ __global__ void build_tree_efficient(
5255
seq_tree_idx += verified_seq_len[i] * draft_token_num;
5356
}
5457
int seq_len = verified_seq_len[bid];
55-
int token_tree_idx = seq_tree_idx + (seq_len + draft_token_num) * tid + seq_len + 1;
58+
int token_tree_idx;
59+
if (tree_mask_mode == FULL_MASK) {
60+
token_tree_idx = seq_tree_idx + (seq_len + draft_token_num) * tid + seq_len + 1;
61+
} else {
62+
token_tree_idx = draft_token_num * draft_token_num * bid + draft_token_num * tid + 1;
63+
}
64+
tree_mask[token_tree_idx - 1] = true;
5665
for (int i = 0; i < draft_token_num - 1; i++) {
5766
tree_mask[token_tree_idx + i] = false;
5867
}
@@ -124,26 +133,38 @@ void build_tree_kernel_efficient(
124133
at::Tensor retrive_next_sibling,
125134
int64_t topk,
126135
int64_t depth,
127-
int64_t draft_token_num) {
136+
int64_t draft_token_num,
137+
int64_t tree_mask_mode) {
128138
// TODO (ying) check shape
129139
// TODO (ying) check type
130140
int bs = parent_list.size(0);
131141
dim3 grid(bs);
132142
dim3 block(draft_token_num);
133143
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
134144

135-
build_tree_efficient<<<grid, block, 0, stream>>>(
136-
static_cast<int64_t*>(parent_list.data_ptr()),
137-
static_cast<int64_t*>(selected_index.data_ptr()),
138-
static_cast<int64_t*>(verified_seq_len.data_ptr()),
139-
static_cast<bool*>(tree_mask.data_ptr()),
140-
static_cast<int64_t*>(positions.data_ptr()),
141-
static_cast<int64_t*>(retrive_index.data_ptr()),
142-
static_cast<int64_t*>(retrive_next_token.data_ptr()),
143-
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
144-
int32_t(topk),
145-
int32_t(depth),
146-
int32_t(draft_token_num));
145+
if (tree_mask_mode == QLEN_ONLY_BITPACKING) {
146+
size_t num_bytes_per_item = 1;
147+
if (draft_token_num > 16) {
148+
num_bytes_per_item = 4;
149+
} else if (draft_token_num > 8) {
150+
num_bytes_per_item = 2;
151+
}
152+
throw std::runtime_error("Not implemented");
153+
} else {
154+
build_tree_efficient<<<grid, block, 0, stream>>>(
155+
static_cast<int64_t*>(parent_list.data_ptr()),
156+
static_cast<int64_t*>(selected_index.data_ptr()),
157+
static_cast<int64_t*>(verified_seq_len.data_ptr()),
158+
static_cast<bool*>(tree_mask.data_ptr()),
159+
static_cast<int64_t*>(positions.data_ptr()),
160+
static_cast<int64_t*>(retrive_index.data_ptr()),
161+
static_cast<int64_t*>(retrive_next_token.data_ptr()),
162+
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
163+
int32_t(topk),
164+
int32_t(depth),
165+
int32_t(draft_token_num),
166+
int32_t(tree_mask_mode));
167+
}
147168
}
148169

149170
template <typename IdType, typename IdType2>

sgl-kernel/csrc/torch_extension_rocm.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
7878
m.def(
7979
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
8080
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, "
81-
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num) -> ()");
81+
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num, int tree_mask_mode) -> "
82+
"()");
8283
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);
8384
}
8485

sgl-kernel/include/sgl_kernel_ops.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,8 @@ void build_tree_kernel_efficient(
374374
at::Tensor retrive_next_sibling,
375375
int64_t topk,
376376
int64_t depth,
377-
int64_t draft_token_num);
377+
int64_t draft_token_num,
378+
int64_t tree_mask_mode);
378379

379380
void segment_packbits(
380381
at::Tensor x,

sgl-kernel/python/sgl_kernel/speculative.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def build_tree_kernel_efficient(
7272
topk: int,
7373
depth: int,
7474
draft_token_num: int,
75+
tree_mask_mode: int,
7576
) -> None:
7677
torch.ops.sgl_kernel.build_tree_kernel_efficient.default(
7778
parent_list,
@@ -85,6 +86,7 @@ def build_tree_kernel_efficient(
8586
topk,
8687
depth,
8788
draft_token_num,
89+
tree_mask_mode,
8890
)
8991

9092

0 commit comments

Comments
 (0)