From 263fdcf65d951f95ec5e338490ebcb922d75f9f2 Mon Sep 17 00:00:00 2001 From: rhshadrach Date: Sun, 11 Oct 2020 17:49:28 -0400 Subject: [PATCH] TYP: Misc groupby typing --- pandas/core/generic.py | 2 +- pandas/core/groupby/generic.py | 4 ++-- pandas/core/groupby/groupby.py | 19 +++++++++++-------- pandas/core/groupby/ops.py | 16 ++++++++++++++-- 4 files changed, 28 insertions(+), 13 deletions(-) diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 86b6c4a6cf575..3079bb0b79b33 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -5417,7 +5417,7 @@ def __setattr__(self, name: str, value) -> None: ) object.__setattr__(self, name, value) - def _dir_additions(self): + def _dir_additions(self) -> Set[str]: """ add the string-like attributes from the info_axis. If info_axis is a MultiIndex, it's first level values are used. diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index af3aa5d121391..9c78166ce0480 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -30,7 +30,7 @@ import numpy as np from pandas._libs import lib, reduction as libreduction -from pandas._typing import ArrayLike, FrameOrSeries, FrameOrSeriesUnion +from pandas._typing import ArrayLike, FrameOrSeries, FrameOrSeriesUnion, Label from pandas.util._decorators import Appender, Substitution, doc from pandas.core.dtypes.cast import ( @@ -1131,7 +1131,7 @@ def _aggregate_frame(self, func, *args, **kwargs) -> DataFrame: axis = self.axis obj = self._obj_with_exclusions - result: Dict[Union[int, str], Union[NDFrame, np.ndarray]] = {} + result: Dict[Label, Union[NDFrame, np.ndarray]] = {} if axis != obj._info_axis_number: for name, data in self: fres = func(data, *args, **kwargs) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index c758844da3a2b..4ab40e25db84e 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -20,6 +20,7 @@ class providing the base-class of operations. Generic, Hashable, Iterable, + Iterator, List, Mapping, Optional, @@ -465,7 +466,7 @@ def f(self): @contextmanager -def group_selection_context(groupby: "BaseGroupBy"): +def group_selection_context(groupby: "BaseGroupBy") -> Iterator["BaseGroupBy"]: """ Set / reset the group_selection_context. """ @@ -486,7 +487,7 @@ def group_selection_context(groupby: "BaseGroupBy"): class BaseGroupBy(PandasObject, SelectionMixin, Generic[FrameOrSeries]): - _group_selection = None + _group_selection: Optional[IndexLabel] = None _apply_allowlist: FrozenSet[str] = frozenset() def __init__( @@ -570,7 +571,7 @@ def groups(self) -> Dict[Hashable, np.ndarray]: return self.grouper.groups @property - def ngroups(self): + def ngroups(self) -> int: self._assure_grouper() return self.grouper.ngroups @@ -649,7 +650,7 @@ def _selected_obj(self): else: return self.obj[self._selection] - def _reset_group_selection(self): + def _reset_group_selection(self) -> None: """ Clear group based selection. @@ -661,7 +662,7 @@ def _reset_group_selection(self): self._group_selection = None self._reset_cache("_selected_obj") - def _set_group_selection(self): + def _set_group_selection(self) -> None: """ Create group based selection. @@ -686,7 +687,9 @@ def _set_group_selection(self): self._group_selection = ax.difference(Index(groupers), sort=False).tolist() self._reset_cache("_selected_obj") - def _set_result_index_ordered(self, result): + def _set_result_index_ordered( + self, result: "OutputFrameOrSeries" + ) -> "OutputFrameOrSeries": # set the result index on the passed values object and # return the new object, xref 8046 @@ -700,7 +703,7 @@ def _set_result_index_ordered(self, result): result.set_axis(self.obj._get_axis(self.axis), axis=self.axis, inplace=True) return result - def _dir_additions(self): + def _dir_additions(self) -> Set[str]: return self.obj._dir_additions() | self._apply_allowlist def __getattr__(self, attr: str): @@ -818,7 +821,7 @@ def get_group(self, name, obj=None): return obj._take_with_is_copy(inds, axis=self.axis) - def __iter__(self): + def __iter__(self) -> Iterator[Tuple[Label, FrameOrSeries]]: """ Groupby iterator. diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index 1c18ef891b8c5..bca71b5c9646b 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -7,7 +7,17 @@ """ import collections -from typing import Dict, Generic, Hashable, List, Optional, Sequence, Tuple, Type +from typing import ( + Dict, + Generic, + Hashable, + Iterator, + List, + Optional, + Sequence, + Tuple, + Type, +) import numpy as np @@ -116,7 +126,9 @@ def __iter__(self): def nkeys(self) -> int: return len(self.groupings) - def get_iterator(self, data: FrameOrSeries, axis: int = 0): + def get_iterator( + self, data: FrameOrSeries, axis: int = 0 + ) -> Iterator[Tuple[Label, FrameOrSeries]]: """ Groupby iterator