Skip to content

Commit a4933d9

Browse files
committed
update visionzip for llava-next
1 parent 695fbc3 commit a4933d9

File tree

3 files changed

+71
-25
lines changed

3 files changed

+71
-25
lines changed

configs/sparsification/methods/VisionZip/visionzip.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@ sparse:
1717
method: TokenReduction
1818
special:
1919
method: VisionZip # retain
20-
dominant: 191 # visual_tokens = dominan_tokens + 1(cls_token)
21-
contextual: 30
20+
dominant: 162 # visual_tokens = dominan_tokens + contextual
21+
contextual: 30 # llava: 162+30,108+20,54+10 llava_next: 108+20,54+10,27+5
22+
prune_only: False
23+
merge_only: False
2224
save:
2325
save_trans: False
2426
save_fake: False

llmc/compression/token_reduction/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def prepare_inputs_labels_for_multimodal_with_index_masks(
296296
if 'maxpool2x2' in mm_patch_merge_type:
297297
raise NotImplementedError
298298
elif 'unpad' in mm_patch_merge_type and 'anyres_max' in image_aspect_ratio:
299-
NotImplementedError
299+
raise NotImplementedError
300300
elif 'unpad' in mm_patch_merge_type:
301301
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
302302
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
@@ -446,7 +446,6 @@ def prepare_inputs_labels_for_multimodal_with_index_masks(
446446

447447
cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
448448

449-
# import pdb; pdb.set_trace()
450449
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
451450
cur_new_labels = torch.cat(cur_new_labels)
452451

@@ -554,7 +553,6 @@ def prepare_inputs_labels_for_multimodal_with_index_masks(
554553
right_add = random.randint(left_add, self.config.pos_skipping_range)
555554
position_ids[:, :split_position] += left_add
556555
position_ids[:, split_position:] += right_add
557-
# import pdb; pdb.set_trace()
558556
# rank0_print("Finish preparing")
559557
# print(vtoken_length)
560558
return None, position_ids, attention_mask, past_key_values, \

llmc/compression/token_reduction/visionzip.py

Lines changed: 66 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY
1313

1414
from .token_reduction_module import TokenReductionModule
15-
from .utils import apply_info, prefill_wrapper
15+
from .utils import (apply_info, prefill_wrapper,
16+
prepare_inputs_labels_for_multimodal_with_index_masks)
1617

1718

1819
def visionzip_forward(
@@ -286,15 +287,19 @@ def __init__(self, config, model, blocks):
286287
self.register_reduction_modules()
287288

288289
def add_sparse_config(self):
289-
special_config = self.config.get('special', {})
290-
self.dominant = special_config['dominant']
291-
self.contextual = special_config['contextual']
290+
self.dominant = self.special_config['dominant']
291+
self.contextual = self.special_config['contextual']
292292

293-
self.pruning_paras = special_config
293+
self.pruning_paras = self.special_config
294+
prune_only = self.special_config.get('prune_only', False)
295+
merge_only = self.special_config.get('merge_only', False)
296+
assert not (prune_only and merge_only), 'prune_only and merge_only cannot both be True'
297+
self.pruning_paras['prune_only'] = prune_only
298+
self.pruning_paras['merge_only'] = merge_only
294299

295300
def register_reduction_modules(self):
296301

297-
def visionzip_hook(m, images, image_forward_outs):
302+
def visionzip_hook(m, images, image_forward_outs, pruning_paras, llava_next):
298303
attn_weights = image_forward_outs.attentions[-2]
299304
hidden_states = image_forward_outs.hidden_states[-2]
300305
metric = self.blocks[-2].self_attn.k_proj.metric
@@ -306,17 +311,22 @@ def visionzip_hook(m, images, image_forward_outs):
306311
cls_attention = attn_weights[:, :, cls_idx, cls_idx + 1:]
307312
cls_attention_sum = cls_attention.sum(dim=1)
308313
topk_indices = cls_attention_sum.topk(dominant_num, dim=1).indices + 1
309-
all_indices = torch.cat(
310-
[
311-
torch.zeros(
312-
(hidden_states.shape[0], 1),
313-
dtype=topk_indices.dtype,
314-
device=topk_indices.device,
315-
),
316-
topk_indices,
317-
],
318-
dim=1,
319-
)
314+
if pruning_paras['merge_only']:
315+
all_indices = torch.zeros(
316+
(hidden_states.shape[0], 1),
317+
dtype=topk_indices.dtype, device=topk_indices.device
318+
)
319+
dominant_num = 0
320+
else:
321+
all_indices = torch.cat(
322+
[
323+
torch.zeros(
324+
(hidden_states.shape[0], 1),
325+
dtype=topk_indices.dtype, device=topk_indices.device,
326+
),
327+
topk_indices,
328+
], dim=1,
329+
)
320330

321331
mask = torch.ones_like(
322332
hidden_states[:, :, 0], dtype=torch.bool, device=metric.device
@@ -355,6 +365,15 @@ def visionzip_hook(m, images, image_forward_outs):
355365
target_indices = torch.arange(
356366
0, metric_normalized.shape[1], step, device=metric_normalized.device
357367
)[:contextual_num]
368+
369+
# keep_idxs
370+
index_masks = ~mask
371+
if not pruning_paras['prune_only']:
372+
pruned_indices = mask.nonzero(as_tuple=False)[:, 1].view(hidden_states.shape[0], -1)
373+
target_index = pruned_indices[:, target_indices]
374+
index_masks.scatter_(1, target_index, True)
375+
pruning_paras['index_masks'] = index_masks[:, 1:]
376+
358377
target_tokens = metric_normalized[:, target_indices, :]
359378

360379
tokens_to_merge = metric_normalized[
@@ -401,9 +420,15 @@ def visionzip_hook(m, images, image_forward_outs):
401420
).to(images[0].dtype)
402421

403422
res = list(image_forward_outs.hidden_states)
404-
res[-2] = hidden_states_save.contiguous()
423+
if not llava_next:
424+
if pruning_paras['prune_only']:
425+
res[-2] = dominant_tokens.contiguous().to(images[0].dtype)
426+
else:
427+
res[-2] = hidden_states_save.contiguous()
405428
image_forward_outs.hidden_states = tuple(res)
406429

430+
return image_forward_outs
431+
407432
def store_key_hook(m, x, outputs):
408433
bsz = x[0].shape[0]
409434
raw_outputs = (
@@ -418,10 +443,13 @@ def update_output_attentions_hook(module, args, kwargs):
418443
kwargs['output_attentions'] = True
419444
return args, kwargs
420445

446+
def update_index_masks_hook(module, inps, outs, pruning_paras):
447+
module.index_masks = pruning_paras['index_masks']
448+
421449
if self.model.__class__.__name__ == 'LlavaHf':
422450
vision_tower = self.model.vlm_model.vision_tower
423451
elif self.model.__class__.__name__ == 'Llava':
424-
vision_tower = self.model.vlm_model.model.vision_tower.vision_tower
452+
vision_tower = self.model.vision_model.vision_tower
425453

426454
if self.model.__class__.__name__ in ('LlavaHf', 'Llava'):
427455
apply_info(
@@ -444,7 +472,25 @@ def update_output_attentions_hook(module, args, kwargs):
444472
block.self_attn.k_proj.head_dim = block.self_attn.head_dim
445473
block.self_attn.k_proj.register_forward_hook(store_key_hook)
446474

447-
vision_tower.register_forward_hook(visionzip_hook)
475+
vision_tower.register_forward_hook(
476+
functools.partial(
477+
visionzip_hook,
478+
pruning_paras=self.pruning_paras,
479+
llava_next=self.special_config['vision_token_length'] is None
480+
)
481+
)
482+
483+
# llava_next
484+
if self.special_config['vision_token_length'] is None:
485+
486+
self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType(
487+
prepare_inputs_labels_for_multimodal_with_index_masks,
488+
self.model.vlm_model
489+
)
490+
491+
self.model.vision_model.register_forward_hook(
492+
functools.partial(update_index_masks_hook, pruning_paras=self.pruning_paras),
493+
)
448494

449495
def get_metric(fn, pruning_paras):
450496
@wraps(fn)

0 commit comments

Comments
 (0)