1
- from typing import Any , Optional
1
+ from collections .abc import Iterable
2
+ from typing import Any , Literal , Optional
3
+
2
4
3
5
from segmentation_models_pytorch .base import (
4
6
ClassificationHead ,
@@ -23,13 +25,17 @@ class DeepLabV3(SegmentationModel):
23
25
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
24
26
other pretrained weights (see table with available weights for each encoder_name)
25
27
decoder_channels: A number of convolution filters in ASPP module. Default is 256
28
+ encoder_output_stride: Downsampling factor for last encoder features (see original paper for explanation)
29
+ decoder_atrous_rates: Dilation rates for ASPP module (should be an iterable of 3 integer values)
30
+ decoder_aspp_separable: Use separable convolutions in ASPP module. Default is False
31
+ decoder_aspp_dropout: Use dropout in ASPP module projection layer. Default is 0.5
26
32
in_channels: A number of input channels for the model, default is 3 (RGB images)
27
33
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
28
34
activation: An activation function to apply after the final convolution layer.
29
35
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
30
36
**callable** and **None**.
31
37
Default is **None**
32
- upsampling: Final upsampling factor. Default is 8 to preserve input-output spatial shape identity
38
+ upsampling: Final upsampling factor (should have the same value as ``encoder_output_stride`` to preserve input-output spatial shape identity).
33
39
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
34
40
on top of encoder if **aux_params** is not **None** (default). Supported params:
35
41
- classes (int): A number of classes
@@ -52,11 +58,15 @@ def __init__(
52
58
encoder_name : str = "resnet34" ,
53
59
encoder_depth : int = 5 ,
54
60
encoder_weights : Optional [str ] = "imagenet" ,
61
+ encoder_output_stride : Literal [8 , 16 ] = 8 ,
55
62
decoder_channels : int = 256 ,
63
+ decoder_atrous_rates : Iterable [int ] = (12 , 24 , 36 ),
64
+ decoder_aspp_separable : bool = False ,
65
+ decoder_aspp_dropout : float = 0.5 ,
56
66
in_channels : int = 3 ,
57
67
classes : int = 1 ,
58
68
activation : Optional [str ] = None ,
59
- upsampling : int = 8 ,
69
+ upsampling : Optional [ int ] = None ,
60
70
aux_params : Optional [dict ] = None ,
61
71
** kwargs : dict [str , Any ],
62
72
):
@@ -67,20 +77,24 @@ def __init__(
67
77
in_channels = in_channels ,
68
78
depth = encoder_depth ,
69
79
weights = encoder_weights ,
70
- output_stride = 8 ,
80
+ output_stride = encoder_output_stride ,
71
81
** kwargs ,
72
82
)
73
83
74
84
self .decoder = DeepLabV3Decoder (
75
- in_channels = self .encoder .out_channels [- 1 ], out_channels = decoder_channels
85
+ in_channels = self .encoder .out_channels [- 1 ],
86
+ out_channels = decoder_channels ,
87
+ atrous_rates = decoder_atrous_rates ,
88
+ aspp_separable = decoder_aspp_separable ,
89
+ aspp_dropout = decoder_aspp_dropout ,
76
90
)
77
91
78
92
self .segmentation_head = SegmentationHead (
79
93
in_channels = self .decoder .out_channels ,
80
94
out_channels = classes ,
81
95
activation = activation ,
82
96
kernel_size = 1 ,
83
- upsampling = upsampling ,
97
+ upsampling = encoder_output_stride if upsampling is None else upsampling ,
84
98
)
85
99
86
100
if aux_params is not None :
@@ -105,7 +119,9 @@ class DeepLabV3Plus(SegmentationModel):
105
119
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
106
120
other pretrained weights (see table with available weights for each encoder_name)
107
121
encoder_output_stride: Downsampling factor for last encoder features (see original paper for explanation)
108
- decoder_atrous_rates: Dilation rates for ASPP module (should be a tuple of 3 integer values)
122
+ decoder_atrous_rates: Dilation rates for ASPP module (should be an iterable of 3 integer values)
123
+ decoder_aspp_separable: Use separable convolutions in ASPP module. Default is True
124
+ decoder_aspp_dropout: Use dropout in ASPP module projection layer. Default is 0.5
109
125
decoder_channels: A number of convolution filters in ASPP module. Default is 256
110
126
in_channels: A number of input channels for the model, default is 3 (RGB images)
111
127
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
@@ -136,9 +152,11 @@ def __init__(
136
152
encoder_name : str = "resnet34" ,
137
153
encoder_depth : int = 5 ,
138
154
encoder_weights : Optional [str ] = "imagenet" ,
139
- encoder_output_stride : int = 16 ,
155
+ encoder_output_stride : Literal [ 8 , 16 ] = 16 ,
140
156
decoder_channels : int = 256 ,
141
- decoder_atrous_rates : tuple = (12 , 24 , 36 ),
157
+ decoder_atrous_rates : Iterable [int ] = (12 , 24 , 36 ),
158
+ decoder_aspp_separable : bool = True ,
159
+ decoder_aspp_dropout : float = 0.5 ,
142
160
in_channels : int = 3 ,
143
161
classes : int = 1 ,
144
162
activation : Optional [str ] = None ,
@@ -148,13 +166,6 @@ def __init__(
148
166
):
149
167
super ().__init__ ()
150
168
151
- if encoder_output_stride not in [8 , 16 ]:
152
- raise ValueError (
153
- "Encoder output stride should be 8 or 16, got {}" .format (
154
- encoder_output_stride
155
- )
156
- )
157
-
158
169
self .encoder = get_encoder (
159
170
encoder_name ,
160
171
in_channels = in_channels ,
@@ -169,6 +180,8 @@ def __init__(
169
180
out_channels = decoder_channels ,
170
181
atrous_rates = decoder_atrous_rates ,
171
182
output_stride = encoder_output_stride ,
183
+ aspp_separable = decoder_aspp_separable ,
184
+ aspp_dropout = decoder_aspp_dropout ,
172
185
)
173
186
174
187
self .segmentation_head = SegmentationHead (
0 commit comments