Skip to content

Commit 4bb1fd5

Browse files
authored
TYP: Missing return annotations in util/tseries/plotting (#47510)
* TYP: Missing return annotations in util/tseries/plotting * the more tricky parts
1 parent cb67837 commit 4bb1fd5

File tree

5 files changed

+79
-33
lines changed

5 files changed

+79
-33
lines changed

pandas/plotting/_misc.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,18 @@
1+
from __future__ import annotations
2+
13
from contextlib import contextmanager
4+
from typing import (
5+
TYPE_CHECKING,
6+
Iterator,
7+
)
28

39
from pandas.plotting._core import _get_plot_backend
410

11+
if TYPE_CHECKING:
12+
from matplotlib.axes import Axes
13+
from matplotlib.figure import Figure
14+
import numpy as np
15+
516

617
def table(ax, data, rowLabels=None, colLabels=None, **kwargs):
718
"""
@@ -27,7 +38,7 @@ def table(ax, data, rowLabels=None, colLabels=None, **kwargs):
2738
)
2839

2940

30-
def register():
41+
def register() -> None:
3142
"""
3243
Register pandas formatters and converters with matplotlib.
3344
@@ -49,7 +60,7 @@ def register():
4960
plot_backend.register()
5061

5162

52-
def deregister():
63+
def deregister() -> None:
5364
"""
5465
Remove pandas formatters and converters.
5566
@@ -81,7 +92,7 @@ def scatter_matrix(
8192
hist_kwds=None,
8293
range_padding=0.05,
8394
**kwargs,
84-
):
95+
) -> np.ndarray:
8596
"""
8697
Draw a matrix of scatter plots.
8798
@@ -156,7 +167,7 @@ def scatter_matrix(
156167
)
157168

158169

159-
def radviz(frame, class_column, ax=None, color=None, colormap=None, **kwds):
170+
def radviz(frame, class_column, ax=None, color=None, colormap=None, **kwds) -> Axes:
160171
"""
161172
Plot a multidimensional dataset in 2D.
162173
@@ -239,7 +250,7 @@ def radviz(frame, class_column, ax=None, color=None, colormap=None, **kwds):
239250

240251
def andrews_curves(
241252
frame, class_column, ax=None, samples=200, color=None, colormap=None, **kwargs
242-
):
253+
) -> Axes:
243254
"""
244255
Generate a matplotlib plot of Andrews curves, for visualising clusters of
245256
multivariate data.
@@ -297,7 +308,7 @@ def andrews_curves(
297308
)
298309

299310

300-
def bootstrap_plot(series, fig=None, size=50, samples=500, **kwds):
311+
def bootstrap_plot(series, fig=None, size=50, samples=500, **kwds) -> Figure:
301312
"""
302313
Bootstrap plot on mean, median and mid-range statistics.
303314
@@ -364,7 +375,7 @@ def parallel_coordinates(
364375
axvlines_kwds=None,
365376
sort_labels=False,
366377
**kwargs,
367-
):
378+
) -> Axes:
368379
"""
369380
Parallel coordinates plotting.
370381
@@ -430,7 +441,7 @@ def parallel_coordinates(
430441
)
431442

432443

433-
def lag_plot(series, lag=1, ax=None, **kwds):
444+
def lag_plot(series, lag=1, ax=None, **kwds) -> Axes:
434445
"""
435446
Lag plot for time series.
436447
@@ -474,7 +485,7 @@ def lag_plot(series, lag=1, ax=None, **kwds):
474485
return plot_backend.lag_plot(series=series, lag=lag, ax=ax, **kwds)
475486

476487

477-
def autocorrelation_plot(series, ax=None, **kwargs):
488+
def autocorrelation_plot(series, ax=None, **kwargs) -> Axes:
478489
"""
479490
Autocorrelation plot for time series.
480491
@@ -531,21 +542,21 @@ def __getitem__(self, key):
531542
raise ValueError(f"{key} is not a valid pandas plotting option")
532543
return super().__getitem__(key)
533544

534-
def __setitem__(self, key, value):
545+
def __setitem__(self, key, value) -> None:
535546
key = self._get_canonical_key(key)
536-
return super().__setitem__(key, value)
547+
super().__setitem__(key, value)
537548

538-
def __delitem__(self, key):
549+
def __delitem__(self, key) -> None:
539550
key = self._get_canonical_key(key)
540551
if key in self._DEFAULT_KEYS:
541552
raise ValueError(f"Cannot remove default parameter {key}")
542-
return super().__delitem__(key)
553+
super().__delitem__(key)
543554

544555
def __contains__(self, key) -> bool:
545556
key = self._get_canonical_key(key)
546557
return super().__contains__(key)
547558

548-
def reset(self):
559+
def reset(self) -> None:
549560
"""
550561
Reset the option store to its initial state
551562
@@ -560,7 +571,7 @@ def _get_canonical_key(self, key):
560571
return self._ALIASES.get(key, key)
561572

562573
@contextmanager
563-
def use(self, key, value):
574+
def use(self, key, value) -> Iterator[_Options]:
564575
"""
565576
Temporarily set a parameter value using the with statement.
566577
Aliasing allowed.

pandas/tseries/frequencies.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -314,12 +314,12 @@ def get_freq(self) -> str | None:
314314
return _maybe_add_count("N", delta)
315315

316316
@cache_readonly
317-
def day_deltas(self):
317+
def day_deltas(self) -> list[int]:
318318
ppd = periods_per_day(self._reso)
319319
return [x / ppd for x in self.deltas]
320320

321321
@cache_readonly
322-
def hour_deltas(self):
322+
def hour_deltas(self) -> list[int]:
323323
pph = periods_per_day(self._reso) // 24
324324
return [x / pph for x in self.deltas]
325325

@@ -328,10 +328,10 @@ def fields(self) -> np.ndarray: # structured array of fields
328328
return build_field_sarray(self.i8values, reso=self._reso)
329329

330330
@cache_readonly
331-
def rep_stamp(self):
331+
def rep_stamp(self) -> Timestamp:
332332
return Timestamp(self.i8values[0])
333333

334-
def month_position_check(self):
334+
def month_position_check(self) -> str | None:
335335
return month_position_check(self.fields, self.index.dayofweek)
336336

337337
@cache_readonly
@@ -394,7 +394,11 @@ def _get_annual_rule(self) -> str | None:
394394
return None
395395

396396
pos_check = self.month_position_check()
397-
return {"cs": "AS", "bs": "BAS", "ce": "A", "be": "BA"}.get(pos_check)
397+
# error: Argument 1 to "get" of "dict" has incompatible type
398+
# "Optional[str]"; expected "str"
399+
return {"cs": "AS", "bs": "BAS", "ce": "A", "be": "BA"}.get(
400+
pos_check # type: ignore[arg-type]
401+
)
398402

399403
def _get_quarterly_rule(self) -> str | None:
400404
if len(self.mdiffs) > 1:
@@ -404,13 +408,21 @@ def _get_quarterly_rule(self) -> str | None:
404408
return None
405409

406410
pos_check = self.month_position_check()
407-
return {"cs": "QS", "bs": "BQS", "ce": "Q", "be": "BQ"}.get(pos_check)
411+
# error: Argument 1 to "get" of "dict" has incompatible type
412+
# "Optional[str]"; expected "str"
413+
return {"cs": "QS", "bs": "BQS", "ce": "Q", "be": "BQ"}.get(
414+
pos_check # type: ignore[arg-type]
415+
)
408416

409417
def _get_monthly_rule(self) -> str | None:
410418
if len(self.mdiffs) > 1:
411419
return None
412420
pos_check = self.month_position_check()
413-
return {"cs": "MS", "bs": "BMS", "ce": "M", "be": "BM"}.get(pos_check)
421+
# error: Argument 1 to "get" of "dict" has incompatible type
422+
# "Optional[str]"; expected "str"
423+
return {"cs": "MS", "bs": "BMS", "ce": "M", "be": "BM"}.get(
424+
pos_check # type: ignore[arg-type]
425+
)
414426

415427
def _is_business_daily(self) -> bool:
416428
# quick check: cannot be business daily

pandas/util/_exceptions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
import contextlib
44
import inspect
55
import os
6+
from typing import Iterator
67

78

89
@contextlib.contextmanager
9-
def rewrite_exception(old_name: str, new_name: str):
10+
def rewrite_exception(old_name: str, new_name: str) -> Iterator[None]:
1011
"""
1112
Rewrite the message of an exception.
1213
"""

pandas/util/_test_decorators.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@ def test_foo():
2727

2828
from contextlib import contextmanager
2929
import locale
30-
from typing import Callable
30+
from typing import (
31+
Callable,
32+
Iterator,
33+
)
3134
import warnings
3235

3336
import numpy as np
@@ -253,7 +256,7 @@ def check_file_leaks(func) -> Callable:
253256

254257

255258
@contextmanager
256-
def file_leak_context():
259+
def file_leak_context() -> Iterator[None]:
257260
"""
258261
ContextManager analogue to check_file_leaks.
259262
"""
@@ -290,7 +293,7 @@ def async_mark():
290293
return async_mark
291294

292295

293-
def mark_array_manager_not_yet_implemented(request):
296+
def mark_array_manager_not_yet_implemented(request) -> None:
294297
mark = pytest.mark.xfail(reason="Not yet implemented for ArrayManager")
295298
request.node.add_marker(mark)
296299

pandas/util/_validators.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from typing import (
88
Iterable,
99
Sequence,
10+
TypeVar,
11+
overload,
1012
)
1113
import warnings
1214

@@ -19,6 +21,9 @@
1921
is_integer,
2022
)
2123

24+
BoolishT = TypeVar("BoolishT", bool, int)
25+
BoolishNoneT = TypeVar("BoolishNoneT", bool, int, None)
26+
2227

2328
def _check_arg_length(fname, args, max_fname_arg_count, compat_args):
2429
"""
@@ -78,7 +83,7 @@ def _check_for_default_values(fname, arg_val_dict, compat_args):
7883
)
7984

8085

81-
def validate_args(fname, args, max_fname_arg_count, compat_args):
86+
def validate_args(fname, args, max_fname_arg_count, compat_args) -> None:
8287
"""
8388
Checks whether the length of the `*args` argument passed into a function
8489
has at most `len(compat_args)` arguments and whether or not all of these
@@ -132,7 +137,7 @@ def _check_for_invalid_keys(fname, kwargs, compat_args):
132137
raise TypeError(f"{fname}() got an unexpected keyword argument '{bad_arg}'")
133138

134139

135-
def validate_kwargs(fname, kwargs, compat_args):
140+
def validate_kwargs(fname, kwargs, compat_args) -> None:
136141
"""
137142
Checks whether parameters passed to the **kwargs argument in a
138143
function `fname` are valid parameters as specified in `*compat_args`
@@ -159,7 +164,9 @@ def validate_kwargs(fname, kwargs, compat_args):
159164
_check_for_default_values(fname, kwds, compat_args)
160165

161166

162-
def validate_args_and_kwargs(fname, args, kwargs, max_fname_arg_count, compat_args):
167+
def validate_args_and_kwargs(
168+
fname, args, kwargs, max_fname_arg_count, compat_args
169+
) -> None:
163170
"""
164171
Checks whether parameters passed to the *args and **kwargs argument in a
165172
function `fname` are valid parameters as specified in `*compat_args`
@@ -215,7 +222,9 @@ def validate_args_and_kwargs(fname, args, kwargs, max_fname_arg_count, compat_ar
215222
validate_kwargs(fname, kwargs, compat_args)
216223

217224

218-
def validate_bool_kwarg(value, arg_name, none_allowed=True, int_allowed=False):
225+
def validate_bool_kwarg(
226+
value: BoolishNoneT, arg_name, none_allowed=True, int_allowed=False
227+
) -> BoolishNoneT:
219228
"""
220229
Ensure that argument passed in arg_name can be interpreted as boolean.
221230
@@ -424,12 +433,22 @@ def validate_percentile(q: float | Iterable[float]) -> np.ndarray:
424433
return q_arr
425434

426435

436+
@overload
437+
def validate_ascending(ascending: BoolishT) -> BoolishT:
438+
...
439+
440+
441+
@overload
442+
def validate_ascending(ascending: Sequence[BoolishT]) -> list[BoolishT]:
443+
...
444+
445+
427446
def validate_ascending(
428-
ascending: bool | int | Sequence[bool | int] = True,
429-
):
447+
ascending: bool | int | Sequence[BoolishT],
448+
) -> bool | int | list[BoolishT]:
430449
"""Validate ``ascending`` kwargs for ``sort_index`` method."""
431450
kwargs = {"none_allowed": False, "int_allowed": True}
432-
if not isinstance(ascending, (list, tuple)):
451+
if not isinstance(ascending, Sequence):
433452
return validate_bool_kwarg(ascending, "ascending", **kwargs)
434453

435454
return [validate_bool_kwarg(item, "ascending", **kwargs) for item in ascending]

0 commit comments

Comments
 (0)