File tree Expand file tree Collapse file tree 1 file changed +10
-2
lines changed Expand file tree Collapse file tree 1 file changed +10
-2
lines changed Original file line number Diff line number Diff line change 14
14
train_iter = 10000
15
15
num_codes = 256
16
16
seed = 1234
17
- rotation_trick = True
17
+
18
+ rotation_trick = True # rotation trick instead ot straight-through
19
+ use_mlp = True # use a one layer mlp with relu instead of linear
20
+
18
21
device = "cuda" if torch .cuda .is_available () else "cpu"
19
22
20
23
def SimVQAutoEncoder (** vq_kwargs ):
@@ -77,7 +80,12 @@ def iterate_dataset(data_loader):
77
80
78
81
model = SimVQAutoEncoder (
79
82
codebook_size = num_codes ,
80
- rotation_trick = rotation_trick
83
+ rotation_trick = rotation_trick ,
84
+ codebook_transform = nn .Sequential (
85
+ nn .Linear (32 , 128 ),
86
+ nn .ReLU (),
87
+ nn .Linear (128 , 32 ),
88
+ ) if use_mlp else None
81
89
).to (device )
82
90
83
91
opt = torch .optim .AdamW (model .parameters (), lr = lr )
You can’t perform that action at this time.
0 commit comments