@@ -199,6 +199,13 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
199199 paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
200200 paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
201201 paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
202+ paddle::Tensor &encoder_batch_ids, // Inplace
203+ paddle::Tensor &encoder_tile_ids_per_batch // Inplace
204+ paddle::Tensor &encoder_num_blocks_x_cpu // Inplace, Pinned Memory
205+ paddle::Tensor &kv_batch_ids // Inplace
206+ paddle::Tensor &kv_tile_ids_per_batch // Inplace
207+ paddle::Tensor &kv_num_blocks_x_cpu // Inplace, Pinned Memory
208+ paddle::Tensor &max_len_kv_cpu // Inplace, Pinned Memory
202209 const int encoder_block_shape_q,
203210 const int decoder_block_shape_q,
204211 const int group_size,
@@ -223,14 +230,8 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
223230 int max_system_len = max_len_cpu_ptr[6 ];
224231 int max_just_dec_len_without_system = max_len_cpu_ptr[7 ];
225232
226- paddle::Tensor encoder_batch_ids;
227- paddle::Tensor encoder_tile_ids_per_batch;
228- paddle::Tensor encoder_num_blocks_x_cpu; /* cpu*/
229- paddle::Tensor kv_batch_ids;
230- paddle::Tensor kv_tile_ids_per_batch;
231- paddle::Tensor kv_num_blocks_x_cpu; /* cpu*/
232- paddle::Tensor max_len_kv_cpu; /* cpu*/
233233
234+ PADDLE_ENFORCE_GPU_SUCCESS (cudaMemsetAsync (max_len_kv_cpu.data <int >(), 0 , sizeof (int32_t ), stream));
234235 auto max_len_kv =
235236 GetEmptyTensor ({1 }, paddle::DataType::INT32, seq_lens_decoder.place ());
236237 get_max_len_kv_ernel<128 ><<<1 , 128 , 0 , stream>>> (
@@ -240,14 +241,11 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
240241 max_len_kv_cpu = max_len_kv.copy_to (paddle::CPUPlace (), false );
241242
242243 if (max_enc_len_this_time > 0 ) {
243- const uint32_t max_tile_size_per_bs_kv =
244- div_up (max_enc_dec_len_this_time, block_size);
245- kv_batch_ids =
246- GetEmptyTensor ({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32,
247- seq_lens_encoder.place ());
248- kv_tile_ids_per_batch =
249- GetEmptyTensor ({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32,
250- seq_lens_encoder.place ());
244+ const uint32_t max_tile_size_per_bs_kv = div_up (max_enc_dec_len_this_time, block_size);
245+ const uint32_t kv_batch_shape = bsz * max_tile_size_per_bs_kv;
246+ PADDLE_ENFORCE_GPU_SUCCESS (cudaMemsetAsync (kv_batch_ids.data <int >(), 0 , kv_batch_shape * sizeof (int32_t ), stream));
247+ PADDLE_ENFORCE_GPU_SUCCESS (cudaMemsetAsync (kv_tile_ids_per_batch.data <int >(), 0 , kv_batch_shape * sizeof (int32_t ), stream));
248+ PADDLE_ENFORCE_GPU_SUCCESS (cudaMemsetAsync (kv_num_blocks_x_cpu.data <int >(), 0 , sizeof (int32_t ), stream));
251249 auto kv_num_blocks_x =
252250 GetEmptyTensor ({1 }, paddle::DataType::INT32, seq_lens_encoder.place ());
253251
@@ -259,15 +257,12 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
259257 block_size, block_size);
260258
261259 kv_num_blocks_x_cpu = kv_num_blocks_x.copy_to (paddle::CPUPlace (), false );
262-
263- const uint32_t encoder_max_tile_size_per_bs_q =
264- div_up ((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
265- encoder_batch_ids =
266- GetEmptyTensor ({bsz * encoder_max_tile_size_per_bs_q},
267- paddle::DataType::INT32, seq_lens_encoder.place ());
268- encoder_tile_ids_per_batch =
269- GetEmptyTensor ({bsz * encoder_max_tile_size_per_bs_q},
270- paddle::DataType::INT32, seq_lens_encoder.place ());
260+ // Clear buffer
261+ const uint32_t encoder_max_tile_size_per_bs_q = div_up ((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
262+ const uint32_t encoder_batch_shape = bsz * encoder_max_tile_size_per_bs_q;
263+ PADDLE_ENFORCE_GPU_SUCCESS (cudaMemsetAsync (encoder_batch_ids.data <int >(), 0 , encoder_batch_shape * sizeof (int32_t ), stream));
264+ PADDLE_ENFORCE_GPU_SUCCESS (cudaMemsetAsync (encoder_tile_ids_per_batch.data <int >(), 0 , encoder_batch_shape * sizeof (int32_t ), stream));
265+ PADDLE_ENFORCE_GPU_SUCCESS (cudaMemsetAsync (encoder_num_blocks_x_cpu.data <int >(), 0 , sizeof (int32_t ), stream));
271266 auto encoder_num_blocks_x =
272267 GetEmptyTensor ({1 }, paddle::DataType::INT32, seq_lens_encoder.place ());
273268 split_q_block<<<1 , 32 , 0 , stream>>> (seq_lens_encoder.data <int >(), nullptr ,
@@ -277,19 +272,6 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
277272 encoder_block_shape_q, group_size);
278273 encoder_num_blocks_x_cpu =
279274 encoder_num_blocks_x.copy_to (paddle::CPUPlace (), false );
280- } else {
281- encoder_batch_ids =
282- GetEmptyTensor ({0 }, paddle::DataType::INT32, seq_lens_encoder.place ());
283- encoder_tile_ids_per_batch =
284- GetEmptyTensor ({0 }, paddle::DataType::INT32, seq_lens_encoder.place ());
285- encoder_num_blocks_x_cpu =
286- GetEmptyTensor ({0 }, paddle::DataType::INT32, paddle::CPUPlace ());
287- kv_batch_ids =
288- GetEmptyTensor ({0 }, paddle::DataType::INT32, seq_lens_encoder.place ());
289- kv_tile_ids_per_batch =
290- GetEmptyTensor ({0 }, paddle::DataType::INT32, seq_lens_encoder.place ());
291- kv_num_blocks_x_cpu =
292- GetEmptyTensor ({0 }, paddle::DataType::INT32, paddle::CPUPlace ());
293275 }
294276
295277 if (max_just_dec_len_this_time > 0 ) {
@@ -314,15 +296,6 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
314296 decoder_num_blocks_x_cpu.copy_ (decoder_num_blocks_x, decoder_num_blocks_x_cpu.place (), false );
315297 }
316298
317- return {
318- encoder_batch_ids,
319- encoder_tile_ids_per_batch,
320- encoder_num_blocks_x_cpu, /* cpu*/
321- kv_batch_ids,
322- kv_tile_ids_per_batch,
323- kv_num_blocks_x_cpu, /* cpu*/
324- max_len_kv_cpu, /* cpu*/
325- };
326299}
327300
328301PD_BUILD_STATIC_OP (get_block_shape_and_split_kv_block)
@@ -333,16 +306,17 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
333306 " decoder_batch_ids" ,
334307 " decoder_tile_ids_per_batch" ,
335308 " decoder_num_blocks_x_cpu" ,
336- " max_len_tensor_cpu"
309+ " max_len_tensor_cpu" ,
310+ " encoder_batch_ids" ,
311+ " encoder_tile_ids_per_batch" ,
312+ " encoder_num_blocks_x_cpu" ,
313+ " kv_batch_ids" ,
314+ " kv_tile_ids_per_batch" ,
315+ " kv_num_blocks_x_cpu" ,
316+ " max_len_kv_cpu"
337317 })
338318 .Outputs({
339- paddle::Optional (" encoder_batch_ids" ),
340- paddle::Optional (" encoder_tile_ids_per_batch" ),
341- paddle::Optional (" encoder_num_blocks_x_cpu" ),
342- paddle::Optional (" kv_batch_ids" ),
343- paddle::Optional (" kv_tile_ids_per_batch" ),
344- paddle::Optional (" kv_num_blocks_x_cpu" ),
345- " max_len_kv_cpu"
319+
346320 })
347321 .Attrs({
348322 " encoder_block_shape_q: int" ,
0 commit comments