Skip to content

Commit 6fd6145

Browse files
authored
Add support for DataFrame#loc[..., NDArray[bool]] (#862)
Add support for DataFrame#loc[..., MaskType]
1 parent e35c3ca commit 6fd6145

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

pandas-stubs/core/frame.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ class _LocIndexerFrame(_LocIndexer):
170170
| slice
171171
| _IndexSliceTuple
172172
| Callable,
173-
list[HashableT] | slice | Series[bool] | Callable,
173+
MaskType | list[HashableT] | slice | Callable,
174174
],
175175
) -> DataFrame: ...
176176
@overload

tests/test_frame.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,17 @@ def test_types_loc_at() -> None:
259259

260260
def test_types_boolean_indexing() -> None:
261261
df = pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]})
262-
df[df > 1]
263-
df[~(df > 1.0)]
262+
check(assert_type(df[df > 1], pd.DataFrame), pd.DataFrame)
263+
check(assert_type(df[~(df > 1.0)], pd.DataFrame), pd.DataFrame)
264+
265+
row_mask = df["col1"] >= 2
266+
col_mask = df.columns.isin(["col2"])
267+
check(assert_type(df.loc[row_mask], pd.DataFrame), pd.DataFrame)
268+
check(assert_type(df.loc[~row_mask], pd.DataFrame), pd.DataFrame)
269+
check(assert_type(df.loc[row_mask, :], pd.DataFrame), pd.DataFrame)
270+
check(assert_type(df.loc[:, col_mask], pd.DataFrame), pd.DataFrame)
271+
check(assert_type(df.loc[row_mask, col_mask], pd.DataFrame), pd.DataFrame)
272+
check(assert_type(df.loc[~row_mask, ~col_mask], pd.DataFrame), pd.DataFrame)
264273

265274

266275
def test_types_df_to_df_comparison() -> None:

0 commit comments

Comments
 (0)