diff --git a/doc/source/whatsnew/v2.0.3.rst b/doc/source/whatsnew/v2.0.3.rst index 2c63d7d20ed1c..89c64b02e0cb5 100644 --- a/doc/source/whatsnew/v2.0.3.rst +++ b/doc/source/whatsnew/v2.0.3.rst @@ -22,6 +22,8 @@ Fixed regressions Bug fixes ~~~~~~~~~ - Bug in :func:`read_csv` when defining ``dtype`` with ``bool[pyarrow]`` for the ``"c"`` and ``"python"`` engines (:issue:`53390`) +- Bug in :meth:`Series.str.split` and :meth:`Series.str.rsplit` with ``expand=True`` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`53532`) +- .. --------------------------------------------------------------------------- .. _whatsnew_203.other: diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index 699a32fe0c028..e544bde16da9c 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -273,14 +273,40 @@ def _wrap_result( if isinstance(result.dtype, ArrowDtype): import pyarrow as pa + from pandas.compat import pa_version_under11p0 + from pandas.core.arrays.arrow.array import ArrowExtensionArray - max_len = pa.compute.max( - result._data.combine_chunks().value_lengths() - ).as_py() - if result.isna().any(): + value_lengths = result._data.combine_chunks().value_lengths() + max_len = pa.compute.max(value_lengths).as_py() + min_len = pa.compute.min(value_lengths).as_py() + if result._hasna: # ArrowExtensionArray.fillna doesn't work for list scalars - result._data = result._data.fill_null([None] * max_len) + result = ArrowExtensionArray( + result._data.fill_null([None] * max_len) + ) + if min_len < max_len: + # append nulls to each scalar list element up to max_len + if not pa_version_under11p0: + result = ArrowExtensionArray( + pa.compute.list_slice( + result._data, + start=0, + stop=max_len, + return_fixed_size_list=True, + ) + ) + else: + all_null = np.full(max_len, fill_value=None, dtype=object) + values = result.to_numpy() + new_values = [] + for row in values: + if len(row) < max_len: + nulls = all_null[: max_len - len(row)] + row = np.append(row, nulls) + new_values.append(row) + pa_type = result._data.type + result = ArrowExtensionArray(pa.array(new_values, type=pa_type)) if name is not None: labels = name else: diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 3efb59fc6afce..03734626c8f95 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -2315,6 +2315,15 @@ def test_str_split(): ) tm.assert_frame_equal(result, expected) + result = ser.str.split("1", expand=True) + expected = pd.DataFrame( + { + 0: ArrowExtensionArray(pa.array(["a", "a2cbcb", None])), + 1: ArrowExtensionArray(pa.array(["cbcb", None, None])), + } + ) + tm.assert_frame_equal(result, expected) + def test_str_rsplit(): # GH 52401 @@ -2340,6 +2349,15 @@ def test_str_rsplit(): ) tm.assert_frame_equal(result, expected) + result = ser.str.rsplit("1", expand=True) + expected = pd.DataFrame( + { + 0: ArrowExtensionArray(pa.array(["a", "a2cbcb", None])), + 1: ArrowExtensionArray(pa.array(["cbcb", None, None])), + } + ) + tm.assert_frame_equal(result, expected) + def test_str_unsupported_extract(): ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))