Skip to content

Commit f8f357a

Browse files
committed
cosine sim scale defaults to the codebook value
1 parent 371a1b9 commit f8f357a

File tree

2 files changed

+19
-6
lines changed

2 files changed

+19
-6
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.14.18"
3+
version = "1.14.19"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/lookup_free_quantization.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"""
88

99
from math import log2, ceil
10+
from functools import partial
1011
from collections import namedtuple
1112

1213
import torch
@@ -51,14 +52,20 @@ def entropy(prob):
5152
# cosine sim linear
5253

5354
class CosineSimLinear(Module):
54-
def __init__(self, dim_in, dim_out, **kwargs):
55+
def __init__(
56+
self,
57+
dim_in,
58+
dim_out,
59+
scale = 1.
60+
):
5561
super().__init__()
62+
self.scale = scale
5663
self.weight = nn.Parameter(torch.randn(dim_in, dim_out))
5764

5865
def forward(self, x):
5966
x = F.normalize(x, dim = -1)
6067
w = F.normalize(self.weight, dim = 0)
61-
return x @ w
68+
return (x @ w) * self.scale
6269

6370
# class
6471

@@ -79,7 +86,8 @@ def __init__(
7986
use_code_agnostic_commit_loss = False,
8087
projection_has_bias = True,
8188
soft_clamp_input_value = None,
82-
cosine_sim_project_in = False
89+
cosine_sim_project_in = False,
90+
cosine_sim_project_in_scale = None
8391
):
8492
super().__init__()
8593

@@ -96,8 +104,13 @@ def __init__(
96104

97105
has_projections = dim != codebook_dims
98106

99-
project_in_klass = CosineSimLinear if cosine_sim_project_in else nn.Linear
100-
self.project_in = project_in_klass(dim, codebook_dims, bias = projection_has_bias) if has_projections else nn.Identity()
107+
if cosine_sim_project_in:
108+
cosine_sim_project_in = default(cosine_sim_project_in_scale, codebook_scale)
109+
project_in_klass = partial(CosineSimLinear, scale = cosine_sim_project_in)
110+
else:
111+
project_in_klass = partial(nn.Linear, bias = projection_has_bias)
112+
113+
self.project_in = project_in_klass(dim, codebook_dims) if has_projections else nn.Identity()
101114
self.project_out = nn.Linear(codebook_dims, dim, bias = projection_has_bias) if has_projections else nn.Identity()
102115
self.has_projections = has_projections
103116

0 commit comments

Comments
 (0)