Skip to content

Commit f967856

Browse files
Alejandro Monroy MuñozAlejandro Monroy Muñoz
authored andcommitted
Fix bug with class_label param in UVit class
1 parent 5107295 commit f967856

File tree

2 files changed

+2
-1
lines changed

2 files changed

+2
-1
lines changed

models/early_exit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def forward(self, x, timesteps, y=None):
280280
)
281281
time_token = time_token.unsqueeze(dim=1)
282282
x = torch.cat((time_token, x), dim=1)
283-
if y is not None:
283+
if y is not None and self.uvit.label_emb is not None:
284284
label_emb = self.uvit.label_emb(y)
285285
label_emb = label_emb.unsqueeze(dim=1)
286286
x = torch.cat((label_emb, x), dim=1)

models/uvit.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ def __init__(
275275
self.label_emb = nn.Embedding(self.num_classes, embed_dim)
276276
self.extras = 2
277277
else:
278+
self.label_emb = None
278279
self.extras = 1
279280

280281
self.pos_embed = nn.Parameter(

0 commit comments

Comments
 (0)