Skip to content

Commit b653412

Browse files
committed
simplify models
Former-commit-id: fd3dc23d1936ba21caeb8d21ceac051d96d37263 [formerly 22c422a] Former-commit-id: 45aa9929d6bb1e796cfa4c847258d1d6be1c580f
1 parent e030893 commit b653412

31 files changed

+384
-1270
lines changed

models/__init__.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,39 @@
11
import os, torch
22
import segmentation_models_pytorch as smp
33

4+
from .adscnet import ADSCNet
5+
from .aglnet import AGLNet
6+
from .bisenetv1 import BiSeNetv1
47
from .bisenetv2 import BiSeNetv2
5-
from .enet import ENet
6-
from .fastscnn import FastSCNN
7-
from .lednet import LEDNet
8-
from .linknet import LinkNet
8+
from .canet import CANet
9+
from .cfpnet import CFPNet
10+
from .cgnet import CGNet
911
from .contextnet import ContextNet
10-
from .pp_liteseg import PPLiteSeg
11-
from .ddrnet import DDRNet
12-
from .espnet import ESPNet
13-
from .erfnet import ERFNet
14-
from .segnet import SegNet
1512
from .dabnet import DABNet
16-
from .bisenetv1 import BiSeNetv1
17-
from .espnetv2 import ESPNetv2
18-
from .aglnet import AGLNet
19-
from .cgnet import CGNet
13+
from .ddrnet import DDRNet
14+
from .dfanet import DFANet
2015
from .edanet import EDANet
16+
from .enet import ENet
17+
from .erfnet import ERFNet
2118
from .esnet import ESNet
22-
from .adscnet import ADSCNet
23-
from .canet import CANet
24-
from .cfpnet import CFPNet
19+
from .espnet import ESPNet
20+
from .espnetv2 import ESPNetv2
2521
from .farseenet import FarSeeNet
26-
from .fpenet import FPENet
22+
from .fastscnn import FastSCNN
2723
from .fddwnet import FDDWNet
28-
from .mininet import MiniNet
29-
from .mininetv2 import MiniNetv2
24+
from .fpenet import FPENet
25+
from .fssnet import FSSNet
3026
from .icnet import ICNet
27+
from .lednet import LEDNet
28+
from .linknet import LinkNet
3129
from .liteseg import LiteSeg
30+
from .mininet import MiniNet
31+
from .mininetv2 import MiniNetv2
32+
from .pp_liteseg import PPLiteSeg
33+
from .segnet import SegNet
3234
from .shelfnet import ShelfNet
33-
from .fssnet import FSSNet
34-
from .swiftnet import SwiftNet
3535
from .sqnet import SQNet
36-
from .dfanet import DFANet
36+
from .swiftnet import SwiftNet
3737

3838

3939
decoder_hub = {'deeplabv3':smp.DeepLabV3, 'deeplabv3p':smp.DeepLabV3Plus, 'fpn':smp.FPN,
@@ -42,18 +42,20 @@
4242

4343

4444
def get_model(config):
45-
model_hub = {'bisenetv2':BiSeNetv2, 'enet':ENet, 'fastscnn':FastSCNN, 'lednet':LEDNet,
46-
'linknet':LinkNet, 'contextnet':ContextNet, 'ppliteseg':PPLiteSeg,
47-
'ddrnet':DDRNet, 'espnet':ESPNet, 'erfnet':ERFNet, 'segnet':SegNet,
48-
'dabnet':DABNet, 'bisenetv1':BiSeNetv1, 'espnetv2':ESPNetv2,
49-
'aglnet':AGLNet, 'cgnet':CGNet, 'edanet':EDANet, 'esnet':ESNet,
50-
'adscnet':ADSCNet, 'canet':CANet, 'cfpnet':CFPNet, 'farseenet':FarSeeNet,
51-
'fpenet':FPENet, 'fddwnet':FDDWNet, 'mininet':MiniNet, 'mininetv2':MiniNetv2,
52-
'icnet':ICNet, 'liteseg':LiteSeg, 'shelfnet':ShelfNet, 'fssnet':FSSNet,
53-
'swiftnet':SwiftNet, 'sqnet':SQNet, 'dfanet':DFANet,}
45+
model_hub = {'adscnet':ADSCNet, 'aglnet':AGLNet, 'bisenetv1':BiSeNetv1,
46+
'bisenetv2':BiSeNetv2, 'canet':CANet, 'cfpnet':CFPNet,
47+
'cgnet':CGNet, 'contextnet':ContextNet, 'dabnet':DABNet,
48+
'ddrnet':DDRNet, 'dfanet':DFANet, 'edanet':EDANet,
49+
'enet':ENet, 'erfnet':ERFNet, 'esnet':ESNet,
50+
'espnet':ESPNet, 'espnetv2':ESPNetv2, 'farseenet':FarSeeNet,
51+
'fastscnn':FastSCNN, 'fddwnet':FDDWNet, 'fpenet':FPENet,
52+
'fssnet':FSSNet, 'icnet':ICNet, 'lednet':LEDNet,
53+
'linknet':LinkNet, 'liteseg':LiteSeg, 'mininet':MiniNet,
54+
'mininetv2':MiniNetv2, 'ppliteseg':PPLiteSeg, 'segnet':SegNet,
55+
'shelfnet':ShelfNet, 'sqnet':SQNet, 'swiftnet':SwiftNet,}
5456

5557
# The following models currently support auxiliary heads
56-
aux_models = ['bisenetv2', 'contextnet', 'fastscnn', 'ddrnet', 'icnet']
58+
aux_models = ['bisenetv2', 'ddrnet', 'icnet']
5759

5860
if config.model == 'smp': # Use segmentation models pytorch
5961
if config.decoder not in decoder_hub:
@@ -66,8 +68,6 @@ def get_model(config):
6668
elif config.model in model_hub.keys():
6769
if config.model in aux_models:
6870
model = model_hub[config.model](num_class=config.num_class, use_aux=config.use_aux)
69-
elif config.model == 'ppliteseg':
70-
model = model_hub[config.model](num_class=config.num_class, encoder_type=config.encoder)
7171
else:
7272
model = model_hub[config.model](num_class=config.num_class)
7373

models/adscnet.py

Lines changed: 4 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010
import torch.nn as nn
1111

12-
from .modules import conv1x1, ConvBNAct, DWConvBNAct, Activation
12+
from .modules import conv1x1, ConvBNAct, DWConvBNAct, DeConvBNAct, Activation
1313

1414

1515
class ADSCNet(nn.Module):
@@ -25,16 +25,16 @@ def __init__(self, num_class=1, n_channel=3, act_type='relu6'):
2525
self.conv5 = ADSCModule(64, 2, act_type=act_type)
2626
self.ddcc = DDCC(128, [3, 5, 9, 13], act_type)
2727
self.up1 = nn.Sequential(
28-
Upsample(128, 64),
28+
DeConvBNAct(128, 64),
2929
ADSCModule(64, 1, act_type=act_type)
3030
)
3131
self.up2 = nn.Sequential(
3232
ADSCModule(64, 1, act_type=act_type),
33-
Upsample(64, 32)
33+
DeConvBNAct(64, 32)
3434
)
3535
self.up3 = nn.Sequential(
3636
ADSCModule(32, 1, act_type=act_type),
37-
Upsample(32, num_class)
37+
DeConvBNAct(32, num_class)
3838
)
3939

4040
def forward(self, x):
@@ -123,31 +123,3 @@ def forward(self, x):
123123
x = self.conv_last(x)
124124

125125
return x
126-
127-
128-
class Upsample(nn.Module):
129-
def __init__(self, in_channels, out_channels, scale_factor=2, kernel_size=None, padding=None,
130-
upsample_type='deconvolution', act_type='relu'):
131-
super(Upsample, self).__init__()
132-
if upsample_type == 'deconvolution':
133-
if kernel_size is None:
134-
kernel_size = 2*scale_factor - 1
135-
if padding is None:
136-
padding = (kernel_size - 1) // 2
137-
output_padding = scale_factor - 1
138-
self.up_conv = nn.Sequential(
139-
nn.ConvTranspose2d(in_channels, out_channels,
140-
kernel_size=kernel_size,
141-
stride=scale_factor, padding=padding,
142-
output_padding=output_padding),
143-
nn.BatchNorm2d(out_channels),
144-
Activation(act_type)
145-
)
146-
else:
147-
self.up_conv = nn.Sequential(
148-
ConvBNAct(in_channels, out_channels, 1, act_type=act_type),
149-
nn.Upsample(scale_factor=scale_factor, mode='bilinear')
150-
)
151-
152-
def forward(self, x):
153-
return self.up_conv(x)

models/aglnet.py

Lines changed: 3 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
import torch.nn as nn
1111
import torch.nn.functional as F
1212

13-
from .modules import conv1x1, ConvBNAct, Activation
13+
from .modules import conv1x1, ConvBNAct, Activation, channel_shuffle
14+
from .enet import InitialBlock as DownsamplingUnit
15+
from .lednet import SSnbtUnit
1416

1517

1618
class AGLNet(nn.Module):
@@ -54,18 +56,6 @@ def forward(self, x):
5456
return x
5557

5658

57-
class DownsamplingUnit(nn.Module):
58-
def __init__(self, in_channels, out_channels, act_type):
59-
super(DownsamplingUnit, self).__init__()
60-
self.conv = ConvBNAct(in_channels, out_channels - in_channels, 3, 2, act_type=act_type)
61-
self.pool = nn.MaxPool2d(3, 2, 1)
62-
63-
def forward(self, x):
64-
x = torch.cat([self.conv(x), self.pool(x)], dim=1)
65-
66-
return x
67-
68-
6959
def build_blocks(block, channels, num_block, dilations=[], act_type='relu'):
7060
if len(dilations) == 0:
7161
dilations = [1 for _ in range(num_block)]
@@ -79,62 +69,6 @@ def build_blocks(block, channels, num_block, dilations=[], act_type='relu'):
7969
return nn.Sequential(*layers)
8070

8171

82-
class SSnbtUnit(nn.Module):
83-
def __init__(self, channels, dilation, act_type):
84-
super(SSnbtUnit, self).__init__()
85-
assert channels % 2 == 0, 'Input channel should be multiple of 2.\n'
86-
split_channels = channels // 2
87-
self.split_channels = split_channels
88-
self.left_branch = nn.Sequential(
89-
nn.Conv2d(split_channels, split_channels, (3, 1), padding=(1,0)),
90-
Activation(act_type),
91-
ConvBNAct(split_channels, split_channels, (1, 3), act_type=act_type),
92-
nn.Conv2d(split_channels, split_channels, (3, 1),
93-
padding=(dilation,0), dilation=dilation),
94-
Activation(act_type),
95-
ConvBNAct(split_channels, split_channels, (1, 3), dilation=dilation, act_type=act_type),
96-
)
97-
98-
self.right_branch = nn.Sequential(
99-
nn.Conv2d(split_channels, split_channels, (1, 3), padding=(0,1)),
100-
Activation(act_type),
101-
ConvBNAct(split_channels, split_channels, (3, 1), act_type=act_type),
102-
nn.Conv2d(split_channels, split_channels, (1, 3),
103-
padding=(0,dilation), dilation=dilation),
104-
Activation(act_type),
105-
ConvBNAct(split_channels, split_channels, (3, 1), dilation=dilation, act_type=act_type),
106-
)
107-
self.act = Activation(act_type)
108-
109-
def forward(self, x):
110-
x_left = x[:, :self.split_channels].clone()
111-
x_right = x[:, self.split_channels:].clone()
112-
x_left = self.left_branch(x_left)
113-
x_right = self.right_branch(x_right)
114-
x_cat = torch.cat([x_left, x_right], dim=1)
115-
x += x_cat
116-
x = self.act(x)
117-
x = channel_shuffle(x)
118-
return x
119-
120-
121-
def channel_shuffle(x, groups=2):
122-
# Codes are borrowed from
123-
# https://github.com/pytorch/vision/blob/main/torchvision/models/shufflenetv2.py
124-
batchsize, num_channels, height, width = x.size()
125-
channels_per_group = num_channels // groups
126-
127-
# reshape
128-
x = x.view(batchsize, groups, channels_per_group, height, width)
129-
130-
x = torch.transpose(x, 1, 2).contiguous()
131-
132-
# flatten
133-
x = x.view(batchsize, -1, height, width)
134-
135-
return x
136-
137-
13872
class FAPM(nn.Module):
13973
def __init__(self, channels, act_type):
14074
super(FAPM, self).__init__()

models/backbone.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import torch.nn as nn
2+
3+
4+
class ResNet(nn.Module):
5+
# Load ResNet pretrained on ImageNet from torchvision, see
6+
# https://pytorch.org/vision/stable/models/resnet.html
7+
def __init__(self, resnet_type, pretrained=True):
8+
super(ResNet, self).__init__()
9+
from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152
10+
11+
resnet_hub = {'resnet18':resnet18, 'resnet34':resnet34, 'resnet50':resnet50,
12+
'resnet101':resnet101, 'resnet152':resnet152}
13+
if resnet_type not in resnet_hub:
14+
raise ValueError(f'Unsupported ResNet type: {resnet_type}.\n')
15+
16+
resnet = resnet_hub[resnet_type](pretrained=pretrained)
17+
self.conv1 = resnet.conv1
18+
self.bn1 = resnet.bn1
19+
self.relu = resnet.relu
20+
self.maxpool = resnet.maxpool
21+
self.layer1 = resnet.layer1
22+
self.layer2 = resnet.layer2
23+
self.layer3 = resnet.layer3
24+
self.layer4 = resnet.layer4
25+
26+
def forward(self, x):
27+
x = self.conv1(x) # 2x down
28+
x = self.bn1(x)
29+
x = self.relu(x)
30+
x = self.maxpool(x) # 4x down
31+
x1 = self.layer1(x)
32+
x2 = self.layer2(x1) # 8x down
33+
x3 = self.layer3(x2) # 16x down
34+
x4 = self.layer4(x3) # 32x down
35+
36+
return x1, x2, x3, x4
37+
38+
39+
class Mobilenetv2(nn.Module):
40+
def __init__(self, pretrained=True):
41+
super(Mobilenetv2, self).__init__()
42+
from torchvision.models import mobilenet_v2
43+
44+
mobilenet = mobilenet_v2(pretrained=pretrained)
45+
46+
self.layer1 = mobilenet.features[:4]
47+
self.layer2 = mobilenet.features[4:7]
48+
self.layer3 = mobilenet.features[7:14]
49+
self.layer4 = mobilenet.features[14:18]
50+
51+
def forward(self, x):
52+
x1 = self.layer1(x) # 4x down
53+
x2 = self.layer2(x1) # 8x down
54+
x3 = self.layer3(x2) # 16x down
55+
x4 = self.layer4(x3) # 32x down
56+
57+
return x1, x2, x3, x4

models/bisenetv1.py

Lines changed: 3 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
import torch.nn as nn
1010
import torch.nn.functional as F
1111

12-
from .modules import conv1x1, ConvBNAct
12+
from .modules import conv1x1, ConvBNAct, SegHead
13+
from .backbone import ResNet
1314

1415

1516
class BiSeNetv1(nn.Module):
@@ -57,7 +58,7 @@ def __init__(self, out_channels, backbone_type, act_type):
5758
self.conv_32 = conv1x1(channels[1], out_channels)
5859

5960
def forward(self, x):
60-
x_32, x_16 = self.backbone(x)
61+
_, _, x_16, x_32 = self.backbone(x)
6162
x_32_avg = self.pool(x_32)
6263
x_32 = self.arm_32(x_32)
6364
x_32 += x_32_avg
@@ -111,46 +112,3 @@ def forward(self, x_low, x_high):
111112
x = x + x_pool
112113

113114
return x
114-
115-
116-
class ResNet(nn.Module):
117-
# Load ResNet pretrained on ImageNet from torchvision, see
118-
# https://pytorch.org/vision/stable/models/resnet.html
119-
def __init__(self, resnet_type, pretrained=True):
120-
super(ResNet, self).__init__()
121-
from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152
122-
123-
resnet_hub = {'resnet18':resnet18, 'resnet34':resnet34, 'resnet50':resnet50,
124-
'resnet101':resnet101, 'resnet152':resnet152}
125-
if resnet_type not in resnet_hub:
126-
raise ValueError(f'Unsupported ResNet type: {resnet_type}.\n')
127-
128-
resnet = resnet_hub[resnet_type](pretrained=pretrained)
129-
self.conv1 = resnet.conv1
130-
self.bn1 = resnet.bn1
131-
self.relu = resnet.relu
132-
self.maxpool = resnet.maxpool
133-
self.layer1 = resnet.layer1
134-
self.layer2 = resnet.layer2
135-
self.layer3 = resnet.layer3
136-
self.layer4 = resnet.layer4
137-
138-
def forward(self, x):
139-
x = self.conv1(x) # 2x down
140-
x = self.bn1(x)
141-
x = self.relu(x)
142-
x = self.maxpool(x) # 4x down
143-
x = self.layer1(x)
144-
x = self.layer2(x) # 8x down
145-
x3 = self.layer3(x) # 16x down
146-
x = self.layer4(x3) # 32x down
147-
148-
return x, x3
149-
150-
151-
class SegHead(nn.Sequential):
152-
def __init__(self, in_channels, out_channels, act_type, hid_channels=128):
153-
super(SegHead, self).__init__(
154-
ConvBNAct(in_channels, hid_channels, 3, act_type=act_type),
155-
conv1x1(hid_channels, out_channels)
156-
)

0 commit comments

Comments
 (0)