Skip to content

Commit 32db176

Browse files
bashtageKevin Sheppard
and
Kevin Sheppard
authored
BUG: Complete str accessor methods (#157)
* BUG: Complete str accessor methods Add remaining methods closes #155 * MAINT: Fix unused import * BUG/ENH: Clean string accessor methods Remove invalid methods Correct all types * TST: Add types for testing * TST: Add many tests * TST: Test string accessor overloads Test overloads for correctness * TYP: Final typing for string accessor * CLN: Simplify overload and add test Simplify overload Add test for other forms Change return type to Series * CLN: Clean up after rebase Co-authored-by: Kevin Sheppard <kevin.sheppard@gmail.com>
1 parent 9383884 commit 32db176

File tree

2 files changed

+276
-91
lines changed

2 files changed

+276
-91
lines changed

pandas-stubs/core/strings.pyi

Lines changed: 159 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,104 +1,172 @@
11
from __future__ import annotations
22

3-
from typing import Generic
3+
import re
4+
from typing import (
5+
Any,
6+
Callable,
7+
Generic,
8+
Literal,
9+
Sequence,
10+
overload,
11+
)
412

13+
import numpy as np
14+
import pandas as pd
515
from pandas import Series
616
from pandas.core.base import NoNewAttributesMixin as NoNewAttributesMixin
717

8-
from pandas._typing import T
18+
from pandas._typing import (
19+
F,
20+
T,
21+
)
922

10-
def cat_core(list_of_columns: list, sep: str): ...
11-
def cat_safe(list_of_columns: list, sep: str): ...
12-
def str_count(arr, pat, flags: int = ...): ...
13-
def str_contains(
14-
arr, pat, case: bool = ..., flags: int = ..., na=..., regex: bool = ...
15-
): ...
16-
def str_startswith(arr, pat, na=...): ...
17-
def str_endswith(arr, pat, na=...): ...
18-
def str_replace(
19-
arr, pat, repl, n: int = ..., case=..., flags: int = ..., regex: bool = ...
20-
): ...
21-
def str_repeat(arr, repeats): ...
22-
def str_match(arr, pat, case: bool = ..., flags: int = ..., na=...): ...
23-
def str_extract(arr, pat, flags: int = ..., expand: bool = ...): ...
24-
def str_extractall(arr, pat, flags: int = ...): ...
25-
def str_get_dummies(arr, sep: str = ...): ...
26-
def str_join(arr, sep): ...
27-
def str_findall(arr, pat, flags: int = ...): ...
28-
def str_find(arr, sub, start: int = ..., end=..., side: str = ...): ...
29-
def str_index(arr, sub, start: int = ..., end=..., side: str = ...): ...
30-
def str_pad(arr, width, side: str = ..., fillchar: str = ...): ...
31-
def str_split(arr, pat=..., n=...): ...
32-
def str_rsplit(arr, pat=..., n=...): ...
33-
def str_slice(arr, start=..., stop=..., step=...): ...
34-
def str_slice_replace(arr, start=..., stop=..., repl=...): ...
35-
def str_strip(arr, to_strip=..., side: str = ...): ...
36-
def str_wrap(arr, width, **kwargs): ...
37-
def str_translate(arr, table): ...
38-
def str_get(arr, i): ...
39-
def str_decode(arr, encoding, errors: str = ...): ...
40-
def str_encode(arr, encoding, errors: str = ...): ...
41-
def forbid_nonstring_types(forbidden, name=...): ...
42-
def copy(source): ...
23+
def cat_core(list_of_columns: list[np.ndarray], sep: str) -> np.ndarray: ...
24+
def cat_safe(list_of_columns: list[np.ndarray], sep: str) -> np.ndarray: ...
25+
def forbid_nonstring_types(
26+
forbidden: list[str] | None, name: str | None = ...
27+
) -> Callable[[F], F]: ...
4328

4429
class StringMethods(NoNewAttributesMixin, Generic[T]):
45-
def __init__(self, data) -> None: ...
46-
def __getitem__(self, key) -> T: ...
47-
def __iter__(self): ...
48-
def cat(self, others=..., sep=..., na_rep=..., join: str = ...) -> T: ...
49-
def split(self, pat=..., n: int = ..., expand: bool = ...) -> T: ...
50-
def rsplit(self, pat=..., n: int = ..., expand: bool = ...) -> T: ...
51-
def partition(self, sep: str = ..., expand: bool = ...) -> T: ...
52-
def rpartition(self, sep: str = ..., expand: bool = ...) -> T: ...
53-
def get(self, i) -> T: ...
54-
def join(self, sep) -> T: ...
30+
def __init__(self, data: T) -> None: ...
31+
def __getitem__(self, key: slice | int) -> T: ...
32+
def __iter__(self) -> T: ...
33+
@overload
34+
def cat(
35+
self,
36+
*,
37+
sep: str,
38+
na_rep: str | None = ...,
39+
join: Literal["left", "right", "outer", "inner"] = ...,
40+
) -> str: ...
41+
@overload
42+
def cat(
43+
self,
44+
others: Literal[None] = ...,
45+
*,
46+
sep: str,
47+
na_rep: str | None = ...,
48+
join: Literal["left", "right", "outer", "inner"] = ...,
49+
) -> str: ...
50+
@overload
51+
def cat(
52+
self,
53+
others: Series | pd.Index | pd.DataFrame | np.ndarray | list[Any],
54+
sep: str = ...,
55+
na_rep: str | None = ...,
56+
join: Literal["left", "right", "outer", "inner"] = ...,
57+
) -> T: ...
58+
def split(
59+
self, pat: str = ..., n: int = ..., expand: bool = ..., *, regex: bool = ...
60+
) -> T: ...
61+
def rsplit(
62+
self, pat: str = ..., n: int = ..., expand: bool = ..., *, regex: bool = ...
63+
) -> T: ...
64+
@overload
65+
def partition(self, sep: str = ...) -> pd.DataFrame: ...
66+
@overload
67+
def partition(self, *, expand: Literal[True]) -> pd.DataFrame: ...
68+
@overload
69+
def partition(self, sep: str, expand: Literal[True]) -> pd.DataFrame: ...
70+
@overload
71+
def partition(self, sep: str, expand: Literal[False]) -> T: ...
72+
@overload
73+
def partition(self, *, expand: Literal[False]) -> T: ...
74+
@overload
75+
def rpartition(self, sep: str = ...) -> pd.DataFrame: ...
76+
@overload
77+
def rpartition(self, *, expand: Literal[True]) -> pd.DataFrame: ...
78+
@overload
79+
def rpartition(self, sep: str, expand: Literal[True]) -> pd.DataFrame: ...
80+
@overload
81+
def rpartition(self, sep: str, expand: Literal[False]) -> T: ...
82+
@overload
83+
def rpartition(self, *, expand: Literal[False]) -> T: ...
84+
def get(self, i: int) -> T: ...
85+
def join(self, sep: str) -> T: ...
5586
def contains(
56-
self, pat, case: bool = ..., flags: int = ..., na=..., regex: bool = ...
87+
self, pat: str, case: bool = ..., flags: int = ..., na=..., regex: bool = ...
5788
) -> Series[bool]: ...
58-
def match(self, pat, case: bool = ..., flags: int = ..., na=...) -> T: ...
89+
def match(
90+
self, pat: str, case: bool = ..., flags: int = ..., na: Any = ...
91+
) -> T: ...
5992
def replace(
60-
self, pat, repl, n: int = ..., case=..., flags: int = ..., regex: bool = ...
93+
self,
94+
pat: str,
95+
repl: str | Callable[[re.Match], str],
96+
n: int = ...,
97+
case: bool | None = ...,
98+
flags: int = ...,
99+
regex: bool = ...,
100+
) -> T: ...
101+
def repeat(self, repeats: int | Sequence[int]) -> T: ...
102+
def pad(
103+
self,
104+
width: int,
105+
side: Literal["left", "right", "both"] = ...,
106+
fillchar: str = ...,
61107
) -> T: ...
62-
def repeat(self, repeats) -> T: ...
63-
def pad(self, width, side: str = ..., fillchar: str = ...) -> T: ...
64-
def center(self, width, fillchar: str = ...) -> T: ...
65-
def ljust(self, width, fillchar: str = ...) -> T: ...
66-
def rjust(self, width, fillchar: str = ...) -> T: ...
67-
def zfill(self, width) -> T: ...
68-
def slice(self, start=..., stop=..., step=...) -> T: ...
69-
def slice_replace(self, start=..., stop=..., repl=...) -> T: ...
70-
def decode(self, encoding, errors: str = ...) -> T: ...
71-
def encode(self, encoding, errors: str = ...) -> T: ...
72-
def strip(self, to_strip=...) -> T: ...
73-
def lstrip(self, to_strip=...) -> T: ...
74-
def rstrip(self, to_strip=...) -> T: ...
75-
def wrap(self, width, **kwargs) -> T: ...
76-
def get_dummies(self, sep: str = ...) -> T: ...
77-
def translate(self, table) -> T: ...
78-
count = ...
79-
startswith = ...
80-
endswith = ...
81-
findall = ...
82-
def extract(self, pat, flags: int = ..., expand: bool = ...) -> T: ...
83-
def extractall(self, pat, flags: int = ...) -> T: ...
84-
def find(self, sub, start: int = ..., end=...) -> T: ...
85-
def rfind(self, sub, start: int = ..., end=...) -> T: ...
86-
def normalize(self, form) -> T: ...
87-
def index(self, sub, start: int = ..., end=...) -> T: ...
88-
def rindex(self, sub, start: int = ..., end=...) -> T: ...
89-
len = ...
90-
lower = ...
91-
upper = ...
92-
title = ...
93-
capitalize = ...
94-
swapcase = ...
95-
casefold = ...
96-
isalnum = ...
97-
isalpha = ...
98-
isdigit = ...
99-
isspace = ...
100-
islower = ...
101-
isupper = ...
102-
istitle = ...
103-
isnumeric = ...
104-
isdecimal = ...
108+
def center(self, width: int, fillchar: str = ...) -> T: ...
109+
def ljust(self, width: int, fillchar: str = ...) -> T: ...
110+
def rjust(self, width: int, fillchar: str = ...) -> T: ...
111+
def zfill(self, width: int) -> T: ...
112+
def slice(
113+
self, start: int | None = ..., stop: int | None = ..., step: int | None = ...
114+
) -> T: ...
115+
def slice_replace(
116+
self, start: int | None = ..., stop: int | None = ..., repl: str | None = ...
117+
) -> T: ...
118+
def decode(self, encoding: str, errors: str = ...) -> T: ...
119+
def encode(self, encoding: str, errors: str = ...) -> T: ...
120+
def strip(self, to_strip: str | None = ...) -> T: ...
121+
def lstrip(self, to_strip: str | None = ...) -> T: ...
122+
def rstrip(self, to_strip: str | None = ...) -> T: ...
123+
def wrap(
124+
self,
125+
width: int,
126+
expand_tabs: bool | None = ...,
127+
replace_whitespace: bool | None = ...,
128+
drop_whitespace: bool | None = ...,
129+
break_long_words: bool | None = ...,
130+
break_on_hyphens: bool | None = ...,
131+
) -> T: ...
132+
def get_dummies(self, sep: str = ...) -> pd.DataFrame: ...
133+
def translate(self, table: dict[int, int | str | None] | None) -> T: ...
134+
def count(self, pat: str, flags: int = ...) -> Series[int]: ...
135+
def startswith(self, pat: str, na: Any = ...) -> Series[bool]: ...
136+
def endswith(self, pat: str, na: Any = ...) -> Series[bool]: ...
137+
def findall(self, pat: str, flags: int = ...) -> Series: ...
138+
@overload
139+
def extract(
140+
self, pat: str, flags: int = ..., *, expand: Literal[True] = ...
141+
) -> pd.DataFrame: ...
142+
@overload
143+
def extract(self, pat: str, flags: int, expand: Literal[False]) -> T: ...
144+
@overload
145+
def extract(self, pat: str, flags: int = ..., *, expand: Literal[False]) -> T: ...
146+
def extractall(self, pat: str, flags: int = ...) -> pd.DataFrame: ...
147+
def find(self, sub: str, start: int = ..., end: int | None = ...) -> T: ...
148+
def rfind(self, sub: str, start: int = ..., end: int | None = ...) -> T: ...
149+
def normalize(self, form: Literal["NFC", "NFKC", "NFD", "NFKD"]) -> T: ...
150+
def index(self, sub: str, start: int = ..., end: int | None = ...) -> T: ...
151+
def rindex(self, sub: str, start: int = ..., end: int | None = ...) -> T: ...
152+
def len(self) -> Series[int]: ...
153+
def lower(self) -> T: ...
154+
def upper(self) -> T: ...
155+
def title(self) -> T: ...
156+
def capitalize(self) -> T: ...
157+
def swapcase(self) -> T: ...
158+
def casefold(self) -> T: ...
159+
def isalnum(self) -> Series[bool]: ...
160+
def isalpha(self) -> Series[bool]: ...
161+
def isdigit(self) -> Series[bool]: ...
162+
def isspace(self) -> Series[bool]: ...
163+
def islower(self) -> Series[bool]: ...
164+
def isupper(self) -> Series[bool]: ...
165+
def istitle(self) -> Series[bool]: ...
166+
def isnumeric(self) -> Series[bool]: ...
167+
def isdecimal(self) -> Series[bool]: ...
168+
def fullmatch(
169+
self, pat: str, case: bool = ..., flags: int = ..., na: Any = ...
170+
) -> Series[bool]: ...
171+
def removeprefix(self, prefix: str) -> T: ...
172+
def removesuffix(self, suffix: str) -> T: ...

tests/test_series.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from pathlib import Path
4+
import re
45
import tempfile
56
from typing import (
67
TYPE_CHECKING,
@@ -831,3 +832,119 @@ def test_categorical_codes():
831832
# GH-111
832833
cat = pd.Categorical(["a", "b", "a"])
833834
assert_type(cat.codes, "np_ndarray_int")
835+
836+
837+
def test_string_accessors():
838+
s = pd.Series(
839+
["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"]
840+
)
841+
s2 = pd.Series([["apple", "banana"], ["cherry", "date"], [1, "eggplant"]])
842+
s3 = pd.Series(["a1", "b2", "c3"])
843+
check(assert_type(s.str.capitalize(), pd.Series), pd.Series)
844+
check(assert_type(s.str.casefold(), pd.Series), pd.Series)
845+
check(assert_type(s.str.cat(sep="X"), str), str)
846+
check(assert_type(s.str.center(10), pd.Series), pd.Series)
847+
check(assert_type(s.str.contains("a"), "pd.Series[bool]"), pd.Series, bool)
848+
check(assert_type(s.str.count("pp"), "pd.Series[int]"), pd.Series, int)
849+
check(assert_type(s.str.decode("utf-8"), pd.Series), pd.Series)
850+
check(assert_type(s.str.encode("latin-1"), pd.Series), pd.Series)
851+
check(assert_type(s.str.endswith("e"), "pd.Series[bool]"), pd.Series, bool)
852+
check(assert_type(s3.str.extract(r"([ab])?(\d)"), pd.DataFrame), pd.DataFrame)
853+
check(assert_type(s3.str.extractall(r"([ab])?(\d)"), pd.DataFrame), pd.DataFrame)
854+
check(assert_type(s.str.find("p"), pd.Series), pd.Series)
855+
check(assert_type(s.str.findall("pp"), pd.Series), pd.Series)
856+
check(assert_type(s.str.fullmatch("apple"), "pd.Series[bool]"), pd.Series, bool)
857+
check(assert_type(s.str.get(2), pd.Series), pd.Series)
858+
check(assert_type(s.str.get_dummies(), pd.DataFrame), pd.DataFrame)
859+
check(assert_type(s.str.index("p"), pd.Series), pd.Series)
860+
check(assert_type(s.str.isalnum(), "pd.Series[bool]"), pd.Series, bool)
861+
check(assert_type(s.str.isalpha(), "pd.Series[bool]"), pd.Series, bool)
862+
check(assert_type(s.str.isdecimal(), "pd.Series[bool]"), pd.Series, bool)
863+
check(assert_type(s.str.isdigit(), "pd.Series[bool]"), pd.Series, bool)
864+
check(assert_type(s.str.isnumeric(), "pd.Series[bool]"), pd.Series, bool)
865+
check(assert_type(s.str.islower(), "pd.Series[bool]"), pd.Series, bool)
866+
check(assert_type(s.str.isspace(), "pd.Series[bool]"), pd.Series, bool)
867+
check(assert_type(s.str.istitle(), "pd.Series[bool]"), pd.Series, bool)
868+
check(assert_type(s.str.isupper(), "pd.Series[bool]"), pd.Series, bool)
869+
check(assert_type(s2.str.join("-"), pd.Series), pd.Series)
870+
check(assert_type(s.str.len(), "pd.Series[int]"), pd.Series, int)
871+
check(assert_type(s.str.ljust(80), pd.Series), pd.Series)
872+
check(assert_type(s.str.lower(), pd.Series), pd.Series)
873+
check(assert_type(s.str.lstrip("a"), pd.Series), pd.Series)
874+
check(assert_type(s.str.match("pp"), pd.Series), pd.Series)
875+
check(assert_type(s.str.normalize("NFD"), pd.Series), pd.Series)
876+
check(assert_type(s.str.pad(80, "right"), pd.Series), pd.Series)
877+
check(assert_type(s.str.partition("p"), pd.DataFrame), pd.DataFrame)
878+
check(assert_type(s.str.removeprefix("a"), pd.Series), pd.Series)
879+
check(assert_type(s.str.removesuffix("e"), pd.Series), pd.Series)
880+
check(assert_type(s.str.repeat(2), pd.Series), pd.Series)
881+
check(assert_type(s.str.replace("a", "X"), pd.Series), pd.Series)
882+
check(assert_type(s.str.rfind("e"), pd.Series), pd.Series)
883+
check(assert_type(s.str.rindex("p"), pd.Series), pd.Series)
884+
check(assert_type(s.str.rjust(80), pd.Series), pd.Series)
885+
check(assert_type(s.str.rpartition("p"), pd.DataFrame), pd.DataFrame)
886+
check(assert_type(s.str.rsplit("a"), pd.Series), pd.Series)
887+
check(assert_type(s.str.rstrip(), pd.Series), pd.Series)
888+
check(assert_type(s.str.slice(0, 4, 2), pd.Series), pd.Series)
889+
check(assert_type(s.str.slice_replace(0, 2, "XX"), pd.Series), pd.Series)
890+
check(assert_type(s.str.split("a"), pd.Series), pd.Series)
891+
check(assert_type(s.str.startswith("a"), "pd.Series[bool]"), pd.Series, bool)
892+
check(assert_type(s.str.strip(), pd.Series), pd.Series)
893+
check(assert_type(s.str.swapcase(), pd.Series), pd.Series)
894+
check(assert_type(s.str.title(), pd.Series), pd.Series)
895+
check(assert_type(s.str.translate(None), pd.Series), pd.Series)
896+
check(assert_type(s.str.upper(), pd.Series), pd.Series)
897+
check(assert_type(s.str.wrap(80), pd.Series), pd.Series)
898+
check(assert_type(s.str.zfill(10), pd.Series), pd.Series)
899+
900+
901+
def test_series_overloads_cat():
902+
s = pd.Series(
903+
["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"]
904+
)
905+
check(assert_type(s.str.cat(sep=";"), str), str)
906+
check(assert_type(s.str.cat(None, sep=";"), str), str)
907+
check(
908+
assert_type(s.str.cat(["A", "B", "C", "D", "E", "F", "G"], sep=";"), pd.Series),
909+
pd.Series,
910+
)
911+
912+
913+
def test_series_overloads_partition():
914+
s = pd.Series(
915+
[
916+
"ap;pl;ep",
917+
"ban;an;ap",
918+
"Che;rr;yp",
919+
"DA;TEp",
920+
"eGGp;LANT;p",
921+
"12;3p",
922+
"23.45p",
923+
]
924+
)
925+
check(assert_type(s.str.partition(sep=";"), pd.DataFrame), pd.DataFrame)
926+
check(
927+
assert_type(s.str.partition(sep=";", expand=True), pd.DataFrame), pd.DataFrame
928+
)
929+
check(assert_type(s.str.partition(sep=";", expand=False), pd.Series), pd.Series)
930+
931+
check(assert_type(s.str.rpartition(sep=";"), pd.DataFrame), pd.DataFrame)
932+
check(
933+
assert_type(s.str.rpartition(sep=";", expand=True), pd.DataFrame), pd.DataFrame
934+
)
935+
check(assert_type(s.str.rpartition(sep=";", expand=False), pd.Series), pd.Series)
936+
937+
938+
def test_series_overloads_extract():
939+
s = pd.Series(
940+
["appl;ep", "ban;anap", "Cherr;yp", "DATEp", "eGGp;LANTp", "12;3p", "23.45p"]
941+
)
942+
check(assert_type(s.str.extract(r"[ab](\d)"), pd.DataFrame), pd.DataFrame)
943+
check(
944+
assert_type(s.str.extract(r"[ab](\d)", expand=True), pd.DataFrame), pd.DataFrame
945+
)
946+
check(assert_type(s.str.extract(r"[ab](\d)", expand=False), pd.Series), pd.Series)
947+
check(
948+
assert_type(s.str.extract(r"[ab](\d)", re.IGNORECASE, False), pd.Series),
949+
pd.Series,
950+
)

0 commit comments

Comments
 (0)