Skip to content

Commit cdd0141

Browse files
committed
Update test_lfq.py
1 parent 9251915 commit cdd0141

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

tests/test_lfq.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,12 @@
1010
torch.manual_seed(0)
1111

1212
@pytest.mark.parametrize('frac_per_sample_entropy', (1., 0.5))
13+
@pytest.mark.parametrize('mask', (torch.tensor([False, False]),
14+
torch.tensor([True, False]),
15+
torch.tensor([True, True])))
1316
def test_masked_lfq(
14-
frac_per_sample_entropy
17+
frac_per_sample_entropy,
18+
mask
1519
):
1620
# you can specify either dim or codebook_size
1721
# if both specified, will be validated against each other
@@ -26,7 +30,7 @@ def test_masked_lfq(
2630

2731
image_feats = torch.randn(2, 16, 32, 32)
2832

29-
ret, loss_breakdown = quantizer(image_feats, inv_temperature=100., return_loss_breakdown=True) # you may want to experiment with temperature
33+
ret, loss_breakdown = quantizer(image_feats, inv_temperature=100., return_loss_breakdown=True, mask=mask) # you may want to experiment with temperature
3034

3135
quantized, indices, _ = ret
3236
assert (quantized == quantizer.indices_to_codes(indices)).all()

0 commit comments

Comments
 (0)