Skip to content

Optimize code of FastV and fix SparseVLM's bugs related to LLaVA. #402

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 22 additions & 23 deletions llmc/compression/token_reduction/fastv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
from functools import wraps
from types import MethodType

import torch
Expand Down Expand Up @@ -39,26 +40,23 @@ def input_hook(module, input_args, pruning_paras):

return input_args

def make_hook_prepare_inputs_labels_for_multimodal(pruning_paras):
def hook_prepare_inputs_labels_for_multimodal(
self,
input_ids,
position_ids,
attention_mask,
past_key_values,
labels,
images,
modalities=['image'],
image_sizes=None,
):
if 'image_token_start_index' not in pruning_paras:
token_indices = input_ids[0][attention_mask[0]] == IMAGE_TOKEN_INDEX
pruning_paras['image_token_start_index'] = torch.where(token_indices)[0].item()
return self._original_prepare_inputs_labels_for_multimodal(
input_ids, position_ids, attention_mask,
past_key_values, labels, images, modalities, image_sizes
)
return hook_prepare_inputs_labels_for_multimodal
def input_hook_llava(fn, pruning_paras):
@wraps(fn)
def wrapper(self, *args, **kwargs):
if len(args) == 0:
return fn(*args, **kwargs)
input_args = args[0]
if hasattr(input_args[0], 'shape') and input_args[0].shape[0] == 1:
return fn(*args, **kwargs)

input_ids = args[0]
attention_mask = args[2]
token_indices = input_ids[0][attention_mask[0]] == IMAGE_TOKEN_INDEX
pruning_paras['image_token_start_index'] = torch.where(token_indices)[0].item()

outputs = fn(*args, **kwargs)
return outputs
return wrapper

def update_output_attentions_hook(module, args, kwargs, pruning_paras):
kwargs['output_attentions'] = True
Expand Down Expand Up @@ -129,9 +127,10 @@ def fastv_pruning_hook(module, args, kwargs, pruning_paras):
functools.partial(input_hook, pruning_paras=self.pruning_paras)
)
elif self.model.__class__.__name__ == 'Llava':
hook_fn = make_hook_prepare_inputs_labels_for_multimodal(self.pruning_paras)
self.model.vlm_model._original_prepare_inputs_labels_for_multimodal = (
self.model.vlm_model.prepare_inputs_labels_for_multimodal
from llava.constants import IMAGE_TOKEN_INDEX
hook_fn = input_hook_llava(
self.model.vlm_model.prepare_inputs_labels_for_multimodal,
self.pruning_paras
)
self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType(
hook_fn, self.model.vlm_model
Expand Down
137 changes: 111 additions & 26 deletions llmc/compression/token_reduction/sparsevlm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import copy
import functools
import math
from functools import wraps
from types import MethodType

Expand Down Expand Up @@ -66,22 +68,26 @@ def wrapper(self, *args, **kwargs):
input_ids = args[0]
attention_mask = args[2]

if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
else:
attention_mask = attention_mask.bool()
Comment on lines +71 to +74

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It's better to avoid using torch.ones_like with dtype=torch.bool directly. Instead, create a boolean tensor and then use it to create the ones tensor. This can prevent potential issues with type casting and ensure that the resulting tensor has the correct boolean type.

Consider using input_ids.bool().fill_(True) instead.

                    attention_mask = input_ids.bool().fill_(True)


pre_prompt_length_list = []
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask):
seq = cur_input_ids[cur_attention_mask]
image_token_index = torch.where(seq == IMAGE_TOKEN_INDEX)[0].tolist()
if len(image_token_index) > 0:
pre_prompt_length_list.append(image_token_index[0])
else:
pre_prompt_length_list.append(0)
image_token_index = (
[-1]
+ torch.where(seq == IMAGE_TOKEN_INDEX)[0].tolist()
+ [seq.shape[0]]
)
Comment on lines +79 to +83

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The list concatenation [-1] + ... + [seq.shape[0]] creates a new list in each iteration. This can be inefficient, especially when dealing with long sequences. Consider using torch.cat to concatenate tensors instead, which is generally more efficient for PyTorch tensors.

                        image_token_index = torch.cat([
                            torch.tensor([-1]),
                            torch.where(seq == IMAGE_TOKEN_INDEX)[0],
                            torch.tensor([seq.shape[0]])
                        ]).tolist()

pre_prompt_length_list.append(image_token_index[1])

pruning_paras['pre_prompt_length_list'] = pre_prompt_length_list

outputs = fn(*args, **kwargs)

token_length_list = []
for cur_attention_mask in outputs[2]:
token_length_list.append(cur_attention_mask.sum().item())
pruning_paras['token_length_list'] = token_length_list
pruning_paras['token_length_list'] = outputs[2].sum(dim=1).tolist()

return outputs
return wrapper
Expand Down Expand Up @@ -128,14 +134,90 @@ def update_output_attentions_hook(module, args, kwargs, pruning_pars, layer_idx)
kwargs['position_embeddings'] = pruning_pars['position_embeddings']
return args, kwargs

def get_attn_logits_hook(module, args, kwargs, pruning_pars, layer_idx):

if len(kwargs['position_ids'][0]) == 1:
return args, kwargs

from transformers.models.llama.modeling_llama import \
apply_rotary_pos_emb

if layer_idx != self.pruning_loc[0]:
kwargs['position_ids'] = pruning_pars['position_ids']
kwargs['cache_position'] = pruning_pars['cache_position']
kwargs['position_embeddings'] = pruning_pars['position_embeddings']

hidden_states = kwargs['hidden_states']
position_embeddings = kwargs['position_embeddings']
position_ids = kwargs['position_ids']
past_key_value = kwargs['past_key_value']
cache_position = kwargs['cache_position']
attention_mask = kwargs['attention_mask']

t_token_idx = pruning_pars['t_token_idx']
v_token_start = pruning_pars['v_token_start']
v_token_num = pruning_pars['v_token_num']

bsz, q_len, _ = hidden_states.size()
query_states = module.q_proj(hidden_states)
key_states = module.k_proj(hidden_states)
value_states = module.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, module.num_heads, module.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, module.num_key_value_heads, module.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, module.num_key_value_heads, module.head_dim
).transpose(1, 2)

if position_embeddings is None:
cos, sin = module.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
temp_cache = copy.deepcopy(past_key_value)
cache_kwargs = {'sin': sin, 'cos': cos, 'cache_position': cache_position}
key_states, value_states = temp_cache.update(
key_states, value_states,
layer_idx, cache_kwargs
)
t_token_idx = t_token_idx[1] + v_token_start + v_token_num
L, S = query_states.size(-2), key_states.size(-2)
scale_factor = 1 / math.sqrt(query_states.size(-1))
attn_bias = torch.zeros(L, S, dtype=query_states.dtype)
if module.is_causal:
assert attention_mask is None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Asserting that attention_mask is None might not be the best approach. It's possible that the attention mask is not None but still needs to be handled. Consider checking if the module is causal and if the attention mask is needed based on the specific logic of the module.

            # assert attention_mask is None

temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float('-inf'))
attn_bias.to(query_states.dtype)
Comment on lines +194 to +196

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Instead of creating a temporary mask and then masking the attention bias, you can directly create the attention bias with the desired values. This can simplify the code and potentially improve performance.

                attn_bias = torch.where(torch.tril(torch.ones(L, S, dtype=torch.bool)),
                                      torch.zeros(L, S, dtype=query_states.dtype),
                                      torch.tensor(float('-inf'))).to(query_states.dtype)


attn_logits = query_states @ key_states.transpose(2, 3) * scale_factor
attn_logits += attn_bias.to(query_states.device)
attn_logits = torch.softmax(attn_logits, dim=-1)

pruning_pars['attn_logits'] = attn_logits

return args, kwargs

@prefill_wrapper
def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer_idx):

attn_logits = layer_outputs[1]
# pruning_pars['attn_logits'] 对llavaHf运行存在BUG,
# 使用layer_outputs[1]运行llavaHf无问题,但精度没对上
# llava:attn_logits = pruning_pars['attn_logits']
# llavahf:attn_logits = layer_outputs[1]
if 'attn_logits' not in pruning_pars:
attn_logits = layer_outputs[1]
else:
attn_logits = pruning_pars['attn_logits']
v_token_start = pruning_pars['v_token_start']
v_token_num = pruning_pars['v_token_num']
text_token_start = pruning_pars['text_token_start']
t_token_idx = pruning_pars['t_token_idx']
v_token_num = pruning_pars['v_token_num']
retained_tokens = pruning_pars['retained_tokens']
B = pruning_pars['B']
pre_prompt_length_list = pruning_pars['pre_prompt_length_list']
Expand All @@ -145,10 +227,7 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer
pruning_pars['position_ids'] = position_ids
else:
position_ids = pruning_pars['position_ids']

hidden_states = inputs[0] # [B, L, D]
pre_prompt_length_list = pruning_pars['pre_prompt_length_list']
image_shape = pruning_pars['image_shape']

pred_score_vis, s_flag, relation_vis_text = attn_postprocess_topk(
attn_logits,
Expand Down Expand Up @@ -177,7 +256,6 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer

# merge and cluster
if s_flag and total_sparse_token_idx.shape[1] > 0:
total_sparse_token_idx = torch.where(policy == 0)[1].unsqueeze(0)
total_sparse_token = batch_index_select(layer_outputs[0], total_sparse_token_idx)

merge_token_idx_stage1 = torch.where(pred_score_vis == 0)[1]
Expand Down Expand Up @@ -208,20 +286,17 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer
)
layer_outputs = (select_and_merge_token, layer_outputs[1])
position_ids = position_ids[:, :len(select_token_idx[0]) + cluster_num]
# prev_decision = policy
v_token_num = pred_score_vis.sum() + cluster_num
text_token_start = v_token_start + v_token_num
else:
select_token_idx = torch.where(policy == 1)[1].unsqueeze(0)
layer_outputs = (batch_index_select(layer_outputs[0], select_token_idx),
layer_outputs[1])
position_ids = position_ids[:, :len(select_token_idx[0])]
# prev_decision = policy
v_token_num = pred_score_vis.sum()
text_token_start = v_token_start + v_token_num

new_output = layer_outputs
# hidden_states = layer_outputs[0]
cache_position = position_ids.detach().clone()

pruning_pars['v_token_num'] = v_token_num
Expand Down Expand Up @@ -273,14 +348,24 @@ def read_parameter_hook(module, args, kwargs, pruning_pars):

for block_idx in range(sorted_pruning_locs[0], total_layers):
if block_idx in sorted_pruning_locs:
self.blocks[block_idx].register_forward_pre_hook(
functools.partial(
update_output_attentions_hook,
pruning_pars=self.pruning_paras,
layer_idx=block_idx,
),
with_kwargs=True
)
if self.model.__class__.__name__ == 'LlavaHf':
self.blocks[block_idx].register_forward_pre_hook(
functools.partial(
update_output_attentions_hook,
pruning_pars=self.pruning_paras,
layer_idx=block_idx,
),
with_kwargs=True
)
elif self.model.__class__.__name__ == 'Llava':
self.blocks[block_idx].self_attn.register_forward_pre_hook(
functools.partial(
get_attn_logits_hook,
pruning_pars=self.pruning_paras,
layer_idx=block_idx,
),
with_kwargs=True
)
Comment on lines +360 to +368

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Registering the self_attn.register_forward_pre_hook only for the 'Llava' model might lead to unexpected behavior if the attention mechanism is used differently in other models. Consider making this registration more generic or providing a clear explanation of why it's only needed for 'Llava'.

                    # if self.model.__class__.__name__ == 'Llava':
                    self.blocks[block_idx].self_attn.register_forward_pre_hook(
                        functools.partial(
                            get_attn_logits_hook,
                            pruning_pars=self.pruning_paras,
                            layer_idx=block_idx,
                        ),
                        with_kwargs=True
                    )

self.blocks[block_idx].register_forward_hook(
functools.partial(
decoder_attn_hook,
Expand Down