Skip to content

Commit 3266678

Browse files
committed
More specific typing for DataFrameGroupBy.apply.
1 parent d7307e5 commit 3266678

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

pandas-stubs/core/groupby/generic.pyi

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,18 +110,20 @@ class _DataFrameGroupByNonScalar(DataFrameGroupBy):
110110
class DataFrameGroupBy(GroupBy):
111111
def any(self, skipna: bool = ...) -> DataFrame: ...
112112
def all(self, skipna: bool = ...) -> DataFrame: ...
113-
# mypy sees the two overloads as overlapping
113+
# mypy and pyright see these overloads as overlapping
114114
@overload
115115
def apply( # type: ignore[misc]
116-
self, func: Callable[[DataFrame], Series | Scalar], *args, **kwargs
116+
self, func: Callable[[Iterable], float], *args, **kwargs
117+
) -> DataFrame: ...
118+
@overload
119+
def apply( # type: ignore[misc]
120+
self, func: Callable[[DataFrame], Scalar | list | dict], *args, **kwargs
117121
) -> Series: ...
118122
@overload
119123
def apply( # type: ignore[misc]
120-
self, func: Callable[[Iterable], Series | Scalar], *args, **kwargs
124+
self, func: Callable[[DataFrame], Series | DataFrame], *args, **kwargs
121125
) -> DataFrame: ...
122126
@overload
123-
def apply(self, func: Callable, *args, **kwargs) -> DataFrame | Series: ...
124-
@overload
125127
def aggregate(self, arg: str, *args, **kwargs) -> DataFrame: ...
126128
@overload
127129
def aggregate(self, arg: dict, *args, **kwargs) -> DataFrame: ...

tests/test_frame.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,13 +1346,32 @@ def test_groupby_apply() -> None:
13461346
# GH 167
13471347
df = pd.DataFrame({"col1": [1, 2, 3], "col2": [4, 5, 6]})
13481348

1349-
def summean(x: pd.DataFrame) -> float:
1349+
def sum_mean(x: pd.DataFrame) -> float:
13501350
return x.sum().mean()
13511351

1352-
check(assert_type(df.groupby("col1").apply(summean), pd.Series), pd.Series)
1352+
check(assert_type(df.groupby("col1").apply(sum_mean), pd.Series), pd.Series)
13531353

13541354
lfunc: Callable[[pd.DataFrame], float] = lambda x: x.sum().mean()
13551355
check(
13561356
assert_type(df.groupby("col1").apply(lfunc), pd.Series),
13571357
pd.Series,
13581358
)
1359+
1360+
def sum_to_list(x: pd.DataFrame) -> list:
1361+
return x.sum().tolist()
1362+
1363+
check(assert_type(df.groupby("col1").apply(sum_to_list), pd.Series), pd.Series)
1364+
1365+
def sum_to_series(x: pd.DataFrame) -> pd.Series:
1366+
return x.sum()
1367+
1368+
check(
1369+
assert_type(df.groupby("col1").apply(sum_to_series), pd.DataFrame), pd.DataFrame
1370+
)
1371+
1372+
def sample_to_df(x: pd.DataFrame) -> pd.DataFrame:
1373+
return x.sample()
1374+
1375+
check(
1376+
assert_type(df.groupby("col1").apply(sample_to_df), pd.DataFrame), pd.DataFrame
1377+
)

0 commit comments

Comments
 (0)