Skip to content

Commit ab35982

Browse files
committed
BUG: idxmax raising for arrow strings
1 parent 8664572 commit ab35982

File tree

3 files changed

+27
-1
lines changed

3 files changed

+27
-1
lines changed

pandas/core/arrays/arrow/array.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1627,6 +1627,15 @@ def _reduce(
16271627
------
16281628
TypeError : subclass does not define reductions
16291629
"""
1630+
result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs)
1631+
if isinstance(result, pa.Array):
1632+
return type(self)(result)
1633+
else:
1634+
return result
1635+
1636+
def _reduce_calc(
1637+
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
1638+
):
16301639
pa_result = self._reduce_pyarrow(name, skipna=skipna, **kwargs)
16311640

16321641
if keepdims:
@@ -1637,7 +1646,7 @@ def _reduce(
16371646
[pa_result],
16381647
type=to_pyarrow_type(infer_dtype_from_scalar(pa_result)[0]),
16391648
)
1640-
return type(self)(result)
1649+
return result
16411650

16421651
if pc.is_null(pa_result).as_py():
16431652
return self.dtype.na_value

pandas/core/arrays/string_arrow.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,14 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None):
501501
def _convert_int_dtype(self, result):
502502
return Int64Dtype().__from_arrow__(result)
503503

504+
def _reduce(
505+
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
506+
):
507+
result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs)
508+
if name in ("argmin", "argmax") and isinstance(result, pa.Array):
509+
return self._convert_int_dtype(result)
510+
return type(self)(result)
511+
504512

505513
class ArrowStringArrayNumpySemantics(ArrowStringArray):
506514
_storage = "pyarrow_numpy"

pandas/tests/frame/test_reductions.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,6 +1069,15 @@ def test_idxmax_arrow_types(self):
10691069
expected = Series([2, 1], index=["a", "b"])
10701070
tm.assert_series_equal(result, expected)
10711071

1072+
df = DataFrame({"a": ["b", "c", "a"]}, dtype="string[pyarrow]")
1073+
result = df.idxmax(numeric_only=False)
1074+
expected = Series([1], index=["a"])
1075+
tm.assert_series_equal(result, expected)
1076+
1077+
result = df.idxmin(numeric_only=False)
1078+
expected = Series([2], index=["a"])
1079+
tm.assert_series_equal(result, expected)
1080+
10721081
def test_idxmax_axis_2(self, float_frame):
10731082
frame = float_frame
10741083
msg = "No axis named 2 for object type DataFrame"

0 commit comments

Comments
 (0)