From 3edb9926ebe46c15537977bcedee04f414b1d9df Mon Sep 17 00:00:00 2001 From: Siddhartha Gandhi Date: Wed, 16 Nov 2022 23:26:09 -0500 Subject: [PATCH] Allow tuple[str, ...] pattern to be passed to str.startswith / str.endswith --- pandas-stubs/core/strings.pyi | 4 ++-- tests/test_series.py | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/pandas-stubs/core/strings.pyi b/pandas-stubs/core/strings.pyi index 800a2a821..9f1c5e455 100644 --- a/pandas-stubs/core/strings.pyi +++ b/pandas-stubs/core/strings.pyi @@ -142,8 +142,8 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS]): def get_dummies(self, sep: str = ...) -> pd.DataFrame: ... def translate(self, table: dict[int, int | str | None] | None) -> T: ... def count(self, pat: str, flags: int = ...) -> Series[int]: ... - def startswith(self, pat: str, na: Any = ...) -> Series[bool]: ... - def endswith(self, pat: str, na: Any = ...) -> Series[bool]: ... + def startswith(self, pat: str | tuple[str, ...], na: Any = ...) -> Series[bool]: ... + def endswith(self, pat: str | tuple[str, ...], na: Any = ...) -> Series[bool]: ... def findall(self, pat: str, flags: int = ...) -> Series: ... @overload def extract( diff --git a/tests/test_series.py b/tests/test_series.py index 1b578e933..f37631d6a 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -995,6 +995,7 @@ def test_string_accessors(): check(assert_type(s.str.decode("utf-8"), pd.Series), pd.Series) check(assert_type(s.str.encode("latin-1"), pd.Series), pd.Series) check(assert_type(s.str.endswith("e"), "pd.Series[bool]"), pd.Series, bool) + check(assert_type(s.str.endswith(("e", "f")), "pd.Series[bool]"), pd.Series, bool) check(assert_type(s3.str.extract(r"([ab])?(\d)"), pd.DataFrame), pd.DataFrame) check(assert_type(s3.str.extractall(r"([ab])?(\d)"), pd.DataFrame), pd.DataFrame) check(assert_type(s.str.find("p"), pd.Series), pd.Series) @@ -1038,6 +1039,11 @@ def test_string_accessors(): # GH 194 check(assert_type(s.str.split("a", expand=True), pd.DataFrame), pd.DataFrame) check(assert_type(s.str.startswith("a"), "pd.Series[bool]"), pd.Series, bool) + check( + assert_type(s.str.startswith(("a", "b")), "pd.Series[bool]"), + pd.Series, + bool, + ) check(assert_type(s.str.strip(), pd.Series), pd.Series) check(assert_type(s.str.swapcase(), pd.Series), pd.Series) check(assert_type(s.str.title(), pd.Series), pd.Series)