Skip to content

Commit fc705de

Browse files
committed
remove slice and tuple from groupby getitem
1 parent 0f12dc5 commit fc705de

File tree

3 files changed

+10
-13
lines changed

3 files changed

+10
-13
lines changed

pandas-stubs/core/groupby/generic.pyi

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -229,10 +229,8 @@ class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT]):
229229
def filter(
230230
self, func: Callable, dropna: bool = ..., *args, **kwargs
231231
) -> DataFrame: ...
232-
@overload # type: ignore[override]
233-
def __getitem__(self, key: slice) -> DataFrameGroupBy[ByT]: ...
234232
@overload
235-
def __getitem__(self, key: Scalar | tuple[Hashable, ...]) -> SeriesGroupBy[Any, ByT]: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
233+
def __getitem__(self, key: Scalar) -> SeriesGroupBy[Any, ByT]: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
236234
@overload
237235
def __getitem__( # pyright: ignore[reportIncompatibleMethodOverride]
238236
self, key: Iterable[Hashable]

pandas-stubs/core/groupby/groupby.pyi

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -139,22 +139,15 @@ class BaseGroupBy(SelectionMixin[NDFrameT], GroupByIndexingMixin):
139139
@final
140140
def __iter__(self) -> Iterator[tuple[Hashable, NDFrameT]]: ...
141141
@overload
142-
def __getitem__(self: BaseGroupBy[DataFrame], key: Scalar | tuple[Hashable, ...]) -> generic.SeriesGroupBy: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
142+
def __getitem__(self: BaseGroupBy[DataFrame], key: Scalar) -> generic.SeriesGroupBy: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
143143
@overload
144144
def __getitem__(
145-
self: BaseGroupBy[DataFrame], key: Iterable[Hashable] | slice
145+
self: BaseGroupBy[DataFrame], key: Iterable[Hashable]
146146
) -> generic.DataFrameGroupBy: ...
147147
@overload
148148
def __getitem__(
149149
self: BaseGroupBy[Series[S1]],
150-
idx: (
151-
list[str]
152-
| Index
153-
| Series[S1]
154-
| slice
155-
| MaskType
156-
| tuple[Hashable | slice, ...]
157-
),
150+
idx: list[str] | Index | Series[S1] | MaskType | tuple[Hashable | slice, ...],
158151
) -> generic.SeriesGroupBy: ...
159152
@overload
160153
def __getitem__(self: BaseGroupBy[Series[S1]], idx: Scalar) -> S1: ...

tests/test_groupby.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,3 +1052,9 @@ def test_engine() -> None:
10521052
other_kwarg="",
10531053
)
10541054
GB_DF.aggregate("size", engine="cython", engine_kwargs={})
1055+
1056+
1057+
def test_groupby_getitem() -> None:
1058+
df = DataFrame(np.random.random((3, 4)), columns=["a", "b", "c", "d"])
1059+
check(assert_type(df.groupby("a")["b"].sum(), Series), Series, float)
1060+
check(assert_type(df.groupby("a")[["b", "c"]].sum(), DataFrame), DataFrame)

0 commit comments

Comments
 (0)