diff --git a/torchvision_models/__init__.py b/torchvision_models/__init__.py index 3dec5932..8c499ead 100644 --- a/torchvision_models/__init__.py +++ b/torchvision_models/__init__.py @@ -1,2 +1,3 @@ +# The code structure is based on an older version of TorchVision from .resnet import * from . import segmentation diff --git a/torchvision_models/lane_detection/common_models.py b/torchvision_models/lane_detection/common_models.py index 6f035c58..84980cdf 100644 --- a/torchvision_models/lane_detection/common_models.py +++ b/torchvision_models/lane_detection/common_models.py @@ -84,6 +84,30 @@ def forward(self, input): return output +# SCNN_D (more efficient and used by a lot of people nowadays) +class SCNN_D(nn.Module): + def __init__(self, num_channels=128): + super().__init__() + self.conv_d = nn.Conv2d(num_channels, num_channels, (1, 9), padding=(0, 4)) + self._adjust_initializations(num_channels=num_channels) + + def _adjust_initializations(self, num_channels=128): + # https://github.com/XingangPan/SCNN/issues/82 + bound = math.sqrt(2.0 / (num_channels * 9 * 5)) + nn.init.uniform_(self.conv_d.weight, -bound, bound) + + def forward(self, input): + output = input + + # First one remains unchanged (according to the original paper), why not add a relu afterwards? + # Update and send to next + # Down + for i in range(1, output.shape[2]): + output[:, :, i:i + 1, :].add_(F.relu(self.conv_d(output[:, :, i - 1:i, :]))) + + return output + + # Typical lane existence head originated from the SCNN paper class SimpleLaneExist(nn.Module): def __init__(self, num_output, flattened_size=4500): diff --git a/torchvision_models/lane_detection/prnet.py b/torchvision_models/lane_detection/prnet.py new file mode 100644 index 00000000..8f05275b --- /dev/null +++ b/torchvision_models/lane_detection/prnet.py @@ -0,0 +1,71 @@ +# Implementation of Polynomial Regression Network based on the original paper (PRNet): +# http://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123630698.pdf +import torch.nn as nn +from collections import OrderedDict +from common_models import RESAReducer, SCNN_D +from .. import resnet +from ..segmentation import erfnet_resnet +from .._utils import IntermediateLayerGetter + + +# One convolution layer for each branch +# The kernel size 3x3 is an educated guess, the 3 branches are implemented separately for future flexibility +class PolynomialBranch(nn.Module): + def __init__(self, in_channels, order=2): + super(PolynomialBranch, self).__init__() + self.conv = nn.Conv2d(in_channels, order + 1, kernel_size=3, stride=1, padding=1, bias=False) + + def forward(self, inputs): + return self.conv(inputs) + + +class InitializationBranch(nn.Module): + def __init__(self, in_channels): + super(InitializationBranch, self).__init__() + self.conv = nn.Conv2d(in_channels, 1, kernel_size=3, stride=1, padding=1, bias=False) + + def forward(self, inputs): + return self.conv(inputs) + + +class HeightBranch(nn.Module): + def __init__(self, in_channels): + super(HeightBranch, self).__init__() + self.conv = nn.Conv2d(in_channels, 1, kernel_size=3, stride=1, padding=1, bias=False) + + def forward(self, inputs): + return self.conv(inputs) + + +# Currently supported backbones: ERFNet, ResNets +class PRNet(nn.Module): + def __init__(self, backbone_name, dropout_1=0.3, dropout_2=0.03, order=2): + super(PRNet, self).__init__() + if backbone_name == 'erfnet': + self.backbone = erfnet_resnet(dropout_1=dropout_1, dropout_2=dropout_2, encoder_only=True) + in_channels = 128 + else: + in_channels = 2048 if backbone_name == 'resnet50' or backbone_name == 'resnet101' else 512 + backbone = resnet.__dict__[backbone_name]( + pretrained=True, + replace_stride_with_dilation=[False, True, True]) + return_layers = {'layer4': 'out'} + self.backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) + + self.channel_reducer = RESAReducer(in_channels=in_channels) + self.spatial_conv = SCNN_D() + self.polynomial_branch = PolynomialBranch(in_channels=128, order=order) + self.initialization_branch = InitializationBranch(in_channels=128) + self.height_branch = HeightBranch(in_channels=128) + + def forward(self, inputs): + # Encoder (8x down-sampling) -> channel reduction (128, another educated guess) -> SCNN_D -> 3 branches + out = OrderedDict() + x = self.backbone(inputs) + x = self.channel_reducer(x) + x = self.spatial_conv(x) + out['polynomials'] = self.polynomial_branch(x) + out['initializations'] = self.initialization_branch(x) + out['heights'] = self.height_branch(x) + + return out diff --git a/torchvision_models/segmentation/enet.py b/torchvision_models/segmentation/enet.py index f3cc5680..4dcc61aa 100644 --- a/torchvision_models/segmentation/enet.py +++ b/torchvision_models/segmentation/enet.py @@ -7,6 +7,7 @@ from torch.nn.parameter import Parameter from ..lane_detection.common_models import EDLaneExist + class InitialBlock(nn.Module): """The initial block is composed of two branches: 1. a main branch which performs a regular convolution with stride 2; @@ -696,7 +697,6 @@ def forward(self, x): stage2_input_size, input_size) out['out'] = x - return out # net = ENet(num_classes=19,encoder_only=True) diff --git a/torchvision_models/segmentation/erfnet.py b/torchvision_models/segmentation/erfnet.py index 928c6431..f2145213 100644 --- a/torchvision_models/segmentation/erfnet.py +++ b/torchvision_models/segmentation/erfnet.py @@ -167,14 +167,15 @@ def forward(self, input): # ERFNet class ERFNet(nn.Module): def __init__(self, num_classes, encoder=None, num_lanes=0, dropout_1=0.03, dropout_2=0.3, flattened_size=3965, - scnn=False): + scnn=False, encoder_only=False): super().__init__() if encoder is None: self.encoder = Encoder(num_classes=num_classes, dropout_1=dropout_1, dropout_2=dropout_2) else: self.encoder = encoder - self.decoder = Decoder(num_classes) + # Only encoder (to be used as backbone) + self.decoder = None if encoder_only else Decoder(num_classes) if scnn: self.spatial_conv = SpatialConv() @@ -187,16 +188,23 @@ def __init__(self, num_classes, encoder=None, num_lanes=0, dropout_1=0.03, dropo else: self.lane_classifier = None - def forward(self, input, only_encode=False): + def forward(self, inputs, only_encode=False): + # only_encode=True is for the pre-training step of 2-step segmentation training, + # in order to match with the original implementation. + # If encoder is used as feature extractor, set encoder_only=True in class init, but do not change this variable out = OrderedDict() if only_encode: - return self.encoder.forward(input, predict=True) + return self.encoder.forward(inputs, predict=True) else: - output = self.encoder(input) # predict=False by default + output = self.encoder(inputs) # predict=False by default + if self.spatial_conv is not None: output = self.spatial_conv(output) - out['out'] = self.decoder.forward(output) + + if self.decoder is not None: + out['out'] = self.decoder.forward(output) if self.lane_classifier is not None: out['lane'] = self.lane_classifier(output) + return out diff --git a/torchvision_models/segmentation/segmentation.py b/torchvision_models/segmentation/segmentation.py index 5830a39d..bbdc1476 100644 --- a/torchvision_models/segmentation/segmentation.py +++ b/torchvision_models/segmentation/segmentation.py @@ -228,14 +228,16 @@ def deeplabv3_resnet101(pretrained=False, progress=True, def erfnet_resnet(pretrained_weights='erfnet_encoder_pretrained.pth.tar', num_classes=19, num_lanes=0, - dropout_1=0.03, dropout_2=0.3, flattened_size=3965, scnn=False): + dropout_1=0.03, dropout_2=0.3, flattened_size=3965, scnn=False, encoder_only=False): """Constructs a ERFNet model with ResNet-style backbone. Args: pretrained_weights (str): If not None, load ImageNet pre-trained weights from this filename + encoder_only (bool): If True, only encoder is returned as a feature extractor, ImageNet weights loading + will not be affected """ net = ERFNet(num_classes=num_classes, encoder=None, num_lanes=num_lanes, dropout_1=dropout_1, dropout_2=dropout_2, - flattened_size=flattened_size, scnn=scnn) + flattened_size=flattened_size, scnn=scnn, encoder_only=encoder_only) if pretrained_weights is not None: # Load ImageNet pre-trained weights saved_weights = load(pretrained_weights)['state_dict'] original_weights = net.state_dict() diff --git a/utils/losses/pr_loss.py b/utils/losses/pr_loss.py new file mode 100644 index 00000000..dc3f1bc8 --- /dev/null +++ b/utils/losses/pr_loss.py @@ -0,0 +1,46 @@ +# Loss for PRNet +import torch +from torch.nn import functional as F +from ._utils import WeightedLoss + + +def polynomial_curve_without_projection(coefficients, y): + # Arbitrary polynomial curve function + # Return x coordinates + # coefficients: [d1, d2, ... , m] + # m: number of coefficients, order increasing + # y: [d1, d2, ... , N] + y = y.permute(-1, *[i for i in range(len(y.shape) - 1)]) + x = coefficients[..., 0] + for i in range(1, coefficients.shape[-1]): + x += coefficients[..., i] * y ** i + + return x.permute(*[i + 1 for i in range(len(x.shape) - 1)], 0) # [d1, d2, ... , N] + + +class PRLoss(WeightedLoss): + __constants__ = ['reduction'] + ignore_index: int + + def __init__(self, polynomial_weight=1, initialization_weight=1, height_weight=0.1, beta=0.005, m=20, + weight=None, size_average=None, reduce=None, reduction='mean'): + super(PRLoss, self).__init__(weight, size_average, reduce, reduction) + self.polynomial_weight = polynomial_weight + self.initialization_weight = initialization_weight + self.height_weight = height_weight + self.beta = beta # Beta for smoothed L1 loss + self.m = m # Number of sample points to calculate polynomial regression loss + + def forward(self, inputs, targets, masks, net): + # masks: True for polynomial points (which have height & polynomial regression losses) + outputs = net(inputs) + + pass + + @staticmethod + def beta_smoothed_l1_loss(inputs, targets, beta=0.005): + # Smoothed L1 loss with a hyper-parameter (as in PRNet paper) + # The original torch F.smooth_l1_loss() is equivalent to beta=1 + t = torch.abs(inputs - targets) + + return torch.where(t < beta, 0.5 * t ** 2 / beta, t - 0.5 * beta)