diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 45f02ca25..e9bf1a9bb 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -170,7 +170,7 @@ class _LocIndexerFrame(_LocIndexer): | slice | _IndexSliceTuple | Callable, - list[HashableT] | slice | Series[bool] | Callable, + MaskType | list[HashableT] | slice | Callable, ], ) -> DataFrame: ... @overload diff --git a/tests/test_frame.py b/tests/test_frame.py index cd4ace1ff..2d7a45b4e 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -259,8 +259,17 @@ def test_types_loc_at() -> None: def test_types_boolean_indexing() -> None: df = pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]}) - df[df > 1] - df[~(df > 1.0)] + check(assert_type(df[df > 1], pd.DataFrame), pd.DataFrame) + check(assert_type(df[~(df > 1.0)], pd.DataFrame), pd.DataFrame) + + row_mask = df["col1"] >= 2 + col_mask = df.columns.isin(["col2"]) + check(assert_type(df.loc[row_mask], pd.DataFrame), pd.DataFrame) + check(assert_type(df.loc[~row_mask], pd.DataFrame), pd.DataFrame) + check(assert_type(df.loc[row_mask, :], pd.DataFrame), pd.DataFrame) + check(assert_type(df.loc[:, col_mask], pd.DataFrame), pd.DataFrame) + check(assert_type(df.loc[row_mask, col_mask], pd.DataFrame), pd.DataFrame) + check(assert_type(df.loc[~row_mask, ~col_mask], pd.DataFrame), pd.DataFrame) def test_types_df_to_df_comparison() -> None: