Skip to content

Commit 934a8e1

Browse files
Update DeepLab models (#959)
* Fix DeepLabV3+ import warnings * Clean up code - Added type hints to ASPP components. - Updated ASPP module list creation block. - Updated model type hints and docs. - Delegated encoder output stride checks to decoder. * Expose hidden block arguments * Forgot comma * Fixed imports
1 parent d538393 commit 934a8e1

File tree

2 files changed

+76
-33
lines changed

2 files changed

+76
-33
lines changed

segmentation_models_pytorch/decoders/deeplabv3/decoder.py

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,33 @@
3030
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
3131
"""
3232

33+
from collections.abc import Iterable, Sequence
34+
from typing import Literal
35+
3336
import torch
3437
from torch import nn
3538
from torch.nn import functional as F
3639

37-
__all__ = ["DeepLabV3Decoder"]
40+
__all__ = ["DeepLabV3Decoder", "DeepLabV3PlusDecoder"]
3841

3942

4043
class DeepLabV3Decoder(nn.Sequential):
41-
def __init__(self, in_channels, out_channels=256, atrous_rates=(12, 24, 36)):
44+
def __init__(
45+
self,
46+
in_channels: int,
47+
out_channels: int,
48+
atrous_rates: Iterable[int],
49+
aspp_separable: bool,
50+
aspp_dropout: float,
51+
):
4252
super().__init__(
43-
ASPP(in_channels, out_channels, atrous_rates),
53+
ASPP(
54+
in_channels,
55+
out_channels,
56+
atrous_rates,
57+
separable=aspp_separable,
58+
dropout=aspp_dropout,
59+
),
4460
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
4561
nn.BatchNorm2d(out_channels),
4662
nn.ReLU(),
@@ -54,10 +70,12 @@ def forward(self, *features):
5470
class DeepLabV3PlusDecoder(nn.Module):
5571
def __init__(
5672
self,
57-
encoder_channels,
58-
out_channels=256,
59-
atrous_rates=(12, 24, 36),
60-
output_stride=16,
73+
encoder_channels: Sequence[int, ...],
74+
out_channels: int,
75+
atrous_rates: Iterable[int],
76+
output_stride: Literal[8, 16],
77+
aspp_separable: bool,
78+
aspp_dropout: float,
6179
):
6280
super().__init__()
6381
if output_stride not in {8, 16}:
@@ -69,7 +87,13 @@ def __init__(
6987
self.output_stride = output_stride
7088

7189
self.aspp = nn.Sequential(
72-
ASPP(encoder_channels[-1], out_channels, atrous_rates, separable=True),
90+
ASPP(
91+
encoder_channels[-1],
92+
out_channels,
93+
atrous_rates,
94+
separable=aspp_separable,
95+
dropout=aspp_dropout,
96+
),
7397
SeparableConv2d(
7498
out_channels, out_channels, kernel_size=3, padding=1, bias=False
7599
),
@@ -111,7 +135,7 @@ def forward(self, *features):
111135

112136

113137
class ASPPConv(nn.Sequential):
114-
def __init__(self, in_channels, out_channels, dilation):
138+
def __init__(self, in_channels: int, out_channels: int, dilation: int):
115139
super().__init__(
116140
nn.Conv2d(
117141
in_channels,
@@ -127,7 +151,7 @@ def __init__(self, in_channels, out_channels, dilation):
127151

128152

129153
class ASPPSeparableConv(nn.Sequential):
130-
def __init__(self, in_channels, out_channels, dilation):
154+
def __init__(self, in_channels: int, out_channels: int, dilation: int):
131155
super().__init__(
132156
SeparableConv2d(
133157
in_channels,
@@ -143,7 +167,7 @@ def __init__(self, in_channels, out_channels, dilation):
143167

144168

145169
class ASPPPooling(nn.Sequential):
146-
def __init__(self, in_channels, out_channels):
170+
def __init__(self, in_channels: int, out_channels: int):
147171
super().__init__(
148172
nn.AdaptiveAvgPool2d(1),
149173
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
@@ -159,16 +183,22 @@ def forward(self, x):
159183

160184

161185
class ASPP(nn.Module):
162-
def __init__(self, in_channels, out_channels, atrous_rates, separable=False):
186+
def __init__(
187+
self,
188+
in_channels: int,
189+
out_channels: int,
190+
atrous_rates: Iterable[int],
191+
separable: bool,
192+
dropout: float,
193+
):
163194
super(ASPP, self).__init__()
164-
modules = []
165-
modules.append(
195+
modules = [
166196
nn.Sequential(
167-
nn.Conv2d(in_channels, out_channels, 1, bias=False),
197+
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
168198
nn.BatchNorm2d(out_channels),
169199
nn.ReLU(),
170200
)
171-
)
201+
]
172202

173203
rate1, rate2, rate3 = tuple(atrous_rates)
174204
ASPPConvModule = ASPPConv if not separable else ASPPSeparableConv
@@ -184,7 +214,7 @@ def __init__(self, in_channels, out_channels, atrous_rates, separable=False):
184214
nn.Conv2d(5 * out_channels, out_channels, kernel_size=1, bias=False),
185215
nn.BatchNorm2d(out_channels),
186216
nn.ReLU(),
187-
nn.Dropout(0.5),
217+
nn.Dropout(dropout),
188218
)
189219

190220
def forward(self, x):

segmentation_models_pytorch/decoders/deeplabv3/model.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from typing import Any, Optional
1+
from collections.abc import Iterable
2+
from typing import Any, Literal, Optional
3+
24

35
from segmentation_models_pytorch.base import (
46
ClassificationHead,
@@ -23,13 +25,17 @@ class DeepLabV3(SegmentationModel):
2325
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
2426
other pretrained weights (see table with available weights for each encoder_name)
2527
decoder_channels: A number of convolution filters in ASPP module. Default is 256
28+
encoder_output_stride: Downsampling factor for last encoder features (see original paper for explanation)
29+
decoder_atrous_rates: Dilation rates for ASPP module (should be an iterable of 3 integer values)
30+
decoder_aspp_separable: Use separable convolutions in ASPP module. Default is False
31+
decoder_aspp_dropout: Use dropout in ASPP module projection layer. Default is 0.5
2632
in_channels: A number of input channels for the model, default is 3 (RGB images)
2733
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
2834
activation: An activation function to apply after the final convolution layer.
2935
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
3036
**callable** and **None**.
3137
Default is **None**
32-
upsampling: Final upsampling factor. Default is 8 to preserve input-output spatial shape identity
38+
upsampling: Final upsampling factor (should have the same value as ``encoder_output_stride`` to preserve input-output spatial shape identity).
3339
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
3440
on top of encoder if **aux_params** is not **None** (default). Supported params:
3541
- classes (int): A number of classes
@@ -52,11 +58,15 @@ def __init__(
5258
encoder_name: str = "resnet34",
5359
encoder_depth: int = 5,
5460
encoder_weights: Optional[str] = "imagenet",
61+
encoder_output_stride: Literal[8, 16] = 8,
5562
decoder_channels: int = 256,
63+
decoder_atrous_rates: Iterable[int] = (12, 24, 36),
64+
decoder_aspp_separable: bool = False,
65+
decoder_aspp_dropout: float = 0.5,
5666
in_channels: int = 3,
5767
classes: int = 1,
5868
activation: Optional[str] = None,
59-
upsampling: int = 8,
69+
upsampling: Optional[int] = None,
6070
aux_params: Optional[dict] = None,
6171
**kwargs: dict[str, Any],
6272
):
@@ -67,20 +77,24 @@ def __init__(
6777
in_channels=in_channels,
6878
depth=encoder_depth,
6979
weights=encoder_weights,
70-
output_stride=8,
80+
output_stride=encoder_output_stride,
7181
**kwargs,
7282
)
7383

7484
self.decoder = DeepLabV3Decoder(
75-
in_channels=self.encoder.out_channels[-1], out_channels=decoder_channels
85+
in_channels=self.encoder.out_channels[-1],
86+
out_channels=decoder_channels,
87+
atrous_rates=decoder_atrous_rates,
88+
aspp_separable=decoder_aspp_separable,
89+
aspp_dropout=decoder_aspp_dropout,
7690
)
7791

7892
self.segmentation_head = SegmentationHead(
7993
in_channels=self.decoder.out_channels,
8094
out_channels=classes,
8195
activation=activation,
8296
kernel_size=1,
83-
upsampling=upsampling,
97+
upsampling=encoder_output_stride if upsampling is None else upsampling,
8498
)
8599

86100
if aux_params is not None:
@@ -105,7 +119,9 @@ class DeepLabV3Plus(SegmentationModel):
105119
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
106120
other pretrained weights (see table with available weights for each encoder_name)
107121
encoder_output_stride: Downsampling factor for last encoder features (see original paper for explanation)
108-
decoder_atrous_rates: Dilation rates for ASPP module (should be a tuple of 3 integer values)
122+
decoder_atrous_rates: Dilation rates for ASPP module (should be an iterable of 3 integer values)
123+
decoder_aspp_separable: Use separable convolutions in ASPP module. Default is True
124+
decoder_aspp_dropout: Use dropout in ASPP module projection layer. Default is 0.5
109125
decoder_channels: A number of convolution filters in ASPP module. Default is 256
110126
in_channels: A number of input channels for the model, default is 3 (RGB images)
111127
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
@@ -136,9 +152,11 @@ def __init__(
136152
encoder_name: str = "resnet34",
137153
encoder_depth: int = 5,
138154
encoder_weights: Optional[str] = "imagenet",
139-
encoder_output_stride: int = 16,
155+
encoder_output_stride: Literal[8, 16] = 16,
140156
decoder_channels: int = 256,
141-
decoder_atrous_rates: tuple = (12, 24, 36),
157+
decoder_atrous_rates: Iterable[int] = (12, 24, 36),
158+
decoder_aspp_separable: bool = True,
159+
decoder_aspp_dropout: float = 0.5,
142160
in_channels: int = 3,
143161
classes: int = 1,
144162
activation: Optional[str] = None,
@@ -148,13 +166,6 @@ def __init__(
148166
):
149167
super().__init__()
150168

151-
if encoder_output_stride not in [8, 16]:
152-
raise ValueError(
153-
"Encoder output stride should be 8 or 16, got {}".format(
154-
encoder_output_stride
155-
)
156-
)
157-
158169
self.encoder = get_encoder(
159170
encoder_name,
160171
in_channels=in_channels,
@@ -169,6 +180,8 @@ def __init__(
169180
out_channels=decoder_channels,
170181
atrous_rates=decoder_atrous_rates,
171182
output_stride=encoder_output_stride,
183+
aspp_separable=decoder_aspp_separable,
184+
aspp_dropout=decoder_aspp_dropout,
172185
)
173186

174187
self.segmentation_head = SegmentationHead(

0 commit comments

Comments
 (0)