diff --git a/doc/source/whatsnew/v2.0.1.rst b/doc/source/whatsnew/v2.0.1.rst index 2b2a1029f6544..1bd758110a197 100644 --- a/doc/source/whatsnew/v2.0.1.rst +++ b/doc/source/whatsnew/v2.0.1.rst @@ -32,6 +32,7 @@ Bug fixes Other ~~~~~ +- Implemented :meth:`Series.str.split` and :meth:`Series.str.rsplit` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`52401`) - :class:`DataFrame` created from empty dicts had :attr:`~DataFrame.columns` of dtype ``object``. It is now a :class:`RangeIndex` (:issue:`52404`) - :class:`Series` created from empty dicts had :attr:`~Series.index` of dtype ``object``. It is now a :class:`RangeIndex` (:issue:`52404`) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index d33d25fb08069..5b0753fd32bb7 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -2025,15 +2025,25 @@ def _str_rfind(self, sub, start: int = 0, end=None): ) def _str_split( - self, pat=None, n=-1, expand: bool = False, regex: bool | None = None + self, + pat: str | None = None, + n: int | None = -1, + expand: bool = False, + regex: bool | None = None, ): - raise NotImplementedError( - "str.split not supported with pd.ArrowDtype(pa.string())." - ) + if n in {-1, 0}: + n = None + if regex: + split_func = pc.split_pattern_regex + else: + split_func = pc.split_pattern + return type(self)(split_func(self._pa_array, pat, max_splits=n)) - def _str_rsplit(self, pat=None, n=-1): - raise NotImplementedError( - "str.rsplit not supported with pd.ArrowDtype(pa.string())." + def _str_rsplit(self, pat: str | None = None, n: int | None = -1): + if n in {-1, 0}: + n = None + return type(self)( + pc.split_pattern(self._pa_array, pat, max_splits=n, reverse=True) ) def _str_translate(self, table): diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index deb87149ab8cc..e9a5d60a156a6 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -41,6 +41,7 @@ ) from pandas.core.dtypes.missing import isna +from pandas.core.arrays.arrow.dtype import ArrowDtype from pandas.core.base import NoNewAttributesMixin from pandas.core.construction import extract_array @@ -267,27 +268,39 @@ def _wrap_result( # infer from ndim if expand is not specified expand = result.ndim != 1 - elif ( - expand is True - and is_object_dtype(result) - and not isinstance(self._orig, ABCIndex) - ): + elif expand is True and not isinstance(self._orig, ABCIndex): # required when expand=True is explicitly specified # not needed when inferred - - def cons_row(x): - if is_list_like(x): - return x - else: - return [x] - - result = [cons_row(x) for x in result] - if result and not self._is_string: - # propagate nan values to match longest sequence (GH 18450) - max_len = max(len(x) for x in result) - result = [ - x * max_len if len(x) == 0 or x[0] is np.nan else x for x in result - ] + if isinstance(result.dtype, ArrowDtype): + import pyarrow as pa + + from pandas.core.arrays.arrow.array import ArrowExtensionArray + + max_len = pa.compute.max( + result._pa_array.combine_chunks().value_lengths() + ).as_py() + if result.isna().any(): + result._pa_array = result._pa_array.fill_null([None] * max_len) + result = { + i: ArrowExtensionArray(pa.array(res)) + for i, res in enumerate(zip(*result.tolist())) + } + elif is_object_dtype(result): + + def cons_row(x): + if is_list_like(x): + return x + else: + return [x] + + result = [cons_row(x) for x in result] + if result and not self._is_string: + # propagate nan values to match longest sequence (GH 18450) + max_len = max(len(x) for x in result) + result = [ + x * max_len if len(x) == 0 or x[0] is np.nan else x + for x in result + ] if not isinstance(expand, bool): raise ValueError("expand must be True or False") diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 12975499a8d5c..11ce10f19a1e3 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -2067,6 +2067,62 @@ def test_str_removesuffix(val): tm.assert_series_equal(result, expected) +def test_str_split(): + # GH 52401 + ser = pd.Series(["a1cbcb", "a2cbcb", None], dtype=ArrowDtype(pa.string())) + result = ser.str.split("c") + expected = pd.Series( + ArrowExtensionArray(pa.array([["a1", "b", "b"], ["a2", "b", "b"], None])) + ) + tm.assert_series_equal(result, expected) + + result = ser.str.split("c", n=1) + expected = pd.Series( + ArrowExtensionArray(pa.array([["a1", "bcb"], ["a2", "bcb"], None])) + ) + tm.assert_series_equal(result, expected) + + result = ser.str.split("[1-2]", regex=True) + expected = pd.Series( + ArrowExtensionArray(pa.array([["a", "cbcb"], ["a", "cbcb"], None])) + ) + tm.assert_series_equal(result, expected) + + result = ser.str.split("[1-2]", regex=True, expand=True) + expected = pd.DataFrame( + { + 0: ArrowExtensionArray(pa.array(["a", "a", None])), + 1: ArrowExtensionArray(pa.array(["cbcb", "cbcb", None])), + } + ) + tm.assert_frame_equal(result, expected) + + +def test_str_rsplit(): + # GH 52401 + ser = pd.Series(["a1cbcb", "a2cbcb", None], dtype=ArrowDtype(pa.string())) + result = ser.str.rsplit("c") + expected = pd.Series( + ArrowExtensionArray(pa.array([["a1", "b", "b"], ["a2", "b", "b"], None])) + ) + tm.assert_series_equal(result, expected) + + result = ser.str.rsplit("c", n=1) + expected = pd.Series( + ArrowExtensionArray(pa.array([["a1cb", "b"], ["a2cb", "b"], None])) + ) + tm.assert_series_equal(result, expected) + + result = ser.str.rsplit("c", n=1, expand=True) + expected = pd.DataFrame( + { + 0: ArrowExtensionArray(pa.array(["a1cb", "a2cb", None])), + 1: ArrowExtensionArray(pa.array(["b", "b", None])), + } + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( "method, args", [ @@ -2082,8 +2138,6 @@ def test_str_removesuffix(val): ["rindex", ("abc",)], ["normalize", ("abc",)], ["rfind", ("abc",)], - ["split", ()], - ["rsplit", ()], ["translate", ("abc",)], ["wrap", ("abc",)], ],