Skip to content

Commit 5c12d4f

Browse files
jbrockmendeljreback
authored andcommitted
BUG: DTI.searchsorted accepting invalid types/dtypes (#30826)
1 parent 1740d03 commit 5c12d4f

File tree

2 files changed

+90
-5
lines changed

2 files changed

+90
-5
lines changed

pandas/core/indexes/datetimes.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def _convert_for_op(self, value):
324324
Convert value to be insertable to ndarray.
325325
"""
326326
if self._has_same_tz(value):
327-
return _to_M8(value)
327+
return Timestamp(value).asm8
328328
raise ValueError("Passed item and index have different timezone")
329329

330330
# --------------------------------------------------------------------
@@ -859,11 +859,25 @@ def __getitem__(self, key):
859859
@Appender(_shared_docs["searchsorted"])
860860
def searchsorted(self, value, side="left", sorter=None):
861861
if isinstance(value, (np.ndarray, Index)):
862-
value = np.array(value, dtype=_NS_DTYPE, copy=False)
863-
else:
864-
value = _to_M8(value, tz=self.tz)
862+
if not type(self._data)._is_recognized_dtype(value):
863+
raise TypeError(
864+
"searchsorted requires compatible dtype or scalar, "
865+
f"not {type(value).__name__}"
866+
)
867+
value = type(self._data)(value)
868+
self._data._check_compatible_with(value)
869+
870+
elif isinstance(value, self._data._recognized_scalars):
871+
self._data._check_compatible_with(value)
872+
value = self._data._scalar_type(value)
873+
874+
elif not isinstance(value, DatetimeArray):
875+
raise TypeError(
876+
"searchsorted requires compatible dtype or scalar, "
877+
f"not {type(value).__name__}"
878+
)
865879

866-
return self.values.searchsorted(value, side=side)
880+
return self._data.searchsorted(value, side=side)
867881

868882
def is_type_compatible(self, typ) -> bool:
869883
return typ == self.inferred_type or typ == "datetime"

pandas/tests/arrays/test_datetimes.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,77 @@ def test_array_interface(self):
282282
)
283283
tm.assert_numpy_array_equal(result, expected)
284284

285+
@pytest.mark.parametrize("index", [True, False])
286+
def test_searchsorted_different_tz(self, index):
287+
data = np.arange(10, dtype="i8") * 24 * 3600 * 10 ** 9
288+
arr = DatetimeArray(data, freq="D").tz_localize("Asia/Tokyo")
289+
if index:
290+
arr = pd.Index(arr)
291+
292+
expected = arr.searchsorted(arr[2])
293+
result = arr.searchsorted(arr[2].tz_convert("UTC"))
294+
assert result == expected
295+
296+
expected = arr.searchsorted(arr[2:6])
297+
result = arr.searchsorted(arr[2:6].tz_convert("UTC"))
298+
tm.assert_equal(result, expected)
299+
300+
@pytest.mark.parametrize("index", [True, False])
301+
def test_searchsorted_tzawareness_compat(self, index):
302+
data = np.arange(10, dtype="i8") * 24 * 3600 * 10 ** 9
303+
arr = DatetimeArray(data, freq="D")
304+
if index:
305+
arr = pd.Index(arr)
306+
307+
mismatch = arr.tz_localize("Asia/Tokyo")
308+
309+
msg = "Cannot compare tz-naive and tz-aware datetime-like objects"
310+
with pytest.raises(TypeError, match=msg):
311+
arr.searchsorted(mismatch[0])
312+
with pytest.raises(TypeError, match=msg):
313+
arr.searchsorted(mismatch)
314+
315+
with pytest.raises(TypeError, match=msg):
316+
mismatch.searchsorted(arr[0])
317+
with pytest.raises(TypeError, match=msg):
318+
mismatch.searchsorted(arr)
319+
320+
@pytest.mark.parametrize(
321+
"other",
322+
[
323+
1,
324+
np.int64(1),
325+
1.0,
326+
np.timedelta64("NaT"),
327+
pd.Timedelta(days=2),
328+
"invalid",
329+
np.arange(10, dtype="i8") * 24 * 3600 * 10 ** 9,
330+
np.arange(10).view("timedelta64[ns]") * 24 * 3600 * 10 ** 9,
331+
pd.Timestamp.now().to_period("D"),
332+
],
333+
)
334+
@pytest.mark.parametrize(
335+
"index",
336+
[
337+
True,
338+
pytest.param(
339+
False,
340+
marks=pytest.mark.xfail(
341+
reason="Raises ValueError instead of TypeError", raises=ValueError
342+
),
343+
),
344+
],
345+
)
346+
def test_searchsorted_invalid_types(self, other, index):
347+
data = np.arange(10, dtype="i8") * 24 * 3600 * 10 ** 9
348+
arr = DatetimeArray(data, freq="D")
349+
if index:
350+
arr = pd.Index(arr)
351+
352+
msg = "searchsorted requires compatible dtype or scalar"
353+
with pytest.raises(TypeError, match=msg):
354+
arr.searchsorted(other)
355+
285356

286357
class TestSequenceToDT64NS:
287358
def test_tz_dtype_mismatch_raises(self):

0 commit comments

Comments
 (0)