Skip to content

Commit cb9a1c7

Browse files
jbrockmendeljreback
authored andcommitted
BUG: TimedeltaIndex.searchsorted accepting invalid types/dtypes (#30831)
1 parent 5c12d4f commit cb9a1c7

File tree

2 files changed

+54
-4
lines changed

2 files changed

+54
-4
lines changed

pandas/core/indexes/timedeltas.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -357,11 +357,25 @@ def _partial_td_slice(self, key):
357357
@Appender(_shared_docs["searchsorted"])
358358
def searchsorted(self, value, side="left", sorter=None):
359359
if isinstance(value, (np.ndarray, Index)):
360-
value = np.array(value, dtype=_TD_DTYPE, copy=False)
361-
else:
362-
value = Timedelta(value).asm8.view(_TD_DTYPE)
360+
if not type(self._data)._is_recognized_dtype(value):
361+
raise TypeError(
362+
"searchsorted requires compatible dtype or scalar, "
363+
f"not {type(value).__name__}"
364+
)
365+
value = type(self._data)(value)
366+
self._data._check_compatible_with(value)
367+
368+
elif isinstance(value, self._data._recognized_scalars):
369+
self._data._check_compatible_with(value)
370+
value = self._data._scalar_type(value)
371+
372+
elif not isinstance(value, TimedeltaArray):
373+
raise TypeError(
374+
"searchsorted requires compatible dtype or scalar, "
375+
f"not {type(value).__name__}"
376+
)
363377

364-
return self.values.searchsorted(value, side=side, sorter=sorter)
378+
return self._data.searchsorted(value, side=side, sorter=sorter)
365379

366380
def is_type_compatible(self, typ) -> bool:
367381
return typ == self.inferred_type or typ == "timedelta"

pandas/tests/arrays/test_timedeltas.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,42 @@ def test_setitem_objects(self, obj):
140140
arr[0] = obj
141141
assert arr[0] == pd.Timedelta(seconds=1)
142142

143+
@pytest.mark.parametrize(
144+
"other",
145+
[
146+
1,
147+
np.int64(1),
148+
1.0,
149+
np.datetime64("NaT"),
150+
pd.Timestamp.now(),
151+
"invalid",
152+
np.arange(10, dtype="i8") * 24 * 3600 * 10 ** 9,
153+
(np.arange(10) * 24 * 3600 * 10 ** 9).view("datetime64[ns]"),
154+
pd.Timestamp.now().to_period("D"),
155+
],
156+
)
157+
@pytest.mark.parametrize(
158+
"index",
159+
[
160+
True,
161+
pytest.param(
162+
False,
163+
marks=pytest.mark.xfail(
164+
reason="Raises ValueError instead of TypeError", raises=ValueError
165+
),
166+
),
167+
],
168+
)
169+
def test_searchsorted_invalid_types(self, other, index):
170+
data = np.arange(10, dtype="i8") * 24 * 3600 * 10 ** 9
171+
arr = TimedeltaArray(data, freq="D")
172+
if index:
173+
arr = pd.Index(arr)
174+
175+
msg = "searchsorted requires compatible dtype or scalar"
176+
with pytest.raises(TypeError, match=msg):
177+
arr.searchsorted(other)
178+
143179

144180
class TestReductions:
145181
@pytest.mark.parametrize("name", ["sum", "std", "min", "max", "median"])

0 commit comments

Comments
 (0)