Skip to content

Commit 665f4b6

Browse files
committed
add an extra method to FSQ for personal use
1 parent f40388a commit 665f4b6

File tree

2 files changed

+21
-19
lines changed

2 files changed

+21
-19
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.14.8"
3+
version = "1.14.9"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/finite_scalar_quantization.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -87,44 +87,46 @@ def __init__(
8787

8888
self.allowed_dtypes = allowed_dtypes
8989

90-
def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor:
91-
"""Bound `z`, an array of shape (..., d)."""
90+
def bound(self, z, eps: float = 1e-3):
91+
""" Bound `z`, an array of shape (..., d). """
9292
half_l = (self._levels - 1) * (1 + eps) / 2
9393
offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
9494
shift = (offset / half_l).atanh()
9595
return (z + shift).tanh() * half_l - offset
9696

97-
def quantize(self, z: Tensor) -> Tensor:
98-
"""Quantizes z, returns quantized zhat, same shape as z."""
97+
def quantize(self, z):
98+
""" Quantizes z, returns quantized zhat, same shape as z. """
9999
quantized = round_ste(self.bound(z))
100100
half_width = self._levels // 2 # Renormalize to [-1, 1].
101101
return quantized / half_width
102102

103-
def _scale_and_shift(self, zhat_normalized: Tensor) -> Tensor:
103+
def _scale_and_shift(self, zhat_normalized):
104104
half_width = self._levels // 2
105105
return (zhat_normalized * half_width) + half_width
106106

107-
def _scale_and_shift_inverse(self, zhat: Tensor) -> Tensor:
107+
def _scale_and_shift_inverse(self, zhat):
108108
half_width = self._levels // 2
109109
return (zhat - half_width) / half_width
110110

111-
def _indices_to_codes(self, indices: Tensor):
112-
indices = rearrange(indices, '... -> ... 1')
113-
codes_non_centered = (indices // self._basis) % self._levels
114-
codes = self._scale_and_shift_inverse(codes_non_centered)
111+
def _indices_to_codes(self, indices):
112+
level_indices = self.indices_to_level_indices(indices)
113+
codes = self._scale_and_shift_inverse(level_indices)
115114
return codes
116115

117-
def codes_to_indices(self, zhat: Tensor) -> Tensor:
118-
"""Converts a `code` to an index in the codebook."""
116+
def codes_to_indices(self, zhat):
117+
""" Converts a `code` to an index in the codebook. """
119118
assert zhat.shape[-1] == self.codebook_dim
120119
zhat = self._scale_and_shift(zhat)
121120
return (zhat * self._basis).sum(dim=-1).to(int32)
122121

123-
def indices_to_codes(
124-
self,
125-
indices: Tensor
126-
) -> Tensor:
127-
"""Inverse of `codes_to_indices`."""
122+
def indices_to_level_indices(self, indices):
123+
""" Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings """
124+
indices = rearrange(indices, '... -> ... 1')
125+
codes_non_centered = (indices // self._basis) % self._levels
126+
return codes_non_centered
127+
128+
def indices_to_codes(self, indices):
129+
""" Inverse of `codes_to_indices`. """
128130

129131
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
130132

@@ -141,7 +143,7 @@ def indices_to_codes(
141143
return codes
142144

143145
@autocast(enabled = False)
144-
def forward(self, z: Tensor) -> Tensor:
146+
def forward(self, z):
145147
"""
146148
einstein notation
147149
b - batch

0 commit comments

Comments
 (0)