Skip to content

Commit 905ec28

Browse files
committed
allow callable in .loc
1 parent 261eabb commit 905ec28

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

pandas-stubs/core/frame.pyi

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ class _LocIndexerFrame(_LocIndexer):
157157
self,
158158
idx: IndexType
159159
| MaskType
160+
| Callable[[DataFrame], IndexType | MaskType | list[HashableT]]
160161
| list[HashableT]
161162
| tuple[
162163
IndexType | MaskType | list[HashableT] | Hashable,
@@ -167,14 +168,22 @@ class _LocIndexerFrame(_LocIndexer):
167168
def __getitem__(
168169
self,
169170
idx: tuple[
170-
int | StrLike | tuple[Scalar, ...], int | StrLike | tuple[Scalar, ...]
171+
int | StrLike | tuple[Scalar, ...] | Callable[[DataFrame], ScalarT],
172+
int | StrLike | tuple[Scalar, ...],
171173
],
172174
) -> Scalar: ...
173175
@overload
174176
def __getitem__(
175177
self,
176178
idx: ScalarT
177-
| tuple[IndexType | MaskType | _IndexSliceTuple, ScalarT | None]
179+
| Callable[[DataFrame], ScalarT]
180+
| tuple[
181+
IndexType
182+
| MaskType
183+
| _IndexSliceTuple
184+
| Callable[[DataFrame], ScalarT | list[HashableT] | IndexType | MaskType],
185+
ScalarT | None,
186+
]
178187
| None,
179188
) -> Series: ...
180189
@overload

tests/test_frame.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
TypedDict,
2323
TypeVar,
2424
Union,
25+
cast,
2526
)
2627

2728
import numpy as np
@@ -2363,3 +2364,24 @@ def test_frame_dropna_subset() -> None:
23632364
assert_type(df.dropna(subset=df.columns.drop("col1")), pd.DataFrame),
23642365
pd.DataFrame,
23652366
)
2367+
2368+
2369+
def test_loc_callable() -> None:
2370+
# GH 256
2371+
df = pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})
2372+
2373+
def select1(df: pd.DataFrame) -> pd.Series:
2374+
return df["x"] > 2.0
2375+
2376+
check(assert_type(df.loc[select1], pd.DataFrame), pd.DataFrame)
2377+
check(assert_type(df.loc[select1, :], pd.DataFrame), pd.DataFrame)
2378+
2379+
def select2(df: pd.DataFrame) -> list[Hashable]:
2380+
return [i for i in df.index if cast(int, i) % 2 == 1]
2381+
2382+
check(assert_type(df.loc[select2, "x"], pd.Series), pd.Series)
2383+
2384+
def select3(df: pd.DataFrame) -> int:
2385+
return 1
2386+
2387+
check(assert_type(df.loc[select3, "x"], Scalar), np.integer)

0 commit comments

Comments
 (0)