-
Notifications
You must be signed in to change notification settings - Fork 61
Optimize code of FastV and fix SparseVLM's bugs related to LLaVA. #402
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,6 @@ | ||
import copy | ||
import functools | ||
import math | ||
from functools import wraps | ||
from types import MethodType | ||
|
||
|
@@ -66,22 +68,26 @@ def wrapper(self, *args, **kwargs): | |
input_ids = args[0] | ||
attention_mask = args[2] | ||
|
||
if attention_mask is None: | ||
attention_mask = torch.ones_like(input_ids, dtype=torch.bool) | ||
else: | ||
attention_mask = attention_mask.bool() | ||
|
||
pre_prompt_length_list = [] | ||
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask): | ||
seq = cur_input_ids[cur_attention_mask] | ||
image_token_index = torch.where(seq == IMAGE_TOKEN_INDEX)[0].tolist() | ||
if len(image_token_index) > 0: | ||
pre_prompt_length_list.append(image_token_index[0]) | ||
else: | ||
pre_prompt_length_list.append(0) | ||
image_token_index = ( | ||
[-1] | ||
+ torch.where(seq == IMAGE_TOKEN_INDEX)[0].tolist() | ||
+ [seq.shape[0]] | ||
) | ||
Comment on lines
+79
to
+83
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The list concatenation image_token_index = torch.cat([
torch.tensor([-1]),
torch.where(seq == IMAGE_TOKEN_INDEX)[0],
torch.tensor([seq.shape[0]])
]).tolist() |
||
pre_prompt_length_list.append(image_token_index[1]) | ||
|
||
pruning_paras['pre_prompt_length_list'] = pre_prompt_length_list | ||
|
||
outputs = fn(*args, **kwargs) | ||
|
||
token_length_list = [] | ||
for cur_attention_mask in outputs[2]: | ||
token_length_list.append(cur_attention_mask.sum().item()) | ||
pruning_paras['token_length_list'] = token_length_list | ||
pruning_paras['token_length_list'] = outputs[2].sum(dim=1).tolist() | ||
|
||
return outputs | ||
return wrapper | ||
|
@@ -128,14 +134,90 @@ def update_output_attentions_hook(module, args, kwargs, pruning_pars, layer_idx) | |
kwargs['position_embeddings'] = pruning_pars['position_embeddings'] | ||
return args, kwargs | ||
|
||
def get_attn_logits_hook(module, args, kwargs, pruning_pars, layer_idx): | ||
|
||
if len(kwargs['position_ids'][0]) == 1: | ||
return args, kwargs | ||
|
||
from transformers.models.llama.modeling_llama import \ | ||
apply_rotary_pos_emb | ||
|
||
if layer_idx != self.pruning_loc[0]: | ||
kwargs['position_ids'] = pruning_pars['position_ids'] | ||
kwargs['cache_position'] = pruning_pars['cache_position'] | ||
kwargs['position_embeddings'] = pruning_pars['position_embeddings'] | ||
|
||
hidden_states = kwargs['hidden_states'] | ||
position_embeddings = kwargs['position_embeddings'] | ||
position_ids = kwargs['position_ids'] | ||
past_key_value = kwargs['past_key_value'] | ||
cache_position = kwargs['cache_position'] | ||
attention_mask = kwargs['attention_mask'] | ||
|
||
t_token_idx = pruning_pars['t_token_idx'] | ||
v_token_start = pruning_pars['v_token_start'] | ||
v_token_num = pruning_pars['v_token_num'] | ||
|
||
bsz, q_len, _ = hidden_states.size() | ||
query_states = module.q_proj(hidden_states) | ||
key_states = module.k_proj(hidden_states) | ||
value_states = module.v_proj(hidden_states) | ||
query_states = query_states.view( | ||
bsz, q_len, module.num_heads, module.head_dim | ||
).transpose(1, 2) | ||
key_states = key_states.view( | ||
bsz, q_len, module.num_key_value_heads, module.head_dim | ||
).transpose(1, 2) | ||
value_states = value_states.view( | ||
bsz, q_len, module.num_key_value_heads, module.head_dim | ||
).transpose(1, 2) | ||
|
||
if position_embeddings is None: | ||
cos, sin = module.rotary_emb(value_states, position_ids) | ||
else: | ||
cos, sin = position_embeddings | ||
|
||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) | ||
if past_key_value is not None: | ||
temp_cache = copy.deepcopy(past_key_value) | ||
cache_kwargs = {'sin': sin, 'cos': cos, 'cache_position': cache_position} | ||
key_states, value_states = temp_cache.update( | ||
key_states, value_states, | ||
layer_idx, cache_kwargs | ||
) | ||
t_token_idx = t_token_idx[1] + v_token_start + v_token_num | ||
L, S = query_states.size(-2), key_states.size(-2) | ||
scale_factor = 1 / math.sqrt(query_states.size(-1)) | ||
attn_bias = torch.zeros(L, S, dtype=query_states.dtype) | ||
if module.is_causal: | ||
assert attention_mask is None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Asserting that # assert attention_mask is None |
||
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) | ||
attn_bias.masked_fill_(temp_mask.logical_not(), float('-inf')) | ||
attn_bias.to(query_states.dtype) | ||
Comment on lines
+194
to
+196
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of creating a temporary mask and then masking the attention bias, you can directly create the attention bias with the desired values. This can simplify the code and potentially improve performance. attn_bias = torch.where(torch.tril(torch.ones(L, S, dtype=torch.bool)),
torch.zeros(L, S, dtype=query_states.dtype),
torch.tensor(float('-inf'))).to(query_states.dtype) |
||
|
||
attn_logits = query_states @ key_states.transpose(2, 3) * scale_factor | ||
attn_logits += attn_bias.to(query_states.device) | ||
attn_logits = torch.softmax(attn_logits, dim=-1) | ||
|
||
pruning_pars['attn_logits'] = attn_logits | ||
|
||
return args, kwargs | ||
|
||
@prefill_wrapper | ||
def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer_idx): | ||
|
||
attn_logits = layer_outputs[1] | ||
# pruning_pars['attn_logits'] 对llavaHf运行存在BUG, | ||
# 使用layer_outputs[1]运行llavaHf无问题,但精度没对上 | ||
# llava:attn_logits = pruning_pars['attn_logits'] | ||
# llavahf:attn_logits = layer_outputs[1] | ||
if 'attn_logits' not in pruning_pars: | ||
attn_logits = layer_outputs[1] | ||
else: | ||
attn_logits = pruning_pars['attn_logits'] | ||
v_token_start = pruning_pars['v_token_start'] | ||
v_token_num = pruning_pars['v_token_num'] | ||
text_token_start = pruning_pars['text_token_start'] | ||
t_token_idx = pruning_pars['t_token_idx'] | ||
v_token_num = pruning_pars['v_token_num'] | ||
retained_tokens = pruning_pars['retained_tokens'] | ||
B = pruning_pars['B'] | ||
pre_prompt_length_list = pruning_pars['pre_prompt_length_list'] | ||
|
@@ -145,10 +227,7 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer | |
pruning_pars['position_ids'] = position_ids | ||
else: | ||
position_ids = pruning_pars['position_ids'] | ||
|
||
hidden_states = inputs[0] # [B, L, D] | ||
pre_prompt_length_list = pruning_pars['pre_prompt_length_list'] | ||
image_shape = pruning_pars['image_shape'] | ||
|
||
pred_score_vis, s_flag, relation_vis_text = attn_postprocess_topk( | ||
attn_logits, | ||
|
@@ -177,7 +256,6 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer | |
|
||
# merge and cluster | ||
if s_flag and total_sparse_token_idx.shape[1] > 0: | ||
total_sparse_token_idx = torch.where(policy == 0)[1].unsqueeze(0) | ||
total_sparse_token = batch_index_select(layer_outputs[0], total_sparse_token_idx) | ||
|
||
merge_token_idx_stage1 = torch.where(pred_score_vis == 0)[1] | ||
|
@@ -208,20 +286,17 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer | |
) | ||
layer_outputs = (select_and_merge_token, layer_outputs[1]) | ||
position_ids = position_ids[:, :len(select_token_idx[0]) + cluster_num] | ||
# prev_decision = policy | ||
v_token_num = pred_score_vis.sum() + cluster_num | ||
text_token_start = v_token_start + v_token_num | ||
else: | ||
select_token_idx = torch.where(policy == 1)[1].unsqueeze(0) | ||
layer_outputs = (batch_index_select(layer_outputs[0], select_token_idx), | ||
layer_outputs[1]) | ||
position_ids = position_ids[:, :len(select_token_idx[0])] | ||
# prev_decision = policy | ||
v_token_num = pred_score_vis.sum() | ||
text_token_start = v_token_start + v_token_num | ||
|
||
new_output = layer_outputs | ||
# hidden_states = layer_outputs[0] | ||
cache_position = position_ids.detach().clone() | ||
|
||
pruning_pars['v_token_num'] = v_token_num | ||
|
@@ -273,14 +348,24 @@ def read_parameter_hook(module, args, kwargs, pruning_pars): | |
|
||
for block_idx in range(sorted_pruning_locs[0], total_layers): | ||
if block_idx in sorted_pruning_locs: | ||
self.blocks[block_idx].register_forward_pre_hook( | ||
functools.partial( | ||
update_output_attentions_hook, | ||
pruning_pars=self.pruning_paras, | ||
layer_idx=block_idx, | ||
), | ||
with_kwargs=True | ||
) | ||
if self.model.__class__.__name__ == 'LlavaHf': | ||
self.blocks[block_idx].register_forward_pre_hook( | ||
functools.partial( | ||
update_output_attentions_hook, | ||
pruning_pars=self.pruning_paras, | ||
layer_idx=block_idx, | ||
), | ||
with_kwargs=True | ||
) | ||
elif self.model.__class__.__name__ == 'Llava': | ||
self.blocks[block_idx].self_attn.register_forward_pre_hook( | ||
functools.partial( | ||
get_attn_logits_hook, | ||
pruning_pars=self.pruning_paras, | ||
layer_idx=block_idx, | ||
), | ||
with_kwargs=True | ||
) | ||
Comment on lines
+360
to
+368
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Registering the # if self.model.__class__.__name__ == 'Llava':
self.blocks[block_idx].self_attn.register_forward_pre_hook(
functools.partial(
get_attn_logits_hook,
pruning_pars=self.pruning_paras,
layer_idx=block_idx,
),
with_kwargs=True
) |
||
self.blocks[block_idx].register_forward_hook( | ||
functools.partial( | ||
decoder_attn_hook, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's better to avoid using
torch.ones_like
withdtype=torch.bool
directly. Instead, create a boolean tensor and then use it to create the ones tensor. This can prevent potential issues with type casting and ensure that the resulting tensor has the correct boolean type.Consider using
input_ids.bool().fill_(True)
instead.