Skip to content

Commit 48430f0

Browse files
committed
add STDC
Former-commit-id: 2d64fd1765c65934f53510384d954ced9e917b99 [formerly 0428310] Former-commit-id: ee8150c29f36fc54cf4c5bc53998abd5be51f8aa
1 parent d174805 commit 48430f0

File tree

10 files changed

+322
-84
lines changed

10 files changed

+322
-84
lines changed

README.md

Lines changed: 78 additions & 68 deletions
Large diffs are not rendered by default.

configs/base_config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@ def __init__(self,):
1212
self.decoder = None
1313
self.encoder_weights = 'imagenet'
1414

15+
# Detail Head (For STDC)
16+
self.use_detail_head = False
17+
self.detail_thrs = 0.1
18+
self.detail_loss_coef = 1.0
19+
self.dice_loss_coef = 1.0
20+
self.bce_loss_coef = 1.0
21+
1522
# Training
1623
self.total_epoch = 200
1724
self.base_lr = 0.01

configs/parser.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ def get_parser():
3232
'enet', 'erfnet', 'esnet', 'espnet', 'espnetv2', 'farseenet',
3333
'fastscnn', 'fddwnet', 'fpenet', 'fssnet', 'icnet', 'lednet',
3434
'linknet', 'liteseg', 'mininet', 'mininetv2', 'ppliteseg',
35-
'regseg', 'segnet', 'shelfnet', 'sqnet', 'swiftnet', 'smp'],
35+
'regseg', 'segnet', 'shelfnet', 'sqnet', 'stdc', 'swiftnet',
36+
'smp'],
3637
help='choose which model you want to use')
3738
parser.add_argument('--encoder', type=str, default=None,
3839
help='choose which encoder of SMP model you want to use (please refer to SMP repo)')

core/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from .base_trainer import BaseTrainer
22
from .seg_trainer import SegTrainer
3-
from .loss import get_loss_fn, kd_loss_fn
3+
from .loss import get_loss_fn, kd_loss_fn, get_detail_loss_fn

core/loss.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,42 @@ def forward(self, logits, labels):
1616
loss_hard = loss[loss > self.thresh]
1717
if loss_hard.numel() < n_min:
1818
loss_hard, _ = loss.topk(n_min)
19+
1920
return torch.mean(loss_hard)
2021

2122

23+
class DiceLoss(nn.Module):
24+
def __init__(self, smooth=1):
25+
super(DiceLoss, self).__init__()
26+
self.smooth = smooth
27+
28+
def forward(self, logits, labels):
29+
logits = torch.flatten(logits, 1)
30+
labels = torch.flatten(labels, 1)
31+
32+
intersection = torch.sum(logits * labels, dim=1)
33+
loss = 1 - ((2 * intersection + self.smooth) / (logits.sum(1) + labels.sum(1) + self.smooth))
34+
35+
return torch.mean(loss)
36+
37+
38+
class DetailLoss(nn.Module):
39+
'''Implement detail loss used in paper
40+
`Rethinking BiSeNet For Real-time Semantic Segmentation`'''
41+
def __init__(self, dice_loss_coef=1., bce_loss_coef=1., smooth=1):
42+
super(DetailLoss, self).__init__()
43+
self.dice_loss_coef = dice_loss_coef
44+
self.bce_loss_coef = bce_loss_coef
45+
self.dice_loss_fn = DiceLoss(smooth)
46+
self.bce_loss_fn = nn.BCEWithLogitsLoss()
47+
48+
def forward(self, logits, labels):
49+
loss = self.dice_loss_coef * self.dice_loss_fn(logits, labels) + \
50+
self.bce_loss_coef * self.bce_loss_fn(logits, labels)
51+
52+
return loss
53+
54+
2255
def get_loss_fn(config, device):
2356
if config.class_weights is None:
2457
weights = None
@@ -28,16 +61,22 @@ def get_loss_fn(config, device):
2861
if config.loss_type == 'ce':
2962
criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_index,
3063
reduction=config.reduction, weight=weights)
31-
64+
3265
elif config.loss_type == 'ohem':
3366
criterion = OhemCELoss(thresh=config.ohem_thrs, ignore_index=config.ignore_index)
3467

3568
else:
3669
raise NotImplementedError(f"Unsupport loss type: {config.loss_type}")
37-
70+
3871
return criterion
39-
40-
72+
73+
74+
def get_detail_loss_fn(config):
75+
detail_loss_fn = DetailLoss(dice_loss_coef=config.dice_loss_coef, bce_loss_coef=config.bce_loss_coef)
76+
77+
return detail_loss_fn
78+
79+
4180
def kd_loss_fn(config, outputs, outputsT):
4281
if config.kd_loss_type == 'kl_div':
4382
lossT = F.kl_div(F.log_softmax(outputs/config.kd_temperature, dim=1),

core/seg_trainer.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,13 @@ def __init__(self, config):
2121
self.teacher_model = get_teacher_model(config, self.device)
2222
self.metrics = get_seg_metrics(config).to(self.device)
2323

24+
if config.use_detail_head:
25+
from .loss import get_detail_loss_fn
26+
from models import LaplacianConv
27+
28+
self.laplacian_conv = LaplacianConv(self.device)
29+
self.detail_loss_fn = get_detail_loss_fn(config)
30+
2431
def train_one_epoch(self, config):
2532
self.model.train()
2633

@@ -57,14 +64,33 @@ def train_one_epoch(self, config):
5764
with amp.autocast(enabled=config.amp_training):
5865
loss += config.aux_coef[i] * self.loss_fn(preds_aux[i], masks_aux)
5966

67+
# Detail loss proposed in paper for model STDC
68+
elif config.use_detail_head:
69+
masks_detail = masks.unsqueeze(1).float()
70+
masks_detail = self.laplacian_conv(masks_detail)
71+
72+
with amp.autocast(enabled=config.amp_training):
73+
# Detail ground truth
74+
masks_detail = self.model.module.detail_conv(masks_detail)
75+
masks_detail[masks_detail > config.detail_thrs] = 1
76+
masks_detail[masks_detail <= config.detail_thrs] = 0
77+
detail_size = masks_detail.size()[2:]
78+
79+
preds, preds_detail = self.model(images, is_training=True)
80+
preds_detail = F.interpolate(preds_detail, detail_size, mode='bilinear', align_corners=True)
81+
loss_detail = self.detail_loss_fn(preds_detail, masks_detail)
82+
loss = self.loss_fn(preds, masks) + config.detail_loss_coef * loss_detail
83+
6084
else:
6185
with amp.autocast(enabled=config.amp_training):
6286
preds = self.model(images)
6387
loss = self.loss_fn(preds, masks)
6488

6589
if config.use_tb and self.main_rank:
6690
self.writer.add_scalar('train/loss', loss.detach(), self.train_itrs)
67-
91+
if config.use_detail_head:
92+
self.writer.add_scalar('train/loss_detail', loss_detail.detach(), self.train_itrs)
93+
6894
# Knowledge distillation
6995
if config.kd_training:
7096
with amp.autocast(enabled=config.amp_training):
@@ -75,8 +101,8 @@ def train_one_epoch(self, config):
75101
loss += config.kd_loss_coefficient * loss_kd
76102

77103
if config.use_tb and self.main_rank:
78-
self.writer.add_scalar('train/loss_kd', loss_kd.detach(), self.train_itrs)
79-
self.writer.add_scalar('train/loss_total', loss.detach(), self.train_itrs)
104+
self.writer.add_scalar('train/loss_kd', loss_kd.detach(), self.train_itrs)
105+
self.writer.add_scalar('train/loss_total', loss.detach(), self.train_itrs)
80106

81107
# Backward path
82108
self.scaler.scale(loss).backward()

models/__init__.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from .segnet import SegNet
3535
from .shelfnet import ShelfNet
3636
from .sqnet import SQNet
37+
from .stdc import STDC, LaplacianConv
3738
from .swiftnet import SwiftNet
3839

3940

@@ -54,10 +55,13 @@ def get_model(config):
5455
'linknet':LinkNet, 'liteseg':LiteSeg, 'mininet':MiniNet,
5556
'mininetv2':MiniNetv2, 'ppliteseg':PPLiteSeg, 'regseg':RegSeg,
5657
'segnet':SegNet, 'shelfnet':ShelfNet, 'sqnet':SQNet,
57-
'swiftnet':SwiftNet,}
58+
'stdc':STDC, 'swiftnet':SwiftNet,}
5859

5960
# The following models currently support auxiliary heads
6061
aux_models = ['bisenetv2', 'ddrnet', 'icnet']
62+
63+
# The following models currently support detail heads
64+
detail_head_models = ['stdc']
6165

6266
if config.model == 'smp': # Use segmentation models pytorch
6367
if config.decoder not in decoder_hub:
@@ -70,7 +74,14 @@ def get_model(config):
7074
elif config.model in model_hub.keys():
7175
if config.model in aux_models:
7276
model = model_hub[config.model](num_class=config.num_class, use_aux=config.use_aux)
77+
elif config.model in detail_head_models:
78+
model = model_hub[config.model](num_class=config.num_class, use_detail_head=config.use_detail_head, use_aux=config.use_aux)
7379
else:
80+
if config.use_aux:
81+
raise ValueError(f'Model {config.model} does not support auxiliary heads.\n')
82+
if config.use_detail_head:
83+
raise ValueError(f'Model {config.model} does not support detail heads.\n')
84+
7485
model = model_hub[config.model](num_class=config.num_class)
7586

7687
else:
@@ -83,7 +94,7 @@ def get_teacher_model(config, device):
8394
if config.kd_training:
8495
if not os.path.isfile(config.teacher_ckpt):
8596
raise ValueError(f'Could not find teacher checkpoint at path {config.teacher_ckpt}.')
86-
97+
8798
if config.teacher_decoder not in decoder_hub.keys():
8899
raise ValueError(f"Unsupported teacher decoder type: {config.teacher_decoder}")
89100

@@ -93,10 +104,10 @@ def get_teacher_model(config, device):
93104
teacher_ckpt = torch.load(config.teacher_ckpt, map_location=torch.device('cpu'))
94105
model.load_state_dict(teacher_ckpt['state_dict'])
95106
del teacher_ckpt
96-
107+
97108
model = model.to(device)
98109
model.eval()
99110
else:
100111
model = None
101-
112+
102113
return model

models/stdc.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
"""
2+
Paper: Rethinking BiSeNet For Real-time Semantic Segmentation
3+
Url: https://arxiv.org/abs/2104.13188
4+
Create by: zh320
5+
Date: 2024/01/20
6+
"""
7+
8+
import torch
9+
import torch.nn as nn
10+
import torch.nn.functional as F
11+
12+
from .modules import conv1x1, ConvBNAct, SegHead
13+
from .bisenetv1 import AttentionRefinementModule, FeatureFusionModule
14+
15+
16+
class STDC(nn.Module):
17+
def __init__(self, num_class=1, n_channel=3, encoder_type='stdc1', use_detail_head=False, use_aux=False,
18+
act_type='relu'):
19+
super(STDC, self).__init__()
20+
repeat_times_hub = {'stdc1': [1,1,1], 'stdc2': [3,4,2]}
21+
if encoder_type not in repeat_times_hub.keys():
22+
raise ValueError('Unsupported encoder type.\n')
23+
repeat_times = repeat_times_hub[encoder_type]
24+
assert not use_detail_head * use_aux, 'Currently only support either aux-head or detail head.\n'
25+
self.use_detail_head = use_detail_head
26+
self.use_aux = use_aux
27+
28+
self.stage1 = ConvBNAct(n_channel, 32, 3, 2)
29+
self.stage2 = ConvBNAct(32, 64, 3, 2)
30+
self.stage3 = self._make_stage(64, 256, repeat_times[0], act_type)
31+
self.stage4 = self._make_stage(256, 512, repeat_times[1], act_type)
32+
self.stage5 = self._make_stage(512, 1024, repeat_times[2], act_type)
33+
34+
if use_aux:
35+
self.aux_head3 = SegHead(256, num_class, act_type)
36+
self.aux_head4 = SegHead(512, num_class, act_type)
37+
self.aux_head5 = SegHead(1024, num_class, act_type)
38+
39+
self.pool = nn.AdaptiveAvgPool2d(1)
40+
self.arm4 = AttentionRefinementModule(512)
41+
self.arm5 = AttentionRefinementModule(1024)
42+
self.conv4 = conv1x1(512, 256)
43+
self.conv5 = conv1x1(1024, 256)
44+
45+
self.ffm = FeatureFusionModule(256+256, 128, act_type)
46+
47+
self.seg_head = SegHead(128, num_class, act_type)
48+
if use_detail_head:
49+
self.detail_head = SegHead(256, 1, act_type)
50+
self.detail_conv = conv1x1(3, 1)
51+
52+
def _make_stage(self, in_channels, out_channels, repeat_times, act_type):
53+
layers = [STDCModule(in_channels, out_channels, 2, act_type)]
54+
55+
for _ in range(repeat_times):
56+
layers.append(STDCModule(out_channels, out_channels, 1, act_type))
57+
return nn.Sequential(*layers)
58+
59+
def forward(self, x, is_training=False):
60+
size = x.size()[2:]
61+
62+
x = self.stage1(x)
63+
x = self.stage2(x)
64+
x3 = self.stage3(x)
65+
if self.use_aux:
66+
aux3 = self.aux_head3(x3)
67+
68+
x4 = self.stage4(x3)
69+
if self.use_aux:
70+
aux4 = self.aux_head4(x4)
71+
72+
x5 = self.stage5(x4)
73+
if self.use_aux:
74+
aux5 = self.aux_head5(x5)
75+
76+
x5_pool = self.pool(x5)
77+
x5 = x5_pool + self.arm5(x5)
78+
x5 = self.conv5(x5)
79+
x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=True)
80+
81+
x4 = self.arm4(x4)
82+
x4 = self.conv4(x4)
83+
x4 += x5
84+
x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=True)
85+
86+
x = self.ffm(x4, x3)
87+
x = self.seg_head(x)
88+
x = F.interpolate(x, size, mode='bilinear', align_corners=True)
89+
90+
if self.use_detail_head and is_training:
91+
x_detail = self.detail_head(x3)
92+
return x, x_detail
93+
elif self.use_aux and is_training:
94+
return x, (aux3, aux4, aux5)
95+
else:
96+
return x
97+
98+
99+
class STDCModule(nn.Module):
100+
def __init__(self, in_channels, out_channels, stride, act_type):
101+
super(STDCModule, self).__init__()
102+
if out_channels % 8 != 0:
103+
raise ValueError('Output channel should be evenly divided by 8.\n')
104+
if stride not in [1, 2]:
105+
raise ValueError(f'Unsupported stride: {stride}\n')
106+
107+
self.stride = stride
108+
self.block1 = ConvBNAct(in_channels, out_channels//2, 1)
109+
self.block2 = ConvBNAct(out_channels//2, out_channels//4, 3, stride)
110+
if self.stride == 2:
111+
self.pool = nn.AvgPool2d(3, 2, 1)
112+
self.block3 = ConvBNAct(out_channels//4, out_channels//8, 3)
113+
self.block4 = ConvBNAct(out_channels//8, out_channels//8, 3)
114+
115+
def forward(self, x):
116+
x1 = self.block1(x)
117+
x2 = self.block2(x1)
118+
if self.stride == 2:
119+
x1 = self.pool(x1)
120+
x3 = self.block3(x2)
121+
x4 = self.block4(x3)
122+
123+
return torch.cat([x1, x2, x3, x4], dim=1)
124+
125+
126+
class LaplacianConv(nn.Module):
127+
def __init__(self, device):
128+
super(LaplacianConv, self).__init__()
129+
self.laplacian_kernel = torch.tensor([[[[-1.,-1.,-1.],[-1.,8.,-1.],[-1.,-1.,-1.]]]]).to(device)
130+
131+
def forward(self, lbl):
132+
size = lbl.size()[2:]
133+
lbl_1x = F.conv2d(lbl, self.laplacian_kernel, stride=1, padding=1)
134+
lbl_2x = F.conv2d(lbl, self.laplacian_kernel, stride=2, padding=1)
135+
lbl_4x = F.conv2d(lbl, self.laplacian_kernel, stride=4, padding=1)
136+
137+
lbl_2x = F.interpolate(lbl_2x, size, mode='nearest')
138+
lbl_4x = F.interpolate(lbl_4x, size, mode='nearest')
139+
140+
lbl = torch.cat([lbl_1x, lbl_2x, lbl_4x], dim=1)
141+
142+
return lbl

tools/get_model_infos.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from os import path
33
sys.path.append( path.dirname( path.dirname( path.abspath(__file__) ) ) )
44

5-
from configs import MyConfig
5+
from configs import MyConfig, load_parser
66
from models import get_model
77

88

@@ -29,5 +29,6 @@ def cal_model_params(config, imgw=1024, imgh=512):
2929

3030
if __name__ == '__main__':
3131
config = MyConfig()
32+
config = load_parser(config)
3233

3334
cal_model_params(config)

tools/test_speed.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from os import path
33
sys.path.append( path.dirname( path.dirname( path.abspath(__file__) ) ) )
44

5-
from configs import MyConfig
5+
from configs import MyConfig, load_parser
66
from models import get_model
77

88

@@ -63,5 +63,6 @@ def test_model_speed(config, ratio=0.5, imgw=2048, imgh=1024, iterations=None):
6363

6464
if __name__ == '__main__':
6565
config = MyConfig()
66-
66+
config = load_parser(config)
67+
6768
test_model_speed(config)

0 commit comments

Comments
 (0)