Skip to content

Commit c061cb5

Browse files
committed
Allow EarlyExitUViT for class conditioning
1 parent 22fd9fb commit c061cb5

File tree

2 files changed

+22
-12
lines changed

2 files changed

+22
-12
lines changed

models/early_exit.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ def __init__(
1919
nn.Conv2d(in_chans, in_chans, 3, padding=1) if conv else nn.Identity()
2020
)
2121

22-
def forward(self, x):
22+
def forward(self, x, extras):
2323
x = self.norm(x)
2424
x = self.decoder_pred(x)
25-
x = x[:, 1:, :] # Ignore time vector
25+
x = x[:, extras:, :] # Ignore time (and class) vector
2626
x = unpatchify(x, self.in_chans)
2727
x = self.final_layer(x)
2828
return x
@@ -216,11 +216,14 @@ def __init__(
216216
# TODO: FIXME (right now we have to modify this both in get_classifier and here)
217217
if classifier_type == "attention_probe":
218218
self.matrix = nn.ModuleDict(
219-
{f"{i}": AttentionProbe(embed_dim=uvit.embed_dim) for i in range(13)}
219+
{
220+
f"{i}": AttentionProbe(embed_dim=uvit.embed_dim)
221+
for i in range(uvit.depth)
222+
}
220223
)
221224
elif classifier_type == "mlp_probe_per_layer":
222225
self.matrix = nn.ModuleDict(
223-
{f"{i}": MLPProbe(embed_dim=uvit.embed_dim) for i in range(13)}
226+
{f"{i}": MLPProbe(embed_dim=uvit.embed_dim) for i in range(uvit.depth)}
224227
)
225228
elif classifier_type == "mlp_probe_per_timestep":
226229
self.matrix = nn.ModuleDict(
@@ -231,7 +234,7 @@ def __init__(
231234
{
232235
f"{i}, {t}": MLPProbe(embed_dim=uvit.embed_dim)
233236
for t in range(1000)
234-
for i in range(13)
237+
for i in range(uvit.depth)
235238
}
236239
)
237240

@@ -262,7 +265,7 @@ def __init__(
262265
]
263266
)
264267

265-
def forward(self, x, timesteps):
268+
def forward(self, x, timesteps, y=None):
266269
t = int(timesteps[0])
267270
if self.uvit.normalize_timesteps:
268271
timesteps = timesteps.float() / 1000
@@ -277,29 +280,35 @@ def forward(self, x, timesteps):
277280
)
278281
time_token = time_token.unsqueeze(dim=1)
279282
x = torch.cat((time_token, x), dim=1)
283+
if y is not None:
284+
label_emb = self.uvit.label_emb(y)
285+
label_emb = label_emb.unsqueeze(dim=1)
286+
x = torch.cat((label_emb, x), dim=1)
280287
x = x + self.uvit.pos_embed
281288

282289
skips = []
283290

284291
for blk, layer_id, output_head in zip(
285-
self.uvit.in_blocks, range(6), self.in_blocks_heads
292+
self.uvit.in_blocks, range(self.uvit.depth // 2), self.in_blocks_heads
286293
):
287294
classifier = self.get_classifer(t, layer_id)
288-
outputs.append(output_head(x))
295+
outputs.append(output_head(x, self.uvit.extras))
289296
classifier_outputs.append(classifier(x))
290297
x = blk(x)
291298
skips.append(x)
292299

293-
classifier = self.get_classifer(t, 6)
294-
outputs.append(self.mid_block_head(x))
300+
classifier = self.get_classifer(t, self.uvit.depth // 2)
301+
outputs.append(self.mid_block_head(x, self.uvit.extras))
295302
classifier_outputs.append(classifier(x))
296303
x = self.uvit.mid_block(x)
297304

298305
for blk, layer_id, output_head in zip(
299-
self.uvit.out_blocks, range(7, 13), self.out_blocks_heads
306+
self.uvit.out_blocks,
307+
range(self.uvit.depth // 2 + 1, self.uvit.depth),
308+
self.out_blocks_heads,
300309
):
301310
classifier = self.get_classifer(t, layer_id)
302-
outputs.append(output_head(x))
311+
outputs.append(output_head(x, self.uvit.extras))
303312
classifier_outputs.append(classifier(x))
304313
x = blk(x, skips.pop())
305314

models/uvit.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ def __init__(
254254

255255
self.num_classes = num_classes
256256
self.in_chans = in_chans
257+
self.depth = depth
257258

258259
self.patch_embed = PatchEmbed(
259260
patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim

0 commit comments

Comments
 (0)