Skip to content

Commit 1e06b9f

Browse files
authored
make append_attn supports mask_offset (#3138)
* make append_attn supports mask_offset * add unittest
1 parent 6031f9a commit 1e06b9f

File tree

10 files changed

+88
-20
lines changed

10 files changed

+88
-20
lines changed

custom_ops/gpu_ops/append_attention.cu

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
7272
const paddle::optional<paddle::Tensor>& cache_v_zp,
7373
const paddle::optional<paddle::Tensor>& out_linear_shifts,
7474
const paddle::optional<paddle::Tensor>& out_linear_smooths,
75+
const paddle::optional<paddle::Tensor>& mask_offset,
7576
const paddle::optional<paddle::Tensor>& kv_signal_data,
7677
const paddle::optional<paddle::Tensor>& q_norm_weight,
7778
const paddle::optional<paddle::Tensor>& k_norm_weight,
@@ -441,6 +442,7 @@ std::vector<paddle::Tensor> AppendAttention(
441442
const paddle::optional<paddle::Tensor>& cache_v_zp,
442443
const paddle::optional<paddle::Tensor>& out_linear_shifts,
443444
const paddle::optional<paddle::Tensor>& out_linear_smooths,
445+
const paddle::optional<paddle::Tensor>& mask_offset,
444446
const paddle::optional<paddle::Tensor>& kv_signal_data,
445447
const paddle::optional<paddle::Tensor>& q_norm_weight,
446448
const paddle::optional<paddle::Tensor>& k_norm_weight,
@@ -479,6 +481,10 @@ std::vector<paddle::Tensor> AppendAttention(
479481
meta_data.block_size = key_cache.dims()[2];
480482
meta_data.batch_size = seq_lens_this_time.dims()[0];
481483

484+
if (mask_offset) {
485+
meta_data.mask_offset = mask_offset.get().data<int>();
486+
}
487+
482488
auto dispatch_by_template = [&](auto temp_args) -> std::vector<paddle::Tensor> {
483489
return AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
484490
meta_data,
@@ -514,6 +520,7 @@ std::vector<paddle::Tensor> AppendAttention(
514520
cache_v_zp,
515521
out_linear_shifts,
516522
out_linear_smooths,
523+
mask_offset,
517524
kv_signal_data,
518525
q_norm_weight,
519526
k_norm_weight,
@@ -594,6 +601,7 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
594601
const paddle::optional<std::vector<int64_t>>& cache_v_zp_shape,
595602
const paddle::optional<std::vector<int64_t>>& out_linear_shifts_shape,
596603
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape,
604+
const paddle::optional<std::vector<int64_t>>& mask_offset_shape,
597605
const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape,
598606
const paddle::optional<std::vector<int64_t>>& q_norm_weight_shape,
599607
const paddle::optional<std::vector<int64_t>>& k_norm_weight_shape,
@@ -657,6 +665,7 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
657665
const paddle::optional<paddle::DataType>& cache_v_zp_dtype,
658666
const paddle::optional<paddle::DataType>& out_linear_shifts_dtype,
659667
const paddle::optional<paddle::DataType>& out_linear_smooths_dtype,
668+
const paddle::optional<paddle::DataType>& mask_offset_dtype,
660669
const paddle::optional<paddle::DataType>& kv_signal_data_dtype,
661670
const paddle::optional<paddle::DataType>& q_norm_weight_dtype,
662671
const paddle::optional<paddle::DataType>& k_norm_weight_dtype,
@@ -738,6 +747,7 @@ PD_BUILD_STATIC_OP(append_attention)
738747
paddle::Optional("cache_v_zp"),
739748
paddle::Optional("out_linear_shifts"),
740749
paddle::Optional("out_linear_smooths"),
750+
paddle::Optional("mask_offset"),
741751
paddle::Optional("kv_signal_data"),
742752
paddle::Optional("q_norm_weight"),
743753
paddle::Optional("k_norm_weight")})

custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ __global__ void multi_query_append_attention_kernel(
4343
const int *__restrict__ tile_ids_per_batch,
4444
const int *__restrict__ cu_seqlens_q,
4545
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
46+
const int *__restrict__ mask_offset,
4647
const int max_seq_len,
4748
const int max_dec_len,
4849
const int max_block_num_per_seq,
@@ -141,6 +142,7 @@ __global__ void multi_query_append_attention_kernel(
141142
} else {
142143
o_base_ptr_int8 = out + o_offset;
143144
}
145+
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
144146
smem_t qo_smem(smem);
145147

146148
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -179,7 +181,7 @@ __global__ void multi_query_append_attention_kernel(
179181
kv_len - q_len +
180182
tile_id * num_rows_per_block / GROUP_SIZE,
181183
chunk_start)))
182-
: chunk_len) /
184+
: mask_offset ? 0 : chunk_len) /
183185
(num_frags_z * 16);
184186
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
185187
8 * (tid / 16) + tid % 8, (tid % 16) / 8);
@@ -250,7 +252,8 @@ __global__ void multi_query_append_attention_kernel(
250252
q_len,
251253
kv_len,
252254
chunk_end,
253-
s_frag);
255+
s_frag,
256+
mask_offset_this_seq);
254257
}
255258

256259
// update m,d
@@ -406,6 +409,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
406409
const int *__restrict__ tile_ids_per_batch,
407410
const int *__restrict__ cu_seqlens_q,
408411
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
412+
const int *__restrict__ mask_offset,
409413
const int max_seq_len,
410414
const int max_dec_len,
411415
const int max_block_num_per_seq,
@@ -502,7 +506,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
502506
tid % 8 * num_elems_per_128b<T>();
503507
}
504508
}
505-
509+
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
506510
smem_t qo_smem(smem);
507511

508512
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -543,7 +547,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
543547
kv_len - q_len +
544548
tile_id * num_rows_per_block / GROUP_SIZE,
545549
chunk_start)))
546-
: chunk_len) /
550+
: mask_offset ? 0 : chunk_len) /
547551
(NUM_WARP_KV * num_frags_z * 16);
548552

549553
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -616,7 +620,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
616620
q_len,
617621
kv_len,
618622
chunk_end,
619-
s_frag);
623+
s_frag,
624+
mask_offset_this_seq);
620625
}
621626

622627
// update m,d
@@ -882,6 +887,7 @@ void MultiQueryAppendAttention(
882887
tile_ids_per_batch.data<int>(),
883888
cu_seqlens_q.data<int>(),
884889
block_table.data<int>(),
890+
meta_data.mask_offset,
885891
max_seq_len,
886892
max_dec_len,
887893
max_block_num_per_seq,
@@ -939,6 +945,7 @@ void MultiQueryAppendAttention(
939945
tile_ids_per_batch.data<int>(),
940946
cu_seqlens_q.data<int>(),
941947
block_table.data<int>(),
948+
meta_data.mask_offset,
942949
max_seq_len,
943950
max_dec_len,
944951
max_block_num_per_seq,
@@ -1103,6 +1110,7 @@ void MultiQueryAppendAttention(
11031110
tile_ids_per_batch.data<int>(),
11041111
cu_seqlens_q.data<int>(),
11051112
block_table.data<int>(),
1113+
meta_data.mask_offset,
11061114
max_seq_len,
11071115
max_dec_len,
11081116
max_block_num_per_seq,
@@ -1171,6 +1179,7 @@ void MultiQueryAppendAttention(
11711179
tile_ids_per_batch.data<int>(),
11721180
cu_seqlens_q.data<int>(),
11731181
block_table.data<int>(),
1182+
meta_data.mask_offset,
11741183
max_seq_len,
11751184
max_dec_len,
11761185
max_block_num_per_seq,

custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ __global__ void multi_query_append_attention_c4_kernel(
4848
const int *__restrict__ tile_ids_per_batch,
4949
const int *__restrict__ cu_seqlens_q,
5050
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
51+
const int *__restrict__ mask_offset,
5152
const int max_seq_len,
5253
const int max_dec_len,
5354
const int max_block_num_per_seq,
@@ -172,6 +173,7 @@ __global__ void multi_query_append_attention_c4_kernel(
172173
} else {
173174
o_base_ptr_int8 = out + o_offset;
174175
}
176+
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
175177
smem_t qo_smem(smem);
176178

177179
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -248,7 +250,7 @@ __global__ void multi_query_append_attention_c4_kernel(
248250
kv_len - q_len +
249251
tile_id * num_rows_per_block / GROUP_SIZE,
250252
chunk_start)))
251-
: chunk_len) /
253+
: mask_offset ? 0 : chunk_len) /
252254
(num_frags_z * 16);
253255

254256
uint32_t k_smem_offset_r =
@@ -338,7 +340,8 @@ __global__ void multi_query_append_attention_c4_kernel(
338340
q_len,
339341
kv_len,
340342
chunk_end,
341-
s_frag);
343+
s_frag,
344+
mask_offset_this_seq);
342345
}
343346

344347
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(
@@ -505,6 +508,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
505508
const int *__restrict__ tile_ids_per_batch,
506509
const int *__restrict__ cu_seqlens_q,
507510
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
511+
const int *__restrict__ mask_offset,
508512
const int max_seq_len,
509513
const int max_dec_len,
510514
const int max_block_num_per_seq,
@@ -627,7 +631,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
627631
tid % 8 * num_elems_per_128b<T>();
628632
}
629633
}
630-
634+
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
631635
smem_t qo_smem(smem);
632636

633637
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -706,7 +710,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
706710
kv_len - q_len +
707711
tile_id * num_rows_per_block / GROUP_SIZE,
708712
chunk_start)))
709-
: chunk_len) /
713+
: mask_offset ? 0 : chunk_len) /
710714
(NUM_WARP_KV * num_frags_z * 16);
711715

712716
uint32_t k_smem_offset_r =
@@ -793,7 +797,8 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
793797
q_len,
794798
kv_len,
795799
chunk_end,
796-
s_frag);
800+
s_frag,
801+
mask_offset_this_seq);
797802
}
798803

799804
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(
@@ -1088,6 +1093,7 @@ void MultiQueryAppendC4Attention(
10881093
tile_ids_per_batch.data<int>(),
10891094
cu_seqlens_q.data<int>(),
10901095
block_table.data<int>(),
1096+
meta_data.mask_offset,
10911097
max_seq_len,
10921098
max_dec_len,
10931099
max_block_num_per_seq,
@@ -1151,6 +1157,7 @@ void MultiQueryAppendC4Attention(
11511157
tile_ids_per_batch.data<int>(),
11521158
cu_seqlens_q.data<int>(),
11531159
block_table.data<int>(),
1160+
meta_data.mask_offset,
11541161
max_seq_len,
11551162
max_dec_len,
11561163
max_block_num_per_seq,
@@ -1335,6 +1342,7 @@ void MultiQueryAppendC4Attention(
13351342
tile_ids_per_batch.data<int>(),
13361343
cu_seqlens_q.data<int>(),
13371344
block_table.data<int>(),
1345+
meta_data.mask_offset,
13381346
max_seq_len,
13391347
max_dec_len,
13401348
max_block_num_per_seq,
@@ -1411,6 +1419,7 @@ void MultiQueryAppendC4Attention(
14111419
tile_ids_per_batch.data<int>(),
14121420
cu_seqlens_q.data<int>(),
14131421
block_table.data<int>(),
1422+
meta_data.mask_offset,
14141423
max_seq_len,
14151424
max_dec_len,
14161425
max_block_num_per_seq,

custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ __global__ void multi_query_append_attention_c8_kernel(
4848
const int *__restrict__ tile_ids_per_batch,
4949
const int *__restrict__ cu_seqlens_q,
5050
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
51+
const int *__restrict__ mask_offset,
5152
const int max_seq_len,
5253
const int max_dec_len,
5354
const int max_block_num_per_seq,
@@ -179,6 +180,7 @@ __global__ void multi_query_append_attention_c8_kernel(
179180
} else {
180181
o_base_ptr_int8 = out + o_offset;
181182
}
183+
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
182184
smem_t qo_smem(smem);
183185

184186
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -216,7 +218,7 @@ __global__ void multi_query_append_attention_c8_kernel(
216218
kv_len - q_len +
217219
tile_id * num_rows_per_block / GROUP_SIZE,
218220
chunk_start)))
219-
: chunk_len) /
221+
: mask_offset ? 0 : chunk_len) /
220222
(num_frags_z * 16);
221223

222224
uint32_t k_smem_offset_r =
@@ -305,7 +307,8 @@ __global__ void multi_query_append_attention_c8_kernel(
305307
q_len,
306308
kv_len,
307309
chunk_end,
308-
s_frag);
310+
s_frag,
311+
mask_offset_this_seq);
309312
}
310313

311314
// update m,d
@@ -474,6 +477,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
474477
const int *__restrict__ tile_ids_per_batch,
475478
const int *__restrict__ cu_seqlens_q,
476479
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
480+
const int *__restrict__ mask_offset,
477481
const int max_seq_len,
478482
const int max_dec_len,
479483
const int max_block_num_per_seq,
@@ -601,7 +605,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
601605
tid % 8 * num_elems_per_128b<T>();
602606
}
603607
}
604-
608+
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
605609
smem_t qo_smem(smem);
606610

607611
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -642,7 +646,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
642646
kv_len - q_len +
643647
tile_id * num_rows_per_block / GROUP_SIZE,
644648
chunk_start)))
645-
: chunk_len) /
649+
: mask_offset ? 0 : chunk_len) /
646650
(NUM_WARP_KV * num_frags_z * 16);
647651

648652
uint32_t k_smem_offset_r =
@@ -733,7 +737,8 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
733737
q_len,
734738
kv_len,
735739
chunk_end,
736-
s_frag);
740+
s_frag,
741+
mask_offset_this_seq);
737742
}
738743

739744
// update m,d
@@ -1054,6 +1059,7 @@ void MultiQueryAppendC8Attention(
10541059
tile_ids_per_batch.data<int>(),
10551060
cu_seqlens_q.data<int>(),
10561061
block_table.data<int>(),
1062+
meta_data.mask_offset,
10571063
max_seq_len,
10581064
max_dec_len,
10591065
max_block_num_per_seq,
@@ -1111,6 +1117,7 @@ void MultiQueryAppendC8Attention(
11111117
tile_ids_per_batch.data<int>(),
11121118
cu_seqlens_q.data<int>(),
11131119
block_table.data<int>(),
1120+
meta_data.mask_offset,
11141121
max_seq_len,
11151122
max_dec_len,
11161123
max_block_num_per_seq,
@@ -1318,6 +1325,7 @@ void MultiQueryAppendC8Attention(
13181325
tile_ids_per_batch.data<int>(),
13191326
cu_seqlens_q.data<int>(),
13201327
block_table.data<int>(),
1328+
meta_data.mask_offset,
13211329
max_seq_len,
13221330
max_dec_len,
13231331
max_block_num_per_seq,
@@ -1388,6 +1396,7 @@ void MultiQueryAppendC8Attention(
13881396
tile_ids_per_batch.data<int>(),
13891397
cu_seqlens_q.data<int>(),
13901398
block_table.data<int>(),
1399+
meta_data.mask_offset,
13911400
max_seq_len,
13921401
max_dec_len,
13931402
max_block_num_per_seq,

custom_ops/gpu_ops/append_attn/append_attention_func.cuh

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -910,7 +910,8 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base,
910910
const uint32_t qo_len,
911911
const uint32_t kv_len,
912912
const uint32_t chunk_end,
913-
float (*s_frag)[num_frags_z][8]) {
913+
float (*s_frag)[num_frags_z][8],
914+
const int *mask_offset = nullptr) {
914915
const uint32_t tx = threadIdx.x;
915916
#pragma unroll
916917
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
@@ -924,10 +925,15 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base,
924925
group_size,
925926
kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) +
926927
8 * (reg_id / 4) + reg_id % 2;
927-
const bool out_of_boundary =
928-
(causal
929-
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
930-
: kv_idx >= chunk_end);
928+
bool out_of_boundary;
929+
if (mask_offset) {
930+
out_of_boundary = q_idx < qo_len ? (kv_idx > mask_offset[q_idx]) : true;
931+
} else {
932+
out_of_boundary =
933+
(causal
934+
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
935+
: kv_idx >= chunk_end);
936+
}
931937
if constexpr (std::is_same<T, half>::value) {
932938
s_frag[fx][fz][reg_id] =
933939
out_of_boundary ? -5e4f : s_frag[fx][fz][reg_id];

0 commit comments

Comments
 (0)