Skip to content

Commit aae34d0

Browse files
authored
Merge pull request #863 from ProGamerGov/optim-wip-helpers-1
Optim-wip: Move inception test helpers to separate file
2 parents 82ab88d + f97d2e3 commit aae34d0

File tree

2 files changed

+21
-35
lines changed

2 files changed

+21
-35
lines changed

tests/optim/helpers/models.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from typing import Type
2+
3+
import torch
4+
5+
6+
def check_layer_in_model(model: torch.nn.Module, layer: Type[torch.nn.Module]) -> bool:
7+
for _, child in model._modules.items():
8+
if child is None:
9+
continue
10+
if isinstance(child, layer) or check_layer_in_model(child, layer):
11+
return True
12+
return False

tests/optim/models/test_models.py renamed to tests/optim/models/test_inceptionv1.py

Lines changed: 9 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,12 @@
11
#!/usr/bin/env python3
22
import unittest
3-
from typing import Type
43

54
import torch
65

76
from captum.optim.models import googlenet
87
from captum.optim.models._common import RedirectedReluLayer, SkipLayer
98
from tests.helpers.basic import BaseTest, assertTensorAlmostEqual
10-
11-
12-
def _check_layer_in_model(
13-
self,
14-
model: torch.nn.Module,
15-
layer: Type[torch.nn.Module],
16-
) -> None:
17-
def check_for_layer_in_model(model, layer) -> bool:
18-
for name, child in model._modules.items():
19-
if child is not None:
20-
if isinstance(child, layer):
21-
return True
22-
if check_for_layer_in_model(child, layer):
23-
return True
24-
return False
25-
26-
self.assertTrue(check_for_layer_in_model(model, layer))
27-
28-
29-
def _check_layer_not_in_model(
30-
self, model: torch.nn.Module, layer: Type[torch.nn.Module]
31-
) -> None:
32-
for name, child in model._modules.items():
33-
if child is not None:
34-
self.assertNotIsInstance(child, layer)
35-
_check_layer_not_in_model(self, child, layer)
9+
from tests.optim.helpers.models import check_layer_in_model
3610

3711

3812
class TestInceptionV1(BaseTest):
@@ -43,7 +17,7 @@ def test_load_inceptionv1_with_redirected_relu(self) -> None:
4317
+ " due to insufficient Torch version."
4418
)
4519
model = googlenet(pretrained=True, replace_relus_with_redirectedrelu=True)
46-
_check_layer_in_model(self, model, RedirectedReluLayer)
20+
self.assertTrue(check_layer_in_model(model, RedirectedReluLayer))
4721

4822
def test_load_inceptionv1_no_redirected_relu(self) -> None:
4923
if torch.__version__ <= "1.2.0":
@@ -52,8 +26,8 @@ def test_load_inceptionv1_no_redirected_relu(self) -> None:
5226
+ " due to insufficient Torch version."
5327
)
5428
model = googlenet(pretrained=True, replace_relus_with_redirectedrelu=False)
55-
_check_layer_not_in_model(self, model, RedirectedReluLayer)
56-
_check_layer_in_model(self, model, torch.nn.ReLU)
29+
self.assertFalse(check_layer_in_model(model, RedirectedReluLayer))
30+
self.assertTrue(check_layer_in_model(model, torch.nn.ReLU))
5731

5832
def test_load_inceptionv1_linear(self) -> None:
5933
if torch.__version__ <= "1.2.0":
@@ -62,11 +36,11 @@ def test_load_inceptionv1_linear(self) -> None:
6236
+ " due to insufficient Torch version."
6337
)
6438
model = googlenet(pretrained=True, use_linear_modules_only=True)
65-
_check_layer_not_in_model(self, model, RedirectedReluLayer)
66-
_check_layer_not_in_model(self, model, torch.nn.ReLU)
67-
_check_layer_not_in_model(self, model, torch.nn.MaxPool2d)
68-
_check_layer_in_model(self, model, SkipLayer)
69-
_check_layer_in_model(self, model, torch.nn.AvgPool2d)
39+
self.assertFalse(check_layer_in_model(model, RedirectedReluLayer))
40+
self.assertFalse(check_layer_in_model(model, torch.nn.ReLU))
41+
self.assertFalse(check_layer_in_model(model, torch.nn.MaxPool2d))
42+
self.assertTrue(check_layer_in_model(model, SkipLayer))
43+
self.assertTrue(check_layer_in_model(model, torch.nn.AvgPool2d))
7044

7145
def test_transform_inceptionv1(self) -> None:
7246
if torch.__version__ <= "1.2.0":

0 commit comments

Comments
 (0)