Skip to content

Commit 5a9515b

Browse files
bashtageKevin Sheppard
andauthored
ENH: Improve typing for pivot and pivot_table (#379)
* ENH: Improve typing for pivot and pivot_table * TST: Add test for pivot * TST: Add tests for pivot_table * TST: Add test for pivot_table * ENH: Improve function definitions * BUG: Correct type * TST: Fix test issue * TYP: Improve concat * ENH: Allow selected use of Index or ndarray * TYP: Final typing of concat * TYP: Final typing of concat * TYP: Final typing of concat * BUG: Correct values type * CLN: Remove comments Co-authored-by: Kevin Sheppard <kevin.sheppard@gmail.com>
1 parent c5d6648 commit 5a9515b

File tree

3 files changed

+639
-47
lines changed

3 files changed

+639
-47
lines changed

pandas-stubs/core/reshape/concat.pyi

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ from typing import (
22
Iterable,
33
Literal,
44
Mapping,
5+
Sequence,
56
overload,
67
)
78

@@ -10,43 +11,48 @@ from pandas import (
1011
Series,
1112
)
1213

13-
from pandas._typing import HashableT
14+
from pandas._typing import (
15+
HashableT1,
16+
HashableT2,
17+
HashableT3,
18+
HashableT4,
19+
)
1420

1521
@overload
1622
def concat(
17-
objs: Iterable[DataFrame] | Mapping[HashableT, DataFrame],
23+
objs: Iterable[DataFrame] | Mapping[HashableT1, DataFrame],
1824
axis: Literal[0, "index"] = ...,
19-
join: str = ...,
25+
join: Literal["inner", "outer"] = ...,
2026
ignore_index: bool = ...,
21-
keys=...,
22-
levels=...,
23-
names=...,
27+
keys: list[HashableT2] = ...,
28+
levels: Sequence[list[HashableT3] | tuple[HashableT3, ...]] = ...,
29+
names: list[HashableT4] = ...,
2430
verify_integrity: bool = ...,
2531
sort: bool = ...,
2632
copy: bool = ...,
2733
) -> DataFrame: ...
2834
@overload
2935
def concat(
30-
objs: Iterable[Series] | Mapping[HashableT, Series],
36+
objs: Iterable[Series] | Mapping[HashableT1, Series],
3137
axis: Literal[0, "index"] = ...,
32-
join: str = ...,
38+
join: Literal["inner", "outer"] = ...,
3339
ignore_index: bool = ...,
34-
keys=...,
35-
levels=...,
36-
names=...,
40+
keys: list[HashableT2] = ...,
41+
levels: Sequence[list[HashableT3] | tuple[HashableT3, ...]] = ...,
42+
names: list[HashableT4] = ...,
3743
verify_integrity: bool = ...,
3844
sort: bool = ...,
3945
copy: bool = ...,
4046
) -> Series: ...
4147
@overload
4248
def concat(
43-
objs: Iterable[Series | DataFrame] | Mapping[HashableT, Series | DataFrame],
49+
objs: Iterable[Series | DataFrame] | Mapping[HashableT1, Series | DataFrame],
4450
axis: Literal[1, "columns"],
45-
join: str = ...,
51+
join: Literal["inner", "outer"] = ...,
4652
ignore_index: bool = ...,
47-
keys=...,
48-
levels=...,
49-
names=...,
53+
keys: list[HashableT2] = ...,
54+
levels: Sequence[list[HashableT3] | tuple[HashableT3, ...]] = ...,
55+
names: list[HashableT4] = ...,
5056
verify_integrity: bool = ...,
5157
sort: bool = ...,
5258
copy: bool = ...,

pandas-stubs/core/reshape/pivot.pyi

Lines changed: 85 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
import datetime
12
from typing import (
23
Callable,
34
Hashable,
45
Literal,
6+
Mapping,
57
Sequence,
6-
TypeVar,
78
Union,
89
overload,
910
)
@@ -12,47 +13,116 @@ import numpy as np
1213
import pandas as pd
1314
from pandas.core.frame import DataFrame
1415
from pandas.core.groupby.grouper import Grouper
16+
from pandas.core.indexes.base import Index
1517
from pandas.core.series import Series
1618
from typing_extensions import TypeAlias
1719

1820
from pandas._typing import (
1921
AnyArrayLike,
2022
ArrayLike,
21-
HashableT,
22-
IndexLabel,
23+
HashableT1,
24+
HashableT2,
25+
HashableT3,
26+
Label,
2327
Scalar,
28+
ScalarT,
29+
npt,
2430
)
2531

32+
_PivotAggCallable: TypeAlias = Callable[[Series], ScalarT]
33+
34+
_PivotAggFunc: TypeAlias = Union[
35+
_PivotAggCallable,
36+
np.ufunc,
37+
Literal["mean", "sum", "count", "min", "max", "median", "std", "var"],
38+
]
39+
40+
_NonIterableHashable: TypeAlias = Union[
41+
str,
42+
datetime.date,
43+
datetime.datetime,
44+
datetime.timedelta,
45+
bool,
46+
int,
47+
float,
48+
complex,
49+
pd.Timestamp,
50+
pd.Timedelta,
51+
]
52+
53+
_PivotTableIndexTypes: TypeAlias = Union[Label, list[HashableT1], Series, Grouper, None]
54+
_PivotTableColumnsTypes: TypeAlias = Union[
55+
Label, list[HashableT2], Series, Grouper, None
56+
]
57+
2658
_ExtendedAnyArrayLike: TypeAlias = Union[AnyArrayLike, ArrayLike]
2759

28-
_HashableT2 = TypeVar("_HashableT2", bound=Hashable)
60+
@overload
61+
def pivot_table(
62+
data: DataFrame,
63+
values: Label | list[HashableT3] | None = ...,
64+
index: _PivotTableIndexTypes = ...,
65+
columns: _PivotTableColumnsTypes = ...,
66+
aggfunc: _PivotAggFunc
67+
| list[_PivotAggFunc]
68+
| Mapping[Hashable, _PivotAggFunc] = ...,
69+
fill_value: Scalar | None = ...,
70+
margins: bool = ...,
71+
dropna: bool = ...,
72+
margins_name: str = ...,
73+
observed: bool = ...,
74+
sort: bool = ...,
75+
) -> DataFrame: ...
2976

77+
# Can only use Index or ndarray when index or columns is a Grouper
78+
@overload
79+
def pivot_table(
80+
data: DataFrame,
81+
values: Label | list[HashableT3] | None = ...,
82+
*,
83+
index: Grouper,
84+
columns: _PivotTableColumnsTypes | Index | npt.NDArray = ...,
85+
aggfunc: _PivotAggFunc
86+
| list[_PivotAggFunc]
87+
| Mapping[Hashable, _PivotAggFunc] = ...,
88+
fill_value: Scalar | None = ...,
89+
margins: bool = ...,
90+
dropna: bool = ...,
91+
margins_name: str = ...,
92+
observed: bool = ...,
93+
sort: bool = ...,
94+
) -> DataFrame: ...
95+
@overload
3096
def pivot_table(
3197
data: DataFrame,
32-
values: str | None = ...,
33-
index: str | Sequence | Grouper | None = ...,
34-
columns: str | Sequence | Grouper | None = ...,
35-
aggfunc=...,
98+
values: Label | list[HashableT3] | None = ...,
99+
index: _PivotTableIndexTypes | Index | npt.NDArray = ...,
100+
*,
101+
columns: Grouper,
102+
aggfunc: _PivotAggFunc
103+
| list[_PivotAggFunc]
104+
| Mapping[Hashable, _PivotAggFunc] = ...,
36105
fill_value: Scalar | None = ...,
37106
margins: bool = ...,
38107
dropna: bool = ...,
39108
margins_name: str = ...,
40109
observed: bool = ...,
110+
sort: bool = ...,
41111
) -> DataFrame: ...
42112
def pivot(
43113
data: DataFrame,
44114
*,
45-
index: IndexLabel = ...,
46-
columns: IndexLabel = ...,
47-
values: IndexLabel = ...,
115+
index: _NonIterableHashable | list[HashableT1] = ...,
116+
columns: _NonIterableHashable | list[HashableT2] = ...,
117+
values: _NonIterableHashable | list[HashableT3] = ...,
48118
) -> DataFrame: ...
49119
@overload
50120
def crosstab(
51121
index: list | _ExtendedAnyArrayLike | list[Sequence | _ExtendedAnyArrayLike],
52122
columns: list | _ExtendedAnyArrayLike | list[Sequence | _ExtendedAnyArrayLike],
53123
values: list | _ExtendedAnyArrayLike,
54-
rownames: list[HashableT] | None = ...,
55-
colnames: list[_HashableT2] | None = ...,
124+
rownames: list[HashableT1] | None = ...,
125+
colnames: list[HashableT2] | None = ...,
56126
*,
57127
aggfunc: str | np.ufunc | Callable[[Series], float],
58128
margins: bool = ...,
@@ -65,8 +135,8 @@ def crosstab(
65135
index: list | _ExtendedAnyArrayLike | list[Sequence | _ExtendedAnyArrayLike],
66136
columns: list | _ExtendedAnyArrayLike | list[Sequence | _ExtendedAnyArrayLike],
67137
values: None = ...,
68-
rownames: list[HashableT] | None = ...,
69-
colnames: list[_HashableT2] | None = ...,
138+
rownames: list[HashableT1] | None = ...,
139+
colnames: list[HashableT2] | None = ...,
70140
aggfunc: None = ...,
71141
margins: bool = ...,
72142
margins_name: str = ...,

0 commit comments

Comments
 (0)