diff --git a/pandas/core/arrays/_arrow_string_mixins.py b/pandas/core/arrays/_arrow_string_mixins.py index 5b34a7e2c7cef..950d4cd7cc92e 100644 --- a/pandas/core/arrays/_arrow_string_mixins.py +++ b/pandas/core/arrays/_arrow_string_mixins.py @@ -3,6 +3,7 @@ from functools import partial from typing import ( TYPE_CHECKING, + Any, Literal, ) @@ -10,6 +11,7 @@ from pandas.compat import ( pa_version_under10p1, + pa_version_under13p0, pa_version_under17p0, ) @@ -20,7 +22,10 @@ import pyarrow.compute as pc if TYPE_CHECKING: - from collections.abc import Sized + from collections.abc import ( + Callable, + Sized, + ) from pandas._typing import ( Scalar, @@ -42,6 +47,9 @@ def _convert_int_result(self, result): # Convert an integer-dtype result to the appropriate result type raise NotImplementedError + def _apply_elementwise(self, func: Callable) -> list[list[Any]]: + raise NotImplementedError + def _str_pad( self, width: int, @@ -205,3 +213,37 @@ def _str_contains( if not isna(na): # pyright: ignore [reportGeneralTypeIssues] result = result.fill_null(na) return self._convert_bool_result(result) + + def _str_find(self, sub: str, start: int = 0, end: int | None = None): + if ( + pa_version_under13p0 + and not (start != 0 and end is not None) + and not (start == 0 and end is None) + ): + # GH#59562 + res_list = self._apply_elementwise(lambda val: val.find(sub, start, end)) + return self._convert_int_result(pa.chunked_array(res_list)) + + if (start == 0 or start is None) and end is None: + result = pc.find_substring(self._pa_array, sub) + else: + if sub == "": + # GH#56792 + res_list = self._apply_elementwise( + lambda val: val.find(sub, start, end) + ) + return self._convert_int_result(pa.chunked_array(res_list)) + if start is None: + start_offset = 0 + start = 0 + elif start < 0: + start_offset = pc.add(start, pc.utf8_length(self._pa_array)) + start_offset = pc.if_else(pc.less(start_offset, 0), 0, start_offset) + else: + start_offset = start + slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end) + result = pc.find_substring(slices, sub) + found = pc.not_equal(result, pa.scalar(-1, type=result.type)) + offset_result = pc.add(result, start_offset) + result = pc.if_else(found, offset_result, -1) + return self._convert_int_result(result) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 40819ba4ab338..15f9ba611a642 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -2373,29 +2373,6 @@ def _str_fullmatch( pat = f"{pat}$" return self._str_match(pat, case, flags, na) - def _str_find(self, sub: str, start: int = 0, end: int | None = None) -> Self: - if (start == 0 or start is None) and end is None: - result = pc.find_substring(self._pa_array, sub) - else: - if sub == "": - # GH 56792 - result = self._apply_elementwise(lambda val: val.find(sub, start, end)) - return type(self)(pa.chunked_array(result)) - if start is None: - start_offset = 0 - start = 0 - elif start < 0: - start_offset = pc.add(start, pc.utf8_length(self._pa_array)) - start_offset = pc.if_else(pc.less(start_offset, 0), 0, start_offset) - else: - start_offset = start - slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end) - result = pc.find_substring(slices, sub) - found = pc.not_equal(result, pa.scalar(-1, type=result.type)) - offset_result = pc.add(result, start_offset) - result = pc.if_else(found, offset_result, -1) - return type(self)(result) - def _str_join(self, sep: str) -> Self: if pa.types.is_string(self._pa_array.type) or pa.types.is_large_string( self._pa_array.type diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index e18beb629d0c4..97381b82ceab9 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -416,18 +416,14 @@ def _str_count(self, pat: str, flags: int = 0): return self._convert_int_result(result) def _str_find(self, sub: str, start: int = 0, end: int | None = None): - if start != 0 and end is not None: - slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end) - result = pc.find_substring(slices, sub) - not_found = pc.equal(result, -1) - offset_result = pc.add(result, end - start) - result = pc.if_else(not_found, result, offset_result) - elif start == 0 and end is None: - slices = self._pa_array - result = pc.find_substring(slices, sub) - else: + if ( + pa_version_under13p0 + and not (start != 0 and end is not None) + and not (start == 0 and end is None) + ): + # GH#59562 return super()._str_find(sub, start, end) - return self._convert_int_result(result) + return ArrowStringArrayMixin._str_find(self, sub, start, end) def _str_get_dummies(self, sep: str = "|"): dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies(sep) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 3dbdda388d035..fc4f14882b9d7 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -32,8 +32,6 @@ import numpy as np import pytest -from pandas._config import using_string_dtype - from pandas._libs import lib from pandas._libs.tslibs import timezones from pandas.compat import ( @@ -1947,14 +1945,9 @@ def test_str_find_negative_start(): def test_str_find_no_end(): ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) - if pa_version_under13p0: - # https://github.com/apache/arrow/issues/36311 - with pytest.raises(pa.lib.ArrowInvalid, match="Negative buffer resize"): - ser.str.find("ab", start=1) - else: - result = ser.str.find("ab", start=1) - expected = pd.Series([-1, None], dtype="int64[pyarrow]") - tm.assert_series_equal(result, expected) + result = ser.str.find("ab", start=1) + expected = pd.Series([-1, None], dtype="int64[pyarrow]") + tm.assert_series_equal(result, expected) def test_str_find_negative_start_negative_end(): @@ -1968,17 +1961,11 @@ def test_str_find_negative_start_negative_end(): def test_str_find_large_start(): # GH 56791 ser = pd.Series(["abcdefg", None], dtype=ArrowDtype(pa.string())) - if pa_version_under13p0: - # https://github.com/apache/arrow/issues/36311 - with pytest.raises(pa.lib.ArrowInvalid, match="Negative buffer resize"): - ser.str.find(sub="d", start=16) - else: - result = ser.str.find(sub="d", start=16) - expected = pd.Series([-1, None], dtype=ArrowDtype(pa.int64())) - tm.assert_series_equal(result, expected) + result = ser.str.find(sub="d", start=16) + expected = pd.Series([-1, None], dtype=ArrowDtype(pa.int64())) + tm.assert_series_equal(result, expected) -@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False) @pytest.mark.skipif( pa_version_under13p0, reason="https://github.com/apache/arrow/issues/36311" ) @@ -1990,11 +1977,15 @@ def test_str_find_e2e(start, end, sub): ["abcaadef", "abc", "abcdeddefgj8292", "ab", "a", ""], dtype=ArrowDtype(pa.string()), ) - object_series = s.astype(pd.StringDtype()) + object_series = s.astype(pd.StringDtype(storage="python")) result = s.str.find(sub, start, end) expected = object_series.str.find(sub, start, end).astype(result.dtype) tm.assert_series_equal(result, expected) + arrow_str_series = s.astype(pd.StringDtype(storage="pyarrow")) + result2 = arrow_str_series.str.find(sub, start, end).astype(result.dtype) + tm.assert_series_equal(result2, expected) + def test_str_find_negative_start_negative_end_no_match(): # GH 56791