diff --git a/tests/test_models.py b/tests/test_models.py index a1b5f2c6..206635e2 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -23,19 +23,21 @@ def get_encoders(): def get_sample(model_class): if model_class in [ - smp.Unet, - smp.Linknet, smp.FPN, - smp.PSPNet, + smp.Linknet, + smp.Unet, smp.UnetPlusPlus, smp.MAnet, - smp.UPerNet, ]: sample = torch.ones([1, 3, 64, 64]) elif model_class == smp.PAN: sample = torch.ones([2, 3, 256, 256]) - elif model_class == smp.DeepLabV3: + elif model_class in [smp.DeepLabV3, smp.DeepLabV3Plus]: sample = torch.ones([2, 3, 128, 128]) + elif model_class in [smp.PSPNet, smp.UPerNet]: + # Batch size 2 needed due to nn.BatchNorm2d not supporting (1, C, 1, 1) input + # from PSPModule pooling in PSPNet/UPerNet. + sample = torch.ones([2, 3, 64, 64]) else: raise ValueError("Not supported model class {}".format(model_class)) return sample @@ -102,6 +104,8 @@ def test_forward(model_class, encoder_name, encoder_depth, **kwargs): smp.UnetPlusPlus, smp.MAnet, smp.DeepLabV3, + smp.DeepLabV3Plus, + smp.UPerNet, ], ) def test_forward_backward(model_class): @@ -112,7 +116,18 @@ def test_forward_backward(model_class): @pytest.mark.parametrize( "model_class", - [smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus, smp.MAnet], + [ + smp.PAN, + smp.FPN, + smp.PSPNet, + smp.Linknet, + smp.Unet, + smp.UnetPlusPlus, + smp.MAnet, + smp.DeepLabV3, + smp.DeepLabV3Plus, + smp.UPerNet, + ], ) def test_aux_output(model_class): model = model_class(