Skip to content

Commit 6d68683

Browse files
author
Rohan Jain
committed
fix find
1 parent a311f77 commit 6d68683

File tree

2 files changed

+29
-21
lines changed

2 files changed

+29
-21
lines changed

pandas/core/arrays/arrow/array.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2328,19 +2328,25 @@ def _str_fullmatch(
23282328
return self._str_match(pat, case, flags, na)
23292329

23302330
def _str_find(self, sub: str, start: int = 0, end: int | None = None) -> Self:
2331-
if start != 0 and end is not None:
2332-
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
2333-
result = pc.find_substring(slices, sub)
2334-
not_found = pc.equal(result, -1)
2335-
start_offset = max(0, start)
2336-
offset_result = pc.add(result, start_offset)
2337-
result = pc.if_else(not_found, result, offset_result)
2338-
elif start == 0 and end is None:
2339-
slices = self._pa_array
2340-
result = pc.find_substring(slices, sub)
2331+
if (start == 0 or start is None) and end is None:
2332+
result = pc.find_substring(self._pa_array, sub)
23412333
else:
2342-
raise NotImplementedError(
2343-
f"find not implemented with {sub=}, {start=}, {end=}"
2334+
result = pc.find_substring(self._pa_array, sub)
2335+
length = pc.utf8_length(self._pa_array)
2336+
if start is None:
2337+
start = pa.scalar(0, result.type)
2338+
elif start < 0:
2339+
start = pc.add(start, length)
2340+
if end is None:
2341+
end = length
2342+
elif end < 0:
2343+
end = pc.add(end, length)
2344+
found = pc.not_equal(pa.scalar(-1, type=result.type), result)
2345+
found_in_bounds = pc.and_(
2346+
pc.greater_equal(result, start), pc.less(result, end)
2347+
)
2348+
result = pc.if_else(
2349+
pc.and_(found, found_in_bounds), result, pa.scalar(-1, type=result.type)
23442350
)
23452351
return type(self)(result)
23462352

pandas/tests/extension/test_arrow.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1919,28 +1919,30 @@ def test_str_fullmatch(pat, case, na, exp):
19191919

19201920

19211921
@pytest.mark.parametrize(
1922-
"sub, start, end, exp, exp_typ",
1923-
[["ab", 0, None, [0, None], pa.int32()], ["bc", 1, 3, [1, None], pa.int64()]],
1922+
"sub, start, end, exp",
1923+
[["ab", 0, None, [0, None]], ["bc", 1, 3, [1, None]], ["ab", 1, None, [-1, None]]],
19241924
)
1925-
def test_str_find(sub, start, end, exp, exp_typ):
1925+
def test_str_find(sub, start, end, exp):
19261926
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
19271927
result = ser.str.find(sub, start=start, end=end)
1928-
expected = pd.Series(exp, dtype=ArrowDtype(exp_typ))
1928+
expected = pd.Series(exp, dtype=ArrowDtype(pa.int32()))
19291929
tm.assert_series_equal(result, expected)
19301930

19311931

19321932
def test_str_find_negative_start():
19331933
# GH 56411
19341934
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
19351935
result = ser.str.find(sub="b", start=-1000, end=3)
1936-
expected = pd.Series([1, None], dtype=ArrowDtype(pa.int64()))
1936+
expected = pd.Series([1, None], dtype=ArrowDtype(pa.int32()))
19371937
tm.assert_series_equal(result, expected)
19381938

19391939

1940-
def test_str_find_notimplemented():
1941-
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
1942-
with pytest.raises(NotImplementedError, match="find not implemented"):
1943-
ser.str.find("ab", start=1)
1940+
def test_str_find_negative_start_negative_end():
1941+
# GH 56411
1942+
ser = pd.Series(["abcdefg", None], dtype=ArrowDtype(pa.string()))
1943+
result = ser.str.find(sub="d", start=-6, end=-3)
1944+
expected = pd.Series([3, None], dtype=ArrowDtype(pa.int32()))
1945+
tm.assert_series_equal(result, expected)
19441946

19451947

19461948
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)