Skip to content

Commit 3da08b1

Browse files
committed
ENH: Improve typing of some general functions
1 parent ef736be commit 3da08b1

File tree

3 files changed

+160
-5
lines changed

3 files changed

+160
-5
lines changed

pandas-stubs/core/algorithms.pyi

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,30 @@ from typing import (
44
)
55

66
import numpy as np
7+
import pandas as pd
78
from pandas import (
89
Categorical,
10+
CategoricalIndex,
11+
DatetimeIndex,
912
Index,
13+
PeriodIndex,
14+
RangeIndex,
1015
Series,
1116
)
1217
from pandas.api.extensions import ExtensionArray
1318

1419
from pandas._typing import AnyArrayLike
1520

1621
@overload
17-
def unique(values: Index) -> Index: ...
22+
def unique(values: DatetimeIndex) -> DatetimeIndex: ...
23+
@overload
24+
def unique(values: PeriodIndex) -> PeriodIndex: ...
25+
@overload
26+
def unique(values: CategoricalIndex) -> CategoricalIndex: ...
27+
@overload
28+
def unique(values: RangeIndex | pd.Float64Index) -> np.ndarray: ...
29+
@overload
30+
def unique(values: Index) -> Index | np.ndarray: ...
1831
@overload
1932
def unique(values: Categorical) -> Categorical: ...
2033
@overload

pandas-stubs/core/reshape/melt.pyi

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,27 @@
1+
from typing import Hashable
2+
13
import numpy as np
24
from pandas.core.frame import DataFrame
35

6+
from pandas._typing import HashableT
7+
48
def melt(
59
frame: DataFrame,
610
id_vars: tuple | list | np.ndarray | None = ...,
711
value_vars: tuple | list | np.ndarray | None = ...,
812
var_name: str | None = ...,
9-
value_name: str = ...,
13+
value_name: Hashable = ...,
1014
col_level: int | str | None = ...,
1115
ignore_index: bool = ...,
1216
) -> DataFrame: ...
13-
def lreshape(data: DataFrame, groups, dropna: bool = ..., label=...) -> DataFrame: ...
17+
def lreshape(
18+
data: DataFrame, groups: dict[HashableT, list[HashableT]], dropna: bool = ...
19+
) -> DataFrame: ...
1420
def wide_to_long(
15-
df: DataFrame, stubnames, i, j, sep: str = ..., suffix: str = ...
21+
df: DataFrame,
22+
stubnames: str | list[str],
23+
i: str | list[str],
24+
j: str,
25+
sep: str = ...,
26+
suffix: str = ...,
1627
) -> DataFrame: ...

tests/test_pandas.py

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def test_unique() -> None:
208208
]
209209
)
210210
),
211-
pd.Index,
211+
Union[pd.Index, np.ndarray],
212212
),
213213
pd.DatetimeIndex,
214214
)
@@ -246,6 +246,34 @@ def test_unique() -> None:
246246
),
247247
np.ndarray,
248248
)
249+
check(
250+
assert_type(
251+
pd.unique(pd.Index(["a", "b", "c", "a"])), Union[pd.Index, np.ndarray]
252+
),
253+
np.ndarray,
254+
)
255+
check(
256+
assert_type(pd.unique(pd.RangeIndex(0, 10)), np.ndarray),
257+
np.ndarray,
258+
)
259+
check(
260+
assert_type(pd.unique(pd.Categorical(["a", "b", "c", "a"])), pd.Categorical),
261+
pd.Categorical,
262+
)
263+
check(
264+
assert_type(
265+
pd.unique(pd.period_range("2001Q1", periods=10, freq="D")),
266+
pd.PeriodIndex,
267+
),
268+
pd.PeriodIndex,
269+
)
270+
check(
271+
assert_type(
272+
pd.unique(pd.timedelta_range(start="1 day", periods=4)),
273+
Union[pd.Index, np.ndarray],
274+
),
275+
np.ndarray,
276+
)
249277

250278

251279
# GH 200
@@ -423,3 +451,106 @@ def test_to_numeric_array_series() -> None:
423451
assert_type(pd.to_numeric(pd.Series([1, 2, 3]), downcast="float"), pd.Series),
424452
pd.Series,
425453
)
454+
455+
456+
def test_wide_to_long():
457+
df = pd.DataFrame(
458+
{
459+
"A1970": {0: "a", 1: "b", 2: "c"},
460+
"A1980": {0: "d", 1: "e", 2: "f"},
461+
"B1970": {0: 2.5, 1: 1.2, 2: 0.7},
462+
"B1980": {0: 3.2, 1: 1.3, 2: 0.1},
463+
"X": dict(zip(range(3), np.random.randn(3))),
464+
}
465+
)
466+
df["id"] = df.index
467+
df["id2"] = df.index + 1
468+
check(
469+
assert_type(pd.wide_to_long(df, ["A", "B"], i="id", j="year"), pd.DataFrame),
470+
pd.DataFrame,
471+
)
472+
check(
473+
assert_type(
474+
pd.wide_to_long(df, ["A", "B"], i=["id", "id2"], j="year"), pd.DataFrame
475+
),
476+
pd.DataFrame,
477+
)
478+
479+
480+
def test_melt():
481+
df = pd.DataFrame(
482+
{
483+
"A": {0: "a", 1: "b", 2: "c"},
484+
"B": {0: 1, 1: 3, 2: 5},
485+
"C": {0: 2, 1: 4, 2: 6},
486+
"D": {0: 3, 1: 6, 2: 9},
487+
"E": {0: 3, 1: 6, 2: 9},
488+
}
489+
)
490+
check(
491+
assert_type(
492+
pd.melt(df, id_vars=["A"], value_vars=["B"], ignore_index=False),
493+
pd.DataFrame,
494+
),
495+
pd.DataFrame,
496+
)
497+
check(
498+
assert_type(
499+
pd.melt(df, id_vars=["A"], value_vars=["B"], value_name=("F",)),
500+
pd.DataFrame,
501+
),
502+
pd.DataFrame,
503+
)
504+
df.columns = pd.MultiIndex.from_arrays([list("ABCDE"), list("FGHIJ")])
505+
check(
506+
assert_type(
507+
pd.melt(
508+
df, id_vars=["A"], value_vars=["B"], ignore_index=False, col_level=0
509+
),
510+
pd.DataFrame,
511+
),
512+
pd.DataFrame,
513+
)
514+
515+
516+
def test_lreshape() -> None:
517+
data = pd.DataFrame(
518+
{
519+
"hr1": [514, 573],
520+
"hr2": [545, 526],
521+
"team": ["Red Sox", "Yankees"],
522+
"year1": [2007, 2007],
523+
"year2": [2008, 2008],
524+
}
525+
)
526+
check(
527+
assert_type(
528+
pd.lreshape(
529+
data, {"year": ["year1", "year2"], "hr": ["hr1", "hr2"]}, dropna=True
530+
),
531+
pd.DataFrame,
532+
),
533+
pd.DataFrame,
534+
)
535+
data2 = pd.DataFrame(
536+
{
537+
"hr1": [514, 573],
538+
("hr2",): [545, 526],
539+
"team": ["Red Sox", "Yankees"],
540+
("year1",): [2007, 2007],
541+
"year2": [2008, 2008],
542+
}
543+
)
544+
from typing import Hashable
545+
546+
groups: dict[Hashable, list[Hashable]] = {
547+
("year",): [("year1",), "year2"],
548+
("hr",): ["hr1", ("hr2",)],
549+
}
550+
check(
551+
assert_type(
552+
pd.lreshape(data2, groups=groups),
553+
pd.DataFrame,
554+
),
555+
pd.DataFrame,
556+
)

0 commit comments

Comments
 (0)