diff --git a/pandas-stubs/core/groupby/generic.pyi b/pandas-stubs/core/groupby/generic.pyi index bd4a35525..bd9dbc7c5 100644 --- a/pandas-stubs/core/groupby/generic.pyi +++ b/pandas-stubs/core/groupby/generic.pyi @@ -143,6 +143,8 @@ class SeriesGroupBy(GroupBy, Generic[S1]): legend: bool = ..., **kwargs, ) -> AxesSubplot: ... + def idxmax(self, axis: AxisType = ..., skipna: bool = ...) -> Series: ... + def idxmin(self, axis: AxisType = ..., skipna: bool = ...) -> Series: ... class _DataFrameGroupByScalar(DataFrameGroupBy): def __iter__(self) -> Iterator[tuple[Scalar, DataFrame]]: ... @@ -273,10 +275,10 @@ class DataFrameGroupBy(GroupBy): ) -> AxesSubplot | Sequence[AxesSubplot]: ... def idxmax( self, axis: AxisType = ..., skipna: bool = ..., numeric_only: bool = ... - ) -> Series: ... + ) -> DataFrame: ... def idxmin( self, axis: AxisType = ..., skipna: bool = ..., numeric_only: bool = ... - ) -> Series: ... + ) -> DataFrame: ... def last(self, **kwargs) -> DataFrame: ... def max(self, **kwargs) -> DataFrame: ... def mean(self, **kwargs) -> DataFrame: ... diff --git a/tests/test_frame.py b/tests/test_frame.py index 0e55e3f1b..62aa300a9 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -877,6 +877,8 @@ def test_types_groupby_methods() -> None: pd.Series, float, ) + check(assert_type(df.groupby("col1").idxmax(), pd.DataFrame), pd.DataFrame) + check(assert_type(df.groupby("col1").idxmin(), pd.DataFrame), pd.DataFrame) def test_types_groupby_agg() -> None: diff --git a/tests/test_series.py b/tests/test_series.py index 05fe07776..a843c7473 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -519,6 +519,8 @@ def test_types_groupby_methods() -> None: check(assert_type(s.groupby(level=0).var(), "pd.Series[float]"), pd.Series, float) check(assert_type(s.groupby(level=0).tail(), "pd.Series[int]"), pd.Series, np.int_) check(assert_type(s.groupby(level=0).unique(), pd.Series), pd.Series) + check(assert_type(s.groupby(level=0).idxmax(), pd.Series), pd.Series) + check(assert_type(s.groupby(level=0).idxmin(), pd.Series), pd.Series) def test_types_groupby_agg() -> None: