From 7a99bdbe3af18e9742a1216ab6d4b33c5d8b600f Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 20 Aug 2024 11:12:12 -0700 Subject: [PATCH 01/10] BUG (string): ArrowStringArray.find corner cases --- pandas/core/arrays/arrow/array.py | 4 ++-- pandas/core/arrays/string_arrow.py | 15 +-------------- pandas/tests/extension/test_arrow.py | 9 +++++---- 3 files changed, 8 insertions(+), 20 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 40819ba4ab338..ae6335b4699e0 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -2380,7 +2380,7 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None) -> Self: if sub == "": # GH 56792 result = self._apply_elementwise(lambda val: val.find(sub, start, end)) - return type(self)(pa.chunked_array(result)) + return self._convert_int_result(pa.chunked_array(result)) if start is None: start_offset = 0 start = 0 @@ -2394,7 +2394,7 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None) -> Self: found = pc.not_equal(result, pa.scalar(-1, type=result.type)) offset_result = pc.add(result, start_offset) result = pc.if_else(found, offset_result, -1) - return type(self)(result) + return self._convert_int_result(result) def _str_join(self, sep: str) -> Self: if pa.types.is_string(self._pa_array.type) or pa.types.is_large_string( diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index e18beb629d0c4..54908181f18b7 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -293,6 +293,7 @@ def astype(self, dtype, copy: bool = True): _str_startswith = ArrowStringArrayMixin._str_startswith _str_endswith = ArrowStringArrayMixin._str_endswith _str_pad = ArrowStringArrayMixin._str_pad + _str_find = ArrowExtensionArray._str_find def _str_contains( self, pat, case: bool = True, flags: int = 0, na=np.nan, regex: bool = True @@ -415,20 +416,6 @@ def _str_count(self, pat: str, flags: int = 0): result = pc.count_substring_regex(self._pa_array, pat) return self._convert_int_result(result) - def _str_find(self, sub: str, start: int = 0, end: int | None = None): - if start != 0 and end is not None: - slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end) - result = pc.find_substring(slices, sub) - not_found = pc.equal(result, -1) - offset_result = pc.add(result, end - start) - result = pc.if_else(not_found, result, offset_result) - elif start == 0 and end is None: - slices = self._pa_array - result = pc.find_substring(slices, sub) - else: - return super()._str_find(sub, start, end) - return self._convert_int_result(result) - def _str_get_dummies(self, sep: str = "|"): dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies(sep) if len(labels) == 0: diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 3dbdda388d035..a85a2fca8a9ff 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -32,8 +32,6 @@ import numpy as np import pytest -from pandas._config import using_string_dtype - from pandas._libs import lib from pandas._libs.tslibs import timezones from pandas.compat import ( @@ -1978,7 +1976,6 @@ def test_str_find_large_start(): tm.assert_series_equal(result, expected) -@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False) @pytest.mark.skipif( pa_version_under13p0, reason="https://github.com/apache/arrow/issues/36311" ) @@ -1990,11 +1987,15 @@ def test_str_find_e2e(start, end, sub): ["abcaadef", "abc", "abcdeddefgj8292", "ab", "a", ""], dtype=ArrowDtype(pa.string()), ) - object_series = s.astype(pd.StringDtype()) + object_series = s.astype(pd.StringDtype(storage="python")) result = s.str.find(sub, start, end) expected = object_series.str.find(sub, start, end).astype(result.dtype) tm.assert_series_equal(result, expected) + arrow_str_series = s.astype(pd.StringDtype(storage="pyarrow")) + result2 = arrow_str_series.str.find(sub, start, end).astype(result.dtype) + tm.assert_series_equal(result2, expected) + def test_str_find_negative_start_negative_end_no_match(): # GH 56791 From f7f19d32c39a4592bac0e2fd618ec793257c8ced Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 20 Aug 2024 14:49:18 -0700 Subject: [PATCH 02/10] xfail on old pyarrow --- pandas/tests/strings/test_find_replace.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/pandas/tests/strings/test_find_replace.py b/pandas/tests/strings/test_find_replace.py index bf01c4996bb32..9584ca912d6f2 100644 --- a/pandas/tests/strings/test_find_replace.py +++ b/pandas/tests/strings/test_find_replace.py @@ -4,6 +4,7 @@ import numpy as np import pytest +from pandas.compat import pa_version_under13p0 import pandas.util._test_decorators as td import pandas as pd @@ -981,13 +982,22 @@ def test_find_bad_arg_raises(any_string_dtype): ser.str.rfind(0) -def test_find_nan(any_string_dtype): +def test_find_nan(any_string_dtype, request): ser = Series( ["ABCDEFG", np.nan, "DEFGHIJEF", np.nan, "XXXX"], dtype=any_string_dtype ) expected_dtype = ( np.float64 if is_object_or_nan_string_dtype(any_string_dtype) else "Int64" ) + if ( + pa_version_under13p0 + and isinstance(ser.dtype, pd.StringDtype) + and ser.dtype.storage == "pyarrow" + ): + # https://github.com/apache/arrow/issues/36311 + mark = pytest.mark.xfail(reason="https://github.com/apache/arrow/issues/36311") + # raises pa.lib.ArrowInvalid with Negative buffer resize + request.node.add_marker(mark) result = ser.str.find("EF") expected = Series([4, np.nan, 1, np.nan, -1], dtype=expected_dtype) From d9f0aa7caaa4daa496df401eedfb945310a8ec23 Mon Sep 17 00:00:00 2001 From: Brock Date: Thu, 22 Aug 2024 09:47:30 -0700 Subject: [PATCH 03/10] fallback with older pyarrow --- pandas/core/arrays/string_arrow.py | 10 ++++++++++ pandas/tests/strings/test_find_replace.py | 12 +----------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index 54908181f18b7..c634ee657a22a 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -416,6 +416,16 @@ def _str_count(self, pat: str, flags: int = 0): result = pc.count_substring_regex(self._pa_array, pat) return self._convert_int_result(result) + def _str_find(self, sub: str, start: int = 0, end: int | None = None): + if ( + pa_version_under13p0 + and not (start != 0 and end is not None) + and not (start == 0 and end is None) + ): + # https://github.com/pandas-dev/pandas/pull/59562/files#r1725688888 + return super()._str_find(sub, start, end) + return ArrowExtensionArray._str_find(self, sub, start, end) + def _str_get_dummies(self, sep: str = "|"): dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies(sep) if len(labels) == 0: diff --git a/pandas/tests/strings/test_find_replace.py b/pandas/tests/strings/test_find_replace.py index 9584ca912d6f2..bf01c4996bb32 100644 --- a/pandas/tests/strings/test_find_replace.py +++ b/pandas/tests/strings/test_find_replace.py @@ -4,7 +4,6 @@ import numpy as np import pytest -from pandas.compat import pa_version_under13p0 import pandas.util._test_decorators as td import pandas as pd @@ -982,22 +981,13 @@ def test_find_bad_arg_raises(any_string_dtype): ser.str.rfind(0) -def test_find_nan(any_string_dtype, request): +def test_find_nan(any_string_dtype): ser = Series( ["ABCDEFG", np.nan, "DEFGHIJEF", np.nan, "XXXX"], dtype=any_string_dtype ) expected_dtype = ( np.float64 if is_object_or_nan_string_dtype(any_string_dtype) else "Int64" ) - if ( - pa_version_under13p0 - and isinstance(ser.dtype, pd.StringDtype) - and ser.dtype.storage == "pyarrow" - ): - # https://github.com/apache/arrow/issues/36311 - mark = pytest.mark.xfail(reason="https://github.com/apache/arrow/issues/36311") - # raises pa.lib.ArrowInvalid with Negative buffer resize - request.node.add_marker(mark) result = ser.str.find("EF") expected = Series([4, np.nan, 1, np.nan, -1], dtype=expected_dtype) From f11921e273599041dd8b25b14305a5f0d7319801 Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 27 Aug 2024 07:51:59 -0700 Subject: [PATCH 04/10] REF: move implementation to ArrowStringArrayMixin --- pandas/core/arrays/_arrow_string_mixins.py | 23 ++++++++++++++++++++++ pandas/core/arrays/arrow/array.py | 23 ---------------------- pandas/core/arrays/string_arrow.py | 2 +- 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/pandas/core/arrays/_arrow_string_mixins.py b/pandas/core/arrays/_arrow_string_mixins.py index 5b34a7e2c7cef..31bdb478cbb5f 100644 --- a/pandas/core/arrays/_arrow_string_mixins.py +++ b/pandas/core/arrays/_arrow_string_mixins.py @@ -205,3 +205,26 @@ def _str_contains( if not isna(na): # pyright: ignore [reportGeneralTypeIssues] result = result.fill_null(na) return self._convert_bool_result(result) + + def _str_find(self, sub: str, start: int = 0, end: int | None = None): + if (start == 0 or start is None) and end is None: + result = pc.find_substring(self._pa_array, sub) + else: + if sub == "": + # GH#56792 + result = self._apply_elementwise(lambda val: val.find(sub, start, end)) + return self._convert_int_result(pa.chunked_array(result)) + if start is None: + start_offset = 0 + start = 0 + elif start < 0: + start_offset = pc.add(start, pc.utf8_length(self._pa_array)) + start_offset = pc.if_else(pc.less(start_offset, 0), 0, start_offset) + else: + start_offset = start + slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end) + result = pc.find_substring(slices, sub) + found = pc.not_equal(result, pa.scalar(-1, type=result.type)) + offset_result = pc.add(result, start_offset) + result = pc.if_else(found, offset_result, -1) + return self._convert_int_result(result) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index ae6335b4699e0..15f9ba611a642 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -2373,29 +2373,6 @@ def _str_fullmatch( pat = f"{pat}$" return self._str_match(pat, case, flags, na) - def _str_find(self, sub: str, start: int = 0, end: int | None = None) -> Self: - if (start == 0 or start is None) and end is None: - result = pc.find_substring(self._pa_array, sub) - else: - if sub == "": - # GH 56792 - result = self._apply_elementwise(lambda val: val.find(sub, start, end)) - return self._convert_int_result(pa.chunked_array(result)) - if start is None: - start_offset = 0 - start = 0 - elif start < 0: - start_offset = pc.add(start, pc.utf8_length(self._pa_array)) - start_offset = pc.if_else(pc.less(start_offset, 0), 0, start_offset) - else: - start_offset = start - slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end) - result = pc.find_substring(slices, sub) - found = pc.not_equal(result, pa.scalar(-1, type=result.type)) - offset_result = pc.add(result, start_offset) - result = pc.if_else(found, offset_result, -1) - return self._convert_int_result(result) - def _str_join(self, sep: str) -> Self: if pa.types.is_string(self._pa_array.type) or pa.types.is_large_string( self._pa_array.type diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index c634ee657a22a..951437d2e95e3 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -424,7 +424,7 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None): ): # https://github.com/pandas-dev/pandas/pull/59562/files#r1725688888 return super()._str_find(sub, start, end) - return ArrowExtensionArray._str_find(self, sub, start, end) + return ArrowStringArrayMixin._str_find(self, sub, start, end) def _str_get_dummies(self, sep: str = "|"): dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies(sep) From c34ae46c0fb06229861a57e16fbf9525196da9ba Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 27 Aug 2024 08:34:18 -0700 Subject: [PATCH 05/10] mypy fixup --- pandas/core/arrays/_arrow_string_mixins.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pandas/core/arrays/_arrow_string_mixins.py b/pandas/core/arrays/_arrow_string_mixins.py index 31bdb478cbb5f..f21ca225fbbf5 100644 --- a/pandas/core/arrays/_arrow_string_mixins.py +++ b/pandas/core/arrays/_arrow_string_mixins.py @@ -3,6 +3,7 @@ from functools import partial from typing import ( TYPE_CHECKING, + Any, Literal, ) @@ -20,7 +21,10 @@ import pyarrow.compute as pc if TYPE_CHECKING: - from collections.abc import Sized + from collections.abc import ( + Callable, + Sized, + ) from pandas._typing import ( Scalar, @@ -42,6 +46,9 @@ def _convert_int_result(self, result): # Convert an integer-dtype result to the appropriate result type raise NotImplementedError + def _apply_elementwise(self, func: Callable) -> list[list[Any]]: + raise NotImplementedError + def _str_pad( self, width: int, From 86ef129f35ef90b74ae3253ca4630db67de0e0ec Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 27 Aug 2024 10:11:55 -0700 Subject: [PATCH 06/10] trim bad link --- pandas/core/arrays/string_arrow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index 951437d2e95e3..ef4fa12f565f7 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -422,7 +422,7 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None): and not (start != 0 and end is not None) and not (start == 0 and end is None) ): - # https://github.com/pandas-dev/pandas/pull/59562/files#r1725688888 + # GH#59562 return super()._str_find(sub, start, end) return ArrowStringArrayMixin._str_find(self, sub, start, end) From 472f17a86ada3e50da318152f7888ec523a4ea03 Mon Sep 17 00:00:00 2001 From: Brock Date: Thu, 29 Aug 2024 08:11:46 -0700 Subject: [PATCH 07/10] fallback to pointwise for ArrowEA --- pandas/core/arrays/_arrow_string_mixins.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pandas/core/arrays/_arrow_string_mixins.py b/pandas/core/arrays/_arrow_string_mixins.py index f21ca225fbbf5..9c0cd1fcca51a 100644 --- a/pandas/core/arrays/_arrow_string_mixins.py +++ b/pandas/core/arrays/_arrow_string_mixins.py @@ -11,6 +11,7 @@ from pandas.compat import ( pa_version_under10p1, + pa_version_under13p0, pa_version_under17p0, ) @@ -214,6 +215,15 @@ def _str_contains( return self._convert_bool_result(result) def _str_find(self, sub: str, start: int = 0, end: int | None = None): + if ( + pa_version_under13p0 + and not (start != 0 and end is not None) + and not (start == 0 and end is None) + ): + # GH#59562 + result = self._apply_elementwise(lambda val: val.find(sub, start, end)) + return self._convert_int_result(pa.chunked_array(result)) + if (start == 0 or start is None) and end is None: result = pc.find_substring(self._pa_array, sub) else: From e4c782c9a2fbc2dfec6d6e8f258f4fa13add91f4 Mon Sep 17 00:00:00 2001 From: Brock Date: Thu, 29 Aug 2024 08:35:31 -0700 Subject: [PATCH 08/10] mypy fixup --- pandas/core/arrays/_arrow_string_mixins.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pandas/core/arrays/_arrow_string_mixins.py b/pandas/core/arrays/_arrow_string_mixins.py index 9c0cd1fcca51a..950d4cd7cc92e 100644 --- a/pandas/core/arrays/_arrow_string_mixins.py +++ b/pandas/core/arrays/_arrow_string_mixins.py @@ -221,16 +221,18 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None): and not (start == 0 and end is None) ): # GH#59562 - result = self._apply_elementwise(lambda val: val.find(sub, start, end)) - return self._convert_int_result(pa.chunked_array(result)) + res_list = self._apply_elementwise(lambda val: val.find(sub, start, end)) + return self._convert_int_result(pa.chunked_array(res_list)) if (start == 0 or start is None) and end is None: result = pc.find_substring(self._pa_array, sub) else: if sub == "": # GH#56792 - result = self._apply_elementwise(lambda val: val.find(sub, start, end)) - return self._convert_int_result(pa.chunked_array(result)) + res_list = self._apply_elementwise( + lambda val: val.find(sub, start, end) + ) + return self._convert_int_result(pa.chunked_array(res_list)) if start is None: start_offset = 0 start = 0 From e1b7913f75adb4c5c3c3911c4fd94cacff7e54d1 Mon Sep 17 00:00:00 2001 From: Brock Date: Thu, 29 Aug 2024 15:16:19 -0700 Subject: [PATCH 09/10] TST: un-xfail --- pandas/tests/extension/test_arrow.py | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index a85a2fca8a9ff..fc4f14882b9d7 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1945,14 +1945,9 @@ def test_str_find_negative_start(): def test_str_find_no_end(): ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) - if pa_version_under13p0: - # https://github.com/apache/arrow/issues/36311 - with pytest.raises(pa.lib.ArrowInvalid, match="Negative buffer resize"): - ser.str.find("ab", start=1) - else: - result = ser.str.find("ab", start=1) - expected = pd.Series([-1, None], dtype="int64[pyarrow]") - tm.assert_series_equal(result, expected) + result = ser.str.find("ab", start=1) + expected = pd.Series([-1, None], dtype="int64[pyarrow]") + tm.assert_series_equal(result, expected) def test_str_find_negative_start_negative_end(): @@ -1966,14 +1961,9 @@ def test_str_find_negative_start_negative_end(): def test_str_find_large_start(): # GH 56791 ser = pd.Series(["abcdefg", None], dtype=ArrowDtype(pa.string())) - if pa_version_under13p0: - # https://github.com/apache/arrow/issues/36311 - with pytest.raises(pa.lib.ArrowInvalid, match="Negative buffer resize"): - ser.str.find(sub="d", start=16) - else: - result = ser.str.find(sub="d", start=16) - expected = pd.Series([-1, None], dtype=ArrowDtype(pa.int64())) - tm.assert_series_equal(result, expected) + result = ser.str.find(sub="d", start=16) + expected = pd.Series([-1, None], dtype=ArrowDtype(pa.int64())) + tm.assert_series_equal(result, expected) @pytest.mark.skipif( From 8f07638ad606e771ee95d038e624ab957595cafd Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 3 Sep 2024 07:13:45 -0700 Subject: [PATCH 10/10] fixup post-rebase --- pandas/core/arrays/string_arrow.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index ef4fa12f565f7..97381b82ceab9 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -293,7 +293,6 @@ def astype(self, dtype, copy: bool = True): _str_startswith = ArrowStringArrayMixin._str_startswith _str_endswith = ArrowStringArrayMixin._str_endswith _str_pad = ArrowStringArrayMixin._str_pad - _str_find = ArrowExtensionArray._str_find def _str_contains( self, pat, case: bool = True, flags: int = 0, na=np.nan, regex: bool = True