Skip to content

Commit 63ce494

Browse files
committed
handle distributed mean in lfq
1 parent b8d077d commit 63ce494

File tree

2 files changed

+22
-15
lines changed

2 files changed

+22
-15
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.14.32"
3+
version = "1.14.33"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/lookup_free_quantization.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
"""
88

99
from math import log2, ceil
10-
from functools import partial
10+
from functools import partial, cache
1111
from collections import namedtuple
12+
import torch.distributed as dist
1213

1314
import torch
1415
from torch import nn, einsum
@@ -24,6 +25,20 @@
2425

2526
LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'batch_entropy', 'commitment'])
2627

28+
# distributed helpers
29+
30+
@cache
31+
def is_distributed():
32+
return dist.is_initialized() and dist.get_world_size() > 1
33+
34+
def maybe_distributed_mean(t):
35+
if not is_distributed():
36+
return t
37+
38+
dist.all_reduce(t)
39+
t = t / dist.get_world_size()
40+
return t
41+
2742
# helper functions
2843

2944
def exists(v):
@@ -83,14 +98,14 @@ def __init__(
8398
keep_num_codebooks_dim = None,
8499
codebook_scale = 1., # for residual LFQ, codebook scaled down by 2x at each layer
85100
frac_per_sample_entropy = 1., # make less than 1. to only use a random fraction of the probs for per sample entropy
86-
use_code_agnostic_commit_loss = False,
87101
projection_has_bias = True,
88102
soft_clamp_input_value = None,
89103
cosine_sim_project_in = False,
90104
cosine_sim_project_in_scale = None,
91105
channel_first = None,
92106
experimental_softplus_entropy_loss = False,
93107
entropy_loss_offset = 5., # how much to shift the loss before softplus
108+
94109
):
95110
super().__init__()
96111

@@ -149,7 +164,6 @@ def __init__(
149164
# commitment loss
150165

151166
self.commitment_loss_weight = commitment_loss_weight
152-
self.use_code_agnostic_commit_loss = use_code_agnostic_commit_loss
153167

154168
# whether to soft clamp the input value from -value to value
155169

@@ -305,6 +319,9 @@ def forward(
305319
# distribution over all available tokens in the batch
306320

307321
avg_prob = reduce(per_sample_probs, '... c d -> c d', 'mean')
322+
323+
avg_prob = maybe_distributed_mean(avg_prob)
324+
308325
codebook_entropy = entropy(avg_prob).mean()
309326

310327
# 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions
@@ -324,17 +341,7 @@ def forward(
324341

325342
if self.training and self.commitment_loss_weight > 0.:
326343

327-
if self.use_code_agnostic_commit_loss:
328-
# credit goes to @MattMcPartlon for sharing this in https://github.com/lucidrains/vector-quantize-pytorch/issues/120#issuecomment-2095089337
329-
330-
commit_loss = F.mse_loss(
331-
original_input ** 2,
332-
codebook_value ** 2,
333-
reduction = 'none'
334-
)
335-
336-
else:
337-
commit_loss = F.mse_loss(original_input, quantized.detach(), reduction = 'none')
344+
commit_loss = F.mse_loss(original_input, quantized.detach(), reduction = 'none')
338345

339346
if exists(mask):
340347
commit_loss = commit_loss[mask]

0 commit comments

Comments
 (0)