Skip to content

Commit 49817f1

Browse files
committed
remove setup
2 parents eabf156 + 665f4b6 commit 49817f1

File tree

4 files changed

+85
-57
lines changed

4 files changed

+85
-57
lines changed

.github/workflows/build.yml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
name: Build with Rye
2+
on: push
3+
4+
jobs:
5+
build:
6+
runs-on: ubuntu-latest
7+
steps:
8+
- uses: actions/checkout@v4
9+
- name: Install Python
10+
uses: actions/setup-python@v4
11+
- name: Install the latest version of rye
12+
uses: eifinger/setup-rye@v2
13+
- name: Build with Rye
14+
run: rye build

pyproject.toml

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,52 @@
1+
[project]
2+
name = "vector-quantize-pytorch"
3+
version = "1.14.9"
4+
description = "Vector Quantization - Pytorch"
5+
authors = [
6+
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
7+
]
8+
readme = "README.md"
9+
requires-python = ">= 3.9"
10+
license = { file = "LICENSE" }
11+
keywords = [
12+
'artificial intelligence',
13+
'deep learning',
14+
'pytorch',
15+
'quantization'
16+
]
17+
classifiers=[
18+
'Development Status :: 4 - Beta',
19+
'Intended Audience :: Developers',
20+
'Topic :: Scientific/Engineering :: Artificial Intelligence',
21+
'License :: OSI Approved :: MIT License',
22+
'Programming Language :: Python :: 3.6',
23+
]
24+
25+
dependencies = [
26+
"torch>=2.0",
27+
"einops>=0.8.0",
28+
"einx>=0.2.2",
29+
]
30+
31+
[project.urls]
32+
Homepage = "https://pypi.org/project/vector-quantize-pytorch/"
33+
Repository = "https://github.com/lucidrains/vector-quantizer-pytorch"
34+
35+
[project.optional-dependencies]
36+
examples = ["tqdm", "torchvision"]
37+
138
[build-system]
2-
requires = ["setuptools"]
3-
build-backend = "setuptools.build_meta"
39+
requires = ["hatchling"]
40+
build-backend = "hatchling.build"
41+
42+
[tool.rye]
43+
managed = true
44+
dev-dependencies = [
45+
"ruff>=0.4.2",
46+
]
47+
48+
[tool.hatch.metadata]
49+
allow-direct-references = true
50+
51+
[tool.hatch.build.targets.wheel]
52+
packages = ["vector_quantize_pytorch"]

setup.py

Lines changed: 0 additions & 37 deletions
This file was deleted.

vector_quantize_pytorch/finite_scalar_quantization.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -87,44 +87,46 @@ def __init__(
8787

8888
self.allowed_dtypes = allowed_dtypes
8989

90-
def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor:
91-
"""Bound `z`, an array of shape (..., d)."""
90+
def bound(self, z, eps: float = 1e-3):
91+
""" Bound `z`, an array of shape (..., d). """
9292
half_l = (self._levels - 1) * (1 + eps) / 2
9393
offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
9494
shift = (offset / half_l).atanh()
9595
return (z + shift).tanh() * half_l - offset
9696

97-
def quantize(self, z: Tensor) -> Tensor:
98-
"""Quantizes z, returns quantized zhat, same shape as z."""
97+
def quantize(self, z):
98+
""" Quantizes z, returns quantized zhat, same shape as z. """
9999
quantized = round_ste(self.bound(z))
100100
half_width = self._levels // 2 # Renormalize to [-1, 1].
101101
return quantized / half_width
102102

103-
def _scale_and_shift(self, zhat_normalized: Tensor) -> Tensor:
103+
def _scale_and_shift(self, zhat_normalized):
104104
half_width = self._levels // 2
105105
return (zhat_normalized * half_width) + half_width
106106

107-
def _scale_and_shift_inverse(self, zhat: Tensor) -> Tensor:
107+
def _scale_and_shift_inverse(self, zhat):
108108
half_width = self._levels // 2
109109
return (zhat - half_width) / half_width
110110

111-
def _indices_to_codes(self, indices: Tensor):
112-
indices = rearrange(indices, '... -> ... 1')
113-
codes_non_centered = (indices // self._basis) % self._levels
114-
codes = self._scale_and_shift_inverse(codes_non_centered)
111+
def _indices_to_codes(self, indices):
112+
level_indices = self.indices_to_level_indices(indices)
113+
codes = self._scale_and_shift_inverse(level_indices)
115114
return codes
116115

117-
def codes_to_indices(self, zhat: Tensor) -> Tensor:
118-
"""Converts a `code` to an index in the codebook."""
116+
def codes_to_indices(self, zhat):
117+
""" Converts a `code` to an index in the codebook. """
119118
assert zhat.shape[-1] == self.codebook_dim
120119
zhat = self._scale_and_shift(zhat)
121120
return (zhat * self._basis).sum(dim=-1).to(int32)
122121

123-
def indices_to_codes(
124-
self,
125-
indices: Tensor
126-
) -> Tensor:
127-
"""Inverse of `codes_to_indices`."""
122+
def indices_to_level_indices(self, indices):
123+
""" Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings """
124+
indices = rearrange(indices, '... -> ... 1')
125+
codes_non_centered = (indices // self._basis) % self._levels
126+
return codes_non_centered
127+
128+
def indices_to_codes(self, indices):
129+
""" Inverse of `codes_to_indices`. """
128130

129131
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
130132

@@ -141,7 +143,7 @@ def indices_to_codes(
141143
return codes
142144

143145
@autocast(enabled = False)
144-
def forward(self, z: Tensor) -> Tensor:
146+
def forward(self, z):
145147
"""
146148
einstein notation
147149
b - batch

0 commit comments

Comments
 (0)