Skip to content

Commit b863dcb

Browse files
committed
BUG (string): ArrowStringArray.find corner cases
1 parent 360597c commit b863dcb

File tree

3 files changed

+11
-20
lines changed

3 files changed

+11
-20
lines changed

pandas/core/arrays/arrow/array.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2387,14 +2387,17 @@ def _str_fullmatch(
23872387
pat = f"{pat}$"
23882388
return self._str_match(pat, case, flags, na)
23892389

2390+
def _convert_int_dtype(self, result):
2391+
return type(self)(result)
2392+
23902393
def _str_find(self, sub: str, start: int = 0, end: int | None = None) -> Self:
23912394
if (start == 0 or start is None) and end is None:
23922395
result = pc.find_substring(self._pa_array, sub)
23932396
else:
23942397
if sub == "":
23952398
# GH 56792
23962399
result = self._apply_elementwise(lambda val: val.find(sub, start, end))
2397-
return type(self)(pa.chunked_array(result))
2400+
return self._convert_int_dtype(pa.chunked_array(result))
23982401
if start is None:
23992402
start_offset = 0
24002403
start = 0
@@ -2408,7 +2411,7 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None) -> Self:
24082411
found = pc.not_equal(result, pa.scalar(-1, type=result.type))
24092412
offset_result = pc.add(result, start_offset)
24102413
result = pc.if_else(found, offset_result, -1)
2411-
return type(self)(result)
2414+
return self._convert_int_dtype(result)
24122415

24132416
def _str_join(self, sep: str) -> Self:
24142417
if pa.types.is_string(self._pa_array.type) or pa.types.is_large_string(

pandas/core/arrays/string_arrow.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ def astype(self, dtype, copy: bool = True):
280280
# String methods interface
281281

282282
_str_map = BaseStringArray._str_map
283+
_str_find = ArrowExtensionArray._str_find
283284

284285
def _str_contains(
285286
self, pat, case: bool = True, flags: int = 0, na=np.nan, regex: bool = True
@@ -474,20 +475,6 @@ def _str_count(self, pat: str, flags: int = 0):
474475
result = pc.count_substring_regex(self._pa_array, pat)
475476
return self._convert_int_dtype(result)
476477

477-
def _str_find(self, sub: str, start: int = 0, end: int | None = None):
478-
if start != 0 and end is not None:
479-
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
480-
result = pc.find_substring(slices, sub)
481-
not_found = pc.equal(result, -1)
482-
offset_result = pc.add(result, end - start)
483-
result = pc.if_else(not_found, result, offset_result)
484-
elif start == 0 and end is None:
485-
slices = self._pa_array
486-
result = pc.find_substring(slices, sub)
487-
else:
488-
return super()._str_find(sub, start, end)
489-
return self._convert_int_dtype(result)
490-
491478
def _str_get_dummies(self, sep: str = "|"):
492479
dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies(sep)
493480
if len(labels) == 0:

pandas/tests/extension/test_arrow.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@
3232
import numpy as np
3333
import pytest
3434

35-
from pandas._config import using_string_dtype
36-
3735
from pandas._libs import lib
3836
from pandas._libs.tslibs import timezones
3937
from pandas.compat import (
@@ -1995,7 +1993,6 @@ def test_str_find_large_start():
19951993
tm.assert_series_equal(result, expected)
19961994

19971995

1998-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
19991996
@pytest.mark.skipif(
20001997
pa_version_under13p0, reason="https://github.com/apache/arrow/issues/36311"
20011998
)
@@ -2007,11 +2004,15 @@ def test_str_find_e2e(start, end, sub):
20072004
["abcaadef", "abc", "abcdeddefgj8292", "ab", "a", ""],
20082005
dtype=ArrowDtype(pa.string()),
20092006
)
2010-
object_series = s.astype(pd.StringDtype())
2007+
object_series = s.astype(pd.StringDtype(storage="python"))
20112008
result = s.str.find(sub, start, end)
20122009
expected = object_series.str.find(sub, start, end).astype(result.dtype)
20132010
tm.assert_series_equal(result, expected)
20142011

2012+
arrow_str_series = s.astype(pd.StringDtype(storage="pyarrow"))
2013+
result2 = arrow_str_series.str.find(sub, start, end).astype(result.dtype)
2014+
tm.assert_series_equal(result2, expected)
2015+
20152016

20162017
def test_str_find_negative_start_negative_end_no_match():
20172018
# GH 56791

0 commit comments

Comments
 (0)