Skip to content

Commit 8380b15

Browse files
authored
Update test_models.py (#940)
* Update test_models.py * Update test_models.py
1 parent ebd4091 commit 8380b15

File tree

1 file changed

+21
-6
lines changed

1 file changed

+21
-6
lines changed

tests/test_models.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,21 @@ def get_encoders():
2323

2424
def get_sample(model_class):
2525
if model_class in [
26-
smp.Unet,
27-
smp.Linknet,
2826
smp.FPN,
29-
smp.PSPNet,
27+
smp.Linknet,
28+
smp.Unet,
3029
smp.UnetPlusPlus,
3130
smp.MAnet,
32-
smp.UPerNet,
3331
]:
3432
sample = torch.ones([1, 3, 64, 64])
3533
elif model_class == smp.PAN:
3634
sample = torch.ones([2, 3, 256, 256])
37-
elif model_class == smp.DeepLabV3:
35+
elif model_class in [smp.DeepLabV3, smp.DeepLabV3Plus]:
3836
sample = torch.ones([2, 3, 128, 128])
37+
elif model_class in [smp.PSPNet, smp.UPerNet]:
38+
# Batch size 2 needed due to nn.BatchNorm2d not supporting (1, C, 1, 1) input
39+
# from PSPModule pooling in PSPNet/UPerNet.
40+
sample = torch.ones([2, 3, 64, 64])
3941
else:
4042
raise ValueError("Not supported model class {}".format(model_class))
4143
return sample
@@ -102,6 +104,8 @@ def test_forward(model_class, encoder_name, encoder_depth, **kwargs):
102104
smp.UnetPlusPlus,
103105
smp.MAnet,
104106
smp.DeepLabV3,
107+
smp.DeepLabV3Plus,
108+
smp.UPerNet,
105109
],
106110
)
107111
def test_forward_backward(model_class):
@@ -112,7 +116,18 @@ def test_forward_backward(model_class):
112116

113117
@pytest.mark.parametrize(
114118
"model_class",
115-
[smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus, smp.MAnet],
119+
[
120+
smp.PAN,
121+
smp.FPN,
122+
smp.PSPNet,
123+
smp.Linknet,
124+
smp.Unet,
125+
smp.UnetPlusPlus,
126+
smp.MAnet,
127+
smp.DeepLabV3,
128+
smp.DeepLabV3Plus,
129+
smp.UPerNet,
130+
],
116131
)
117132
def test_aux_output(model_class):
118133
model = model_class(

0 commit comments

Comments
 (0)