Skip to content

Commit cdec539

Browse files
authored
Fix groupby.apply() and sum(axis=1) (#168)
* Fix groupby.apply() and sum(axis=1) * reorder sum overloads, change apply overloads * make the catchall return a Union
1 parent 77b357f commit cdec539

File tree

4 files changed

+34
-7
lines changed

4 files changed

+34
-7
lines changed

pandas-stubs/core/frame.pyi

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1906,22 +1906,22 @@ class DataFrame(NDFrame, OpsMixin):
19061906
self,
19071907
axis: AxisType | None = ...,
19081908
skipna: _bool | None = ...,
1909+
level: None = ...,
19091910
numeric_only: _bool | None = ...,
19101911
min_count: int = ...,
1911-
*,
1912-
level: Level,
19131912
**kwargs,
1914-
) -> DataFrame: ...
1913+
) -> Series: ...
19151914
@overload
19161915
def sum(
19171916
self,
19181917
axis: AxisType | None = ...,
19191918
skipna: _bool | None = ...,
1920-
level: None = ...,
19211919
numeric_only: _bool | None = ...,
19221920
min_count: int = ...,
1921+
*,
1922+
level: Level,
19231923
**kwargs,
1924-
) -> Series: ...
1924+
) -> DataFrame: ...
19251925
def swapaxes(
19261926
self, axis1: AxisType, axis2: AxisType, copy: _bool = ...
19271927
) -> DataFrame: ...

pandas-stubs/core/groupby/generic.pyi

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ from __future__ import annotations
33
from typing import (
44
Any,
55
Callable,
6+
Iterable,
67
Iterator,
78
Literal,
89
NamedTuple,
@@ -111,7 +112,17 @@ class _DataFrameGroupByNonScalar(DataFrameGroupBy):
111112
class DataFrameGroupBy(GroupBy):
112113
def any(self, skipna: bool = ...) -> DataFrame: ...
113114
def all(self, skipna: bool = ...) -> DataFrame: ...
114-
def apply(self, func, *args, **kwargs) -> DataFrame: ...
115+
# mypy sees the two overloads as overlapping
116+
@overload
117+
def apply( # type: ignore[misc]
118+
self, func: Callable[[DataFrame], Series | Scalar], *args, **kwargs
119+
) -> Series: ...
120+
@overload
121+
def apply( # type: ignore[misc]
122+
self, func: Callable[[Iterable], Series | Scalar], *args, **kwargs
123+
) -> DataFrame: ...
124+
@overload
125+
def apply(self, func: Callable, *args, **kwargs) -> DataFrame | Series: ...
115126
@overload
116127
def aggregate(self, arg: str, *args, **kwargs) -> DataFrame: ...
117128
@overload

pandas-stubs/core/groupby/groupby.pyi

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ class BaseGroupBy(PandasObject, SelectionMixin[NDFrameT], GroupByIndexingMixin):
6565
def pipe(self, func: Callable, *args, **kwargs): ...
6666
plot = ...
6767
def get_group(self, name, obj: DataFrame | None = ...) -> DataFrame: ...
68-
def apply(self, func: Callable, *args, **kwargs) -> FrameOrSeriesUnion: ...
6968

7069
class GroupBy(BaseGroupBy[NDFrameT]):
7170
def count(self) -> FrameOrSeriesUnion: ...

tests/test_frame.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import (
88
TYPE_CHECKING,
99
Any,
10+
Callable,
1011
Hashable,
1112
Iterable,
1213
Iterator,
@@ -1339,3 +1340,19 @@ def test_setitem_list():
13391340
iter2: Iterator[tuple[str, int]] = (v for v in lst4)
13401341
check(assert_type(df.set_index(iter1), pd.DataFrame), pd.DataFrame)
13411342
check(assert_type(df.set_index(iter2), pd.DataFrame), pd.DataFrame)
1343+
1344+
1345+
def test_groupby_apply() -> None:
1346+
# GH 167
1347+
df = pd.DataFrame({"col1": [1, 2, 3], "col2": [4, 5, 6]})
1348+
1349+
def summean(x: pd.DataFrame) -> float:
1350+
return x.sum().mean()
1351+
1352+
check(assert_type(df.groupby("col1").apply(summean), pd.Series), pd.Series)
1353+
1354+
lfunc: Callable[[pd.DataFrame], float] = lambda x: x.sum().mean()
1355+
check(
1356+
assert_type(df.groupby("col1").apply(lfunc), pd.Series),
1357+
pd.Series,
1358+
)

0 commit comments

Comments
 (0)