From 32666780d6b6795cc69affb9fc2c7d07d2658e7d Mon Sep 17 00:00:00 2001 From: Daniel Roseman Date: Mon, 1 Aug 2022 17:06:27 +0100 Subject: [PATCH 1/3] More specific typing for DataFrameGroupBy.apply. --- pandas-stubs/core/groupby/generic.pyi | 12 +++++++----- tests/test_frame.py | 23 +++++++++++++++++++++-- 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/pandas-stubs/core/groupby/generic.pyi b/pandas-stubs/core/groupby/generic.pyi index aedcb49af..04df13ddb 100644 --- a/pandas-stubs/core/groupby/generic.pyi +++ b/pandas-stubs/core/groupby/generic.pyi @@ -110,18 +110,20 @@ class _DataFrameGroupByNonScalar(DataFrameGroupBy): class DataFrameGroupBy(GroupBy): def any(self, skipna: bool = ...) -> DataFrame: ... def all(self, skipna: bool = ...) -> DataFrame: ... - # mypy sees the two overloads as overlapping + # mypy and pyright see these overloads as overlapping @overload def apply( # type: ignore[misc] - self, func: Callable[[DataFrame], Series | Scalar], *args, **kwargs + self, func: Callable[[Iterable], float], *args, **kwargs + ) -> DataFrame: ... + @overload + def apply( # type: ignore[misc] + self, func: Callable[[DataFrame], Scalar | list | dict], *args, **kwargs ) -> Series: ... @overload def apply( # type: ignore[misc] - self, func: Callable[[Iterable], Series | Scalar], *args, **kwargs + self, func: Callable[[DataFrame], Series | DataFrame], *args, **kwargs ) -> DataFrame: ... @overload - def apply(self, func: Callable, *args, **kwargs) -> DataFrame | Series: ... - @overload def aggregate(self, arg: str, *args, **kwargs) -> DataFrame: ... @overload def aggregate(self, arg: dict, *args, **kwargs) -> DataFrame: ... diff --git a/tests/test_frame.py b/tests/test_frame.py index 27ced9263..b6b5f6156 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -1346,13 +1346,32 @@ def test_groupby_apply() -> None: # GH 167 df = pd.DataFrame({"col1": [1, 2, 3], "col2": [4, 5, 6]}) - def summean(x: pd.DataFrame) -> float: + def sum_mean(x: pd.DataFrame) -> float: return x.sum().mean() - check(assert_type(df.groupby("col1").apply(summean), pd.Series), pd.Series) + check(assert_type(df.groupby("col1").apply(sum_mean), pd.Series), pd.Series) lfunc: Callable[[pd.DataFrame], float] = lambda x: x.sum().mean() check( assert_type(df.groupby("col1").apply(lfunc), pd.Series), pd.Series, ) + + def sum_to_list(x: pd.DataFrame) -> list: + return x.sum().tolist() + + check(assert_type(df.groupby("col1").apply(sum_to_list), pd.Series), pd.Series) + + def sum_to_series(x: pd.DataFrame) -> pd.Series: + return x.sum() + + check( + assert_type(df.groupby("col1").apply(sum_to_series), pd.DataFrame), pd.DataFrame + ) + + def sample_to_df(x: pd.DataFrame) -> pd.DataFrame: + return x.sample() + + check( + assert_type(df.groupby("col1").apply(sample_to_df), pd.DataFrame), pd.DataFrame + ) From 55d0bef896c0b05698743f90f6ef623024400bd3 Mon Sep 17 00:00:00 2001 From: Daniel Roseman Date: Mon, 1 Aug 2022 17:07:56 +0100 Subject: [PATCH 2/3] Add missing SeriesGroupBy.sum --- pandas-stubs/core/groupby/generic.pyi | 1 + 1 file changed, 1 insertion(+) diff --git a/pandas-stubs/core/groupby/generic.pyi b/pandas-stubs/core/groupby/generic.pyi index 04df13ddb..3bc9f67ec 100644 --- a/pandas-stubs/core/groupby/generic.pyi +++ b/pandas-stubs/core/groupby/generic.pyi @@ -100,6 +100,7 @@ class SeriesGroupBy(GroupBy): def nlargest(self, n: int = ..., keep: str = ...) -> Series[S1]: ... def nsmallest(self, n: int = ..., keep: str = ...) -> Series[S1]: ... def nth(self, n: int | Sequence[int], dropna: str | None = ...) -> Series[S1]: ... + def sum(self, **kwargs) -> Series[S1]: ... class _DataFrameGroupByScalar(DataFrameGroupBy): def __iter__(self) -> Iterator[tuple[Scalar, DataFrame]]: ... From e6ae41e976ca606f303044b5a266baf9ba5458d7 Mon Sep 17 00:00:00 2001 From: Daniel Roseman Date: Wed, 3 Aug 2022 15:33:49 +0100 Subject: [PATCH 3/3] Reorder apply overloads. --- pandas-stubs/core/groupby/generic.pyi | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pandas-stubs/core/groupby/generic.pyi b/pandas-stubs/core/groupby/generic.pyi index 3bc9f67ec..d40c63716 100644 --- a/pandas-stubs/core/groupby/generic.pyi +++ b/pandas-stubs/core/groupby/generic.pyi @@ -113,10 +113,6 @@ class DataFrameGroupBy(GroupBy): def all(self, skipna: bool = ...) -> DataFrame: ... # mypy and pyright see these overloads as overlapping @overload - def apply( # type: ignore[misc] - self, func: Callable[[Iterable], float], *args, **kwargs - ) -> DataFrame: ... - @overload def apply( # type: ignore[misc] self, func: Callable[[DataFrame], Scalar | list | dict], *args, **kwargs ) -> Series: ... @@ -125,6 +121,10 @@ class DataFrameGroupBy(GroupBy): self, func: Callable[[DataFrame], Series | DataFrame], *args, **kwargs ) -> DataFrame: ... @overload + def apply( # type: ignore[misc] + self, func: Callable[[Iterable], float], *args, **kwargs + ) -> DataFrame: ... + @overload def aggregate(self, arg: str, *args, **kwargs) -> DataFrame: ... @overload def aggregate(self, arg: dict, *args, **kwargs) -> DataFrame: ...