From 5713cdff21055358c6457da104f6308234121c5d Mon Sep 17 00:00:00 2001 From: Soshi Katsuta Date: Mon, 12 Feb 2024 17:04:16 +0900 Subject: [PATCH] Return Series or DataFrame if a scalar is given to DataFrame#loc Fix #749. --- pandas-stubs/core/frame.pyi | 5 +++-- tests/test_frame.py | 31 ++++++++++++++++++++----------- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 45f02ca25..116ba52e4 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -156,6 +156,8 @@ class _iLocIndexerFrame(_iLocIndexer): ) -> None: ... class _LocIndexerFrame(_LocIndexer): + @overload + def __getitem__(self, idx: Scalar) -> Series | DataFrame: ... @overload def __getitem__( self, @@ -184,8 +186,7 @@ class _LocIndexerFrame(_LocIndexer): @overload def __getitem__( self, - idx: ScalarT - | Callable[[DataFrame], ScalarT] + idx: Callable[[DataFrame], ScalarT] | tuple[ IndexType | MaskType diff --git a/tests/test_frame.py b/tests/test_frame.py index cd4ace1ff..c193f2b49 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -2199,16 +2199,16 @@ def test_frame_scalars_slice() -> None: # Note: bool_ cannot be tested since the index is object and pandas does not # support boolean access using loc except when the index is boolean - check(assert_type(df.loc[str_], pd.Series), pd.Series) - check(assert_type(df.loc[bytes_], pd.Series), pd.Series) - check(assert_type(df.loc[date], pd.Series), pd.Series) - check(assert_type(df.loc[datetime_], pd.Series), pd.Series) - check(assert_type(df.loc[timedelta], pd.Series), pd.Series) - check(assert_type(df.loc[int_], pd.Series), pd.Series) - check(assert_type(df.loc[float_], pd.Series), pd.Series) - check(assert_type(df.loc[complex_], pd.Series), pd.Series) - check(assert_type(df.loc[timestamp], pd.Series), pd.Series) - check(assert_type(df.loc[pd_timedelta], pd.Series), pd.Series) + check(assert_type(df.loc[str_], Union[pd.Series, pd.DataFrame]), pd.Series) + check(assert_type(df.loc[bytes_], Union[pd.Series, pd.DataFrame]), pd.Series) + check(assert_type(df.loc[date], Union[pd.Series, pd.DataFrame]), pd.Series) + check(assert_type(df.loc[datetime_], Union[pd.Series, pd.DataFrame]), pd.Series) + check(assert_type(df.loc[timedelta], Union[pd.Series, pd.DataFrame]), pd.Series) + check(assert_type(df.loc[int_], Union[pd.Series, pd.DataFrame]), pd.Series) + check(assert_type(df.loc[float_], Union[pd.Series, pd.DataFrame]), pd.Series) + check(assert_type(df.loc[complex_], Union[pd.Series, pd.DataFrame]), pd.Series) + check(assert_type(df.loc[timestamp], Union[pd.Series, pd.DataFrame]), pd.Series) + check(assert_type(df.loc[pd_timedelta], Union[pd.Series, pd.DataFrame]), pd.Series) check(assert_type(df.loc[none], pd.Series), pd.Series) check(assert_type(df.loc[:, str_], pd.Series), pd.Series) @@ -2223,11 +2223,20 @@ def test_frame_scalars_slice() -> None: check(assert_type(df.loc[:, pd_timedelta], pd.Series), pd.Series) check(assert_type(df.loc[:, none], pd.Series), pd.Series) + # GH749 + + multi_idx = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=["alpha", "num"]) + df2 = pd.DataFrame({"col1": range(4)}, index=multi_idx) + check(assert_type(df2.loc[str_], Union[pd.Series, pd.DataFrame]), pd.DataFrame) + + df3 = pd.DataFrame({"x": range(2)}, index=pd.Index(["a", "b"])) + check(assert_type(df3.loc[str_], Union[pd.Series, pd.DataFrame]), pd.Series) + def test_boolean_loc() -> None: # Booleans can only be used in loc when the index is boolean df = pd.DataFrame([[0, 1], [1, 0]], columns=[True, False], index=[True, False]) - check(assert_type(df.loc[True], pd.Series), pd.Series) + check(assert_type(df.loc[True], Union[pd.Series, pd.DataFrame]), pd.Series) check(assert_type(df.loc[:, False], pd.Series), pd.Series)