diff --git a/pandas-stubs/core/groupby/generic.pyi b/pandas-stubs/core/groupby/generic.pyi index aedcb49af..d40c63716 100644 --- a/pandas-stubs/core/groupby/generic.pyi +++ b/pandas-stubs/core/groupby/generic.pyi @@ -100,6 +100,7 @@ class SeriesGroupBy(GroupBy): def nlargest(self, n: int = ..., keep: str = ...) -> Series[S1]: ... def nsmallest(self, n: int = ..., keep: str = ...) -> Series[S1]: ... def nth(self, n: int | Sequence[int], dropna: str | None = ...) -> Series[S1]: ... + def sum(self, **kwargs) -> Series[S1]: ... class _DataFrameGroupByScalar(DataFrameGroupBy): def __iter__(self) -> Iterator[tuple[Scalar, DataFrame]]: ... @@ -110,17 +111,19 @@ class _DataFrameGroupByNonScalar(DataFrameGroupBy): class DataFrameGroupBy(GroupBy): def any(self, skipna: bool = ...) -> DataFrame: ... def all(self, skipna: bool = ...) -> DataFrame: ... - # mypy sees the two overloads as overlapping + # mypy and pyright see these overloads as overlapping @overload def apply( # type: ignore[misc] - self, func: Callable[[DataFrame], Series | Scalar], *args, **kwargs + self, func: Callable[[DataFrame], Scalar | list | dict], *args, **kwargs ) -> Series: ... @overload def apply( # type: ignore[misc] - self, func: Callable[[Iterable], Series | Scalar], *args, **kwargs + self, func: Callable[[DataFrame], Series | DataFrame], *args, **kwargs ) -> DataFrame: ... @overload - def apply(self, func: Callable, *args, **kwargs) -> DataFrame | Series: ... + def apply( # type: ignore[misc] + self, func: Callable[[Iterable], float], *args, **kwargs + ) -> DataFrame: ... @overload def aggregate(self, arg: str, *args, **kwargs) -> DataFrame: ... @overload diff --git a/tests/test_frame.py b/tests/test_frame.py index 27ced9263..b6b5f6156 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -1346,13 +1346,32 @@ def test_groupby_apply() -> None: # GH 167 df = pd.DataFrame({"col1": [1, 2, 3], "col2": [4, 5, 6]}) - def summean(x: pd.DataFrame) -> float: + def sum_mean(x: pd.DataFrame) -> float: return x.sum().mean() - check(assert_type(df.groupby("col1").apply(summean), pd.Series), pd.Series) + check(assert_type(df.groupby("col1").apply(sum_mean), pd.Series), pd.Series) lfunc: Callable[[pd.DataFrame], float] = lambda x: x.sum().mean() check( assert_type(df.groupby("col1").apply(lfunc), pd.Series), pd.Series, ) + + def sum_to_list(x: pd.DataFrame) -> list: + return x.sum().tolist() + + check(assert_type(df.groupby("col1").apply(sum_to_list), pd.Series), pd.Series) + + def sum_to_series(x: pd.DataFrame) -> pd.Series: + return x.sum() + + check( + assert_type(df.groupby("col1").apply(sum_to_series), pd.DataFrame), pd.DataFrame + ) + + def sample_to_df(x: pd.DataFrame) -> pd.DataFrame: + return x.sample() + + check( + assert_type(df.groupby("col1").apply(sample_to_df), pd.DataFrame), pd.DataFrame + )