Skip to content

Commit 007209d

Browse files
committed
update init
1 parent 7959292 commit 007209d

File tree

1 file changed

+2
-39
lines changed

1 file changed

+2
-39
lines changed

vector_quantize_pytorch/sim_vq.py

Lines changed: 2 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -29,25 +29,6 @@ def inverse(out, inv_pattern = None):
2929

3030
return packed, inverse
3131

32-
def l2norm(t, dim = -1):
33-
return F.normalize(t, dim = dim)
34-
35-
def safe_div(num, den, eps = 1e-6):
36-
return num / den.clamp(min = eps)
37-
38-
def efficient_rotation_trick_transform(u, q, e):
39-
"""
40-
4.2 in https://arxiv.org/abs/2410.06424
41-
"""
42-
e = rearrange(e, 'b d -> b 1 d')
43-
w = l2norm(u + q, dim = 1).detach()
44-
45-
return (
46-
e -
47-
2 * (e @ rearrange(w, 'b d -> b d 1') @ rearrange(w, 'b d -> b 1 d')) +
48-
2 * (e @ rearrange(u, 'b d -> b d 1').detach() @ rearrange(q, 'b d -> b 1 d').detach())
49-
)
50-
5132
# class
5233

5334
class SimVQ(Module):
@@ -61,7 +42,7 @@ def __init__(
6142
super().__init__()
6243
self.accept_image_fmap = accept_image_fmap
6344

64-
codebook = torch.randn(codebook_size, dim)
45+
codebook = torch.randn(codebook_size, dim) * (dim ** -0.5)
6546
codebook = init_fn(codebook)
6647

6748
# the codebook is actually implicit from a linear layer from frozen gaussian or uniform
@@ -89,25 +70,7 @@ def forward(
8970

9071
# commit loss
9172

92-
commit_loss = (F.pairwise_distance(x, quantized.detach()) ** 2).mean()
93-
94-
# straight through
95-
96-
x, inverse = pack_one(x, '* d')
97-
quantized, _ = pack_one(quantized, '* d')
98-
99-
norm_x = x.norm(dim = -1, keepdim = True)
100-
norm_quantize = quantized.norm(dim = -1, keepdim = True)
101-
102-
rot_quantize = efficient_rotation_trick_transform(
103-
safe_div(x, norm_x),
104-
safe_div(quantized, norm_quantize),
105-
x
106-
).squeeze()
107-
108-
quantized = rot_quantize * safe_div(norm_quantize, norm_x).detach()
109-
110-
x, quantized = inverse(x), inverse(quantized)
73+
commit_loss = (F.pairwise_distance(x, quantized) ** 2).mean()
11174

11275
# quantized = (quantized - x).detach() + x
11376

0 commit comments

Comments
 (0)