diff --git a/doc/source/whatsnew/v1.5.0.rst b/doc/source/whatsnew/v1.5.0.rst index 2efc6c9167a83..042452c79230e 100644 --- a/doc/source/whatsnew/v1.5.0.rst +++ b/doc/source/whatsnew/v1.5.0.rst @@ -150,6 +150,7 @@ Other enhancements - Added ``validate`` argument to :meth:`DataFrame.join` (:issue:`46622`) - A :class:`errors.PerformanceWarning` is now thrown when using ``string[pyarrow]`` dtype with methods that don't dispatch to ``pyarrow.compute`` methods (:issue:`42613`) - Added ``numeric_only`` argument to :meth:`Resampler.sum`, :meth:`Resampler.prod`, :meth:`Resampler.min`, :meth:`Resampler.max`, :meth:`Resampler.first`, and :meth:`Resampler.last` (:issue:`46442`) +- Implemented :meth:`nlargest` and :meth:`nsmallest` methods for :class:`DataFrameGroupBy` (:issue:`46924`) .. --------------------------------------------------------------------------- .. _whatsnew_150.notable_bug_fixes: diff --git a/pandas/core/groupby/base.py b/pandas/core/groupby/base.py index ec9a2e4a4b5c0..8eb6ad061cc25 100644 --- a/pandas/core/groupby/base.py +++ b/pandas/core/groupby/base.py @@ -33,6 +33,8 @@ class OutputKey: "corr", "cov", "diff", + "nlargest", + "nsmallest", ] ) | plotting_methods @@ -40,9 +42,7 @@ class OutputKey: series_apply_allowlist: frozenset[str] = ( common_apply_allowlist - | frozenset( - {"nlargest", "nsmallest", "is_monotonic_increasing", "is_monotonic_decreasing"} - ) + | frozenset({"is_monotonic_increasing", "is_monotonic_decreasing"}) ) | frozenset(["dtype", "unique"]) dataframe_apply_allowlist: frozenset[str] = common_apply_allowlist | frozenset( @@ -155,6 +155,8 @@ def maybe_normalize_deprecated_kernels(kernel): "transform", "sample", "value_counts", + "nlargest", + "nsmallest", ] ) # Valid values of `name` for `groupby.transform(name)` diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 245e33fb1a23b..9aa840def1294 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -1812,6 +1812,24 @@ def value_counts( result = result_frame return result.__finalize__(self.obj, method="value_counts") + @doc(DataFrame.nlargest) + def nlargest(self, n, columns, keep: str = "first"): + f = partial(DataFrame.nlargest, n=n, columns=columns, keep=keep) + data = self._obj_with_exclusions + # Don't change behavior if result index happens to be the same, i.e. + # already ordered and n >= all group sizes. + result = self._python_apply_general(f, data, not_indexed_same=True) + return result + + @doc(DataFrame.nsmallest) + def nsmallest(self, n, columns, keep: str = "first"): + f = partial(DataFrame.nsmallest, n=n, columns=columns, keep=keep) + data = self._obj_with_exclusions + # Don't change behavior if result index happens to be the same, i.e. + # already ordered and n >= all group sizes. + result = self._python_apply_general(f, data, not_indexed_same=True) + return result + def _wrap_transform_general_frame( obj: DataFrame, group: DataFrame, res: DataFrame | Series diff --git a/pandas/tests/groupby/test_allowlist.py b/pandas/tests/groupby/test_allowlist.py index 7c64d82608c9e..cce7f85afb9e1 100644 --- a/pandas/tests/groupby/test_allowlist.py +++ b/pandas/tests/groupby/test_allowlist.py @@ -51,6 +51,8 @@ "corr", "cov", "diff", + "nlargest", + "nsmallest", ] @@ -322,6 +324,8 @@ def test_tab_completion(mframe): "sample", "ewm", "value_counts", + "nlargest", + "nsmallest", } assert results == expected diff --git a/pandas/tests/groupby/test_groupby.py b/pandas/tests/groupby/test_groupby.py index 016e817e43402..f92761e93f9d8 100644 --- a/pandas/tests/groupby/test_groupby.py +++ b/pandas/tests/groupby/test_groupby.py @@ -2721,3 +2721,70 @@ def test_by_column_values_with_same_starting_value(): ).set_index("Name") tm.assert_frame_equal(result, expected_result) + + +@pytest.mark.parametrize( + "function, keep, indices, name, data", + [ + ( + "nlargest", + "first", + [("bar", 1), ("bar", 2), ("foo", 5), ("foo", 3)], + ["b2", "b3", "f3", "f1"], + [3, 3, 3, 1], + ), + ( + "nlargest", + "last", + [("bar", 2), ("bar", 1), ("foo", 5), ("foo", 4)], + ["b3", "b2", "f3", "f2"], + [3, 3, 3, 1], + ), + ( + "nlargest", + "all", + [("bar", 1), ("bar", 2), ("foo", 5), ("foo", 3), ("foo", 4)], + ["b2", "b3", "f3", "f1", "f2"], + [3, 3, 3, 1, 1], + ), + ( + "nsmallest", + "first", + [("bar", 0), ("bar", 1), ("foo", 3), ("foo", 4)], + ["b1", "b2", "f1", "f2"], + [1, 3, 1, 1], + ), + ( + "nsmallest", + "last", + [("bar", 0), ("bar", 2), ("foo", 4), ("foo", 3)], + ["b1", "b3", "f2", "f1"], + [1, 3, 1, 1], + ), + ( + "nsmallest", + "all", + [("bar", 0), ("bar", 1), ("bar", 2), ("foo", 3), ("foo", 4)], + ["b1", "b2", "b3", "f1", "f2"], + [1, 3, 3, 1, 1], + ), + ], +) +def test_nlargest_nsmallest(function, keep, indices, name, data): + # test nlargest and nsmallest for DataFrameGroupBy + # GH46924 + df = DataFrame( + { + "group": ["bar", "bar", "bar", "foo", "foo", "foo"], + "name": ["b1", "b2", "b3", "f1", "f2", "f3"], + "data": [1, 3, 3, 1, 1, 3], + } + ) + grouped = df.groupby("group") + func = getattr(grouped, function) + result = func(n=2, keep=keep, columns="data") + + expected_index = MultiIndex.from_tuples(indices, names=["group", None]) + expected = DataFrame({"name": name, "data": data}, index=expected_index) + + tm.assert_frame_equal(result, expected)