Skip to content

Commit 64df6d5

Browse files
authored
refine sparsevlm for llava (#418)
1 parent ee525d4 commit 64df6d5

File tree

2 files changed

+111
-66
lines changed

2 files changed

+111
-66
lines changed

configs/sparsification/methods/SparseVLM/sparsevlm.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ sparse:
1616
method: TokenReduction
1717
special:
1818
method: SparseVLM
19-
pruning_loc: [2] # [2, 6, 15]
19+
pruning_loc: [2, 6, 15]
2020
retained_tokens: 192
21-
init_token_total_shape: 668
22-
merge_flag: False
21+
prune_flag: True
22+
merge_flag: True
2323
save:
2424
save_trans: False
2525
save_fake: False

llmc/compression/token_reduction/sparsevlm.py

Lines changed: 108 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@
1313
from .utils import prefill_wrapper, prefill_wrapper_model
1414

1515
layer_dict = {}
16+
prune_flag = True
17+
merge_flag = True
18+
sparse_token_list_192 = []
19+
sparse_token_list_128 = []
20+
sparse_token_list_64 = []
21+
sparse_token_dict = {}
1622

1723

1824
@TOKEN_REDUCTION_REGISTRY.register('SparseVLM')
@@ -26,13 +32,13 @@ def add_sparse_config(self):
2632
special_config = self.config.get('special', {})
2733

2834
self.pruning_loc = special_config.get('pruning_loc', [2, 6, 15])
29-
global layer_dict
35+
global layer_dict, prune_flag, merge_flag
3036
layer_dict = {layer: idx for idx, layer in enumerate(self.pruning_loc)}
37+
prune_flag = special_config.get('prune_flag', True)
38+
merge_flag = special_config.get('merge_flag', True)
39+
update_list()
3140
special_config['retained_tokens'] = special_config.get('retained_tokens', 192)
32-
special_config['init_token_total_shape'] = special_config.get('init_token_total_shape', 668)
33-
special_config['generate_process_count'] = 0
3441
special_config['pre_prompt_length_list'] = []
35-
special_config['token_length_list'] = []
3642
special_config['image_shape'] = self.model.pruning_config['image_token_length']
3743
special_config['image_token_index'] = self.model.pruning_config['image_token_index']
3844
self.pruning_paras = special_config
@@ -42,7 +48,6 @@ def register_reduction_modules(self):
4248
def input_hook(module, input_args, pruning_pars):
4349
input_ids = input_args[0]
4450
pre_prompt_length_list = []
45-
token_length_list = []
4651
IMAGE_TOKEN_INDEX = pruning_pars['image_token_index']
4752

4853
# find the position of the first image token
@@ -54,10 +59,7 @@ def input_hook(module, input_args, pruning_pars):
5459
pre_prompt_length_list.append(image_token_index[0].item())
5560
else:
5661
pre_prompt_length_list.append(0)
57-
token_length_list.append(seq.shape[0])
58-
5962
pruning_pars['pre_prompt_length_list'] = pre_prompt_length_list
60-
pruning_pars['token_length_list'] = token_length_list
6163

6264
return input_args
6365

@@ -90,11 +92,7 @@ def wrapper(self, *args, **kwargs):
9092

9193
pruning_paras['pre_prompt_length_list'] = pre_prompt_length_list
9294

93-
outputs = fn(*args, **kwargs)
94-
95-
pruning_paras['token_length_list'] = outputs[2].sum(dim=1).tolist()
96-
97-
return outputs
95+
return fn(*args, **kwargs)
9896
return wrapper
9997

10098
@prefill_wrapper_model
@@ -106,12 +104,6 @@ def register_module_pars(module, args, kwargs, pruning_pars):
106104

107105
B, L, _ = hidden_states.shape
108106
pruning_pars['B'] = B
109-
init_n = pruning_pars['init_token_total_shape'] + \
110-
pruning_pars['generate_process_count'] # 668
111-
pruning_pars['prev_decision'] = torch.ones(
112-
B, init_n, 1, dtype=hidden_states.dtype, device=hidden_states.device)
113-
pruning_pars['policy'] = torch.ones(
114-
B, init_n, 1, dtype=hidden_states.dtype, device=hidden_states.device)
115107

116108
v_token_start = pre_prompt_length_list[0] if len(
117109
pre_prompt_length_list) != 0 else 0
@@ -123,8 +115,8 @@ def register_module_pars(module, args, kwargs, pruning_pars):
123115
if (len(pre_prompt_length_list) != 0 and hidden_states.shape[1] != 1):
124116
v_t = hidden_states[:, v_token_start: text_token_start, :]
125117
t_t = hidden_states[:, text_token_start:, :]
126-
m_v_t = v_t @ t_t.transpose(1, 2) # [1, 576, 52]
127-
m_v_t = m_v_t.softmax(2).mean(1) # [1, 52]
118+
m_v_t = v_t @ t_t.transpose(1, 2)
119+
m_v_t = m_v_t.softmax(2).mean(1)
128120
pruning_pars['t_token_idx'] = torch.where(m_v_t > m_v_t.mean())
129121

130122
return args, kwargs
@@ -133,6 +125,7 @@ def update_output_attentions_hook(module, args, kwargs, pruning_pars, layer_idx)
133125
kwargs['output_attentions'] = True
134126
if layer_idx != self.pruning_loc[0]:
135127
kwargs['position_ids'] = pruning_pars['position_ids']
128+
kwargs['attention_mask'] = pruning_pars['attention_mask']
136129
kwargs['cache_position'] = pruning_pars['cache_position']
137130
kwargs['position_embeddings'] = pruning_pars['position_embeddings']
138131
return args, kwargs
@@ -143,8 +136,14 @@ def update_kwargs_hook(module, args, kwargs, pruning_pars, layer_idx):
143136
return args, kwargs
144137
if layer_idx != self.pruning_loc[0]:
145138
kwargs['position_ids'] = pruning_pars['position_ids']
139+
kwargs['attention_mask'] = pruning_pars['attention_mask']
146140
kwargs['cache_position'] = pruning_pars['cache_position']
147141
kwargs['position_embeddings'] = pruning_pars['position_embeddings']
142+
else:
143+
pruning_pars['position_ids'] = kwargs['position_ids']
144+
pruning_pars['attention_mask'] = kwargs['attention_mask']
145+
pruning_pars['cache_position'] = kwargs['cache_position']
146+
pruning_pars['position_embeddings'] = kwargs['position_embeddings']
148147
return args, kwargs
149148

150149
def get_attn_logits_hook(module, args, kwargs, layer_outs, pruning_pars, layer_idx):
@@ -155,11 +154,6 @@ def get_attn_logits_hook(module, args, kwargs, layer_outs, pruning_pars, layer_i
155154
from transformers.models.llama.modeling_llama import \
156155
apply_rotary_pos_emb
157156

158-
if layer_idx != self.pruning_loc[0]:
159-
kwargs['position_ids'] = pruning_pars['position_ids']
160-
kwargs['cache_position'] = pruning_pars['cache_position']
161-
kwargs['position_embeddings'] = pruning_pars['position_embeddings']
162-
163157
hidden_states = kwargs['hidden_states']
164158
position_embeddings = kwargs['position_embeddings']
165159
position_ids = kwargs['position_ids']
@@ -215,9 +209,10 @@ def get_attn_logits_hook(module, args, kwargs, layer_outs, pruning_pars, layer_i
215209
def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer_idx):
216210

217211
if 'attn_logits' not in pruning_pars:
218-
attn_logits = layer_outputs[1]
212+
attn_logits = layer_outputs[1] # for LlavaHf
219213
else:
220214
attn_logits = pruning_pars['attn_logits']
215+
prune_flag = pruning_pars.get('prune_flag', True)
221216
merge_flag = pruning_pars['merge_flag']
222217
v_token_start = pruning_pars['v_token_start']
223218
v_token_num = pruning_pars['v_token_num']
@@ -227,13 +222,11 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer
227222
B = pruning_pars['B']
228223
pre_prompt_length_list = pruning_pars['pre_prompt_length_list']
229224
image_shape = pruning_pars['image_shape']
230-
if layer_idx == self.pruning_loc[0]:
231-
position_ids = kwargs['position_ids']
232-
pruning_pars['position_ids'] = position_ids
233-
else:
234-
position_ids = pruning_pars['position_ids']
235-
hidden_states = inputs[0] # [B, L, D]
236225

226+
attention_mask = kwargs['attention_mask']
227+
position_embeddings = kwargs['position_embeddings']
228+
229+
hidden_states = inputs[0] # [B, L, D]
237230
pred_score_vis, s_flag, relation_vis_text = attn_postprocess_topk(
238231
attn_logits,
239232
v_token_start,
@@ -243,7 +236,8 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer
243236
layer_idx,
244237
retained_tokens
245238
)
246-
239+
if not prune_flag:
240+
pred_score_vis = torch.zeros_like(relation_vis_text, dtype=bool)
247241
policy = torch.ones(B, hidden_states.shape[1], dtype=hidden_states.dtype,
248242
device=hidden_states.device)
249243
policy[:, v_token_start:text_token_start] = \
@@ -261,60 +255,91 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer
261255

262256
# merge and cluster
263257
if s_flag and merge_flag and total_sparse_token_idx.shape[1] > 0:
264-
total_sparse_token = batch_index_select(layer_outputs[0], total_sparse_token_idx)
258+
total_sparse_token = batch_index_select(
259+
layer_outputs[0], total_sparse_token_idx
260+
)
265261

266262
merge_token_idx_stage1 = torch.where(pred_score_vis == 0)[1]
267263
merge_token_stage1 = relation_vis_text[0][merge_token_idx_stage1]
268-
merge_token_num_stage1 = int(merge_token_idx_stage1.shape[0] * 0.3) + 1 # Top 30%
264+
if prune_flag:
265+
merge_token_num_stage1 = int(merge_token_idx_stage1.shape[0] * 0.3) + 1
266+
else:
267+
merge_token_num_stage1 = (
268+
merge_token_idx_stage1.shape[0]
269+
- sparse_token_dict[retained_tokens][layer_dict[layer_idx]]
270+
)
269271
merge_token_stage2_idx = merge_token_stage1.topk(merge_token_num_stage1)[1]
272+
if not prune_flag:
273+
all_idx = torch.arange(
274+
merge_token_stage1.size(0),
275+
device=merge_token_stage1.device
276+
)
277+
non_topk_idx = all_idx[~torch.isin(all_idx, merge_token_stage2_idx)]
278+
pred_score_vis[0][non_topk_idx] = 1
279+
policy[:, v_token_start:text_token_start] = \
280+
pred_score_vis.type(dtype=hidden_states.dtype)
270281

271282
merge_token_stage2 = total_sparse_token[:, merge_token_stage2_idx, :]
272283
cluster_num = int(merge_token_stage2.shape[1] / 10) + 1
273284
if cluster_num == 0:
274285
cluster_num = merge_token_stage2.shape[1]
286+
merge_sparse_token, index_down = cluster_and_merge(merge_token_stage2, cluster_num)
275287

276-
merge_sparse_token = cluster_and_merge(merge_token_stage2, cluster_num)
277-
288+
cluster_idx = total_sparse_token_idx.squeeze(0)[merge_token_stage2_idx[index_down]]
289+
cluster_idx = cluster_idx.squeeze(0)
278290
select_token_idx = torch.where(policy == 1)[1].unsqueeze(0)
279291
select_token = batch_index_select(layer_outputs[0], select_token_idx)
280292
select_vis_token_num = pred_score_vis.sum()
281-
293+
keep_indexs = torch.cat(
294+
(
295+
select_token_idx.squeeze(0)[:v_token_start + select_vis_token_num],
296+
cluster_idx,
297+
select_token_idx.squeeze(0)[v_token_start + select_vis_token_num:]
298+
)
299+
)
282300
select_and_merge_token = torch.cat(
283301
(
284-
select_token[:, :v_token_start +
285-
select_vis_token_num, :],
302+
select_token[:, :v_token_start + select_vis_token_num, :],
286303
merge_sparse_token,
287-
select_token[:, v_token_start +
288-
select_vis_token_num:, :]
304+
select_token[:, v_token_start + select_vis_token_num:, :]
289305
),
290306
dim=1
291307
)
292308
layer_outputs = (select_and_merge_token, layer_outputs[1])
293-
position_ids = position_ids[:, :len(select_token_idx[0]) + cluster_num]
294309
v_token_num = pred_score_vis.sum() + cluster_num
295-
text_token_start = v_token_start + v_token_num
310+
296311
else:
297-
select_token_idx = torch.where(policy == 1)[1].unsqueeze(0)
312+
keep_indexs = torch.where(policy == 1)[1]
313+
select_token_idx = keep_indexs.unsqueeze(0)
298314
layer_outputs = (batch_index_select(layer_outputs[0], select_token_idx),
299315
layer_outputs[1])
300-
position_ids = position_ids[:, :len(select_token_idx[0])]
301316
v_token_num = pred_score_vis.sum()
302-
text_token_start = v_token_start + v_token_num
303317

318+
text_token_start = v_token_start + v_token_num
319+
position_ids = keep_indexs.unsqueeze(0)
304320
new_output = layer_outputs
305-
cache_position = position_ids.detach().clone()
321+
cache_position = position_ids.squeeze(0)
322+
323+
if attention_mask is not None:
324+
attention_mask = attention_mask[:, :, keep_indexs, keep_indexs]
325+
new_pe0 = position_embeddings[0][:, keep_indexs, :].clone()
326+
new_pe1 = position_embeddings[1][:, keep_indexs, :].clone()
327+
position_embeddings = (new_pe0, new_pe1)
306328

307329
pruning_pars['v_token_num'] = v_token_num
308330
pruning_pars['text_token_start'] = text_token_start
331+
309332
pruning_pars['position_ids'] = position_ids
310333
pruning_pars['cache_position'] = cache_position
311-
pruning_pars['position_embeddings'] = None
334+
pruning_pars['position_embeddings'] = position_embeddings
335+
pruning_pars['attention_mask'] = attention_mask
312336

313337
return new_output
314338

315339
@prefill_wrapper
316340
def read_parameter_hook(module, args, kwargs, pruning_pars):
317341
kwargs['position_ids'] = pruning_pars['position_ids']
342+
kwargs['attention_mask'] = pruning_pars['attention_mask']
318343
kwargs['cache_position'] = pruning_pars['cache_position']
319344
kwargs['position_embeddings'] = pruning_pars['position_embeddings']
320345

@@ -363,7 +388,7 @@ def read_parameter_hook(module, args, kwargs, pruning_pars):
363388
with_kwargs=True
364389
)
365390
elif self.model.__class__.__name__ == 'Llava':
366-
self.blocks[block_idx].self_attn.register_forward_pre_hook(
391+
self.blocks[block_idx].register_forward_pre_hook(
367392
functools.partial(
368393
update_kwargs_hook,
369394
pruning_pars=self.pruning_paras,
@@ -383,7 +408,7 @@ def read_parameter_hook(module, args, kwargs, pruning_pars):
383408
functools.partial(
384409
decoder_attn_hook,
385410
pruning_pars=self.pruning_paras,
386-
layer_idx=block_idx,
411+
layer_idx=block_idx
387412
),
388413
with_kwargs=True
389414
)
@@ -397,17 +422,37 @@ def read_parameter_hook(module, args, kwargs, pruning_pars):
397422
)
398423

399424

400-
layer_dict = {2: 0, 6: 1, 15: 2}
401-
402-
sparse_token_list_192 = [300, 200, 110] # 2*576 4*300 10*200 16*110
403-
sparse_token_list_128 = [303, 110, 36]
404-
sparse_token_list_64 = [66, 30, 17]
425+
def update_list():
426+
global sparse_token_list_192, sparse_token_list_128, sparse_token_list_64
427+
global prune_flag, merge_flag, sparse_token_dict
428+
429+
if layer_dict == {2: 0, 6: 1, 15: 2}: # 2*576 4*300 10*200 16*110
430+
sparse_token_list_192 = [300, 200, 110]
431+
sparse_token_list_128 = [303, 110, 36]
432+
sparse_token_list_64 = [66, 30, 17]
433+
prune_flag, merge_flag = True, True
434+
elif prune_flag and merge_flag:
435+
sparse_token_list_192 = [180]
436+
sparse_token_list_128 = [114]
437+
sparse_token_list_64 = [48]
438+
elif prune_flag:
439+
sparse_token_list_192 = [192]
440+
sparse_token_list_128 = [128]
441+
sparse_token_list_64 = [64]
442+
elif merge_flag:
443+
sparse_token_list_192 = [149]
444+
sparse_token_list_128 = [78]
445+
sparse_token_list_64 = [7]
446+
else:
447+
raise RuntimeError(
448+
'Both prune_flag and merge_flag are False — sparseVLM is inactive.'
449+
)
405450

406-
sparse_token_dict = {
407-
192: sparse_token_list_192,
408-
128: sparse_token_list_128,
409-
64: sparse_token_list_64
410-
}
451+
sparse_token_dict = {
452+
192: sparse_token_list_192,
453+
128: sparse_token_list_128,
454+
64: sparse_token_list_64
455+
}
411456

412457

413458
def attn_postprocess_topk(
@@ -567,4 +612,4 @@ def cluster_and_merge(x, cluster_num):
567612
source=source.reshape(B * N, C).type(x.dtype))
568613
x_merged = x_merged.reshape(B, cluster_num, C)
569614

570-
return x_merged
615+
return x_merged, index_down

0 commit comments

Comments
 (0)