diff --git a/doc/source/whatsnew/v2.1.1.rst b/doc/source/whatsnew/v2.1.1.rst index fa83c8ee400c2..8b4833f6ce043 100644 --- a/doc/source/whatsnew/v2.1.1.rst +++ b/doc/source/whatsnew/v2.1.1.rst @@ -15,6 +15,7 @@ Fixed regressions ~~~~~~~~~~~~~~~~~ - Fixed regression in :func:`merge` when merging over a PyArrow string index (:issue:`54894`) - Fixed regression in :func:`read_csv` when ``usecols`` is given and ``dtypes`` is a dict for ``engine="python"`` (:issue:`54868`) +- Fixed regression in :meth:`.GroupBy.get_group` raising for ``axis=1`` (:issue:`54858`) - Fixed regression in :meth:`DataFrame.__setitem__` raising ``AssertionError`` when setting a :class:`Series` with a partial :class:`MultiIndex` (:issue:`54875`) - Fixed regression in :meth:`Series.value_counts` raising for numeric data if ``bins`` was specified (:issue:`54857`) - Fixed regression when comparing a :class:`Series` with ``datetime64`` dtype with ``None`` (:issue:`54870`) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index dd0c995e6acca..7fa69a70431e4 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1059,7 +1059,8 @@ def get_group(self, name, obj=None) -> DataFrame | Series: raise KeyError(name) if obj is None: - return self._selected_obj.iloc[inds] + indexer = inds if self.axis == 0 else (slice(None), inds) + return self._selected_obj.iloc[indexer] else: warnings.warn( "obj is deprecated and will be removed in a future version. " diff --git a/pandas/tests/groupby/test_groupby.py b/pandas/tests/groupby/test_groupby.py index 9b260d5757767..bcd8323a26767 100644 --- a/pandas/tests/groupby/test_groupby.py +++ b/pandas/tests/groupby/test_groupby.py @@ -3159,3 +3159,26 @@ def test_groupby_series_with_datetimeindex_month_name(): expected = Series([2, 1], name="jan") expected.index.name = "jan" tm.assert_series_equal(result, expected) + + +def test_get_group_axis_1(): + # GH#54858 + df = DataFrame( + { + "col1": [0, 3, 2, 3], + "col2": [4, 1, 6, 7], + "col3": [3, 8, 2, 10], + "col4": [1, 13, 6, 15], + "col5": [-4, 5, 6, -7], + } + ) + with tm.assert_produces_warning(FutureWarning, match="deprecated"): + grouped = df.groupby(axis=1, by=[1, 2, 3, 2, 1]) + result = grouped.get_group(1) + expected = DataFrame( + { + "col1": [0, 3, 2, 3], + "col5": [-4, 5, 6, -7], + } + ) + tm.assert_frame_equal(result, expected)