|
| 1 | +import functools |
| 2 | + |
| 3 | +import torch |
| 4 | + |
| 5 | +from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY |
| 6 | + |
| 7 | +from .token_reduction_module import TokenReductionModule |
| 8 | + |
| 9 | + |
| 10 | +@TOKEN_REDUCTION_REGISTRY.register('MustDrop') |
| 11 | +class MustDrop(TokenReductionModule): |
| 12 | + def __init__(self, config, model, blocks): |
| 13 | + super().__init__(config, model, blocks) |
| 14 | + self.add_sparse_config() |
| 15 | + self.register_reduction_modules() |
| 16 | + |
| 17 | + def add_sparse_config(self): |
| 18 | + self.pruning_paras = self.special_config |
| 19 | + |
| 20 | + def register_reduction_modules(self): |
| 21 | + |
| 22 | + import math |
| 23 | + from typing import Callable, Tuple |
| 24 | + |
| 25 | + import numpy as np |
| 26 | + import torch.nn.functional as F |
| 27 | + from einops import rearrange |
| 28 | + |
| 29 | + def conditional_pooling( |
| 30 | + feat: torch.Tensor, |
| 31 | + threshold: float, |
| 32 | + window_size: Tuple[int, int], |
| 33 | + ) -> Tuple[Callable, Callable]: |
| 34 | + |
| 35 | + with torch.no_grad(): |
| 36 | + |
| 37 | + ws_h, ws_w = int(window_size[0]), int(window_size[1]) # 窗口尺寸,2*2 |
| 38 | + stride_h, stride_w = ws_h, ws_w |
| 39 | + num_token_window = stride_h * stride_w # 窗口内token数量,4 |
| 40 | + |
| 41 | + _, feat = ( |
| 42 | + feat[:, :1, :], |
| 43 | + feat[:, 1:, :], |
| 44 | + ) # 取出cls token之外的所有tokens,一共576个vision token |
| 45 | + B, N, D = feat.size() |
| 46 | + base_grid_H = int(math.sqrt(N)) |
| 47 | + base_grid_W = base_grid_H |
| 48 | + assert ( |
| 49 | + base_grid_H * base_grid_W == N |
| 50 | + and base_grid_H % ws_h == 0 |
| 51 | + and base_grid_W % ws_w == 0 |
| 52 | + ) |
| 53 | + |
| 54 | + feat = rearrange(feat, 'b (h w) c -> b c h w', h=base_grid_H) |
| 55 | + |
| 56 | + feat = rearrange( |
| 57 | + feat, |
| 58 | + 'b c (gh ps_h) (gw ps_w) -> b gh gw c ps_h ps_w', |
| 59 | + gh=base_grid_H // ws_h, |
| 60 | + gw=base_grid_W // ws_w, |
| 61 | + ) |
| 62 | + b, gh, gw, c, ps_h, ps_w = feat.shape |
| 63 | + |
| 64 | + # Flatten mxm window for pairwise operations |
| 65 | + tensor_flattened = feat.reshape(b, gh, gw, c, -1) |
| 66 | + |
| 67 | + # Expand dims for pairwise operations |
| 68 | + tensor_1 = tensor_flattened.unsqueeze(-1) |
| 69 | + tensor_2 = tensor_flattened.unsqueeze(-2) |
| 70 | + |
| 71 | + # Compute cosine similarities |
| 72 | + sims = F.cosine_similarity(tensor_1, tensor_2, dim=3) |
| 73 | + |
| 74 | + # Exclude the self-similarity (i.e., similarity with oneself will be 1) |
| 75 | + sims_mask = 1 - torch.eye(ps_h * ps_w).to(sims.device) |
| 76 | + sims = sims * sims_mask |
| 77 | + |
| 78 | + # Average similarities (excluding the self-similarity) |
| 79 | + similarity_map = sims.sum(-1).sum(-1) / ( |
| 80 | + (ps_h * ps_w) * (ps_h * ps_w - 1) |
| 81 | + ) |
| 82 | + |
| 83 | + similarity_map = rearrange( |
| 84 | + similarity_map.unsqueeze(1), 'b c h w-> b (c h w)' |
| 85 | + ) |
| 86 | + |
| 87 | + # --- adaptive section ---# |
| 88 | + |
| 89 | + n_B, n_H = similarity_map.shape |
| 90 | + node_mean = torch.tensor(threshold).cuda(sims.device) |
| 91 | + node_mean = node_mean.repeat(1, n_H) |
| 92 | + r = torch.ge(similarity_map, node_mean).sum(dim=1).min() |
| 93 | + # -------------# |
| 94 | + |
| 95 | + # get top k similar super patches |
| 96 | + _, sim_super_patch_idxs = similarity_map.topk(r, dim=-1) |
| 97 | + |
| 98 | + # --- creating the mergabel and unmergable super patches |
| 99 | + tensor = ( |
| 100 | + torch.arange(base_grid_H * base_grid_W) |
| 101 | + .reshape(base_grid_H, base_grid_W) |
| 102 | + .to(feat.device) |
| 103 | + ) |
| 104 | + |
| 105 | + # Repeat the tensor to create a batch of size 2 |
| 106 | + tensor = tensor.unsqueeze(0).repeat(B, 1, 1) |
| 107 | + |
| 108 | + # Apply unfold operation on last two dimensions to create the sliding window |
| 109 | + windowed_tensor = tensor.unfold(1, ws_h, stride_h).unfold( |
| 110 | + 2, ws_w, stride_w |
| 111 | + ) |
| 112 | + |
| 113 | + # Reshape the tensor to the desired shape |
| 114 | + windowed_tensor = windowed_tensor.reshape(B, -1, num_token_window) |
| 115 | + |
| 116 | + # Use torch.gather to collect the desired elements |
| 117 | + gathered_tensor = torch.gather( |
| 118 | + windowed_tensor, |
| 119 | + 1, |
| 120 | + sim_super_patch_idxs.unsqueeze(-1).expand(-1, -1, num_token_window), |
| 121 | + ) |
| 122 | + |
| 123 | + # Create a mask for all indices, for each batch |
| 124 | + mask = torch.ones((B, windowed_tensor.shape[1]), dtype=bool).to( |
| 125 | + feat.device |
| 126 | + ) |
| 127 | + |
| 128 | + # Create a tensor that matches the shape of indices and fill it with False |
| 129 | + mask_values = torch.zeros_like( |
| 130 | + sim_super_patch_idxs, dtype=torch.bool |
| 131 | + ).to(feat.device) |
| 132 | + |
| 133 | + # Use scatter_ to update the mask. |
| 134 | + # This will set mask[b, indices[b]] = False for all b |
| 135 | + mask.scatter_(1, sim_super_patch_idxs, mask_values) |
| 136 | + |
| 137 | + # Get the remaining tensor |
| 138 | + remaining_tensor = windowed_tensor[ |
| 139 | + mask.unsqueeze(-1).expand(-1, -1, num_token_window) |
| 140 | + ].reshape(B, -1, num_token_window) |
| 141 | + unm_idx = ( |
| 142 | + remaining_tensor.reshape(B, -1).sort(dim=-1).values.unsqueeze(-1) |
| 143 | + ) |
| 144 | + dim_index = (num_token_window) - 1 |
| 145 | + src_idx = gathered_tensor[:, :, :dim_index].reshape(B, -1).unsqueeze(-1) |
| 146 | + dst_idx = gathered_tensor[:, :, dim_index].reshape(B, -1).unsqueeze(-1) |
| 147 | + merge_idx = ( |
| 148 | + torch.arange(src_idx.shape[1] // dim_index) |
| 149 | + .repeat_interleave(dim_index) |
| 150 | + .repeat(B, 1) |
| 151 | + .unsqueeze(-1) |
| 152 | + .to(feat.device) |
| 153 | + ) |
| 154 | + |
| 155 | + def merge(x: torch.Tensor, mode='mean') -> torch.Tensor: |
| 156 | + # TODO: num_token_window can be undefined |
| 157 | + |
| 158 | + x_cls, x_feat = x[:, :1, :], x[:, 1:, :] |
| 159 | + n, t1, c = x_feat.shape |
| 160 | + src = x_feat.gather(dim=-2, index=src_idx.expand(n, r * dim_index, c)) |
| 161 | + dst = x_feat.gather(dim=-2, index=dst_idx.expand(n, r, c)) |
| 162 | + unm = x_feat.gather( |
| 163 | + dim=-2, index=unm_idx.expand(n, t1 - (r * num_token_window), c) |
| 164 | + ) |
| 165 | + dst = dst.scatter_reduce( |
| 166 | + -2, merge_idx.expand(n, r * dim_index, c), src, reduce=mode |
| 167 | + ) |
| 168 | + x = torch.cat([dst, unm], dim=1) |
| 169 | + x = torch.cat((x_cls, x), dim=1) |
| 170 | + return x |
| 171 | + |
| 172 | + return merge |
| 173 | + |
| 174 | + def merge_wavg( |
| 175 | + merge: Callable, x: torch.Tensor, size: torch.Tensor = None |
| 176 | + ) -> Tuple[torch.Tensor, torch.Tensor]: |
| 177 | + |
| 178 | + if size is None: |
| 179 | + size = torch.ones_like(x[..., 0, None]) |
| 180 | + |
| 181 | + x = merge(x * size, mode='sum') |
| 182 | + size = merge(size, mode='sum') |
| 183 | + x = x / size |
| 184 | + |
| 185 | + return x, size |
| 186 | + |
| 187 | + def spatial_merge_hook(module, args, kwargs, pruning_paras): |
| 188 | + spatial_threshold = pruning_paras['spatial_threshold'] |
| 189 | + window_size = pruning_paras['window_size'] |
| 190 | + hidden_states = args[0] |
| 191 | + merge = conditional_pooling(hidden_states, spatial_threshold, window_size) |
| 192 | + hidden_states, size = merge_wavg(merge, hidden_states, None) |
| 193 | + return (hidden_states,) + args[1:], kwargs |
| 194 | + |
| 195 | + self.model.set_modality('vision') |
| 196 | + self.model.find_blocks() |
| 197 | + self.model.blocks[1].register_forward_pre_hook( |
| 198 | + functools.partial(spatial_merge_hook, pruning_paras=self.pruning_paras), |
| 199 | + with_kwargs=True, |
| 200 | + ) |
0 commit comments