@@ -38,36 +38,11 @@ def unpack_one(t, ps, pattern):
38
38
39
39
# entropy
40
40
41
- def binary_entropy (prob ):
42
- return - prob * log (prob ) - (1 - prob ) * log (1 - prob )
43
-
44
- # tensor helpers
45
-
46
41
def log (t , eps = 1e-20 ):
47
42
return t .clamp (min = eps ).log ()
48
43
49
- # convert to bit representations and back
50
-
51
- def decimal_to_bits (x , bits ):
52
- device = x .device
53
-
54
- x = x .int ()
55
-
56
- mask = 2 ** torch .arange (bits - 1 , - 1 , - 1 , device = device )
57
- x = rearrange (x , 'b n -> b n 1' )
58
-
59
- bits = ((x & mask ) != 0 ).float ()
60
- bits = rearrange (bits , 'b n d -> b n d' )
61
- return bits * 2 - 1
62
-
63
- def bits_to_decimal (x , bits ):
64
- device = x .device
65
-
66
- x = (x > 0 ).int ()
67
-
68
- mask = 2 ** torch .arange (bits - 1 , - 1 , - 1 , device = device , dtype = torch .int32 )
69
- dec = reduce (x * mask , 'b n d -> b n' , 'sum' )
70
- return dec
44
+ def binary_entropy (prob ):
45
+ return - prob * log (prob ) - (1 - prob ) * log (1 - prob )
71
46
72
47
# class
73
48
@@ -105,6 +80,7 @@ def __init__(
105
80
106
81
# for no auxiliary loss, during inference
107
82
83
+ self .register_buffer ('mask' , 2 ** torch .arange (codebook_dim - 1 , - 1 , - 1 ))
108
84
self .register_buffer ('zero' , torch .zeros (1 ,), persistent = False )
109
85
110
86
def indices_to_codes (
@@ -114,14 +90,10 @@ def indices_to_codes(
114
90
):
115
91
is_img_or_video = indices .ndim >= 3
116
92
117
- # rearrange if image or video into (batch, seq, dimension)
118
-
119
- if is_img_or_video :
120
- indices , ps = pack_one (indices , 'b *' )
121
-
122
93
# indices to codes, which are bits of either -1 or 1
123
94
124
- codes = decimal_to_bits (indices , self .codebook_dim )
95
+ bits = ((indices [..., None ].int () & self .mask ) != 0 ).float ()
96
+ codes = bits * 2 - 1
125
97
126
98
# whether to project codes out to original dimensions
127
99
# if the input feature dimensions were not log2(codebook size)
@@ -132,7 +104,6 @@ def indices_to_codes(
132
104
# rearrange codes back to original shape
133
105
134
106
if is_img_or_video :
135
- codes = unpack_one (codes , ps , 'b * d' )
136
107
codes = rearrange (codes , 'b ... d -> b d ...' )
137
108
138
109
return codes
@@ -163,10 +134,8 @@ def forward(
163
134
164
135
# quantize by eq 3.
165
136
166
- greater_than_zero = x > 0
167
137
ones = torch .ones_like (x )
168
-
169
- quantized = torch .where (greater_than_zero , ones , - ones )
138
+ quantized = torch .where (x > 0 , ones , - ones )
170
139
171
140
# use straight-through gradients with tanh if training
172
141
@@ -178,7 +147,7 @@ def forward(
178
147
179
148
# calculate indices
180
149
181
- indices = bits_to_decimal ( x , self .codebook_dim )
150
+ indices = reduce (( x > 0 ). int () * self .mask . int (), 'b n d -> b n' , 'sum' )
182
151
183
152
# entropy aux loss
184
153
0 commit comments