@@ -250,21 +250,21 @@ def efficient_rotation_trick_transform(u, q, e):
250
250
2 * (e @ rearrange (u , 'b d -> b d 1' ).detach () @ rearrange (q , 'b d -> b 1 d' ).detach ())
251
251
)
252
252
253
- def rotate_from_to (src , tgt ):
253
+ def rotate_to (src , tgt ):
254
254
# rotation trick STE (https://arxiv.org/abs/2410.06424) to get gradients through VQ layer.
255
- tgt , inverse = pack_one (tgt , '* d' )
256
- src , _ = pack_one (src , '* d' )
255
+ src , inverse = pack_one (src , '* d' )
256
+ tgt , _ = pack_one (tgt , '* d' )
257
257
258
- norm_tgt = tgt .norm (dim = - 1 , keepdim = True )
259
258
norm_src = src .norm (dim = - 1 , keepdim = True )
259
+ norm_tgt = tgt .norm (dim = - 1 , keepdim = True )
260
260
261
- rotated_src = efficient_rotation_trick_transform (
262
- safe_div (tgt , norm_tgt ),
261
+ rotated_tgt = efficient_rotation_trick_transform (
263
262
safe_div (src , norm_src ),
264
- tgt
263
+ safe_div (tgt , norm_tgt ),
264
+ src
265
265
).squeeze ()
266
266
267
- rotated = rotated_src * safe_div (norm_src , norm_tgt ).detach ()
267
+ rotated = rotated_tgt * safe_div (norm_tgt , norm_src ).detach ()
268
268
269
269
return inverse (rotated )
270
270
@@ -1118,7 +1118,7 @@ def forward(
1118
1118
commit_quantize = maybe_detach (quantize )
1119
1119
1120
1120
if self .rotation_trick :
1121
- quantize = rotate_from_to ( quantize , x )
1121
+ quantize = rotate_to ( x , quantize )
1122
1122
else :
1123
1123
# standard STE to get gradients through VQ layer.
1124
1124
quantize = x + (quantize - x ).detach ()
0 commit comments