1
1
#!/usr/bin/env python3
2
2
import unittest
3
- from typing import Type
4
3
5
4
import torch
6
5
7
6
from captum .optim .models import googlenet
8
7
from captum .optim .models ._common import RedirectedReluLayer , SkipLayer
9
8
from tests .helpers .basic import BaseTest , assertTensorAlmostEqual
10
-
11
-
12
- def _check_layer_in_model (
13
- self ,
14
- model : torch .nn .Module ,
15
- layer : Type [torch .nn .Module ],
16
- ) -> None :
17
- def check_for_layer_in_model (model , layer ) -> bool :
18
- for name , child in model ._modules .items ():
19
- if child is not None :
20
- if isinstance (child , layer ):
21
- return True
22
- if check_for_layer_in_model (child , layer ):
23
- return True
24
- return False
25
-
26
- self .assertTrue (check_for_layer_in_model (model , layer ))
27
-
28
-
29
- def _check_layer_not_in_model (
30
- self , model : torch .nn .Module , layer : Type [torch .nn .Module ]
31
- ) -> None :
32
- for name , child in model ._modules .items ():
33
- if child is not None :
34
- self .assertNotIsInstance (child , layer )
35
- _check_layer_not_in_model (self , child , layer )
9
+ from tests .optim .helpers .models import check_layer_in_model
36
10
37
11
38
12
class TestInceptionV1 (BaseTest ):
@@ -43,7 +17,7 @@ def test_load_inceptionv1_with_redirected_relu(self) -> None:
43
17
+ " due to insufficient Torch version."
44
18
)
45
19
model = googlenet (pretrained = True , replace_relus_with_redirectedrelu = True )
46
- _check_layer_in_model ( self , model , RedirectedReluLayer )
20
+ self . assertTrue ( check_layer_in_model ( model , RedirectedReluLayer ) )
47
21
48
22
def test_load_inceptionv1_no_redirected_relu (self ) -> None :
49
23
if torch .__version__ <= "1.2.0" :
@@ -52,8 +26,8 @@ def test_load_inceptionv1_no_redirected_relu(self) -> None:
52
26
+ " due to insufficient Torch version."
53
27
)
54
28
model = googlenet (pretrained = True , replace_relus_with_redirectedrelu = False )
55
- _check_layer_not_in_model ( self , model , RedirectedReluLayer )
56
- _check_layer_in_model ( self , model , torch .nn .ReLU )
29
+ self . assertFalse ( check_layer_in_model ( model , RedirectedReluLayer ) )
30
+ self . assertTrue ( check_layer_in_model ( model , torch .nn .ReLU ) )
57
31
58
32
def test_load_inceptionv1_linear (self ) -> None :
59
33
if torch .__version__ <= "1.2.0" :
@@ -62,11 +36,11 @@ def test_load_inceptionv1_linear(self) -> None:
62
36
+ " due to insufficient Torch version."
63
37
)
64
38
model = googlenet (pretrained = True , use_linear_modules_only = True )
65
- _check_layer_not_in_model ( self , model , RedirectedReluLayer )
66
- _check_layer_not_in_model ( self , model , torch .nn .ReLU )
67
- _check_layer_not_in_model ( self , model , torch .nn .MaxPool2d )
68
- _check_layer_in_model ( self , model , SkipLayer )
69
- _check_layer_in_model ( self , model , torch .nn .AvgPool2d )
39
+ self . assertFalse ( check_layer_in_model ( model , RedirectedReluLayer ) )
40
+ self . assertFalse ( check_layer_in_model ( model , torch .nn .ReLU ) )
41
+ self . assertFalse ( check_layer_in_model ( model , torch .nn .MaxPool2d ) )
42
+ self . assertTrue ( check_layer_in_model ( model , SkipLayer ) )
43
+ self . assertTrue ( check_layer_in_model ( model , torch .nn .AvgPool2d ) )
70
44
71
45
def test_transform_inceptionv1 (self ) -> None :
72
46
if torch .__version__ <= "1.2.0" :
0 commit comments