From 608dc720eea437b42f011bbba12f52d718c52f7d Mon Sep 17 00:00:00 2001 From: voldemortX Date: Mon, 10 May 2021 15:00:29 +0800 Subject: [PATCH 1/7] Expose ERFNet as feature extractor --- torchvision_models/segmentation/erfnet.py | 20 +++++++++++++------ .../segmentation/segmentation.py | 6 ++++-- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/torchvision_models/segmentation/erfnet.py b/torchvision_models/segmentation/erfnet.py index 928c6431..f2145213 100644 --- a/torchvision_models/segmentation/erfnet.py +++ b/torchvision_models/segmentation/erfnet.py @@ -167,14 +167,15 @@ def forward(self, input): # ERFNet class ERFNet(nn.Module): def __init__(self, num_classes, encoder=None, num_lanes=0, dropout_1=0.03, dropout_2=0.3, flattened_size=3965, - scnn=False): + scnn=False, encoder_only=False): super().__init__() if encoder is None: self.encoder = Encoder(num_classes=num_classes, dropout_1=dropout_1, dropout_2=dropout_2) else: self.encoder = encoder - self.decoder = Decoder(num_classes) + # Only encoder (to be used as backbone) + self.decoder = None if encoder_only else Decoder(num_classes) if scnn: self.spatial_conv = SpatialConv() @@ -187,16 +188,23 @@ def __init__(self, num_classes, encoder=None, num_lanes=0, dropout_1=0.03, dropo else: self.lane_classifier = None - def forward(self, input, only_encode=False): + def forward(self, inputs, only_encode=False): + # only_encode=True is for the pre-training step of 2-step segmentation training, + # in order to match with the original implementation. + # If encoder is used as feature extractor, set encoder_only=True in class init, but do not change this variable out = OrderedDict() if only_encode: - return self.encoder.forward(input, predict=True) + return self.encoder.forward(inputs, predict=True) else: - output = self.encoder(input) # predict=False by default + output = self.encoder(inputs) # predict=False by default + if self.spatial_conv is not None: output = self.spatial_conv(output) - out['out'] = self.decoder.forward(output) + + if self.decoder is not None: + out['out'] = self.decoder.forward(output) if self.lane_classifier is not None: out['lane'] = self.lane_classifier(output) + return out diff --git a/torchvision_models/segmentation/segmentation.py b/torchvision_models/segmentation/segmentation.py index 5830a39d..bbdc1476 100644 --- a/torchvision_models/segmentation/segmentation.py +++ b/torchvision_models/segmentation/segmentation.py @@ -228,14 +228,16 @@ def deeplabv3_resnet101(pretrained=False, progress=True, def erfnet_resnet(pretrained_weights='erfnet_encoder_pretrained.pth.tar', num_classes=19, num_lanes=0, - dropout_1=0.03, dropout_2=0.3, flattened_size=3965, scnn=False): + dropout_1=0.03, dropout_2=0.3, flattened_size=3965, scnn=False, encoder_only=False): """Constructs a ERFNet model with ResNet-style backbone. Args: pretrained_weights (str): If not None, load ImageNet pre-trained weights from this filename + encoder_only (bool): If True, only encoder is returned as a feature extractor, ImageNet weights loading + will not be affected """ net = ERFNet(num_classes=num_classes, encoder=None, num_lanes=num_lanes, dropout_1=dropout_1, dropout_2=dropout_2, - flattened_size=flattened_size, scnn=scnn) + flattened_size=flattened_size, scnn=scnn, encoder_only=encoder_only) if pretrained_weights is not None: # Load ImageNet pre-trained weights saved_weights = load(pretrained_weights)['state_dict'] original_weights = net.state_dict() From 64bf16b6922c5a9172f25810697564bdbc0c7e58 Mon Sep 17 00:00:00 2001 From: voldemortX Date: Mon, 10 May 2021 15:10:34 +0800 Subject: [PATCH 2/7] PRNet model implemented --- torchvision_models/__init__.py | 1 + .../lane_detection/common_models.py | 24 +++++++ torchvision_models/lane_detection/prnet.py | 70 +++++++++++++++++++ 3 files changed, 95 insertions(+) create mode 100644 torchvision_models/lane_detection/prnet.py diff --git a/torchvision_models/__init__.py b/torchvision_models/__init__.py index 3dec5932..8c499ead 100644 --- a/torchvision_models/__init__.py +++ b/torchvision_models/__init__.py @@ -1,2 +1,3 @@ +# The code structure is based on an older version of TorchVision from .resnet import * from . import segmentation diff --git a/torchvision_models/lane_detection/common_models.py b/torchvision_models/lane_detection/common_models.py index 6f035c58..84980cdf 100644 --- a/torchvision_models/lane_detection/common_models.py +++ b/torchvision_models/lane_detection/common_models.py @@ -84,6 +84,30 @@ def forward(self, input): return output +# SCNN_D (more efficient and used by a lot of people nowadays) +class SCNN_D(nn.Module): + def __init__(self, num_channels=128): + super().__init__() + self.conv_d = nn.Conv2d(num_channels, num_channels, (1, 9), padding=(0, 4)) + self._adjust_initializations(num_channels=num_channels) + + def _adjust_initializations(self, num_channels=128): + # https://github.com/XingangPan/SCNN/issues/82 + bound = math.sqrt(2.0 / (num_channels * 9 * 5)) + nn.init.uniform_(self.conv_d.weight, -bound, bound) + + def forward(self, input): + output = input + + # First one remains unchanged (according to the original paper), why not add a relu afterwards? + # Update and send to next + # Down + for i in range(1, output.shape[2]): + output[:, :, i:i + 1, :].add_(F.relu(self.conv_d(output[:, :, i - 1:i, :]))) + + return output + + # Typical lane existence head originated from the SCNN paper class SimpleLaneExist(nn.Module): def __init__(self, num_output, flattened_size=4500): diff --git a/torchvision_models/lane_detection/prnet.py b/torchvision_models/lane_detection/prnet.py new file mode 100644 index 00000000..5ef5b20a --- /dev/null +++ b/torchvision_models/lane_detection/prnet.py @@ -0,0 +1,70 @@ +# Implementation of Polynomial Regression Network based on the original paper (PRNet): +# http://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123630698.pdf +import torch.nn as nn +from collections import OrderedDict +from common_models import RESAReducer, SCNN_D +from .. import resnet +from ..segmentation import erfnet_resnet +from .._utils import IntermediateLayerGetter + + +# One convolution layer for each branch +# The kernel size 3x3 is an educated guess, the 3 branches are implemented separately for future flexibility +class PolynomialBranch(nn.Module): + def __init__(self, in_channels, order=3): + super(PolynomialBranch, self).__init__() + self.conv = nn.Conv2d(in_channels, order + 1, kernel_size=3, stride=1, padding=1, bias=False) + + def forward(self, inputs): + return self.conv(inputs) + + +class InitializationBranch(nn.Module): + def __init__(self, in_channels): + super(InitializationBranch, self).__init__() + self.conv = nn.Conv2d(in_channels, 1, kernel_size=3, stride=1, padding=1, bias=False) + + def forward(self, inputs): + return self.conv(inputs) + + +class HeightBranch(nn.Module): + def __init__(self, in_channels): + super(HeightBranch, self).__init__() + self.conv = nn.Conv2d(in_channels, 1, kernel_size=3, stride=1, padding=1, bias=False) + + def forward(self, inputs): + return self.conv(inputs) + + +class PRNet(nn.Module): + def __init__(self, backbone_name, dropout_1=0.3, dropout_2=0.03, order=3): + super(PRNet, self).__init__() + if backbone_name == 'erfnet': + self.backbone = erfnet_resnet(dropout_1=dropout_1, dropout_2=dropout_2, encoder_only=True) + in_channels = 128 + else: + in_channels = 2048 if backbone_name == 'resnet50' or backbone_name == 'resnet101' else 512 + backbone = resnet.__dict__[backbone_name]( + pretrained=True, + replace_stride_with_dilation=[False, True, True]) + return_layers = {'layer4': 'out'} + self.backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) + + self.channel_reducer = RESAReducer(in_channels=in_channels) + self.spatial_conv = SCNN_D() + self.polynomial_branch = PolynomialBranch(in_channels=128, order=order) + self.initialization_branch = InitializationBranch(in_channels=128) + self.height_branch = HeightBranch(in_channels=128) + + def forward(self, inputs): + # Encoder (8x down-sampling) -> channel reduction (128, another educated guess) -> SCNN_D -> 3 branches + out = OrderedDict() + x = self.backbone(inputs) + x = self.channel_reducer(x) + x = self.spatial_conv(x) + out['polynomials'] = self.polynomial_branch(x) + out['initializations'] = self.initialization_branch(x) + out['heights'] = self.height_branch(x) + + return out From f87ca1a72e0382ae9658d2f5756fccef4fc257a3 Mon Sep 17 00:00:00 2001 From: voldemortX Date: Mon, 10 May 2021 15:11:01 +0800 Subject: [PATCH 3/7] some formatting problem in enet --- torchvision_models/segmentation/enet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision_models/segmentation/enet.py b/torchvision_models/segmentation/enet.py index f3cc5680..4dcc61aa 100644 --- a/torchvision_models/segmentation/enet.py +++ b/torchvision_models/segmentation/enet.py @@ -7,6 +7,7 @@ from torch.nn.parameter import Parameter from ..lane_detection.common_models import EDLaneExist + class InitialBlock(nn.Module): """The initial block is composed of two branches: 1. a main branch which performs a regular convolution with stride 2; @@ -696,7 +697,6 @@ def forward(self, x): stage2_input_size, input_size) out['out'] = x - return out # net = ENet(num_classes=19,encoder_only=True) From 31e9d830dd8e0b3ba682fb9774f0f2740cedf1bb Mon Sep 17 00:00:00 2001 From: voldemortX Date: Mon, 10 May 2021 16:11:14 +0800 Subject: [PATCH 4/7] Loss framework --- torchvision_models/lane_detection/prnet.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchvision_models/lane_detection/prnet.py b/torchvision_models/lane_detection/prnet.py index 5ef5b20a..8f05275b 100644 --- a/torchvision_models/lane_detection/prnet.py +++ b/torchvision_models/lane_detection/prnet.py @@ -11,7 +11,7 @@ # One convolution layer for each branch # The kernel size 3x3 is an educated guess, the 3 branches are implemented separately for future flexibility class PolynomialBranch(nn.Module): - def __init__(self, in_channels, order=3): + def __init__(self, in_channels, order=2): super(PolynomialBranch, self).__init__() self.conv = nn.Conv2d(in_channels, order + 1, kernel_size=3, stride=1, padding=1, bias=False) @@ -37,8 +37,9 @@ def forward(self, inputs): return self.conv(inputs) +# Currently supported backbones: ERFNet, ResNets class PRNet(nn.Module): - def __init__(self, backbone_name, dropout_1=0.3, dropout_2=0.03, order=3): + def __init__(self, backbone_name, dropout_1=0.3, dropout_2=0.03, order=2): super(PRNet, self).__init__() if backbone_name == 'erfnet': self.backbone = erfnet_resnet(dropout_1=dropout_1, dropout_2=dropout_2, encoder_only=True) From 8d975a4c9b52e51bdf8c99560d4dc6193841bb5a Mon Sep 17 00:00:00 2001 From: voldemortX Date: Mon, 10 May 2021 16:11:37 +0800 Subject: [PATCH 5/7] Loss framework --- utils/losses/pr_loss.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 utils/losses/pr_loss.py diff --git a/utils/losses/pr_loss.py b/utils/losses/pr_loss.py new file mode 100644 index 00000000..7c946521 --- /dev/null +++ b/utils/losses/pr_loss.py @@ -0,0 +1,30 @@ +# Loss for PRNet +import torch +from torch.nn import functional as F +from ._utils import WeightedLoss + + +class PRLoss(WeightedLoss): + __constants__ = ['reduction'] + ignore_index: int + + def __init__(self, polynomial_weight=1, initialization_weight=1, height_weight=0.1, beta=0.005, m=20, + weight=None, size_average=None, reduce=None, reduction='mean'): + super(PRLoss, self).__init__(weight, size_average, reduce, reduction) + self.polynomial_weight = polynomial_weight + self.initialization_weight = initialization_weight + self.height_weight = height_weight + self.beta = beta # Beta for smoothed L1 loss + self.m = m # Number of sample points to calculate polynomial regression loss + + def forward(self, inputs, targets, net): + outputs = net(inputs) + pass + + @staticmethod + def beta_smoothed_l1_loss(inputs, targets, beta=0.005): + # Smoothed L1 loss with a hyper-parameter (as in PRNet paper) + # The original torch F.smooth_l1_loss() is equivalent to beta=1 + t = torch.abs(inputs - targets) + + return torch.where(t < beta, 0.5 * t ** 2 / beta, t - 0.5 * beta) From 730dd637d69d668a8baba36eaa583a15760b0908 Mon Sep 17 00:00:00 2001 From: voldemortX Date: Tue, 11 May 2021 21:29:33 +0800 Subject: [PATCH 6/7] sampling function --- utils/losses/pr_loss.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/utils/losses/pr_loss.py b/utils/losses/pr_loss.py index 7c946521..11ae424b 100644 --- a/utils/losses/pr_loss.py +++ b/utils/losses/pr_loss.py @@ -4,6 +4,21 @@ from ._utils import WeightedLoss +def polynomial_curve_without_projection(coefficients, y): + # Polynomial curve model (arbitrary order) + # Return x coordinates + # coefficients: [..., m], ... means arbitrary number of leading dimensions + # m: number of coefficients, order increasing + # y: [N] + original_shape = coefficients.shape + coefficients = coefficients.unsqueeze(-1).expand(*original_shape, y.shape[0]) + x = coefficients[..., 0, :] + for i in range(1, len(original_shape[-1])): + x += coefficients[..., i, :] * y ** i + + return x # [..., N] + + class PRLoss(WeightedLoss): __constants__ = ['reduction'] ignore_index: int @@ -17,8 +32,10 @@ def __init__(self, polynomial_weight=1, initialization_weight=1, height_weight=0 self.beta = beta # Beta for smoothed L1 loss self.m = m # Number of sample points to calculate polynomial regression loss - def forward(self, inputs, targets, net): + def forward(self, inputs, targets, masks, net): + # masks: True for polynomial points (which have height & polynomial regression losses) outputs = net(inputs) + pass @staticmethod From 80688badcda8b1d43b2faaf5f49441598a28bd7b Mon Sep 17 00:00:00 2001 From: voldemortX Date: Tue, 25 May 2021 17:13:02 +0800 Subject: [PATCH 7/7] curve function bug fix --- utils/losses/pr_loss.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/utils/losses/pr_loss.py b/utils/losses/pr_loss.py index 11ae424b..dc3f1bc8 100644 --- a/utils/losses/pr_loss.py +++ b/utils/losses/pr_loss.py @@ -5,18 +5,17 @@ def polynomial_curve_without_projection(coefficients, y): - # Polynomial curve model (arbitrary order) + # Arbitrary polynomial curve function # Return x coordinates - # coefficients: [..., m], ... means arbitrary number of leading dimensions + # coefficients: [d1, d2, ... , m] # m: number of coefficients, order increasing - # y: [N] - original_shape = coefficients.shape - coefficients = coefficients.unsqueeze(-1).expand(*original_shape, y.shape[0]) - x = coefficients[..., 0, :] - for i in range(1, len(original_shape[-1])): - x += coefficients[..., i, :] * y ** i - - return x # [..., N] + # y: [d1, d2, ... , N] + y = y.permute(-1, *[i for i in range(len(y.shape) - 1)]) + x = coefficients[..., 0] + for i in range(1, coefficients.shape[-1]): + x += coefficients[..., i] * y ** i + + return x.permute(*[i + 1 for i in range(len(x.shape) - 1)], 0) # [d1, d2, ... , N] class PRLoss(WeightedLoss):