diff --git a/doc/source/whatsnew/v2.0.1.rst b/doc/source/whatsnew/v2.0.1.rst index 46a780befcede..7b4dc890da3e1 100644 --- a/doc/source/whatsnew/v2.0.1.rst +++ b/doc/source/whatsnew/v2.0.1.rst @@ -36,6 +36,7 @@ Bug fixes Other ~~~~~ +- Implemented :meth:`Series.str.split` and :meth:`Series.str.rsplit` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`52401`) - :class:`DataFrame` created from empty dicts had :attr:`~DataFrame.columns` of dtype ``object``. It is now a :class:`RangeIndex` (:issue:`52404`) - :class:`Series` created from empty dicts had :attr:`~Series.index` of dtype ``object``. It is now a :class:`RangeIndex` (:issue:`52404`) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 891b25f47000d..c46420eb0b349 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1929,16 +1929,24 @@ def _str_rfind(self, sub, start: int = 0, end=None): ) def _str_split( - self, pat=None, n=-1, expand: bool = False, regex: bool | None = None + self, + pat: str | None = None, + n: int | None = -1, + expand: bool = False, + regex: bool | None = None, ): - raise NotImplementedError( - "str.split not supported with pd.ArrowDtype(pa.string())." - ) + if n in {-1, 0}: + n = None + if regex: + split_func = pc.split_pattern_regex + else: + split_func = pc.split_pattern + return type(self)(split_func(self._data, pat, max_splits=n)) - def _str_rsplit(self, pat=None, n=-1): - raise NotImplementedError( - "str.rsplit not supported with pd.ArrowDtype(pa.string())." - ) + def _str_rsplit(self, pat: str | None = None, n: int | None = -1): + if n in {-1, 0}: + n = None + return type(self)(pc.split_pattern(self._data, pat, max_splits=n, reverse=True)) def _str_translate(self, table): raise NotImplementedError( diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index 2acfdcefed055..e44ccd9aede83 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -41,6 +41,7 @@ ) from pandas.core.dtypes.missing import isna +from pandas.core.arrays.arrow.dtype import ArrowDtype from pandas.core.base import NoNewAttributesMixin from pandas.core.construction import extract_array @@ -267,27 +268,39 @@ def _wrap_result( # infer from ndim if expand is not specified expand = result.ndim != 1 - elif ( - expand is True - and is_object_dtype(result) - and not isinstance(self._orig, ABCIndex) - ): + elif expand is True and not isinstance(self._orig, ABCIndex): # required when expand=True is explicitly specified # not needed when inferred - - def cons_row(x): - if is_list_like(x): - return x - else: - return [x] - - result = [cons_row(x) for x in result] - if result and not self._is_string: - # propagate nan values to match longest sequence (GH 18450) - max_len = max(len(x) for x in result) - result = [ - x * max_len if len(x) == 0 or x[0] is np.nan else x for x in result - ] + if isinstance(result.dtype, ArrowDtype): + import pyarrow as pa + + from pandas.core.arrays.arrow.array import ArrowExtensionArray + + max_len = pa.compute.max( + result._data.combine_chunks().value_lengths() + ).as_py() + if result.isna().any(): + result._data = result._data.fill_null([None] * max_len) + result = { + i: ArrowExtensionArray(pa.array(res)) + for i, res in enumerate(zip(*result.tolist())) + } + elif is_object_dtype(result): + + def cons_row(x): + if is_list_like(x): + return x + else: + return [x] + + result = [cons_row(x) for x in result] + if result and not self._is_string: + # propagate nan values to match longest sequence (GH 18450) + max_len = max(len(x) for x in result) + result = [ + x * max_len if len(x) == 0 or x[0] is np.nan else x + for x in result + ] if not isinstance(expand, bool): raise ValueError("expand must be True or False") diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 6bfd57938abc0..da1362a7bd26b 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -2098,6 +2098,62 @@ def test_str_removesuffix(val): tm.assert_series_equal(result, expected) +def test_str_split(): + # GH 52401 + ser = pd.Series(["a1cbcb", "a2cbcb", None], dtype=ArrowDtype(pa.string())) + result = ser.str.split("c") + expected = pd.Series( + ArrowExtensionArray(pa.array([["a1", "b", "b"], ["a2", "b", "b"], None])) + ) + tm.assert_series_equal(result, expected) + + result = ser.str.split("c", n=1) + expected = pd.Series( + ArrowExtensionArray(pa.array([["a1", "bcb"], ["a2", "bcb"], None])) + ) + tm.assert_series_equal(result, expected) + + result = ser.str.split("[1-2]", regex=True) + expected = pd.Series( + ArrowExtensionArray(pa.array([["a", "cbcb"], ["a", "cbcb"], None])) + ) + tm.assert_series_equal(result, expected) + + result = ser.str.split("[1-2]", regex=True, expand=True) + expected = pd.DataFrame( + { + 0: ArrowExtensionArray(pa.array(["a", "a", None])), + 1: ArrowExtensionArray(pa.array(["cbcb", "cbcb", None])), + } + ) + tm.assert_frame_equal(result, expected) + + +def test_str_rsplit(): + # GH 52401 + ser = pd.Series(["a1cbcb", "a2cbcb", None], dtype=ArrowDtype(pa.string())) + result = ser.str.rsplit("c") + expected = pd.Series( + ArrowExtensionArray(pa.array([["a1", "b", "b"], ["a2", "b", "b"], None])) + ) + tm.assert_series_equal(result, expected) + + result = ser.str.rsplit("c", n=1) + expected = pd.Series( + ArrowExtensionArray(pa.array([["a1cb", "b"], ["a2cb", "b"], None])) + ) + tm.assert_series_equal(result, expected) + + result = ser.str.rsplit("c", n=1, expand=True) + expected = pd.DataFrame( + { + 0: ArrowExtensionArray(pa.array(["a1cb", "a2cb", None])), + 1: ArrowExtensionArray(pa.array(["b", "b", None])), + } + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( "method, args", [ @@ -2113,8 +2169,6 @@ def test_str_removesuffix(val): ["rindex", ("abc",)], ["normalize", ("abc",)], ["rfind", ("abc",)], - ["split", ()], - ["rsplit", ()], ["translate", ("abc",)], ["wrap", ("abc",)], ],