Skip to content

Commit 5ae60bd

Browse files
committed
able to turn off bias for FSQ as well, as in scalar quantization, codebook is implicit in the previous projection
1 parent be995a7 commit 5ae60bd

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

vector_quantize_pytorch/finite_scalar_quantization.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ def __init__(
4848
keep_num_codebooks_dim: Optional[bool] = None,
4949
scale: Optional[float] = None,
5050
allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64),
51-
channel_first: bool = False
51+
channel_first: bool = False,
52+
projection_has_bias: bool = True
5253
):
5354
super().__init__()
5455
_levels = torch.tensor(levels, dtype=int32)
@@ -75,8 +76,8 @@ def __init__(
7576
self.channel_first = channel_first
7677

7778
has_projections = self.dim != effective_codebook_dim
78-
self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity()
79-
self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity()
79+
self.project_in = nn.Linear(self.dim, effective_codebook_dim, bias = projection_has_bias) if has_projections else nn.Identity()
80+
self.project_out = nn.Linear(effective_codebook_dim, self.dim, bias = projection_has_bias) if has_projections else nn.Identity()
8081

8182
self.has_projections = has_projections
8283

0 commit comments

Comments
 (0)