Skip to content

Commit d7307e5

Browse files
authored
Make relational operators work with all scalars (#176)
* Make relational operators work with all scalars * use S1 rather than Scalar
1 parent a8bc6c6 commit d7307e5

File tree

2 files changed

+74
-12
lines changed

2 files changed

+74
-12
lines changed

pandas-stubs/core/series.pyi

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,12 +1143,8 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
11431143
def __div__(self, other: num | _ListLike | Series[S1]) -> Series[S1]: ...
11441144
def __eq__(self, other: object) -> Series[_bool]: ... # type: ignore[override]
11451145
def __floordiv__(self, other: num | _ListLike | Series[S1]) -> Series[int]: ...
1146-
def __ge__(
1147-
self, other: num | _ListLike | Series[S1] | Timestamp
1148-
) -> Series[_bool]: ...
1149-
def __gt__(
1150-
self, other: num | _ListLike | Series[S1] | Timestamp
1151-
) -> Series[_bool]: ...
1146+
def __ge__(self, other: S1 | _ListLike | Series[S1]) -> Series[_bool]: ...
1147+
def __gt__(self, other: S1 | _ListLike | Series[S1]) -> Series[_bool]: ...
11521148
# def __iadd__(self, other: S1) -> Series[S1]: ...
11531149
# def __iand__(self, other: S1) -> Series[_bool]: ...
11541150
# def __idiv__(self, other: S1) -> Series[S1]: ...
@@ -1161,12 +1157,8 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
11611157
# def __itruediv__(self, other: S1) -> Series[S1]: ...
11621158
# def __itruediv__(self, other) -> None: ...
11631159
# def __ixor__(self, other: S1) -> Series[_bool]: ...
1164-
def __le__(
1165-
self, other: num | _ListLike | Series[S1] | Timestamp
1166-
) -> Series[_bool]: ...
1167-
def __lt__(
1168-
self, other: num | _ListLike | Series[S1] | Timestamp
1169-
) -> Series[_bool]: ...
1160+
def __le__(self, other: S1 | _ListLike | Series[S1]) -> Series[_bool]: ...
1161+
def __lt__(self, other: S1 | _ListLike | Series[S1]) -> Series[_bool]: ...
11701162
@overload
11711163
def __mul__(self, other: Timedelta | TimedeltaSeries) -> TimedeltaSeries: ...
11721164
@overload

tests/test_series.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import datetime
34
from pathlib import Path
45
import re
56
import tempfile
@@ -948,3 +949,72 @@ def test_series_overloads_extract():
948949
assert_type(s.str.extract(r"[ab](\d)", re.IGNORECASE, False), pd.Series),
949950
pd.Series,
950951
)
952+
953+
954+
def test_relops() -> None:
955+
# GH 175
956+
s: str = "abc"
957+
check(assert_type(pd.Series([s]) > s, "pd.Series[bool]"), pd.Series, bool)
958+
check(assert_type(pd.Series([s]) < s, "pd.Series[bool]"), pd.Series, bool)
959+
check(assert_type(pd.Series([s]) <= s, "pd.Series[bool]"), pd.Series, bool)
960+
check(assert_type(pd.Series([s]) >= s, "pd.Series[bool]"), pd.Series, bool)
961+
962+
b: bytes = b"def"
963+
check(assert_type(pd.Series([b]) > b, "pd.Series[bool]"), pd.Series, bool)
964+
check(assert_type(pd.Series([b]) < b, "pd.Series[bool]"), pd.Series, bool)
965+
check(assert_type(pd.Series([b]) <= b, "pd.Series[bool]"), pd.Series, bool)
966+
check(assert_type(pd.Series([b]) >= b, "pd.Series[bool]"), pd.Series, bool)
967+
968+
dtd = datetime.date(2022, 7, 31)
969+
check(assert_type(pd.Series([dtd]) > dtd, "pd.Series[bool]"), pd.Series, bool)
970+
check(assert_type(pd.Series([dtd]) < dtd, "pd.Series[bool]"), pd.Series, bool)
971+
check(assert_type(pd.Series([dtd]) <= dtd, "pd.Series[bool]"), pd.Series, bool)
972+
check(assert_type(pd.Series([dtd]) >= dtd, "pd.Series[bool]"), pd.Series, bool)
973+
974+
dtdt = datetime.datetime(2022, 7, 31, 8, 32, 21)
975+
check(assert_type(pd.Series([dtdt]) > dtdt, "pd.Series[bool]"), pd.Series, bool)
976+
check(assert_type(pd.Series([dtdt]) < dtdt, "pd.Series[bool]"), pd.Series, bool)
977+
check(assert_type(pd.Series([dtdt]) <= dtdt, "pd.Series[bool]"), pd.Series, bool)
978+
check(assert_type(pd.Series([dtdt]) >= dtdt, "pd.Series[bool]"), pd.Series, bool)
979+
980+
dttd = datetime.timedelta(seconds=10)
981+
check(assert_type(pd.Series([dttd]) > dttd, "pd.Series[bool]"), pd.Series, bool)
982+
check(assert_type(pd.Series([dttd]) < dttd, "pd.Series[bool]"), pd.Series, bool)
983+
check(assert_type(pd.Series([dttd]) <= dttd, "pd.Series[bool]"), pd.Series, bool)
984+
check(assert_type(pd.Series([dttd]) >= dttd, "pd.Series[bool]"), pd.Series, bool)
985+
986+
bo: bool = True
987+
check(assert_type(pd.Series([bo]) > bo, "pd.Series[bool]"), pd.Series, bool)
988+
check(assert_type(pd.Series([bo]) < bo, "pd.Series[bool]"), pd.Series, bool)
989+
check(assert_type(pd.Series([bo]) <= bo, "pd.Series[bool]"), pd.Series, bool)
990+
check(assert_type(pd.Series([bo]) >= bo, "pd.Series[bool]"), pd.Series, bool)
991+
992+
ai: int = 10
993+
check(assert_type(pd.Series([ai]) > ai, "pd.Series[bool]"), pd.Series, bool)
994+
check(assert_type(pd.Series([ai]) < ai, "pd.Series[bool]"), pd.Series, bool)
995+
check(assert_type(pd.Series([ai]) <= ai, "pd.Series[bool]"), pd.Series, bool)
996+
check(assert_type(pd.Series([ai]) >= ai, "pd.Series[bool]"), pd.Series, bool)
997+
998+
af: float = 3.14
999+
check(assert_type(pd.Series([af]) > af, "pd.Series[bool]"), pd.Series, bool)
1000+
check(assert_type(pd.Series([af]) < af, "pd.Series[bool]"), pd.Series, bool)
1001+
check(assert_type(pd.Series([af]) <= af, "pd.Series[bool]"), pd.Series, bool)
1002+
check(assert_type(pd.Series([af]) >= af, "pd.Series[bool]"), pd.Series, bool)
1003+
1004+
ac: complex = 1 + 2j
1005+
check(assert_type(pd.Series([ac]) > ac, "pd.Series[bool]"), pd.Series, bool)
1006+
check(assert_type(pd.Series([ac]) < ac, "pd.Series[bool]"), pd.Series, bool)
1007+
check(assert_type(pd.Series([ac]) <= ac, "pd.Series[bool]"), pd.Series, bool)
1008+
check(assert_type(pd.Series([ac]) >= ac, "pd.Series[bool]"), pd.Series, bool)
1009+
1010+
ts = pd.Timestamp("2022-07-31 08:35:12")
1011+
check(assert_type(pd.Series([ts]) > ts, "pd.Series[bool]"), pd.Series, bool)
1012+
check(assert_type(pd.Series([ts]) < ts, "pd.Series[bool]"), pd.Series, bool)
1013+
check(assert_type(pd.Series([ts]) <= ts, "pd.Series[bool]"), pd.Series, bool)
1014+
check(assert_type(pd.Series([ts]) >= ts, "pd.Series[bool]"), pd.Series, bool)
1015+
1016+
td = pd.Timedelta(seconds=10)
1017+
check(assert_type(pd.Series([td]) > td, "pd.Series[bool]"), pd.Series, bool)
1018+
check(assert_type(pd.Series([td]) < td, "pd.Series[bool]"), pd.Series, bool)
1019+
check(assert_type(pd.Series([td]) <= td, "pd.Series[bool]"), pd.Series, bool)
1020+
check(assert_type(pd.Series([td]) >= td, "pd.Series[bool]"), pd.Series, bool)

0 commit comments

Comments
 (0)