Skip to content

Commit ee017c1

Browse files
authored
BUG: Allow list-like in DatetimeIndex.searchsorted (#32764)
1 parent 7b5e9c8 commit ee017c1

File tree

6 files changed

+97
-3
lines changed

6 files changed

+97
-3
lines changed

doc/source/whatsnew/v1.1.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ Datetimelike
255255
- Bug in :meth:`Period.to_timestamp`, :meth:`Period.start_time` with microsecond frequency returning a timestamp one nanosecond earlier than the correct time (:issue:`31475`)
256256
- :class:`Timestamp` raising confusing error message when year, month or day is missing (:issue:`31200`)
257257
- Bug in :class:`DatetimeIndex` constructor incorrectly accepting ``bool``-dtyped inputs (:issue:`32668`)
258+
- Bug in :meth:`DatetimeIndex.searchsorted` not accepting a ``list`` or :class:`Series` as its argument (:issue:`32762`)
258259

259260
Timedelta
260261
^^^^^^^^^

pandas/core/arrays/datetimelike.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -846,14 +846,14 @@ def searchsorted(self, value, side="left", sorter=None):
846846
elif isinstance(value, self._recognized_scalars):
847847
value = self._scalar_type(value)
848848

849-
elif isinstance(value, np.ndarray):
849+
elif is_list_like(value) and not isinstance(value, type(self)):
850+
value = array(value)
851+
850852
if not type(self)._is_recognized_dtype(value):
851853
raise TypeError(
852854
"searchsorted requires compatible dtype or scalar, "
853855
f"not {type(value).__name__}"
854856
)
855-
value = type(self)(value)
856-
self._check_compatible_with(value)
857857

858858
if not (isinstance(value, (self._scalar_type, type(self))) or (value is NaT)):
859859
raise TypeError(f"Unexpected type for 'value': {type(value)}")

pandas/tests/arrays/test_datetimelike.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -812,3 +812,38 @@ def test_to_numpy_extra(array):
812812
assert result[0] == result[1]
813813

814814
tm.assert_equal(array, original)
815+
816+
817+
@pytest.mark.parametrize(
818+
"values",
819+
[
820+
pd.to_datetime(["2020-01-01", "2020-02-01"]),
821+
pd.TimedeltaIndex([1, 2], unit="D"),
822+
pd.PeriodIndex(["2020-01-01", "2020-02-01"], freq="D"),
823+
],
824+
)
825+
@pytest.mark.parametrize("klass", [list, np.array, pd.array, pd.Series])
826+
def test_searchsorted_datetimelike_with_listlike(values, klass):
827+
# https://github.com/pandas-dev/pandas/issues/32762
828+
result = values.searchsorted(klass(values))
829+
expected = np.array([0, 1], dtype=result.dtype)
830+
831+
tm.assert_numpy_array_equal(result, expected)
832+
833+
834+
@pytest.mark.parametrize(
835+
"values",
836+
[
837+
pd.to_datetime(["2020-01-01", "2020-02-01"]),
838+
pd.TimedeltaIndex([1, 2], unit="D"),
839+
pd.PeriodIndex(["2020-01-01", "2020-02-01"], freq="D"),
840+
],
841+
)
842+
@pytest.mark.parametrize(
843+
"arg", [[1, 2], ["a", "b"], [pd.Timestamp("2020-01-01", tz="Europe/London")] * 2]
844+
)
845+
def test_searchsorted_datetimelike_with_listlike_invalid_dtype(values, arg):
846+
# https://github.com/pandas-dev/pandas/issues/32762
847+
msg = "[Unexpected type|Cannot compare]"
848+
with pytest.raises(TypeError, match=msg):
849+
values.searchsorted(arg)

pandas/tests/indexes/interval/test_interval.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -863,3 +863,25 @@ def test_dir():
863863
index = IntervalIndex.from_arrays([0, 1], [1, 2])
864864
result = dir(index)
865865
assert "str" not in result
866+
867+
868+
@pytest.mark.parametrize("klass", [list, np.array, pd.array, pd.Series])
869+
def test_searchsorted_different_argument_classes(klass):
870+
# https://github.com/pandas-dev/pandas/issues/32762
871+
values = IntervalIndex([Interval(0, 1), Interval(1, 2)])
872+
result = values.searchsorted(klass(values))
873+
expected = np.array([0, 1], dtype=result.dtype)
874+
tm.assert_numpy_array_equal(result, expected)
875+
876+
result = values._data.searchsorted(klass(values))
877+
tm.assert_numpy_array_equal(result, expected)
878+
879+
880+
@pytest.mark.parametrize(
881+
"arg", [[1, 2], ["a", "b"], [pd.Timestamp("2020-01-01", tz="Europe/London")] * 2]
882+
)
883+
def test_searchsorted_invalid_argument(arg):
884+
values = IntervalIndex([Interval(0, 1), Interval(1, 2)])
885+
msg = "unorderable types"
886+
with pytest.raises(TypeError, match=msg):
887+
values.searchsorted(arg)

pandas/tests/indexes/period/test_tools.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
NaT,
1111
Period,
1212
PeriodIndex,
13+
Series,
1314
Timedelta,
1415
Timestamp,
16+
array,
1517
date_range,
1618
period_range,
1719
)
@@ -64,6 +66,19 @@ def test_searchsorted(self, freq):
6466
with pytest.raises(IncompatibleFrequency, match=msg):
6567
pidx.searchsorted(Period("2014-01-01", freq="5D"))
6668

69+
@pytest.mark.parametrize("klass", [list, np.array, array, Series])
70+
def test_searchsorted_different_argument_classes(self, klass):
71+
pidx = PeriodIndex(
72+
["2014-01-01", "2014-01-02", "2014-01-03", "2014-01-04", "2014-01-05"],
73+
freq="D",
74+
)
75+
result = pidx.searchsorted(klass(pidx))
76+
expected = np.arange(len(pidx), dtype=result.dtype)
77+
tm.assert_numpy_array_equal(result, expected)
78+
79+
result = pidx._data.searchsorted(klass(pidx))
80+
tm.assert_numpy_array_equal(result, expected)
81+
6782
def test_searchsorted_invalid(self):
6883
pidx = PeriodIndex(
6984
["2014-01-01", "2014-01-02", "2014-01-03", "2014-01-04", "2014-01-05"],

pandas/tests/indexes/timedeltas/test_timedelta.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
Series,
1212
Timedelta,
1313
TimedeltaIndex,
14+
array,
1415
date_range,
1516
timedelta_range,
1617
)
@@ -111,6 +112,26 @@ def test_sort_values(self):
111112

112113
tm.assert_numpy_array_equal(dexer, np.array([0, 2, 1]), check_dtype=False)
113114

115+
@pytest.mark.parametrize("klass", [list, np.array, array, Series])
116+
def test_searchsorted_different_argument_classes(self, klass):
117+
idx = TimedeltaIndex(["1 day", "2 days", "3 days"])
118+
result = idx.searchsorted(klass(idx))
119+
expected = np.arange(len(idx), dtype=result.dtype)
120+
tm.assert_numpy_array_equal(result, expected)
121+
122+
result = idx._data.searchsorted(klass(idx))
123+
tm.assert_numpy_array_equal(result, expected)
124+
125+
@pytest.mark.parametrize(
126+
"arg",
127+
[[1, 2], ["a", "b"], [pd.Timestamp("2020-01-01", tz="Europe/London")] * 2],
128+
)
129+
def test_searchsorted_invalid_argument_dtype(self, arg):
130+
idx = TimedeltaIndex(["1 day", "2 days", "3 days"])
131+
msg = "searchsorted requires compatible dtype"
132+
with pytest.raises(TypeError, match=msg):
133+
idx.searchsorted(arg)
134+
114135
def test_argmin_argmax(self):
115136
idx = TimedeltaIndex(["1 day 00:00:05", "1 day 00:00:01", "1 day 00:00:02"])
116137
assert idx.argmin() == 1

0 commit comments

Comments
 (0)