diff --git a/pandas-stubs/core/indexing.pyi b/pandas-stubs/core/indexing.pyi index 938696292..90374106b 100644 --- a/pandas-stubs/core/indexing.pyi +++ b/pandas-stubs/core/indexing.pyi @@ -1,5 +1,4 @@ from typing import ( - Generic, TypeVar, Union, ) @@ -9,14 +8,19 @@ from pandas.core.indexes.api import Index from pandas._libs.indexing import _NDFrameIndexerBase from pandas._typing import ( + MaskType, Scalar, - StrLike, + ScalarT, ) -_IndexSliceT = TypeVar("_IndexSliceT", bound=Union[StrLike, Scalar, slice]) +_IndexSliceTuple = Union[ + slice, tuple[Union[Index, MaskType, Scalar, list[ScalarT], slice], ...] +] -class _IndexSlice(Generic[_IndexSliceT]): - def __getitem__(self, arg) -> tuple[_IndexSliceT, ...]: ... +_IndexSliceTupleT = TypeVar("_IndexSliceTupleT", bound=_IndexSliceTuple) + +class _IndexSlice: + def __getitem__(self, arg: _IndexSliceTupleT) -> _IndexSliceTupleT: ... IndexSlice: _IndexSlice diff --git a/tests/test_frame.py b/tests/test_frame.py index c1e66c3fe..d4c4f2f6c 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -1278,6 +1278,28 @@ def test_indexslice_setitem(): df.loc[pd.IndexSlice[2, :], "z"] = [200, 300] +def test_indexslice_getitem(): + # GH 300 + df = ( + pd.DataFrame({"x": [1, 2, 2, 3, 4], "y": [10, 20, 30, 40, 10]}) + .assign(z=lambda df: df.x * df.y) + .set_index(["x", "y"]) + ) + ind = pd.Index([2, 3]) + check(assert_type(pd.IndexSlice[ind, :], "tuple[pd.Index, slice]"), tuple) + check(assert_type(df.loc[pd.IndexSlice[ind, :]], pd.DataFrame), pd.DataFrame) + check(assert_type(df.loc[pd.IndexSlice[1:2]], pd.DataFrame), pd.DataFrame) + check( + assert_type(df.loc[pd.IndexSlice[:, df["z"] > 40], :], pd.DataFrame), + pd.DataFrame, + ) + check(assert_type(df.loc[pd.IndexSlice[2, 30], "z"], Scalar), np.int64) + check( + assert_type(df.loc[pd.IndexSlice[[2, 4], [20, 40]], :], pd.DataFrame), + pd.DataFrame, + ) + + def test_compute_values(): df = pd.DataFrame({"x": [1, 2, 3, 4]}) s: pd.Series = pd.Series([10, 20, 30, 40])