Skip to content

Commit 63f1256

Browse files
committed
Refactor GetBlockShapeAndSplitKVBlock Kernel V2
1 parent 0e5e305 commit 63f1256

File tree

9 files changed

+191
-200
lines changed

9 files changed

+191
-200
lines changed

custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu

Lines changed: 28 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,13 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
199199
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
200200
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
201201
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
202+
paddle::Tensor &encoder_batch_ids, // Inplace
203+
paddle::Tensor &encoder_tile_ids_per_batch // Inplace
204+
paddle::Tensor &encoder_num_blocks_x_cpu // Inplace, Pinned Memory
205+
paddle::Tensor &kv_batch_ids // Inplace
206+
paddle::Tensor &kv_tile_ids_per_batch // Inplace
207+
paddle::Tensor &kv_num_blocks_x_cpu // Inplace, Pinned Memory
208+
paddle::Tensor &max_len_kv_cpu // Inplace, Pinned Memory
202209
const int encoder_block_shape_q,
203210
const int decoder_block_shape_q,
204211
const int group_size,
@@ -223,14 +230,8 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
223230
int max_system_len = max_len_cpu_ptr[6];
224231
int max_just_dec_len_without_system = max_len_cpu_ptr[7];
225232

226-
paddle::Tensor encoder_batch_ids;
227-
paddle::Tensor encoder_tile_ids_per_batch;
228-
paddle::Tensor encoder_num_blocks_x_cpu; /*cpu*/
229-
paddle::Tensor kv_batch_ids;
230-
paddle::Tensor kv_tile_ids_per_batch;
231-
paddle::Tensor kv_num_blocks_x_cpu; /*cpu*/
232-
paddle::Tensor max_len_kv_cpu; /*cpu*/
233233

234+
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(max_len_kv_cpu.data<int>(), 0, sizeof(int32_t), stream));
234235
auto max_len_kv =
235236
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place());
236237
get_max_len_kv_ernel<128><<<1, 128, 0, stream>>>(
@@ -240,14 +241,11 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
240241
max_len_kv_cpu = max_len_kv.copy_to(paddle::CPUPlace(), false);
241242

242243
if (max_enc_len_this_time > 0) {
243-
const uint32_t max_tile_size_per_bs_kv =
244-
div_up(max_enc_dec_len_this_time, block_size);
245-
kv_batch_ids =
246-
GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32,
247-
seq_lens_encoder.place());
248-
kv_tile_ids_per_batch =
249-
GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32,
250-
seq_lens_encoder.place());
244+
const uint32_t max_tile_size_per_bs_kv = div_up(max_enc_dec_len_this_time, block_size);
245+
const uint32_t kv_batch_shape = bsz * max_tile_size_per_bs_kv;
246+
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_batch_ids.data<int>(), 0, kv_batch_shape * sizeof(int32_t), stream));
247+
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_tile_ids_per_batch.data<int>(), 0, kv_batch_shape * sizeof(int32_t), stream));
248+
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_num_blocks_x_cpu.data<int>(), 0, sizeof(int32_t), stream));
251249
auto kv_num_blocks_x =
252250
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
253251

@@ -259,15 +257,12 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
259257
block_size, block_size);
260258

261259
kv_num_blocks_x_cpu = kv_num_blocks_x.copy_to(paddle::CPUPlace(), false);
262-
263-
const uint32_t encoder_max_tile_size_per_bs_q =
264-
div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
265-
encoder_batch_ids =
266-
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
267-
paddle::DataType::INT32, seq_lens_encoder.place());
268-
encoder_tile_ids_per_batch =
269-
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
270-
paddle::DataType::INT32, seq_lens_encoder.place());
260+
// Clear buffer
261+
const uint32_t encoder_max_tile_size_per_bs_q = div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
262+
const uint32_t encoder_batch_shape = bsz * encoder_max_tile_size_per_bs_q;
263+
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_batch_ids.data<int>(), 0, encoder_batch_shape * sizeof(int32_t), stream));
264+
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_tile_ids_per_batch.data<int>(), 0, encoder_batch_shape * sizeof(int32_t), stream));
265+
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_num_blocks_x_cpu.data<int>(), 0, sizeof(int32_t), stream));
271266
auto encoder_num_blocks_x =
272267
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
273268
split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data<int>(), nullptr,
@@ -277,19 +272,6 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
277272
encoder_block_shape_q, group_size);
278273
encoder_num_blocks_x_cpu =
279274
encoder_num_blocks_x.copy_to(paddle::CPUPlace(), false);
280-
} else {
281-
encoder_batch_ids =
282-
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
283-
encoder_tile_ids_per_batch =
284-
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
285-
encoder_num_blocks_x_cpu =
286-
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
287-
kv_batch_ids =
288-
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
289-
kv_tile_ids_per_batch =
290-
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
291-
kv_num_blocks_x_cpu =
292-
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
293275
}
294276

295277
if (max_just_dec_len_this_time > 0) {
@@ -314,15 +296,6 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
314296
decoder_num_blocks_x_cpu.copy_(decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false);
315297
}
316298

317-
return {
318-
encoder_batch_ids,
319-
encoder_tile_ids_per_batch,
320-
encoder_num_blocks_x_cpu, /*cpu*/
321-
kv_batch_ids,
322-
kv_tile_ids_per_batch,
323-
kv_num_blocks_x_cpu, /*cpu*/
324-
max_len_kv_cpu, /*cpu*/
325-
};
326299
}
327300

328301
PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
@@ -333,16 +306,17 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
333306
"decoder_batch_ids",
334307
"decoder_tile_ids_per_batch",
335308
"decoder_num_blocks_x_cpu",
336-
"max_len_tensor_cpu"
309+
"max_len_tensor_cpu",
310+
"encoder_batch_ids",
311+
"encoder_tile_ids_per_batch",
312+
"encoder_num_blocks_x_cpu",
313+
"kv_batch_ids",
314+
"kv_tile_ids_per_batch",
315+
"kv_num_blocks_x_cpu",
316+
"max_len_kv_cpu"
337317
})
338318
.Outputs({
339-
paddle::Optional("encoder_batch_ids"),
340-
paddle::Optional("encoder_tile_ids_per_batch"),
341-
paddle::Optional("encoder_num_blocks_x_cpu"),
342-
paddle::Optional("kv_batch_ids"),
343-
paddle::Optional("kv_tile_ids_per_batch"),
344-
paddle::Optional("kv_num_blocks_x_cpu"),
345-
"max_len_kv_cpu"
319+
346320
})
347321
.Attrs({
348322
"encoder_block_shape_q: int",

fastdeploy/model_executor/forward_meta.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,14 @@ class ForwardMeta:
8888
# Recorded multiple lengths related to prefill or decode
8989
max_len_tensor_cpu: Optional[paddle.Tensor] = None
9090

91+
encoder_batch_ids: Optional[paddle.Tensor] = None
92+
encoder_tile_ids_per_batch: Optional[paddle.Tensor] = None
93+
encoder_num_blocks_x_cpu: Optional[paddle.Tensor] = None
94+
kv_batch_ids: Optional[paddle.Tensor] = None
95+
kv_tile_ids_per_batch: Optional[paddle.Tensor] = None
96+
kv_num_blocks_x_cpu: Optional[paddle.Tensor] = None
97+
max_len_kv_cpu: Optional[paddle.Tensor] = None
98+
9199
# Sequence length of encoder for ever batch
92100
seq_lens_encoder: Optional[paddle.Tensor] = None
93101
# Sequence length of Encoder for ever batch

fastdeploy/model_executor/layers/attention/append_attn_backend.py

Lines changed: 8 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,6 @@ class AppendAttentionMetadata(AttentionMetadata):
4848
AppendAttentionMetadata
4949
"""
5050

51-
encoder_batch_ids: paddle.Tensor = None
52-
encoder_tile_ids_per_batch: paddle.Tensor = None
53-
encoder_num_blocks: paddle.Tensor = None
54-
kv_batch_ids: paddle.Tensor = None
55-
kv_tile_ids_per_batch: paddle.Tensor = None
56-
kv_num_blocks: paddle.Tensor = None
57-
max_len_kv: paddle.Tensor = None
58-
5951
_dtype: paddle.dtype = paddle.bfloat16
6052
encoder_max_partition_size: int = 32768
6153
max_partition_size: int = 32768
@@ -154,57 +146,28 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
154146
metadata.rotary_embs = forward_meta.rotary_embs
155147
metadata.attn_mask = forward_meta.attn_mask
156148
metadata.pre_caches_length = forward_meta.pre_caches_length
157-
(
158-
temp_encoder_batch_ids,
159-
temp_encoder_tile_ids_per_batch,
160-
temp_encoder_num_blocks,
161-
temp_kv_batch_ids,
162-
temp_kv_tile_ids_per_batch,
163-
temp_kv_num_blocks,
164-
temp_max_len_kv,
165-
# metadata.encoder_batch_ids,
166-
# metadata.encoder_tile_ids_per_batch,
167-
# metadata.encoder_num_blocks,
168-
# metadata.kv_batch_ids,
169-
# metadata.kv_tile_ids_per_batch,
170-
# metadata.kv_num_blocks,
171-
# metadata.max_len_kv,
172-
) = get_block_shape_and_split_kv_block(
149+
get_block_shape_and_split_kv_block(
173150
forward_meta.seq_lens_encoder,
174151
forward_meta.seq_lens_decoder,
175152
forward_meta.seq_lens_this_time,
176153
forward_meta.decoder_batch_ids,
177154
forward_meta.decoder_tile_ids_per_batch,
178155
forward_meta.decoder_num_blocks_cpu,
179156
forward_meta.max_len_tensor_cpu,
157+
forward_meta.encoder_batch_ids,
158+
forward_meta.encoder_tile_ids_per_batch,
159+
forward_meta.encoder_num_blocks_x_cpu,
160+
forward_meta.kv_batch_ids,
161+
forward_meta.kv_tile_ids_per_batch,
162+
forward_meta.kv_num_blocks_x_cpu,
163+
forward_meta.max_len_kv_cpu,
180164
self.encoder_block_shape_q,
181165
self.decoder_block_shape_q,
182166
self.group_size,
183167
self.block_size,
184168
self.speculate_max_draft_token_num + 1,
185169
)
186170

187-
self.share_inputs["encoder_batch_ids"].copy_(temp_encoder_batch_ids, False)
188-
metadata.encoder_batch_ids = self.share_inputs["encoder_batch_ids"]
189-
190-
self.share_inputs["encoder_tile_ids_per_batch"].copy_(temp_encoder_tile_ids_per_batch, False)
191-
metadata.encoder_tile_ids_per_batch = self.share_inputs["encoder_tile_ids_per_batch"]
192-
193-
self.share_inputs["encoder_num_blocks"].copy_(temp_encoder_num_blocks, False)
194-
metadata.encoder_num_blocks = self.share_inputs["encoder_num_blocks"]
195-
196-
self.share_inputs["kv_batch_ids"].copy_(temp_kv_batch_ids, False)
197-
metadata.kv_batch_ids = self.share_inputs["kv_batch_ids"]
198-
199-
self.share_inputs["kv_tile_ids_per_batch"].copy_(temp_kv_tile_ids_per_batch, False)
200-
metadata.kv_tile_ids_per_batch = self.share_inputs["kv_tile_ids_per_batch"]
201-
202-
self.share_inputs["kv_num_blocks"].copy_(temp_kv_num_blocks, False)
203-
metadata.kv_num_blocks = self.share_inputs["kv_num_blocks"]
204-
205-
self.share_inputs["max_len_kv"].copy_(temp_max_len_kv, False)
206-
metadata.max_len_kv = self.share_inputs["max_len_kv"]
207-
208171
# pd_disaggregation
209172
metadata.kv_signal_data_list = [None] * self.num_layers
210173
if self.pd_disaggregation_mode == "per_chunk":

fastdeploy/model_executor/layers/attention/flash_attn_backend.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,6 @@ class FlashAttentionMetadata(AttentionMetadata):
6565

6666
rotary_embs: Optional[paddle.Tensor] = None
6767
block_tables: Optional[paddle.Tensor] = None
68-
encoder_batch_ids: paddle.Tensor = None
69-
encoder_tile_ids_per_batch: paddle.Tensor = None
70-
encoder_num_blocks: paddle.Tensor = None
71-
kv_batch_ids: paddle.Tensor = None
72-
kv_tile_ids_per_batch: paddle.Tensor = None
73-
kv_num_blocks: paddle.Tensor = None
74-
max_len_kv: paddle.Tensor = None
7568

7669
cu_seqlens_q: paddle.Tensor = None
7770
cu_seqlens_k: paddle.Tensor = None
@@ -198,22 +191,21 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
198191
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
199192
metadata.rotary_embs = forward_meta.rotary_embs
200193
metadata.block_tables = forward_meta.block_tables
201-
(
202-
metadata.encoder_batch_ids,
203-
metadata.encoder_tile_ids_per_batch,
204-
metadata.encoder_num_blocks,
205-
metadata.kv_batch_ids,
206-
metadata.kv_tile_ids_per_batch,
207-
metadata.kv_num_blocks,
208-
metadata.max_len_kv,
209-
) = get_block_shape_and_split_kv_block(
194+
get_block_shape_and_split_kv_block(
210195
forward_meta.seq_lens_encoder,
211196
forward_meta.seq_lens_decoder,
212197
forward_meta.seq_lens_this_time,
213198
forward_meta.decoder_batch_ids,
214199
forward_meta.decoder_tile_ids_per_batch,
215200
forward_meta.decoder_num_blocks_cpu,
216201
forward_meta.max_len_tensor_cpu,
202+
forward_meta.encoder_batch_ids,
203+
forward_meta.encoder_tile_ids_per_batch,
204+
forward_meta.encoder_num_blocks_x_cpu,
205+
forward_meta.kv_batch_ids,
206+
forward_meta.kv_tile_ids_per_batch,
207+
forward_meta.kv_num_blocks_x_cpu,
208+
forward_meta.max_len_kv_cpu,
217209
self.encoder_block_shape_q,
218210
self.decoder_block_shape_q,
219211
self.group_size,

fastdeploy/model_executor/layers/attention/mla_attention_backend.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,6 @@ class MLAAttentionMetadata(AttentionMetadata):
6464
MLAAttentionMetadata for Multi-Layer Attention
6565
"""
6666

67-
encoder_batch_ids: paddle.Tensor = None
68-
encoder_tile_ids_per_batch: paddle.Tensor = None
69-
encoder_num_blocks: paddle.Tensor = None
70-
kv_batch_ids: paddle.Tensor = None
71-
kv_tile_ids_per_batch: paddle.Tensor = None
72-
kv_num_blocks: paddle.Tensor = None
73-
max_len_kv: paddle.Tensor = None
74-
7567
_dtype: paddle.dtype = paddle.bfloat16
7668
encoder_max_partition_size: int = 32768
7769
max_partition_size: int = 32768
@@ -166,22 +158,21 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
166158
metadata.attn_mask = forward_meta.attn_mask
167159
metadata.pre_caches_length = forward_meta.pre_caches_length
168160

169-
(
170-
metadata.encoder_batch_ids,
171-
metadata.encoder_tile_ids_per_batch,
172-
metadata.encoder_num_blocks,
173-
metadata.kv_batch_ids,
174-
metadata.kv_tile_ids_per_batch,
175-
metadata.kv_num_blocks,
176-
metadata.max_len_kv,
177-
) = get_block_shape_and_split_kv_block(
161+
get_block_shape_and_split_kv_block(
178162
forward_meta.seq_lens_encoder,
179163
forward_meta.seq_lens_decoder,
180164
forward_meta.seq_lens_this_time,
181165
forward_meta.decoder_batch_ids,
182166
forward_meta.decoder_tile_ids_per_batch,
183167
forward_meta.decoder_num_blocks_cpu,
184168
forward_meta.max_len_tensor_cpu,
169+
forward_meta.encoder_batch_ids,
170+
forward_meta.encoder_tile_ids_per_batch,
171+
forward_meta.encoder_num_blocks_x_cpu,
172+
forward_meta.kv_batch_ids,
173+
forward_meta.kv_tile_ids_per_batch,
174+
forward_meta.kv_num_blocks_x_cpu,
175+
forward_meta.max_len_kv_cpu,
185176
self.encoder_block_shape_q,
186177
self.decoder_block_shape_q,
187178
self.group_size,

fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@ def get_block_shape_and_split_kv_block(
3232
decoder_tile_ids_per_batch: paddle.Tensor,
3333
decoder_num_blocks_x_cpu: paddle.Tensor,
3434
max_len_tensor_cpu: paddle.Tensor,
35+
encoder_batch_ids: paddle.Tensor,
36+
encoder_tile_ids_per_batch: paddle.Tensor,
37+
encoder_num_blocks_x_cpu: paddle.Tensor,
38+
kv_batch_ids: paddle.Tensor,
39+
kv_tile_ids_per_batch: paddle.Tensor,
40+
kv_num_blocks_x_cpu: paddle.Tensor,
41+
max_len_kv_cpu: paddle.Tensor,
3542
encoder_block_shape_q: int,
3643
decoder_block_shape_q: int,
3744
group_size: int,
@@ -42,36 +49,27 @@ def get_block_shape_and_split_kv_block(
4249
get_block_shape_and_split_kv_block
4350
"""
4451
if current_platform.is_cuda():
45-
(
46-
encoder_batch_ids,
47-
encoder_tile_ids_per_batch,
48-
encoder_num_blocks,
49-
kv_batch_ids,
50-
kv_tile_ids_per_batch,
51-
kv_num_blocks,
52-
max_len_kv_cpu,
53-
) = get_block_shape_and_split_kv_block_cuda(
52+
get_block_shape_and_split_kv_block_cuda(
5453
seq_lens_encoder,
5554
seq_lens_decoder,
5655
seq_lens_this_time,
5756
decoder_batch_ids,
5857
decoder_tile_ids_per_batch,
5958
decoder_num_blocks_x_cpu,
59+
encoder_batch_ids,
60+
encoder_tile_ids_per_batch,
61+
encoder_num_blocks_x_cpu,
62+
kv_batch_ids,
63+
kv_tile_ids_per_batch,
64+
kv_num_blocks_x_cpu,
65+
max_len_kv_cpu,
6066
max_len_tensor_cpu,
6167
encoder_block_shape_q,
6268
decoder_block_shape_q,
6369
group_size,
6470
block_size,
6571
decoder_step_token_num,
6672
)
67-
return (
68-
encoder_batch_ids,
69-
encoder_tile_ids_per_batch,
70-
encoder_num_blocks,
71-
kv_batch_ids,
72-
kv_tile_ids_per_batch,
73-
kv_num_blocks,
74-
max_len_kv_cpu,
75-
)
73+
7674
else:
7775
raise NotImplementedError

0 commit comments

Comments
 (0)