@@ -23,19 +23,21 @@ def get_encoders():
23
23
24
24
def get_sample (model_class ):
25
25
if model_class in [
26
- smp .Unet ,
27
- smp .Linknet ,
28
26
smp .FPN ,
29
- smp .PSPNet ,
27
+ smp .Linknet ,
28
+ smp .Unet ,
30
29
smp .UnetPlusPlus ,
31
30
smp .MAnet ,
32
- smp .UPerNet ,
33
31
]:
34
32
sample = torch .ones ([1 , 3 , 64 , 64 ])
35
33
elif model_class == smp .PAN :
36
34
sample = torch .ones ([2 , 3 , 256 , 256 ])
37
- elif model_class == smp .DeepLabV3 :
35
+ elif model_class in [ smp .DeepLabV3 , smp . DeepLabV3Plus ] :
38
36
sample = torch .ones ([2 , 3 , 128 , 128 ])
37
+ elif model_class in [smp .PSPNet , smp .UPerNet ]:
38
+ # Batch size 2 needed due to nn.BatchNorm2d not supporting (1, C, 1, 1) input
39
+ # from PSPModule pooling in PSPNet/UPerNet.
40
+ sample = torch .ones ([2 , 3 , 64 , 64 ])
39
41
else :
40
42
raise ValueError ("Not supported model class {}" .format (model_class ))
41
43
return sample
@@ -102,6 +104,8 @@ def test_forward(model_class, encoder_name, encoder_depth, **kwargs):
102
104
smp .UnetPlusPlus ,
103
105
smp .MAnet ,
104
106
smp .DeepLabV3 ,
107
+ smp .DeepLabV3Plus ,
108
+ smp .UPerNet ,
105
109
],
106
110
)
107
111
def test_forward_backward (model_class ):
@@ -112,7 +116,18 @@ def test_forward_backward(model_class):
112
116
113
117
@pytest .mark .parametrize (
114
118
"model_class" ,
115
- [smp .PAN , smp .FPN , smp .PSPNet , smp .Linknet , smp .Unet , smp .UnetPlusPlus , smp .MAnet ],
119
+ [
120
+ smp .PAN ,
121
+ smp .FPN ,
122
+ smp .PSPNet ,
123
+ smp .Linknet ,
124
+ smp .Unet ,
125
+ smp .UnetPlusPlus ,
126
+ smp .MAnet ,
127
+ smp .DeepLabV3 ,
128
+ smp .DeepLabV3Plus ,
129
+ smp .UPerNet ,
130
+ ],
116
131
)
117
132
def test_aux_output (model_class ):
118
133
model = model_class (
0 commit comments