Skip to content

Commit f3e1f32

Browse files
committed
BUG (string): ArrowStringArray.find corner cases
1 parent 27c7d51 commit f3e1f32

File tree

3 files changed

+8
-20
lines changed

3 files changed

+8
-20
lines changed

pandas/core/arrays/arrow/array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2395,7 +2395,7 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None) -> Self:
23952395
if sub == "":
23962396
# GH 56792
23972397
result = self._apply_elementwise(lambda val: val.find(sub, start, end))
2398-
return type(self)(pa.chunked_array(result))
2398+
return self._convert_int_result(pa.chunked_array(result))
23992399
if start is None:
24002400
start_offset = 0
24012401
start = 0
@@ -2409,7 +2409,7 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None) -> Self:
24092409
found = pc.not_equal(result, pa.scalar(-1, type=result.type))
24102410
offset_result = pc.add(result, start_offset)
24112411
result = pc.if_else(found, offset_result, -1)
2412-
return type(self)(result)
2412+
return self._convert_int_result(result)
24132413

24142414
def _str_join(self, sep: str) -> Self:
24152415
if pa.types.is_string(self._pa_array.type) or pa.types.is_large_string(

pandas/core/arrays/string_arrow.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ def astype(self, dtype, copy: bool = True):
282282
_str_map = BaseStringArray._str_map
283283
_str_startswith = ArrowStringArrayMixin._str_startswith
284284
_str_endswith = ArrowStringArrayMixin._str_endswith
285+
_str_find = ArrowExtensionArray._str_find
285286

286287
def _str_contains(
287288
self, pat, case: bool = True, flags: int = 0, na=np.nan, regex: bool = True
@@ -436,20 +437,6 @@ def _str_count(self, pat: str, flags: int = 0):
436437
result = pc.count_substring_regex(self._pa_array, pat)
437438
return self._convert_int_result(result)
438439

439-
def _str_find(self, sub: str, start: int = 0, end: int | None = None):
440-
if start != 0 and end is not None:
441-
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
442-
result = pc.find_substring(slices, sub)
443-
not_found = pc.equal(result, -1)
444-
offset_result = pc.add(result, end - start)
445-
result = pc.if_else(not_found, result, offset_result)
446-
elif start == 0 and end is None:
447-
slices = self._pa_array
448-
result = pc.find_substring(slices, sub)
449-
else:
450-
return super()._str_find(sub, start, end)
451-
return self._convert_int_result(result)
452-
453440
def _str_get_dummies(self, sep: str = "|"):
454441
dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies(sep)
455442
if len(labels) == 0:

pandas/tests/extension/test_arrow.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@
3232
import numpy as np
3333
import pytest
3434

35-
from pandas._config import using_string_dtype
36-
3735
from pandas._libs import lib
3836
from pandas._libs.tslibs import timezones
3937
from pandas.compat import (
@@ -1978,7 +1976,6 @@ def test_str_find_large_start():
19781976
tm.assert_series_equal(result, expected)
19791977

19801978

1981-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
19821979
@pytest.mark.skipif(
19831980
pa_version_under13p0, reason="https://github.com/apache/arrow/issues/36311"
19841981
)
@@ -1990,11 +1987,15 @@ def test_str_find_e2e(start, end, sub):
19901987
["abcaadef", "abc", "abcdeddefgj8292", "ab", "a", ""],
19911988
dtype=ArrowDtype(pa.string()),
19921989
)
1993-
object_series = s.astype(pd.StringDtype())
1990+
object_series = s.astype(pd.StringDtype(storage="python"))
19941991
result = s.str.find(sub, start, end)
19951992
expected = object_series.str.find(sub, start, end).astype(result.dtype)
19961993
tm.assert_series_equal(result, expected)
19971994

1995+
arrow_str_series = s.astype(pd.StringDtype(storage="pyarrow"))
1996+
result2 = arrow_str_series.str.find(sub, start, end).astype(result.dtype)
1997+
tm.assert_series_equal(result2, expected)
1998+
19981999

19992000
def test_str_find_negative_start_negative_end_no_match():
20002001
# GH 56791

0 commit comments

Comments
 (0)