-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
[feat] Adding UPerNet #926
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
f2f717a
add UPerNet
brianhou0208 43000a3
update paper link
brianhou0208 748070e
update tests
brianhou0208 5e37177
update readme and doc
brianhou0208 b7b5a37
rename variable
brianhou0208 b636261
fix format and lint checks
brianhou0208 8d1e565
update UPerNet decoder
brianhou0208 5205fe6
update UPerNet decoder
brianhou0208 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .model import UPerNet | ||
|
||
__all__ = ["UPerNet"] |
139 changes: 139 additions & 0 deletions
139
segmentation_models_pytorch/decoders/upernet/decoder.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from segmentation_models_pytorch.base import modules as md | ||
|
||
|
||
class PSPModule(nn.Module): | ||
def __init__( | ||
self, | ||
in_channels, | ||
out_channels, | ||
sizes=(1, 2, 3, 6), | ||
use_batchnorm=True, | ||
): | ||
super().__init__() | ||
self.blocks = nn.ModuleList( | ||
[ | ||
nn.Sequential( | ||
nn.AdaptiveAvgPool2d(size), | ||
md.Conv2dReLU( | ||
in_channels, | ||
in_channels // len(sizes), | ||
kernel_size=1, | ||
use_batchnorm=use_batchnorm, | ||
), | ||
) | ||
for size in sizes | ||
] | ||
) | ||
self.out_conv = md.Conv2dReLU( | ||
in_channels=in_channels * 2, | ||
out_channels=out_channels, | ||
kernel_size=1, | ||
use_batchnorm=True, | ||
) | ||
|
||
def forward(self, x): | ||
_, _, height, weight = x.shape | ||
out = [x] + [ | ||
F.interpolate( | ||
block(x), size=(height, weight), mode="bilinear", align_corners=False | ||
) | ||
for block in self.blocks | ||
] | ||
out = self.out_conv(torch.cat(out, dim=1)) | ||
return out | ||
|
||
|
||
class FPNBlock(nn.Module): | ||
def __init__(self, skip_channels, pyramid_channels, use_bathcnorm=True): | ||
super().__init__() | ||
self.skip_conv = ( | ||
md.Conv2dReLU( | ||
skip_channels, | ||
pyramid_channels, | ||
kernel_size=1, | ||
use_batchnorm=use_bathcnorm, | ||
) | ||
if skip_channels != 0 | ||
else nn.Identity() | ||
) | ||
|
||
def forward(self, x, skip): | ||
_, channels, height, weight = skip.shape | ||
x = F.interpolate( | ||
x, size=(height, weight), mode="bilinear", align_corners=False | ||
) | ||
if channels != 0: | ||
skip = self.skip_conv(skip) | ||
x = x + skip | ||
return x | ||
|
||
|
||
class UPerNetDecoder(nn.Module): | ||
def __init__( | ||
self, | ||
encoder_channels, | ||
encoder_depth=5, | ||
pyramid_channels=256, | ||
segmentation_channels=64, | ||
): | ||
super().__init__() | ||
|
||
if encoder_depth < 3: | ||
raise ValueError( | ||
"Encoder depth for UPerNet decoder cannot be less than 3, got {}.".format( | ||
encoder_depth | ||
) | ||
) | ||
|
||
encoder_channels = encoder_channels[::-1] | ||
|
||
# PSP Module | ||
self.psp = PSPModule( | ||
in_channels=encoder_channels[0], | ||
out_channels=pyramid_channels, | ||
sizes=(1, 2, 3, 6), | ||
use_batchnorm=True, | ||
) | ||
|
||
# FPN Module | ||
self.fpn_stages = nn.ModuleList( | ||
[FPNBlock(ch, pyramid_channels) for ch in encoder_channels[1:]] | ||
) | ||
|
||
self.fpn_bottleneck = md.Conv2dReLU( | ||
in_channels=(len(encoder_channels) - 1) * pyramid_channels, | ||
out_channels=segmentation_channels, | ||
kernel_size=3, | ||
padding=1, | ||
use_batchnorm=True, | ||
) | ||
|
||
def forward(self, *features): | ||
output_size = features[0].shape[2:] | ||
target_size = [size // 4 for size in output_size] | ||
|
||
features = features[1:] # remove first skip with same spatial resolution | ||
features = features[::-1] # reverse channels to start from head of encoder | ||
|
||
psp_out = self.psp(features[0]) | ||
|
||
fpn_features = [psp_out] | ||
for feature, stage in zip(features[1:], self.fpn_stages): | ||
fpn_feature = stage(fpn_features[-1], feature) | ||
fpn_features.append(fpn_feature) | ||
|
||
# Resize all FPN features to 1/4 of the original resolution. | ||
resized_fpn_features = [] | ||
for feature in fpn_features: | ||
resized_feature = F.interpolate( | ||
feature, size=target_size, mode="bilinear", align_corners=False | ||
) | ||
resized_fpn_features.append(resized_feature) | ||
|
||
output = self.fpn_bottleneck(torch.cat(resized_fpn_features, dim=1)) | ||
|
||
return output |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
from typing import Optional, Union | ||
|
||
from segmentation_models_pytorch.encoders import get_encoder | ||
from segmentation_models_pytorch.base import ( | ||
SegmentationModel, | ||
SegmentationHead, | ||
ClassificationHead, | ||
) | ||
from .decoder import UPerNetDecoder | ||
|
||
|
||
class UPerNet(SegmentationModel): | ||
"""UPerNet is a unified perceptual parsing network for image segmentation. | ||
|
||
Args: | ||
encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) | ||
to extract features of different spatial resolution | ||
encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features | ||
two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features | ||
with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). | ||
Default is 5 | ||
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and | ||
other pretrained weights (see table with available weights for each encoder_name) | ||
decoder_pyramid_channels: A number of convolution filters in Feature Pyramid, default is 256 | ||
decoder_segmentation_channels: A number of convolution filters in segmentation blocks, default is 64 | ||
in_channels: A number of input channels for the model, default is 3 (RGB images) | ||
classes: A number of classes for output mask (or you can think as a number of channels of output mask) | ||
activation: An activation function to apply after the final convolution layer. | ||
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, | ||
**callable** and **None**. | ||
Default is **None** | ||
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build | ||
on top of encoder if **aux_params** is not **None** (default). Supported params: | ||
- classes (int): A number of classes | ||
- pooling (str): One of "max", "avg". Default is "avg" | ||
- dropout (float): Dropout factor in [0, 1) | ||
- activation (str): An activation function to apply "sigmoid"/"softmax" | ||
(could be **None** to return logits) | ||
|
||
Returns: | ||
``torch.nn.Module``: **UPerNet** | ||
|
||
.. _UPerNet: | ||
https://arxiv.org/abs/1807.10221 | ||
|
||
""" | ||
|
||
def __init__( | ||
self, | ||
encoder_name: str = "resnet34", | ||
encoder_depth: int = 5, | ||
encoder_weights: Optional[str] = "imagenet", | ||
decoder_pyramid_channels: int = 256, | ||
decoder_segmentation_channels: int = 64, | ||
in_channels: int = 3, | ||
classes: int = 1, | ||
activation: Optional[Union[str, callable]] = None, | ||
aux_params: Optional[dict] = None, | ||
): | ||
super().__init__() | ||
|
||
self.encoder = get_encoder( | ||
encoder_name, | ||
in_channels=in_channels, | ||
depth=encoder_depth, | ||
weights=encoder_weights, | ||
) | ||
|
||
self.decoder = UPerNetDecoder( | ||
encoder_channels=self.encoder.out_channels, | ||
encoder_depth=encoder_depth, | ||
pyramid_channels=decoder_pyramid_channels, | ||
segmentation_channels=decoder_segmentation_channels, | ||
) | ||
|
||
self.segmentation_head = SegmentationHead( | ||
in_channels=decoder_segmentation_channels, | ||
out_channels=classes, | ||
activation=activation, | ||
kernel_size=1, | ||
upsampling=4, | ||
) | ||
|
||
if aux_params is not None: | ||
self.classification_head = ClassificationHead( | ||
in_channels=self.encoder.out_channels[-1], **aux_params | ||
) | ||
else: | ||
self.classification_head = None | ||
|
||
self.name = "upernet-{}".format(encoder_name) | ||
self.initialize() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.