From ab35982e4bba6b49a32e619c733a5507484f8fe3 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler Date: Tue, 3 Oct 2023 22:50:07 +0200 Subject: [PATCH 1/2] BUG: idxmax raising for arrow strings --- pandas/core/arrays/arrow/array.py | 11 ++++++++++- pandas/core/arrays/string_arrow.py | 8 ++++++++ pandas/tests/frame/test_reductions.py | 9 +++++++++ 3 files changed, 27 insertions(+), 1 deletion(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 2c788411eb089..9743ca891d4b8 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1627,6 +1627,15 @@ def _reduce( ------ TypeError : subclass does not define reductions """ + result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs) + if isinstance(result, pa.Array): + return type(self)(result) + else: + return result + + def _reduce_calc( + self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs + ): pa_result = self._reduce_pyarrow(name, skipna=skipna, **kwargs) if keepdims: @@ -1637,7 +1646,7 @@ def _reduce( [pa_result], type=to_pyarrow_type(infer_dtype_from_scalar(pa_result)[0]), ) - return type(self)(result) + return result if pc.is_null(pa_result).as_py(): return self.dtype.na_value diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index e904123849821..33fcdf56d31cb 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -501,6 +501,14 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None): def _convert_int_dtype(self, result): return Int64Dtype().__from_arrow__(result) + def _reduce( + self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs + ): + result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs) + if name in ("argmin", "argmax") and isinstance(result, pa.Array): + return self._convert_int_dtype(result) + return type(self)(result) + class ArrowStringArrayNumpySemantics(ArrowStringArray): _storage = "pyarrow_numpy" diff --git a/pandas/tests/frame/test_reductions.py b/pandas/tests/frame/test_reductions.py index 77f64b18a82f8..1fcc08946cb04 100644 --- a/pandas/tests/frame/test_reductions.py +++ b/pandas/tests/frame/test_reductions.py @@ -1069,6 +1069,15 @@ def test_idxmax_arrow_types(self): expected = Series([2, 1], index=["a", "b"]) tm.assert_series_equal(result, expected) + df = DataFrame({"a": ["b", "c", "a"]}, dtype="string[pyarrow]") + result = df.idxmax(numeric_only=False) + expected = Series([1], index=["a"]) + tm.assert_series_equal(result, expected) + + result = df.idxmin(numeric_only=False) + expected = Series([2], index=["a"]) + tm.assert_series_equal(result, expected) + def test_idxmax_axis_2(self, float_frame): frame = float_frame msg = "No axis named 2 for object type DataFrame" From d910efac8126b1845dad9f64269b09ba525a41b1 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler Date: Tue, 3 Oct 2023 22:51:10 +0200 Subject: [PATCH 2/2] Fix --- pandas/core/arrays/string_arrow.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index 33fcdf56d31cb..5f800e781d2fa 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -507,7 +507,10 @@ def _reduce( result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs) if name in ("argmin", "argmax") and isinstance(result, pa.Array): return self._convert_int_dtype(result) - return type(self)(result) + elif isinstance(result, pa.Array): + return type(self)(result) + else: + return result class ArrowStringArrayNumpySemantics(ArrowStringArray):