Skip to content

Commit ca3e351

Browse files
authored
REF: implement Groupby idxmin, idxmax without fallback (#38264)
1 parent 5b91feb commit ca3e351

File tree

3 files changed

+65
-19
lines changed

3 files changed

+65
-19
lines changed

pandas/core/groupby/generic.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,14 @@
5050
)
5151
from pandas.core.dtypes.missing import isna, notna
5252

53+
from pandas.core import algorithms, nanops
5354
from pandas.core.aggregation import (
5455
agg_list_like,
5556
aggregate,
5657
maybe_mangle_lambdas,
5758
reconstruct_func,
5859
validate_func_kwargs,
5960
)
60-
import pandas.core.algorithms as algorithms
6161
from pandas.core.arrays import Categorical, ExtensionArray
6262
from pandas.core.base import DataError, SpecificationError
6363
import pandas.core.common as com
@@ -1826,4 +1826,46 @@ def nunique(self, dropna: bool = True) -> DataFrame:
18261826
self._insert_inaxis_grouper_inplace(results)
18271827
return results
18281828

1829+
@Appender(DataFrame.idxmax.__doc__)
1830+
def idxmax(self, axis=0, skipna: bool = True):
1831+
axis = DataFrame._get_axis_number(axis)
1832+
numeric_only = None if axis == 0 else False
1833+
1834+
def func(df):
1835+
# NB: here we use numeric_only=None, in DataFrame it is False GH#38217
1836+
res = df._reduce(
1837+
nanops.nanargmax,
1838+
"argmax",
1839+
axis=axis,
1840+
skipna=skipna,
1841+
numeric_only=numeric_only,
1842+
)
1843+
indices = res._values
1844+
index = df._get_axis(axis)
1845+
result = [index[i] if i >= 0 else np.nan for i in indices]
1846+
return df._constructor_sliced(result, index=res.index)
1847+
1848+
return self._python_apply_general(func, self._obj_with_exclusions)
1849+
1850+
@Appender(DataFrame.idxmin.__doc__)
1851+
def idxmin(self, axis=0, skipna: bool = True):
1852+
axis = DataFrame._get_axis_number(axis)
1853+
numeric_only = None if axis == 0 else False
1854+
1855+
def func(df):
1856+
# NB: here we use numeric_only=None, in DataFrame it is False GH#38217
1857+
res = df._reduce(
1858+
nanops.nanargmin,
1859+
"argmin",
1860+
axis=axis,
1861+
skipna=skipna,
1862+
numeric_only=numeric_only,
1863+
)
1864+
indices = res._values
1865+
index = df._get_axis(axis)
1866+
result = [index[i] if i >= 0 else np.nan for i in indices]
1867+
return df._constructor_sliced(result, index=res.index)
1868+
1869+
return self._python_apply_general(func, self._obj_with_exclusions)
1870+
18291871
boxplot = boxplot_frame_groupby

pandas/core/groupby/groupby.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ class providing the base-class of operations.
1111
import datetime
1212
from functools import partial, wraps
1313
import inspect
14-
import re
1514
import types
1615
from typing import (
1716
Callable,
@@ -797,23 +796,7 @@ def curried(x):
797796
if name in base.plotting_methods:
798797
return self.apply(curried)
799798

800-
try:
801-
return self._python_apply_general(curried, self._obj_with_exclusions)
802-
except TypeError as err:
803-
if not re.search(
804-
"reduction operation '.*' not allowed for this dtype", str(err)
805-
):
806-
# We don't have a cython implementation
807-
# TODO: is the above comment accurate?
808-
raise
809-
810-
if self.obj.ndim == 1:
811-
# this can be called recursively, so need to raise ValueError
812-
raise ValueError
813-
814-
# GH#3688 try to operate item-by-item
815-
result = self._aggregate_item_by_item(name, *args, **kwargs)
816-
return result
799+
return self._python_apply_general(curried, self._obj_with_exclusions)
817800

818801
wrapper.__name__ = name
819802
return wrapper

pandas/tests/groupby/test_function.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,27 @@ def test_idxmin_idxmax_returns_int_types(func, values):
539539
tm.assert_frame_equal(result, expected)
540540

541541

542+
def test_idxmin_idxmax_axis1():
543+
df = DataFrame(np.random.randn(10, 4), columns=["A", "B", "C", "D"])
544+
df["A"] = [1, 2, 3, 1, 2, 3, 1, 2, 3, 4]
545+
546+
gb = df.groupby("A")
547+
548+
res = gb.idxmax(axis=1)
549+
550+
alt = df.iloc[:, 1:].idxmax(axis=1)
551+
indexer = res.index.get_level_values(1)
552+
553+
tm.assert_series_equal(alt[indexer], res.droplevel("A"))
554+
555+
df["E"] = pd.date_range("2016-01-01", periods=10)
556+
gb2 = df.groupby("A")
557+
558+
msg = "reduction operation 'argmax' not allowed for this dtype"
559+
with pytest.raises(TypeError, match=msg):
560+
gb2.idxmax(axis=1)
561+
562+
542563
def test_groupby_cumprod():
543564
# GH 4095
544565
df = DataFrame({"key": ["b"] * 10, "value": 2})

0 commit comments

Comments
 (0)