Skip to content

support DART for llava #403

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions configs/sparsification/methods/DART/dart.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
base:
seed: &seed 42
model:
type: Llava
path: model path
torch_dtype: auto
eval:
eval_pos: [pretrain, transformed]
type: vqa
name: [mme]
download: False
path: MME dataset path
bs: 1
inference_per_block: False
sparse:
method: TokenReduction
special:
method: DART
pruning_loc: 2
reduction_ratio: 0.778
max_num_trunction: 128
pivot_image_token: 4
pivot_text_token : 4

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is an extra space before the colon. Please remove it for consistency.

        pivot_text_token: 4

save:
save_trans: False
save_fake: False
save_path: /path/to/save/
1 change: 1 addition & 0 deletions llmc/compression/token_reduction/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .base_blockwise_token_reduction import TokenReduction
from .dart import DART
from .dycoke import DyCoke
from .fastervlm import FasterVLM
from .fastv import FastV
Expand Down
243 changes: 243 additions & 0 deletions llmc/compression/token_reduction/dart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
import functools
import math
from functools import wraps
from types import MethodType

import torch

from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY

from .token_reduction_module import TokenReductionModule
from .utils import prefill_wrapper


@TOKEN_REDUCTION_REGISTRY.register('DART')
class DART(TokenReductionModule):
def __init__(self, config, model, blocks):
super().__init__(config, model, blocks)
self.add_sparse_config()
self.register_reduction_modules()

def add_sparse_config(self):

self.pruning_loc = self.special_config['pruning_loc']
self.special_config['image_token_length'] = \
self.model.pruning_config['image_token_length']
self.special_config['IMAGE_TOKEN_INDEX'] = \
self.model.pruning_config['IMAGE_TOKEN_INDEX']

self.pruning_paras = self.special_config

def register_reduction_modules(self):

def input_hook_llava(fn, pruning_paras):
@wraps(fn)
def wrapper(self, *args, **kwargs):
if len(args) == 0:
return fn(*args, **kwargs)
input_args = args[0]
if hasattr(input_args[0], 'shape') and input_args[0].shape[0] == 1:
return fn(*args, **kwargs)

input_ids = args[0]
attention_mask = args[2]
token_indices = (
input_ids[0][attention_mask[0]] == pruning_paras['IMAGE_TOKEN_INDEX']
)
pruning_paras['image_token_start_index'] = torch.where(token_indices)[0].item()

outputs = fn(*args, **kwargs)
return outputs
return wrapper

def get_seq_len_hook(module, args, kwargs, pruning_paras):
if kwargs['input_ids'] is not None:
pruning_paras['seq_len'] = kwargs['input_ids'].shape[1]
elif kwargs['inputs_embeds'] is not None:
pruning_paras['seq_len'] = kwargs['inputs_embeds'].shape[1]
else:
raise ValueError('You have to specify either input_ids or inputs_embeds')

def get_any_states_hook(module, args, kwargs, layer_outs, pruning_paras, layer_idx):
from transformers.models.llama.modeling_llama import (
apply_rotary_pos_emb, repeat_kv)
if len(kwargs['position_ids'][0]) == 1:
return layer_outs

hidden_states = kwargs['hidden_states']
position_embeddings = kwargs['position_embeddings']
position_ids = kwargs['position_ids']
past_key_value = kwargs['past_key_value']

bsz, q_len, _ = hidden_states.size()
query_states = module.q_proj(hidden_states)
key_states = module.k_proj(hidden_states)
value_states = module.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, module.num_heads, module.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, module.num_key_value_heads, module.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, module.num_key_value_heads, module.head_dim
).transpose(1, 2)

if position_embeddings is None:
cos, sin = module.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
key_states = past_key_value.key_cache[layer_idx]
value_states = past_key_value.value_cache[layer_idx]
key_states = repeat_kv(key_states, module.num_key_value_groups)
value_states = repeat_kv(value_states, module.num_key_value_groups)

pruning_paras['any_states'] = (query_states, key_states, value_states)

return layer_outs

@prefill_wrapper
def pruning_hook(module, args, kwargs, pruning_paras, normlayer):

image_token_start_index = pruning_paras['image_token_start_index']
image_token_length = pruning_paras['image_token_length']
any_states = pruning_paras['any_states'][-2]
seq_length = pruning_paras['seq_len']

hidden_states = args[0]
attention_mask = kwargs['attention_mask']
device = hidden_states.device
last_layer_state = normlayer(hidden_states)

# keep index
retained_image_tokens_index = get_retained_image_token(
pruning_paras, last_layer_state, any_states)

keep_indexs = torch.cat(
(
torch.arange(image_token_start_index, device=device),
retained_image_tokens_index,
torch.arange(
image_token_start_index + image_token_length,
seq_length,
device=device
)
)
)
# sort index
keep_indexs = keep_indexs.sort().values
hidden_states = hidden_states[:, keep_indexs, :]
position_ids = keep_indexs.unsqueeze(0)
if attention_mask is not None:
attention_mask = attention_mask[
:, :, :hidden_states.shape[1], :hidden_states.shape[1]
]
kwargs['attention_mask'].resize_as_(attention_mask).copy_(attention_mask.clone())
kwargs['cache_position'].resize_as_(position_ids.squeeze(0)).copy_(
position_ids.squeeze(0).clone())
kwargs['position_ids'].resize_as_(position_ids).copy_(position_ids.clone())

position_embeddings = kwargs['position_embeddings']
new_pe0 = position_embeddings[0][:, keep_indexs, :].clone()
new_pe1 = position_embeddings[1][:, keep_indexs, :].clone()
position_embeddings[0].resize_as_(new_pe0).copy_(new_pe0)
position_embeddings[1].resize_as_(new_pe0).copy_(new_pe1)

return (hidden_states,), kwargs

hook_fn = input_hook_llava(
self.model.vlm_model.prepare_inputs_labels_for_multimodal,
self.pruning_paras
)
self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType(
hook_fn, self.model.vlm_model
)

self.model.model.model.register_forward_pre_hook(
functools.partial(get_seq_len_hook, pruning_paras=self.pruning_paras),
with_kwargs=True
)

self.blocks[self.pruning_loc - 1].self_attn.register_forward_hook(
functools.partial(
get_any_states_hook,
pruning_paras=self.pruning_paras,
layer_idx=self.pruning_loc - 1
),
with_kwargs=True
)

self.blocks[self.pruning_loc].register_forward_pre_hook(
functools.partial(
pruning_hook,
pruning_paras=self.pruning_paras,
normlayer=self.model.model.model.norm
),
with_kwargs=True
)


def get_retained_image_token(pruning_paras, last_layer_state, any_states):
image_token_start_index = pruning_paras['image_token_start_index']
image_token_length = pruning_paras['image_token_length']
MAX_NUM_TRUNCTION = pruning_paras['max_num_trunction']
pivot_image_token = pruning_paras['pivot_image_token']
pivot_text_token = pruning_paras['pivot_text_token']
reduction_ratio = pruning_paras['reduction_ratio']
TOKEN_TOPK = math.ceil(
(
MAX_NUM_TRUNCTION if MAX_NUM_TRUNCTION is not None
else (image_token_length * (1 - reduction_ratio))
) // (pivot_image_token + pivot_text_token))
device = last_layer_state.device

any_states = (
any_states.permute(0, 2, 1, 3)
.reshape(any_states.shape[0], any_states.shape[1], -1)
)

k_states_image_token = any_states[0][
image_token_start_index:image_token_start_index + image_token_length, :
]
k_states_query_token = any_states[0][image_token_start_index + image_token_length:, :]

k_states_image_token_L1_norm = torch.norm(k_states_image_token, p=1, dim=-1)
k_states_query_token_L1_norm = torch.norm(k_states_query_token, p=1, dim=-1)

image_indices = (
k_states_image_token_L1_norm.topk(pivot_image_token).indices
+ image_token_start_index
).tolist()
query_indices = (
k_states_query_token_L1_norm.topk(pivot_text_token).indices
+ image_token_start_index + image_token_length
).tolist()
indices_set = set(image_indices + query_indices)

valid_indices = set(
range(image_token_start_index, image_token_start_index + image_token_length)
) - set(image_indices)

valid_indices_list = list(valid_indices)
for item in list(indices_set):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The list valid_indices_list can become empty inside this loop. If it does, valid_vectors on line 225 will be an empty tensor, and cos_sim.topk(...) on line 231 will raise an error. Please add a check at the beginning of the loop to handle this case, for example by breaking out of the loop if valid_indices_list is empty.

valid_vectors = last_layer_state[0][valid_indices_list, :]
cos_sim = -torch.nn.functional.cosine_similarity(
last_layer_state[0][item, :],
valid_vectors,
dim=-1
)
top_k_indices = cos_sim.topk(TOKEN_TOPK).indices

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The topk method will raise an error if TOKEN_TOPK is greater than the number of elements in cos_sim. The number of elements in cos_sim depends on len(valid_indices_list), which can be smaller than TOKEN_TOPK. You should ensure that the value passed to topk is not larger than the number of elements in the tensor.

Suggested change
top_k_indices = cos_sim.topk(TOKEN_TOPK).indices
top_k_indices = cos_sim.topk(min(TOKEN_TOPK, cos_sim.shape[0])).indices


top_k_real_indices = [valid_indices_list[i] for i in top_k_indices]
indices_set.update(top_k_real_indices)

valid_indices.difference_update(top_k_real_indices)
valid_indices_list = list(valid_indices)

indices_set.difference_update(query_indices)

retained_image_tokens_index = torch.tensor(list(indices_set), device=device)

return retained_image_tokens_index
9 changes: 4 additions & 5 deletions llmc/compression/token_reduction/fastv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,11 @@

import torch

from llmc.compression.sparsification.attn_utils import _update_causal_mask
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY

from .token_reduction_module import TokenReductionModule
from .utils import prefill_wrapper

IMAGE_TOKEN_INDEX = -200


@TOKEN_REDUCTION_REGISTRY.register('FastV')
class FastV(TokenReductionModule):
Expand All @@ -25,6 +22,8 @@ def add_sparse_config(self):
self.pruning_loc = self.special_config['pruning_loc']
self.special_config['image_token_length'] = \
self.model.pruning_config['image_token_length']
self.special_config['IMAGE_TOKEN_INDEX'] = \
self.model.pruning_config['IMAGE_TOKEN_INDEX']
self.special_config['attn_scores'] = None

self.pruning_paras = self.special_config
Expand All @@ -51,7 +50,8 @@ def wrapper(self, *args, **kwargs):

input_ids = args[0]
attention_mask = args[2]
token_indices = input_ids[0][attention_mask[0]] == IMAGE_TOKEN_INDEX
token_indices = \
input_ids[0][attention_mask[0]] == pruning_paras['IMAGE_TOKEN_INDEX']
pruning_paras['image_token_start_index'] = torch.where(token_indices)[0].item()

outputs = fn(*args, **kwargs)
Expand Down Expand Up @@ -127,7 +127,6 @@ def fastv_pruning_hook(module, args, kwargs, pruning_paras):
functools.partial(input_hook, pruning_paras=self.pruning_paras)
)
elif self.model.__class__.__name__ == 'Llava':
from llava.constants import IMAGE_TOKEN_INDEX
hook_fn = input_hook_llava(
self.model.vlm_model.prepare_inputs_labels_for_multimodal,
self.pruning_paras
Expand Down
6 changes: 4 additions & 2 deletions llmc/compression/token_reduction/pyramiddrop.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def __init__(self, config, model, blocks):
def add_sparse_config(self):

self.pruning_loc = self.special_config['layer_list']
self.special_config['IMAGE_TOKEN_INDEX'] = \
self.model.pruning_config['IMAGE_TOKEN_INDEX']

image_token_ratio_list = self.special_config['image_token_ratio_list']
image_token_ratio_list.insert(0, 1.0)
self.special_config['image_token_ratio_list'] = image_token_ratio_list
Expand Down Expand Up @@ -348,7 +351,7 @@ def wrapper(self, *args, **kwargs):
vision_tokens = []
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask):
seq = cur_input_ids[cur_attention_mask]
image_index = torch.where(seq == IMAGE_TOKEN_INDEX)[0].tolist()
image_index = torch.where(seq == pruning_paras['IMAGE_TOKEN_INDEX'])[0].tolist()
if image_index == []:
image_token_posi.append(-1)
prompt_len.append(cur_input_ids.shape[0])
Expand Down Expand Up @@ -378,7 +381,6 @@ def read_parameter_hook(module, args, kwargs, pruning_pars):
functools.partial(input_hook, pruning_pars=self.pruning_paras)
)
elif self.model.__class__.__name__ == 'Llava':
from llava.constants import IMAGE_TOKEN_INDEX
hook_fn = input_hook_llava(
self.model.vlm_model.prepare_inputs_labels_for_multimodal,
self.pruning_paras
Expand Down
5 changes: 3 additions & 2 deletions llmc/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

try:
from llava.constants import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
DEFAULT_IMAGE_PATCH_TOKEN)
DEFAULT_IMAGE_PATCH_TOKEN, IMAGE_TOKEN_INDEX)
from llava.mm_utils import get_model_name_from_path
from llava.model.builder import load_pretrained_model
from llava.model.language_model.llava_llama import LlavaConfig
Expand Down Expand Up @@ -66,7 +66,8 @@ def build_model(self):
'image_token_length': self.vlm_model_config.image_seq_length,
'select_layer': self.vlm_model_config.vision_feature_layer,
'select_feature': self.vlm_model_config.vision_feature_select_strategy,
'image_token_index': self.vlm_model_config.image_token_index
'image_token_index': self.vlm_model_config.image_token_index,
'IMAGE_TOKEN_INDEX': IMAGE_TOKEN_INDEX, # for llava
}
self.processor = None

Expand Down