Skip to content

Commit c70c7f6

Browse files
authored
support DART for llava (#403)
1 parent 2c61449 commit c70c7f6

File tree

6 files changed

+282
-9
lines changed

6 files changed

+282
-9
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
base:
2+
seed: &seed 42
3+
model:
4+
type: Llava
5+
path: model path
6+
torch_dtype: auto
7+
eval:
8+
eval_pos: [pretrain, transformed]
9+
type: vqa
10+
name: [mme]
11+
download: False
12+
path: MME dataset path
13+
bs: 1
14+
inference_per_block: False
15+
sparse:
16+
method: TokenReduction
17+
special:
18+
method: DART
19+
pruning_loc: 2
20+
reduction_ratio: 0.778
21+
max_num_trunction: 128
22+
pivot_image_token: 4
23+
pivot_text_token : 4
24+
save:
25+
save_trans: False
26+
save_fake: False
27+
save_path: /path/to/save/

llmc/compression/token_reduction/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .base_blockwise_token_reduction import TokenReduction
2+
from .dart import DART
23
from .dycoke import DyCoke
34
from .fastervlm import FasterVLM
45
from .fastv import FastV
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
import functools
2+
import math
3+
from functools import wraps
4+
from types import MethodType
5+
6+
import torch
7+
8+
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY
9+
10+
from .token_reduction_module import TokenReductionModule
11+
from .utils import prefill_wrapper
12+
13+
14+
@TOKEN_REDUCTION_REGISTRY.register('DART')
15+
class DART(TokenReductionModule):
16+
def __init__(self, config, model, blocks):
17+
super().__init__(config, model, blocks)
18+
self.add_sparse_config()
19+
self.register_reduction_modules()
20+
21+
def add_sparse_config(self):
22+
23+
self.pruning_loc = self.special_config['pruning_loc']
24+
self.special_config['image_token_length'] = \
25+
self.model.pruning_config['image_token_length']
26+
self.special_config['IMAGE_TOKEN_INDEX'] = \
27+
self.model.pruning_config['IMAGE_TOKEN_INDEX']
28+
29+
self.pruning_paras = self.special_config
30+
31+
def register_reduction_modules(self):
32+
33+
def input_hook_llava(fn, pruning_paras):
34+
@wraps(fn)
35+
def wrapper(self, *args, **kwargs):
36+
if len(args) == 0:
37+
return fn(*args, **kwargs)
38+
input_args = args[0]
39+
if hasattr(input_args[0], 'shape') and input_args[0].shape[0] == 1:
40+
return fn(*args, **kwargs)
41+
42+
input_ids = args[0]
43+
attention_mask = args[2]
44+
token_indices = (
45+
input_ids[0][attention_mask[0]] == pruning_paras['IMAGE_TOKEN_INDEX']
46+
)
47+
pruning_paras['image_token_start_index'] = torch.where(token_indices)[0].item()
48+
49+
outputs = fn(*args, **kwargs)
50+
return outputs
51+
return wrapper
52+
53+
def get_seq_len_hook(module, args, kwargs, pruning_paras):
54+
if kwargs['input_ids'] is not None:
55+
pruning_paras['seq_len'] = kwargs['input_ids'].shape[1]
56+
elif kwargs['inputs_embeds'] is not None:
57+
pruning_paras['seq_len'] = kwargs['inputs_embeds'].shape[1]
58+
else:
59+
raise ValueError('You have to specify either input_ids or inputs_embeds')
60+
61+
def get_any_states_hook(module, args, kwargs, layer_outs, pruning_paras, layer_idx):
62+
from transformers.models.llama.modeling_llama import (
63+
apply_rotary_pos_emb, repeat_kv)
64+
if len(kwargs['position_ids'][0]) == 1:
65+
return layer_outs
66+
67+
hidden_states = kwargs['hidden_states']
68+
position_embeddings = kwargs['position_embeddings']
69+
position_ids = kwargs['position_ids']
70+
past_key_value = kwargs['past_key_value']
71+
72+
bsz, q_len, _ = hidden_states.size()
73+
query_states = module.q_proj(hidden_states)
74+
key_states = module.k_proj(hidden_states)
75+
value_states = module.v_proj(hidden_states)
76+
query_states = query_states.view(
77+
bsz, q_len, module.num_heads, module.head_dim
78+
).transpose(1, 2)
79+
key_states = key_states.view(
80+
bsz, q_len, module.num_key_value_heads, module.head_dim
81+
).transpose(1, 2)
82+
value_states = value_states.view(
83+
bsz, q_len, module.num_key_value_heads, module.head_dim
84+
).transpose(1, 2)
85+
86+
if position_embeddings is None:
87+
cos, sin = module.rotary_emb(value_states, position_ids)
88+
else:
89+
cos, sin = position_embeddings
90+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
91+
if past_key_value is not None:
92+
key_states = past_key_value.key_cache[layer_idx]
93+
value_states = past_key_value.value_cache[layer_idx]
94+
key_states = repeat_kv(key_states, module.num_key_value_groups)
95+
value_states = repeat_kv(value_states, module.num_key_value_groups)
96+
97+
pruning_paras['any_states'] = (query_states, key_states, value_states)
98+
99+
return layer_outs
100+
101+
@prefill_wrapper
102+
def pruning_hook(module, args, kwargs, pruning_paras, normlayer):
103+
104+
image_token_start_index = pruning_paras['image_token_start_index']
105+
image_token_length = pruning_paras['image_token_length']
106+
any_states = pruning_paras['any_states'][-2]
107+
seq_length = pruning_paras['seq_len']
108+
109+
hidden_states = args[0]
110+
attention_mask = kwargs['attention_mask']
111+
device = hidden_states.device
112+
last_layer_state = normlayer(hidden_states)
113+
114+
# keep index
115+
retained_image_tokens_index = get_retained_image_token(
116+
pruning_paras, last_layer_state, any_states)
117+
118+
keep_indexs = torch.cat(
119+
(
120+
torch.arange(image_token_start_index, device=device),
121+
retained_image_tokens_index,
122+
torch.arange(
123+
image_token_start_index + image_token_length,
124+
seq_length,
125+
device=device
126+
)
127+
)
128+
)
129+
# sort index
130+
keep_indexs = keep_indexs.sort().values
131+
hidden_states = hidden_states[:, keep_indexs, :]
132+
position_ids = keep_indexs.unsqueeze(0)
133+
if attention_mask is not None:
134+
attention_mask = attention_mask[
135+
:, :, :hidden_states.shape[1], :hidden_states.shape[1]
136+
]
137+
kwargs['attention_mask'].resize_as_(attention_mask).copy_(attention_mask.clone())
138+
kwargs['cache_position'].resize_as_(position_ids.squeeze(0)).copy_(
139+
position_ids.squeeze(0).clone())
140+
kwargs['position_ids'].resize_as_(position_ids).copy_(position_ids.clone())
141+
142+
position_embeddings = kwargs['position_embeddings']
143+
new_pe0 = position_embeddings[0][:, keep_indexs, :].clone()
144+
new_pe1 = position_embeddings[1][:, keep_indexs, :].clone()
145+
position_embeddings[0].resize_as_(new_pe0).copy_(new_pe0)
146+
position_embeddings[1].resize_as_(new_pe0).copy_(new_pe1)
147+
148+
return (hidden_states,), kwargs
149+
150+
hook_fn = input_hook_llava(
151+
self.model.vlm_model.prepare_inputs_labels_for_multimodal,
152+
self.pruning_paras
153+
)
154+
self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType(
155+
hook_fn, self.model.vlm_model
156+
)
157+
158+
self.model.model.model.register_forward_pre_hook(
159+
functools.partial(get_seq_len_hook, pruning_paras=self.pruning_paras),
160+
with_kwargs=True
161+
)
162+
163+
self.blocks[self.pruning_loc - 1].self_attn.register_forward_hook(
164+
functools.partial(
165+
get_any_states_hook,
166+
pruning_paras=self.pruning_paras,
167+
layer_idx=self.pruning_loc - 1
168+
),
169+
with_kwargs=True
170+
)
171+
172+
self.blocks[self.pruning_loc].register_forward_pre_hook(
173+
functools.partial(
174+
pruning_hook,
175+
pruning_paras=self.pruning_paras,
176+
normlayer=self.model.model.model.norm
177+
),
178+
with_kwargs=True
179+
)
180+
181+
182+
def get_retained_image_token(pruning_paras, last_layer_state, any_states):
183+
image_token_start_index = pruning_paras['image_token_start_index']
184+
image_token_length = pruning_paras['image_token_length']
185+
MAX_NUM_TRUNCTION = pruning_paras['max_num_trunction']
186+
pivot_image_token = pruning_paras['pivot_image_token']
187+
pivot_text_token = pruning_paras['pivot_text_token']
188+
reduction_ratio = pruning_paras['reduction_ratio']
189+
TOKEN_TOPK = math.ceil(
190+
(
191+
MAX_NUM_TRUNCTION if MAX_NUM_TRUNCTION is not None
192+
else (image_token_length * (1 - reduction_ratio))
193+
) // (pivot_image_token + pivot_text_token))
194+
device = last_layer_state.device
195+
196+
any_states = (
197+
any_states.permute(0, 2, 1, 3)
198+
.reshape(any_states.shape[0], any_states.shape[1], -1)
199+
)
200+
201+
k_states_image_token = any_states[0][
202+
image_token_start_index:image_token_start_index + image_token_length, :
203+
]
204+
k_states_query_token = any_states[0][image_token_start_index + image_token_length:, :]
205+
206+
k_states_image_token_L1_norm = torch.norm(k_states_image_token, p=1, dim=-1)
207+
k_states_query_token_L1_norm = torch.norm(k_states_query_token, p=1, dim=-1)
208+
209+
image_indices = (
210+
k_states_image_token_L1_norm.topk(pivot_image_token).indices
211+
+ image_token_start_index
212+
).tolist()
213+
query_indices = (
214+
k_states_query_token_L1_norm.topk(pivot_text_token).indices
215+
+ image_token_start_index + image_token_length
216+
).tolist()
217+
indices_set = set(image_indices + query_indices)
218+
219+
valid_indices = set(
220+
range(image_token_start_index, image_token_start_index + image_token_length)
221+
) - set(image_indices)
222+
223+
valid_indices_list = list(valid_indices)
224+
for item in list(indices_set):
225+
valid_vectors = last_layer_state[0][valid_indices_list, :]
226+
cos_sim = -torch.nn.functional.cosine_similarity(
227+
last_layer_state[0][item, :],
228+
valid_vectors,
229+
dim=-1
230+
)
231+
top_k_indices = cos_sim.topk(TOKEN_TOPK).indices
232+
233+
top_k_real_indices = [valid_indices_list[i] for i in top_k_indices]
234+
indices_set.update(top_k_real_indices)
235+
236+
valid_indices.difference_update(top_k_real_indices)
237+
valid_indices_list = list(valid_indices)
238+
239+
indices_set.difference_update(query_indices)
240+
241+
retained_image_tokens_index = torch.tensor(list(indices_set), device=device)
242+
243+
return retained_image_tokens_index

llmc/compression/token_reduction/fastv.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,11 @@
44

55
import torch
66

7-
from llmc.compression.sparsification.attn_utils import _update_causal_mask
87
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY
98

109
from .token_reduction_module import TokenReductionModule
1110
from .utils import prefill_wrapper
1211

13-
IMAGE_TOKEN_INDEX = -200
14-
1512

1613
@TOKEN_REDUCTION_REGISTRY.register('FastV')
1714
class FastV(TokenReductionModule):
@@ -25,6 +22,8 @@ def add_sparse_config(self):
2522
self.pruning_loc = self.special_config['pruning_loc']
2623
self.special_config['image_token_length'] = \
2724
self.model.pruning_config['image_token_length']
25+
self.special_config['IMAGE_TOKEN_INDEX'] = \
26+
self.model.pruning_config['IMAGE_TOKEN_INDEX']
2827
self.special_config['attn_scores'] = None
2928

3029
self.pruning_paras = self.special_config
@@ -51,7 +50,8 @@ def wrapper(self, *args, **kwargs):
5150

5251
input_ids = args[0]
5352
attention_mask = args[2]
54-
token_indices = input_ids[0][attention_mask[0]] == IMAGE_TOKEN_INDEX
53+
token_indices = \
54+
input_ids[0][attention_mask[0]] == pruning_paras['IMAGE_TOKEN_INDEX']
5555
pruning_paras['image_token_start_index'] = torch.where(token_indices)[0].item()
5656

5757
outputs = fn(*args, **kwargs)
@@ -127,7 +127,6 @@ def fastv_pruning_hook(module, args, kwargs, pruning_paras):
127127
functools.partial(input_hook, pruning_paras=self.pruning_paras)
128128
)
129129
elif self.model.__class__.__name__ == 'Llava':
130-
from llava.constants import IMAGE_TOKEN_INDEX
131130
hook_fn = input_hook_llava(
132131
self.model.vlm_model.prepare_inputs_labels_for_multimodal,
133132
self.pruning_paras

llmc/compression/token_reduction/pyramiddrop.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ def __init__(self, config, model, blocks):
2525
def add_sparse_config(self):
2626

2727
self.pruning_loc = self.special_config['layer_list']
28+
self.special_config['IMAGE_TOKEN_INDEX'] = \
29+
self.model.pruning_config['IMAGE_TOKEN_INDEX']
30+
2831
image_token_ratio_list = self.special_config['image_token_ratio_list']
2932
image_token_ratio_list.insert(0, 1.0)
3033
self.special_config['image_token_ratio_list'] = image_token_ratio_list
@@ -348,7 +351,7 @@ def wrapper(self, *args, **kwargs):
348351
vision_tokens = []
349352
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask):
350353
seq = cur_input_ids[cur_attention_mask]
351-
image_index = torch.where(seq == IMAGE_TOKEN_INDEX)[0].tolist()
354+
image_index = torch.where(seq == pruning_paras['IMAGE_TOKEN_INDEX'])[0].tolist()
352355
if image_index == []:
353356
image_token_posi.append(-1)
354357
prompt_len.append(cur_input_ids.shape[0])
@@ -378,7 +381,6 @@ def read_parameter_hook(module, args, kwargs, pruning_pars):
378381
functools.partial(input_hook, pruning_pars=self.pruning_paras)
379382
)
380383
elif self.model.__class__.__name__ == 'Llava':
381-
from llava.constants import IMAGE_TOKEN_INDEX
382384
hook_fn = input_hook_llava(
383385
self.model.vlm_model.prepare_inputs_labels_for_multimodal,
384386
self.pruning_paras

llmc/models/llava.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
try:
1919
from llava.constants import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
20-
DEFAULT_IMAGE_PATCH_TOKEN)
20+
DEFAULT_IMAGE_PATCH_TOKEN, IMAGE_TOKEN_INDEX)
2121
from llava.mm_utils import get_model_name_from_path
2222
from llava.model.builder import load_pretrained_model
2323
from llava.model.language_model.llava_llama import LlavaConfig
@@ -66,7 +66,8 @@ def build_model(self):
6666
'image_token_length': self.vlm_model_config.image_seq_length,
6767
'select_layer': self.vlm_model_config.vision_feature_layer,
6868
'select_feature': self.vlm_model_config.vision_feature_select_strategy,
69-
'image_token_index': self.vlm_model_config.image_token_index
69+
'image_token_index': self.vlm_model_config.image_token_index,
70+
'IMAGE_TOKEN_INDEX': IMAGE_TOKEN_INDEX, # for llava
7071
}
7172
self.processor = None
7273

0 commit comments

Comments
 (0)