Skip to content

Commit 371a1b9

Browse files
committed
ability to use a cosine sim distance for initial projection before LFQ
1 parent 5bf9b91 commit 371a1b9

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.14.17"
3+
version = "1.14.18"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/lookup_free_quantization.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,18 @@ def log(t, eps = 1e-5):
4848
def entropy(prob):
4949
return (-prob * log(prob)).sum(dim=-1)
5050

51+
# cosine sim linear
52+
53+
class CosineSimLinear(Module):
54+
def __init__(self, dim_in, dim_out, **kwargs):
55+
super().__init__()
56+
self.weight = nn.Parameter(torch.randn(dim_in, dim_out))
57+
58+
def forward(self, x):
59+
x = F.normalize(x, dim = -1)
60+
w = F.normalize(self.weight, dim = 0)
61+
return x @ w
62+
5163
# class
5264

5365
class LFQ(Module):
@@ -66,7 +78,8 @@ def __init__(
6678
frac_per_sample_entropy = 1., # make less than 1. to only use a random fraction of the probs for per sample entropy
6779
use_code_agnostic_commit_loss = False,
6880
projection_has_bias = True,
69-
soft_clamp_input_value = None
81+
soft_clamp_input_value = None,
82+
cosine_sim_project_in = False
7083
):
7184
super().__init__()
7285

@@ -82,7 +95,9 @@ def __init__(
8295
dim = default(dim, codebook_dims)
8396

8497
has_projections = dim != codebook_dims
85-
self.project_in = nn.Linear(dim, codebook_dims, bias = projection_has_bias) if has_projections else nn.Identity()
98+
99+
project_in_klass = CosineSimLinear if cosine_sim_project_in else nn.Linear
100+
self.project_in = project_in_klass(dim, codebook_dims, bias = projection_has_bias) if has_projections else nn.Identity()
86101
self.project_out = nn.Linear(codebook_dims, dim, bias = projection_has_bias) if has_projections else nn.Identity()
87102
self.has_projections = has_projections
88103

0 commit comments

Comments
 (0)