Skip to content

Commit 3bb00f5

Browse files
committed
seems to work even better with a one layer mlp
1 parent 949b0ba commit 3bb00f5

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

examples/autoencoder_sim_vq.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
train_iter = 10000
1515
num_codes = 256
1616
seed = 1234
17-
rotation_trick = True
17+
18+
rotation_trick = True # rotation trick instead ot straight-through
19+
use_mlp = True # use a one layer mlp with relu instead of linear
20+
1821
device = "cuda" if torch.cuda.is_available() else "cpu"
1922

2023
def SimVQAutoEncoder(**vq_kwargs):
@@ -77,7 +80,12 @@ def iterate_dataset(data_loader):
7780

7881
model = SimVQAutoEncoder(
7982
codebook_size = num_codes,
80-
rotation_trick = rotation_trick
83+
rotation_trick = rotation_trick,
84+
codebook_transform = nn.Sequential(
85+
nn.Linear(32, 128),
86+
nn.ReLU(),
87+
nn.Linear(128, 32),
88+
) if use_mlp else None
8189
).to(device)
8290

8391
opt = torch.optim.AdamW(model.parameters(), lr=lr)

0 commit comments

Comments
 (0)