diff --git a/pandas/_libs/tslibs/offsets.pyi b/pandas/_libs/tslibs/offsets.pyi index 12b113f0b73b1..4567dde4e056b 100644 --- a/pandas/_libs/tslibs/offsets.pyi +++ b/pandas/_libs/tslibs/offsets.pyi @@ -104,7 +104,9 @@ class SingleConstructorOffset(BaseOffset): @overload def to_offset(freq: None) -> None: ... @overload -def to_offset(freq: timedelta | BaseOffset | str) -> BaseOffset: ... +def to_offset(freq: _BaseOffsetT) -> _BaseOffsetT: ... +@overload +def to_offset(freq: timedelta | str) -> BaseOffset: ... class Tick(SingleConstructorOffset): _reso: int diff --git a/pandas/compat/numpy/function.py b/pandas/compat/numpy/function.py index e3aa5bb52f2ba..140d41782e6d3 100644 --- a/pandas/compat/numpy/function.py +++ b/pandas/compat/numpy/function.py @@ -17,7 +17,11 @@ """ from __future__ import annotations -from typing import Any +from typing import ( + Any, + TypeVar, + overload, +) from numpy import ndarray @@ -25,6 +29,7 @@ is_bool, is_integer, ) +from pandas._typing import Axis from pandas.errors import UnsupportedFunctionCall from pandas.util._validators import ( validate_args, @@ -32,6 +37,8 @@ validate_kwargs, ) +AxisNoneT = TypeVar("AxisNoneT", Axis, None) + class CompatValidator: def __init__( @@ -84,7 +91,7 @@ def __call__( ) -def process_skipna(skipna, args): +def process_skipna(skipna: bool | ndarray | None, args) -> tuple[bool, Any]: if isinstance(skipna, ndarray) or skipna is None: args = (skipna,) + args skipna = True @@ -92,7 +99,7 @@ def process_skipna(skipna, args): return skipna, args -def validate_argmin_with_skipna(skipna, args, kwargs): +def validate_argmin_with_skipna(skipna: bool | ndarray | None, args, kwargs) -> bool: """ If 'Series.argmin' is called via the 'numpy' library, the third parameter in its signature is 'out', which takes either an ndarray or 'None', so @@ -104,7 +111,7 @@ def validate_argmin_with_skipna(skipna, args, kwargs): return skipna -def validate_argmax_with_skipna(skipna, args, kwargs): +def validate_argmax_with_skipna(skipna: bool | ndarray | None, args, kwargs) -> bool: """ If 'Series.argmax' is called via the 'numpy' library, the third parameter in its signature is 'out', which takes either an ndarray or 'None', so @@ -137,7 +144,7 @@ def validate_argmax_with_skipna(skipna, args, kwargs): ) -def validate_argsort_with_ascending(ascending, args, kwargs): +def validate_argsort_with_ascending(ascending: bool | int | None, args, kwargs) -> bool: """ If 'Categorical.argsort' is called via the 'numpy' library, the first parameter in its signature is 'axis', which takes either an integer or @@ -149,7 +156,8 @@ def validate_argsort_with_ascending(ascending, args, kwargs): ascending = True validate_argsort_kind(args, kwargs, max_fname_arg_count=3) - return ascending + # error: Incompatible return value type (got "int", expected "bool") + return ascending # type: ignore[return-value] CLIP_DEFAULTS: dict[str, Any] = {"out": None} @@ -158,7 +166,19 @@ def validate_argsort_with_ascending(ascending, args, kwargs): ) -def validate_clip_with_axis(axis, args, kwargs): +@overload +def validate_clip_with_axis(axis: ndarray, args, kwargs) -> None: + ... + + +@overload +def validate_clip_with_axis(axis: AxisNoneT, args, kwargs) -> AxisNoneT: + ... + + +def validate_clip_with_axis( + axis: ndarray | AxisNoneT, args, kwargs +) -> AxisNoneT | None: """ If 'NDFrame.clip' is called via the numpy library, the third parameter in its signature is 'out', which can takes an ndarray, so check if the 'axis' @@ -167,10 +187,14 @@ def validate_clip_with_axis(axis, args, kwargs): """ if isinstance(axis, ndarray): args = (axis,) + args - axis = None + # error: Incompatible types in assignment (expression has type "None", + # variable has type "Union[ndarray[Any, Any], str, int]") + axis = None # type: ignore[assignment] validate_clip(args, kwargs) - return axis + # error: Incompatible return value type (got "Union[ndarray[Any, Any], + # str, int]", expected "Union[str, int, None]") + return axis # type: ignore[return-value] CUM_FUNC_DEFAULTS: dict[str, Any] = {} @@ -184,7 +208,7 @@ def validate_clip_with_axis(axis, args, kwargs): ) -def validate_cum_func_with_skipna(skipna, args, kwargs, name): +def validate_cum_func_with_skipna(skipna, args, kwargs, name) -> bool: """ If this function is called via the 'numpy' library, the third parameter in its signature is 'dtype', which takes either a 'numpy' dtype or 'None', so @@ -288,7 +312,7 @@ def validate_cum_func_with_skipna(skipna, args, kwargs, name): validate_take = CompatValidator(TAKE_DEFAULTS, fname="take", method="kwargs") -def validate_take_with_convert(convert, args, kwargs): +def validate_take_with_convert(convert: ndarray | bool | None, args, kwargs) -> bool: """ If this function is called via the 'numpy' library, the third parameter in its signature is 'axis', which takes either an ndarray or 'None', so check diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index 0f88ad9811bf0..325c94d0ea267 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -2164,7 +2164,17 @@ def ensure_arraylike_for_datetimelike(data, copy: bool, cls_name: str): return data, copy -def validate_periods(periods): +@overload +def validate_periods(periods: None) -> None: + ... + + +@overload +def validate_periods(periods: int | float) -> int: + ... + + +def validate_periods(periods: int | float | None) -> int | None: """ If a `periods` argument is passed to the Datetime/Timedelta Array/Index constructor, cast it to an integer. @@ -2187,7 +2197,9 @@ def validate_periods(periods): periods = int(periods) elif not lib.is_integer(periods): raise TypeError(f"periods must be a number, got {periods}") - return periods + # error: Incompatible return value type (got "Optional[float]", + # expected "Optional[int]") + return periods # type: ignore[return-value] def validate_inferred_freq(freq, inferred_freq, freq_infer): diff --git a/pandas/core/arrays/datetimes.py b/pandas/core/arrays/datetimes.py index 106afcc3c12ea..d9f6cecc8d61d 100644 --- a/pandas/core/arrays/datetimes.py +++ b/pandas/core/arrays/datetimes.py @@ -251,7 +251,7 @@ def _scalar_type(self) -> type[Timestamp]: # Constructors _dtype: np.dtype | DatetimeTZDtype - _freq = None + _freq: BaseOffset | None = None _default_dtype = DT64NS_DTYPE # used in TimeLikeOps.__init__ @classmethod diff --git a/pandas/core/arrays/period.py b/pandas/core/arrays/period.py index fa7c4e0d0aa70..6e6de8399cc38 100644 --- a/pandas/core/arrays/period.py +++ b/pandas/core/arrays/period.py @@ -8,6 +8,8 @@ Callable, Literal, Sequence, + TypeVar, + overload, ) import numpy as np @@ -92,6 +94,8 @@ TimedeltaArray, ) +BaseOffsetT = TypeVar("BaseOffsetT", bound=BaseOffset) + _shared_doc_kwargs = { "klass": "PeriodArray", @@ -976,7 +980,19 @@ def period_array( return PeriodArray._from_sequence(data, dtype=dtype) -def validate_dtype_freq(dtype, freq): +@overload +def validate_dtype_freq(dtype, freq: BaseOffsetT) -> BaseOffsetT: + ... + + +@overload +def validate_dtype_freq(dtype, freq: timedelta | str | None) -> BaseOffset: + ... + + +def validate_dtype_freq( + dtype, freq: BaseOffsetT | timedelta | str | None +) -> BaseOffsetT: """ If both a dtype and a freq are available, ensure they match. If only dtype is available, extract the implied freq. @@ -996,7 +1012,10 @@ def validate_dtype_freq(dtype, freq): IncompatibleFrequency : mismatch between dtype and freq """ if freq is not None: - freq = to_offset(freq) + # error: Incompatible types in assignment (expression has type + # "BaseOffset", variable has type "Union[BaseOffsetT, timedelta, + # str, None]") + freq = to_offset(freq) # type: ignore[assignment] if dtype is not None: dtype = pandas_dtype(dtype) @@ -1006,7 +1025,9 @@ def validate_dtype_freq(dtype, freq): freq = dtype.freq elif freq != dtype.freq: raise IncompatibleFrequency("specified freq and dtype are different") - return freq + # error: Incompatible return value type (got "Union[BaseOffset, Any, None]", + # expected "BaseOffset") + return freq # type: ignore[return-value] def dt64arr_to_periodarr( diff --git a/pandas/util/_validators.py b/pandas/util/_validators.py index 3676e6eb0091e..fc3439a57a002 100644 --- a/pandas/util/_validators.py +++ b/pandas/util/_validators.py @@ -5,6 +5,7 @@ from __future__ import annotations from typing import ( + Any, Iterable, Sequence, TypeVar, @@ -265,7 +266,9 @@ def validate_bool_kwarg( return value -def validate_axis_style_args(data, args, kwargs, arg_name, method_name): +def validate_axis_style_args( + data, args, kwargs, arg_name, method_name +) -> dict[str, Any]: """ Argument handler for mixed index, columns / axis functions