diff --git a/README.md b/README.md index a8c8bc17..8433222d 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ Segmentation based on [PyTorch](https://pytorch.org/).** The main features of this library are: - High-level API (just two lines to create a neural network) - - 9 models architectures for binary and multi class segmentation (including legendary Unet) + - 10 models architectures for binary and multi class segmentation (including legendary Unet) - 124 available encoders (and 500+ encoders from [timm](https://github.com/rwightman/pytorch-image-models)) - All encoders have pre-trained weights for faster and better convergence - Popular metrics and losses for training routines @@ -94,6 +94,7 @@ Congratulations! You are done! Now you can train your model with your favorite f - PAN [[paper](https://arxiv.org/abs/1805.10180)] [[docs](https://smp.readthedocs.io/en/latest/models.html#pan)] - DeepLabV3 [[paper](https://arxiv.org/abs/1706.05587)] [[docs](https://smp.readthedocs.io/en/latest/models.html#deeplabv3)] - DeepLabV3+ [[paper](https://arxiv.org/abs/1802.02611)] [[docs](https://smp.readthedocs.io/en/latest/models.html#id9)] + - UPerNet [[paper](https://arxiv.org/abs/1807.10221)] [[docs](https://smp.readthedocs.io/en/latest/models.html#upernet)] #### Encoders diff --git a/docs/models.rst b/docs/models.rst index 003908a0..ad3682b0 100644 --- a/docs/models.rst +++ b/docs/models.rst @@ -66,3 +66,10 @@ MAnet PAN ~~~ .. autoclass:: segmentation_models_pytorch.PAN + + +.. _upernet: + +UPerNet +~~~ +.. autoclass:: segmentation_models_pytorch.UPerNet diff --git a/segmentation_models_pytorch/__init__.py b/segmentation_models_pytorch/__init__.py index d3778ecc..5733e7b9 100644 --- a/segmentation_models_pytorch/__init__.py +++ b/segmentation_models_pytorch/__init__.py @@ -14,6 +14,7 @@ from .decoders.pspnet import PSPNet from .decoders.deeplabv3 import DeepLabV3, DeepLabV3Plus from .decoders.pan import PAN +from .decoders.upernet import UPerNet from .base.hub_mixin import from_pretrained from .__version__ import __version__ @@ -48,6 +49,7 @@ def create_model( DeepLabV3, DeepLabV3Plus, PAN, + UPerNet, ] archs_dict = {a.__name__.lower(): a for a in archs} try: @@ -82,6 +84,7 @@ def create_model( "DeepLabV3", "DeepLabV3Plus", "PAN", + "UPerNet", "from_pretrained", "create_model", "__version__", diff --git a/segmentation_models_pytorch/decoders/upernet/__init__.py b/segmentation_models_pytorch/decoders/upernet/__init__.py new file mode 100644 index 00000000..012967b5 --- /dev/null +++ b/segmentation_models_pytorch/decoders/upernet/__init__.py @@ -0,0 +1,3 @@ +from .model import UPerNet + +__all__ = ["UPerNet"] diff --git a/segmentation_models_pytorch/decoders/upernet/decoder.py b/segmentation_models_pytorch/decoders/upernet/decoder.py new file mode 100644 index 00000000..b36d3b40 --- /dev/null +++ b/segmentation_models_pytorch/decoders/upernet/decoder.py @@ -0,0 +1,139 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from segmentation_models_pytorch.base import modules as md + + +class PSPModule(nn.Module): + def __init__( + self, + in_channels, + out_channels, + sizes=(1, 2, 3, 6), + use_batchnorm=True, + ): + super().__init__() + self.blocks = nn.ModuleList( + [ + nn.Sequential( + nn.AdaptiveAvgPool2d(size), + md.Conv2dReLU( + in_channels, + in_channels // len(sizes), + kernel_size=1, + use_batchnorm=use_batchnorm, + ), + ) + for size in sizes + ] + ) + self.out_conv = md.Conv2dReLU( + in_channels=in_channels * 2, + out_channels=out_channels, + kernel_size=1, + use_batchnorm=True, + ) + + def forward(self, x): + _, _, height, weight = x.shape + out = [x] + [ + F.interpolate( + block(x), size=(height, weight), mode="bilinear", align_corners=False + ) + for block in self.blocks + ] + out = self.out_conv(torch.cat(out, dim=1)) + return out + + +class FPNBlock(nn.Module): + def __init__(self, skip_channels, pyramid_channels, use_bathcnorm=True): + super().__init__() + self.skip_conv = ( + md.Conv2dReLU( + skip_channels, + pyramid_channels, + kernel_size=1, + use_batchnorm=use_bathcnorm, + ) + if skip_channels != 0 + else nn.Identity() + ) + + def forward(self, x, skip): + _, channels, height, weight = skip.shape + x = F.interpolate( + x, size=(height, weight), mode="bilinear", align_corners=False + ) + if channels != 0: + skip = self.skip_conv(skip) + x = x + skip + return x + + +class UPerNetDecoder(nn.Module): + def __init__( + self, + encoder_channels, + encoder_depth=5, + pyramid_channels=256, + segmentation_channels=64, + ): + super().__init__() + + if encoder_depth < 3: + raise ValueError( + "Encoder depth for UPerNet decoder cannot be less than 3, got {}.".format( + encoder_depth + ) + ) + + encoder_channels = encoder_channels[::-1] + + # PSP Module + self.psp = PSPModule( + in_channels=encoder_channels[0], + out_channels=pyramid_channels, + sizes=(1, 2, 3, 6), + use_batchnorm=True, + ) + + # FPN Module + self.fpn_stages = nn.ModuleList( + [FPNBlock(ch, pyramid_channels) for ch in encoder_channels[1:]] + ) + + self.fpn_bottleneck = md.Conv2dReLU( + in_channels=(len(encoder_channels) - 1) * pyramid_channels, + out_channels=segmentation_channels, + kernel_size=3, + padding=1, + use_batchnorm=True, + ) + + def forward(self, *features): + output_size = features[0].shape[2:] + target_size = [size // 4 for size in output_size] + + features = features[1:] # remove first skip with same spatial resolution + features = features[::-1] # reverse channels to start from head of encoder + + psp_out = self.psp(features[0]) + + fpn_features = [psp_out] + for feature, stage in zip(features[1:], self.fpn_stages): + fpn_feature = stage(fpn_features[-1], feature) + fpn_features.append(fpn_feature) + + # Resize all FPN features to 1/4 of the original resolution. + resized_fpn_features = [] + for feature in fpn_features: + resized_feature = F.interpolate( + feature, size=target_size, mode="bilinear", align_corners=False + ) + resized_fpn_features.append(resized_feature) + + output = self.fpn_bottleneck(torch.cat(resized_fpn_features, dim=1)) + + return output diff --git a/segmentation_models_pytorch/decoders/upernet/model.py b/segmentation_models_pytorch/decoders/upernet/model.py new file mode 100644 index 00000000..18b97a94 --- /dev/null +++ b/segmentation_models_pytorch/decoders/upernet/model.py @@ -0,0 +1,92 @@ +from typing import Optional, Union + +from segmentation_models_pytorch.encoders import get_encoder +from segmentation_models_pytorch.base import ( + SegmentationModel, + SegmentationHead, + ClassificationHead, +) +from .decoder import UPerNetDecoder + + +class UPerNet(SegmentationModel): + """UPerNet is a unified perceptual parsing network for image segmentation. + + Args: + encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) + to extract features of different spatial resolution + encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features + two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features + with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). + Default is 5 + encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and + other pretrained weights (see table with available weights for each encoder_name) + decoder_pyramid_channels: A number of convolution filters in Feature Pyramid, default is 256 + decoder_segmentation_channels: A number of convolution filters in segmentation blocks, default is 64 + in_channels: A number of input channels for the model, default is 3 (RGB images) + classes: A number of classes for output mask (or you can think as a number of channels of output mask) + activation: An activation function to apply after the final convolution layer. + Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, + **callable** and **None**. + Default is **None** + aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build + on top of encoder if **aux_params** is not **None** (default). Supported params: + - classes (int): A number of classes + - pooling (str): One of "max", "avg". Default is "avg" + - dropout (float): Dropout factor in [0, 1) + - activation (str): An activation function to apply "sigmoid"/"softmax" + (could be **None** to return logits) + + Returns: + ``torch.nn.Module``: **UPerNet** + + .. _UPerNet: + https://arxiv.org/abs/1807.10221 + + """ + + def __init__( + self, + encoder_name: str = "resnet34", + encoder_depth: int = 5, + encoder_weights: Optional[str] = "imagenet", + decoder_pyramid_channels: int = 256, + decoder_segmentation_channels: int = 64, + in_channels: int = 3, + classes: int = 1, + activation: Optional[Union[str, callable]] = None, + aux_params: Optional[dict] = None, + ): + super().__init__() + + self.encoder = get_encoder( + encoder_name, + in_channels=in_channels, + depth=encoder_depth, + weights=encoder_weights, + ) + + self.decoder = UPerNetDecoder( + encoder_channels=self.encoder.out_channels, + encoder_depth=encoder_depth, + pyramid_channels=decoder_pyramid_channels, + segmentation_channels=decoder_segmentation_channels, + ) + + self.segmentation_head = SegmentationHead( + in_channels=decoder_segmentation_channels, + out_channels=classes, + activation=activation, + kernel_size=1, + upsampling=4, + ) + + if aux_params is not None: + self.classification_head = ClassificationHead( + in_channels=self.encoder.out_channels[-1], **aux_params + ) + else: + self.classification_head = None + + self.name = "upernet-{}".format(encoder_name) + self.initialize() diff --git a/tests/test_models.py b/tests/test_models.py index f78f55d6..a1b5f2c6 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -29,6 +29,7 @@ def get_sample(model_class): smp.PSPNet, smp.UnetPlusPlus, smp.MAnet, + smp.UPerNet, ]: sample = torch.ones([1, 3, 64, 64]) elif model_class == smp.PAN: @@ -57,7 +58,8 @@ def _test_forward_backward(model, sample, test_shape=False): @pytest.mark.parametrize("encoder_name", ENCODERS) @pytest.mark.parametrize("encoder_depth", [3, 5]) @pytest.mark.parametrize( - "model_class", [smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus] + "model_class", + [smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus, smp.UPerNet], ) def test_forward(model_class, encoder_name, encoder_depth, **kwargs): if (