Skip to content

Commit 53c299f

Browse files
GH203 Split groupby with as_index (temptative) (#1014)
* GH203 Split groupby with as_index * Update to the fix * Update to the fix * Experiment for size * Experiment for size * GH203 Create new overload for DatetimeIndex * GH203 Fix lint
1 parent 1a314a0 commit 53c299f

File tree

4 files changed

+83
-17
lines changed

4 files changed

+83
-17
lines changed

pandas-stubs/core/frame.pyi

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,29 +1054,53 @@ class DataFrame(NDFrame, OpsMixin):
10541054
errors: IgnoreRaise = ...,
10551055
) -> None: ...
10561056
@overload
1057-
def groupby(
1057+
def groupby( # type: ignore[overload-overlap] # pyright: ignore reportOverlappingOverload
10581058
self,
10591059
by: Scalar,
10601060
axis: AxisIndex | NoDefault = ...,
10611061
level: IndexLabel | None = ...,
1062-
as_index: _bool = ...,
1062+
as_index: Literal[True] = True,
10631063
sort: _bool = ...,
10641064
group_keys: _bool = ...,
10651065
observed: _bool | NoDefault = ...,
10661066
dropna: _bool = ...,
1067-
) -> DataFrameGroupBy[Scalar]: ...
1067+
) -> DataFrameGroupBy[Scalar, Literal[True]]: ...
10681068
@overload
10691069
def groupby(
1070+
self,
1071+
by: Scalar,
1072+
axis: AxisIndex | NoDefault = ...,
1073+
level: IndexLabel | None = ...,
1074+
as_index: Literal[False] = ...,
1075+
sort: _bool = ...,
1076+
group_keys: _bool = ...,
1077+
observed: _bool | NoDefault = ...,
1078+
dropna: _bool = ...,
1079+
) -> DataFrameGroupBy[Scalar, Literal[False]]: ...
1080+
@overload
1081+
def groupby( # type: ignore[overload-overlap] # pyright: ignore reportOverlappingOverload
10701082
self,
10711083
by: DatetimeIndex,
10721084
axis: AxisIndex | NoDefault = ...,
10731085
level: IndexLabel | None = ...,
1074-
as_index: _bool = ...,
1086+
as_index: Literal[True] = True,
1087+
sort: _bool = ...,
1088+
group_keys: _bool = ...,
1089+
observed: _bool | NoDefault = ...,
1090+
dropna: _bool = ...,
1091+
) -> DataFrameGroupBy[Timestamp, Literal[True]]: ...
1092+
@overload
1093+
def groupby( # type: ignore[overload-overlap]
1094+
self,
1095+
by: DatetimeIndex,
1096+
axis: AxisIndex | NoDefault = ...,
1097+
level: IndexLabel | None = ...,
1098+
as_index: Literal[False] = ...,
10751099
sort: _bool = ...,
10761100
group_keys: _bool = ...,
10771101
observed: _bool | NoDefault = ...,
10781102
dropna: _bool = ...,
1079-
) -> DataFrameGroupBy[Timestamp]: ...
1103+
) -> DataFrameGroupBy[Timestamp, Literal[False]]: ...
10801104
@overload
10811105
def groupby(
10821106
self,
@@ -1088,7 +1112,7 @@ class DataFrame(NDFrame, OpsMixin):
10881112
group_keys: _bool = ...,
10891113
observed: _bool | NoDefault = ...,
10901114
dropna: _bool = ...,
1091-
) -> DataFrameGroupBy[Timedelta]: ...
1115+
) -> DataFrameGroupBy[Timedelta, bool]: ...
10921116
@overload
10931117
def groupby(
10941118
self,
@@ -1100,7 +1124,7 @@ class DataFrame(NDFrame, OpsMixin):
11001124
group_keys: _bool = ...,
11011125
observed: _bool | NoDefault = ...,
11021126
dropna: _bool = ...,
1103-
) -> DataFrameGroupBy[Period]: ...
1127+
) -> DataFrameGroupBy[Period, bool]: ...
11041128
@overload
11051129
def groupby(
11061130
self,
@@ -1112,7 +1136,7 @@ class DataFrame(NDFrame, OpsMixin):
11121136
group_keys: _bool = ...,
11131137
observed: _bool | NoDefault = ...,
11141138
dropna: _bool = ...,
1115-
) -> DataFrameGroupBy[IntervalT]: ...
1139+
) -> DataFrameGroupBy[IntervalT, bool]: ...
11161140
@overload
11171141
def groupby(
11181142
self,
@@ -1124,7 +1148,7 @@ class DataFrame(NDFrame, OpsMixin):
11241148
group_keys: _bool = ...,
11251149
observed: _bool | NoDefault = ...,
11261150
dropna: _bool = ...,
1127-
) -> DataFrameGroupBy[tuple]: ...
1151+
) -> DataFrameGroupBy[tuple, bool]: ...
11281152
@overload
11291153
def groupby(
11301154
self,
@@ -1136,7 +1160,7 @@ class DataFrame(NDFrame, OpsMixin):
11361160
group_keys: _bool = ...,
11371161
observed: _bool | NoDefault = ...,
11381162
dropna: _bool = ...,
1139-
) -> DataFrameGroupBy[SeriesByT]: ...
1163+
) -> DataFrameGroupBy[SeriesByT, bool]: ...
11401164
@overload
11411165
def groupby(
11421166
self,
@@ -1148,7 +1172,7 @@ class DataFrame(NDFrame, OpsMixin):
11481172
group_keys: _bool = ...,
11491173
observed: _bool | NoDefault = ...,
11501174
dropna: _bool = ...,
1151-
) -> DataFrameGroupBy[Any]: ...
1175+
) -> DataFrameGroupBy[Any, bool]: ...
11521176
def pivot(
11531177
self,
11541178
*,

pandas-stubs/core/groupby/generic.pyi

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ from typing import (
1111
Generic,
1212
Literal,
1313
NamedTuple,
14+
TypeVar,
1415
final,
1516
overload,
1617
)
@@ -29,6 +30,7 @@ from typing_extensions import (
2930
)
3031

3132
from pandas._libs.lib import NoDefault
33+
from pandas._libs.tslibs.timestamps import Timestamp
3234
from pandas._typing import (
3335
S1,
3436
AggFuncTypeBase,
@@ -182,7 +184,9 @@ class SeriesGroupBy(GroupBy[Series[S1]], Generic[S1, ByT]):
182184
self,
183185
) -> Iterator[tuple[ByT, Series[S1]]]: ...
184186

185-
class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT]):
187+
_TT = TypeVar("_TT", bound=Literal[True, False])
188+
189+
class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT, _TT]):
186190
# error: Overload 3 for "apply" will never be used because its parameters overlap overload 1
187191
@overload # type: ignore[override]
188192
def apply( # type: ignore[overload-overlap]
@@ -236,7 +240,7 @@ class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT]):
236240
@overload
237241
def __getitem__( # pyright: ignore[reportIncompatibleMethodOverride, reportOverlappingOverload]
238242
self, key: Iterable[Hashable] | slice
239-
) -> DataFrameGroupBy[ByT]: ...
243+
) -> DataFrameGroupBy[ByT, bool]: ...
240244
def nunique(self, dropna: bool = ...) -> DataFrame: ...
241245
def idxmax(
242246
self,
@@ -388,3 +392,11 @@ class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT]):
388392
def __iter__( # pyright: ignore[reportIncompatibleMethodOverride]
389393
self,
390394
) -> Iterator[tuple[ByT, DataFrame]]: ...
395+
@overload
396+
def size(self: DataFrameGroupBy[ByT, Literal[True]]) -> Series[int]: ...
397+
@overload
398+
def size(self: DataFrameGroupBy[ByT, Literal[False]]) -> DataFrame: ...
399+
@overload
400+
def size(self: DataFrameGroupBy[Timestamp, Literal[True]]) -> Series[int]: ...
401+
@overload
402+
def size(self: DataFrameGroupBy[Timestamp, Literal[False]]) -> DataFrame: ...

pandas-stubs/core/groupby/groupby.pyi

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -232,11 +232,7 @@ class GroupBy(BaseGroupBy[NDFrameT]):
232232
def sem(
233233
self: GroupBy[DataFrame], ddof: int = ..., numeric_only: bool = ...
234234
) -> DataFrame: ...
235-
@final
236-
@overload
237235
def size(self: GroupBy[Series]) -> Series[int]: ...
238-
@overload # return type depends on `as_index` for dataframe groupby
239-
def size(self: GroupBy[DataFrame]) -> DataFrame | Series[int]: ...
240236
@final
241237
def sum(
242238
self,

tests/test_frame.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,6 +1025,40 @@ def test_types_pivot_table() -> None:
10251025
)
10261026

10271027

1028+
def test_types_groupby_as_index() -> None:
1029+
df = pd.DataFrame({"a": [1, 2, 3]})
1030+
check(
1031+
assert_type(
1032+
df.groupby("a", as_index=False).size(),
1033+
pd.DataFrame,
1034+
),
1035+
pd.DataFrame,
1036+
)
1037+
check(
1038+
assert_type(
1039+
df.groupby("a", as_index=True).size(),
1040+
"pd.Series[int]",
1041+
),
1042+
pd.Series,
1043+
)
1044+
1045+
1046+
def test_types_groupby_size() -> None:
1047+
"""Test for GH886."""
1048+
data = [
1049+
{"date": "2023-12-01", "val": 12},
1050+
{"date": "2023-12-02", "val": 2},
1051+
{"date": "2023-12-03", "val": 1},
1052+
{"date": "2023-12-03", "val": 10},
1053+
]
1054+
1055+
df = pd.DataFrame(data)
1056+
groupby = df.groupby("date")
1057+
size = groupby.size()
1058+
frame = size.to_frame()
1059+
check(assert_type(frame.reset_index(), pd.DataFrame), pd.DataFrame)
1060+
1061+
10281062
def test_types_groupby() -> None:
10291063
df = pd.DataFrame(data={"col1": [1, 1, 2], "col2": [3, 4, 5], "col3": [0, 1, 0]})
10301064
df.index.name = "ind"

0 commit comments

Comments
 (0)