Skip to content

Commit b5ba2ed

Browse files
committed
merge with upstream
2 parents fc705de + 53c299f commit b5ba2ed

File tree

7 files changed

+116
-47
lines changed

7 files changed

+116
-47
lines changed

pandas-stubs/_libs/tslibs/dtypes.pyi

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from enum import Enum
2+
from typing import cast
23

34
from .offsets import BaseOffset
45

@@ -29,16 +30,16 @@ class FreqGroup:
2930
def get_freq_group(code: int) -> int: ...
3031

3132
class Resolution(Enum):
32-
RESO_NS: int
33-
RESO_US: int
34-
RESO_MS: int
35-
RESO_SEC: int
36-
RESO_MIN: int
37-
RESO_HR: int
38-
RESO_DAY: int
39-
RESO_MTH: int
40-
RESO_QTR: int
41-
RESO_YR: int
33+
RESO_NS = cast(int, ...)
34+
RESO_US = cast(int, ...)
35+
RESO_MS = cast(int, ...)
36+
RESO_SEC = cast(int, ...)
37+
RESO_MIN = cast(int, ...)
38+
RESO_HR = cast(int, ...)
39+
RESO_DAY = cast(int, ...)
40+
RESO_MTH = cast(int, ...)
41+
RESO_QTR = cast(int, ...)
42+
RESO_YR = cast(int, ...)
4243

4344
def __lt__(self, other) -> bool: ...
4445
def __ge__(self, other) -> bool: ...

pandas-stubs/core/frame.pyi

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,29 +1054,53 @@ class DataFrame(NDFrame, OpsMixin):
10541054
errors: IgnoreRaise = ...,
10551055
) -> None: ...
10561056
@overload
1057+
def groupby( # pyright: ignore reportOverlappingOverload
1058+
self,
1059+
by: Scalar,
1060+
axis: AxisIndex | NoDefault = ...,
1061+
level: IndexLabel | None = ...,
1062+
as_index: Literal[True] = True,
1063+
sort: _bool = ...,
1064+
group_keys: _bool = ...,
1065+
observed: _bool | NoDefault = ...,
1066+
dropna: _bool = ...,
1067+
) -> DataFrameGroupBy[Scalar, Literal[True]]: ...
1068+
@overload
10571069
def groupby(
10581070
self,
10591071
by: Scalar,
10601072
axis: AxisIndex | NoDefault = ...,
10611073
level: IndexLabel | None = ...,
1062-
as_index: _bool = ...,
1074+
as_index: Literal[False] = ...,
10631075
sort: _bool = ...,
10641076
group_keys: _bool = ...,
10651077
observed: _bool | NoDefault = ...,
10661078
dropna: _bool = ...,
1067-
) -> DataFrameGroupBy[Scalar]: ...
1079+
) -> DataFrameGroupBy[Scalar, Literal[False]]: ...
1080+
@overload
1081+
def groupby( # pyright: ignore reportOverlappingOverload
1082+
self,
1083+
by: DatetimeIndex,
1084+
axis: AxisIndex | NoDefault = ...,
1085+
level: IndexLabel | None = ...,
1086+
as_index: Literal[True] = True,
1087+
sort: _bool = ...,
1088+
group_keys: _bool = ...,
1089+
observed: _bool | NoDefault = ...,
1090+
dropna: _bool = ...,
1091+
) -> DataFrameGroupBy[Timestamp, Literal[True]]: ...
10681092
@overload
10691093
def groupby(
10701094
self,
10711095
by: DatetimeIndex,
10721096
axis: AxisIndex | NoDefault = ...,
10731097
level: IndexLabel | None = ...,
1074-
as_index: _bool = ...,
1098+
as_index: Literal[False] = ...,
10751099
sort: _bool = ...,
10761100
group_keys: _bool = ...,
10771101
observed: _bool | NoDefault = ...,
10781102
dropna: _bool = ...,
1079-
) -> DataFrameGroupBy[Timestamp]: ...
1103+
) -> DataFrameGroupBy[Timestamp, Literal[False]]: ...
10801104
@overload
10811105
def groupby(
10821106
self,
@@ -1088,7 +1112,7 @@ class DataFrame(NDFrame, OpsMixin):
10881112
group_keys: _bool = ...,
10891113
observed: _bool | NoDefault = ...,
10901114
dropna: _bool = ...,
1091-
) -> DataFrameGroupBy[Timedelta]: ...
1115+
) -> DataFrameGroupBy[Timedelta, bool]: ...
10921116
@overload
10931117
def groupby(
10941118
self,
@@ -1100,7 +1124,7 @@ class DataFrame(NDFrame, OpsMixin):
11001124
group_keys: _bool = ...,
11011125
observed: _bool | NoDefault = ...,
11021126
dropna: _bool = ...,
1103-
) -> DataFrameGroupBy[Period]: ...
1127+
) -> DataFrameGroupBy[Period, bool]: ...
11041128
@overload
11051129
def groupby(
11061130
self,
@@ -1112,7 +1136,7 @@ class DataFrame(NDFrame, OpsMixin):
11121136
group_keys: _bool = ...,
11131137
observed: _bool | NoDefault = ...,
11141138
dropna: _bool = ...,
1115-
) -> DataFrameGroupBy[IntervalT]: ...
1139+
) -> DataFrameGroupBy[IntervalT, bool]: ...
11161140
@overload
11171141
def groupby(
11181142
self,
@@ -1124,7 +1148,7 @@ class DataFrame(NDFrame, OpsMixin):
11241148
group_keys: _bool = ...,
11251149
observed: _bool | NoDefault = ...,
11261150
dropna: _bool = ...,
1127-
) -> DataFrameGroupBy[tuple]: ...
1151+
) -> DataFrameGroupBy[tuple, bool]: ...
11281152
@overload
11291153
def groupby(
11301154
self,
@@ -1136,7 +1160,7 @@ class DataFrame(NDFrame, OpsMixin):
11361160
group_keys: _bool = ...,
11371161
observed: _bool | NoDefault = ...,
11381162
dropna: _bool = ...,
1139-
) -> DataFrameGroupBy[SeriesByT]: ...
1163+
) -> DataFrameGroupBy[SeriesByT, bool]: ...
11401164
@overload
11411165
def groupby(
11421166
self,
@@ -1148,7 +1172,7 @@ class DataFrame(NDFrame, OpsMixin):
11481172
group_keys: _bool = ...,
11491173
observed: _bool | NoDefault = ...,
11501174
dropna: _bool = ...,
1151-
) -> DataFrameGroupBy[Any]: ...
1175+
) -> DataFrameGroupBy[Any, bool]: ...
11521176
def pivot(
11531177
self,
11541178
*,

pandas-stubs/core/groupby/generic.pyi

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ from typing import (
1111
Generic,
1212
Literal,
1313
NamedTuple,
14+
TypeVar,
1415
final,
1516
overload,
1617
)
@@ -29,6 +30,7 @@ from typing_extensions import (
2930
)
3031

3132
from pandas._libs.lib import NoDefault
33+
from pandas._libs.tslibs.timestamps import Timestamp
3234
from pandas._typing import (
3335
S1,
3436
AggFuncTypeBase,
@@ -182,7 +184,9 @@ class SeriesGroupBy(GroupBy[Series[S1]], Generic[S1, ByT]):
182184
self,
183185
) -> Iterator[tuple[ByT, Series[S1]]]: ...
184186

185-
class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT]):
187+
_TT = TypeVar("_TT", bound=Literal[True, False])
188+
189+
class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT, _TT]):
186190
# error: Overload 3 for "apply" will never be used because its parameters overlap overload 1
187191
@overload # type: ignore[override]
188192
def apply(
@@ -234,7 +238,7 @@ class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT]):
234238
@overload
235239
def __getitem__( # pyright: ignore[reportIncompatibleMethodOverride]
236240
self, key: Iterable[Hashable]
237-
) -> DataFrameGroupBy[ByT]: ...
241+
) -> DataFrameGroupBy[ByT, bool]: ...
238242
def nunique(self, dropna: bool = ...) -> DataFrame: ...
239243
def idxmax(
240244
self,
@@ -386,3 +390,11 @@ class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT]):
386390
def __iter__( # pyright: ignore[reportIncompatibleMethodOverride]
387391
self,
388392
) -> Iterator[tuple[ByT, DataFrame]]: ...
393+
@overload
394+
def size(self: DataFrameGroupBy[ByT, Literal[True]]) -> Series[int]: ...
395+
@overload
396+
def size(self: DataFrameGroupBy[ByT, Literal[False]]) -> DataFrame: ...
397+
@overload
398+
def size(self: DataFrameGroupBy[Timestamp, Literal[True]]) -> Series[int]: ...
399+
@overload
400+
def size(self: DataFrameGroupBy[Timestamp, Literal[False]]) -> DataFrame: ...

pandas-stubs/core/groupby/groupby.pyi

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,11 +225,7 @@ class GroupBy(BaseGroupBy[NDFrameT]):
225225
def sem(
226226
self: GroupBy[DataFrame], ddof: int = ..., numeric_only: bool = ...
227227
) -> DataFrame: ...
228-
@final
229-
@overload
230228
def size(self: GroupBy[Series]) -> Series[int]: ...
231-
@overload # return type depends on `as_index` for dataframe groupby
232-
def size(self: GroupBy[DataFrame]) -> DataFrame | Series[int]: ...
233229
@final
234230
def sum(
235231
self,

pandas-stubs/core/interchange/dataframe_protocol.pyi

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,33 +11,34 @@ import enum
1111
from typing import (
1212
Any,
1313
TypedDict,
14+
cast,
1415
)
1516

1617
class DlpackDeviceType(enum.IntEnum):
17-
CPU: int
18-
CUDA: int
19-
CPU_PINNED: int
20-
OPENCL: int
21-
VULKAN: int
22-
METAL: int
23-
VPI: int
24-
ROCM: int
18+
CPU = cast(int, ...)
19+
CUDA = cast(int, ...)
20+
CPU_PINNED = cast(int, ...)
21+
OPENCL = cast(int, ...)
22+
VULKAN = cast(int, ...)
23+
METAL = cast(int, ...)
24+
VPI = cast(int, ...)
25+
ROCM = cast(int, ...)
2526

2627
class DtypeKind(enum.IntEnum):
27-
INT: int
28-
UINT: int
29-
FLOAT: int
30-
BOOL: int
31-
STRING: int
32-
DATETIME: int
33-
CATEGORICAL: int
28+
INT = cast(int, ...)
29+
UINT = cast(int, ...)
30+
FLOAT = cast(int, ...)
31+
BOOL = cast(int, ...)
32+
STRING = cast(int, ...)
33+
DATETIME = cast(int, ...)
34+
CATEGORICAL = cast(int, ...)
3435

3536
class ColumnNullType(enum.IntEnum):
36-
NON_NULLABLE: int
37-
USE_NAN: int
38-
USE_SENTINEL: int
39-
USE_BITMASK: int
40-
USE_BYTEMASK: int
37+
NON_NULLABLE = cast(int, ...)
38+
USE_NAN = cast(int, ...)
39+
USE_SENTINEL = cast(int, ...)
40+
USE_BITMASK = cast(int, ...)
41+
USE_BYTEMASK = cast(int, ...)
4142

4243
class ColumnBuffers(TypedDict):
4344
data: tuple[Buffer, Any]

scripts/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import sys
2+
from typing import Any
23

34
from loguru import logger
45

56
# Config the format of log message
6-
config = {
7+
config: dict[str, Any] = {
78
"handlers": [
89
{
910
"sink": sys.stderr,

tests/test_frame.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,6 +1025,40 @@ def test_types_pivot_table() -> None:
10251025
)
10261026

10271027

1028+
def test_types_groupby_as_index() -> None:
1029+
df = pd.DataFrame({"a": [1, 2, 3]})
1030+
check(
1031+
assert_type(
1032+
df.groupby("a", as_index=False).size(),
1033+
pd.DataFrame,
1034+
),
1035+
pd.DataFrame,
1036+
)
1037+
check(
1038+
assert_type(
1039+
df.groupby("a", as_index=True).size(),
1040+
"pd.Series[int]",
1041+
),
1042+
pd.Series,
1043+
)
1044+
1045+
1046+
def test_types_groupby_size() -> None:
1047+
"""Test for GH886."""
1048+
data = [
1049+
{"date": "2023-12-01", "val": 12},
1050+
{"date": "2023-12-02", "val": 2},
1051+
{"date": "2023-12-03", "val": 1},
1052+
{"date": "2023-12-03", "val": 10},
1053+
]
1054+
1055+
df = pd.DataFrame(data)
1056+
groupby = df.groupby("date")
1057+
size = groupby.size()
1058+
frame = size.to_frame()
1059+
check(assert_type(frame.reset_index(), pd.DataFrame), pd.DataFrame)
1060+
1061+
10281062
def test_types_groupby() -> None:
10291063
df = pd.DataFrame(data={"col1": [1, 1, 2], "col2": [3, 4, 5], "col3": [0, 1, 0]})
10301064
df.index.name = "ind"

0 commit comments

Comments
 (0)