Skip to content

Commit ae4bb1a

Browse files
Kevin Sheppardbashtage
Kevin Sheppard
authored andcommitted
ENH: Improve Period typing and testing
1 parent e29a94a commit ae4bb1a

File tree

2 files changed

+97
-38
lines changed

2 files changed

+97
-38
lines changed

pandas-stubs/_libs/tslibs/period.pyi

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@ from pandas import (
1010
Index,
1111
PeriodIndex,
1212
Timedelta,
13+
TimedeltaIndex,
1314
)
1415
from pandas.core.series import (
1516
PeriodSeries,
1617
TimedeltaSeries,
1718
)
1819
from typing_extensions import TypeAlias
1920

21+
from pandas._libs.tslibs import NaTType
2022
from pandas._typing import npt
2123

2224
from .timestamps import Timestamp
@@ -75,13 +77,19 @@ class Period(PeriodMixin):
7577
@overload
7678
def __sub__(self, other: Period) -> BaseOffset: ...
7779
@overload
80+
def __sub__(self, other: NaTType) -> NaTType: ...
81+
@overload
7882
def __sub__(self, other: PeriodIndex) -> Index: ...
7983
@overload
8084
def __sub__(self, other: TimedeltaSeries) -> PeriodSeries: ...
85+
@overload
86+
def __sub__(self, other: TimedeltaIndex) -> PeriodIndex: ...
8187
def __rsub__(self, other: PeriodIndex) -> Index: ...
8288
@overload
8389
def __add__(self, other: _PeriodAddSub) -> Period: ...
8490
@overload
91+
def __add__(self, other: NaTType) -> NaTType: ...
92+
@overload
8593
def __add__(self, other: Index) -> PeriodIndex: ...
8694
@overload
8795
def __add__(self, other: TimedeltaSeries) -> PeriodSeries: ...
@@ -122,6 +130,8 @@ class Period(PeriodMixin):
122130
def __radd__(self, other: Index) -> Index: ...
123131
@overload
124132
def __radd__(self, other: TimedeltaSeries) -> PeriodSeries: ...
133+
@overload
134+
def __radd__(self, other: NaTType) -> NaTType: ...
125135
@property
126136
def day(self) -> int: ...
127137
@property

tests/test_scalars.py

Lines changed: 87 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,27 @@
1111
import pandas as pd
1212
from typing_extensions import assert_type
1313

14-
from pandas._libs.tslibs import BaseOffset
14+
from pandas._libs.tslibs import (
15+
BaseOffset,
16+
NaTType,
17+
)
1518

1619
if TYPE_CHECKING:
17-
from pandas.core.series import PeriodSeries # noqa: F401
18-
from pandas.core.series import TimedeltaSeries # noqa: F401
20+
from pandas.core.series import (
21+
PeriodSeries,
22+
TimedeltaSeries,
23+
)
1924

2025
from pandas._typing import np_ndarray_bool
2126
else:
22-
np_ndarray_bool = Any
27+
PeriodSeries = TimedeltaSeries = np_ndarray_bool = Any
2328

2429
from tests import check
2530

2631
from pandas.tseries.offsets import Day
2732

2833

29-
def test_period() -> None:
34+
def test_period_construction() -> None:
3035
p = pd.Period("2012-1-1", freq="D")
3136
check(assert_type(p, pd.Period), pd.Period)
3237
check(assert_type(pd.Period(p), pd.Period), pd.Period)
@@ -53,6 +58,11 @@ def test_period() -> None:
5358
pd.Period,
5459
)
5560
check(assert_type(pd.Period(freq="Q", year=2012, quarter=2), pd.Period), pd.Period)
61+
62+
63+
def test_period_properties() -> None:
64+
p = pd.Period("2012-1-1", freq="D")
65+
5666
check(assert_type(p.day, int), int)
5767
check(assert_type(p.day_of_week, int), int)
5868
check(assert_type(p.day_of_year, int), int)
@@ -80,15 +90,22 @@ def test_period() -> None:
8090
p2 = pd.Period("2012-1-1", freq="2D")
8191
check(assert_type(p2.freq, BaseOffset), Day)
8292

93+
94+
def test_periof_add_subtract() -> None:
95+
p = pd.Period("2012-1-1", freq="D")
96+
8397
as0 = pd.Timedelta(1, "D")
8498
as1 = dt.timedelta(days=1)
8599
as2 = np.timedelta64(1, "D")
86100
as3 = np.int64(1)
87101
as4 = int(1)
88102
as5 = pd.period_range("2012-1-1", periods=10, freq="D")
89103
as6 = pd.Period("2012-1-1", freq="D")
90-
as7 = cast("TimedeltaSeries", pd.Series([pd.Timedelta(days=1)]))
91-
as8 = cast("PeriodSeries", pd.Series([as6]))
104+
scale = 24 * 60 * 60 * 10**9
105+
as7 = cast(TimedeltaSeries, pd.Series(pd.timedelta_range(scale, scale, freq="D")))
106+
as8 = pd.Series(as5)
107+
as9 = pd.timedelta_range(scale, scale, freq="D")
108+
as10 = pd.NaT
92109

93110
check(assert_type(p + as0, pd.Period), pd.Period)
94111
check(assert_type(p + as1, pd.Period), pd.Period)
@@ -97,25 +114,32 @@ def test_period() -> None:
97114
check(assert_type(p + as4, pd.Period), pd.Period)
98115
check(assert_type(p + p.freq, pd.Period), pd.Period)
99116
check(assert_type(p + (p - as5), pd.PeriodIndex), pd.PeriodIndex)
100-
check(assert_type(p + as7, "PeriodSeries"), pd.Series)
101-
das8 = cast("TimedeltaSeries", (as8 - as8))
102-
check(assert_type(p + das8, "PeriodSeries"), pd.Series)
117+
check(assert_type(p + as7, PeriodSeries), pd.Series, pd.Period)
118+
check(assert_type(p + as9, pd.PeriodIndex), pd.PeriodIndex)
119+
check(assert_type(p + as10, NaTType), NaTType)
120+
das8 = cast(TimedeltaSeries, (as8 - as8))
121+
check(assert_type(p + das8, PeriodSeries), pd.Series, pd.Period)
103122
check(assert_type(p - as0, pd.Period), pd.Period)
104123
check(assert_type(p - as1, pd.Period), pd.Period)
105124
check(assert_type(p - as2, pd.Period), pd.Period)
106125
check(assert_type(p - as3, pd.Period), pd.Period)
107126
check(assert_type(p - as4, pd.Period), pd.Period)
108127
check(assert_type(p - as5, pd.Index), pd.Index)
109128
check(assert_type(p - as6, BaseOffset), Day)
110-
check(assert_type(p - as7, "PeriodSeries"), pd.Series)
129+
check(assert_type(p - as7, PeriodSeries), pd.Series, pd.Period)
130+
check(assert_type(p - as9, pd.PeriodIndex), pd.PeriodIndex)
131+
check(assert_type(p - as10, NaTType), NaTType)
111132
check(assert_type(p - p.freq, pd.Period), pd.Period)
112133

113134
check(assert_type(as0 + p, pd.Period), pd.Period)
114135
check(assert_type(as1 + p, pd.Period), pd.Period)
115136
check(assert_type(as2 + p, pd.Period), pd.Period)
116137
check(assert_type(as3 + p, pd.Period), pd.Period)
117138
check(assert_type(as4 + p, pd.Period), pd.Period)
118-
check(assert_type(as7 + p, "PeriodSeries"), pd.Series)
139+
check(assert_type(as7 + p, PeriodSeries), pd.Series, pd.Period)
140+
# TODO: Improve Index to not handle __add__(period)
141+
check(assert_type(as9 + p, pd.Index), pd.PeriodIndex)
142+
check(assert_type(as10 + p, NaTType), NaTType)
119143
check(assert_type(p.freq + p, pd.Period), pd.Period)
120144

121145
check(assert_type(as5 - p, pd.Index), pd.Index)
@@ -125,41 +149,66 @@ def test_period() -> None:
125149
check(assert_type(p.__radd__(as2), pd.Period), pd.Period)
126150
check(assert_type(p.__radd__(as3), pd.Period), pd.Period)
127151
check(assert_type(p.__radd__(as4), pd.Period), pd.Period)
152+
check(assert_type(p.__radd__(as10), NaTType), NaTType)
128153
check(assert_type(p.__radd__(p.freq), pd.Period), pd.Period)
129154

155+
156+
def test_period_cmp() -> None:
157+
p = pd.Period("2012-1-1", freq="D")
158+
130159
c0 = pd.Period("2012-1-1", freq="D")
131160
c1 = pd.period_range("2012-1-1", periods=10, freq="D")
132161

133-
check(assert_type(p == c0, bool), bool)
134-
check(assert_type(p == c1, np_ndarray_bool), np.ndarray)
135-
check(assert_type(c0 == p, bool), bool)
136-
check(assert_type(c1 == p, np_ndarray_bool), np.ndarray)
162+
eq = check(assert_type(p == c0, bool), bool)
163+
ne = check(assert_type(p != c0, bool), bool)
164+
assert eq != ne
165+
166+
eq_a = check(assert_type(p == c1, np_ndarray_bool), np.ndarray)
167+
ne_q = check(assert_type(p != c1, np_ndarray_bool), np.ndarray)
168+
assert (eq_a != ne_q).all()
169+
170+
eq = check(assert_type(c0 == p, bool), bool)
171+
ne = check(assert_type(c0 != p, bool), bool)
172+
assert eq != ne
173+
174+
eq_a = check(assert_type(c1 == p, np_ndarray_bool), np.ndarray)
175+
ne_a = check(assert_type(c1 != p, np_ndarray_bool), np.ndarray)
176+
assert (eq_a != ne_q).all()
177+
178+
gt = check(assert_type(p > c0, bool), bool)
179+
le = check(assert_type(p <= c0, bool), bool)
180+
assert gt != le
181+
182+
gt_a = check(assert_type(p > c1, np_ndarray_bool), np.ndarray)
183+
le_a = check(assert_type(p <= c1, np_ndarray_bool), np.ndarray)
184+
assert (gt_a != le_a).all()
185+
186+
gt = check(assert_type(c0 > p, bool), bool)
187+
le = check(assert_type(c0 <= p, bool), bool)
188+
assert gt != le
189+
190+
gt_a = check(assert_type(c1 > p, np_ndarray_bool), np.ndarray)
191+
le_a = check(assert_type(c1 <= p, np_ndarray_bool), np.ndarray)
192+
assert (gt_a != le_a).all()
137193

138-
check(assert_type(p != c0, bool), bool)
139-
check(assert_type(p != c1, np_ndarray_bool), np.ndarray)
140-
check(assert_type(c0 != p, bool), bool)
141-
check(assert_type(c1 != p, np_ndarray_bool), np.ndarray)
194+
lt = check(assert_type(p < c0, bool), bool)
195+
ge = check(assert_type(p >= c0, bool), bool)
196+
assert lt != ge
142197

143-
check(assert_type(p > c0, bool), bool)
144-
check(assert_type(p > c1, np_ndarray_bool), np.ndarray)
145-
check(assert_type(c0 > p, bool), bool)
146-
check(assert_type(c1 > p, np_ndarray_bool), np.ndarray)
198+
lt_a = check(assert_type(p < c1, np_ndarray_bool), np.ndarray)
199+
ge_a = check(assert_type(p >= c1, np_ndarray_bool), np.ndarray)
200+
assert (lt_a != ge_a).all()
147201

148-
check(assert_type(p < c0, bool), bool)
149-
check(assert_type(p < c1, np_ndarray_bool), np.ndarray)
150-
check(assert_type(c0 < p, bool), bool)
151-
check(assert_type(c1 < p, np_ndarray_bool), np.ndarray)
202+
lt = check(assert_type(c0 < p, bool), bool)
203+
ge = check(assert_type(c0 >= p, bool), bool)
204+
assert lt != ge
152205

153-
check(assert_type(p <= c0, bool), bool)
154-
check(assert_type(p <= c1, np_ndarray_bool), np.ndarray)
155-
check(assert_type(c0 <= p, bool), bool)
156-
check(assert_type(c1 <= p, np_ndarray_bool), np.ndarray)
206+
lt_a = check(assert_type(c1 < p, np_ndarray_bool), np.ndarray)
207+
ge_a = check(assert_type(c1 >= p, np_ndarray_bool), np.ndarray)
208+
assert (lt_a != ge_a).all()
157209

158-
check(assert_type(p >= c0, bool), bool)
159-
check(assert_type(p >= c1, np_ndarray_bool), np.ndarray)
160-
check(assert_type(c0 >= p, bool), bool)
161-
check(assert_type(c1 >= p, np_ndarray_bool), np.ndarray)
162210

211+
def test_period_methods():
163212
p3 = pd.Period("2007-01", freq="M")
164213
check(assert_type(p3.to_timestamp("D", "S"), pd.Timestamp), pd.Timestamp)
165214
check(assert_type(p3.to_timestamp("D", "E"), pd.Timestamp), pd.Timestamp)
@@ -181,5 +230,5 @@ def test_period() -> None:
181230
check(assert_type(pd.Period.now("D"), pd.Period), pd.Period)
182231
check(assert_type(pd.Period.now(Day()), pd.Period), pd.Period)
183232

184-
check(assert_type(p.strftime("%Y-%m-%d"), str), str)
185-
check(assert_type(hash(p), int), int)
233+
check(assert_type(p3.strftime("%Y-%m-%d"), str), str)
234+
check(assert_type(hash(p3), int), int)

0 commit comments

Comments
 (0)