From f2f717a03cda032513a7158046b2d08714007bac Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Sun, 15 Sep 2024 16:21:08 +0800 Subject: [PATCH 1/8] add UPerNet --- segmentation_models_pytorch/__init__.py | 3 + .../decoders/upernet/__init__.py | 3 + .../decoders/upernet/decoder.py | 134 ++++++++++++++++++ .../decoders/upernet/model.py | 91 ++++++++++++ 4 files changed, 231 insertions(+) create mode 100644 segmentation_models_pytorch/decoders/upernet/__init__.py create mode 100644 segmentation_models_pytorch/decoders/upernet/decoder.py create mode 100644 segmentation_models_pytorch/decoders/upernet/model.py diff --git a/segmentation_models_pytorch/__init__.py b/segmentation_models_pytorch/__init__.py index d3778ecc..5733e7b9 100644 --- a/segmentation_models_pytorch/__init__.py +++ b/segmentation_models_pytorch/__init__.py @@ -14,6 +14,7 @@ from .decoders.pspnet import PSPNet from .decoders.deeplabv3 import DeepLabV3, DeepLabV3Plus from .decoders.pan import PAN +from .decoders.upernet import UPerNet from .base.hub_mixin import from_pretrained from .__version__ import __version__ @@ -48,6 +49,7 @@ def create_model( DeepLabV3, DeepLabV3Plus, PAN, + UPerNet, ] archs_dict = {a.__name__.lower(): a for a in archs} try: @@ -82,6 +84,7 @@ def create_model( "DeepLabV3", "DeepLabV3Plus", "PAN", + "UPerNet", "from_pretrained", "create_model", "__version__", diff --git a/segmentation_models_pytorch/decoders/upernet/__init__.py b/segmentation_models_pytorch/decoders/upernet/__init__.py new file mode 100644 index 00000000..012967b5 --- /dev/null +++ b/segmentation_models_pytorch/decoders/upernet/__init__.py @@ -0,0 +1,3 @@ +from .model import UPerNet + +__all__ = ["UPerNet"] diff --git a/segmentation_models_pytorch/decoders/upernet/decoder.py b/segmentation_models_pytorch/decoders/upernet/decoder.py new file mode 100644 index 00000000..fe823822 --- /dev/null +++ b/segmentation_models_pytorch/decoders/upernet/decoder.py @@ -0,0 +1,134 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from segmentation_models_pytorch.base import modules as md + + +class PSPModule(nn.Module): + def __init__( + self, + in_channels, + out_channels, + sizes=(1, 2, 3, 6), + use_batchnorm=True, + ): + super().__init__() + self.blocks = nn.ModuleList( + [ + nn.Sequential( + nn.AdaptiveAvgPool2d(size), + md.Conv2dReLU( + in_channels, + in_channels // len(sizes), + kernel_size=1, + use_batchnorm=use_batchnorm, + ), + ) + for size in sizes + ] + ) + self.out_conv = md.Conv2dReLU( + in_channels=in_channels * 2, + out_channels=out_channels, + kernel_size=1, + use_batchnorm=True, + ) + + def forward(self, x): + _, _, h, w = x.shape + out = [x] + [ + F.interpolate(block(x), size=(h, w), mode="bilinear", align_corners=False) + for block in self.blocks + ] + out = self.out_conv(torch.cat(out, dim=1)) + return out + + +class FPNBlock(nn.Module): + def __init__(self, skip_channels, pyramid_channels, use_bathcnorm=True): + super().__init__() + self.skip_conv = ( + md.Conv2dReLU( + skip_channels, + pyramid_channels, + kernel_size=1, + use_batchnorm=use_bathcnorm, + ) + if skip_channels != 0 + else nn.Identity() + ) + + def forward(self, x, skip): + _, ch, h, w = skip.shape + x = F.interpolate(x, size=(h, w), mode="bilinear", align_corners=False) + if ch != 0: + skip = self.skip_conv(skip) + x = x + skip + return x + + +class UPerNetDecoder(nn.Module): + def __init__( + self, + encoder_channels, + encoder_depth=5, + pyramid_channels=256, + segmentation_channels=64, + ): + super().__init__() + self.out_channels = segmentation_channels + if encoder_depth < 3: + raise ValueError( + "Encoder depth for UPerNet decoder cannot be less than 3, got {}.".format( + encoder_depth + ) + ) + + encoder_channels = encoder_channels[::-1] + + # PSP Module + self.psp = PSPModule( + in_channels=encoder_channels[0], + out_channels=pyramid_channels, + sizes=(1, 2, 3, 6), + use_batchnorm=True, + ) + + # FPN Module + self.fpn_stages = nn.ModuleList( + [FPNBlock(ch, pyramid_channels) for ch in encoder_channels[1:]] + ) + + self.fpn_bottleneck = md.Conv2dReLU( + in_channels=(len(encoder_channels) - 1) * pyramid_channels, + out_channels=segmentation_channels, + kernel_size=3, + padding=1, + use_batchnorm=True, + ) + + def forward(self, *features): + # Resize all FPN features to the size of the largest feature + target_size = features[0].shape[2:] + + features = features[1:] # remove first skip with same spatial resolution + features = features[::-1] # reverse channels to start from head of encoder + + psp_out = self.psp(features[0]) + + fpn_features = [psp_out] + for feature, stage in zip(features[1:], self.fpn_stages): + fpn_feature = stage(fpn_features[-1], feature) + fpn_features.append(fpn_feature) + + resized_fpn_features = [] + for feature in fpn_features: + resized_feature = F.interpolate( + feature, size=target_size, mode="bilinear", align_corners=False + ) + resized_fpn_features.append(resized_feature) + + output = self.fpn_bottleneck(torch.cat(resized_fpn_features, dim=1)) + + return output diff --git a/segmentation_models_pytorch/decoders/upernet/model.py b/segmentation_models_pytorch/decoders/upernet/model.py new file mode 100644 index 00000000..0de37a6c --- /dev/null +++ b/segmentation_models_pytorch/decoders/upernet/model.py @@ -0,0 +1,91 @@ +from typing import Optional, Union + +from segmentation_models_pytorch.encoders import get_encoder +from segmentation_models_pytorch.base import ( + SegmentationModel, + SegmentationHead, + ClassificationHead, +) +from .decoder import UPerNetDecoder + + +class UPerNet(SegmentationModel): + """UPerNet is a unified perceptual parsing network for image segmentation. + + Args: + encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) + to extract features of different spatial resolution + encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features + two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features + with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). + Default is 5 + encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and + other pretrained weights (see table with available weights for each encoder_name) + decoder_pyramid_channels: A number of convolution filters in Feature Pyramid, default is 256 + decoder_segmentation_channels: A number of convolution filters in segmentation blocks, default is 64 + in_channels: A number of input channels for the model, default is 3 (RGB images) + classes: A number of classes for output mask (or you can think as a number of channels of output mask) + activation: An activation function to apply after the final convolution layer. + Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, + **callable** and **None**. + Default is **None** + aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build + on top of encoder if **aux_params** is not **None** (default). Supported params: + - classes (int): A number of classes + - pooling (str): One of "max", "avg". Default is "avg" + - dropout (float): Dropout factor in [0, 1) + - activation (str): An activation function to apply "sigmoid"/"softmax" + (could be **None** to return logits) + + Returns: + ``torch.nn.Module``: **UPerNet** + + .. _UPerNet: + https://arxiv.org/abs/1505.04597 + + """ + + def __init__( + self, + encoder_name: str = "resnet34", + encoder_depth: int = 5, + encoder_weights: Optional[str] = "imagenet", + decoder_pyramid_channels: int = 256, + decoder_segmentation_channels: int = 64, + in_channels: int = 3, + classes: int = 1, + activation: Optional[Union[str, callable]] = None, + aux_params: Optional[dict] = None, + ): + super().__init__() + + self.encoder = get_encoder( + encoder_name, + in_channels=in_channels, + depth=encoder_depth, + weights=encoder_weights, + ) + + self.decoder = UPerNetDecoder( + encoder_channels=self.encoder.out_channels, + encoder_depth=encoder_depth, + pyramid_channels=decoder_pyramid_channels, + segmentation_channels=decoder_segmentation_channels, + ) + + self.segmentation_head = SegmentationHead( + in_channels=self.decoder.out_channels, + out_channels=classes, + activation=activation, + kernel_size=3, + ) + + if aux_params is not None: + self.classification_head = ClassificationHead( + in_channels=self.encoder.out_channels[-1], **aux_params + ) + else: + self.classification_head = None + + self.name = "upernet-{}".format(encoder_name) + self.initialize() From 43000a3ba98e1250e5cf64d123889492031a00da Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Sun, 15 Sep 2024 19:02:32 +0800 Subject: [PATCH 2/8] update paper link --- segmentation_models_pytorch/decoders/upernet/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/segmentation_models_pytorch/decoders/upernet/model.py b/segmentation_models_pytorch/decoders/upernet/model.py index 0de37a6c..5523b72e 100644 --- a/segmentation_models_pytorch/decoders/upernet/model.py +++ b/segmentation_models_pytorch/decoders/upernet/model.py @@ -41,7 +41,7 @@ class UPerNet(SegmentationModel): ``torch.nn.Module``: **UPerNet** .. _UPerNet: - https://arxiv.org/abs/1505.04597 + https://arxiv.org/abs/1807.10221 """ From 748070ebd4ad7672821a1efda529c8454257442f Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Sun, 15 Sep 2024 19:23:05 +0800 Subject: [PATCH 3/8] update tests add UPerNet for test_models --- tests/test_models.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index f78f55d6..a1b5f2c6 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -29,6 +29,7 @@ def get_sample(model_class): smp.PSPNet, smp.UnetPlusPlus, smp.MAnet, + smp.UPerNet, ]: sample = torch.ones([1, 3, 64, 64]) elif model_class == smp.PAN: @@ -57,7 +58,8 @@ def _test_forward_backward(model, sample, test_shape=False): @pytest.mark.parametrize("encoder_name", ENCODERS) @pytest.mark.parametrize("encoder_depth", [3, 5]) @pytest.mark.parametrize( - "model_class", [smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus] + "model_class", + [smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus, smp.UPerNet], ) def test_forward(model_class, encoder_name, encoder_depth, **kwargs): if ( From 5e3717712c8a33e06cc93336cdf988ec2d9cdf3c Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Fri, 20 Sep 2024 23:18:07 +0800 Subject: [PATCH 4/8] update readme and doc --- README.md | 3 ++- docs/models.rst | 7 +++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index a8c8bc17..8433222d 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ Segmentation based on [PyTorch](https://pytorch.org/).** The main features of this library are: - High-level API (just two lines to create a neural network) - - 9 models architectures for binary and multi class segmentation (including legendary Unet) + - 10 models architectures for binary and multi class segmentation (including legendary Unet) - 124 available encoders (and 500+ encoders from [timm](https://github.com/rwightman/pytorch-image-models)) - All encoders have pre-trained weights for faster and better convergence - Popular metrics and losses for training routines @@ -94,6 +94,7 @@ Congratulations! You are done! Now you can train your model with your favorite f - PAN [[paper](https://arxiv.org/abs/1805.10180)] [[docs](https://smp.readthedocs.io/en/latest/models.html#pan)] - DeepLabV3 [[paper](https://arxiv.org/abs/1706.05587)] [[docs](https://smp.readthedocs.io/en/latest/models.html#deeplabv3)] - DeepLabV3+ [[paper](https://arxiv.org/abs/1802.02611)] [[docs](https://smp.readthedocs.io/en/latest/models.html#id9)] + - UPerNet [[paper](https://arxiv.org/abs/1807.10221)] [[docs](https://smp.readthedocs.io/en/latest/models.html#upernet)] #### Encoders diff --git a/docs/models.rst b/docs/models.rst index 003908a0..ad3682b0 100644 --- a/docs/models.rst +++ b/docs/models.rst @@ -66,3 +66,10 @@ MAnet PAN ~~~ .. autoclass:: segmentation_models_pytorch.PAN + + +.. _upernet: + +UPerNet +~~~ +.. autoclass:: segmentation_models_pytorch.UPerNet From b7b5a3780b9dd8508d598a1e7e2dab6d16d70b13 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Fri, 20 Sep 2024 23:24:47 +0800 Subject: [PATCH 5/8] rename variable --- .../decoders/upernet/decoder.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/segmentation_models_pytorch/decoders/upernet/decoder.py b/segmentation_models_pytorch/decoders/upernet/decoder.py index fe823822..20a21e6a 100644 --- a/segmentation_models_pytorch/decoders/upernet/decoder.py +++ b/segmentation_models_pytorch/decoders/upernet/decoder.py @@ -36,9 +36,9 @@ def __init__( ) def forward(self, x): - _, _, h, w = x.shape + _, _, height, weight = x.shape out = [x] + [ - F.interpolate(block(x), size=(h, w), mode="bilinear", align_corners=False) + F.interpolate(block(x), size=(height, weight), mode="bilinear", align_corners=False) for block in self.blocks ] out = self.out_conv(torch.cat(out, dim=1)) @@ -60,9 +60,9 @@ def __init__(self, skip_channels, pyramid_channels, use_bathcnorm=True): ) def forward(self, x, skip): - _, ch, h, w = skip.shape - x = F.interpolate(x, size=(h, w), mode="bilinear", align_corners=False) - if ch != 0: + _, channels, height, weight = skip.shape + x = F.interpolate(x, size=(height, weight), mode="bilinear", align_corners=False) + if channels != 0: skip = self.skip_conv(skip) x = x + skip return x From b63626149a704c2136b9d694d73b33e656b69b40 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Fri, 20 Sep 2024 23:29:23 +0800 Subject: [PATCH 6/8] fix format and lint checks --- segmentation_models_pytorch/decoders/upernet/decoder.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/segmentation_models_pytorch/decoders/upernet/decoder.py b/segmentation_models_pytorch/decoders/upernet/decoder.py index 20a21e6a..59615da0 100644 --- a/segmentation_models_pytorch/decoders/upernet/decoder.py +++ b/segmentation_models_pytorch/decoders/upernet/decoder.py @@ -38,7 +38,9 @@ def __init__( def forward(self, x): _, _, height, weight = x.shape out = [x] + [ - F.interpolate(block(x), size=(height, weight), mode="bilinear", align_corners=False) + F.interpolate( + block(x), size=(height, weight), mode="bilinear", align_corners=False + ) for block in self.blocks ] out = self.out_conv(torch.cat(out, dim=1)) @@ -61,7 +63,9 @@ def __init__(self, skip_channels, pyramid_channels, use_bathcnorm=True): def forward(self, x, skip): _, channels, height, weight = skip.shape - x = F.interpolate(x, size=(height, weight), mode="bilinear", align_corners=False) + x = F.interpolate( + x, size=(height, weight), mode="bilinear", align_corners=False + ) if channels != 0: skip = self.skip_conv(skip) x = x + skip From 8d1e565a01ba9841f6d2e1e6db6156302aa843e7 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Sun, 22 Sep 2024 16:50:41 +0800 Subject: [PATCH 7/8] update UPerNet decoder Resize all FPN output features to 1/4 of the original resolution. --- segmentation_models_pytorch/decoders/upernet/decoder.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/segmentation_models_pytorch/decoders/upernet/decoder.py b/segmentation_models_pytorch/decoders/upernet/decoder.py index 59615da0..5b33ed93 100644 --- a/segmentation_models_pytorch/decoders/upernet/decoder.py +++ b/segmentation_models_pytorch/decoders/upernet/decoder.py @@ -113,8 +113,8 @@ def __init__( ) def forward(self, *features): - # Resize all FPN features to the size of the largest feature - target_size = features[0].shape[2:] + output_size = features[0].shape[2:] + target_size = [size // 4 for size in output_size] features = features[1:] # remove first skip with same spatial resolution features = features[::-1] # reverse channels to start from head of encoder @@ -126,6 +126,7 @@ def forward(self, *features): fpn_feature = stage(fpn_features[-1], feature) fpn_features.append(fpn_feature) + # Resize all FPN features to 1/4 of the original resolution. resized_fpn_features = [] for feature in fpn_features: resized_feature = F.interpolate( @@ -134,5 +135,8 @@ def forward(self, *features): resized_fpn_features.append(resized_feature) output = self.fpn_bottleneck(torch.cat(resized_fpn_features, dim=1)) + output = F.interpolate( + output, size=output_size, mode="bilinear", align_corners=False + ) return output From 5205fe6a69e45f848458b53c9ce2c511c5ce0ef8 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Thu, 3 Oct 2024 05:14:49 +0800 Subject: [PATCH 8/8] update UPerNet decoder 1. Use `SegmentationHead` for upsampling, set `upsampling=4` 2. Remove the additional variable `out_channels` from `UPerNetDecoder` 3. Fix `SegmentationHead` kernel size to 1 --- segmentation_models_pytorch/decoders/upernet/decoder.py | 5 +---- segmentation_models_pytorch/decoders/upernet/model.py | 5 +++-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/segmentation_models_pytorch/decoders/upernet/decoder.py b/segmentation_models_pytorch/decoders/upernet/decoder.py index 5b33ed93..b36d3b40 100644 --- a/segmentation_models_pytorch/decoders/upernet/decoder.py +++ b/segmentation_models_pytorch/decoders/upernet/decoder.py @@ -81,7 +81,7 @@ def __init__( segmentation_channels=64, ): super().__init__() - self.out_channels = segmentation_channels + if encoder_depth < 3: raise ValueError( "Encoder depth for UPerNet decoder cannot be less than 3, got {}.".format( @@ -135,8 +135,5 @@ def forward(self, *features): resized_fpn_features.append(resized_feature) output = self.fpn_bottleneck(torch.cat(resized_fpn_features, dim=1)) - output = F.interpolate( - output, size=output_size, mode="bilinear", align_corners=False - ) return output diff --git a/segmentation_models_pytorch/decoders/upernet/model.py b/segmentation_models_pytorch/decoders/upernet/model.py index 5523b72e..18b97a94 100644 --- a/segmentation_models_pytorch/decoders/upernet/model.py +++ b/segmentation_models_pytorch/decoders/upernet/model.py @@ -74,10 +74,11 @@ def __init__( ) self.segmentation_head = SegmentationHead( - in_channels=self.decoder.out_channels, + in_channels=decoder_segmentation_channels, out_channels=classes, activation=activation, - kernel_size=3, + kernel_size=1, + upsampling=4, ) if aux_params is not None: