Skip to content

Commit 52ce001

Browse files
committed
Merge branch 'idxmax_string' into string_dtype_tests
# Conflicts: # pandas/core/arrays/string_arrow.py
2 parents ca296ec + d910efa commit 52ce001

File tree

3 files changed

+30
-1
lines changed

3 files changed

+30
-1
lines changed

pandas/core/arrays/arrow/array.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1627,6 +1627,15 @@ def _reduce(
16271627
------
16281628
TypeError : subclass does not define reductions
16291629
"""
1630+
result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs)
1631+
if isinstance(result, pa.Array):
1632+
return type(self)(result)
1633+
else:
1634+
return result
1635+
1636+
def _reduce_calc(
1637+
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
1638+
):
16301639
pa_result = self._reduce_pyarrow(name, skipna=skipna, **kwargs)
16311640

16321641
if keepdims:
@@ -1637,7 +1646,7 @@ def _reduce(
16371646
[pa_result],
16381647
type=to_pyarrow_type(infer_dtype_from_scalar(pa_result)[0]),
16391648
)
1640-
return type(self)(result)
1649+
return result
16411650

16421651
if pc.is_null(pa_result).as_py():
16431652
return self.dtype.na_value

pandas/core/arrays/string_arrow.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,17 @@ def _rank(
524524
)
525525
)
526526

527+
def _reduce(
528+
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
529+
):
530+
result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs)
531+
if name in ("argmin", "argmax") and isinstance(result, pa.Array):
532+
return self._convert_int_dtype(result)
533+
elif isinstance(result, pa.Array):
534+
return type(self)(result)
535+
else:
536+
return result
537+
527538

528539
class ArrowStringArrayNumpySemantics(ArrowStringArray):
529540
_storage = "pyarrow_numpy"

pandas/tests/frame/test_reductions.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,6 +1083,15 @@ def test_idxmax_arrow_types(self):
10831083
expected = Series([2, 1], index=["a", "b"])
10841084
tm.assert_series_equal(result, expected)
10851085

1086+
df = DataFrame({"a": ["b", "c", "a"]}, dtype="string[pyarrow]")
1087+
result = df.idxmax(numeric_only=False)
1088+
expected = Series([1], index=["a"])
1089+
tm.assert_series_equal(result, expected)
1090+
1091+
result = df.idxmin(numeric_only=False)
1092+
expected = Series([2], index=["a"])
1093+
tm.assert_series_equal(result, expected)
1094+
10861095
def test_idxmax_axis_2(self, float_frame):
10871096
frame = float_frame
10881097
msg = "No axis named 2 for object type DataFrame"

0 commit comments

Comments
 (0)