Skip to content

Commit ab5a61a

Browse files
authored
remove double l2 normalization
subsequent l2 normalizations are not computational wasteful and do not change the output vector.
1 parent 1447998 commit ab5a61a

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -751,7 +751,7 @@ def forward(
751751
embed_normalized = self.embed_avg / rearrange(cluster_size, '... -> ... 1')
752752
embed_normalized = l2norm(embed_normalized)
753753

754-
self.embed.data.copy_(l2norm(embed_normalized))
754+
self.embed.data.copy_(embed_normalized)
755755
self.expire_codes_(x)
756756

757757
if needs_codebook_dim:

0 commit comments

Comments
 (0)