From aa7403cbd656ae0630345f0bec550ef646273106 Mon Sep 17 00:00:00 2001 From: Brock Mendel Date: Wed, 8 Jan 2020 14:12:01 -0800 Subject: [PATCH 1/2] BUG: DTI.searchsorted accepting invalid types/dtypes --- pandas/core/indexes/datetimes.py | 24 +++++++-- pandas/tests/arrays/test_datetimes.py | 71 +++++++++++++++++++++++++++ 2 files changed, 90 insertions(+), 5 deletions(-) diff --git a/pandas/core/indexes/datetimes.py b/pandas/core/indexes/datetimes.py index 88b841e7d4a88..1d7f84352d9da 100644 --- a/pandas/core/indexes/datetimes.py +++ b/pandas/core/indexes/datetimes.py @@ -359,7 +359,7 @@ def _convert_for_op(self, value): Convert value to be insertable to ndarray. """ if self._has_same_tz(value): - return _to_M8(value) + return Timestamp(value).asm8 raise ValueError("Passed item and index have different timezone") # -------------------------------------------------------------------- @@ -892,11 +892,25 @@ def __getitem__(self, key): @Appender(_shared_docs["searchsorted"]) def searchsorted(self, value, side="left", sorter=None): if isinstance(value, (np.ndarray, Index)): - value = np.array(value, dtype=_NS_DTYPE, copy=False) - else: - value = _to_M8(value, tz=self.tz) + if not type(self._data)._is_recognized_dtype(value): + raise TypeError( + "searchsorted requires compatible dtype or scalar, " + f"not {type(value).__name__}" + ) + value = type(self._data)(value) + self._data._check_compatible_with(value) + + elif isinstance(value, self._data._recognized_scalars): + self._data._check_compatible_with(value) + value = self._data._scalar_type(value) + + elif not isinstance(value, DatetimeArray): + raise TypeError( + "searchsorted requires compatible dtype or scalar, " + f"not {type(value).__name__}" + ) - return self.values.searchsorted(value, side=side) + return self._data.searchsorted(value, side=side) def is_type_compatible(self, typ) -> bool: return typ == self.inferred_type or typ == "datetime" diff --git a/pandas/tests/arrays/test_datetimes.py b/pandas/tests/arrays/test_datetimes.py index b5f9c8957c2b8..45c9b3240e051 100644 --- a/pandas/tests/arrays/test_datetimes.py +++ b/pandas/tests/arrays/test_datetimes.py @@ -282,6 +282,77 @@ def test_array_interface(self): ) tm.assert_numpy_array_equal(result, expected) + @pytest.mark.parametrize("index", [True, False]) + def test_searchsorted_different_tz(self, index): + data = np.arange(10, dtype="i8") * 24 * 3600 * 10 ** 9 + arr = DatetimeArray(data, freq="D").tz_localize("Asia/Tokyo") + if index: + arr = pd.Index(arr) + + expected = arr.searchsorted(arr[2]) + result = arr.searchsorted(arr[2].tz_convert("UTC")) + assert result == expected + + expected = arr.searchsorted(arr[2:6]) + result = arr.searchsorted(arr[2:6].tz_convert("UTC")) + tm.assert_equal(result, expected) + + @pytest.mark.parametrize("index", [True, False]) + def test_searchsorted_tzawareness_compat(self, index): + data = np.arange(10, dtype="i8") * 24 * 3600 * 10 ** 9 + arr = DatetimeArray(data, freq="D") + if index: + arr = pd.Index(arr) + + mismatch = arr.tz_localize("Asia/Tokyo") + + msg = "Cannot compare tz-naive and tz-aware datetime-like objects" + with pytest.raises(TypeError, match=msg): + arr.searchsorted(mismatch[0]) + with pytest.raises(TypeError, match=msg): + arr.searchsorted(mismatch) + + with pytest.raises(TypeError, match=msg): + mismatch.searchsorted(arr[0]) + with pytest.raises(TypeError, match=msg): + mismatch.searchsorted(arr) + + @pytest.mark.parametrize( + "other", + [ + 1, + np.int64(1), + 1.0, + np.timedelta64("NaT"), + pd.Timedelta(days=2), + "invalid", + np.arange(10, dtype="i8") * 24 * 3600 * 10 ** 9, + np.arange(10, dtype="timedelta64[ns]") * 24 * 3600 * 10 ** 9, + pd.Timestamp.now().to_period("D"), + ], + ) + @pytest.mark.parametrize( + "index", + [ + True, + pytest.param( + False, + marks=pytest.mark.xfail( + reason="Raises ValueError instead of TypeError", raises=ValueError + ), + ), + ], + ) + def test_searchsorted_invalid_types(self, other, index): + data = np.arange(10, dtype="i8") * 24 * 3600 * 10 ** 9 + arr = DatetimeArray(data, freq="D") + if index: + arr = pd.Index(arr) + + msg = "searchsorted requires compatible dtype or scalar" + with pytest.raises(TypeError, match=msg): + arr.searchsorted(other) + class TestSequenceToDT64NS: def test_tz_dtype_mismatch_raises(self): From 9565b5666000c86b1a7c82fc9b4b4ce169429ab3 Mon Sep 17 00:00:00 2001 From: Brock Mendel Date: Wed, 8 Jan 2020 16:50:38 -0800 Subject: [PATCH 2/2] np version compat --- pandas/tests/arrays/test_datetimes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/tests/arrays/test_datetimes.py b/pandas/tests/arrays/test_datetimes.py index 45c9b3240e051..5608ab5fbd9db 100644 --- a/pandas/tests/arrays/test_datetimes.py +++ b/pandas/tests/arrays/test_datetimes.py @@ -327,7 +327,7 @@ def test_searchsorted_tzawareness_compat(self, index): pd.Timedelta(days=2), "invalid", np.arange(10, dtype="i8") * 24 * 3600 * 10 ** 9, - np.arange(10, dtype="timedelta64[ns]") * 24 * 3600 * 10 ** 9, + np.arange(10).view("timedelta64[ns]") * 24 * 3600 * 10 ** 9, pd.Timestamp.now().to_period("D"), ], )