Skip to content

Commit 7de3a7e

Browse files
authored
fix bugs for dart fasyv sparsevlm and update ci (#412)
1 parent 3299323 commit 7de3a7e

File tree

7 files changed

+60
-48
lines changed

7 files changed

+60
-48
lines changed

.github/workflows/main.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ jobs:
3030
3131
- name: Download dataset
3232
run: |
33-
# pwd # /home/runner/work/llmc/llmc
33+
# pwd # /home/runner/work/LightCompress/LightCompress
3434
cd tools
3535
python download_calib_dataset.py --save_path ../check/datasets/calib --dataset_name pileval
3636
python download_eval_dataset.py --save_path ../check/datasets/eval --dataset_name wikitext2
@@ -46,17 +46,17 @@ jobs:
4646
4747
- name: Preparation for check.
4848
run: |
49-
cd ci_check # /home/runner/work/llmc/llmc/ci_check
49+
cd ci_check # /home/runner/work/LightCompress/LightCompress/ci_check
5050
python change_files.py
5151
5252
- name: Run awq check
5353
run: |
54-
cd ci_check # /home/runner/work/llmc/llmc/ci_check
54+
cd ci_check # /home/runner/work/LightCompress/LightCompress/ci_check
5555
bash run_awq.sh
5656
5757
- name: Run gptq check
5858
run: |
59-
cd ci_check # /home/runner/work/llmc/llmc/ci_check
59+
cd ci_check # /home/runner/work/LightCompress/LightCompress/ci_check
6060
bash run_gptq.sh
6161
6262
- name: Check success

ci_check/awq_w4a16_fakequant_eval.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@ base:
22
seed: &seed 42
33
model:
44
type: Opt
5-
path: /home/runner/work/llmc/llmc/ci_check/opt-125m
5+
path: /home/runner/work/LightCompress/LightCompress/ci_check/opt-125m
66
torch_dtype: auto
77
calib:
88
name: pileval
99
download: False
10-
path: /home/runner/work/llmc/llmc/check/datasets/calib/pileval
10+
path: /home/runner/work/LightCompress/LightCompress/check/datasets/calib/pileval
1111
n_samples: 4 # 128
1212
bs: -1
1313
seq_len: 16 # 512
@@ -17,7 +17,7 @@ eval:
1717
eval_pos: [pretrain, transformed, fake_quant]
1818
name: wikitext2
1919
download: False
20-
path: /home/runner/work/llmc/llmc/check/datasets/eval/wikitext2
20+
path: /home/runner/work/LightCompress/LightCompress/check/datasets/eval/wikitext2
2121
bs: 1
2222
seq_len: 16 # 2048
2323
eval_token_consist: True
@@ -35,4 +35,4 @@ quant:
3535
clip_sym: False
3636
save:
3737
save_trans: False
38-
save_path: /home/runner/work/llmc/llmc/save/opt-125m_awq_w4a16
38+
save_path: /home/runner/work/LightCompress/LightCompress/save/opt-125m_awq_w4a16

ci_check/gptq_w_only.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@ base:
22
seed: &seed 0
33
model:
44
type: Opt
5-
path: /home/runner/work/llmc/llmc/ci_check/opt-125m
5+
path: /home/runner/work/LightCompress/LightCompress/ci_check/opt-125m
66
torch_dtype: auto
77
calib:
88
name: wikitext2
99
download: False
1010
n_samples: 4
11-
path: /home/runner/work/llmc/llmc/check/datasets/eval/wikitext2
11+
path: /home/runner/work/LightCompress/LightCompress/check/datasets/eval/wikitext2
1212
bs: 1
1313
seq_len: 16
1414
preproc: wikitext2_gptq
@@ -17,7 +17,7 @@ eval:
1717
eval_pos: [fake_quant]
1818
name: wikitext2
1919
download: False
20-
path: /home/runner/work/llmc/llmc/check/datasets/eval/wikitext2
20+
path: /home/runner/work/LightCompress/LightCompress/check/datasets/eval/wikitext2
2121
bs: 1
2222
seq_len: 16
2323
inference_per_block: False
@@ -40,4 +40,4 @@ quant:
4040
quant_out: True
4141
save:
4242
save_fake: False
43-
save_path: /home/runner/work/llmc/llmc/save/opt-125m_gptq_w4a16
43+
save_path: /home/runner/work/LightCompress/LightCompress/save/opt-125m_gptq_w4a16

configs/sparsification/methods/SparseVLM/sparsevlm.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ sparse:
1919
pruning_loc: [2] # [2, 6, 15]
2020
retained_tokens: 192
2121
init_token_total_shape: 668
22+
merge_flag: False
2223
save:
2324
save_trans: False
2425
save_fake: False

llmc/compression/token_reduction/dart.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def wrapper(self, *args, **kwargs):
4444
token_indices = (
4545
input_ids[0][attention_mask[0]] == pruning_paras['IMAGE_TOKEN_INDEX']
4646
)
47-
pruning_paras['image_token_start_index'] = torch.where(token_indices)[0].item()
47+
pruning_paras['image_token_start_index'] = torch.where(token_indices)[0][0].item()
4848

4949
outputs = fn(*args, **kwargs)
5050
return outputs
@@ -67,7 +67,7 @@ def get_any_states_hook(module, args, kwargs, layer_outs, pruning_paras, layer_i
6767
hidden_states = kwargs['hidden_states']
6868
position_embeddings = kwargs['position_embeddings']
6969
position_ids = kwargs['position_ids']
70-
past_key_value = kwargs['past_key_value']
70+
past_key_value = layer_outs[2]
7171

7272
bsz, q_len, _ = hidden_states.size()
7373
query_states = module.q_proj(hidden_states)
@@ -193,10 +193,8 @@ def get_retained_image_token(pruning_paras, last_layer_state, any_states):
193193
) // (pivot_image_token + pivot_text_token))
194194
device = last_layer_state.device
195195

196-
any_states = (
197-
any_states.permute(0, 2, 1, 3)
198-
.reshape(any_states.shape[0], any_states.shape[1], -1)
199-
)
196+
any_states = any_states.permute(0, 2, 1, 3)
197+
any_states = any_states.reshape(any_states.shape[0], any_states.shape[1], -1)
200198

201199
k_states_image_token = any_states[0][
202200
image_token_start_index:image_token_start_index + image_token_length, :

llmc/compression/token_reduction/fastv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def wrapper(self, *args, **kwargs):
5252
attention_mask = args[2]
5353
token_indices = \
5454
input_ids[0][attention_mask[0]] == pruning_paras['IMAGE_TOKEN_INDEX']
55-
pruning_paras['image_token_start_index'] = torch.where(token_indices)[0].item()
55+
pruning_paras['image_token_start_index'] = torch.where(token_indices)[0][0].item()
5656

5757
outputs = fn(*args, **kwargs)
5858
return outputs

llmc/compression/token_reduction/sparsevlm.py

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from .token_reduction_module import TokenReductionModule
1313
from .utils import prefill_wrapper, prefill_wrapper_model
1414

15+
layer_dict = {}
16+
1517

1618
@TOKEN_REDUCTION_REGISTRY.register('SparseVLM')
1719
class SparseVLM(TokenReductionModule):
@@ -24,6 +26,8 @@ def add_sparse_config(self):
2426
special_config = self.config.get('special', {})
2527

2628
self.pruning_loc = special_config.get('pruning_loc', [2, 6, 15])
29+
global layer_dict
30+
layer_dict = {layer: idx for idx, layer in enumerate(self.pruning_loc)}
2731
special_config['retained_tokens'] = special_config.get('retained_tokens', 192)
2832
special_config['init_token_total_shape'] = special_config.get('init_token_total_shape', 668)
2933
special_config['generate_process_count'] = 0
@@ -44,7 +48,8 @@ def input_hook(module, input_args, pruning_pars):
4448
# find the position of the first image token
4549
for seq in input_ids:
4650
image_token_index = (
47-
seq == IMAGE_TOKEN_INDEX).nonzero(as_tuple=True)[0]
51+
seq == IMAGE_TOKEN_INDEX
52+
).nonzero(as_tuple=True)[0]
4853
if len(image_token_index) > 0:
4954
pre_prompt_length_list.append(image_token_index[0].item())
5055
else:
@@ -95,33 +100,31 @@ def wrapper(self, *args, **kwargs):
95100
@prefill_wrapper_model
96101
def register_module_pars(module, args, kwargs, pruning_pars):
97102
pre_prompt_length_list = pruning_pars['pre_prompt_length_list']
98-
inputs_embeds = kwargs['inputs_embeds']
99-
if inputs_embeds is None:
100-
inputs_embeds = module.embed_tokens(kwargs['input_ids'])
101-
hidden_states = inputs_embeds # shape: (B, L, C)
103+
hidden_states = kwargs['inputs_embeds']
104+
if hidden_states is None:
105+
hidden_states = module.embed_tokens(kwargs['input_ids'])
102106

103107
B, L, _ = hidden_states.shape
104108
pruning_pars['B'] = B
105109
init_n = pruning_pars['init_token_total_shape'] + \
106-
pruning_pars['generate_process_count'] # 668
110+
pruning_pars['generate_process_count'] # 668
107111
pruning_pars['prev_decision'] = torch.ones(
108112
B, init_n, 1, dtype=hidden_states.dtype, device=hidden_states.device)
109113
pruning_pars['policy'] = torch.ones(
110114
B, init_n, 1, dtype=hidden_states.dtype, device=hidden_states.device)
111115

112-
pruning_pars['v_token_start'] = pre_prompt_length_list[0] if len(
113-
pre_prompt_length_list) != 0 else 0 # 35
114-
v_token_start = pruning_pars['v_token_start']
115-
pruning_pars['text_token_start'] = pruning_pars['v_token_start'] + \
116-
pruning_pars['image_shape'] # 35 + 576 = 611
117-
text_token_start = pruning_pars['text_token_start']
116+
v_token_start = pre_prompt_length_list[0] if len(
117+
pre_prompt_length_list) != 0 else 0
118+
text_token_start = v_token_start + pruning_pars['image_shape']
119+
pruning_pars['v_token_start'] = v_token_start # 35
120+
pruning_pars['text_token_start'] = text_token_start # 611
118121
pruning_pars['v_token_num'] = pruning_pars['image_shape'] # 576
119122

120123
if (len(pre_prompt_length_list) != 0 and hidden_states.shape[1] != 1):
121124
v_t = hidden_states[:, v_token_start: text_token_start, :]
122125
t_t = hidden_states[:, text_token_start:, :]
123-
m_v_t = v_t @ t_t.transpose(1, 2) # [1, 576, 53] # 52?
124-
m_v_t = m_v_t.softmax(2).mean(1) # [1, 53]
126+
m_v_t = v_t @ t_t.transpose(1, 2) # [1, 576, 52]
127+
m_v_t = m_v_t.softmax(2).mean(1) # [1, 52]
125128
pruning_pars['t_token_idx'] = torch.where(m_v_t > m_v_t.mean())
126129

127130
return args, kwargs
@@ -134,10 +137,20 @@ def update_output_attentions_hook(module, args, kwargs, pruning_pars, layer_idx)
134137
kwargs['position_embeddings'] = pruning_pars['position_embeddings']
135138
return args, kwargs
136139

137-
def get_attn_logits_hook(module, args, kwargs, pruning_pars, layer_idx):
140+
def update_kwargs_hook(module, args, kwargs, pruning_pars, layer_idx):
138141

139142
if len(kwargs['position_ids'][0]) == 1:
140143
return args, kwargs
144+
if layer_idx != self.pruning_loc[0]:
145+
kwargs['position_ids'] = pruning_pars['position_ids']
146+
kwargs['cache_position'] = pruning_pars['cache_position']
147+
kwargs['position_embeddings'] = pruning_pars['position_embeddings']
148+
return args, kwargs
149+
150+
def get_attn_logits_hook(module, args, kwargs, layer_outs, pruning_pars, layer_idx):
151+
152+
if len(kwargs['position_ids'][0]) == 1:
153+
return layer_outs
141154

142155
from transformers.models.llama.modeling_llama import \
143156
apply_rotary_pos_emb
@@ -150,8 +163,7 @@ def get_attn_logits_hook(module, args, kwargs, pruning_pars, layer_idx):
150163
hidden_states = kwargs['hidden_states']
151164
position_embeddings = kwargs['position_embeddings']
152165
position_ids = kwargs['position_ids']
153-
past_key_value = kwargs['past_key_value']
154-
cache_position = kwargs['cache_position']
166+
past_key_value = layer_outs[2]
155167
attention_mask = kwargs['attention_mask']
156168

157169
t_token_idx = pruning_pars['t_token_idx']
@@ -179,12 +191,8 @@ def get_attn_logits_hook(module, args, kwargs, pruning_pars, layer_idx):
179191

180192
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
181193
if past_key_value is not None:
182-
temp_cache = copy.deepcopy(past_key_value)
183-
cache_kwargs = {'sin': sin, 'cos': cos, 'cache_position': cache_position}
184-
key_states, value_states = temp_cache.update(
185-
key_states, value_states,
186-
layer_idx, cache_kwargs
187-
)
194+
key_states = past_key_value.key_cache[layer_idx]
195+
value_states = past_key_value.value_cache[layer_idx]
188196
t_token_idx = t_token_idx[1] + v_token_start + v_token_num
189197
L, S = query_states.size(-2), key_states.size(-2)
190198
scale_factor = 1 / math.sqrt(query_states.size(-1))
@@ -201,19 +209,16 @@ def get_attn_logits_hook(module, args, kwargs, pruning_pars, layer_idx):
201209

202210
pruning_pars['attn_logits'] = attn_logits
203211

204-
return args, kwargs
212+
return layer_outs
205213

206214
@prefill_wrapper
207215
def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer_idx):
208216

209-
# pruning_pars['attn_logits'] 对llavaHf运行存在BUG,
210-
# 使用layer_outputs[1]运行llavaHf无问题,但精度没对上
211-
# llava:attn_logits = pruning_pars['attn_logits']
212-
# llavahf:attn_logits = layer_outputs[1]
213217
if 'attn_logits' not in pruning_pars:
214218
attn_logits = layer_outputs[1]
215219
else:
216220
attn_logits = pruning_pars['attn_logits']
221+
merge_flag = pruning_pars['merge_flag']
217222
v_token_start = pruning_pars['v_token_start']
218223
v_token_num = pruning_pars['v_token_num']
219224
text_token_start = pruning_pars['text_token_start']
@@ -255,7 +260,7 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer
255260
total_sparse_token_idx = torch.where(policy == 0)[1].unsqueeze(0)
256261

257262
# merge and cluster
258-
if s_flag and total_sparse_token_idx.shape[1] > 0:
263+
if s_flag and merge_flag and total_sparse_token_idx.shape[1] > 0:
259264
total_sparse_token = batch_index_select(layer_outputs[0], total_sparse_token_idx)
260265

261266
merge_token_idx_stage1 = torch.where(pred_score_vis == 0)[1]
@@ -359,6 +364,14 @@ def read_parameter_hook(module, args, kwargs, pruning_pars):
359364
)
360365
elif self.model.__class__.__name__ == 'Llava':
361366
self.blocks[block_idx].self_attn.register_forward_pre_hook(
367+
functools.partial(
368+
update_kwargs_hook,
369+
pruning_pars=self.pruning_paras,
370+
layer_idx=block_idx,
371+
),
372+
with_kwargs=True
373+
)
374+
self.blocks[block_idx].self_attn.register_forward_hook(
362375
functools.partial(
363376
get_attn_logits_hook,
364377
pruning_pars=self.pruning_paras,

0 commit comments

Comments
 (0)