diff --git a/segmentation_models_pytorch/decoders/deeplabv3/model.py b/segmentation_models_pytorch/decoders/deeplabv3/model.py index ad422dbc..6527c7a7 100644 --- a/segmentation_models_pytorch/decoders/deeplabv3/model.py +++ b/segmentation_models_pytorch/decoders/deeplabv3/model.py @@ -1,11 +1,12 @@ -from typing import Optional +from typing import Any, Optional from segmentation_models_pytorch.base import ( - SegmentationModel, - SegmentationHead, ClassificationHead, + SegmentationHead, + SegmentationModel, ) from segmentation_models_pytorch.encoders import get_encoder + from .decoder import DeepLabV3Decoder, DeepLabV3PlusDecoder @@ -36,6 +37,8 @@ class DeepLabV3(SegmentationModel): - dropout (float): Dropout factor in [0, 1) - activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits) + kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing. + Returns: ``torch.nn.Module``: **DeepLabV3** @@ -55,6 +58,7 @@ def __init__( activation: Optional[str] = None, upsampling: int = 8, aux_params: Optional[dict] = None, + **kwargs: dict[str, Any], ): super().__init__() @@ -64,6 +68,7 @@ def __init__( depth=encoder_depth, weights=encoder_weights, output_stride=8, + **kwargs, ) self.decoder = DeepLabV3Decoder( @@ -116,6 +121,8 @@ class DeepLabV3Plus(SegmentationModel): - dropout (float): Dropout factor in [0, 1) - activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits) + kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing. + Returns: ``torch.nn.Module``: **DeepLabV3Plus** @@ -137,6 +144,7 @@ def __init__( activation: Optional[str] = None, upsampling: int = 4, aux_params: Optional[dict] = None, + **kwargs: dict[str, Any], ): super().__init__() @@ -153,6 +161,7 @@ def __init__( depth=encoder_depth, weights=encoder_weights, output_stride=encoder_output_stride, + **kwargs, ) self.decoder = DeepLabV3PlusDecoder( diff --git a/segmentation_models_pytorch/decoders/fpn/model.py b/segmentation_models_pytorch/decoders/fpn/model.py index f18457d5..373269c5 100644 --- a/segmentation_models_pytorch/decoders/fpn/model.py +++ b/segmentation_models_pytorch/decoders/fpn/model.py @@ -1,11 +1,12 @@ -from typing import Optional +from typing import Any, Optional from segmentation_models_pytorch.base import ( - SegmentationModel, - SegmentationHead, ClassificationHead, + SegmentationHead, + SegmentationModel, ) from segmentation_models_pytorch.encoders import get_encoder + from .decoder import FPNDecoder @@ -40,6 +41,7 @@ class FPN(SegmentationModel): - dropout (float): Dropout factor in [0, 1) - activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits) + kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing. Returns: ``torch.nn.Module``: **FPN** @@ -63,6 +65,7 @@ def __init__( activation: Optional[str] = None, upsampling: int = 4, aux_params: Optional[dict] = None, + **kwargs: dict[str, Any], ): super().__init__() @@ -77,6 +80,7 @@ def __init__( in_channels=in_channels, depth=encoder_depth, weights=encoder_weights, + **kwargs, ) self.decoder = FPNDecoder( diff --git a/segmentation_models_pytorch/decoders/linknet/model.py b/segmentation_models_pytorch/decoders/linknet/model.py index b8c3139f..708ea562 100644 --- a/segmentation_models_pytorch/decoders/linknet/model.py +++ b/segmentation_models_pytorch/decoders/linknet/model.py @@ -1,11 +1,12 @@ -from typing import Optional, Union +from typing import Any, Optional, Union from segmentation_models_pytorch.base import ( + ClassificationHead, SegmentationHead, SegmentationModel, - ClassificationHead, ) from segmentation_models_pytorch.encoders import get_encoder + from .decoder import LinknetDecoder @@ -43,6 +44,7 @@ class Linknet(SegmentationModel): - dropout (float): Dropout factor in [0, 1) - activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits) + kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing. Returns: ``torch.nn.Module``: **Linknet** @@ -61,6 +63,7 @@ def __init__( classes: int = 1, activation: Optional[Union[str, callable]] = None, aux_params: Optional[dict] = None, + **kwargs: dict[str, Any], ): super().__init__() @@ -74,6 +77,7 @@ def __init__( in_channels=in_channels, depth=encoder_depth, weights=encoder_weights, + **kwargs, ) self.decoder = LinknetDecoder( diff --git a/segmentation_models_pytorch/decoders/manet/model.py b/segmentation_models_pytorch/decoders/manet/model.py index 08e64a2a..6651dee6 100644 --- a/segmentation_models_pytorch/decoders/manet/model.py +++ b/segmentation_models_pytorch/decoders/manet/model.py @@ -1,11 +1,12 @@ -from typing import Optional, Union, List +from typing import Any, List, Optional, Union -from segmentation_models_pytorch.encoders import get_encoder from segmentation_models_pytorch.base import ( - SegmentationModel, - SegmentationHead, ClassificationHead, + SegmentationHead, + SegmentationModel, ) +from segmentation_models_pytorch.encoders import get_encoder + from .decoder import MAnetDecoder @@ -45,6 +46,7 @@ class MAnet(SegmentationModel): - dropout (float): Dropout factor in [0, 1) - activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits) + kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing. Returns: ``torch.nn.Module``: **MAnet** @@ -66,6 +68,7 @@ def __init__( classes: int = 1, activation: Optional[Union[str, callable]] = None, aux_params: Optional[dict] = None, + **kwargs: dict[str, Any], ): super().__init__() @@ -74,6 +77,7 @@ def __init__( in_channels=in_channels, depth=encoder_depth, weights=encoder_weights, + **kwargs, ) self.decoder = MAnetDecoder( diff --git a/segmentation_models_pytorch/decoders/pan/model.py b/segmentation_models_pytorch/decoders/pan/model.py index 8086d024..5c46f489 100644 --- a/segmentation_models_pytorch/decoders/pan/model.py +++ b/segmentation_models_pytorch/decoders/pan/model.py @@ -1,11 +1,12 @@ -from typing import Optional, Union +from typing import Any, Optional, Union -from segmentation_models_pytorch.encoders import get_encoder from segmentation_models_pytorch.base import ( - SegmentationModel, - SegmentationHead, ClassificationHead, + SegmentationHead, + SegmentationModel, ) +from segmentation_models_pytorch.encoders import get_encoder + from .decoder import PANDecoder @@ -38,6 +39,7 @@ class PAN(SegmentationModel): - dropout (float): Dropout factor in [0, 1) - activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits) + kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing. Returns: ``torch.nn.Module``: **PAN** @@ -58,6 +60,7 @@ def __init__( activation: Optional[Union[str, callable]] = None, upsampling: int = 4, aux_params: Optional[dict] = None, + **kwargs: dict[str, Any], ): super().__init__() @@ -74,6 +77,7 @@ def __init__( depth=5, weights=encoder_weights, output_stride=encoder_output_stride, + **kwargs, ) self.decoder = PANDecoder( diff --git a/segmentation_models_pytorch/decoders/pspnet/model.py b/segmentation_models_pytorch/decoders/pspnet/model.py index 9f9997f8..dbf04ea4 100644 --- a/segmentation_models_pytorch/decoders/pspnet/model.py +++ b/segmentation_models_pytorch/decoders/pspnet/model.py @@ -1,11 +1,12 @@ -from typing import Optional, Union +from typing import Any, Optional, Union -from segmentation_models_pytorch.encoders import get_encoder from segmentation_models_pytorch.base import ( - SegmentationModel, - SegmentationHead, ClassificationHead, + SegmentationHead, + SegmentationModel, ) +from segmentation_models_pytorch.encoders import get_encoder + from .decoder import PSPDecoder @@ -44,6 +45,7 @@ class PSPNet(SegmentationModel): - dropout (float): Dropout factor in [0, 1) - activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits) + kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing. Returns: ``torch.nn.Module``: **PSPNet** @@ -65,6 +67,7 @@ def __init__( activation: Optional[Union[str, callable]] = None, upsampling: int = 8, aux_params: Optional[dict] = None, + **kwargs: dict[str, Any], ): super().__init__() @@ -73,6 +76,7 @@ def __init__( in_channels=in_channels, depth=encoder_depth, weights=encoder_weights, + **kwargs, ) self.decoder = PSPDecoder( diff --git a/segmentation_models_pytorch/decoders/unet/model.py b/segmentation_models_pytorch/decoders/unet/model.py index 46528c5a..0ac7b5bd 100644 --- a/segmentation_models_pytorch/decoders/unet/model.py +++ b/segmentation_models_pytorch/decoders/unet/model.py @@ -1,11 +1,12 @@ -from typing import Optional, Union, List +from typing import Any, List, Optional, Union -from segmentation_models_pytorch.encoders import get_encoder from segmentation_models_pytorch.base import ( - SegmentationModel, - SegmentationHead, ClassificationHead, + SegmentationHead, + SegmentationModel, ) +from segmentation_models_pytorch.encoders import get_encoder + from .decoder import UnetDecoder @@ -44,6 +45,7 @@ class Unet(SegmentationModel): - dropout (float): Dropout factor in [0, 1) - activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits) + kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing. Returns: ``torch.nn.Module``: Unet @@ -65,6 +67,7 @@ def __init__( classes: int = 1, activation: Optional[Union[str, callable]] = None, aux_params: Optional[dict] = None, + **kwargs: dict[str, Any], ): super().__init__() @@ -73,6 +76,7 @@ def __init__( in_channels=in_channels, depth=encoder_depth, weights=encoder_weights, + **kwargs, ) self.decoder = UnetDecoder( diff --git a/segmentation_models_pytorch/decoders/unetplusplus/model.py b/segmentation_models_pytorch/decoders/unetplusplus/model.py index 60d591f0..9ba72321 100644 --- a/segmentation_models_pytorch/decoders/unetplusplus/model.py +++ b/segmentation_models_pytorch/decoders/unetplusplus/model.py @@ -1,11 +1,12 @@ -from typing import Optional, Union, List +from typing import Any, List, Optional, Union -from segmentation_models_pytorch.encoders import get_encoder from segmentation_models_pytorch.base import ( - SegmentationModel, - SegmentationHead, ClassificationHead, + SegmentationHead, + SegmentationModel, ) +from segmentation_models_pytorch.encoders import get_encoder + from .decoder import UnetPlusPlusDecoder @@ -44,6 +45,7 @@ class UnetPlusPlus(SegmentationModel): - dropout (float): Dropout factor in [0, 1) - activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits) + kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing. Returns: ``torch.nn.Module``: **Unet++** @@ -65,6 +67,7 @@ def __init__( classes: int = 1, activation: Optional[Union[str, callable]] = None, aux_params: Optional[dict] = None, + **kwargs: dict[str, Any], ): super().__init__() @@ -78,6 +81,7 @@ def __init__( in_channels=in_channels, depth=encoder_depth, weights=encoder_weights, + **kwargs, ) self.decoder = UnetPlusPlusDecoder( diff --git a/segmentation_models_pytorch/decoders/upernet/model.py b/segmentation_models_pytorch/decoders/upernet/model.py index 18b97a94..de30a7bb 100644 --- a/segmentation_models_pytorch/decoders/upernet/model.py +++ b/segmentation_models_pytorch/decoders/upernet/model.py @@ -1,11 +1,12 @@ -from typing import Optional, Union +from typing import Any, Optional, Union -from segmentation_models_pytorch.encoders import get_encoder from segmentation_models_pytorch.base import ( - SegmentationModel, - SegmentationHead, ClassificationHead, + SegmentationHead, + SegmentationModel, ) +from segmentation_models_pytorch.encoders import get_encoder + from .decoder import UPerNetDecoder @@ -36,6 +37,7 @@ class UPerNet(SegmentationModel): - dropout (float): Dropout factor in [0, 1) - activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits) + kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing. Returns: ``torch.nn.Module``: **UPerNet** @@ -56,6 +58,7 @@ def __init__( classes: int = 1, activation: Optional[Union[str, callable]] = None, aux_params: Optional[dict] = None, + **kwargs: dict[str, Any], ): super().__init__() @@ -64,6 +67,7 @@ def __init__( in_channels=in_channels, depth=encoder_depth, weights=encoder_weights, + **kwargs, ) self.decoder = UPerNetDecoder( diff --git a/segmentation_models_pytorch/encoders/timm_universal.py b/segmentation_models_pytorch/encoders/timm_universal.py index 9702a7c3..eb008221 100644 --- a/segmentation_models_pytorch/encoders/timm_universal.py +++ b/segmentation_models_pytorch/encoders/timm_universal.py @@ -1,11 +1,21 @@ +from typing import Any + import timm import torch.nn as nn class TimmUniversalEncoder(nn.Module): - def __init__(self, name, pretrained=True, in_channels=3, depth=5, output_stride=32): + def __init__( + self, + name: str, + pretrained: bool = True, + in_channels: int = 3, + depth: int = 5, + output_stride: int = 32, + **kwargs: dict[str, Any], + ): super().__init__() - kwargs = dict( + common_kwargs = dict( in_chans=in_channels, features_only=True, output_stride=output_stride, @@ -15,9 +25,11 @@ def __init__(self, name, pretrained=True, in_channels=3, depth=5, output_stride= # not all models support output stride argument, drop it by default if output_stride == 32: - kwargs.pop("output_stride") + common_kwargs.pop("output_stride") - self.model = timm.create_model(name, **kwargs) + self.model = timm.create_model( + name, **_merge_kwargs_no_duplicates(common_kwargs, kwargs) + ) self._in_channels = in_channels self._out_channels = [in_channels] + self.model.feature_info.channels() @@ -36,3 +48,11 @@ def out_channels(self): @property def output_stride(self): return min(self._output_stride, 2**self._depth) + + +def _merge_kwargs_no_duplicates(a: dict[str, Any], b: dict[str, Any]) -> dict[str, Any]: + duplicates = a.keys() & b.keys() + if duplicates: + raise ValueError(f"'{duplicates}' already specified internally") + + return a | b