diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 8ffe3bac2..a036bf176 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -157,6 +157,7 @@ class _LocIndexerFrame(_LocIndexer): self, idx: IndexType | MaskType + | Callable[[DataFrame], IndexType | MaskType | list[HashableT]] | list[HashableT] | tuple[ IndexType | MaskType | list[HashableT] | Hashable, @@ -167,14 +168,22 @@ class _LocIndexerFrame(_LocIndexer): def __getitem__( self, idx: tuple[ - int | StrLike | tuple[Scalar, ...], int | StrLike | tuple[Scalar, ...] + int | StrLike | tuple[Scalar, ...] | Callable[[DataFrame], ScalarT], + int | StrLike | tuple[Scalar, ...], ], ) -> Scalar: ... @overload def __getitem__( self, idx: ScalarT - | tuple[IndexType | MaskType | _IndexSliceTuple, ScalarT | None] + | Callable[[DataFrame], ScalarT] + | tuple[ + IndexType + | MaskType + | _IndexSliceTuple + | Callable[[DataFrame], ScalarT | list[HashableT] | IndexType | MaskType], + ScalarT | None, + ] | None, ) -> Series: ... @overload diff --git a/tests/test_frame.py b/tests/test_frame.py index 66716d79a..e31c05a21 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -22,6 +22,7 @@ TypedDict, TypeVar, Union, + cast, ) import numpy as np @@ -2366,6 +2367,27 @@ def test_frame_dropna_subset() -> None: ) +def test_loc_callable() -> None: + # GH 256 + df = pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]}) + + def select1(df: pd.DataFrame) -> pd.Series: + return df["x"] > 2.0 + + check(assert_type(df.loc[select1], pd.DataFrame), pd.DataFrame) + check(assert_type(df.loc[select1, :], pd.DataFrame), pd.DataFrame) + + def select2(df: pd.DataFrame) -> list[Hashable]: + return [i for i in df.index if cast(int, i) % 2 == 1] + + check(assert_type(df.loc[select2, "x"], pd.Series), pd.Series) + + def select3(df: pd.DataFrame) -> int: + return 1 + + check(assert_type(df.loc[select3, "x"], Scalar), np.integer) + + def test_npint_loc_indexer() -> None: # GH 508