Skip to content

Commit 05eedff

Browse files
authored
mustdrop (#413)
1 parent 7de3a7e commit 05eedff

File tree

2 files changed

+201
-0
lines changed

2 files changed

+201
-0
lines changed

llmc/compression/token_reduction/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .fastv import FastV
77
from .fastvid import FastVID
88
from .holitom import HoliTom
9+
from .mustdrop import MustDrop
910
from .prunevid import PruneVid
1011
from .pyramiddrop import PyramidDrop
1112
from .sparsevlm import SparseVLM
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
import functools
2+
3+
import torch
4+
5+
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY
6+
7+
from .token_reduction_module import TokenReductionModule
8+
9+
10+
@TOKEN_REDUCTION_REGISTRY.register('MustDrop')
11+
class MustDrop(TokenReductionModule):
12+
def __init__(self, config, model, blocks):
13+
super().__init__(config, model, blocks)
14+
self.add_sparse_config()
15+
self.register_reduction_modules()
16+
17+
def add_sparse_config(self):
18+
self.pruning_paras = self.special_config
19+
20+
def register_reduction_modules(self):
21+
22+
import math
23+
from typing import Callable, Tuple
24+
25+
import numpy as np
26+
import torch.nn.functional as F
27+
from einops import rearrange
28+
29+
def conditional_pooling(
30+
feat: torch.Tensor,
31+
threshold: float,
32+
window_size: Tuple[int, int],
33+
) -> Tuple[Callable, Callable]:
34+
35+
with torch.no_grad():
36+
37+
ws_h, ws_w = int(window_size[0]), int(window_size[1]) # 窗口尺寸,2*2
38+
stride_h, stride_w = ws_h, ws_w
39+
num_token_window = stride_h * stride_w # 窗口内token数量,4
40+
41+
_, feat = (
42+
feat[:, :1, :],
43+
feat[:, 1:, :],
44+
) # 取出cls token之外的所有tokens,一共576个vision token
45+
B, N, D = feat.size()
46+
base_grid_H = int(math.sqrt(N))
47+
base_grid_W = base_grid_H
48+
assert (
49+
base_grid_H * base_grid_W == N
50+
and base_grid_H % ws_h == 0
51+
and base_grid_W % ws_w == 0
52+
)
53+
54+
feat = rearrange(feat, 'b (h w) c -> b c h w', h=base_grid_H)
55+
56+
feat = rearrange(
57+
feat,
58+
'b c (gh ps_h) (gw ps_w) -> b gh gw c ps_h ps_w',
59+
gh=base_grid_H // ws_h,
60+
gw=base_grid_W // ws_w,
61+
)
62+
b, gh, gw, c, ps_h, ps_w = feat.shape
63+
64+
# Flatten mxm window for pairwise operations
65+
tensor_flattened = feat.reshape(b, gh, gw, c, -1)
66+
67+
# Expand dims for pairwise operations
68+
tensor_1 = tensor_flattened.unsqueeze(-1)
69+
tensor_2 = tensor_flattened.unsqueeze(-2)
70+
71+
# Compute cosine similarities
72+
sims = F.cosine_similarity(tensor_1, tensor_2, dim=3)
73+
74+
# Exclude the self-similarity (i.e., similarity with oneself will be 1)
75+
sims_mask = 1 - torch.eye(ps_h * ps_w).to(sims.device)
76+
sims = sims * sims_mask
77+
78+
# Average similarities (excluding the self-similarity)
79+
similarity_map = sims.sum(-1).sum(-1) / (
80+
(ps_h * ps_w) * (ps_h * ps_w - 1)
81+
)
82+
83+
similarity_map = rearrange(
84+
similarity_map.unsqueeze(1), 'b c h w-> b (c h w)'
85+
)
86+
87+
# --- adaptive section ---#
88+
89+
n_B, n_H = similarity_map.shape
90+
node_mean = torch.tensor(threshold).cuda(sims.device)
91+
node_mean = node_mean.repeat(1, n_H)
92+
r = torch.ge(similarity_map, node_mean).sum(dim=1).min()
93+
# -------------#
94+
95+
# get top k similar super patches
96+
_, sim_super_patch_idxs = similarity_map.topk(r, dim=-1)
97+
98+
# --- creating the mergabel and unmergable super patches
99+
tensor = (
100+
torch.arange(base_grid_H * base_grid_W)
101+
.reshape(base_grid_H, base_grid_W)
102+
.to(feat.device)
103+
)
104+
105+
# Repeat the tensor to create a batch of size 2
106+
tensor = tensor.unsqueeze(0).repeat(B, 1, 1)
107+
108+
# Apply unfold operation on last two dimensions to create the sliding window
109+
windowed_tensor = tensor.unfold(1, ws_h, stride_h).unfold(
110+
2, ws_w, stride_w
111+
)
112+
113+
# Reshape the tensor to the desired shape
114+
windowed_tensor = windowed_tensor.reshape(B, -1, num_token_window)
115+
116+
# Use torch.gather to collect the desired elements
117+
gathered_tensor = torch.gather(
118+
windowed_tensor,
119+
1,
120+
sim_super_patch_idxs.unsqueeze(-1).expand(-1, -1, num_token_window),
121+
)
122+
123+
# Create a mask for all indices, for each batch
124+
mask = torch.ones((B, windowed_tensor.shape[1]), dtype=bool).to(
125+
feat.device
126+
)
127+
128+
# Create a tensor that matches the shape of indices and fill it with False
129+
mask_values = torch.zeros_like(
130+
sim_super_patch_idxs, dtype=torch.bool
131+
).to(feat.device)
132+
133+
# Use scatter_ to update the mask.
134+
# This will set mask[b, indices[b]] = False for all b
135+
mask.scatter_(1, sim_super_patch_idxs, mask_values)
136+
137+
# Get the remaining tensor
138+
remaining_tensor = windowed_tensor[
139+
mask.unsqueeze(-1).expand(-1, -1, num_token_window)
140+
].reshape(B, -1, num_token_window)
141+
unm_idx = (
142+
remaining_tensor.reshape(B, -1).sort(dim=-1).values.unsqueeze(-1)
143+
)
144+
dim_index = (num_token_window) - 1
145+
src_idx = gathered_tensor[:, :, :dim_index].reshape(B, -1).unsqueeze(-1)
146+
dst_idx = gathered_tensor[:, :, dim_index].reshape(B, -1).unsqueeze(-1)
147+
merge_idx = (
148+
torch.arange(src_idx.shape[1] // dim_index)
149+
.repeat_interleave(dim_index)
150+
.repeat(B, 1)
151+
.unsqueeze(-1)
152+
.to(feat.device)
153+
)
154+
155+
def merge(x: torch.Tensor, mode='mean') -> torch.Tensor:
156+
# TODO: num_token_window can be undefined
157+
158+
x_cls, x_feat = x[:, :1, :], x[:, 1:, :]
159+
n, t1, c = x_feat.shape
160+
src = x_feat.gather(dim=-2, index=src_idx.expand(n, r * dim_index, c))
161+
dst = x_feat.gather(dim=-2, index=dst_idx.expand(n, r, c))
162+
unm = x_feat.gather(
163+
dim=-2, index=unm_idx.expand(n, t1 - (r * num_token_window), c)
164+
)
165+
dst = dst.scatter_reduce(
166+
-2, merge_idx.expand(n, r * dim_index, c), src, reduce=mode
167+
)
168+
x = torch.cat([dst, unm], dim=1)
169+
x = torch.cat((x_cls, x), dim=1)
170+
return x
171+
172+
return merge
173+
174+
def merge_wavg(
175+
merge: Callable, x: torch.Tensor, size: torch.Tensor = None
176+
) -> Tuple[torch.Tensor, torch.Tensor]:
177+
178+
if size is None:
179+
size = torch.ones_like(x[..., 0, None])
180+
181+
x = merge(x * size, mode='sum')
182+
size = merge(size, mode='sum')
183+
x = x / size
184+
185+
return x, size
186+
187+
def spatial_merge_hook(module, args, kwargs, pruning_paras):
188+
spatial_threshold = pruning_paras['spatial_threshold']
189+
window_size = pruning_paras['window_size']
190+
hidden_states = args[0]
191+
merge = conditional_pooling(hidden_states, spatial_threshold, window_size)
192+
hidden_states, size = merge_wavg(merge, hidden_states, None)
193+
return (hidden_states,) + args[1:], kwargs
194+
195+
self.model.set_modality('vision')
196+
self.model.find_blocks()
197+
self.model.blocks[1].register_forward_pre_hook(
198+
functools.partial(spatial_merge_hook, pruning_paras=self.pruning_paras),
199+
with_kwargs=True,
200+
)

0 commit comments

Comments
 (0)