From 905ec288dc812ca7ab9dd18ddbdecf8c7cce8af0 Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Thu, 12 Jan 2023 11:53:53 -0500 Subject: [PATCH] allow callable in .loc --- pandas-stubs/core/frame.pyi | 13 +++++++++++-- tests/test_frame.py | 22 ++++++++++++++++++++++ 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 8ffe3bac2..a036bf176 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -157,6 +157,7 @@ class _LocIndexerFrame(_LocIndexer): self, idx: IndexType | MaskType + | Callable[[DataFrame], IndexType | MaskType | list[HashableT]] | list[HashableT] | tuple[ IndexType | MaskType | list[HashableT] | Hashable, @@ -167,14 +168,22 @@ class _LocIndexerFrame(_LocIndexer): def __getitem__( self, idx: tuple[ - int | StrLike | tuple[Scalar, ...], int | StrLike | tuple[Scalar, ...] + int | StrLike | tuple[Scalar, ...] | Callable[[DataFrame], ScalarT], + int | StrLike | tuple[Scalar, ...], ], ) -> Scalar: ... @overload def __getitem__( self, idx: ScalarT - | tuple[IndexType | MaskType | _IndexSliceTuple, ScalarT | None] + | Callable[[DataFrame], ScalarT] + | tuple[ + IndexType + | MaskType + | _IndexSliceTuple + | Callable[[DataFrame], ScalarT | list[HashableT] | IndexType | MaskType], + ScalarT | None, + ] | None, ) -> Series: ... @overload diff --git a/tests/test_frame.py b/tests/test_frame.py index 7b13ee77f..37e1238d9 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -22,6 +22,7 @@ TypedDict, TypeVar, Union, + cast, ) import numpy as np @@ -2363,3 +2364,24 @@ def test_frame_dropna_subset() -> None: assert_type(df.dropna(subset=df.columns.drop("col1")), pd.DataFrame), pd.DataFrame, ) + + +def test_loc_callable() -> None: + # GH 256 + df = pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]}) + + def select1(df: pd.DataFrame) -> pd.Series: + return df["x"] > 2.0 + + check(assert_type(df.loc[select1], pd.DataFrame), pd.DataFrame) + check(assert_type(df.loc[select1, :], pd.DataFrame), pd.DataFrame) + + def select2(df: pd.DataFrame) -> list[Hashable]: + return [i for i in df.index if cast(int, i) % 2 == 1] + + check(assert_type(df.loc[select2, "x"], pd.Series), pd.Series) + + def select3(df: pd.DataFrame) -> int: + return 1 + + check(assert_type(df.loc[select3, "x"], Scalar), np.integer)