Skip to content

Commit dafe02a

Browse files
authored
[stop sequence] support stop sequence (PaddlePaddle#3025)
* stop seqs in multi-ends * unittest for gpu stop op * kernel tid==0
1 parent 1a815b7 commit dafe02a

File tree

11 files changed

+193
-189
lines changed

11 files changed

+193
-189
lines changed

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -266,13 +266,12 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
266266
const paddle::Tensor &seq_lens,
267267
const paddle::Tensor &end_ids,
268268
const paddle::Tensor &next_tokens,
269+
const paddle::Tensor &pre_ids,
270+
const paddle::Tensor &step_idx,
271+
const paddle::Tensor &stop_seqs,
272+
const paddle::Tensor &stop_seqs_len,
269273
const bool beam_search);
270274

271-
void GetStopFlagsMultiSeqs(
272-
const paddle::Tensor &topk_ids, const paddle::Tensor &pre_ids,
273-
const paddle::Tensor &step_idx, const paddle::Tensor &stop_flags,
274-
const paddle::Tensor &seq_lens, const paddle::Tensor &stop_seqs,
275-
const paddle::Tensor &stop_seqs_len, const paddle::Tensor &end_ids);
276275

277276
void UpdateInputes(const paddle::Tensor &stop_flags,
278277
const paddle::Tensor &not_need_stop, // only on cpu
@@ -954,12 +953,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
954953
m.def("set_stop_value_multi_ends", &GetStopFlagsMulti,
955954
"update_inputs function");
956955

957-
/**
958-
* stop_generation_multi_stop_seqs.cu
959-
* set_stop_value_multi_seqs
960-
*/
961-
m.def("set_stop_value_multi_seqs", &GetStopFlagsMultiSeqs,
962-
"update_inputs function");
963956

964957
/**
965958
* update_inputs.cu

custom_ops/gpu_ops/stop_generation_multi_ends.cu

Lines changed: 65 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -30,30 +30,62 @@ __global__ void set_value_by_flags(bool *stop_flags,
3030
const int *seq_lens,
3131
const int bs,
3232
const int end_length,
33+
const int64_t *pre_ids,
34+
const int pre_ids_len,
35+
const int64_t *step_idx,
36+
const int64_t *stop_seqs,
37+
const int *stop_seqs_len,
38+
const int stop_seqs_bs,
39+
const int stop_seqs_max_len,
3340
bool beam_search,
3441
bool prefill_one_step_stop) {
3542
int tid = threadIdx.x;
36-
if (tid < bs) {
37-
if (prefill_one_step_stop) {
38-
stop_flags[tid] = true;
39-
if (seq_lens[tid] == 0) {
40-
topk_ids[tid] = -1;
41-
}
42-
next_tokens[tid] = topk_ids[tid];
43-
} else {
44-
if (stop_flags[tid]) {
45-
if (seq_lens[tid] == 0) {
46-
topk_ids[tid] = -1;
47-
} else {
48-
topk_ids[tid] = end_ids[0];
49-
next_tokens[tid] = end_ids[0];
43+
int bid = blockIdx.x;
44+
if (tid >= stop_seqs_bs) return;
45+
if (bid < bs) {
46+
if(tid == 0){
47+
if (prefill_one_step_stop) {
48+
stop_flags[bid] = true;
49+
if (seq_lens[bid] == 0) {
50+
topk_ids[bid] = -1;
5051
}
52+
next_tokens[bid] = topk_ids[bid];
5153
} else {
52-
next_tokens[tid] = topk_ids[tid];
54+
if (stop_flags[bid]) {
55+
if (seq_lens[bid] == 0) {
56+
topk_ids[bid] = -1;
57+
} else {
58+
topk_ids[bid] = end_ids[0];
59+
next_tokens[bid] = end_ids[0];
60+
}
61+
} else {
62+
next_tokens[bid] = topk_ids[bid];
63+
}
64+
}
65+
if (!beam_search && is_in_end(topk_ids[bid], end_ids, end_length)) {
66+
stop_flags[bid] = true;
67+
}
68+
}
69+
// dealing stop_seqs
70+
const int stop_seq_len = (stop_seqs_len + bid * stop_seqs_bs)[tid];
71+
if (stop_seq_len <= 0) return;
72+
const int64_t *stop_seq_now = stop_seqs + bid * stop_seqs_bs + tid * stop_seqs_max_len;
73+
const int64_t *pre_ids_now = pre_ids + bid * pre_ids_len;
74+
const int64_t step_idx_now = step_idx[bid];
75+
76+
bool is_end = true;
77+
int count = 1;
78+
for (int i = stop_seq_len - 1; i >= 0; --i) {
79+
if ((step_idx_now - count) < 0 ||
80+
pre_ids_now[step_idx_now - count++] != stop_seq_now[i]) {
81+
is_end = false;
82+
break;
5383
}
5484
}
55-
if (!beam_search && is_in_end(topk_ids[tid], end_ids, end_length)) {
56-
stop_flags[tid] = true;
85+
if (is_end) {
86+
next_tokens[bid] = end_ids[0];
87+
stop_flags[bid] = true;
88+
topk_ids[bid] = end_ids[0];
5789
}
5890
}
5991
}
@@ -63,6 +95,10 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
6395
const paddle::Tensor &seq_lens,
6496
const paddle::Tensor &end_ids,
6597
const paddle::Tensor &next_tokens,
98+
const paddle::Tensor &pre_ids,
99+
const paddle::Tensor &step_idx,
100+
const paddle::Tensor &stop_seqs,
101+
const paddle::Tensor &stop_seqs_len,
66102
const bool beam_search) {
67103
PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64);
68104
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
@@ -83,21 +119,30 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
83119
std::vector<int64_t> shape = topk_ids.shape();
84120
int64_t bs_now = shape[0];
85121
int64_t end_length = end_ids.shape()[0];
86-
int block_size = (bs_now + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
87-
set_value_by_flags<<<1, block_size, 0, cu_stream>>>(
122+
int stop_seqs_bs = stop_seqs.shape()[1];
123+
int stop_seqs_max_len = stop_seqs.shape()[2];
124+
int block_size = (stop_seqs_bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
125+
set_value_by_flags<<<bs_now, block_size, 0, cu_stream>>>(
88126
const_cast<bool *>(stop_flags.data<bool>()),
89127
const_cast<int64_t *>(topk_ids.data<int64_t>()),
90128
const_cast<int64_t *>(next_tokens.data<int64_t>()),
91129
end_ids.data<int64_t>(),
92130
seq_lens.data<int>(),
93131
bs_now,
94132
end_length,
133+
pre_ids.data<int64_t>(),
134+
pre_ids.shape()[1],
135+
step_idx.data<int64_t>(),
136+
stop_seqs.data<int64_t>(),
137+
stop_seqs_len.data<int>(),
138+
stop_seqs_bs,
139+
stop_seqs_max_len,
95140
beam_search,
96141
prefill_one_step_stop);
97142
}
98143

99144
PD_BUILD_STATIC_OP(set_stop_value_multi_ends)
100-
.Inputs({"topk_ids", "stop_flags", "seq_lens", "end_ids", "next_tokens"})
145+
.Inputs({"topk_ids", "stop_flags", "seq_lens", "end_ids", "next_tokens", "pre_ids", "step_idx", "stop_seqs", "stop_seqs_len"})
101146
.Attrs({"beam_search: bool"})
102147
.Outputs({"topk_ids_out", "stop_flags_out", "next_tokens_out"})
103148
.SetInplaceMap({{"topk_ids", "topk_ids_out"},

custom_ops/gpu_ops/stop_generation_multi_stop_seqs.cu

Lines changed: 0 additions & 133 deletions
This file was deleted.

custom_ops/setup_ops.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,6 @@ def find_end_files(directory, end_str):
260260
"gpu_ops/token_penalty_only_once.cu",
261261
"gpu_ops/stop_generation.cu",
262262
"gpu_ops/stop_generation_multi_ends.cu",
263-
"gpu_ops/stop_generation_multi_stop_seqs.cu",
264263
"gpu_ops/set_flags.cu",
265264
"gpu_ops/update_inputs_v1.cu",
266265
"gpu_ops/recover_decode_task.cu",
@@ -529,7 +528,6 @@ def find_end_files(directory, end_str):
529528
sources=[
530529
"gpu_ops/get_padding_offset.cu",
531530
"gpu_ops/set_value_by_flags.cu",
532-
"gpu_ops/stop_generation_multi_stop_seqs.cu",
533531
"gpu_ops/rebuild_padding.cu",
534532
"gpu_ops/update_inputs.cu",
535533
"gpu_ops/stop_generation_multi_ends.cu",

fastdeploy/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,6 @@ def __init__(
101101
self,
102102
args,
103103
):
104-
self.max_stop_seqs_num = 5
105-
self.stop_seqs_max_len = 8
106-
107104
# NOTE(gongshaotain): form _load_model_init_val()
108105
self.top_p = 1.0
109106
self.temperature = 1.0
@@ -122,6 +119,9 @@ def __init__(
122119
self.enable_redundant_experts = False
123120
self.redundant_experts_num = 0
124121

122+
self.max_stop_seqs_num = int(envs.FD_MAX_STOP_SEQS_NUM)
123+
self.stop_seqs_max_len = int(envs.FD_STOP_SEQS_MAX_LEN)
124+
125125
for key, value in args.items():
126126
if hasattr(self, key):
127127
setattr(self, key, value)

fastdeploy/engine/sampling_params.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ class SamplingParams:
9090
min_p: float = 0.0
9191
seed: Optional[int] = None
9292
stop: Optional[Union[str, List[str]]] = None
93-
stop_token_ids: Optional[Union[List[List[int]], List[int]]] = None
93+
stop_token_ids: Optional[List[int]] = None
94+
stop_seqs_len: Optional[int] = None
9495
max_tokens: Optional[int] = None
9596
reasoning_max_tokens: Optional[int] = None
9697
min_tokens: int = 1

fastdeploy/input/ernie_processor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,8 @@ def update_stop_seq(self, stop_sequences):
414414
Update stop sequences from request.
415415
"""
416416
stop_seqs = []
417+
if isinstance(stop_sequences, str):
418+
stop_sequences = [stop_sequences]
417419
for seq in stop_sequences:
418420
if seq != self.tokenizer.eos_token_id:
419421
stop_seqs.append(self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(seq)))

fastdeploy/model_executor/pre_and_post_process.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -210,15 +210,29 @@ def post_process_normal(
210210
paddle.logical_or(model_output.stop_flags, length_cond),
211211
model_output.stop_flags,
212212
)
213-
# TODO(gongshaotian): Add use_stop_seqs
214-
set_stop_value_multi_ends(
215-
sampler_output.sampled_token_ids,
216-
model_output.stop_flags,
217-
model_output.seq_lens_this_time,
218-
model_output.eos_token_id,
219-
model_output.next_tokens,
220-
False,
221-
) # multi ends
213+
214+
if current_platform.is_cuda():
215+
set_stop_value_multi_ends(
216+
sampler_output.sampled_token_ids,
217+
model_output.stop_flags,
218+
model_output.seq_lens_this_time,
219+
model_output.eos_token_id,
220+
model_output.next_tokens,
221+
model_output.pre_ids,
222+
model_output.step_idx,
223+
model_output.stop_token_ids,
224+
model_output.stop_seqs_len,
225+
False,
226+
) # multi ends
227+
else:
228+
set_stop_value_multi_ends(
229+
sampler_output.sampled_token_ids,
230+
model_output.stop_flags,
231+
model_output.seq_lens_this_time,
232+
model_output.eos_token_id,
233+
model_output.next_tokens,
234+
False,
235+
)
222236

223237
# 2. Update the input buffer of the model
224238
with paddle.framework._no_check_dy2st_diff():

0 commit comments

Comments
 (0)