Skip to content

Commit 0894494

Browse files
committed
update multi turn question eval
1 parent 35ca4db commit 0894494

File tree

6 files changed

+58
-12
lines changed

6 files changed

+58
-12
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
base:
2+
seed: &seed 42
3+
model:
4+
type: Llava
5+
path: model path
6+
torch_dtype: auto
7+
eval:
8+
eval_pos: [transformed] # transformed
9+
name: custom_gen
10+
type: just_infer
11+
download: False
12+
path: /data/nvme1/yongyang/projects/llmc_plus/general_custom_data
13+
apply_chat_template: True
14+
bs: 1
15+
inference_per_block: False
16+
max_new_tokens: 512
17+
statistics: False
18+
sparse:
19+
method: TokenReduction
20+
special:
21+
method: SparseVLM
22+
pruning_loc: [2, 6, 15]
23+
retained_tokens: 192
24+
prune_flag: True
25+
merge_flag: True
26+
save:
27+
save_trans: False
28+
save_fake: False
29+
save_path: /path/to/save/

llmc/compression/token_reduction/fastv.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,12 @@ def fastv_pruning_hook(module, args, kwargs, pruning_paras):
9090
top_attention_rank_index = \
9191
last_layer_attention_avg_last_tok_image.topk(
9292
round(image_token_length * (1 - rate))).indices + image_token_start_index
93+
94+
if self.model.first_turn_question:
95+
module.register_buffer('top_attention_rank_index', top_attention_rank_index)
96+
else:
97+
top_attention_rank_index = module.top_attention_rank_index
98+
9399
# keep index
94100
keep_indexs = torch.cat(
95101
(

llmc/compression/token_reduction/random.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,23 +76,22 @@ def random_pruning_hook(module, args, kwargs, pruning_paras):
7676

7777
device = hidden_states.device
7878

79+
vision_indexes = torch.arange(
80+
image_token_start_index,
81+
image_token_start_index + image_token_length,
82+
device=device,
83+
)
7984
if self.model.first_turn_question:
80-
logger.info(' -----first_turn_question-----')
81-
vision_indexes = torch.arange(
82-
image_token_start_index,
83-
image_token_start_index + image_token_length,
84-
device=device,
85-
)
8685
num_keep = round(image_token_length * (1 - rate))
8786
rand_idx = torch.randperm(image_token_length, device=device)[:num_keep]
8887
vision_indexes = vision_indexes[rand_idx]
8988

90-
# save vision_indexes to module
91-
module.register_buffer('vision_indexes', vision_indexes)
89+
# save rand_idx to module
90+
module.register_buffer('rand_idx', rand_idx)
9291
else:
93-
logger.info(' -----not first_turn_question-----')
9492
# load vision_indexes from module (prompt cache)
95-
vision_indexes = module.vision_indexes
93+
rand_idx = module.rand_idx
94+
vision_indexes = vision_indexes[rand_idx]
9695

9796
# keep index
9897
keep_indexs = torch.cat(

llmc/compression/token_reduction/sparsevlm.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,13 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer
251251
text_token_start = prompt_length + image_shape
252252
policy[batch, text_token_start:] = 1
253253

254+
if self.model.first_turn_question:
255+
vision_mask = policy[:, v_token_start:v_token_start + v_token_num]
256+
module.register_buffer('vision_mask', vision_mask)
257+
else:
258+
vision_mask = module.vision_mask
259+
policy[:, v_token_start:v_token_start + v_token_num] = vision_mask
260+
254261
total_sparse_token_idx = torch.where(policy == 0)[1].unsqueeze(0)
255262

256263
# merge and cluster

llmc/compression/token_reduction/visionzip.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,12 @@ def visionzip_hook(m, images, image_forward_outs):
323323
mask = torch.ones_like(
324324
hidden_states[:, :, 0], dtype=torch.bool, device=metric.device
325325
).scatter_(1, all_indices, False)
326+
327+
if self.model.first_turn_question:
328+
m.register_buffer('mask', mask)
329+
else:
330+
mask = m.mask
331+
326332
dominant_tokens = hidden_states.masked_select(~mask.unsqueeze(-1)).view(
327333
hidden_states.shape[0], dominant_num + 1, hidden_states.shape[2]
328334
)

llmc/models/llava.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def build_model(self):
7575
'IMAGE_TOKEN_INDEX': IMAGE_TOKEN_INDEX, # for llava
7676
}
7777
self.processor = None
78+
self.first_turn_question = True
7879

7980
def get_extra_rot_module_besides_embed_layers(self):
8081
return [self.vision_projector[2]]
@@ -163,8 +164,6 @@ def load_images(image_files):
163164
out.append(image)
164165
return out
165166

166-
self.first_turn_question = True
167-
168167
for data_idx, questions in enumerate(img_qas):
169168
self.first_turn_question = True
170169

0 commit comments

Comments
 (0)