12
12
from .token_reduction_module import TokenReductionModule
13
13
from .utils import prefill_wrapper , prefill_wrapper_model
14
14
15
+ layer_dict = {}
16
+
15
17
16
18
@TOKEN_REDUCTION_REGISTRY .register ('SparseVLM' )
17
19
class SparseVLM (TokenReductionModule ):
@@ -24,6 +26,8 @@ def add_sparse_config(self):
24
26
special_config = self .config .get ('special' , {})
25
27
26
28
self .pruning_loc = special_config .get ('pruning_loc' , [2 , 6 , 15 ])
29
+ global layer_dict
30
+ layer_dict = {layer : idx for idx , layer in enumerate (self .pruning_loc )}
27
31
special_config ['retained_tokens' ] = special_config .get ('retained_tokens' , 192 )
28
32
special_config ['init_token_total_shape' ] = special_config .get ('init_token_total_shape' , 668 )
29
33
special_config ['generate_process_count' ] = 0
@@ -44,7 +48,8 @@ def input_hook(module, input_args, pruning_pars):
44
48
# find the position of the first image token
45
49
for seq in input_ids :
46
50
image_token_index = (
47
- seq == IMAGE_TOKEN_INDEX ).nonzero (as_tuple = True )[0 ]
51
+ seq == IMAGE_TOKEN_INDEX
52
+ ).nonzero (as_tuple = True )[0 ]
48
53
if len (image_token_index ) > 0 :
49
54
pre_prompt_length_list .append (image_token_index [0 ].item ())
50
55
else :
@@ -95,33 +100,31 @@ def wrapper(self, *args, **kwargs):
95
100
@prefill_wrapper_model
96
101
def register_module_pars (module , args , kwargs , pruning_pars ):
97
102
pre_prompt_length_list = pruning_pars ['pre_prompt_length_list' ]
98
- inputs_embeds = kwargs ['inputs_embeds' ]
99
- if inputs_embeds is None :
100
- inputs_embeds = module .embed_tokens (kwargs ['input_ids' ])
101
- hidden_states = inputs_embeds # shape: (B, L, C)
103
+ hidden_states = kwargs ['inputs_embeds' ]
104
+ if hidden_states is None :
105
+ hidden_states = module .embed_tokens (kwargs ['input_ids' ])
102
106
103
107
B , L , _ = hidden_states .shape
104
108
pruning_pars ['B' ] = B
105
109
init_n = pruning_pars ['init_token_total_shape' ] + \
106
- pruning_pars ['generate_process_count' ] # 668
110
+ pruning_pars ['generate_process_count' ] # 668
107
111
pruning_pars ['prev_decision' ] = torch .ones (
108
112
B , init_n , 1 , dtype = hidden_states .dtype , device = hidden_states .device )
109
113
pruning_pars ['policy' ] = torch .ones (
110
114
B , init_n , 1 , dtype = hidden_states .dtype , device = hidden_states .device )
111
115
112
- pruning_pars ['v_token_start' ] = pre_prompt_length_list [0 ] if len (
113
- pre_prompt_length_list ) != 0 else 0 # 35
114
- v_token_start = pruning_pars ['v_token_start' ]
115
- pruning_pars ['text_token_start' ] = pruning_pars ['v_token_start' ] + \
116
- pruning_pars ['image_shape' ] # 35 + 576 = 611
117
- text_token_start = pruning_pars ['text_token_start' ]
116
+ v_token_start = pre_prompt_length_list [0 ] if len (
117
+ pre_prompt_length_list ) != 0 else 0
118
+ text_token_start = v_token_start + pruning_pars ['image_shape' ]
119
+ pruning_pars ['v_token_start' ] = v_token_start # 35
120
+ pruning_pars ['text_token_start' ] = text_token_start # 611
118
121
pruning_pars ['v_token_num' ] = pruning_pars ['image_shape' ] # 576
119
122
120
123
if (len (pre_prompt_length_list ) != 0 and hidden_states .shape [1 ] != 1 ):
121
124
v_t = hidden_states [:, v_token_start : text_token_start , :]
122
125
t_t = hidden_states [:, text_token_start :, :]
123
- m_v_t = v_t @ t_t .transpose (1 , 2 ) # [1, 576, 53] # 52?
124
- m_v_t = m_v_t .softmax (2 ).mean (1 ) # [1, 53 ]
126
+ m_v_t = v_t @ t_t .transpose (1 , 2 ) # [1, 576, 52]
127
+ m_v_t = m_v_t .softmax (2 ).mean (1 ) # [1, 52 ]
125
128
pruning_pars ['t_token_idx' ] = torch .where (m_v_t > m_v_t .mean ())
126
129
127
130
return args , kwargs
@@ -134,10 +137,20 @@ def update_output_attentions_hook(module, args, kwargs, pruning_pars, layer_idx)
134
137
kwargs ['position_embeddings' ] = pruning_pars ['position_embeddings' ]
135
138
return args , kwargs
136
139
137
- def get_attn_logits_hook (module , args , kwargs , pruning_pars , layer_idx ):
140
+ def update_kwargs_hook (module , args , kwargs , pruning_pars , layer_idx ):
138
141
139
142
if len (kwargs ['position_ids' ][0 ]) == 1 :
140
143
return args , kwargs
144
+ if layer_idx != self .pruning_loc [0 ]:
145
+ kwargs ['position_ids' ] = pruning_pars ['position_ids' ]
146
+ kwargs ['cache_position' ] = pruning_pars ['cache_position' ]
147
+ kwargs ['position_embeddings' ] = pruning_pars ['position_embeddings' ]
148
+ return args , kwargs
149
+
150
+ def get_attn_logits_hook (module , args , kwargs , layer_outs , pruning_pars , layer_idx ):
151
+
152
+ if len (kwargs ['position_ids' ][0 ]) == 1 :
153
+ return layer_outs
141
154
142
155
from transformers .models .llama .modeling_llama import \
143
156
apply_rotary_pos_emb
@@ -150,8 +163,7 @@ def get_attn_logits_hook(module, args, kwargs, pruning_pars, layer_idx):
150
163
hidden_states = kwargs ['hidden_states' ]
151
164
position_embeddings = kwargs ['position_embeddings' ]
152
165
position_ids = kwargs ['position_ids' ]
153
- past_key_value = kwargs ['past_key_value' ]
154
- cache_position = kwargs ['cache_position' ]
166
+ past_key_value = layer_outs [2 ]
155
167
attention_mask = kwargs ['attention_mask' ]
156
168
157
169
t_token_idx = pruning_pars ['t_token_idx' ]
@@ -179,12 +191,8 @@ def get_attn_logits_hook(module, args, kwargs, pruning_pars, layer_idx):
179
191
180
192
query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin )
181
193
if past_key_value is not None :
182
- temp_cache = copy .deepcopy (past_key_value )
183
- cache_kwargs = {'sin' : sin , 'cos' : cos , 'cache_position' : cache_position }
184
- key_states , value_states = temp_cache .update (
185
- key_states , value_states ,
186
- layer_idx , cache_kwargs
187
- )
194
+ key_states = past_key_value .key_cache [layer_idx ]
195
+ value_states = past_key_value .value_cache [layer_idx ]
188
196
t_token_idx = t_token_idx [1 ] + v_token_start + v_token_num
189
197
L , S = query_states .size (- 2 ), key_states .size (- 2 )
190
198
scale_factor = 1 / math .sqrt (query_states .size (- 1 ))
@@ -201,19 +209,16 @@ def get_attn_logits_hook(module, args, kwargs, pruning_pars, layer_idx):
201
209
202
210
pruning_pars ['attn_logits' ] = attn_logits
203
211
204
- return args , kwargs
212
+ return layer_outs
205
213
206
214
@prefill_wrapper
207
215
def decoder_attn_hook (module , inputs , kwargs , layer_outputs , pruning_pars , layer_idx ):
208
216
209
- # pruning_pars['attn_logits'] 对llavaHf运行存在BUG,
210
- # 使用layer_outputs[1]运行llavaHf无问题,但精度没对上
211
- # llava:attn_logits = pruning_pars['attn_logits']
212
- # llavahf:attn_logits = layer_outputs[1]
213
217
if 'attn_logits' not in pruning_pars :
214
218
attn_logits = layer_outputs [1 ]
215
219
else :
216
220
attn_logits = pruning_pars ['attn_logits' ]
221
+ merge_flag = pruning_pars ['merge_flag' ]
217
222
v_token_start = pruning_pars ['v_token_start' ]
218
223
v_token_num = pruning_pars ['v_token_num' ]
219
224
text_token_start = pruning_pars ['text_token_start' ]
@@ -255,7 +260,7 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer
255
260
total_sparse_token_idx = torch .where (policy == 0 )[1 ].unsqueeze (0 )
256
261
257
262
# merge and cluster
258
- if s_flag and total_sparse_token_idx .shape [1 ] > 0 :
263
+ if s_flag and merge_flag and total_sparse_token_idx .shape [1 ] > 0 :
259
264
total_sparse_token = batch_index_select (layer_outputs [0 ], total_sparse_token_idx )
260
265
261
266
merge_token_idx_stage1 = torch .where (pred_score_vis == 0 )[1 ]
@@ -359,6 +364,14 @@ def read_parameter_hook(module, args, kwargs, pruning_pars):
359
364
)
360
365
elif self .model .__class__ .__name__ == 'Llava' :
361
366
self .blocks [block_idx ].self_attn .register_forward_pre_hook (
367
+ functools .partial (
368
+ update_kwargs_hook ,
369
+ pruning_pars = self .pruning_paras ,
370
+ layer_idx = block_idx ,
371
+ ),
372
+ with_kwargs = True
373
+ )
374
+ self .blocks [block_idx ].self_attn .register_forward_hook (
362
375
functools .partial (
363
376
get_attn_logits_hook ,
364
377
pruning_pars = self .pruning_paras ,
0 commit comments