@@ -15,6 +15,7 @@ def __init__(self, config, model, blocks):
15
15
self .register_reduction_modules ()
16
16
17
17
def add_sparse_config (self ):
18
+ self .pruning_loc = self .special_config ['pruning_loc' ]
18
19
self .pruning_paras = self .special_config
19
20
20
21
def register_reduction_modules (self ):
@@ -30,6 +31,7 @@ def conditional_pooling(
30
31
feat : torch .Tensor ,
31
32
threshold : float ,
32
33
window_size : Tuple [int , int ],
34
+ fix_r : int = 0 ,
33
35
) -> Tuple [Callable , Callable ]:
34
36
35
37
with torch .no_grad ():
@@ -91,7 +93,8 @@ def conditional_pooling(
91
93
node_mean = node_mean .repeat (1 , n_H )
92
94
r = torch .ge (similarity_map , node_mean ).sum (dim = 1 ).min ()
93
95
# -------------#
94
-
96
+ if fix_r != 0 :
97
+ r = fix_r
95
98
# get top k similar super patches
96
99
_ , sim_super_patch_idxs = similarity_map .topk (r , dim = - 1 )
97
100
@@ -184,17 +187,20 @@ def merge_wavg(
184
187
185
188
return x , size
186
189
187
- def spatial_merge_hook (module , args , kwargs , pruning_paras ):
190
+ def spatial_merge_hook (module , args , kwargs , layer_outs , pruning_paras ):
188
191
spatial_threshold = pruning_paras ['spatial_threshold' ]
189
192
window_size = pruning_paras ['window_size' ]
190
- hidden_states = args [0 ]
191
- merge = conditional_pooling (hidden_states , spatial_threshold , window_size )
193
+ hidden_states = layer_outs [0 ]
194
+ fix_r = 0
195
+ if pruning_paras .get ('retained_tokens' , None ) is not None :
196
+ retained_tokens = pruning_paras ['retained_tokens' ]
197
+ fix_r = (pruning_paras ['vision_token_length' ] - retained_tokens ) \
198
+ // (window_size [0 ] * window_size [1 ] - 1 )
199
+ merge = conditional_pooling (hidden_states , spatial_threshold , window_size , fix_r )
192
200
hidden_states , size = merge_wavg (merge , hidden_states , None )
193
- return (hidden_states ,) + args [ 1 :], kwargs
201
+ return (hidden_states ,)
194
202
195
- self .model .set_modality ('vision' )
196
- self .model .find_blocks ()
197
- self .model .blocks [1 ].register_forward_pre_hook (
203
+ self .blocks [self .pruning_loc - 1 ].register_forward_hook (
198
204
functools .partial (spatial_merge_hook , pruning_paras = self .pruning_paras ),
199
205
with_kwargs = True ,
200
206
)
0 commit comments