Skip to content

Commit 371f406

Browse files
committed
Add get method to Series
1 parent 6487b7f commit 371f406

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

pandas-stubs/core/series.pyi

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ from typing import (
1818
ClassVar,
1919
Generic,
2020
Literal,
21+
TypeVar,
2122
overload,
2223
)
2324

@@ -160,6 +161,8 @@ from pandas.plotting import PlotAccessor
160161
_bool = bool
161162
_str = str
162163

164+
_T = TypeVar("_T")
165+
163166
class _iLocIndexerSeries(_iLocIndexer, Generic[S1]):
164167
# get item
165168
@overload
@@ -381,6 +384,10 @@ class Series(IndexOpsMixin[S1], NDFrame):
381384
@overload
382385
def __getitem__(self, idx: Scalar) -> S1: ...
383386
def __setitem__(self, key, value) -> None: ...
387+
@overload
388+
def get(self, key: Hashable, default: None = ...) -> S1 | None: ...
389+
@overload
390+
def get(self, key: Hashable, default: S1 | _T = ...) -> S1 | _T: ...
384391
def repeat(
385392
self, repeats: int | list[int], axis: AxisIndex | None = ...
386393
) -> Series[S1]: ...

tests/test_series.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
Any,
1717
Generic,
1818
TypeVar,
19+
Union,
1920
cast,
2021
)
2122

@@ -2875,6 +2876,24 @@ def test_round() -> None:
28752876
check(assert_type(round(pd.Series([1], dtype=int)), "pd.Series[int]"), pd.Series)
28762877

28772878

2879+
def test_get() -> None:
2880+
s_int = pd.Series([1, 2, 3], index=[1, 2, 3])
2881+
2882+
check(assert_type(s_int.get(1), Union[int, None]), np.int64)
2883+
check(assert_type(s_int.get(99), Union[int, None]), type(None))
2884+
check(assert_type(s_int.get(1, default=None), Union[int, None]), np.int64)
2885+
check(assert_type(s_int.get(1, default=2), int), np.int64)
2886+
check(assert_type(s_int.get(99, default="a"), Union[int, str]), str)
2887+
2888+
s_str = pd.Series(list("abc"), index=list("abc"))
2889+
2890+
check(assert_type(s_str.get("a"), Union[str, None]), str)
2891+
check(assert_type(s_str.get("z"), Union[str, None]), type(None))
2892+
check(assert_type(s_str.get("a", default=None), Union[str, None]), str)
2893+
check(assert_type(s_str.get("a", default="b"), str), str)
2894+
check(assert_type(s_str.get("z", default=True), Union[str, bool]), bool)
2895+
2896+
28782897
def test_series_new_empty() -> None:
28792898
# GH 826
28802899
check(assert_type(pd.Series(), "pd.Series[Any]"), pd.Series)

0 commit comments

Comments
 (0)