Skip to content

TYP: def validate_* #47750

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pandas/_libs/tslibs/offsets.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ class SingleConstructorOffset(BaseOffset):
@overload
def to_offset(freq: None) -> None: ...
@overload
def to_offset(freq: timedelta | BaseOffset | str) -> BaseOffset: ...
def to_offset(freq: _BaseOffsetT) -> _BaseOffsetT: ...
@overload
def to_offset(freq: timedelta | str) -> BaseOffset: ...

class Tick(SingleConstructorOffset):
_reso: int
Expand Down
46 changes: 35 additions & 11 deletions pandas/compat/numpy/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,28 @@
"""
from __future__ import annotations

from typing import Any
from typing import (
Any,
TypeVar,
overload,
)

from numpy import ndarray

from pandas._libs.lib import (
is_bool,
is_integer,
)
from pandas._typing import Axis
from pandas.errors import UnsupportedFunctionCall
from pandas.util._validators import (
validate_args,
validate_args_and_kwargs,
validate_kwargs,
)

AxisNoneT = TypeVar("AxisNoneT", Axis, None)


class CompatValidator:
def __init__(
Expand Down Expand Up @@ -84,15 +91,15 @@ def __call__(
)


def process_skipna(skipna, args):
def process_skipna(skipna: bool | ndarray | None, args) -> tuple[bool, Any]:
if isinstance(skipna, ndarray) or skipna is None:
args = (skipna,) + args
skipna = True

return skipna, args


def validate_argmin_with_skipna(skipna, args, kwargs):
def validate_argmin_with_skipna(skipna: bool | ndarray | None, args, kwargs) -> bool:
"""
If 'Series.argmin' is called via the 'numpy' library, the third parameter
in its signature is 'out', which takes either an ndarray or 'None', so
Expand All @@ -104,7 +111,7 @@ def validate_argmin_with_skipna(skipna, args, kwargs):
return skipna


def validate_argmax_with_skipna(skipna, args, kwargs):
def validate_argmax_with_skipna(skipna: bool | ndarray | None, args, kwargs) -> bool:
"""
If 'Series.argmax' is called via the 'numpy' library, the third parameter
in its signature is 'out', which takes either an ndarray or 'None', so
Expand Down Expand Up @@ -137,7 +144,7 @@ def validate_argmax_with_skipna(skipna, args, kwargs):
)


def validate_argsort_with_ascending(ascending, args, kwargs):
def validate_argsort_with_ascending(ascending: bool | int | None, args, kwargs) -> bool:
"""
If 'Categorical.argsort' is called via the 'numpy' library, the first
parameter in its signature is 'axis', which takes either an integer or
Expand All @@ -149,7 +156,8 @@ def validate_argsort_with_ascending(ascending, args, kwargs):
ascending = True

validate_argsort_kind(args, kwargs, max_fname_arg_count=3)
return ascending
# error: Incompatible return value type (got "int", expected "bool")
return ascending # type: ignore[return-value]


CLIP_DEFAULTS: dict[str, Any] = {"out": None}
Expand All @@ -158,7 +166,19 @@ def validate_argsort_with_ascending(ascending, args, kwargs):
)


def validate_clip_with_axis(axis, args, kwargs):
@overload
def validate_clip_with_axis(axis: ndarray, args, kwargs) -> None:
...


@overload
def validate_clip_with_axis(axis: AxisNoneT, args, kwargs) -> AxisNoneT:
...


def validate_clip_with_axis(
axis: ndarray | AxisNoneT, args, kwargs
) -> AxisNoneT | None:
"""
If 'NDFrame.clip' is called via the numpy library, the third parameter in
its signature is 'out', which can takes an ndarray, so check if the 'axis'
Expand All @@ -167,10 +187,14 @@ def validate_clip_with_axis(axis, args, kwargs):
"""
if isinstance(axis, ndarray):
args = (axis,) + args
axis = None
# error: Incompatible types in assignment (expression has type "None",
# variable has type "Union[ndarray[Any, Any], str, int]")
axis = None # type: ignore[assignment]

validate_clip(args, kwargs)
return axis
# error: Incompatible return value type (got "Union[ndarray[Any, Any],
# str, int]", expected "Union[str, int, None]")
return axis # type: ignore[return-value]


CUM_FUNC_DEFAULTS: dict[str, Any] = {}
Expand All @@ -184,7 +208,7 @@ def validate_clip_with_axis(axis, args, kwargs):
)


def validate_cum_func_with_skipna(skipna, args, kwargs, name):
def validate_cum_func_with_skipna(skipna, args, kwargs, name) -> bool:
"""
If this function is called via the 'numpy' library, the third parameter in
its signature is 'dtype', which takes either a 'numpy' dtype or 'None', so
Expand Down Expand Up @@ -288,7 +312,7 @@ def validate_cum_func_with_skipna(skipna, args, kwargs, name):
validate_take = CompatValidator(TAKE_DEFAULTS, fname="take", method="kwargs")


def validate_take_with_convert(convert, args, kwargs):
def validate_take_with_convert(convert: ndarray | bool | None, args, kwargs) -> bool:
"""
If this function is called via the 'numpy' library, the third parameter in
its signature is 'axis', which takes either an ndarray or 'None', so check
Expand Down
16 changes: 14 additions & 2 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -2164,7 +2164,17 @@ def ensure_arraylike_for_datetimelike(data, copy: bool, cls_name: str):
return data, copy


def validate_periods(periods):
@overload
def validate_periods(periods: None) -> None:
...


@overload
def validate_periods(periods: int | float) -> int:
...


def validate_periods(periods: int | float | None) -> int | None:
"""
If a `periods` argument is passed to the Datetime/Timedelta Array/Index
constructor, cast it to an integer.
Expand All @@ -2187,7 +2197,9 @@ def validate_periods(periods):
periods = int(periods)
elif not lib.is_integer(periods):
raise TypeError(f"periods must be a number, got {periods}")
return periods
# error: Incompatible return value type (got "Optional[float]",
# expected "Optional[int]")
return periods # type: ignore[return-value]


def validate_inferred_freq(freq, inferred_freq, freq_infer):
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/arrays/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def _scalar_type(self) -> type[Timestamp]:
# Constructors

_dtype: np.dtype | DatetimeTZDtype
_freq = None
_freq: BaseOffset | None = None
_default_dtype = DT64NS_DTYPE # used in TimeLikeOps.__init__

@classmethod
Expand Down
27 changes: 24 additions & 3 deletions pandas/core/arrays/period.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
Callable,
Literal,
Sequence,
TypeVar,
overload,
)

import numpy as np
Expand Down Expand Up @@ -92,6 +94,8 @@
TimedeltaArray,
)

BaseOffsetT = TypeVar("BaseOffsetT", bound=BaseOffset)


_shared_doc_kwargs = {
"klass": "PeriodArray",
Expand Down Expand Up @@ -976,7 +980,19 @@ def period_array(
return PeriodArray._from_sequence(data, dtype=dtype)


def validate_dtype_freq(dtype, freq):
@overload
def validate_dtype_freq(dtype, freq: BaseOffsetT) -> BaseOffsetT:
...


@overload
def validate_dtype_freq(dtype, freq: timedelta | str | None) -> BaseOffset:
...


def validate_dtype_freq(
dtype, freq: BaseOffsetT | timedelta | str | None
) -> BaseOffsetT:
"""
If both a dtype and a freq are available, ensure they match. If only
dtype is available, extract the implied freq.
Expand All @@ -996,7 +1012,10 @@ def validate_dtype_freq(dtype, freq):
IncompatibleFrequency : mismatch between dtype and freq
"""
if freq is not None:
freq = to_offset(freq)
# error: Incompatible types in assignment (expression has type
# "BaseOffset", variable has type "Union[BaseOffsetT, timedelta,
# str, None]")
freq = to_offset(freq) # type: ignore[assignment]

if dtype is not None:
dtype = pandas_dtype(dtype)
Expand All @@ -1006,7 +1025,9 @@ def validate_dtype_freq(dtype, freq):
freq = dtype.freq
elif freq != dtype.freq:
raise IncompatibleFrequency("specified freq and dtype are different")
return freq
# error: Incompatible return value type (got "Union[BaseOffset, Any, None]",
# expected "BaseOffset")
return freq # type: ignore[return-value]


def dt64arr_to_periodarr(
Expand Down
5 changes: 4 additions & 1 deletion pandas/util/_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

from typing import (
Any,
Iterable,
Sequence,
TypeVar,
Expand Down Expand Up @@ -265,7 +266,9 @@ def validate_bool_kwarg(
return value


def validate_axis_style_args(data, args, kwargs, arg_name, method_name):
def validate_axis_style_args(
data, args, kwargs, arg_name, method_name
) -> dict[str, Any]:
"""
Argument handler for mixed index, columns / axis functions

Expand Down