diff --git a/doc/source/whatsnew/v2.2.0.rst b/doc/source/whatsnew/v2.2.0.rst index 8c475791df64d..645a3cdb967ca 100644 --- a/doc/source/whatsnew/v2.2.0.rst +++ b/doc/source/whatsnew/v2.2.0.rst @@ -601,6 +601,7 @@ Strings - Bug in :meth:`Series.__mul__` for :class:`ArrowDtype` with ``pyarrow.string`` dtype and ``string[pyarrow]`` for the pyarrow backend (:issue:`51970`) - Bug in :meth:`Series.str.find` when ``start < 0`` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`56411`) - Bug in :meth:`Series.str.replace` when ``n < 0`` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`56404`) +- Bug in :meth:`Series.str.startswith` and :meth:`Series.str.endswith` with arguments of type ``tuple[str, ...]`` for :class:`ArrowDtype` with ``pyarrow.string`` dtype (:issue:`56579`) - Bug in :meth:`Series.str.startswith` and :meth:`Series.str.endswith` with arguments of type ``tuple[str, ...]`` for ``string[pyarrow]`` (:issue:`54942`) Interval diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 633efe43fce1a..de1fb0ec5d4d5 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -2174,14 +2174,36 @@ def _str_contains( result = result.fill_null(na) return type(self)(result) - def _str_startswith(self, pat: str, na=None): - result = pc.starts_with(self._pa_array, pattern=pat) + def _str_startswith(self, pat: str | tuple[str, ...], na=None): + if isinstance(pat, str): + result = pc.starts_with(self._pa_array, pattern=pat) + else: + if len(pat) == 0: + # For empty tuple, pd.StringDtype() returns null for missing values + # and false for valid values. + result = pc.if_else(pc.is_null(self._pa_array), None, False) + else: + result = pc.starts_with(self._pa_array, pattern=pat[0]) + + for p in pat[1:]: + result = pc.or_(result, pc.starts_with(self._pa_array, pattern=p)) if not isna(na): result = result.fill_null(na) return type(self)(result) - def _str_endswith(self, pat: str, na=None): - result = pc.ends_with(self._pa_array, pattern=pat) + def _str_endswith(self, pat: str | tuple[str, ...], na=None): + if isinstance(pat, str): + result = pc.ends_with(self._pa_array, pattern=pat) + else: + if len(pat) == 0: + # For empty tuple, pd.StringDtype() returns null for missing values + # and false for valid values. + result = pc.if_else(pc.is_null(self._pa_array), None, False) + else: + result = pc.ends_with(self._pa_array, pattern=pat[0]) + + for p in pat[1:]: + result = pc.or_(result, pc.ends_with(self._pa_array, pattern=p)) if not isna(na): result = result.fill_null(na) return type(self)(result) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 674a5da216011..28ee5446d4b5c 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1801,19 +1801,33 @@ def test_str_contains_flags_unsupported(): @pytest.mark.parametrize( "side, pat, na, exp", [ - ["startswith", "ab", None, [True, None]], - ["startswith", "b", False, [False, False]], - ["endswith", "b", True, [False, True]], - ["endswith", "bc", None, [True, None]], + ["startswith", "ab", None, [True, None, False]], + ["startswith", "b", False, [False, False, False]], + ["endswith", "b", True, [False, True, False]], + ["endswith", "bc", None, [True, None, False]], + ["startswith", ("a", "e", "g"), None, [True, None, True]], + ["endswith", ("a", "c", "g"), None, [True, None, True]], + ["startswith", (), None, [False, None, False]], + ["endswith", (), None, [False, None, False]], ], ) def test_str_start_ends_with(side, pat, na, exp): - ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + ser = pd.Series(["abc", None, "efg"], dtype=ArrowDtype(pa.string())) result = getattr(ser.str, side)(pat, na=na) expected = pd.Series(exp, dtype=ArrowDtype(pa.bool_())) tm.assert_series_equal(result, expected) +@pytest.mark.parametrize("side", ("startswith", "endswith")) +def test_str_starts_ends_with_all_nulls_empty_tuple(side): + ser = pd.Series([None, None], dtype=ArrowDtype(pa.string())) + result = getattr(ser.str, side)(()) + + # bool datatype preserved for all nulls. + expected = pd.Series([None, None], dtype=ArrowDtype(pa.bool_())) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( "arg_name, arg", [["pat", re.compile("b")], ["repl", str], ["case", False], ["flags", 1]],