@@ -2208,16 +2208,16 @@ def test_frame_scalars_slice() -> None:
2208
2208
2209
2209
# Note: bool_ cannot be tested since the index is object and pandas does not
2210
2210
# support boolean access using loc except when the index is boolean
2211
- check (assert_type (df .loc [str_ ], pd .Series ), pd .Series )
2212
- check (assert_type (df .loc [bytes_ ], pd .Series ), pd .Series )
2213
- check (assert_type (df .loc [date ], pd .Series ), pd .Series )
2214
- check (assert_type (df .loc [datetime_ ], pd .Series ), pd .Series )
2215
- check (assert_type (df .loc [timedelta ], pd .Series ), pd .Series )
2216
- check (assert_type (df .loc [int_ ], pd .Series ), pd .Series )
2217
- check (assert_type (df .loc [float_ ], pd .Series ), pd .Series )
2218
- check (assert_type (df .loc [complex_ ], pd .Series ), pd .Series )
2219
- check (assert_type (df .loc [timestamp ], pd .Series ), pd .Series )
2220
- check (assert_type (df .loc [pd_timedelta ], pd .Series ), pd .Series )
2211
+ check (assert_type (df .loc [str_ ], Union [ pd .Series , pd . DataFrame ] ), pd .Series )
2212
+ check (assert_type (df .loc [bytes_ ], Union [ pd .Series , pd . DataFrame ] ), pd .Series )
2213
+ check (assert_type (df .loc [date ], Union [ pd .Series , pd . DataFrame ] ), pd .Series )
2214
+ check (assert_type (df .loc [datetime_ ], Union [ pd .Series , pd . DataFrame ] ), pd .Series )
2215
+ check (assert_type (df .loc [timedelta ], Union [ pd .Series , pd . DataFrame ] ), pd .Series )
2216
+ check (assert_type (df .loc [int_ ], Union [ pd .Series , pd . DataFrame ] ), pd .Series )
2217
+ check (assert_type (df .loc [float_ ], Union [ pd .Series , pd . DataFrame ] ), pd .Series )
2218
+ check (assert_type (df .loc [complex_ ], Union [ pd .Series , pd . DataFrame ] ), pd .Series )
2219
+ check (assert_type (df .loc [timestamp ], Union [ pd .Series , pd . DataFrame ] ), pd .Series )
2220
+ check (assert_type (df .loc [pd_timedelta ], Union [ pd .Series , pd . DataFrame ] ), pd .Series )
2221
2221
check (assert_type (df .loc [none ], pd .Series ), pd .Series )
2222
2222
2223
2223
check (assert_type (df .loc [:, str_ ], pd .Series ), pd .Series )
@@ -2232,11 +2232,20 @@ def test_frame_scalars_slice() -> None:
2232
2232
check (assert_type (df .loc [:, pd_timedelta ], pd .Series ), pd .Series )
2233
2233
check (assert_type (df .loc [:, none ], pd .Series ), pd .Series )
2234
2234
2235
+ # GH749
2236
+
2237
+ multi_idx = pd .MultiIndex .from_product ([["a" , "b" ], [1 , 2 ]], names = ["alpha" , "num" ])
2238
+ df2 = pd .DataFrame ({"col1" : range (4 )}, index = multi_idx )
2239
+ check (assert_type (df2 .loc [str_ ], Union [pd .Series , pd .DataFrame ]), pd .DataFrame )
2240
+
2241
+ df3 = pd .DataFrame ({"x" : range (2 )}, index = pd .Index (["a" , "b" ]))
2242
+ check (assert_type (df3 .loc [str_ ], Union [pd .Series , pd .DataFrame ]), pd .Series )
2243
+
2235
2244
2236
2245
def test_boolean_loc () -> None :
2237
2246
# Booleans can only be used in loc when the index is boolean
2238
2247
df = pd .DataFrame ([[0 , 1 ], [1 , 0 ]], columns = [True , False ], index = [True , False ])
2239
- check (assert_type (df .loc [True ], pd .Series ), pd .Series )
2248
+ check (assert_type (df .loc [True ], Union [ pd .Series , pd . DataFrame ] ), pd .Series )
2240
2249
check (assert_type (df .loc [:, False ], pd .Series ), pd .Series )
2241
2250
2242
2251
0 commit comments