Skip to content

Commit 7e290e9

Browse files
committed
Fix tests
1 parent b399ea5 commit 7e290e9

File tree

1 file changed

+0
-25
lines changed

1 file changed

+0
-25
lines changed

tests/models/test_early_exit.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -113,28 +113,3 @@ def test_backward(classifier_type):
113113

114114
fake_loss = torch.sum(y)
115115
fake_loss.backward()
116-
117-
118-
@pytest.mark.skip(reason="no early exit on inference mode yet")
119-
def test_inference_no_early_exit():
120-
model = EarlyExitUViT(UViT(**cifar10_config), exit_threshold=-torch.inf)
121-
model.eval()
122-
with torch.inference_mode():
123-
y, classifier_outputs, early_exit_layer = model(x, t)
124-
125-
assert len(classifier_outputs) == 13
126-
assert y.shape == (batch_size, num_channels, height, width)
127-
assert early_exit_layer == 13
128-
129-
130-
@pytest.mark.skip(reason="no early exit on inference mode yet")
131-
def test_inference_exit_first():
132-
model = EarlyExitUViT(UViT(**cifar10_config), exit_threshold=torch.inf)
133-
model.eval()
134-
with torch.inference_mode():
135-
y, classifier_outputs, early_exit_layer = model(x, t)
136-
137-
assert y.shape == (batch_size, num_channels, height, width)
138-
assert len(classifier_outputs) == 1
139-
assert all(classifier_outputs[0] < model.exit_threshold)
140-
assert early_exit_layer == 0

0 commit comments

Comments
 (0)