Skip to content

Commit 1ba7ff8

Browse files
committed
implemented the searchsorted() method, w.r.t issue pandas-dev#18433
1 parent cffb863 commit 1ba7ff8

File tree

2 files changed

+61
-9
lines changed

2 files changed

+61
-9
lines changed

pandas/core/indexes/multi.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3814,22 +3814,37 @@ def searchsorted(
38143814
>>> mi.searchsorted(("b", "y"))
38153815
1
38163816
"""
3817-
if isinstance(value, tuple):
3818-
value = list(value)
3817+
if not isinstance(value, (tuple,list)):
3818+
raise TypeError("value must be a tuple or list")
38193819

3820+
if isinstance(value, tuple):
3821+
value = [value]
38203822
if side not in ["left", "right"]:
38213823
raise ValueError("side must be either 'left' or 'right'")
38223824

38233825
if not value:
38243826
raise ValueError("searchsorted requires a non-empty value")
38253827

3826-
3827-
3828-
dtype = np.dtype([(f"level_{i}", level.dtype) for i,level in enumerate(self.levels)])
3829-
3830-
val = np.asarray(value, dtype=dtype)
3831-
3832-
return np.searchsorted(self.values.astype(dtype),val, side=side, sorter=sorter)
3828+
try:
3829+
3830+
indexer = self.get_indexer(value)
3831+
result = []
3832+
3833+
for v, i in zip(value, indexer):
3834+
if i!= -1:
3835+
result.append(i if side == "left" else i + 1)
3836+
else:
3837+
dtype = np.dtype([(f"level_{i}", level.dtype) for i, level in enumerate(self.levels)])
3838+
3839+
val_array = np.array(value, dtype=dtype)
3840+
3841+
pos = np.searchsorted( np.asarray(self.values,dtype=dtype),val_array , side=side, sorter = sorter)
3842+
result.append(pos)
3843+
3844+
return np.array(result, dtype=np.intp)
3845+
3846+
except KeyError:
3847+
pass
38333848

38343849

38353850
def truncate(self, before=None, after=None) -> MultiIndex:

pandas/tests/indexes/multi/test_indexing.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,3 +1029,40 @@ def test_get_loc_namedtuple_behaves_like_tuple():
10291029
assert idx.get_loc(("i1", "i2")) == 0
10301030
assert idx.get_loc(("i3", "i4")) == 1
10311031
assert idx.get_loc(("i5", "i6")) == 2
1032+
1033+
1034+
1035+
1036+
def test_searchsorted():
1037+
mi = MultiIndex.from_tuples([
1038+
('a', 0),
1039+
('a', 1),
1040+
('b', 0),
1041+
('b', 1),
1042+
('c', 0)
1043+
])
1044+
1045+
1046+
assert mi.searchsorted(('b', 0)) == 2
1047+
assert mi.searchsorted(('b', 0), side="right") == 3
1048+
1049+
assert mi.searchsorted(('a', 0)) == 0
1050+
assert mi.searchsorted(('a', -1)) == 0
1051+
assert mi.searchsorted(('c', 1)) == 5 # Beyond the last
1052+
1053+
1054+
result = mi.searchsorted([('a', 1), ('b', 0), ('c', 0)])
1055+
expected = np.array([1, 2, 4], dtype=np.intp)
1056+
np.testing.assert_array_equal(result, expected)
1057+
1058+
1059+
result = mi.searchsorted([('a', 1), ('b', 0), ('c', 0)], side='right')
1060+
expected = np.array([2, 3, 5], dtype=np.intp)
1061+
np.testing.assert_array_equal(result, expected)
1062+
1063+
1064+
with pytest.raises(ValueError, match="side must be either 'left' or 'right'"):
1065+
mi.searchsorted(('a', 1), side='middle')
1066+
1067+
with pytest.raises(TypeError, match="value must be a tuple or list"):
1068+
mi.searchsorted('a') # not a tuple

0 commit comments

Comments
 (0)