From 7edec5ff27a3c6d189f247a7c3b0e5590abc5218 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 11 Apr 2023 11:27:53 -0700 Subject: [PATCH 1/6] Backport PR #52499: ENH: Implement str.r/split for ArrowDtype --- doc/source/whatsnew/v2.0.1.rst | 1 + pandas/core/arrays/arrow/array.py | 24 ++++++++---- pandas/core/strings/accessor.py | 51 +++++++++++++++--------- pandas/tests/extension/test_arrow.py | 58 +++++++++++++++++++++++++++- 4 files changed, 106 insertions(+), 28 deletions(-) diff --git a/doc/source/whatsnew/v2.0.1.rst b/doc/source/whatsnew/v2.0.1.rst index 46a780befcede..7b4dc890da3e1 100644 --- a/doc/source/whatsnew/v2.0.1.rst +++ b/doc/source/whatsnew/v2.0.1.rst @@ -36,6 +36,7 @@ Bug fixes Other ~~~~~ +- Implemented :meth:`Series.str.split` and :meth:`Series.str.rsplit` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`52401`) - :class:`DataFrame` created from empty dicts had :attr:`~DataFrame.columns` of dtype ``object``. It is now a :class:`RangeIndex` (:issue:`52404`) - :class:`Series` created from empty dicts had :attr:`~Series.index` of dtype ``object``. It is now a :class:`RangeIndex` (:issue:`52404`) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 891b25f47000d..450f8b0a5439f 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1929,15 +1929,25 @@ def _str_rfind(self, sub, start: int = 0, end=None): ) def _str_split( - self, pat=None, n=-1, expand: bool = False, regex: bool | None = None + self, + pat: str | None = None, + n: int | None = -1, + expand: bool = False, + regex: bool | None = None, ): - raise NotImplementedError( - "str.split not supported with pd.ArrowDtype(pa.string())." - ) + if n in {-1, 0}: + n = None + if regex: + split_func = pc.split_pattern_regex + else: + split_func = pc.split_pattern + return type(self)(split_func(self._pa_array, pat, max_splits=n)) - def _str_rsplit(self, pat=None, n=-1): - raise NotImplementedError( - "str.rsplit not supported with pd.ArrowDtype(pa.string())." + def _str_rsplit(self, pat: str | None = None, n: int | None = -1): + if n in {-1, 0}: + n = None + return type(self)( + pc.split_pattern(self._pa_array, pat, max_splits=n, reverse=True) ) def _str_translate(self, table): diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index 2acfdcefed055..cdb77ecd7a4a5 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -41,6 +41,7 @@ ) from pandas.core.dtypes.missing import isna +from pandas.core.arrays.arrow.dtype import ArrowDtype from pandas.core.base import NoNewAttributesMixin from pandas.core.construction import extract_array @@ -267,27 +268,39 @@ def _wrap_result( # infer from ndim if expand is not specified expand = result.ndim != 1 - elif ( - expand is True - and is_object_dtype(result) - and not isinstance(self._orig, ABCIndex) - ): + elif expand is True and not isinstance(self._orig, ABCIndex): # required when expand=True is explicitly specified # not needed when inferred - - def cons_row(x): - if is_list_like(x): - return x - else: - return [x] - - result = [cons_row(x) for x in result] - if result and not self._is_string: - # propagate nan values to match longest sequence (GH 18450) - max_len = max(len(x) for x in result) - result = [ - x * max_len if len(x) == 0 or x[0] is np.nan else x for x in result - ] + if isinstance(result.dtype, ArrowDtype): + import pyarrow as pa + + from pandas.core.arrays.arrow.array import ArrowExtensionArray + + max_len = pa.compute.max( + result._pa_array.combine_chunks().value_lengths() + ).as_py() + if result.isna().any(): + result._pa_array = result._pa_array.fill_null([None] * max_len) + result = { + i: ArrowExtensionArray(pa.array(res)) + for i, res in enumerate(zip(*result.tolist())) + } + elif is_object_dtype(result): + + def cons_row(x): + if is_list_like(x): + return x + else: + return [x] + + result = [cons_row(x) for x in result] + if result and not self._is_string: + # propagate nan values to match longest sequence (GH 18450) + max_len = max(len(x) for x in result) + result = [ + x * max_len if len(x) == 0 or x[0] is np.nan else x + for x in result + ] if not isinstance(expand, bool): raise ValueError("expand must be True or False") diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 6bfd57938abc0..da1362a7bd26b 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -2098,6 +2098,62 @@ def test_str_removesuffix(val): tm.assert_series_equal(result, expected) +def test_str_split(): + # GH 52401 + ser = pd.Series(["a1cbcb", "a2cbcb", None], dtype=ArrowDtype(pa.string())) + result = ser.str.split("c") + expected = pd.Series( + ArrowExtensionArray(pa.array([["a1", "b", "b"], ["a2", "b", "b"], None])) + ) + tm.assert_series_equal(result, expected) + + result = ser.str.split("c", n=1) + expected = pd.Series( + ArrowExtensionArray(pa.array([["a1", "bcb"], ["a2", "bcb"], None])) + ) + tm.assert_series_equal(result, expected) + + result = ser.str.split("[1-2]", regex=True) + expected = pd.Series( + ArrowExtensionArray(pa.array([["a", "cbcb"], ["a", "cbcb"], None])) + ) + tm.assert_series_equal(result, expected) + + result = ser.str.split("[1-2]", regex=True, expand=True) + expected = pd.DataFrame( + { + 0: ArrowExtensionArray(pa.array(["a", "a", None])), + 1: ArrowExtensionArray(pa.array(["cbcb", "cbcb", None])), + } + ) + tm.assert_frame_equal(result, expected) + + +def test_str_rsplit(): + # GH 52401 + ser = pd.Series(["a1cbcb", "a2cbcb", None], dtype=ArrowDtype(pa.string())) + result = ser.str.rsplit("c") + expected = pd.Series( + ArrowExtensionArray(pa.array([["a1", "b", "b"], ["a2", "b", "b"], None])) + ) + tm.assert_series_equal(result, expected) + + result = ser.str.rsplit("c", n=1) + expected = pd.Series( + ArrowExtensionArray(pa.array([["a1cb", "b"], ["a2cb", "b"], None])) + ) + tm.assert_series_equal(result, expected) + + result = ser.str.rsplit("c", n=1, expand=True) + expected = pd.DataFrame( + { + 0: ArrowExtensionArray(pa.array(["a1cb", "a2cb", None])), + 1: ArrowExtensionArray(pa.array(["b", "b", None])), + } + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( "method, args", [ @@ -2113,8 +2169,6 @@ def test_str_removesuffix(val): ["rindex", ("abc",)], ["normalize", ("abc",)], ["rfind", ("abc",)], - ["split", ()], - ["rsplit", ()], ["translate", ("abc",)], ["wrap", ("abc",)], ], From a1bf5a545e26ebab5c343be1d14b5b36f2d11b4a Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 11 Apr 2023 11:30:09 -0700 Subject: [PATCH 2/6] Update pandas/core/arrays/arrow/array.py --- pandas/core/arrays/arrow/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 450f8b0a5439f..6ffba91232e63 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1941,7 +1941,7 @@ def _str_split( split_func = pc.split_pattern_regex else: split_func = pc.split_pattern - return type(self)(split_func(self._pa_array, pat, max_splits=n)) + return type(self)(split_func(self._data, pat, max_splits=n)) def _str_rsplit(self, pat: str | None = None, n: int | None = -1): if n in {-1, 0}: From 04b017a14de1382849e62d2a140a2d89c657b2d3 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 11 Apr 2023 11:30:16 -0700 Subject: [PATCH 3/6] Update pandas/core/arrays/arrow/array.py --- pandas/core/arrays/arrow/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 6ffba91232e63..b0c89f12e0675 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1947,7 +1947,7 @@ def _str_rsplit(self, pat: str | None = None, n: int | None = -1): if n in {-1, 0}: n = None return type(self)( - pc.split_pattern(self._pa_array, pat, max_splits=n, reverse=True) + pc.split_pattern(self._data, pat, max_splits=n, reverse=True) ) def _str_translate(self, table): From 204fa2ebe8c8117f8ce19fec3dcb671caf389cda Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 11 Apr 2023 11:30:31 -0700 Subject: [PATCH 4/6] Update pandas/core/strings/accessor.py --- pandas/core/strings/accessor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index cdb77ecd7a4a5..b6e0a1776cbd0 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -277,7 +277,7 @@ def _wrap_result( from pandas.core.arrays.arrow.array import ArrowExtensionArray max_len = pa.compute.max( - result._pa_array.combine_chunks().value_lengths() + result._data.combine_chunks().value_lengths() ).as_py() if result.isna().any(): result._pa_array = result._pa_array.fill_null([None] * max_len) From ad6287891a713e3118b8209fc976d811d0f9d392 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 11 Apr 2023 11:30:43 -0700 Subject: [PATCH 5/6] Update pandas/core/strings/accessor.py --- pandas/core/strings/accessor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index b6e0a1776cbd0..e44ccd9aede83 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -280,7 +280,7 @@ def _wrap_result( result._data.combine_chunks().value_lengths() ).as_py() if result.isna().any(): - result._pa_array = result._pa_array.fill_null([None] * max_len) + result._data = result._data.fill_null([None] * max_len) result = { i: ArrowExtensionArray(pa.array(res)) for i, res in enumerate(zip(*result.tolist())) From c20c6aa4c95beb7ed1b3698acb537699ad1517ed Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 11 Apr 2023 12:45:42 -0700 Subject: [PATCH 6/6] Update pandas/core/arrays/arrow/array.py --- pandas/core/arrays/arrow/array.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index b0c89f12e0675..c46420eb0b349 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1946,9 +1946,7 @@ def _str_split( def _str_rsplit(self, pat: str | None = None, n: int | None = -1): if n in {-1, 0}: n = None - return type(self)( - pc.split_pattern(self._data, pat, max_splits=n, reverse=True) - ) + return type(self)(pc.split_pattern(self._data, pat, max_splits=n, reverse=True)) def _str_translate(self, table): raise NotImplementedError(