Skip to content

Commit 6a6a21f

Browse files
committed
overload for __getitem__ and use pattern with ExtensionArrayT as self and return type
1 parent e0e0131 commit 6a6a21f

File tree

3 files changed

+49
-16
lines changed

3 files changed

+49
-16
lines changed

pandas/core/arrays/_mixins.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
Type,
99
TypeVar,
1010
Union,
11+
overload,
1112
)
1213

1314
import numpy as np
@@ -248,6 +249,22 @@ def __setitem__(self, key, value):
248249
def _validate_setitem_value(self, value):
249250
return value
250251

252+
@overload
253+
def __getitem__(self: NDArrayBackedExtensionArrayT, key: int) -> Any:
254+
...
255+
256+
@overload
257+
def __getitem__(
258+
self: NDArrayBackedExtensionArrayT, key: slice
259+
) -> NDArrayBackedExtensionArrayT:
260+
...
261+
262+
@overload
263+
def __getitem__(
264+
self: NDArrayBackedExtensionArrayT, key: np.ndarray
265+
) -> NDArrayBackedExtensionArrayT:
266+
...
267+
251268
def __getitem__(
252269
self: NDArrayBackedExtensionArrayT, key: Union[int, slice, np.ndarray]
253270
) -> Union[NDArrayBackedExtensionArrayT, Any]:

pandas/core/arrays/base.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
TypeVar,
2323
Union,
2424
cast,
25+
overload,
2526
)
2627

2728
import numpy as np
@@ -287,8 +288,20 @@ def _from_factorized(cls, values, original):
287288
# Must be a Sequence
288289
# ------------------------------------------------------------------------
289290

291+
@overload
292+
def __getitem__(self: ExtensionArrayT, item: int) -> Any:
293+
...
294+
295+
@overload
296+
def __getitem__(self: ExtensionArrayT, item: slice) -> ExtensionArrayT:
297+
...
298+
299+
@overload
300+
def __getitem__(self: ExtensionArrayT, item: np.ndarray) -> ExtensionArrayT:
301+
...
302+
290303
def __getitem__(
291-
self, item: Union[int, slice, np.ndarray]
304+
self: ExtensionArrayT, item: Union[int, slice, np.ndarray]
292305
) -> Union[ExtensionArray, Any]:
293306
"""
294307
Select a subset of self.
@@ -673,11 +686,11 @@ def argmax(self, skipna: bool = True) -> int:
673686
return nargminmax(self, "argmax")
674687

675688
def fillna(
676-
self,
689+
self: ExtensionArrayT,
677690
value: Optional[Union[Any, ArrayLike]] = None,
678691
method: Optional[Literal["backfill", "bfill", "ffill", "pad"]] = None,
679692
limit: Optional[int] = None,
680-
) -> ExtensionArray:
693+
) -> ExtensionArrayT:
681694
"""
682695
Fill NA/NaN values using the specified method.
683696
@@ -722,7 +735,7 @@ def fillna(
722735
new_values = self.copy()
723736
return new_values
724737

725-
def dropna(self) -> ExtensionArray:
738+
def dropna(self) -> ExtensionArrayT:
726739
"""
727740
Return ExtensionArray without NA values.
728741
@@ -732,7 +745,9 @@ def dropna(self) -> ExtensionArray:
732745
"""
733746
return self[~self.isna()]
734747

735-
def shift(self, periods: int = 1, fill_value: object = None) -> ExtensionArray:
748+
def shift(
749+
self: ExtensionArrayT, periods: int = 1, fill_value: object = None
750+
) -> ExtensionArrayT:
736751
"""
737752
Shift values by desired number.
738753
@@ -780,13 +795,13 @@ def shift(self, periods: int = 1, fill_value: object = None) -> ExtensionArray:
780795
)
781796
if periods > 0:
782797
a = empty
783-
b = self[:-periods]
798+
b: ExtensionArrayT = self[:-periods]
784799
else:
785800
a = self[abs(periods) :]
786801
b = empty
787802
return self._concat_same_type([a, b])
788803

789-
def unique(self) -> ExtensionArray:
804+
def unique(self: ExtensionArrayT) -> ExtensionArrayT:
790805
"""
791806
Compute the ExtensionArray of unique values.
792807
@@ -1018,8 +1033,10 @@ def factorize(self, na_sentinel: int = -1) -> Tuple[np.ndarray, ExtensionArray]:
10181033
@Substitution(klass="ExtensionArray")
10191034
@Appender(_extension_array_shared_docs["repeat"])
10201035
def repeat(
1021-
self, repeats: Union[int, Sequence[int]], axis: Literal[None] = None
1022-
) -> ExtensionArray:
1036+
self: ExtensionArrayT,
1037+
repeats: Union[int, Sequence[int]],
1038+
axis: Literal[None] = None,
1039+
) -> ExtensionArrayT:
10231040
nv.validate_repeat((), {"axis": axis})
10241041
ind = np.arange(len(self)).repeat(repeats)
10251042
return self.take(ind)
@@ -1203,7 +1220,7 @@ def _formatter(self, boxed: bool = False) -> Callable[[Any], Optional[str]]:
12031220
# Reshaping
12041221
# ------------------------------------------------------------------------
12051222

1206-
def transpose(self, *axes: int) -> ExtensionArray:
1223+
def transpose(self: ExtensionArrayT, *axes: int) -> ExtensionArrayT:
12071224
"""
12081225
Return a transposed view on this array.
12091226
@@ -1213,12 +1230,12 @@ def transpose(self, *axes: int) -> ExtensionArray:
12131230
return self[:]
12141231

12151232
@property
1216-
def T(self) -> ExtensionArray:
1233+
def T(self: ExtensionArrayT) -> ExtensionArrayT:
12171234
return self.transpose()
12181235

12191236
def ravel(
1220-
self, order: Optional[Literal["C", "F", "A", "K"]] = "C"
1221-
) -> ExtensionArray:
1237+
self: ExtensionArrayT, order: Optional[Literal["C", "F", "A", "K"]] = "C"
1238+
) -> ExtensionArrayT:
12221239
"""
12231240
Return a flattened view on this array.
12241241
@@ -1240,7 +1257,7 @@ def ravel(
12401257
@classmethod
12411258
def _concat_same_type(
12421259
cls: Type[ExtensionArrayT], to_concat: Sequence[ExtensionArrayT]
1243-
) -> ExtensionArrayT:
1260+
):
12441261
"""
12451262
Concatenate multiple array of this dtype.
12461263

pandas/core/arrays/sparse/array.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
Type,
1515
TypeVar,
1616
Union,
17-
cast,
1817
)
1918
import warnings
2019

@@ -1204,7 +1203,7 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
12041203
if skipna:
12051204
arr = self
12061205
else:
1207-
arr = cast(SparseArray, self.dropna())
1206+
arr = self.dropna()
12081207

12091208
# we don't support these kwargs.
12101209
# They should only be present when called via pandas, so do it here.

0 commit comments

Comments
 (0)