Skip to content

ENH: Implement str.r/split for ArrowDtype #52499

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 11, 2023
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.0.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Bug fixes

Other
~~~~~
- Implemented :meth:`Series.str.split` and :meth:`Series.str.rsplit` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`52401`)
- :class:`DataFrame` created from empty dicts had :attr:`~DataFrame.columns` of dtype ``object``. It is now a :class:`RangeIndex` (:issue:`52404`)
- :class:`Series` created from empty dicts had :attr:`~Series.index` of dtype ``object``. It is now a :class:`RangeIndex` (:issue:`52404`)

Expand Down
24 changes: 17 additions & 7 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2025,15 +2025,25 @@ def _str_rfind(self, sub, start: int = 0, end=None):
)

def _str_split(
self, pat=None, n=-1, expand: bool = False, regex: bool | None = None
self,
pat: str | None = None,
n: int | None = -1,
expand: bool = False,
regex: bool | None = None,
):
raise NotImplementedError(
"str.split not supported with pd.ArrowDtype(pa.string())."
)
if n in {-1, 0}:
n = None
if regex:
split_func = pc.split_pattern_regex
else:
split_func = pc.split_pattern
return type(self)(split_func(self._pa_array, pat, max_splits=n))

def _str_rsplit(self, pat=None, n=-1):
raise NotImplementedError(
"str.rsplit not supported with pd.ArrowDtype(pa.string())."
def _str_rsplit(self, pat: str | None = None, n: int | None = -1):
if n in {-1, 0}:
n = None
return type(self)(
pc.split_pattern(self._pa_array, pat, max_splits=n, reverse=True)
)

def _str_translate(self, table):
Expand Down
51 changes: 32 additions & 19 deletions pandas/core/strings/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
)
from pandas.core.dtypes.missing import isna

from pandas.core.arrays.arrow.dtype import ArrowDtype
from pandas.core.base import NoNewAttributesMixin
from pandas.core.construction import extract_array

Expand Down Expand Up @@ -267,27 +268,39 @@ def _wrap_result(
# infer from ndim if expand is not specified
expand = result.ndim != 1

elif (
expand is True
and is_object_dtype(result)
and not isinstance(self._orig, ABCIndex)
):
elif expand is True and not isinstance(self._orig, ABCIndex):
# required when expand=True is explicitly specified
# not needed when inferred

def cons_row(x):
if is_list_like(x):
return x
else:
return [x]

result = [cons_row(x) for x in result]
if result and not self._is_string:
# propagate nan values to match longest sequence (GH 18450)
max_len = max(len(x) for x in result)
result = [
x * max_len if len(x) == 0 or x[0] is np.nan else x for x in result
]
if isinstance(result.dtype, ArrowDtype):
import pyarrow as pa

from pandas.core.arrays.arrow.array import ArrowExtensionArray

max_len = pa.compute.max(
result._pa_array.combine_chunks().value_lengths()
).as_py()
if result.isna().any():
result._pa_array = result._pa_array.fill_null([None] * max_len)
result = {
i: ArrowExtensionArray(pa.array(res))
for i, res in enumerate(zip(*result.tolist()))
}
elif is_object_dtype(result):

def cons_row(x):
if is_list_like(x):
return x
else:
return [x]

result = [cons_row(x) for x in result]
if result and not self._is_string:
# propagate nan values to match longest sequence (GH 18450)
max_len = max(len(x) for x in result)
result = [
x * max_len if len(x) == 0 or x[0] is np.nan else x
for x in result
]

if not isinstance(expand, bool):
raise ValueError("expand must be True or False")
Expand Down
58 changes: 56 additions & 2 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2067,6 +2067,62 @@ def test_str_removesuffix(val):
tm.assert_series_equal(result, expected)


def test_str_split():
# GH 52401
ser = pd.Series(["a1cbcb", "a2cbcb", None], dtype=ArrowDtype(pa.string()))
result = ser.str.split("c")
expected = pd.Series(
ArrowExtensionArray(pa.array([["a1", "b", "b"], ["a2", "b", "b"], None]))
)
tm.assert_series_equal(result, expected)

result = ser.str.split("c", n=1)
expected = pd.Series(
ArrowExtensionArray(pa.array([["a1", "bcb"], ["a2", "bcb"], None]))
)
tm.assert_series_equal(result, expected)

result = ser.str.split("[1-2]", regex=True)
expected = pd.Series(
ArrowExtensionArray(pa.array([["a", "cbcb"], ["a", "cbcb"], None]))
)
tm.assert_series_equal(result, expected)

result = ser.str.split("[1-2]", regex=True, expand=True)
expected = pd.DataFrame(
{
0: ArrowExtensionArray(pa.array(["a", "a", None])),
1: ArrowExtensionArray(pa.array(["cbcb", "cbcb", None])),
}
)
tm.assert_frame_equal(result, expected)


def test_str_rsplit():
# GH 52401
ser = pd.Series(["a1cbcb", "a2cbcb", None], dtype=ArrowDtype(pa.string()))
result = ser.str.rsplit("c")
expected = pd.Series(
ArrowExtensionArray(pa.array([["a1", "b", "b"], ["a2", "b", "b"], None]))
)
tm.assert_series_equal(result, expected)

result = ser.str.rsplit("c", n=1)
expected = pd.Series(
ArrowExtensionArray(pa.array([["a1cb", "b"], ["a2cb", "b"], None]))
)
tm.assert_series_equal(result, expected)

result = ser.str.rsplit("c", n=1, expand=True)
expected = pd.DataFrame(
{
0: ArrowExtensionArray(pa.array(["a1cb", "a2cb", None])),
1: ArrowExtensionArray(pa.array(["b", "b", None])),
}
)
tm.assert_frame_equal(result, expected)


@pytest.mark.parametrize(
"method, args",
[
Expand All @@ -2082,8 +2138,6 @@ def test_str_removesuffix(val):
["rindex", ("abc",)],
["normalize", ("abc",)],
["rfind", ("abc",)],
["split", ()],
["rsplit", ()],
["translate", ("abc",)],
["wrap", ("abc",)],
],
Expand Down