Skip to content

Commit 176805d

Browse files
authored
Make DataFrame#loc return Series or DataFrame if a scalar is given (#866)
Return Series or DataFrame if a scalar is given to DataFrame#loc Fix #749.
1 parent fe28163 commit 176805d

File tree

2 files changed

+23
-13
lines changed

2 files changed

+23
-13
lines changed

pandas-stubs/core/frame.pyi

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ class _iLocIndexerFrame(_iLocIndexer):
156156
) -> None: ...
157157

158158
class _LocIndexerFrame(_LocIndexer):
159+
@overload
160+
def __getitem__(self, idx: Scalar) -> Series | DataFrame: ...
159161
@overload
160162
def __getitem__(
161163
self,
@@ -184,8 +186,7 @@ class _LocIndexerFrame(_LocIndexer):
184186
@overload
185187
def __getitem__(
186188
self,
187-
idx: ScalarT
188-
| Callable[[DataFrame], ScalarT]
189+
idx: Callable[[DataFrame], ScalarT]
189190
| tuple[
190191
IndexType
191192
| MaskType

tests/test_frame.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2208,16 +2208,16 @@ def test_frame_scalars_slice() -> None:
22082208

22092209
# Note: bool_ cannot be tested since the index is object and pandas does not
22102210
# support boolean access using loc except when the index is boolean
2211-
check(assert_type(df.loc[str_], pd.Series), pd.Series)
2212-
check(assert_type(df.loc[bytes_], pd.Series), pd.Series)
2213-
check(assert_type(df.loc[date], pd.Series), pd.Series)
2214-
check(assert_type(df.loc[datetime_], pd.Series), pd.Series)
2215-
check(assert_type(df.loc[timedelta], pd.Series), pd.Series)
2216-
check(assert_type(df.loc[int_], pd.Series), pd.Series)
2217-
check(assert_type(df.loc[float_], pd.Series), pd.Series)
2218-
check(assert_type(df.loc[complex_], pd.Series), pd.Series)
2219-
check(assert_type(df.loc[timestamp], pd.Series), pd.Series)
2220-
check(assert_type(df.loc[pd_timedelta], pd.Series), pd.Series)
2211+
check(assert_type(df.loc[str_], Union[pd.Series, pd.DataFrame]), pd.Series)
2212+
check(assert_type(df.loc[bytes_], Union[pd.Series, pd.DataFrame]), pd.Series)
2213+
check(assert_type(df.loc[date], Union[pd.Series, pd.DataFrame]), pd.Series)
2214+
check(assert_type(df.loc[datetime_], Union[pd.Series, pd.DataFrame]), pd.Series)
2215+
check(assert_type(df.loc[timedelta], Union[pd.Series, pd.DataFrame]), pd.Series)
2216+
check(assert_type(df.loc[int_], Union[pd.Series, pd.DataFrame]), pd.Series)
2217+
check(assert_type(df.loc[float_], Union[pd.Series, pd.DataFrame]), pd.Series)
2218+
check(assert_type(df.loc[complex_], Union[pd.Series, pd.DataFrame]), pd.Series)
2219+
check(assert_type(df.loc[timestamp], Union[pd.Series, pd.DataFrame]), pd.Series)
2220+
check(assert_type(df.loc[pd_timedelta], Union[pd.Series, pd.DataFrame]), pd.Series)
22212221
check(assert_type(df.loc[none], pd.Series), pd.Series)
22222222

22232223
check(assert_type(df.loc[:, str_], pd.Series), pd.Series)
@@ -2232,11 +2232,20 @@ def test_frame_scalars_slice() -> None:
22322232
check(assert_type(df.loc[:, pd_timedelta], pd.Series), pd.Series)
22332233
check(assert_type(df.loc[:, none], pd.Series), pd.Series)
22342234

2235+
# GH749
2236+
2237+
multi_idx = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=["alpha", "num"])
2238+
df2 = pd.DataFrame({"col1": range(4)}, index=multi_idx)
2239+
check(assert_type(df2.loc[str_], Union[pd.Series, pd.DataFrame]), pd.DataFrame)
2240+
2241+
df3 = pd.DataFrame({"x": range(2)}, index=pd.Index(["a", "b"]))
2242+
check(assert_type(df3.loc[str_], Union[pd.Series, pd.DataFrame]), pd.Series)
2243+
22352244

22362245
def test_boolean_loc() -> None:
22372246
# Booleans can only be used in loc when the index is boolean
22382247
df = pd.DataFrame([[0, 1], [1, 0]], columns=[True, False], index=[True, False])
2239-
check(assert_type(df.loc[True], pd.Series), pd.Series)
2248+
check(assert_type(df.loc[True], Union[pd.Series, pd.DataFrame]), pd.Series)
22402249
check(assert_type(df.loc[:, False], pd.Series), pd.Series)
22412250

22422251

0 commit comments

Comments
 (0)