diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 45f02ca25..ce9e5ba88 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -61,6 +61,7 @@ from pandas._typing import ( S1, AggFuncTypeBase, AggFuncTypeDictFrame, + AggFuncTypeDictSeries, AggFuncTypeFrame, AnyArrayLike, ArrayLike, @@ -1139,7 +1140,9 @@ class DataFrame(NDFrame, OpsMixin): ) -> DataFrame: ... def diff(self, periods: int = ..., axis: Axis = ...) -> DataFrame: ... @overload - def agg(self, func: AggFuncTypeBase, axis: Axis = ..., **kwargs) -> Series: ... + def agg( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload] + self, func: AggFuncTypeBase | AggFuncTypeDictSeries, axis: Axis = ..., **kwargs + ) -> Series: ... @overload def agg( self, @@ -1148,8 +1151,8 @@ class DataFrame(NDFrame, OpsMixin): **kwargs, ) -> DataFrame: ... @overload - def aggregate( - self, func: AggFuncTypeBase, axis: Axis = ..., **kwargs + def aggregate( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload] + self, func: AggFuncTypeBase | AggFuncTypeDictSeries, axis: Axis = ..., **kwargs ) -> Series: ... @overload def aggregate( diff --git a/tests/test_frame.py b/tests/test_frame.py index cd4ace1ff..d53d1a0d4 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -1267,7 +1267,10 @@ def test_types_agg() -> None: assert_type(df.agg({"A": ["min", "max"], "B": "min"}), pd.DataFrame), pd.DataFrame, ) + check(assert_type(df.agg({"A": ["mean"]}), pd.DataFrame), pd.DataFrame) check(assert_type(df.agg("mean", axis=1), pd.Series), pd.Series) + check(assert_type(df.agg({"A": "mean"}), pd.Series), pd.Series) + check(assert_type(df.agg({"A": "mean", "B": "sum"}), pd.Series), pd.Series) def test_types_aggregate() -> None: @@ -1290,6 +1293,10 @@ def test_types_aggregate() -> None: assert_type(df.aggregate({"A": ["min", "max"], "B": "min"}), pd.DataFrame), pd.DataFrame, ) + check(assert_type(df.aggregate({"A": ["mean"]}), pd.DataFrame), pd.DataFrame) + check(assert_type(df.aggregate("mean", axis=1), pd.Series), pd.Series) + check(assert_type(df.aggregate({"A": "mean"}), pd.Series), pd.Series) + check(assert_type(df.aggregate({"A": "mean", "B": "sum"}), pd.Series), pd.Series) def test_types_transform() -> None: