Skip to content

Commit 949b0ba

Browse files
committed
cleanup
1 parent 80f4e84 commit 949b0ba

File tree

2 files changed

+11
-11
lines changed

2 files changed

+11
-11
lines changed

vector_quantize_pytorch/sim_vq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from einx import get_at
1010
from einops import rearrange, pack, unpack
1111

12-
from vector_quantize_pytorch.vector_quantize_pytorch import rotate_from_to
12+
from vector_quantize_pytorch.vector_quantize_pytorch import rotate_to
1313

1414
# helper functions
1515

@@ -94,7 +94,7 @@ def forward(
9494

9595
if self.rotation_trick:
9696
# rotation trick from @cfifty
97-
quantized = rotate_from_to(quantized, x)
97+
quantized = rotate_to(x, quantized)
9898
else:
9999

100100
commit_loss = (

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -250,21 +250,21 @@ def efficient_rotation_trick_transform(u, q, e):
250250
2 * (e @ rearrange(u, 'b d -> b d 1').detach() @ rearrange(q, 'b d -> b 1 d').detach())
251251
)
252252

253-
def rotate_from_to(src, tgt):
253+
def rotate_to(src, tgt):
254254
# rotation trick STE (https://arxiv.org/abs/2410.06424) to get gradients through VQ layer.
255-
tgt, inverse = pack_one(tgt, '* d')
256-
src, _ = pack_one(src, '* d')
255+
src, inverse = pack_one(src, '* d')
256+
tgt, _ = pack_one(tgt, '* d')
257257

258-
norm_tgt = tgt.norm(dim = -1, keepdim = True)
259258
norm_src = src.norm(dim = -1, keepdim = True)
259+
norm_tgt = tgt.norm(dim = -1, keepdim = True)
260260

261-
rotated_src = efficient_rotation_trick_transform(
262-
safe_div(tgt, norm_tgt),
261+
rotated_tgt = efficient_rotation_trick_transform(
263262
safe_div(src, norm_src),
264-
tgt
263+
safe_div(tgt, norm_tgt),
264+
src
265265
).squeeze()
266266

267-
rotated = rotated_src * safe_div(norm_src, norm_tgt).detach()
267+
rotated = rotated_tgt * safe_div(norm_tgt, norm_src).detach()
268268

269269
return inverse(rotated)
270270

@@ -1118,7 +1118,7 @@ def forward(
11181118
commit_quantize = maybe_detach(quantize)
11191119

11201120
if self.rotation_trick:
1121-
quantize = rotate_from_to(quantize, x)
1121+
quantize = rotate_to(x, quantize)
11221122
else:
11231123
# standard STE to get gradients through VQ layer.
11241124
quantize = x + (quantize - x).detach()

0 commit comments

Comments
 (0)