Skip to content

Commit e2c2c5e

Browse files
committed
Closes the GH 14833 issue
1 parent 9ac62ab commit e2c2c5e

File tree

3 files changed

+41
-16
lines changed

3 files changed

+41
-16
lines changed

pandas/core/indexes/multi.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
Any,
1717
Literal,
1818
cast,
19+
overload,
1920
)
2021
import warnings
2122

@@ -44,6 +45,14 @@
4445
Shape,
4546
npt,
4647
)
48+
if TYPE_CHECKING:
49+
from pandas._typing import (
50+
NumpySorter,
51+
NumpyValueArrayLike,
52+
ScalarLike_co,
53+
)
54+
55+
4756
from pandas.compat.numpy import function as nv
4857
from pandas.errors import (
4958
InvalidIndexError,
@@ -3778,9 +3787,25 @@ def _reorder_indexer(
37783787
ind = np.lexsort(keys)
37793788
return indexer[ind]
37803789

3790+
@overload
3791+
def searchsorted(
3792+
self,
3793+
value: ScalarLike_co,
3794+
side: Literal["left", "right"] = ...,
3795+
sirter:NumpySorter = ...,
3796+
) -> np.intp:...
3797+
3798+
@overload
3799+
def searchsorted(
3800+
self,
3801+
value: npt.ArrayLike | ExtensionArray,
3802+
side: Literal["left", "right"] = ...,
3803+
sorter: NumpySorter = ...,
3804+
) -> npt.NDArray[np.intp]:...
3805+
37813806
def searchsorted(
37823807
self,
3783-
value: Any,
3808+
value: NumpyValueArrayLike | ExtensionArray,
37843809
side: Literal["left", "right"] = "left",
37853810
sorter: npt.NDArray[np.intp] | None = None,
37863811
) -> npt.NDArray[np.intp]:
@@ -3831,6 +3856,7 @@ def searchsorted(
38313856

38323857
for v, i in zip(value, indexer):
38333858
if i != -1:
3859+
38343860
result.append(i if side == "left" else i + 1)
38353861
else:
38363862
dtype = np.dtype(
@@ -3839,7 +3865,7 @@ def searchsorted(
38393865
for i, level in enumerate(self.levels)
38403866
]
38413867
)
3842-
3868+
38433869
val_array = np.array([v], dtype=dtype)
38443870

38453871
pos = np.searchsorted(

pandas/tests/base/test_misc.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -141,20 +141,20 @@ def test_memory_usage_components_narrow_series(any_real_numpy_dtype):
141141
index_usage = series.index.memory_usage()
142142
assert total_usage == non_index_usage + index_usage
143143

144-
144+
145145
def test_searchsorted(request, index_or_series_obj):
146146
# numpy.searchsorted calls obj.searchsorted under the hood.
147147
# See gh-12238
148148
obj = index_or_series_obj
149149

150-
if isinstance(obj, pd.MultiIndex):
151-
# See gh-14833
152-
request.applymarker(
153-
pytest.mark.xfail(
154-
reason="np.searchsorted doesn't work on pd.MultiIndex: GH 14833"
155-
)
156-
)
157-
elif obj.dtype.kind == "c" and isinstance(obj, Index):
150+
# if isinstance(obj, pd.MultiIndex):
151+
# # See gh-14833
152+
# request.applymarker(
153+
# pytest.mark.xfail(
154+
# reason="np.searchsorted doesn't work on pd.MultiIndex: GH 14833"
155+
# )
156+
# )
157+
if obj.dtype.kind == "c" and isinstance(obj, Index):
158158
# TODO: Should Series cases also raise? Looks like they use numpy
159159
# comparison semantics https://github.com/numpy/numpy/issues/15981
160160
mark = pytest.mark.xfail(reason="complex objects are not comparable")

pandas/tests/indexes/multi/test_indexing.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,15 +1032,14 @@ def test_get_loc_namedtuple_behaves_like_tuple():
10321032

10331033

10341034
def test_searchsorted():
1035-
mi = MultiIndex.from_tuples([("a", 0), ("a", 1), ("b", 0),
1036-
("b", 1), ("c", 0)])
1035+
# GH14833
1036+
mi = MultiIndex.from_tuples([("a", 0), ("a", 1), ("b", 0), ("b", 1), ("c", 0)])
10371037

10381038
assert np.all(mi.searchsorted(("b", 0)) == 2)
10391039
assert np.all(mi.searchsorted(("b", 0), side="right") == 3)
1040-
10411040
assert np.all(mi.searchsorted(("a", 0)) == 0)
10421041
assert np.all(mi.searchsorted(("a", -1)) == 0)
1043-
assert np.all(mi.searchsorted(("c", 1)) == 5) # Beyond the last
1042+
assert np.all(mi.searchsorted(("c", 1)) == 5)
10441043

10451044
result = mi.searchsorted([("a", 1), ("b", 0), ("c", 0)])
10461045
expected = np.array([1, 2, 4], dtype=np.intp)
@@ -1054,4 +1053,4 @@ def test_searchsorted():
10541053
mi.searchsorted(("a", 1), side="middle")
10551054

10561055
with pytest.raises(TypeError, match="value must be a tuple or list"):
1057-
mi.searchsorted("a") # not a tuple
1056+
mi.searchsorted("a")

0 commit comments

Comments
 (0)