12
12
from llmc .utils .registry_factory import TOKEN_REDUCTION_REGISTRY
13
13
14
14
from .token_reduction_module import TokenReductionModule
15
- from .utils import apply_info , prefill_wrapper
15
+ from .utils import (apply_info , prefill_wrapper ,
16
+ prepare_inputs_labels_for_multimodal_with_index_masks )
16
17
17
18
18
19
def visionzip_forward (
@@ -286,15 +287,19 @@ def __init__(self, config, model, blocks):
286
287
self .register_reduction_modules ()
287
288
288
289
def add_sparse_config (self ):
289
- special_config = self .config .get ('special' , {})
290
- self .dominant = special_config ['dominant' ]
291
- self .contextual = special_config ['contextual' ]
290
+ self .dominant = self .special_config ['dominant' ]
291
+ self .contextual = self .special_config ['contextual' ]
292
292
293
- self .pruning_paras = special_config
293
+ self .pruning_paras = self .special_config
294
+ prune_only = self .special_config .get ('prune_only' , False )
295
+ merge_only = self .special_config .get ('merge_only' , False )
296
+ assert not (prune_only and merge_only ), 'prune_only and merge_only cannot both be True'
297
+ self .pruning_paras ['prune_only' ] = prune_only
298
+ self .pruning_paras ['merge_only' ] = merge_only
294
299
295
300
def register_reduction_modules (self ):
296
301
297
- def visionzip_hook (m , images , image_forward_outs ):
302
+ def visionzip_hook (m , images , image_forward_outs , pruning_paras , llava_next ):
298
303
attn_weights = image_forward_outs .attentions [- 2 ]
299
304
hidden_states = image_forward_outs .hidden_states [- 2 ]
300
305
metric = self .blocks [- 2 ].self_attn .k_proj .metric
@@ -306,17 +311,22 @@ def visionzip_hook(m, images, image_forward_outs):
306
311
cls_attention = attn_weights [:, :, cls_idx , cls_idx + 1 :]
307
312
cls_attention_sum = cls_attention .sum (dim = 1 )
308
313
topk_indices = cls_attention_sum .topk (dominant_num , dim = 1 ).indices + 1
309
- all_indices = torch .cat (
310
- [
311
- torch .zeros (
312
- (hidden_states .shape [0 ], 1 ),
313
- dtype = topk_indices .dtype ,
314
- device = topk_indices .device ,
315
- ),
316
- topk_indices ,
317
- ],
318
- dim = 1 ,
319
- )
314
+ if pruning_paras ['merge_only' ]:
315
+ all_indices = torch .zeros (
316
+ (hidden_states .shape [0 ], 1 ),
317
+ dtype = topk_indices .dtype , device = topk_indices .device
318
+ )
319
+ dominant_num = 0
320
+ else :
321
+ all_indices = torch .cat (
322
+ [
323
+ torch .zeros (
324
+ (hidden_states .shape [0 ], 1 ),
325
+ dtype = topk_indices .dtype , device = topk_indices .device ,
326
+ ),
327
+ topk_indices ,
328
+ ], dim = 1 ,
329
+ )
320
330
321
331
mask = torch .ones_like (
322
332
hidden_states [:, :, 0 ], dtype = torch .bool , device = metric .device
@@ -355,6 +365,15 @@ def visionzip_hook(m, images, image_forward_outs):
355
365
target_indices = torch .arange (
356
366
0 , metric_normalized .shape [1 ], step , device = metric_normalized .device
357
367
)[:contextual_num ]
368
+
369
+ # keep_idxs
370
+ index_masks = ~ mask
371
+ if not pruning_paras ['prune_only' ]:
372
+ pruned_indices = mask .nonzero (as_tuple = False )[:, 1 ].view (hidden_states .shape [0 ], - 1 )
373
+ target_index = pruned_indices [:, target_indices ]
374
+ index_masks .scatter_ (1 , target_index , True )
375
+ pruning_paras ['index_masks' ] = index_masks [:, 1 :]
376
+
358
377
target_tokens = metric_normalized [:, target_indices , :]
359
378
360
379
tokens_to_merge = metric_normalized [
@@ -401,9 +420,15 @@ def visionzip_hook(m, images, image_forward_outs):
401
420
).to (images [0 ].dtype )
402
421
403
422
res = list (image_forward_outs .hidden_states )
404
- res [- 2 ] = hidden_states_save .contiguous ()
423
+ if not llava_next :
424
+ if pruning_paras ['prune_only' ]:
425
+ res [- 2 ] = dominant_tokens .contiguous ().to (images [0 ].dtype )
426
+ else :
427
+ res [- 2 ] = hidden_states_save .contiguous ()
405
428
image_forward_outs .hidden_states = tuple (res )
406
429
430
+ return image_forward_outs
431
+
407
432
def store_key_hook (m , x , outputs ):
408
433
bsz = x [0 ].shape [0 ]
409
434
raw_outputs = (
@@ -418,10 +443,13 @@ def update_output_attentions_hook(module, args, kwargs):
418
443
kwargs ['output_attentions' ] = True
419
444
return args , kwargs
420
445
446
+ def update_index_masks_hook (module , inps , outs , pruning_paras ):
447
+ module .index_masks = pruning_paras ['index_masks' ]
448
+
421
449
if self .model .__class__ .__name__ == 'LlavaHf' :
422
450
vision_tower = self .model .vlm_model .vision_tower
423
451
elif self .model .__class__ .__name__ == 'Llava' :
424
- vision_tower = self .model .vlm_model . model . vision_tower .vision_tower
452
+ vision_tower = self .model .vision_model .vision_tower
425
453
426
454
if self .model .__class__ .__name__ in ('LlavaHf' , 'Llava' ):
427
455
apply_info (
@@ -444,7 +472,25 @@ def update_output_attentions_hook(module, args, kwargs):
444
472
block .self_attn .k_proj .head_dim = block .self_attn .head_dim
445
473
block .self_attn .k_proj .register_forward_hook (store_key_hook )
446
474
447
- vision_tower .register_forward_hook (visionzip_hook )
475
+ vision_tower .register_forward_hook (
476
+ functools .partial (
477
+ visionzip_hook ,
478
+ pruning_paras = self .pruning_paras ,
479
+ llava_next = self .special_config ['vision_token_length' ] is None
480
+ )
481
+ )
482
+
483
+ # llava_next
484
+ if self .special_config ['vision_token_length' ] is None :
485
+
486
+ self .model .vlm_model .prepare_inputs_labels_for_multimodal = MethodType (
487
+ prepare_inputs_labels_for_multimodal_with_index_masks ,
488
+ self .model .vlm_model
489
+ )
490
+
491
+ self .model .vision_model .register_forward_hook (
492
+ functools .partial (update_index_masks_hook , pruning_paras = self .pruning_paras ),
493
+ )
448
494
449
495
def get_metric (fn , pruning_paras ):
450
496
@wraps (fn )
0 commit comments