diff --git a/segmentation_models_pytorch/decoders/deeplabv3/decoder.py b/segmentation_models_pytorch/decoders/deeplabv3/decoder.py index 2bec43c9..caeb95d1 100644 --- a/segmentation_models_pytorch/decoders/deeplabv3/decoder.py +++ b/segmentation_models_pytorch/decoders/deeplabv3/decoder.py @@ -30,17 +30,33 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """ +from collections.abc import Iterable, Sequence +from typing import Literal + import torch from torch import nn from torch.nn import functional as F -__all__ = ["DeepLabV3Decoder"] +__all__ = ["DeepLabV3Decoder", "DeepLabV3PlusDecoder"] class DeepLabV3Decoder(nn.Sequential): - def __init__(self, in_channels, out_channels=256, atrous_rates=(12, 24, 36)): + def __init__( + self, + in_channels: int, + out_channels: int, + atrous_rates: Iterable[int], + aspp_separable: bool, + aspp_dropout: float, + ): super().__init__( - ASPP(in_channels, out_channels, atrous_rates), + ASPP( + in_channels, + out_channels, + atrous_rates, + separable=aspp_separable, + dropout=aspp_dropout, + ), nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(), @@ -54,10 +70,12 @@ def forward(self, *features): class DeepLabV3PlusDecoder(nn.Module): def __init__( self, - encoder_channels, - out_channels=256, - atrous_rates=(12, 24, 36), - output_stride=16, + encoder_channels: Sequence[int, ...], + out_channels: int, + atrous_rates: Iterable[int], + output_stride: Literal[8, 16], + aspp_separable: bool, + aspp_dropout: float, ): super().__init__() if output_stride not in {8, 16}: @@ -69,7 +87,13 @@ def __init__( self.output_stride = output_stride self.aspp = nn.Sequential( - ASPP(encoder_channels[-1], out_channels, atrous_rates, separable=True), + ASPP( + encoder_channels[-1], + out_channels, + atrous_rates, + separable=aspp_separable, + dropout=aspp_dropout, + ), SeparableConv2d( out_channels, out_channels, kernel_size=3, padding=1, bias=False ), @@ -111,7 +135,7 @@ def forward(self, *features): class ASPPConv(nn.Sequential): - def __init__(self, in_channels, out_channels, dilation): + def __init__(self, in_channels: int, out_channels: int, dilation: int): super().__init__( nn.Conv2d( in_channels, @@ -127,7 +151,7 @@ def __init__(self, in_channels, out_channels, dilation): class ASPPSeparableConv(nn.Sequential): - def __init__(self, in_channels, out_channels, dilation): + def __init__(self, in_channels: int, out_channels: int, dilation: int): super().__init__( SeparableConv2d( in_channels, @@ -143,7 +167,7 @@ def __init__(self, in_channels, out_channels, dilation): class ASPPPooling(nn.Sequential): - def __init__(self, in_channels, out_channels): + def __init__(self, in_channels: int, out_channels: int): super().__init__( nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), @@ -159,16 +183,22 @@ def forward(self, x): class ASPP(nn.Module): - def __init__(self, in_channels, out_channels, atrous_rates, separable=False): + def __init__( + self, + in_channels: int, + out_channels: int, + atrous_rates: Iterable[int], + separable: bool, + dropout: float, + ): super(ASPP, self).__init__() - modules = [] - modules.append( + modules = [ nn.Sequential( - nn.Conv2d(in_channels, out_channels, 1, bias=False), + nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(), ) - ) + ] rate1, rate2, rate3 = tuple(atrous_rates) ASPPConvModule = ASPPConv if not separable else ASPPSeparableConv @@ -184,7 +214,7 @@ def __init__(self, in_channels, out_channels, atrous_rates, separable=False): nn.Conv2d(5 * out_channels, out_channels, kernel_size=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(), - nn.Dropout(0.5), + nn.Dropout(dropout), ) def forward(self, x): diff --git a/segmentation_models_pytorch/decoders/deeplabv3/model.py b/segmentation_models_pytorch/decoders/deeplabv3/model.py index 6527c7a7..d67a3be3 100644 --- a/segmentation_models_pytorch/decoders/deeplabv3/model.py +++ b/segmentation_models_pytorch/decoders/deeplabv3/model.py @@ -1,4 +1,6 @@ -from typing import Any, Optional +from collections.abc import Iterable +from typing import Any, Literal, Optional + from segmentation_models_pytorch.base import ( ClassificationHead, @@ -23,13 +25,17 @@ class DeepLabV3(SegmentationModel): encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and other pretrained weights (see table with available weights for each encoder_name) decoder_channels: A number of convolution filters in ASPP module. Default is 256 + encoder_output_stride: Downsampling factor for last encoder features (see original paper for explanation) + decoder_atrous_rates: Dilation rates for ASPP module (should be an iterable of 3 integer values) + decoder_aspp_separable: Use separable convolutions in ASPP module. Default is False + decoder_aspp_dropout: Use dropout in ASPP module projection layer. Default is 0.5 in_channels: A number of input channels for the model, default is 3 (RGB images) classes: A number of classes for output mask (or you can think as a number of channels of output mask) activation: An activation function to apply after the final convolution layer. Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**. Default is **None** - upsampling: Final upsampling factor. Default is 8 to preserve input-output spatial shape identity + upsampling: Final upsampling factor (should have the same value as ``encoder_output_stride`` to preserve input-output spatial shape identity). aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build on top of encoder if **aux_params** is not **None** (default). Supported params: - classes (int): A number of classes @@ -52,11 +58,15 @@ def __init__( encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", + encoder_output_stride: Literal[8, 16] = 8, decoder_channels: int = 256, + decoder_atrous_rates: Iterable[int] = (12, 24, 36), + decoder_aspp_separable: bool = False, + decoder_aspp_dropout: float = 0.5, in_channels: int = 3, classes: int = 1, activation: Optional[str] = None, - upsampling: int = 8, + upsampling: Optional[int] = None, aux_params: Optional[dict] = None, **kwargs: dict[str, Any], ): @@ -67,12 +77,16 @@ def __init__( in_channels=in_channels, depth=encoder_depth, weights=encoder_weights, - output_stride=8, + output_stride=encoder_output_stride, **kwargs, ) self.decoder = DeepLabV3Decoder( - in_channels=self.encoder.out_channels[-1], out_channels=decoder_channels + in_channels=self.encoder.out_channels[-1], + out_channels=decoder_channels, + atrous_rates=decoder_atrous_rates, + aspp_separable=decoder_aspp_separable, + aspp_dropout=decoder_aspp_dropout, ) self.segmentation_head = SegmentationHead( @@ -80,7 +94,7 @@ def __init__( out_channels=classes, activation=activation, kernel_size=1, - upsampling=upsampling, + upsampling=encoder_output_stride if upsampling is None else upsampling, ) if aux_params is not None: @@ -105,7 +119,9 @@ class DeepLabV3Plus(SegmentationModel): encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and other pretrained weights (see table with available weights for each encoder_name) encoder_output_stride: Downsampling factor for last encoder features (see original paper for explanation) - decoder_atrous_rates: Dilation rates for ASPP module (should be a tuple of 3 integer values) + decoder_atrous_rates: Dilation rates for ASPP module (should be an iterable of 3 integer values) + decoder_aspp_separable: Use separable convolutions in ASPP module. Default is True + decoder_aspp_dropout: Use dropout in ASPP module projection layer. Default is 0.5 decoder_channels: A number of convolution filters in ASPP module. Default is 256 in_channels: A number of input channels for the model, default is 3 (RGB images) classes: A number of classes for output mask (or you can think as a number of channels of output mask) @@ -136,9 +152,11 @@ def __init__( encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", - encoder_output_stride: int = 16, + encoder_output_stride: Literal[8, 16] = 16, decoder_channels: int = 256, - decoder_atrous_rates: tuple = (12, 24, 36), + decoder_atrous_rates: Iterable[int] = (12, 24, 36), + decoder_aspp_separable: bool = True, + decoder_aspp_dropout: float = 0.5, in_channels: int = 3, classes: int = 1, activation: Optional[str] = None, @@ -148,13 +166,6 @@ def __init__( ): super().__init__() - if encoder_output_stride not in [8, 16]: - raise ValueError( - "Encoder output stride should be 8 or 16, got {}".format( - encoder_output_stride - ) - ) - self.encoder = get_encoder( encoder_name, in_channels=in_channels, @@ -169,6 +180,8 @@ def __init__( out_channels=decoder_channels, atrous_rates=decoder_atrous_rates, output_stride=encoder_output_stride, + aspp_separable=decoder_aspp_separable, + aspp_dropout=decoder_aspp_dropout, ) self.segmentation_head = SegmentationHead(