@@ -29,25 +29,6 @@ def inverse(out, inv_pattern = None):
29
29
30
30
return packed , inverse
31
31
32
- def l2norm (t , dim = - 1 ):
33
- return F .normalize (t , dim = dim )
34
-
35
- def safe_div (num , den , eps = 1e-6 ):
36
- return num / den .clamp (min = eps )
37
-
38
- def efficient_rotation_trick_transform (u , q , e ):
39
- """
40
- 4.2 in https://arxiv.org/abs/2410.06424
41
- """
42
- e = rearrange (e , 'b d -> b 1 d' )
43
- w = l2norm (u + q , dim = 1 ).detach ()
44
-
45
- return (
46
- e -
47
- 2 * (e @ rearrange (w , 'b d -> b d 1' ) @ rearrange (w , 'b d -> b 1 d' )) +
48
- 2 * (e @ rearrange (u , 'b d -> b d 1' ).detach () @ rearrange (q , 'b d -> b 1 d' ).detach ())
49
- )
50
-
51
32
# class
52
33
53
34
class SimVQ (Module ):
@@ -61,7 +42,7 @@ def __init__(
61
42
super ().__init__ ()
62
43
self .accept_image_fmap = accept_image_fmap
63
44
64
- codebook = torch .randn (codebook_size , dim )
45
+ codebook = torch .randn (codebook_size , dim ) * ( dim ** - 0.5 )
65
46
codebook = init_fn (codebook )
66
47
67
48
# the codebook is actually implicit from a linear layer from frozen gaussian or uniform
@@ -89,25 +70,7 @@ def forward(
89
70
90
71
# commit loss
91
72
92
- commit_loss = (F .pairwise_distance (x , quantized .detach ()) ** 2 ).mean ()
93
-
94
- # straight through
95
-
96
- x , inverse = pack_one (x , '* d' )
97
- quantized , _ = pack_one (quantized , '* d' )
98
-
99
- norm_x = x .norm (dim = - 1 , keepdim = True )
100
- norm_quantize = quantized .norm (dim = - 1 , keepdim = True )
101
-
102
- rot_quantize = efficient_rotation_trick_transform (
103
- safe_div (x , norm_x ),
104
- safe_div (quantized , norm_quantize ),
105
- x
106
- ).squeeze ()
107
-
108
- quantized = rot_quantize * safe_div (norm_quantize , norm_x ).detach ()
109
-
110
- x , quantized = inverse (x ), inverse (quantized )
73
+ commit_loss = (F .pairwise_distance (x , quantized ) ** 2 ).mean ()
111
74
112
75
# quantized = (quantized - x).detach() + x
113
76
0 commit comments