Skip to content

[Code Simplification] remove cum_offsets #3410

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Aug 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
4950687
rebuild_padding
lizexu123 Jul 30, 2025
f1db526
update rebuild_padding
lizexu123 Aug 4, 2025
0869206
Merge branch 'develop' of https://github.com/PaddlePaddle/FastDeploy …
lizexu123 Aug 4, 2025
4f4a7b8
add rebuild_padding test
lizexu123 Aug 4, 2025
d62c164
fix
lizexu123 Aug 4, 2025
dcb809d
fix test_rebuild_padding.py
lizexu123 Aug 4, 2025
8d6e604
Remove the Chinese comments and translate this sentence into English.
lizexu123 Aug 4, 2025
9407c17
fix comments:
lizexu123 Aug 4, 2025
97a73a7
fix rebuild_padding.cc
lizexu123 Aug 4, 2025
5566244
fix rebuild_padding.cc
lizexu123 Aug 4, 2025
767d1d1
fix comments
lizexu123 Aug 4, 2025
ed2da59
Adapt to MTP.
lizexu123 Aug 5, 2025
43fc43e
fix mtp RebuildAppendPaddingKernel
lizexu123 Aug 6, 2025
aa2ab40
Merge branch 'develop' of https://github.com/PaddlePaddle/FastDeploy …
lizexu123 Aug 6, 2025
beded76
merge develop
lizexu123 Aug 10, 2025
569f7be
gaoxingneng
lizexu123 Aug 10, 2025
c6cd8f7
pre-commit
lizexu123 Aug 10, 2025
2ccb4a7
merge develop
lizexu123 Aug 11, 2025
586f8df
fix mtp
lizexu123 Aug 12, 2025
1ccf9f1
delete BLCOK_SIZE
lizexu123 Aug 12, 2025
6a9f3ee
fix
lizexu123 Aug 14, 2025
9f6bf4c
merge develop
lizexu123 Aug 14, 2025
ad381f1
Merge branch 'develop' of https://github.com/PaddlePaddle/FastDeploy …
lizexu123 Aug 14, 2025
3baf2b5
fix pre-commit
lizexu123 Aug 14, 2025
dcf8f56
fix pre-commit
lizexu123 Aug 14, 2025
c63b394
Merge branch 'develop' of https://github.com/PaddlePaddle/FastDeploy …
lizexu123 Aug 18, 2025
0b69bf9
merge develop
lizexu123 Aug 18, 2025
9593694
fix
lizexu123 Aug 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions custom_ops/cpu_ops/get_padding_offset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
seq_length,
bsz);
return {x_remove_padding,
cum_offsets_out,
padding_offset,
cu_seqlens_q,
cu_seqlens_k};
Expand All @@ -97,7 +96,7 @@ std::vector<std::vector<int64_t>> GetPaddingOffsetInferShape(
const std::vector<int64_t> &seq_len_shape) {
int64_t bsz = seq_len_shape[0];
int64_t seq_len = input_ids_shape[1];
return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}};
return {{-1}, {-1}, {bsz + 1}, {bsz + 1}};
}

std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
Expand All @@ -106,7 +105,6 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
const paddle::DataType &token_num_dtype,
const paddle::DataType &seq_len_dtype) {
return {input_ids_dtype,
seq_len_dtype,
seq_len_dtype,
seq_len_dtype,
seq_len_dtype};
Expand All @@ -115,7 +113,6 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
PD_BUILD_STATIC_OP(get_padding_offset_cpu)
.Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"})
.Outputs({"x_remove_padding",
"cum_offsets_out",
"padding_offset",
"cu_seqlens_q",
"cu_seqlens_k"})
Expand Down
51 changes: 27 additions & 24 deletions custom_ops/cpu_ops/rebuild_padding.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -19,10 +19,11 @@
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif


template <typename T>
void RebuildPaddingCPUImpl(T *output_data,
const T *input_data,
const int *cum_offsets_data,
const int *cu_seqlens_q_data,
const int *seq_len_this_time_data,
const int *seq_lens_decoder_data,
const int *seq_lens_encoder_data,
Expand All @@ -40,11 +41,12 @@ void RebuildPaddingCPUImpl(T *output_data,
if (seq_lens_decoder_data[bi] == 0 && seq_lens_encoder_data[bi] == 0) {
continue;
}

if (seq_lens_encoder_data[bi] > 0) {
seq_id = seq_lens_encoder_data[bi] - 1;
}
const int ori_token_idx =
bi * max_input_length - cum_offsets_data[bi] + seq_id;

const int ori_token_idx = cu_seqlens_q_data[bi] + seq_id;
const int src_offset = ori_token_idx * dim_embed + bias_idx;

output_data[i] = input_data[src_offset];
Expand All @@ -54,7 +56,7 @@ void RebuildPaddingCPUImpl(T *output_data,
template <typename T>
void RebuildAppendPaddingCPUImpl(T *output_data,
const T *input_data,
const int *cum_offsets_data,
const int *cu_seqlens_q_data,
const int *seq_len_this_time_data,
const int *seq_lens_decoder_data,
const int *seq_lens_encoder_data,
Expand All @@ -69,30 +71,32 @@ void RebuildAppendPaddingCPUImpl(T *output_data,
int bi = ori_token_id / max_input_length;
if (seq_len_this_time_data[bi] == 0 ||
(seq_lens_decoder_data[bi] == 0 &&
seq_lens_encoder_data[bi] == 0)) {
continue;
}
seq_lens_encoder_data[bi] == 0)) {
continue;
}
int seq_id = 0;

if (seq_lens_encoder_data[bi] > 0) {
seq_id = seq_lens_encoder_data[bi] - 1;
}
int input_token_id = ori_token_id - cum_offsets_data[bi] + seq_id;
int input_token_id = cu_seqlens_q_data[bi] + seq_id;
int bias_idx = i % dim_embed;
int src_offset = input_token_id * dim_embed + bias_idx;

output_data[i] = input_data[src_offset];
}
}

std::vector<paddle::Tensor> RebuildPaddingCPU(
const paddle::Tensor &tmp_out,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &seq_len_this_time,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_encoder,
const paddle::optional<paddle::Tensor> &output_padding_offset,
int max_input_length) {
auto tmp_out_cpu = tmp_out.copy_to(paddle::CPUPlace(), true);
auto cum_offsets_cpu = cum_offsets.copy_to(paddle::CPUPlace(), true);
auto cu_seqlens_q_cpu = cu_seqlens_q.copy_to(paddle::CPUPlace(), true);
auto seq_len_this_time_cpu =
seq_len_this_time.copy_to(paddle::CPUPlace(), true);
auto seq_lens_decoder_cpu =
Expand All @@ -107,7 +111,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(

int token_num = tmp_out_cpu.shape()[0];
int dim_embed = tmp_out_cpu.shape()[1];
int bsz = cum_offsets_cpu.shape()[0];
int bsz = cu_seqlens_q_cpu.shape()[0] - 1;

paddle::Tensor out;
if (output_padding_offset_cpu) {
Expand All @@ -128,7 +132,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
{bsz, dim_embed}, 0, tmp_out_cpu.dtype(), paddle::CPUPlace());
}

const int *cum_offsets_data = cum_offsets_cpu.data<int>();
const int *cu_seqlens_q_data = cu_seqlens_q_cpu.data<int>();
const int *seq_len_this_time_data = seq_len_this_time_cpu.data<int>();
const int *seq_lens_decoder_data = seq_lens_decoder_cpu.data<int>();
const int *seq_lens_encoder_data = seq_lens_encoder_cpu.data<int>();
Expand All @@ -141,7 +145,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
case paddle::DataType::FLOAT32:
RebuildAppendPaddingCPUImpl<float>(out.data<float>(),
tmp_out_cpu.data<float>(),
cum_offsets_data,
cu_seqlens_q_data,
seq_len_this_time_data,
seq_lens_decoder_data,
seq_lens_encoder_data,
Expand All @@ -154,7 +158,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
RebuildAppendPaddingCPUImpl<paddle::float16>(
out.data<paddle::float16>(),
tmp_out_cpu.data<paddle::float16>(),
cum_offsets_data,
cu_seqlens_q_data,
seq_len_this_time_data,
seq_lens_decoder_data,
seq_lens_encoder_data,
Expand All @@ -167,7 +171,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
RebuildAppendPaddingCPUImpl<paddle::bfloat16>(
out.data<paddle::bfloat16>(),
tmp_out_cpu.data<paddle::bfloat16>(),
cum_offsets_data,
cu_seqlens_q_data,
seq_len_this_time_data,
seq_lens_decoder_data,
seq_lens_encoder_data,
Expand All @@ -186,7 +190,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
case paddle::DataType::FLOAT32:
RebuildPaddingCPUImpl<float>(out.data<float>(),
tmp_out_cpu.data<float>(),
cum_offsets_data,
cu_seqlens_q_data,
seq_len_this_time_data,
seq_lens_decoder_data,
seq_lens_encoder_data,
Expand All @@ -198,7 +202,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
RebuildPaddingCPUImpl<paddle::float16>(
out.data<paddle::float16>(),
tmp_out_cpu.data<paddle::float16>(),
cum_offsets_data,
cu_seqlens_q_data,
seq_len_this_time_data,
seq_lens_decoder_data,
seq_lens_encoder_data,
Expand All @@ -207,11 +211,10 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
elem_nums);
break;
case paddle::DataType::BFLOAT16:

RebuildPaddingCPUImpl<paddle::bfloat16>(
out.data<paddle::bfloat16>(),
tmp_out_cpu.data<paddle::bfloat16>(),
cum_offsets_data,
cu_seqlens_q_data,
seq_len_this_time_data,
seq_lens_decoder_data,
seq_lens_encoder_data,
Expand All @@ -230,7 +233,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(

std::vector<std::vector<int64_t>> RebuildPaddingInferShape(
const std::vector<int64_t> &tmp_out_shape,
const std::vector<int64_t> &cum_offsets_shape,
const std::vector<int64_t> &cu_seqlens_q_shape,
const std::vector<int64_t> &seq_len_this_time_shape,
const std::vector<int64_t> &seq_lens_decoder_shape,
const std::vector<int64_t> &seq_lens_encoder_shape,
Expand All @@ -239,14 +242,14 @@ std::vector<std::vector<int64_t>> RebuildPaddingInferShape(
if (output_padding_offset_shape) {
return {{-1, dim_embed}};
} else {
int64_t bsz = cum_offsets_shape[0];
int64_t bsz = cu_seqlens_q_shape[0] - 1;
return {{bsz, dim_embed}};
}
}

std::vector<paddle::DataType> RebuildPaddingInferDtype(
const paddle::DataType &tmp_out_dtype,
const paddle::DataType &cum_offsets_dtype,
const paddle::DataType &cu_seqlens_q_dtype,
const paddle::DataType &seq_len_this_time_dtype,
const paddle::DataType &seq_lens_decoder_dtype,
const paddle::DataType &seq_lens_encoder_dtype,
Expand All @@ -256,7 +259,7 @@ std::vector<paddle::DataType> RebuildPaddingInferDtype(

PD_BUILD_STATIC_OP(rebuild_padding_cpu)
.Inputs({"tmp_out",
"cum_offsets",
"cu_seqlens_q",
"seq_len_this_time",
"seq_lens_decoder",
"seq_lens_encoder",
Expand Down
5 changes: 1 addition & 4 deletions custom_ops/gpu_ops/get_padding_offset.cu
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
cum_offsets_out.data<int>(),
seq_length);
return {x_remove_padding,
cum_offsets_out,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k}; // , enc_token_num, dec_token_num};
Expand All @@ -114,7 +113,7 @@ std::vector<std::vector<int64_t>> GetPaddingOffsetInferShape(
const std::vector<int64_t> &seq_len_shape) {
int64_t bsz = seq_len_shape[0];
int64_t seq_len = input_ids_shape[1];
return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}};
return {{-1}, {-1}, {bsz + 1}, {bsz + 1}};
}

std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
Expand All @@ -123,7 +122,6 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
const paddle::DataType &token_num_dtype,
const paddle::DataType &seq_len_dtype) {
return {input_ids_dtype,
seq_len_dtype,
seq_len_dtype,
seq_len_dtype,
seq_len_dtype};
Expand All @@ -132,7 +130,6 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
PD_BUILD_STATIC_OP(get_padding_offset)
.Inputs({"input_ids", "token_num", "cum_offsets", "seq_len"})
.Outputs({"x_remove_padding",
"cum_offsets_out",
"batch_id_per_token",
"cu_seqlens_q",
"cu_seqlens_k"})
Expand Down
Loading
Loading