From 0b1ed491fa49416f1a3dc7c44e7eca126bacdc5c Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Wed, 5 Apr 2023 19:10:35 -0700 Subject: [PATCH 1/6] Add str split for ArrowDtype(pa.string()) --- pandas/core/arrays/arrow/array.py | 24 ++++++++---- pandas/core/strings/accessor.py | 51 +++++++++++++++--------- pandas/tests/extension/test_arrow.py | 58 +++++++++++++++++++++++++++- 3 files changed, 105 insertions(+), 28 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index f76fe166dba78..dcfa93877e42f 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1940,15 +1940,25 @@ def _str_rfind(self, sub, start: int = 0, end=None): ) def _str_split( - self, pat=None, n=-1, expand: bool = False, regex: bool | None = None + self, + pat: str | None = None, + n: int = -1, + expand: bool = False, + regex: bool | None = None, ): - raise NotImplementedError( - "str.split not supported with pd.ArrowDtype(pa.string())." - ) + if n in {-1, 0}: + n = None + if regex: + split_func = pc.split_pattern_regex + else: + split_func = pc.split_pattern + return type(self)(split_func(self._pa_array, pat, max_splits=n)) - def _str_rsplit(self, pat=None, n=-1): - raise NotImplementedError( - "str.rsplit not supported with pd.ArrowDtype(pa.string())." + def _str_rsplit(self, pat: str | None = None, n: int = -1): + if n in {-1, 0}: + n = None + return type(self)( + pc.split_pattern(self._pa_array, pat, max_splits=n, reverse=True) ) def _str_translate(self, table): diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index deb87149ab8cc..b1c945b5b39cd 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -41,6 +41,7 @@ ) from pandas.core.dtypes.missing import isna +from pandas.core.arrays.arrow.dtype import ArrowDtype from pandas.core.base import NoNewAttributesMixin from pandas.core.construction import extract_array @@ -267,27 +268,39 @@ def _wrap_result( # infer from ndim if expand is not specified expand = result.ndim != 1 - elif ( - expand is True - and is_object_dtype(result) - and not isinstance(self._orig, ABCIndex) - ): + elif expand is True and not isinstance(self._orig, ABCIndex): # required when expand=True is explicitly specified # not needed when inferred - - def cons_row(x): - if is_list_like(x): - return x - else: - return [x] - - result = [cons_row(x) for x in result] - if result and not self._is_string: - # propagate nan values to match longest sequence (GH 18450) - max_len = max(len(x) for x in result) - result = [ - x * max_len if len(x) == 0 or x[0] is np.nan else x for x in result - ] + if isinstance(result.dtype, ArrowDtype): + import pyarrow as pa + + from pandas.core.arrays.arrow.array import ArrowExtensionArray + + max_len = pa.compute.max( + result._pa_array.combine_chunks().value_lengths() + ).as_py() + if result.isna().any(): + result = result.fill_null([None] * max_len) + result = { + i: ArrowExtensionArray(pa.array(list(res))) + for i, res in enumerate(zip(*result.tolist())) + } + elif is_object_dtype(result): + + def cons_row(x): + if is_list_like(x): + return x + else: + return [x] + + result = [cons_row(x) for x in result] + if result and not self._is_string: + # propagate nan values to match longest sequence (GH 18450) + max_len = max(len(x) for x in result) + result = [ + x * max_len if len(x) == 0 or x[0] is np.nan else x + for x in result + ] if not isinstance(expand, bool): raise ValueError("expand must be True or False") diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index df470d85a4fad..ca1f6104d4f06 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -2059,6 +2059,62 @@ def test_str_removesuffix(val): tm.assert_series_equal(result, expected) +def test_str_split(): + # GH 52401 + ser = pd.Series(["a1cbcb", "a2cbcb", None], dtype=pd.ArrowDtype(pa.string())) + result = ser.str.split("c") + expected = pd.Series( + ArrowExtensionArray(pa.array([["a1" "b" "b"], ["a2" "b" "b"], None])) + ) + tm.assert_series_equal(result, expected) + + result = ser.str.split("c", n=1) + expected = pd.Series( + ArrowExtensionArray(pa.array([["a1" "bcb"], ["a2" "bcb"], None])) + ) + tm.assert_series_equal(result, expected) + + result = ser.str.split("[1-2]", regex=True) + expected = pd.Series( + ArrowExtensionArray(pa.array([["a" "cbcb"], ["a" "cbcb"], None])) + ) + tm.assert_series_equal(result, expected) + + result = ser.str.split("[1-2]", regex=True, expand=True) + expected = pd.DataFrame( + { + 0: ArrowExtensionArray(pa.array(["a", "a", None])), + 1: ArrowExtensionArray(pa.array(["cbcb", "cbcb", None])), + } + ) + tm.assert_frame_equal(result, expected) + + +def test_str_rsplit(): + # GH 52401 + ser = pd.Series(["a1cbcb", "a2cbcb", None], dtype=pd.ArrowDtype(pa.string())) + result = ser.str.rsplit("c") + expected = pd.Series( + ArrowExtensionArray(pa.array([["a1" "b" "b"], ["a2" "b" "b"], None])) + ) + tm.assert_series_equal(result, expected) + + result = ser.str.rsplit("c", n=1) + expected = pd.Series( + ArrowExtensionArray(pa.array([["a1cb" "b"], ["a2cb" "b"], None])) + ) + tm.assert_series_equal(result, expected) + + result = ser.str.rsplit("c", n=1, expand=True) + expected = pd.DataFrame( + { + 0: ArrowExtensionArray(pa.array(["a", "a", None])), + 1: ArrowExtensionArray(pa.array(["cbcb", "cbcb", None])), + } + ) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( "method, args", [ @@ -2074,8 +2130,6 @@ def test_str_removesuffix(val): ["rindex", ("abc",)], ["normalize", ("abc",)], ["rfind", ("abc",)], - ["split", ()], - ["rsplit", ()], ["translate", ("abc",)], ["wrap", ("abc",)], ], From 427381a089c3eeafe08e86e4f8ca3e9d87403cb8 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Thu, 6 Apr 2023 11:40:45 -0700 Subject: [PATCH 2/6] Add tests --- pandas/tests/extension/test_arrow.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index ca1f6104d4f06..c821f1572d213 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -2061,22 +2061,22 @@ def test_str_removesuffix(val): def test_str_split(): # GH 52401 - ser = pd.Series(["a1cbcb", "a2cbcb", None], dtype=pd.ArrowDtype(pa.string())) + ser = pd.Series(["a1cbcb", "a2cbcb", None], dtype=ArrowDtype(pa.string())) result = ser.str.split("c") expected = pd.Series( - ArrowExtensionArray(pa.array([["a1" "b" "b"], ["a2" "b" "b"], None])) + ArrowExtensionArray(pa.array([["a1", "b", "b"], ["a2", "b", "b"], None])) ) tm.assert_series_equal(result, expected) result = ser.str.split("c", n=1) expected = pd.Series( - ArrowExtensionArray(pa.array([["a1" "bcb"], ["a2" "bcb"], None])) + ArrowExtensionArray(pa.array([["a1", "bcb"], ["a2", "bcb"], None])) ) tm.assert_series_equal(result, expected) result = ser.str.split("[1-2]", regex=True) expected = pd.Series( - ArrowExtensionArray(pa.array([["a" "cbcb"], ["a" "cbcb"], None])) + ArrowExtensionArray(pa.array([["a", "cbcb"], ["a", "cbcb"], None])) ) tm.assert_series_equal(result, expected) @@ -2092,16 +2092,16 @@ def test_str_split(): def test_str_rsplit(): # GH 52401 - ser = pd.Series(["a1cbcb", "a2cbcb", None], dtype=pd.ArrowDtype(pa.string())) + ser = pd.Series(["a1cbcb", "a2cbcb", None], dtype=ArrowDtype(pa.string())) result = ser.str.rsplit("c") expected = pd.Series( - ArrowExtensionArray(pa.array([["a1" "b" "b"], ["a2" "b" "b"], None])) + ArrowExtensionArray(pa.array([["a1", "b", "b"], ["a2", "b", "b"], None])) ) tm.assert_series_equal(result, expected) result = ser.str.rsplit("c", n=1) expected = pd.Series( - ArrowExtensionArray(pa.array([["a1cb" "b"], ["a2cb" "b"], None])) + ArrowExtensionArray(pa.array([["a1cb", "b"], ["a2cb", "b"], None])) ) tm.assert_series_equal(result, expected) From f205e1a53c3efedfad8363747e7bafc02d5d22ff Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Thu, 6 Apr 2023 13:29:39 -0700 Subject: [PATCH 3/6] Fix tests and add whats --- doc/source/whatsnew/v2.1.0.rst | 2 +- pandas/core/strings/accessor.py | 4 ++-- pandas/tests/extension/test_arrow.py | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/doc/source/whatsnew/v2.1.0.rst b/doc/source/whatsnew/v2.1.0.rst index 107b22953ff79..2c73d2ad9de3e 100644 --- a/doc/source/whatsnew/v2.1.0.rst +++ b/doc/source/whatsnew/v2.1.0.rst @@ -87,7 +87,7 @@ Other enhancements - :meth:`DataFrame.applymap` now uses the :meth:`~api.extensions.ExtensionArray.map` method of underlying :class:`api.extensions.ExtensionArray` instances (:issue:`52219`) - :meth:`arrays.SparseArray.map` now supports ``na_action`` (:issue:`52096`). - Add dtype of categories to ``repr`` information of :class:`CategoricalDtype` (:issue:`52179`) -- +- Implemented :meth:`Series.str.split` and :meth:`Series.str.rsplit` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`52401`) .. --------------------------------------------------------------------------- .. _whatsnew_210.notable_bug_fixes: diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index b1c945b5b39cd..e9a5d60a156a6 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -280,9 +280,9 @@ def _wrap_result( result._pa_array.combine_chunks().value_lengths() ).as_py() if result.isna().any(): - result = result.fill_null([None] * max_len) + result._pa_array = result._pa_array.fill_null([None] * max_len) result = { - i: ArrowExtensionArray(pa.array(list(res))) + i: ArrowExtensionArray(pa.array(res)) for i, res in enumerate(zip(*result.tolist())) } elif is_object_dtype(result): diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index cf5438436b4e1..1659ffdcb29df 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -2126,11 +2126,11 @@ def test_str_rsplit(): result = ser.str.rsplit("c", n=1, expand=True) expected = pd.DataFrame( { - 0: ArrowExtensionArray(pa.array(["a", "a", None])), - 1: ArrowExtensionArray(pa.array(["cbcb", "cbcb", None])), + 0: ArrowExtensionArray(pa.array(["a1cb", "a2cb", None])), + 1: ArrowExtensionArray(pa.array(["b", "b", None])), } ) - tm.assert_series_equal(result, expected) + tm.assert_frame_equal(result, expected) @pytest.mark.parametrize( From 59811a546141c44a668f4e43bdccbfbd86b64902 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Thu, 6 Apr 2023 18:51:17 -0700 Subject: [PATCH 4/6] Typing --- pandas/core/arrays/arrow/array.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 961afdf1a10e6..e42fc4db85aad 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1947,7 +1947,7 @@ def _str_rfind(self, sub, start: int = 0, end=None): def _str_split( self, pat: str | None = None, - n: int = -1, + n: int | None = -1, expand: bool = False, regex: bool | None = None, ): @@ -1959,7 +1959,7 @@ def _str_split( split_func = pc.split_pattern return type(self)(split_func(self._pa_array, pat, max_splits=n)) - def _str_rsplit(self, pat: str | None = None, n: int = -1): + def _str_rsplit(self, pat: str | None = None, n: int | None = -1): if n in {-1, 0}: n = None return type(self)( From 06fc6ca56ca9b01ebb4e98fbb93b173d5d368674 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Fri, 7 Apr 2023 11:00:25 -0700 Subject: [PATCH 5/6] More whatsnew note --- doc/source/whatsnew/v2.0.1.rst | 2 +- doc/source/whatsnew/v2.1.0.rst | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/doc/source/whatsnew/v2.0.1.rst b/doc/source/whatsnew/v2.0.1.rst index 882b600278c88..9f57e98879a34 100644 --- a/doc/source/whatsnew/v2.0.1.rst +++ b/doc/source/whatsnew/v2.0.1.rst @@ -31,7 +31,7 @@ Bug fixes Other ~~~~~ -- +- Implemented :meth:`Series.str.split` and :meth:`Series.str.rsplit` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`52401`) .. --------------------------------------------------------------------------- .. _whatsnew_201.contributors: diff --git a/doc/source/whatsnew/v2.1.0.rst b/doc/source/whatsnew/v2.1.0.rst index 8ebe07cf44ed3..192cd82271f82 100644 --- a/doc/source/whatsnew/v2.1.0.rst +++ b/doc/source/whatsnew/v2.1.0.rst @@ -87,7 +87,6 @@ Other enhancements - :meth:`DataFrame.applymap` now uses the :meth:`~api.extensions.ExtensionArray.map` method of underlying :class:`api.extensions.ExtensionArray` instances (:issue:`52219`) - :meth:`arrays.SparseArray.map` now supports ``na_action`` (:issue:`52096`). - Add dtype of categories to ``repr`` information of :class:`CategoricalDtype` (:issue:`52179`) -- Implemented :meth:`Series.str.split` and :meth:`Series.str.rsplit` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`52401`) .. --------------------------------------------------------------------------- .. _whatsnew_210.notable_bug_fixes: From 50185abce76a3c7c9bbe9479b0947eae1b4bbcbd Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Fri, 7 Apr 2023 14:19:15 -0700 Subject: [PATCH 6/6] undo whatsnew --- doc/source/whatsnew/v2.1.0.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/source/whatsnew/v2.1.0.rst b/doc/source/whatsnew/v2.1.0.rst index 136c5869d7b55..fd19c84f8ab23 100644 --- a/doc/source/whatsnew/v2.1.0.rst +++ b/doc/source/whatsnew/v2.1.0.rst @@ -87,6 +87,7 @@ Other enhancements - :meth:`DataFrame.applymap` now uses the :meth:`~api.extensions.ExtensionArray.map` method of underlying :class:`api.extensions.ExtensionArray` instances (:issue:`52219`) - :meth:`arrays.SparseArray.map` now supports ``na_action`` (:issue:`52096`). - Add dtype of categories to ``repr`` information of :class:`CategoricalDtype` (:issue:`52179`) +- .. --------------------------------------------------------------------------- .. _whatsnew_210.notable_bug_fixes: