Skip to content

groupby.__iter__() fix types #148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions pandas-stubs/_typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ from typing import (
Optional,
Protocol,
Sequence,
Tuple,
Type,
TypeVar,
Union,
Expand Down Expand Up @@ -166,6 +167,12 @@ IndexingInt = Union[
int, np.int_, np.integer, np.unsignedinteger, np.signedinteger, np.int8
]

# NDFrameT is stricter and ensures that the same subclass of NDFrame always is
# used. E.g. `def func(a: NDFrameT) -> NDFrameT: ...` means that if a
# Series is passed into a function, a Series is always returned and if a DataFrame is
# passed in, a DataFrame is always returned.
NDFrameT = TypeVar("NDFrameT", bound=NDFrame)

# Interval closed type

IntervalClosedType = Literal["left", "right", "both", "neither"]
Expand Down Expand Up @@ -197,6 +204,7 @@ XMLParsers = Literal["lxml", "etree"]

# Any plain Python or numpy function
Function = Union[np.ufunc, Callable[..., Any]]
GroupByObject = Union[
Label, List[Label], Function, Series, np.ndarray, Mapping[Label, Any], Index
GroupByObjectNonScalar = Union[
Tuple, List[Label], Function, Series, np.ndarray, Mapping[Label, Any], Index
]
GroupByObject = Union[Scalar, GroupByObjectNonScalar]
4 changes: 3 additions & 1 deletion pandas-stubs/core/base.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ from __future__ import annotations

from typing import (
Callable,
Generic,
List,
Literal,
Optional,
Expand All @@ -20,6 +21,7 @@ from pandas.core.arrays import ExtensionArray
from pandas.core.arrays.categorical import Categorical

from pandas._typing import (
NDFrameT,
Scalar,
SeriesAxisType,
)
Expand All @@ -34,7 +36,7 @@ class GroupByError(Exception): ...
class DataError(GroupByError): ...
class SpecificationError(GroupByError): ...

class SelectionMixin:
class SelectionMixin(Generic[NDFrameT]):
def ndim(self) -> int: ...
def __getitem__(self, key): ...
def aggregate(
Expand Down
25 changes: 21 additions & 4 deletions pandas-stubs/core/frame.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ from pandas import (
)
from pandas.core.arraylike import OpsMixin
from pandas.core.generic import NDFrame
from pandas.core.groupby.generic import DataFrameGroupBy
from pandas.core.groupby.generic import (
_DataFrameGroupByNonScalar,
_DataFrameGroupByScalar,
)
from pandas.core.groupby.grouper import Grouper
from pandas.core.indexes.base import Index
from pandas.core.indexing import (
Expand All @@ -54,7 +57,7 @@ from pandas._typing import (
DtypeNp,
FilePathOrBuffer,
FilePathOrBytesBuffer,
GroupByObject,
GroupByObjectNonScalar,
IgnoreRaise,
IndexingInt,
IndexLabel,
Expand Down Expand Up @@ -862,9 +865,23 @@ class DataFrame(NDFrame, OpsMixin):
filter_func: Optional[Callable] = ...,
errors: Union[_str, Literal["raise", "ignore"]] = ...,
) -> None: ...
@overload
def groupby(
self,
by: Scalar,
axis: AxisType = ...,
level: Optional[Level] = ...,
as_index: _bool = ...,
sort: _bool = ...,
group_keys: _bool = ...,
squeeze: _bool = ...,
observed: _bool = ...,
dropna: _bool = ...,
) -> _DataFrameGroupByScalar: ...
@overload
def groupby(
self,
by: Optional[GroupByObject] = ...,
by: Optional[GroupByObjectNonScalar] = ...,
axis: AxisType = ...,
level: Optional[Level] = ...,
as_index: _bool = ...,
Expand All @@ -873,7 +890,7 @@ class DataFrame(NDFrame, OpsMixin):
squeeze: _bool = ...,
observed: _bool = ...,
dropna: _bool = ...,
) -> DataFrameGroupBy: ...
) -> _DataFrameGroupByNonScalar: ...
def pivot(
self,
index=...,
Expand Down
14 changes: 14 additions & 0 deletions pandas-stubs/core/groupby/generic.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ from typing import (
Callable,
Dict,
FrozenSet,
Iterator,
List,
Literal,
NamedTuple,
Expand Down Expand Up @@ -32,6 +33,7 @@ from pandas._typing import (
FrameOrSeries,
FuncType,
Level,
Scalar,
)

AggScalar = Union[str, Callable[..., Any]]
Expand All @@ -46,6 +48,12 @@ def pin_whitelisted_properties(
klass: Type[FrameOrSeries], whitelist: FrozenSet[str]
): ...

class _SeriesGroupByScalar(SeriesGroupBy):
def __iter__(self) -> Iterator[Tuple[Scalar, Series]]: ...

class _SeriesGroupByNonScalar(SeriesGroupBy):
def __iter__(self) -> Iterator[Tuple[Tuple, Series]]: ...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No strong opinion: I see the point from a user-perspective, but I think deviating too much from the pandas implementation will make it challenging to use stubtest later (and maybe also maintaining the stubs).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should probably make these classes that are only in the stubs but not in pandas private

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No strong opinion

I think I develop a slightly stronger opinion. I would prefer not to introduce too many non-pandas classes: I believe that this will keep pandas-stubs more maintainable in the long-run.

Personally, I prefer the previous state of the PR where you returned tuple[Hashable, NDFrameT] (or making SeriesGroupBy generic, might be more invasive).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I develop a slightly stronger opinion. I would prefer not to introduce too many non-pandas classes: I believe that this will keep pandas-stubs more maintainable in the long-run.

Personally, I prefer the previous state of the PR where you returned tuple[Hashable, NDFrameT] (or making SeriesGroupBy generic, might be more invasive).

This is a tough call. There are lots of things in pandas that are designed without static typing taken into consideration. In this case, we want to differentiate when someone writes df.groupby(["a", "b"]) versus df.groupby("a") so that __iter__() returns a different result. The implementation has __iter__() in a base class, so there isn't an easy way from a static typing perspective to differentiate between the different kinds of arguments of groupby() when you get down to the base class.

My philosophy on the stubs has been to make it useful for end users with the most common ways of using pandas. To do that, we have to deviate from the implementation. A good example of this is Series.dt, where the accessors are all dynamically hooked in, but we have to create static declarations for each of the accessors. I think this groupby() example is similar.

I tried some experiments making DataFrameGroupBy generic, but the types that are then returned by __iter__() become too wide. Now we just make the return types either Iterator[Tuple[Tuple, DataFrame]] or Iterator[Tuple[Scalar, DataFrame]], which covers the majority of use cases.

My suggestion is you accept this PR as it is - it addresses the issue reported, and we could create a new issue to see if we can figure out a way to make a generic implementation work.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should probably make these classes that are only in the stubs but not in pandas private

Did that in the next commit

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I prefer pandas and pandas-stubs to be aligned (I would like if pandas and pandas-stub could in the future converge and only be different in challenging cases: __getitem__, Timestamp, ... - might be a long list) I'm fine with introducing classes/variables that are not in pandas as long as they are private (or some other prefix).


class SeriesGroupBy(GroupBy):
def any(self, skipna: bool = ...) -> Series[bool]: ...
def all(self, skipna: bool = ...) -> Series[bool]: ...
Expand Down Expand Up @@ -100,6 +108,12 @@ class SeriesGroupBy(GroupBy):
self, n: Union[int, Sequence[int]], dropna: Optional[str] = ...
) -> Series[S1]: ...

class _DataFrameGroupByScalar(DataFrameGroupBy):
def __iter__(self) -> Iterator[Tuple[Scalar, DataFrame]]: ...

class _DataFrameGroupByNonScalar(DataFrameGroupBy):
def __iter__(self) -> Iterator[Tuple[Tuple, DataFrame]]: ...

class DataFrameGroupBy(GroupBy):
def any(self, skipna: bool = ...) -> DataFrame: ...
def all(self, skipna: bool = ...) -> DataFrame: ...
Expand Down
7 changes: 2 additions & 5 deletions pandas-stubs/core/groupby/groupby.pyi
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
from typing import (
Any,
Callable,
Dict,
Generator,
List,
Optional,
Tuple,
Union,
)

from pandas.core.base import PandasObject
from pandas.core.frame import DataFrame
from pandas.core.generic import NDFrame
from pandas.core.groupby import ops
from pandas.core.groupby.indexing import GroupByIndexingMixin
from pandas.core.indexes.api import Index
from pandas.core.series import Series

Expand All @@ -27,7 +25,7 @@ class GroupByPlot(PandasObject):
def __call__(self, *args, **kwargs): ...
def __getattr__(self, name: str): ...

class _GroupBy(PandasObject):
class _GroupBy(PandasObject, GroupByIndexingMixin):
level = ...
as_index = ...
keys = ...
Expand Down Expand Up @@ -67,7 +65,6 @@ class _GroupBy(PandasObject):
def pipe(self, func: Callable, *args, **kwargs): ...
plot = ...
def get_group(self, name, obj: Optional[DataFrame] = ...) -> DataFrame: ...
def __iter__(self) -> Generator[Tuple[str, Any], None, None]: ...
def apply(self, func: Callable, *args, **kwargs) -> FrameOrSeriesUnion: ...

class GroupBy(_GroupBy):
Expand Down
24 changes: 21 additions & 3 deletions pandas-stubs/core/series.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ from pandas import (
)
from pandas.core.arrays.base import ExtensionArray
from pandas.core.arrays.categorical import CategoricalAccessor
from pandas.core.groupby.generic import SeriesGroupBy
from pandas.core.groupby.generic import (
_SeriesGroupByNonScalar,
_SeriesGroupByScalar,
)
from pandas.core.indexes.accessors import (
CombinedDatetimelikeProperties,
DatetimeProperties,
Expand Down Expand Up @@ -65,6 +68,7 @@ from pandas._typing import (
Dtype,
DtypeNp,
FilePathOrBuffer,
GroupByObjectNonScalar,
IgnoreRaise,
IndexingInt,
Label,
Expand Down Expand Up @@ -363,9 +367,23 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
def keys(self) -> List: ...
def to_dict(self, into: Hashable = ...) -> Dict[Any, S1]: ...
def to_frame(self, name: Optional[object] = ...) -> DataFrame: ...
@overload
def groupby(
self,
by: Scalar,
axis: SeriesAxisType = ...,
level: Optional[Level] = ...,
as_index: _bool = ...,
sort: _bool = ...,
group_keys: _bool = ...,
squeeze: _bool = ...,
observed: _bool = ...,
dropna: _bool = ...,
) -> _SeriesGroupByScalar: ...
@overload
def groupby(
self,
by=...,
by: GroupByObjectNonScalar = ...,
axis: SeriesAxisType = ...,
level: Optional[Level] = ...,
as_index: _bool = ...,
Expand All @@ -374,7 +392,7 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
squeeze: _bool = ...,
observed: _bool = ...,
dropna: _bool = ...,
) -> SeriesGroupBy: ...
) -> _SeriesGroupByNonScalar: ...
@overload
def count(self, level: None = ...) -> int: ...
@overload
Expand Down
30 changes: 30 additions & 0 deletions tests/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Dict,
Hashable,
Iterable,
Iterator,
List,
Tuple,
Union,
Expand All @@ -20,6 +21,8 @@
import pytest
from typing_extensions import assert_type

from pandas._typing import Scalar

from tests import check

from pandas.io.parsers import TextFileReader
Expand Down Expand Up @@ -1251,3 +1254,30 @@ def test_boolean_loc() -> None:
df = pd.DataFrame([[0, 1], [1, 0]], columns=[True, False], index=[True, False])
check(assert_type(df.loc[True], pd.Series), pd.Series)
check(assert_type(df.loc[:, False], pd.Series), pd.Series)


def test_groupby_result() -> None:
# GH 142
df = pd.DataFrame({"a": [0, 1, 2], "b": [4, 5, 6], "c": [7, 8, 9]})
iterator = df.groupby(["a", "b"]).__iter__()
assert_type(iterator, Iterator[Tuple[Tuple, pd.DataFrame]])
index, value = next(iterator)
assert_type((index, value), Tuple[Tuple, pd.DataFrame])

check(assert_type(index, Tuple), tuple, np.int64)
check(assert_type(value, pd.DataFrame), pd.DataFrame)

iterator2 = df.groupby("a").__iter__()
assert_type(iterator2, Iterator[Tuple[Scalar, pd.DataFrame]])
index2, value2 = next(iterator2)
assert_type((index2, value2), Tuple[Scalar, pd.DataFrame])

check(assert_type(index2, Scalar), int)
check(assert_type(value2, pd.DataFrame), pd.DataFrame)

# Want to make sure these cases are differentiated
for (k1, k2), g in df.groupby(["a", "b"]):
pass

for kk, g in df.groupby("a"):
pass