Skip to content

Commit a6dd774

Browse files
authored
groupby.__iter__() fix types (#148)
* groupby.__iter__() fix types * WIP: try splitting by label and otherwise * fix tests to avoid cast * make new classes private. Change tests to test iterator and next
1 parent 33fbbe1 commit a6dd774

File tree

7 files changed

+101
-15
lines changed

7 files changed

+101
-15
lines changed

pandas-stubs/_typing.pyi

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ from typing import (
2020
Optional,
2121
Protocol,
2222
Sequence,
23+
Tuple,
2324
Type,
2425
TypeVar,
2526
Union,
@@ -166,6 +167,12 @@ IndexingInt = Union[
166167
int, np.int_, np.integer, np.unsignedinteger, np.signedinteger, np.int8
167168
]
168169

170+
# NDFrameT is stricter and ensures that the same subclass of NDFrame always is
171+
# used. E.g. `def func(a: NDFrameT) -> NDFrameT: ...` means that if a
172+
# Series is passed into a function, a Series is always returned and if a DataFrame is
173+
# passed in, a DataFrame is always returned.
174+
NDFrameT = TypeVar("NDFrameT", bound=NDFrame)
175+
169176
# Interval closed type
170177

171178
IntervalClosedType = Literal["left", "right", "both", "neither"]
@@ -197,6 +204,7 @@ XMLParsers = Literal["lxml", "etree"]
197204

198205
# Any plain Python or numpy function
199206
Function = Union[np.ufunc, Callable[..., Any]]
200-
GroupByObject = Union[
201-
Label, List[Label], Function, Series, np.ndarray, Mapping[Label, Any], Index
207+
GroupByObjectNonScalar = Union[
208+
Tuple, List[Label], Function, Series, np.ndarray, Mapping[Label, Any], Index
202209
]
210+
GroupByObject = Union[Scalar, GroupByObjectNonScalar]

pandas-stubs/core/base.pyi

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ from __future__ import annotations
22

33
from typing import (
44
Callable,
5+
Generic,
56
List,
67
Literal,
78
Optional,
@@ -20,6 +21,7 @@ from pandas.core.arrays import ExtensionArray
2021
from pandas.core.arrays.categorical import Categorical
2122

2223
from pandas._typing import (
24+
NDFrameT,
2325
Scalar,
2426
SeriesAxisType,
2527
)
@@ -34,7 +36,7 @@ class GroupByError(Exception): ...
3436
class DataError(GroupByError): ...
3537
class SpecificationError(GroupByError): ...
3638

37-
class SelectionMixin:
39+
class SelectionMixin(Generic[NDFrameT]):
3840
def ndim(self) -> int: ...
3941
def __getitem__(self, key): ...
4042
def aggregate(

pandas-stubs/core/frame.pyi

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@ from pandas import (
3030
)
3131
from pandas.core.arraylike import OpsMixin
3232
from pandas.core.generic import NDFrame
33-
from pandas.core.groupby.generic import DataFrameGroupBy
33+
from pandas.core.groupby.generic import (
34+
_DataFrameGroupByNonScalar,
35+
_DataFrameGroupByScalar,
36+
)
3437
from pandas.core.groupby.grouper import Grouper
3538
from pandas.core.indexes.base import Index
3639
from pandas.core.indexing import (
@@ -54,7 +57,7 @@ from pandas._typing import (
5457
DtypeNp,
5558
FilePathOrBuffer,
5659
FilePathOrBytesBuffer,
57-
GroupByObject,
60+
GroupByObjectNonScalar,
5861
IgnoreRaise,
5962
IndexingInt,
6063
IndexLabel,
@@ -862,9 +865,23 @@ class DataFrame(NDFrame, OpsMixin):
862865
filter_func: Optional[Callable] = ...,
863866
errors: Union[_str, Literal["raise", "ignore"]] = ...,
864867
) -> None: ...
868+
@overload
869+
def groupby(
870+
self,
871+
by: Scalar,
872+
axis: AxisType = ...,
873+
level: Optional[Level] = ...,
874+
as_index: _bool = ...,
875+
sort: _bool = ...,
876+
group_keys: _bool = ...,
877+
squeeze: _bool = ...,
878+
observed: _bool = ...,
879+
dropna: _bool = ...,
880+
) -> _DataFrameGroupByScalar: ...
881+
@overload
865882
def groupby(
866883
self,
867-
by: Optional[GroupByObject] = ...,
884+
by: Optional[GroupByObjectNonScalar] = ...,
868885
axis: AxisType = ...,
869886
level: Optional[Level] = ...,
870887
as_index: _bool = ...,
@@ -873,7 +890,7 @@ class DataFrame(NDFrame, OpsMixin):
873890
squeeze: _bool = ...,
874891
observed: _bool = ...,
875892
dropna: _bool = ...,
876-
) -> DataFrameGroupBy: ...
893+
) -> _DataFrameGroupByNonScalar: ...
877894
def pivot(
878895
self,
879896
index=...,

pandas-stubs/core/groupby/generic.pyi

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ from typing import (
33
Callable,
44
Dict,
55
FrozenSet,
6+
Iterator,
67
List,
78
Literal,
89
NamedTuple,
@@ -32,6 +33,7 @@ from pandas._typing import (
3233
FrameOrSeries,
3334
FuncType,
3435
Level,
36+
Scalar,
3537
)
3638

3739
AggScalar = Union[str, Callable[..., Any]]
@@ -46,6 +48,12 @@ def pin_whitelisted_properties(
4648
klass: Type[FrameOrSeries], whitelist: FrozenSet[str]
4749
): ...
4850

51+
class _SeriesGroupByScalar(SeriesGroupBy):
52+
def __iter__(self) -> Iterator[Tuple[Scalar, Series]]: ...
53+
54+
class _SeriesGroupByNonScalar(SeriesGroupBy):
55+
def __iter__(self) -> Iterator[Tuple[Tuple, Series]]: ...
56+
4957
class SeriesGroupBy(GroupBy):
5058
def any(self, skipna: bool = ...) -> Series[bool]: ...
5159
def all(self, skipna: bool = ...) -> Series[bool]: ...
@@ -100,6 +108,12 @@ class SeriesGroupBy(GroupBy):
100108
self, n: Union[int, Sequence[int]], dropna: Optional[str] = ...
101109
) -> Series[S1]: ...
102110

111+
class _DataFrameGroupByScalar(DataFrameGroupBy):
112+
def __iter__(self) -> Iterator[Tuple[Scalar, DataFrame]]: ...
113+
114+
class _DataFrameGroupByNonScalar(DataFrameGroupBy):
115+
def __iter__(self) -> Iterator[Tuple[Tuple, DataFrame]]: ...
116+
103117
class DataFrameGroupBy(GroupBy):
104118
def any(self, skipna: bool = ...) -> DataFrame: ...
105119
def all(self, skipna: bool = ...) -> DataFrame: ...

pandas-stubs/core/groupby/groupby.pyi

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,16 @@
11
from typing import (
2-
Any,
32
Callable,
43
Dict,
5-
Generator,
64
List,
75
Optional,
8-
Tuple,
96
Union,
107
)
118

129
from pandas.core.base import PandasObject
1310
from pandas.core.frame import DataFrame
1411
from pandas.core.generic import NDFrame
1512
from pandas.core.groupby import ops
13+
from pandas.core.groupby.indexing import GroupByIndexingMixin
1614
from pandas.core.indexes.api import Index
1715
from pandas.core.series import Series
1816

@@ -27,7 +25,7 @@ class GroupByPlot(PandasObject):
2725
def __call__(self, *args, **kwargs): ...
2826
def __getattr__(self, name: str): ...
2927

30-
class _GroupBy(PandasObject):
28+
class _GroupBy(PandasObject, GroupByIndexingMixin):
3129
level = ...
3230
as_index = ...
3331
keys = ...
@@ -67,7 +65,6 @@ class _GroupBy(PandasObject):
6765
def pipe(self, func: Callable, *args, **kwargs): ...
6866
plot = ...
6967
def get_group(self, name, obj: Optional[DataFrame] = ...) -> DataFrame: ...
70-
def __iter__(self) -> Generator[Tuple[str, Any], None, None]: ...
7168
def apply(self, func: Callable, *args, **kwargs) -> FrameOrSeriesUnion: ...
7269

7370
class GroupBy(_GroupBy):

pandas-stubs/core/series.pyi

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@ from pandas import (
3434
)
3535
from pandas.core.arrays.base import ExtensionArray
3636
from pandas.core.arrays.categorical import CategoricalAccessor
37-
from pandas.core.groupby.generic import SeriesGroupBy
37+
from pandas.core.groupby.generic import (
38+
_SeriesGroupByNonScalar,
39+
_SeriesGroupByScalar,
40+
)
3841
from pandas.core.indexes.accessors import (
3942
CombinedDatetimelikeProperties,
4043
DatetimeProperties,
@@ -65,6 +68,7 @@ from pandas._typing import (
6568
Dtype,
6669
DtypeNp,
6770
FilePathOrBuffer,
71+
GroupByObjectNonScalar,
6872
IgnoreRaise,
6973
IndexingInt,
7074
Label,
@@ -363,9 +367,23 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
363367
def keys(self) -> List: ...
364368
def to_dict(self, into: Hashable = ...) -> Dict[Any, S1]: ...
365369
def to_frame(self, name: Optional[object] = ...) -> DataFrame: ...
370+
@overload
371+
def groupby(
372+
self,
373+
by: Scalar,
374+
axis: SeriesAxisType = ...,
375+
level: Optional[Level] = ...,
376+
as_index: _bool = ...,
377+
sort: _bool = ...,
378+
group_keys: _bool = ...,
379+
squeeze: _bool = ...,
380+
observed: _bool = ...,
381+
dropna: _bool = ...,
382+
) -> _SeriesGroupByScalar: ...
383+
@overload
366384
def groupby(
367385
self,
368-
by=...,
386+
by: GroupByObjectNonScalar = ...,
369387
axis: SeriesAxisType = ...,
370388
level: Optional[Level] = ...,
371389
as_index: _bool = ...,
@@ -374,7 +392,7 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
374392
squeeze: _bool = ...,
375393
observed: _bool = ...,
376394
dropna: _bool = ...,
377-
) -> SeriesGroupBy: ...
395+
) -> _SeriesGroupByNonScalar: ...
378396
@overload
379397
def count(self, level: None = ...) -> int: ...
380398
@overload

tests/test_frame.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
Dict,
1010
Hashable,
1111
Iterable,
12+
Iterator,
1213
List,
1314
Tuple,
1415
Union,
@@ -20,6 +21,8 @@
2021
import pytest
2122
from typing_extensions import assert_type
2223

24+
from pandas._typing import Scalar
25+
2326
from tests import check
2427

2528
from pandas.io.parsers import TextFileReader
@@ -1251,3 +1254,30 @@ def test_boolean_loc() -> None:
12511254
df = pd.DataFrame([[0, 1], [1, 0]], columns=[True, False], index=[True, False])
12521255
check(assert_type(df.loc[True], pd.Series), pd.Series)
12531256
check(assert_type(df.loc[:, False], pd.Series), pd.Series)
1257+
1258+
1259+
def test_groupby_result() -> None:
1260+
# GH 142
1261+
df = pd.DataFrame({"a": [0, 1, 2], "b": [4, 5, 6], "c": [7, 8, 9]})
1262+
iterator = df.groupby(["a", "b"]).__iter__()
1263+
assert_type(iterator, Iterator[Tuple[Tuple, pd.DataFrame]])
1264+
index, value = next(iterator)
1265+
assert_type((index, value), Tuple[Tuple, pd.DataFrame])
1266+
1267+
check(assert_type(index, Tuple), tuple, np.int64)
1268+
check(assert_type(value, pd.DataFrame), pd.DataFrame)
1269+
1270+
iterator2 = df.groupby("a").__iter__()
1271+
assert_type(iterator2, Iterator[Tuple[Scalar, pd.DataFrame]])
1272+
index2, value2 = next(iterator2)
1273+
assert_type((index2, value2), Tuple[Scalar, pd.DataFrame])
1274+
1275+
check(assert_type(index2, Scalar), int)
1276+
check(assert_type(value2, pd.DataFrame), pd.DataFrame)
1277+
1278+
# Want to make sure these cases are differentiated
1279+
for (k1, k2), g in df.groupby(["a", "b"]):
1280+
pass
1281+
1282+
for kk, g in df.groupby("a"):
1283+
pass

0 commit comments

Comments
 (0)