@@ -19,10 +19,10 @@ def __init__(
19
19
nn .Conv2d (in_chans , in_chans , 3 , padding = 1 ) if conv else nn .Identity ()
20
20
)
21
21
22
- def forward (self , x ):
22
+ def forward (self , x , extras ):
23
23
x = self .norm (x )
24
24
x = self .decoder_pred (x )
25
- x = x [:, 1 :, :] # Ignore time vector
25
+ x = x [:, extras :, :] # Ignore time (and class) vector
26
26
x = unpatchify (x , self .in_chans )
27
27
x = self .final_layer (x )
28
28
return x
@@ -216,11 +216,14 @@ def __init__(
216
216
# TODO: FIXME (right now we have to modify this both in get_classifier and here)
217
217
if classifier_type == "attention_probe" :
218
218
self .matrix = nn .ModuleDict (
219
- {f"{ i } " : AttentionProbe (embed_dim = uvit .embed_dim ) for i in range (13 )}
219
+ {
220
+ f"{ i } " : AttentionProbe (embed_dim = uvit .embed_dim )
221
+ for i in range (uvit .depth )
222
+ }
220
223
)
221
224
elif classifier_type == "mlp_probe_per_layer" :
222
225
self .matrix = nn .ModuleDict (
223
- {f"{ i } " : MLPProbe (embed_dim = uvit .embed_dim ) for i in range (13 )}
226
+ {f"{ i } " : MLPProbe (embed_dim = uvit .embed_dim ) for i in range (uvit . depth )}
224
227
)
225
228
elif classifier_type == "mlp_probe_per_timestep" :
226
229
self .matrix = nn .ModuleDict (
@@ -231,7 +234,7 @@ def __init__(
231
234
{
232
235
f"{ i } , { t } " : MLPProbe (embed_dim = uvit .embed_dim )
233
236
for t in range (1000 )
234
- for i in range (13 )
237
+ for i in range (uvit . depth )
235
238
}
236
239
)
237
240
@@ -262,7 +265,7 @@ def __init__(
262
265
]
263
266
)
264
267
265
- def forward (self , x , timesteps ):
268
+ def forward (self , x , timesteps , y = None ):
266
269
t = int (timesteps [0 ])
267
270
if self .uvit .normalize_timesteps :
268
271
timesteps = timesteps .float () / 1000
@@ -277,29 +280,35 @@ def forward(self, x, timesteps):
277
280
)
278
281
time_token = time_token .unsqueeze (dim = 1 )
279
282
x = torch .cat ((time_token , x ), dim = 1 )
283
+ if y is not None :
284
+ label_emb = self .uvit .label_emb (y )
285
+ label_emb = label_emb .unsqueeze (dim = 1 )
286
+ x = torch .cat ((label_emb , x ), dim = 1 )
280
287
x = x + self .uvit .pos_embed
281
288
282
289
skips = []
283
290
284
291
for blk , layer_id , output_head in zip (
285
- self .uvit .in_blocks , range (6 ), self .in_blocks_heads
292
+ self .uvit .in_blocks , range (self . uvit . depth // 2 ), self .in_blocks_heads
286
293
):
287
294
classifier = self .get_classifer (t , layer_id )
288
- outputs .append (output_head (x ))
295
+ outputs .append (output_head (x , self . uvit . extras ))
289
296
classifier_outputs .append (classifier (x ))
290
297
x = blk (x )
291
298
skips .append (x )
292
299
293
- classifier = self .get_classifer (t , 6 )
294
- outputs .append (self .mid_block_head (x ))
300
+ classifier = self .get_classifer (t , self . uvit . depth // 2 )
301
+ outputs .append (self .mid_block_head (x , self . uvit . extras ))
295
302
classifier_outputs .append (classifier (x ))
296
303
x = self .uvit .mid_block (x )
297
304
298
305
for blk , layer_id , output_head in zip (
299
- self .uvit .out_blocks , range (7 , 13 ), self .out_blocks_heads
306
+ self .uvit .out_blocks ,
307
+ range (self .uvit .depth // 2 + 1 , self .uvit .depth ),
308
+ self .out_blocks_heads ,
300
309
):
301
310
classifier = self .get_classifer (t , layer_id )
302
- outputs .append (output_head (x ))
311
+ outputs .append (output_head (x , self . uvit . extras ))
303
312
classifier_outputs .append (classifier (x ))
304
313
x = blk (x , skips .pop ())
305
314
0 commit comments