From 1ab02f5fc60b2cf70883523c9a5f3e45956e51a6 Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Sun, 17 Jul 2022 20:17:08 -0400 Subject: [PATCH 1/4] groupby.__iter__() fix types --- pandas-stubs/_typing.pyi | 6 ++++++ pandas-stubs/core/base.pyi | 4 +++- pandas-stubs/core/groupby/generic.pyi | 2 +- pandas-stubs/core/groupby/groupby.pyi | 13 ++++++++----- tests/test_frame.py | 11 +++++++++++ 5 files changed, 29 insertions(+), 7 deletions(-) diff --git a/pandas-stubs/_typing.pyi b/pandas-stubs/_typing.pyi index 8415187a8..2fe627c1d 100644 --- a/pandas-stubs/_typing.pyi +++ b/pandas-stubs/_typing.pyi @@ -166,6 +166,12 @@ IndexingInt = Union[ int, np.int_, np.integer, np.unsignedinteger, np.signedinteger, np.int8 ] +# NDFrameT is stricter and ensures that the same subclass of NDFrame always is +# used. E.g. `def func(a: NDFrameT) -> NDFrameT: ...` means that if a +# Series is passed into a function, a Series is always returned and if a DataFrame is +# passed in, a DataFrame is always returned. +NDFrameT = TypeVar("NDFrameT", bound=NDFrame) + # Interval closed type IntervalClosedType = Literal["left", "right", "both", "neither"] diff --git a/pandas-stubs/core/base.pyi b/pandas-stubs/core/base.pyi index 780c7bc16..bd1ee2681 100644 --- a/pandas-stubs/core/base.pyi +++ b/pandas-stubs/core/base.pyi @@ -2,6 +2,7 @@ from __future__ import annotations from typing import ( Callable, + Generic, List, Literal, Optional, @@ -20,6 +21,7 @@ from pandas.core.arrays import ExtensionArray from pandas.core.arrays.categorical import Categorical from pandas._typing import ( + NDFrameT, Scalar, SeriesAxisType, ) @@ -34,7 +36,7 @@ class GroupByError(Exception): ... class DataError(GroupByError): ... class SpecificationError(GroupByError): ... -class SelectionMixin: +class SelectionMixin(Generic[NDFrameT]): def ndim(self) -> int: ... def __getitem__(self, key): ... def aggregate( diff --git a/pandas-stubs/core/groupby/generic.pyi b/pandas-stubs/core/groupby/generic.pyi index d27536c6a..44af2c4aa 100644 --- a/pandas-stubs/core/groupby/generic.pyi +++ b/pandas-stubs/core/groupby/generic.pyi @@ -100,7 +100,7 @@ class SeriesGroupBy(GroupBy): self, n: Union[int, Sequence[int]], dropna: Optional[str] = ... ) -> Series[S1]: ... -class DataFrameGroupBy(GroupBy): +class DataFrameGroupBy(GroupBy[DataFrame]): def any(self, skipna: bool = ...) -> DataFrame: ... def all(self, skipna: bool = ...) -> DataFrame: ... def apply(self, func, *args, **kwargs) -> DataFrame: ... diff --git a/pandas-stubs/core/groupby/groupby.pyi b/pandas-stubs/core/groupby/groupby.pyi index b765a6e44..117ee2906 100644 --- a/pandas-stubs/core/groupby/groupby.pyi +++ b/pandas-stubs/core/groupby/groupby.pyi @@ -1,8 +1,9 @@ from typing import ( - Any, Callable, Dict, - Generator, + Generic, + Hashable, + Iterator, List, Optional, Tuple, @@ -13,6 +14,7 @@ from pandas.core.base import PandasObject from pandas.core.frame import DataFrame from pandas.core.generic import NDFrame from pandas.core.groupby import ops +from pandas.core.groupby.indexing import GroupByIndexingMixin from pandas.core.indexes.api import Index from pandas.core.series import Series @@ -20,6 +22,7 @@ from pandas._typing import ( AxisType, FrameOrSeriesUnion, KeysArgType, + NDFrameT, ) class GroupByPlot(PandasObject): @@ -27,7 +30,7 @@ class GroupByPlot(PandasObject): def __call__(self, *args, **kwargs): ... def __getattr__(self, name: str): ... -class _GroupBy(PandasObject): +class _GroupBy(PandasObject, GroupByIndexingMixin, Generic[NDFrameT]): level = ... as_index = ... keys = ... @@ -67,10 +70,10 @@ class _GroupBy(PandasObject): def pipe(self, func: Callable, *args, **kwargs): ... plot = ... def get_group(self, name, obj: Optional[DataFrame] = ...) -> DataFrame: ... - def __iter__(self) -> Generator[Tuple[str, Any], None, None]: ... + def __iter__(self) -> Iterator[Tuple[Hashable, NDFrameT]]: ... def apply(self, func: Callable, *args, **kwargs) -> FrameOrSeriesUnion: ... -class GroupBy(_GroupBy): +class GroupBy(_GroupBy[NDFrameT]): def count(self) -> FrameOrSeriesUnion: ... def mean(self, **kwargs) -> FrameOrSeriesUnion: ... def median(self, **kwargs) -> FrameOrSeriesUnion: ... diff --git a/tests/test_frame.py b/tests/test_frame.py index 6c799c20e..3163985f8 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -12,6 +12,7 @@ List, Tuple, Union, + cast, ) import numpy as np @@ -1251,3 +1252,13 @@ def test_boolean_loc() -> None: df = pd.DataFrame([[0, 1], [1, 0]], columns=[True, False], index=[True, False]) check(assert_type(df.loc[True], pd.Series), pd.Series) check(assert_type(df.loc[:, False], pd.Series), pd.Series) + + +def test_groupby_result() -> None: + # GH 142 + df = pd.DataFrame({"a": [0, 1, 2], "b": [4, 5, 6], "c": [7, 8, 9]}) + lresult = [(cast(Tuple[int, int], k), g) for k, g in df.groupby(["a", "b"])] + check(assert_type(lresult, List[Tuple[Tuple[int, int], pd.DataFrame]]), list, tuple) + + lresult2 = [(cast(int, k), g) for k, g in df.groupby("a")] + check(assert_type(lresult2, List[Tuple[int, pd.DataFrame]]), list, tuple) From 62fd358fa127c419e92da4beb1a2a8241177927a Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Sun, 17 Jul 2022 20:54:36 -0400 Subject: [PATCH 2/4] WIP: try splitting by label and otherwise --- pandas-stubs/_typing.pyi | 5 +++-- pandas-stubs/core/frame.pyi | 23 ++++++++++++++++++++--- pandas-stubs/core/groupby/generic.pyi | 5 +++++ tests/test_frame.py | 4 ++-- 4 files changed, 30 insertions(+), 7 deletions(-) diff --git a/pandas-stubs/_typing.pyi b/pandas-stubs/_typing.pyi index 2fe627c1d..eb52c57fc 100644 --- a/pandas-stubs/_typing.pyi +++ b/pandas-stubs/_typing.pyi @@ -203,6 +203,7 @@ XMLParsers = Literal["lxml", "etree"] # Any plain Python or numpy function Function = Union[np.ufunc, Callable[..., Any]] -GroupByObject = Union[ - Label, List[Label], Function, Series, np.ndarray, Mapping[Label, Any], Index +GroupByObjectNonLabel = Union[ + List[Label], Function, Series, np.ndarray, Mapping[Label, Any], Index ] +GroupByObject = Union[Label, GroupByObjectNonLabel] diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 6a946e267..c16b32118 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -30,7 +30,10 @@ from pandas import ( ) from pandas.core.arraylike import OpsMixin from pandas.core.generic import NDFrame -from pandas.core.groupby.generic import DataFrameGroupBy +from pandas.core.groupby.generic import ( + DataFrameGroupBy, + DataFrameGroupByLabel, +) from pandas.core.groupby.grouper import Grouper from pandas.core.indexes.base import Index from pandas.core.indexing import ( @@ -54,7 +57,7 @@ from pandas._typing import ( DtypeNp, FilePathOrBuffer, FilePathOrBytesBuffer, - GroupByObject, + GroupByObjectNonLabel, IgnoreRaise, IndexingInt, IndexLabel, @@ -862,9 +865,23 @@ class DataFrame(NDFrame, OpsMixin): filter_func: Optional[Callable] = ..., errors: Union[_str, Literal["raise", "ignore"]] = ..., ) -> None: ... + @overload + def groupby( + self, + by: Optional[Label] = ..., + axis: AxisType = ..., + level: Optional[Level] = ..., + as_index: _bool = ..., + sort: _bool = ..., + group_keys: _bool = ..., + squeeze: _bool = ..., + observed: _bool = ..., + dropna: _bool = ..., + ) -> DataFrameGroupByLabel: ... + @overload def groupby( self, - by: Optional[GroupByObject] = ..., + by: Optional[GroupByObjectNonLabel] = ..., axis: AxisType = ..., level: Optional[Level] = ..., as_index: _bool = ..., diff --git a/pandas-stubs/core/groupby/generic.pyi b/pandas-stubs/core/groupby/generic.pyi index 44af2c4aa..36818418b 100644 --- a/pandas-stubs/core/groupby/generic.pyi +++ b/pandas-stubs/core/groupby/generic.pyi @@ -3,6 +3,7 @@ from typing import ( Callable, Dict, FrozenSet, + Iterator, List, Literal, NamedTuple, @@ -31,6 +32,7 @@ from pandas._typing import ( AxisType, FrameOrSeries, FuncType, + Label, Level, ) @@ -100,6 +102,9 @@ class SeriesGroupBy(GroupBy): self, n: Union[int, Sequence[int]], dropna: Optional[str] = ... ) -> Series[S1]: ... +class DataFrameGroupByLabel(DataFrameGroupBy): + def __iter__(self) -> Iterator[Tuple[Label, DataFrame]]: ... + class DataFrameGroupBy(GroupBy[DataFrame]): def any(self, skipna: bool = ...) -> DataFrame: ... def all(self, skipna: bool = ...) -> DataFrame: ... diff --git a/tests/test_frame.py b/tests/test_frame.py index 3163985f8..2d9ce49d0 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -1260,5 +1260,5 @@ def test_groupby_result() -> None: lresult = [(cast(Tuple[int, int], k), g) for k, g in df.groupby(["a", "b"])] check(assert_type(lresult, List[Tuple[Tuple[int, int], pd.DataFrame]]), list, tuple) - lresult2 = [(cast(int, k), g) for k, g in df.groupby("a")] - check(assert_type(lresult2, List[Tuple[int, pd.DataFrame]]), list, tuple) + lresult2 = [(k, g) for k, g in df.groupby("a")] + check(assert_type(lresult2, List[Tuple[Label, pd.DataFrame]]), list, tuple) From 167ba506d7b9c676f4477b8f9fb594e90da7c9cf Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Sun, 17 Jul 2022 21:21:36 -0400 Subject: [PATCH 3/4] fix tests to avoid cast --- tests/test_frame.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/tests/test_frame.py b/tests/test_frame.py index 3163985f8..ef74364f3 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -12,7 +12,6 @@ List, Tuple, Union, - cast, ) import numpy as np @@ -1257,8 +1256,14 @@ def test_boolean_loc() -> None: def test_groupby_result() -> None: # GH 142 df = pd.DataFrame({"a": [0, 1, 2], "b": [4, 5, 6], "c": [7, 8, 9]}) - lresult = [(cast(Tuple[int, int], k), g) for k, g in df.groupby(["a", "b"])] - check(assert_type(lresult, List[Tuple[Tuple[int, int], pd.DataFrame]]), list, tuple) + index, value = next(df.groupby(["a", "b"]).__iter__()) + assert_type((index, value), Tuple[Hashable, pd.DataFrame]) - lresult2 = [(cast(int, k), g) for k, g in df.groupby("a")] - check(assert_type(lresult2, List[Tuple[int, pd.DataFrame]]), list, tuple) + check(assert_type(index, Hashable), tuple, np.int64) + check(assert_type(value, pd.DataFrame), pd.DataFrame) + + index2, value2 = next(df.groupby("a").__iter__()) + assert_type((index2, value2), Tuple[Hashable, pd.DataFrame]) + + check(assert_type(index2, Hashable), int) + check(assert_type(value2, pd.DataFrame), pd.DataFrame) From 981bf9aff982f3f2ddac5cca535cd9d4c912045c Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Mon, 18 Jul 2022 08:18:47 -0400 Subject: [PATCH 4/4] make new classes private. Change tests to test iterator and next --- pandas-stubs/core/frame.pyi | 8 ++++---- pandas-stubs/core/groupby/generic.pyi | 8 ++++---- pandas-stubs/core/series.pyi | 8 ++++---- tests/test_frame.py | 13 +++++++++---- 4 files changed, 21 insertions(+), 16 deletions(-) diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 54ec79068..b701c2540 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -31,8 +31,8 @@ from pandas import ( from pandas.core.arraylike import OpsMixin from pandas.core.generic import NDFrame from pandas.core.groupby.generic import ( - DataFrameGroupByNonScalar, - DataFrameGroupByScalar, + _DataFrameGroupByNonScalar, + _DataFrameGroupByScalar, ) from pandas.core.groupby.grouper import Grouper from pandas.core.indexes.base import Index @@ -877,7 +877,7 @@ class DataFrame(NDFrame, OpsMixin): squeeze: _bool = ..., observed: _bool = ..., dropna: _bool = ..., - ) -> DataFrameGroupByScalar: ... + ) -> _DataFrameGroupByScalar: ... @overload def groupby( self, @@ -890,7 +890,7 @@ class DataFrame(NDFrame, OpsMixin): squeeze: _bool = ..., observed: _bool = ..., dropna: _bool = ..., - ) -> DataFrameGroupByNonScalar: ... + ) -> _DataFrameGroupByNonScalar: ... def pivot( self, index=..., diff --git a/pandas-stubs/core/groupby/generic.pyi b/pandas-stubs/core/groupby/generic.pyi index 5bae37cc4..0eb7de9b4 100644 --- a/pandas-stubs/core/groupby/generic.pyi +++ b/pandas-stubs/core/groupby/generic.pyi @@ -48,10 +48,10 @@ def pin_whitelisted_properties( klass: Type[FrameOrSeries], whitelist: FrozenSet[str] ): ... -class SeriesGroupByScalar(SeriesGroupBy): +class _SeriesGroupByScalar(SeriesGroupBy): def __iter__(self) -> Iterator[Tuple[Scalar, Series]]: ... -class SeriesGroupByNonScalar(SeriesGroupBy): +class _SeriesGroupByNonScalar(SeriesGroupBy): def __iter__(self) -> Iterator[Tuple[Tuple, Series]]: ... class SeriesGroupBy(GroupBy): @@ -108,10 +108,10 @@ class SeriesGroupBy(GroupBy): self, n: Union[int, Sequence[int]], dropna: Optional[str] = ... ) -> Series[S1]: ... -class DataFrameGroupByScalar(DataFrameGroupBy): +class _DataFrameGroupByScalar(DataFrameGroupBy): def __iter__(self) -> Iterator[Tuple[Scalar, DataFrame]]: ... -class DataFrameGroupByNonScalar(DataFrameGroupBy): +class _DataFrameGroupByNonScalar(DataFrameGroupBy): def __iter__(self) -> Iterator[Tuple[Tuple, DataFrame]]: ... class DataFrameGroupBy(GroupBy): diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index d7e90aca1..a1a4ae9d4 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -35,8 +35,8 @@ from pandas import ( from pandas.core.arrays.base import ExtensionArray from pandas.core.arrays.categorical import CategoricalAccessor from pandas.core.groupby.generic import ( - SeriesGroupByNonScalar, - SeriesGroupByScalar, + _SeriesGroupByNonScalar, + _SeriesGroupByScalar, ) from pandas.core.indexes.accessors import ( CombinedDatetimelikeProperties, @@ -379,7 +379,7 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]): squeeze: _bool = ..., observed: _bool = ..., dropna: _bool = ..., - ) -> SeriesGroupByScalar: ... + ) -> _SeriesGroupByScalar: ... @overload def groupby( self, @@ -392,7 +392,7 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]): squeeze: _bool = ..., observed: _bool = ..., dropna: _bool = ..., - ) -> SeriesGroupByNonScalar: ... + ) -> _SeriesGroupByNonScalar: ... @overload def count(self, level: None = ...) -> int: ... @overload diff --git a/tests/test_frame.py b/tests/test_frame.py index 7dca98b46..ac437a707 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -9,6 +9,7 @@ Dict, Hashable, Iterable, + Iterator, List, Tuple, Union, @@ -1258,13 +1259,17 @@ def test_boolean_loc() -> None: def test_groupby_result() -> None: # GH 142 df = pd.DataFrame({"a": [0, 1, 2], "b": [4, 5, 6], "c": [7, 8, 9]}) - index, value = next(df.groupby(["a", "b"]).__iter__()) + iterator = df.groupby(["a", "b"]).__iter__() + assert_type(iterator, Iterator[Tuple[Tuple, pd.DataFrame]]) + index, value = next(iterator) assert_type((index, value), Tuple[Tuple, pd.DataFrame]) check(assert_type(index, Tuple), tuple, np.int64) check(assert_type(value, pd.DataFrame), pd.DataFrame) - index2, value2 = next(df.groupby("a").__iter__()) + iterator2 = df.groupby("a").__iter__() + assert_type(iterator2, Iterator[Tuple[Scalar, pd.DataFrame]]) + index2, value2 = next(iterator2) assert_type((index2, value2), Tuple[Scalar, pd.DataFrame]) check(assert_type(index2, Scalar), int) @@ -1272,7 +1277,7 @@ def test_groupby_result() -> None: # Want to make sure these cases are differentiated for (k1, k2), g in df.groupby(["a", "b"]): - print(k1, k2) + pass for kk, g in df.groupby("a"): - print(kk) + pass