Skip to content

Commit 6aae41c

Browse files
authored
Fix calling df.agg with scalar dict values (#861)
Fixes #846
1 parent 176805d commit 6aae41c

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

pandas-stubs/core/frame.pyi

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ from pandas._typing import (
6161
S1,
6262
AggFuncTypeBase,
6363
AggFuncTypeDictFrame,
64+
AggFuncTypeDictSeries,
6465
AggFuncTypeFrame,
6566
AnyArrayLike,
6667
ArrayLike,
@@ -1140,7 +1141,9 @@ class DataFrame(NDFrame, OpsMixin):
11401141
) -> DataFrame: ...
11411142
def diff(self, periods: int = ..., axis: Axis = ...) -> DataFrame: ...
11421143
@overload
1143-
def agg(self, func: AggFuncTypeBase, axis: Axis = ..., **kwargs) -> Series: ...
1144+
def agg( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
1145+
self, func: AggFuncTypeBase | AggFuncTypeDictSeries, axis: Axis = ..., **kwargs
1146+
) -> Series: ...
11441147
@overload
11451148
def agg(
11461149
self,
@@ -1149,8 +1152,8 @@ class DataFrame(NDFrame, OpsMixin):
11491152
**kwargs,
11501153
) -> DataFrame: ...
11511154
@overload
1152-
def aggregate(
1153-
self, func: AggFuncTypeBase, axis: Axis = ..., **kwargs
1155+
def aggregate( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
1156+
self, func: AggFuncTypeBase | AggFuncTypeDictSeries, axis: Axis = ..., **kwargs
11541157
) -> Series: ...
11551158
@overload
11561159
def aggregate(

tests/test_frame.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,7 +1276,10 @@ def test_types_agg() -> None:
12761276
assert_type(df.agg({"A": ["min", "max"], "B": "min"}), pd.DataFrame),
12771277
pd.DataFrame,
12781278
)
1279+
check(assert_type(df.agg({"A": ["mean"]}), pd.DataFrame), pd.DataFrame)
12791280
check(assert_type(df.agg("mean", axis=1), pd.Series), pd.Series)
1281+
check(assert_type(df.agg({"A": "mean"}), pd.Series), pd.Series)
1282+
check(assert_type(df.agg({"A": "mean", "B": "sum"}), pd.Series), pd.Series)
12801283

12811284

12821285
def test_types_aggregate() -> None:
@@ -1299,6 +1302,10 @@ def test_types_aggregate() -> None:
12991302
assert_type(df.aggregate({"A": ["min", "max"], "B": "min"}), pd.DataFrame),
13001303
pd.DataFrame,
13011304
)
1305+
check(assert_type(df.aggregate({"A": ["mean"]}), pd.DataFrame), pd.DataFrame)
1306+
check(assert_type(df.aggregate("mean", axis=1), pd.Series), pd.Series)
1307+
check(assert_type(df.aggregate({"A": "mean"}), pd.Series), pd.Series)
1308+
check(assert_type(df.aggregate({"A": "mean", "B": "sum"}), pd.Series), pd.Series)
13021309

13031310

13041311
def test_types_transform() -> None:

0 commit comments

Comments
 (0)