6
6
7
7
import pytest
8
8
import timm
9
+ import torch
9
10
10
11
from dev_tools .algorithms_testing_tools import test_conversion_to_dag as _test_conversion_to_dag , \
11
12
test_orbitalization_and_channel_removal as _test_orbitalization_and_channel_removal
@@ -47,3 +48,32 @@ def test_timm_model_light(test_case: TimmModelTestCase, tmpdir):
47
48
)
48
49
def test_timm_model_heavy (test_case : TimmModelTestCase , tmpdir ):
49
50
run_timm_model_test (test_case , tmpdir )
51
+
52
+ def test_multi_output_model (
53
+ ):
54
+ class DummyModel (torch .nn .Module ):
55
+ def __init__ (self ):
56
+ super ().__init__ ()
57
+ self .input = torch .nn .Sequential (torch .nn .Conv2d (3 , 3 , 3 , 1 , 1 ), torch .nn .ReLU ())
58
+ self .head_0 = torch .nn .Sequential (torch .nn .Conv2d (3 , 64 , 3 , 1 , 1 ), torch .nn .ReLU ())
59
+ self .head_1 = torch .nn .Sequential (torch .nn .Conv2d (3 , 128 , 3 , 1 , 1 ), torch .nn .ReLU ())
60
+ self .head_2 = torch .nn .Sequential (torch .nn .Conv2d (3 , 256 , 3 , 1 , 1 ), torch .nn .ReLU ())
61
+
62
+ def forward (self , x ):
63
+ x = self .input (x )
64
+ out_0 = self .head_0 (x )
65
+ out_1 = self .head_1 (x )
66
+ out_2 = self .head_2 (x )
67
+ return out_0 , out_1 , out_2
68
+
69
+ torch_model = DummyModel ()
70
+ torch_model .eval ()
71
+ input_shape = (1 , 3 , 224 , 224 )
72
+ dag , msg = _test_conversion_to_dag (torch_model , input_shape = input_shape )
73
+ result_channel_pruning = _test_orbitalization_and_channel_removal (
74
+ dag = deepcopy (dag ),
75
+ input_shape = input_shape ,
76
+ prob_removal = 0.5 ,
77
+ )
78
+ assert result_channel_pruning ['status' ] == SUCCESS_RESULT
79
+ assert result_channel_pruning ['prunable_fraction' ] > 0.0
0 commit comments