13
13
from .utils import prefill_wrapper , prefill_wrapper_model
14
14
15
15
layer_dict = {}
16
+ prune_flag = True
17
+ merge_flag = True
18
+ sparse_token_list_192 = []
19
+ sparse_token_list_128 = []
20
+ sparse_token_list_64 = []
21
+ sparse_token_dict = {}
16
22
17
23
18
24
@TOKEN_REDUCTION_REGISTRY .register ('SparseVLM' )
@@ -26,13 +32,13 @@ def add_sparse_config(self):
26
32
special_config = self .config .get ('special' , {})
27
33
28
34
self .pruning_loc = special_config .get ('pruning_loc' , [2 , 6 , 15 ])
29
- global layer_dict
35
+ global layer_dict , prune_flag , merge_flag
30
36
layer_dict = {layer : idx for idx , layer in enumerate (self .pruning_loc )}
37
+ prune_flag = special_config .get ('prune_flag' , True )
38
+ merge_flag = special_config .get ('merge_flag' , True )
39
+ update_list ()
31
40
special_config ['retained_tokens' ] = special_config .get ('retained_tokens' , 192 )
32
- special_config ['init_token_total_shape' ] = special_config .get ('init_token_total_shape' , 668 )
33
- special_config ['generate_process_count' ] = 0
34
41
special_config ['pre_prompt_length_list' ] = []
35
- special_config ['token_length_list' ] = []
36
42
special_config ['image_shape' ] = self .model .pruning_config ['image_token_length' ]
37
43
special_config ['image_token_index' ] = self .model .pruning_config ['image_token_index' ]
38
44
self .pruning_paras = special_config
@@ -42,7 +48,6 @@ def register_reduction_modules(self):
42
48
def input_hook (module , input_args , pruning_pars ):
43
49
input_ids = input_args [0 ]
44
50
pre_prompt_length_list = []
45
- token_length_list = []
46
51
IMAGE_TOKEN_INDEX = pruning_pars ['image_token_index' ]
47
52
48
53
# find the position of the first image token
@@ -54,10 +59,7 @@ def input_hook(module, input_args, pruning_pars):
54
59
pre_prompt_length_list .append (image_token_index [0 ].item ())
55
60
else :
56
61
pre_prompt_length_list .append (0 )
57
- token_length_list .append (seq .shape [0 ])
58
-
59
62
pruning_pars ['pre_prompt_length_list' ] = pre_prompt_length_list
60
- pruning_pars ['token_length_list' ] = token_length_list
61
63
62
64
return input_args
63
65
@@ -90,11 +92,7 @@ def wrapper(self, *args, **kwargs):
90
92
91
93
pruning_paras ['pre_prompt_length_list' ] = pre_prompt_length_list
92
94
93
- outputs = fn (* args , ** kwargs )
94
-
95
- pruning_paras ['token_length_list' ] = outputs [2 ].sum (dim = 1 ).tolist ()
96
-
97
- return outputs
95
+ return fn (* args , ** kwargs )
98
96
return wrapper
99
97
100
98
@prefill_wrapper_model
@@ -106,12 +104,6 @@ def register_module_pars(module, args, kwargs, pruning_pars):
106
104
107
105
B , L , _ = hidden_states .shape
108
106
pruning_pars ['B' ] = B
109
- init_n = pruning_pars ['init_token_total_shape' ] + \
110
- pruning_pars ['generate_process_count' ] # 668
111
- pruning_pars ['prev_decision' ] = torch .ones (
112
- B , init_n , 1 , dtype = hidden_states .dtype , device = hidden_states .device )
113
- pruning_pars ['policy' ] = torch .ones (
114
- B , init_n , 1 , dtype = hidden_states .dtype , device = hidden_states .device )
115
107
116
108
v_token_start = pre_prompt_length_list [0 ] if len (
117
109
pre_prompt_length_list ) != 0 else 0
@@ -123,8 +115,8 @@ def register_module_pars(module, args, kwargs, pruning_pars):
123
115
if (len (pre_prompt_length_list ) != 0 and hidden_states .shape [1 ] != 1 ):
124
116
v_t = hidden_states [:, v_token_start : text_token_start , :]
125
117
t_t = hidden_states [:, text_token_start :, :]
126
- m_v_t = v_t @ t_t .transpose (1 , 2 ) # [1, 576, 52]
127
- m_v_t = m_v_t .softmax (2 ).mean (1 ) # [1, 52]
118
+ m_v_t = v_t @ t_t .transpose (1 , 2 )
119
+ m_v_t = m_v_t .softmax (2 ).mean (1 )
128
120
pruning_pars ['t_token_idx' ] = torch .where (m_v_t > m_v_t .mean ())
129
121
130
122
return args , kwargs
@@ -133,6 +125,7 @@ def update_output_attentions_hook(module, args, kwargs, pruning_pars, layer_idx)
133
125
kwargs ['output_attentions' ] = True
134
126
if layer_idx != self .pruning_loc [0 ]:
135
127
kwargs ['position_ids' ] = pruning_pars ['position_ids' ]
128
+ kwargs ['attention_mask' ] = pruning_pars ['attention_mask' ]
136
129
kwargs ['cache_position' ] = pruning_pars ['cache_position' ]
137
130
kwargs ['position_embeddings' ] = pruning_pars ['position_embeddings' ]
138
131
return args , kwargs
@@ -143,8 +136,14 @@ def update_kwargs_hook(module, args, kwargs, pruning_pars, layer_idx):
143
136
return args , kwargs
144
137
if layer_idx != self .pruning_loc [0 ]:
145
138
kwargs ['position_ids' ] = pruning_pars ['position_ids' ]
139
+ kwargs ['attention_mask' ] = pruning_pars ['attention_mask' ]
146
140
kwargs ['cache_position' ] = pruning_pars ['cache_position' ]
147
141
kwargs ['position_embeddings' ] = pruning_pars ['position_embeddings' ]
142
+ else :
143
+ pruning_pars ['position_ids' ] = kwargs ['position_ids' ]
144
+ pruning_pars ['attention_mask' ] = kwargs ['attention_mask' ]
145
+ pruning_pars ['cache_position' ] = kwargs ['cache_position' ]
146
+ pruning_pars ['position_embeddings' ] = kwargs ['position_embeddings' ]
148
147
return args , kwargs
149
148
150
149
def get_attn_logits_hook (module , args , kwargs , layer_outs , pruning_pars , layer_idx ):
@@ -155,11 +154,6 @@ def get_attn_logits_hook(module, args, kwargs, layer_outs, pruning_pars, layer_i
155
154
from transformers .models .llama .modeling_llama import \
156
155
apply_rotary_pos_emb
157
156
158
- if layer_idx != self .pruning_loc [0 ]:
159
- kwargs ['position_ids' ] = pruning_pars ['position_ids' ]
160
- kwargs ['cache_position' ] = pruning_pars ['cache_position' ]
161
- kwargs ['position_embeddings' ] = pruning_pars ['position_embeddings' ]
162
-
163
157
hidden_states = kwargs ['hidden_states' ]
164
158
position_embeddings = kwargs ['position_embeddings' ]
165
159
position_ids = kwargs ['position_ids' ]
@@ -215,9 +209,10 @@ def get_attn_logits_hook(module, args, kwargs, layer_outs, pruning_pars, layer_i
215
209
def decoder_attn_hook (module , inputs , kwargs , layer_outputs , pruning_pars , layer_idx ):
216
210
217
211
if 'attn_logits' not in pruning_pars :
218
- attn_logits = layer_outputs [1 ]
212
+ attn_logits = layer_outputs [1 ] # for LlavaHf
219
213
else :
220
214
attn_logits = pruning_pars ['attn_logits' ]
215
+ prune_flag = pruning_pars .get ('prune_flag' , True )
221
216
merge_flag = pruning_pars ['merge_flag' ]
222
217
v_token_start = pruning_pars ['v_token_start' ]
223
218
v_token_num = pruning_pars ['v_token_num' ]
@@ -227,13 +222,11 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer
227
222
B = pruning_pars ['B' ]
228
223
pre_prompt_length_list = pruning_pars ['pre_prompt_length_list' ]
229
224
image_shape = pruning_pars ['image_shape' ]
230
- if layer_idx == self .pruning_loc [0 ]:
231
- position_ids = kwargs ['position_ids' ]
232
- pruning_pars ['position_ids' ] = position_ids
233
- else :
234
- position_ids = pruning_pars ['position_ids' ]
235
- hidden_states = inputs [0 ] # [B, L, D]
236
225
226
+ attention_mask = kwargs ['attention_mask' ]
227
+ position_embeddings = kwargs ['position_embeddings' ]
228
+
229
+ hidden_states = inputs [0 ] # [B, L, D]
237
230
pred_score_vis , s_flag , relation_vis_text = attn_postprocess_topk (
238
231
attn_logits ,
239
232
v_token_start ,
@@ -243,7 +236,8 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer
243
236
layer_idx ,
244
237
retained_tokens
245
238
)
246
-
239
+ if not prune_flag :
240
+ pred_score_vis = torch .zeros_like (relation_vis_text , dtype = bool )
247
241
policy = torch .ones (B , hidden_states .shape [1 ], dtype = hidden_states .dtype ,
248
242
device = hidden_states .device )
249
243
policy [:, v_token_start :text_token_start ] = \
@@ -261,60 +255,91 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer
261
255
262
256
# merge and cluster
263
257
if s_flag and merge_flag and total_sparse_token_idx .shape [1 ] > 0 :
264
- total_sparse_token = batch_index_select (layer_outputs [0 ], total_sparse_token_idx )
258
+ total_sparse_token = batch_index_select (
259
+ layer_outputs [0 ], total_sparse_token_idx
260
+ )
265
261
266
262
merge_token_idx_stage1 = torch .where (pred_score_vis == 0 )[1 ]
267
263
merge_token_stage1 = relation_vis_text [0 ][merge_token_idx_stage1 ]
268
- merge_token_num_stage1 = int (merge_token_idx_stage1 .shape [0 ] * 0.3 ) + 1 # Top 30%
264
+ if prune_flag :
265
+ merge_token_num_stage1 = int (merge_token_idx_stage1 .shape [0 ] * 0.3 ) + 1
266
+ else :
267
+ merge_token_num_stage1 = (
268
+ merge_token_idx_stage1 .shape [0 ]
269
+ - sparse_token_dict [retained_tokens ][layer_dict [layer_idx ]]
270
+ )
269
271
merge_token_stage2_idx = merge_token_stage1 .topk (merge_token_num_stage1 )[1 ]
272
+ if not prune_flag :
273
+ all_idx = torch .arange (
274
+ merge_token_stage1 .size (0 ),
275
+ device = merge_token_stage1 .device
276
+ )
277
+ non_topk_idx = all_idx [~ torch .isin (all_idx , merge_token_stage2_idx )]
278
+ pred_score_vis [0 ][non_topk_idx ] = 1
279
+ policy [:, v_token_start :text_token_start ] = \
280
+ pred_score_vis .type (dtype = hidden_states .dtype )
270
281
271
282
merge_token_stage2 = total_sparse_token [:, merge_token_stage2_idx , :]
272
283
cluster_num = int (merge_token_stage2 .shape [1 ] / 10 ) + 1
273
284
if cluster_num == 0 :
274
285
cluster_num = merge_token_stage2 .shape [1 ]
286
+ merge_sparse_token , index_down = cluster_and_merge (merge_token_stage2 , cluster_num )
275
287
276
- merge_sparse_token = cluster_and_merge ( merge_token_stage2 , cluster_num )
277
-
288
+ cluster_idx = total_sparse_token_idx . squeeze ( 0 )[ merge_token_stage2_idx [ index_down ]]
289
+ cluster_idx = cluster_idx . squeeze ( 0 )
278
290
select_token_idx = torch .where (policy == 1 )[1 ].unsqueeze (0 )
279
291
select_token = batch_index_select (layer_outputs [0 ], select_token_idx )
280
292
select_vis_token_num = pred_score_vis .sum ()
281
-
293
+ keep_indexs = torch .cat (
294
+ (
295
+ select_token_idx .squeeze (0 )[:v_token_start + select_vis_token_num ],
296
+ cluster_idx ,
297
+ select_token_idx .squeeze (0 )[v_token_start + select_vis_token_num :]
298
+ )
299
+ )
282
300
select_and_merge_token = torch .cat (
283
301
(
284
- select_token [:, :v_token_start +
285
- select_vis_token_num , :],
302
+ select_token [:, :v_token_start + select_vis_token_num , :],
286
303
merge_sparse_token ,
287
- select_token [:, v_token_start +
288
- select_vis_token_num :, :]
304
+ select_token [:, v_token_start + select_vis_token_num :, :]
289
305
),
290
306
dim = 1
291
307
)
292
308
layer_outputs = (select_and_merge_token , layer_outputs [1 ])
293
- position_ids = position_ids [:, :len (select_token_idx [0 ]) + cluster_num ]
294
309
v_token_num = pred_score_vis .sum () + cluster_num
295
- text_token_start = v_token_start + v_token_num
310
+
296
311
else :
297
- select_token_idx = torch .where (policy == 1 )[1 ].unsqueeze (0 )
312
+ keep_indexs = torch .where (policy == 1 )[1 ]
313
+ select_token_idx = keep_indexs .unsqueeze (0 )
298
314
layer_outputs = (batch_index_select (layer_outputs [0 ], select_token_idx ),
299
315
layer_outputs [1 ])
300
- position_ids = position_ids [:, :len (select_token_idx [0 ])]
301
316
v_token_num = pred_score_vis .sum ()
302
- text_token_start = v_token_start + v_token_num
303
317
318
+ text_token_start = v_token_start + v_token_num
319
+ position_ids = keep_indexs .unsqueeze (0 )
304
320
new_output = layer_outputs
305
- cache_position = position_ids .detach ().clone ()
321
+ cache_position = position_ids .squeeze (0 )
322
+
323
+ if attention_mask is not None :
324
+ attention_mask = attention_mask [:, :, keep_indexs , keep_indexs ]
325
+ new_pe0 = position_embeddings [0 ][:, keep_indexs , :].clone ()
326
+ new_pe1 = position_embeddings [1 ][:, keep_indexs , :].clone ()
327
+ position_embeddings = (new_pe0 , new_pe1 )
306
328
307
329
pruning_pars ['v_token_num' ] = v_token_num
308
330
pruning_pars ['text_token_start' ] = text_token_start
331
+
309
332
pruning_pars ['position_ids' ] = position_ids
310
333
pruning_pars ['cache_position' ] = cache_position
311
- pruning_pars ['position_embeddings' ] = None
334
+ pruning_pars ['position_embeddings' ] = position_embeddings
335
+ pruning_pars ['attention_mask' ] = attention_mask
312
336
313
337
return new_output
314
338
315
339
@prefill_wrapper
316
340
def read_parameter_hook (module , args , kwargs , pruning_pars ):
317
341
kwargs ['position_ids' ] = pruning_pars ['position_ids' ]
342
+ kwargs ['attention_mask' ] = pruning_pars ['attention_mask' ]
318
343
kwargs ['cache_position' ] = pruning_pars ['cache_position' ]
319
344
kwargs ['position_embeddings' ] = pruning_pars ['position_embeddings' ]
320
345
@@ -363,7 +388,7 @@ def read_parameter_hook(module, args, kwargs, pruning_pars):
363
388
with_kwargs = True
364
389
)
365
390
elif self .model .__class__ .__name__ == 'Llava' :
366
- self .blocks [block_idx ].self_attn . register_forward_pre_hook (
391
+ self .blocks [block_idx ].register_forward_pre_hook (
367
392
functools .partial (
368
393
update_kwargs_hook ,
369
394
pruning_pars = self .pruning_paras ,
@@ -383,7 +408,7 @@ def read_parameter_hook(module, args, kwargs, pruning_pars):
383
408
functools .partial (
384
409
decoder_attn_hook ,
385
410
pruning_pars = self .pruning_paras ,
386
- layer_idx = block_idx ,
411
+ layer_idx = block_idx
387
412
),
388
413
with_kwargs = True
389
414
)
@@ -397,17 +422,37 @@ def read_parameter_hook(module, args, kwargs, pruning_pars):
397
422
)
398
423
399
424
400
- layer_dict = {2 : 0 , 6 : 1 , 15 : 2 }
401
-
402
- sparse_token_list_192 = [300 , 200 , 110 ] # 2*576 4*300 10*200 16*110
403
- sparse_token_list_128 = [303 , 110 , 36 ]
404
- sparse_token_list_64 = [66 , 30 , 17 ]
425
+ def update_list ():
426
+ global sparse_token_list_192 , sparse_token_list_128 , sparse_token_list_64
427
+ global prune_flag , merge_flag , sparse_token_dict
428
+
429
+ if layer_dict == {2 : 0 , 6 : 1 , 15 : 2 }: # 2*576 4*300 10*200 16*110
430
+ sparse_token_list_192 = [300 , 200 , 110 ]
431
+ sparse_token_list_128 = [303 , 110 , 36 ]
432
+ sparse_token_list_64 = [66 , 30 , 17 ]
433
+ prune_flag , merge_flag = True , True
434
+ elif prune_flag and merge_flag :
435
+ sparse_token_list_192 = [180 ]
436
+ sparse_token_list_128 = [114 ]
437
+ sparse_token_list_64 = [48 ]
438
+ elif prune_flag :
439
+ sparse_token_list_192 = [192 ]
440
+ sparse_token_list_128 = [128 ]
441
+ sparse_token_list_64 = [64 ]
442
+ elif merge_flag :
443
+ sparse_token_list_192 = [149 ]
444
+ sparse_token_list_128 = [78 ]
445
+ sparse_token_list_64 = [7 ]
446
+ else :
447
+ raise RuntimeError (
448
+ 'Both prune_flag and merge_flag are False — sparseVLM is inactive.'
449
+ )
405
450
406
- sparse_token_dict = {
407
- 192 : sparse_token_list_192 ,
408
- 128 : sparse_token_list_128 ,
409
- 64 : sparse_token_list_64
410
- }
451
+ sparse_token_dict = {
452
+ 192 : sparse_token_list_192 ,
453
+ 128 : sparse_token_list_128 ,
454
+ 64 : sparse_token_list_64
455
+ }
411
456
412
457
413
458
def attn_postprocess_topk (
@@ -567,4 +612,4 @@ def cluster_and_merge(x, cluster_num):
567
612
source = source .reshape (B * N , C ).type (x .dtype ))
568
613
x_merged = x_merged .reshape (B , cluster_num , C )
569
614
570
- return x_merged
615
+ return x_merged , index_down
0 commit comments