Skip to content

Commit 15629a5

Browse files
authored
Merge pull request #156 from lweitkamp/master
Remove double l2 normalization
2 parents 1447998 + ab5a61a commit 15629a5

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -751,7 +751,7 @@ def forward(
751751
embed_normalized = self.embed_avg / rearrange(cluster_size, '... -> ... 1')
752752
embed_normalized = l2norm(embed_normalized)
753753

754-
self.embed.data.copy_(l2norm(embed_normalized))
754+
self.embed.data.copy_(embed_normalized)
755755
self.expire_codes_(x)
756756

757757
if needs_codebook_dim:

0 commit comments

Comments
 (0)