Skip to content

Commit 493f358

Browse files
committed
Fix pytests
1 parent d4cf479 commit 493f358

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

tests/models/test_early_exit.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
patch_size=4,
3838
in_chans=3,
3939
embed_dim=512,
40-
depth=12,
40+
depth=13,
4141
num_heads=8,
4242
mlp_ratio=4,
4343
qkv_bias=False,
@@ -51,7 +51,7 @@
5151
patch_size=2,
5252
in_chans=3,
5353
embed_dim=512,
54-
depth=12,
54+
depth=13,
5555
num_heads=8,
5656
mlp_ratio=4,
5757
qkv_bias=False,
@@ -73,7 +73,7 @@ def test_output_head():
7373

7474
x = torch.zeros((16, 257, 512))
7575

76-
y = output_head(x)
76+
y = output_head(x, extras=1)
7777
assert y.shape == (16, 3, 32, 32)
7878

7979

@@ -109,7 +109,7 @@ def test_backward(classifier_type):
109109
y, classifier_outputs, outputs = model(x, t)
110110

111111
assert y.shape == x.shape
112-
assert len(outputs) == len(classifier_outputs) == 13
112+
assert len(outputs) == len(classifier_outputs) == model.uvit.depth
113113

114114
fake_loss = torch.sum(y)
115115
fake_loss.backward()

tests/models/test_uvit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
patch_size=4,
3939
in_chans=3,
4040
embed_dim=512,
41-
depth=12,
41+
depth=13,
4242
num_heads=8,
4343
mlp_ratio=4,
4444
qkv_bias=False,
@@ -52,7 +52,7 @@
5252
patch_size=2,
5353
in_chans=3,
5454
embed_dim=512,
55-
depth=12,
55+
depth=13,
5656
num_heads=8,
5757
mlp_ratio=4,
5858
qkv_bias=False,

0 commit comments

Comments
 (0)