Skip to content

Commit 2311c49

Browse files
bashtageKevin Sheppard
and
Kevin Sheppard
authored
ENH: Imprpve interval_range and IntervalIndex (#351)
* ENH: IMrpve interval_range and IntervalIndex * TYP: Improve typing accuracy * TST: Add tests for IntervalIndex and interval_range Co-authored-by: Kevin Sheppard <kevin.sheppard@gmail.com>
1 parent bd2d98f commit 2311c49

File tree

4 files changed

+378
-21
lines changed

4 files changed

+378
-21
lines changed

pandas-stubs/_typing.pyi

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,7 @@ PandasScalar: TypeAlias = Union[
4444
]
4545
# Scalar: TypeAlias = Union[PythonScalar, PandasScalar]
4646

47-
DatetimeLike: TypeAlias = Union[
48-
datetime.date, datetime.datetime, np.datetime64, Timestamp
49-
]
47+
DatetimeLike: TypeAlias = Union[datetime.datetime, np.datetime64, Timestamp]
5048

5149
# dtypes
5250
NpDtype: TypeAlias = Union[

pandas-stubs/core/indexes/base.pyi

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ from pandas._typing import (
2828
Dtype,
2929
DtypeArg,
3030
DtypeObj,
31+
FillnaOptions,
3132
HashableT,
3233
IndexT,
3334
Label,
@@ -155,7 +156,12 @@ class Index(IndexOpsMixin, PandasObject):
155156
def symmetric_difference(
156157
self, other: list[T1] | Index, result_name=..., sort=...
157158
) -> Index: ...
158-
def get_loc(self, key, tolerance=...): ...
159+
def get_loc(
160+
self,
161+
key: Label,
162+
method: FillnaOptions | Literal["nearest"] | None = ...,
163+
tolerance=...,
164+
): ...
159165
def get_indexer(self, target, method=..., limit=..., tolerance=...): ...
160166
def reindex(self, target, method=..., level=..., limit=..., tolerance=...): ...
161167
def join(
Lines changed: 77 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,50 @@
1-
from typing import Hashable
1+
import datetime as dt
2+
from typing import (
3+
Any,
4+
Hashable,
5+
Literal,
6+
Sequence,
7+
Union,
8+
overload,
9+
)
210

311
import numpy as np
12+
import pandas as pd
13+
from pandas import Index
414
from pandas.core.indexes.extension import ExtensionIndex
15+
from typing_extensions import TypeAlias
516

617
from pandas._libs.interval import (
718
Interval as Interval,
819
IntervalMixin as IntervalMixin,
920
)
21+
from pandas._libs.tslibs.offsets import DateOffset
1022
from pandas._typing import (
23+
DatetimeLike,
1124
DtypeArg,
25+
FillnaOptions,
1226
IntervalClosedType,
27+
Label,
28+
npt,
1329
)
1430

1531
from pandas.core.dtypes.dtypes import IntervalDtype as IntervalDtype
1632
from pandas.core.dtypes.generic import ABCSeries
1733

34+
_Edges: TypeAlias = Union[
35+
Sequence[int],
36+
Sequence[float],
37+
Sequence[DatetimeLike],
38+
npt.NDArray[np.int_],
39+
npt.NDArray[np.float_],
40+
npt.NDArray[np.datetime64],
41+
pd.Series[int],
42+
pd.Series[float],
43+
pd.Series[pd.Timestamp],
44+
pd.Int64Index,
45+
pd.DatetimeIndex,
46+
]
47+
1848
class IntervalIndex(IntervalMixin, ExtensionIndex):
1949
def __new__(
2050
cls,
@@ -28,7 +58,7 @@ class IntervalIndex(IntervalMixin, ExtensionIndex):
2858
@classmethod
2959
def from_breaks(
3060
cls,
31-
breaks,
61+
breaks: _Edges,
3262
closed: IntervalClosedType = ...,
3363
name: Hashable = ...,
3464
copy: bool = ...,
@@ -37,8 +67,8 @@ class IntervalIndex(IntervalMixin, ExtensionIndex):
3767
@classmethod
3868
def from_arrays(
3969
cls,
40-
left,
41-
right,
70+
left: _Edges,
71+
right: _Edges,
4272
closed: IntervalClosedType = ...,
4373
name: Hashable = ...,
4474
copy: bool = ...,
@@ -47,37 +77,67 @@ class IntervalIndex(IntervalMixin, ExtensionIndex):
4777
@classmethod
4878
def from_tuples(
4979
cls,
50-
data,
80+
data: Sequence[tuple[int, int]]
81+
| Sequence[tuple[float, float]]
82+
| Sequence[tuple[DatetimeLike, DatetimeLike]]
83+
| npt.NDArray,
5184
closed: IntervalClosedType = ...,
5285
name: Hashable = ...,
5386
copy: bool = ...,
5487
dtype: IntervalDtype | None = ...,
5588
) -> IntervalIndex: ...
89+
def __contains__(self, key: Any) -> bool: ...
5690
def astype(self, dtype: DtypeArg, copy: bool = ...) -> IntervalIndex: ...
5791
@property
5892
def inferred_type(self) -> str: ...
5993
def memory_usage(self, deep: bool = ...) -> int: ...
6094
@property
6195
def is_overlapping(self) -> bool: ...
62-
def get_loc(self, key, tolerance=...) -> int | slice | np.ndarray: ...
96+
# Note: tolerance no effect. It is included in all get_loc so
97+
# that signatures are consistent with base even though it is usually not used
98+
def get_loc(
99+
self,
100+
key: Label,
101+
method: FillnaOptions | Literal["nearest"] | None = ...,
102+
tolerance=...,
103+
) -> int | slice | npt.NDArray[np.bool_]: ...
63104
def get_indexer(
64105
self,
65-
targetArrayLike,
66-
method: str | None = ...,
106+
target: Index,
107+
method: FillnaOptions | Literal["nearest"] | None = ...,
67108
limit: int | None = ...,
68109
tolerance=...,
69-
) -> np.ndarray: ...
110+
) -> npt.NDArray[np.intp]: ...
70111
def get_indexer_non_unique(
71-
self, targetArrayLike
72-
) -> tuple[np.ndarray, np.ndarray]: ...
112+
self, target: Index
113+
) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]: ...
114+
@property
115+
def left(self) -> Index: ...
116+
@property
117+
def right(self) -> Index: ...
118+
@property
119+
def mid(self) -> Index: ...
120+
@property
121+
def length(self) -> Index: ...
73122
def get_value(self, series: ABCSeries, key): ...
74123
@property
75124
def is_all_dates(self) -> bool: ...
76-
def __lt__(self, other): ...
77-
def __le__(self, other): ...
78-
def __gt__(self, other): ...
79-
def __ge__(self, other): ...
80125

126+
@overload
127+
def interval_range(
128+
start: int | float | None = ...,
129+
end: int | float | None = ...,
130+
periods: int | None = ...,
131+
freq: int | None = ...,
132+
name: Hashable = ...,
133+
closed: IntervalClosedType = ...,
134+
) -> IntervalIndex: ...
135+
@overload
81136
def interval_range(
82-
start=..., end=..., periods=..., freq=..., name=..., closed: str = ...
83-
): ...
137+
start: DatetimeLike | None = ...,
138+
end: DatetimeLike | None = ...,
139+
periods: int | None = ...,
140+
freq: str | DateOffset | None = ...,
141+
name: Hashable = ...,
142+
closed: IntervalClosedType = ...,
143+
) -> IntervalIndex: ...

0 commit comments

Comments
 (0)