Skip to content

Commit 7c524f4

Browse files
Merge pull request #53 from maciej-smyl-tcl/fix-multi-output-orbitalization
fix-multi-output-orbitalization
2 parents 4ae34e3 + a853d62 commit 7c524f4

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

tests/test_timm.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import pytest
88
import timm
9+
import torch
910

1011
from dev_tools.algorithms_testing_tools import test_conversion_to_dag as _test_conversion_to_dag, \
1112
test_orbitalization_and_channel_removal as _test_orbitalization_and_channel_removal
@@ -47,3 +48,32 @@ def test_timm_model_light(test_case: TimmModelTestCase, tmpdir):
4748
)
4849
def test_timm_model_heavy(test_case: TimmModelTestCase, tmpdir):
4950
run_timm_model_test(test_case, tmpdir)
51+
52+
def test_multi_output_model(
53+
):
54+
class DummyModel(torch.nn.Module):
55+
def __init__(self):
56+
super().__init__()
57+
self.input = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3, 1, 1), torch.nn.ReLU())
58+
self.head_0 = torch.nn.Sequential(torch.nn.Conv2d(3, 64, 3, 1, 1), torch.nn.ReLU())
59+
self.head_1 = torch.nn.Sequential(torch.nn.Conv2d(3, 128, 3, 1, 1), torch.nn.ReLU())
60+
self.head_2 = torch.nn.Sequential(torch.nn.Conv2d(3, 256, 3, 1, 1), torch.nn.ReLU())
61+
62+
def forward(self, x):
63+
x = self.input(x)
64+
out_0 = self.head_0(x)
65+
out_1 = self.head_1(x)
66+
out_2 = self.head_2(x)
67+
return out_0, out_1, out_2
68+
69+
torch_model = DummyModel()
70+
torch_model.eval()
71+
input_shape = (1, 3, 224, 224)
72+
dag, msg = _test_conversion_to_dag(torch_model, input_shape=input_shape)
73+
result_channel_pruning = _test_orbitalization_and_channel_removal(
74+
dag=deepcopy(dag),
75+
input_shape=input_shape,
76+
prob_removal=0.5,
77+
)
78+
assert result_channel_pruning['status'] == SUCCESS_RESULT
79+
assert result_channel_pruning['prunable_fraction'] > 0.0

torch_dag/core/dag_module.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,14 @@ def flatten(self, input_shape_for_verification: Optional[Tuple[int, ...]] = None
542542
if input_shape_for_verification:
543543
dag.eval()
544544
new_output = dag(x)
545-
assert torch.abs(reference_output - new_output).sum() == 0.0
545+
if isinstance(reference_output, (list, tuple)):
546+
assert isinstance(new_output, (list, tuple))
547+
assert len(reference_output) == len(new_output)
548+
for ref, new in zip(reference_output, new_output):
549+
assert torch.abs(ref - new).sum() == 0.0
550+
else:
551+
assert torch.abs(reference_output - new_output).sum() == 0.0
552+
546553

547554
# TODO: Remove after validation
548555
# self._update_inner_modules()

0 commit comments

Comments
 (0)