Skip to content

Commit 9fa7148

Browse files
committed
BUG (string): ArrowStringArray.find corner cases
1 parent bc9b1c3 commit 9fa7148

File tree

3 files changed

+8
-20
lines changed

3 files changed

+8
-20
lines changed

pandas/core/arrays/arrow/array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2395,7 +2395,7 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None) -> Self:
23952395
if sub == "":
23962396
# GH 56792
23972397
result = self._apply_elementwise(lambda val: val.find(sub, start, end))
2398-
return type(self)(pa.chunked_array(result))
2398+
return self._convert_int_result(pa.chunked_array(result))
23992399
if start is None:
24002400
start_offset = 0
24012401
start = 0
@@ -2409,7 +2409,7 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None) -> Self:
24092409
found = pc.not_equal(result, pa.scalar(-1, type=result.type))
24102410
offset_result = pc.add(result, start_offset)
24112411
result = pc.if_else(found, offset_result, -1)
2412-
return type(self)(result)
2412+
return self._convert_int_result(result)
24132413

24142414
def _str_join(self, sep: str) -> Self:
24152415
if pa.types.is_string(self._pa_array.type) or pa.types.is_large_string(

pandas/core/arrays/string_arrow.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@ def astype(self, dtype, copy: bool = True):
295295
_str_startswith = ArrowStringArrayMixin._str_startswith
296296
_str_endswith = ArrowStringArrayMixin._str_endswith
297297
_str_pad = ArrowStringArrayMixin._str_pad
298+
_str_find = ArrowExtensionArray._str_find
298299

299300
def _str_contains(
300301
self, pat, case: bool = True, flags: int = 0, na=np.nan, regex: bool = True
@@ -421,20 +422,6 @@ def _str_count(self, pat: str, flags: int = 0):
421422
result = pc.count_substring_regex(self._pa_array, pat)
422423
return self._convert_int_result(result)
423424

424-
def _str_find(self, sub: str, start: int = 0, end: int | None = None):
425-
if start != 0 and end is not None:
426-
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
427-
result = pc.find_substring(slices, sub)
428-
not_found = pc.equal(result, -1)
429-
offset_result = pc.add(result, end - start)
430-
result = pc.if_else(not_found, result, offset_result)
431-
elif start == 0 and end is None:
432-
slices = self._pa_array
433-
result = pc.find_substring(slices, sub)
434-
else:
435-
return super()._str_find(sub, start, end)
436-
return self._convert_int_result(result)
437-
438425
def _str_get_dummies(self, sep: str = "|"):
439426
dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies(sep)
440427
if len(labels) == 0:

pandas/tests/extension/test_arrow.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@
3232
import numpy as np
3333
import pytest
3434

35-
from pandas._config import using_string_dtype
36-
3735
from pandas._libs import lib
3836
from pandas._libs.tslibs import timezones
3937
from pandas.compat import (
@@ -1978,7 +1976,6 @@ def test_str_find_large_start():
19781976
tm.assert_series_equal(result, expected)
19791977

19801978

1981-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
19821979
@pytest.mark.skipif(
19831980
pa_version_under13p0, reason="https://github.com/apache/arrow/issues/36311"
19841981
)
@@ -1990,11 +1987,15 @@ def test_str_find_e2e(start, end, sub):
19901987
["abcaadef", "abc", "abcdeddefgj8292", "ab", "a", ""],
19911988
dtype=ArrowDtype(pa.string()),
19921989
)
1993-
object_series = s.astype(pd.StringDtype())
1990+
object_series = s.astype(pd.StringDtype(storage="python"))
19941991
result = s.str.find(sub, start, end)
19951992
expected = object_series.str.find(sub, start, end).astype(result.dtype)
19961993
tm.assert_series_equal(result, expected)
19971994

1995+
arrow_str_series = s.astype(pd.StringDtype(storage="pyarrow"))
1996+
result2 = arrow_str_series.str.find(sub, start, end).astype(result.dtype)
1997+
tm.assert_series_equal(result2, expected)
1998+
19981999

19992000
def test_str_find_negative_start_negative_end_no_match():
20002001
# GH 56791

0 commit comments

Comments
 (0)