diff --git a/doc/source/whatsnew/v2.1.1.rst b/doc/source/whatsnew/v2.1.1.rst index b7e45d043d869..e31d161993824 100644 --- a/doc/source/whatsnew/v2.1.1.rst +++ b/doc/source/whatsnew/v2.1.1.rst @@ -14,6 +14,7 @@ including other versions of pandas. Fixed regressions ~~~~~~~~~~~~~~~~~ - Fixed regression in :func:`read_csv` when ``usecols`` is given and ``dtypes`` is a dict for ``engine="python"`` (:issue:`54868`) +- Fixed regression in :meth:`.GroupBy.get_group` raising for ``axis=1`` (:issue:`54858`) - Fixed regression in :meth:`DataFrame.__setitem__` raising ``AssertionError`` when setting a :class:`Series` with a partial :class:`MultiIndex` (:issue:`54875`) - Fixed regression when comparing a :class:`Series` with ``datetime64`` dtype with ``None`` (:issue:`54870`) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 11c97e30ab5cd..43d200027220b 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1080,7 +1080,8 @@ def get_group(self, name, obj=None) -> DataFrame | Series: raise KeyError(name) if obj is None: - return self._selected_obj.iloc[inds] + indexer = inds if self.axis == 0 else (slice(None), inds) + return self._selected_obj.iloc[indexer] else: warnings.warn( "obj is deprecated and will be removed in a future version. " diff --git a/pandas/tests/groupby/test_groupby.py b/pandas/tests/groupby/test_groupby.py index c0ac94c09e1ea..be226b4466f98 100644 --- a/pandas/tests/groupby/test_groupby.py +++ b/pandas/tests/groupby/test_groupby.py @@ -3187,3 +3187,26 @@ def test_depr_get_group_len_1_list_likes(test_series, kwarg, value, name, warn): else: expected = DataFrame({"b": [3, 4]}, index=Index([1, 1], name="a")) tm.assert_equal(result, expected) + + +def test_get_group_axis_1(): + # GH#54858 + df = DataFrame( + { + "col1": [0, 3, 2, 3], + "col2": [4, 1, 6, 7], + "col3": [3, 8, 2, 10], + "col4": [1, 13, 6, 15], + "col5": [-4, 5, 6, -7], + } + ) + with tm.assert_produces_warning(FutureWarning, match="deprecated"): + grouped = df.groupby(axis=1, by=[1, 2, 3, 2, 1]) + result = grouped.get_group(1) + expected = DataFrame( + { + "col1": [0, 3, 2, 3], + "col5": [-4, 5, 6, -7], + } + ) + tm.assert_frame_equal(result, expected)