Skip to content

Commit d59fe38

Browse files
authored
add Series.case_when (#957)
add case_when
1 parent 9d5348e commit d59fe38

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

pandas-stubs/core/series.pyi

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1426,6 +1426,17 @@ class Series(IndexOpsMixin[S1], NDFrame):
14261426
axis: AxisIndex | None = ...,
14271427
level: Level | None = ...,
14281428
) -> Series[S1]: ...
1429+
def case_when(
1430+
self,
1431+
caselist: list[
1432+
tuple[
1433+
Sequence[bool]
1434+
| Series[bool]
1435+
| Callable[[Series], Series | np.ndarray | Sequence[bool]],
1436+
ListLikeU | Scalar | Callable[[Series], Series | np.ndarray],
1437+
],
1438+
],
1439+
) -> Series: ...
14291440
def truncate(
14301441
self,
14311442
before: date | _str | int | None = ...,

tests/test_series.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3288,3 +3288,16 @@ def callable(x: int | NAType) -> str | NAType:
32883288

32893289
series = pd.Series(["a", "b", "c"])
32903290
check(assert_type(s.map(series, na_action=None), "pd.Series[str]"), pd.Series, str)
3291+
3292+
3293+
def test_case_when() -> None:
3294+
c = pd.Series([6, 7, 8, 9], name="c")
3295+
a = pd.Series([0, 0, 1, 2])
3296+
b = pd.Series([0, 3, 4, 5])
3297+
r = c.case_when(
3298+
caselist=[
3299+
(a.gt(0), a),
3300+
(b.gt(0), b),
3301+
]
3302+
)
3303+
check(assert_type(r, pd.Series), pd.Series)

0 commit comments

Comments
 (0)