Skip to content

Commit 43e54ac

Browse files
authored
update mustdrop (#414)
1 parent 05eedff commit 43e54ac

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

llmc/compression/token_reduction/mustdrop.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def __init__(self, config, model, blocks):
1515
self.register_reduction_modules()
1616

1717
def add_sparse_config(self):
18+
self.pruning_loc = self.special_config['pruning_loc']
1819
self.pruning_paras = self.special_config
1920

2021
def register_reduction_modules(self):
@@ -30,6 +31,7 @@ def conditional_pooling(
3031
feat: torch.Tensor,
3132
threshold: float,
3233
window_size: Tuple[int, int],
34+
fix_r: int = 0,
3335
) -> Tuple[Callable, Callable]:
3436

3537
with torch.no_grad():
@@ -91,7 +93,8 @@ def conditional_pooling(
9193
node_mean = node_mean.repeat(1, n_H)
9294
r = torch.ge(similarity_map, node_mean).sum(dim=1).min()
9395
# -------------#
94-
96+
if fix_r != 0:
97+
r = fix_r
9598
# get top k similar super patches
9699
_, sim_super_patch_idxs = similarity_map.topk(r, dim=-1)
97100

@@ -184,17 +187,20 @@ def merge_wavg(
184187

185188
return x, size
186189

187-
def spatial_merge_hook(module, args, kwargs, pruning_paras):
190+
def spatial_merge_hook(module, args, kwargs, layer_outs, pruning_paras):
188191
spatial_threshold = pruning_paras['spatial_threshold']
189192
window_size = pruning_paras['window_size']
190-
hidden_states = args[0]
191-
merge = conditional_pooling(hidden_states, spatial_threshold, window_size)
193+
hidden_states = layer_outs[0]
194+
fix_r = 0
195+
if pruning_paras.get('retained_tokens', None) is not None:
196+
retained_tokens = pruning_paras['retained_tokens']
197+
fix_r = (pruning_paras['vision_token_length'] - retained_tokens) \
198+
// (window_size[0] * window_size[1] - 1)
199+
merge = conditional_pooling(hidden_states, spatial_threshold, window_size, fix_r)
192200
hidden_states, size = merge_wavg(merge, hidden_states, None)
193-
return (hidden_states,) + args[1:], kwargs
201+
return (hidden_states,)
194202

195-
self.model.set_modality('vision')
196-
self.model.find_blocks()
197-
self.model.blocks[1].register_forward_pre_hook(
203+
self.blocks[self.pruning_loc - 1].register_forward_hook(
198204
functools.partial(spatial_merge_hook, pruning_paras=self.pruning_paras),
199205
with_kwargs=True,
200206
)

0 commit comments

Comments
 (0)