Skip to content

Commit bd2d98f

Browse files
authored
Fix issue with Series.loc[x] where x is a tuple specifiying a specific index to get from a MultiIndex (#350)
* WIP: TimedeltaIndex accessors * remove dtl to make mypy happy * demonstrate possible bug * fix series loc on multiindex scalars * use IndexSliceTuple, add comments
1 parent 511876d commit bd2d98f

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

pandas-stubs/core/series.pyi

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ from pandas.core.indexes.timedeltas import TimedeltaIndex
4848
from pandas.core.indexing import (
4949
_AtIndexer,
5050
_iAtIndexer,
51+
_IndexSliceTuple,
5152
)
5253
from pandas.core.resample import Resampler
5354
from pandas.core.strings import StringMethods
@@ -131,21 +132,21 @@ class _iLocIndexerSeries(_iLocIndexer, Generic[S1]):
131132
) -> None: ...
132133

133134
class _LocIndexerSeries(_LocIndexer, Generic[S1]):
135+
# ignore needed because of mypy. Overlapping, but we want to distinguish
136+
# having a tuple of just scalars, versus tuples that include slices or Index
134137
@overload
135-
def __getitem__(
138+
def __getitem__( # type: ignore[misc]
136139
self,
137-
idx: MaskType
138-
| Index
139-
| Sequence[float]
140-
| list[str]
141-
| slice
142-
| tuple[str | float | slice | Index, ...],
143-
) -> Series[S1]: ...
140+
idx: Scalar | tuple[Scalar, ...],
141+
# tuple case is for getting a specific element when using a MultiIndex
142+
) -> S1: ...
144143
@overload
145144
def __getitem__(
146145
self,
147-
idx: str | float,
148-
) -> S1: ...
146+
idx: MaskType | Index | Sequence[float] | list[str] | slice | _IndexSliceTuple,
147+
# _IndexSliceTuple is when having a tuple that includes a slice. Could just
148+
# be s.loc[1, :], or s.loc[pd.IndexSlice[1, :]]
149+
) -> Series[S1]: ...
149150
@overload
150151
def __setitem__(
151152
self,

tests/test_series.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,12 @@ def test_types_loc_at() -> None:
120120

121121

122122
def test_multiindex_loc() -> None:
123-
s = pd.Series([1, 2, 3, 4], index=pd.MultiIndex.from_product([[1, 2], ["a", "b"]]))
124-
check(assert_type(s.loc[1, :], pd.Series), pd.Series)
125-
check(assert_type(s.loc[pd.Index([1]), :], pd.Series), pd.Series)
123+
s = pd.Series(
124+
[1, 2, 3, 4], index=pd.MultiIndex.from_product([[1, 2], ["a", "b"]]), dtype=int
125+
)
126+
check(assert_type(s.loc[1, :], "pd.Series[int]"), pd.Series, int)
127+
check(assert_type(s.loc[pd.Index([1]), :], "pd.Series[int]"), pd.Series, int)
128+
check(assert_type(s.loc[1, "a"], int), np.int_)
126129

127130

128131
def test_types_boolean_indexing() -> None:

0 commit comments

Comments
 (0)