From bb381ed0caf5643e45d01a6979fc0b3ab324216d Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Tue, 12 May 2020 14:52:11 -0700 Subject: [PATCH] CLN: avoid hasattr check in Timestamp.__richcmp__ --- pandas/_libs/tslibs/timestamps.pyx | 86 +++++++------------ pandas/tests/arithmetic/test_numeric.py | 4 +- pandas/tests/frame/test_arithmetic.py | 2 +- .../scalar/timestamp/test_comparisons.py | 4 +- 4 files changed, 35 insertions(+), 61 deletions(-) diff --git a/pandas/_libs/tslibs/timestamps.pyx b/pandas/_libs/tslibs/timestamps.pyx index e656d654461c9..52ff3d6a11b5f 100644 --- a/pandas/_libs/tslibs/timestamps.pyx +++ b/pandas/_libs/tslibs/timestamps.pyx @@ -37,7 +37,7 @@ from pandas._libs.tslibs.fields import get_start_end_field, get_date_name_field from pandas._libs.tslibs.nattype cimport NPY_NAT, c_NaT as NaT from pandas._libs.tslibs.np_datetime cimport ( check_dts_bounds, npy_datetimestruct, dt64_to_dtstruct, - reverse_ops, cmp_scalar, + cmp_scalar, ) from pandas._libs.tslibs.np_datetime import OutOfBoundsDatetime from pandas._libs.tslibs.offsets cimport to_offset @@ -228,8 +228,8 @@ cdef class _Timestamp(ABCTimestamp): ots = other elif other is NaT: return op == Py_NE - elif PyDateTime_Check(other): - if self.nanosecond == 0: + elif PyDateTime_Check(other) or is_datetime64_object(other): + if self.nanosecond == 0 and PyDateTime_Check(other): val = self.to_pydatetime() return PyObject_RichCompareBool(val, other, op) @@ -237,44 +237,31 @@ cdef class _Timestamp(ABCTimestamp): ots = type(self)(other) except ValueError: return self._compare_outside_nanorange(other, op) + + elif is_array(other): + # avoid recursion error GH#15183 + if other.dtype.kind == "M": + if self.tz is None: + return PyObject_RichCompare(self.asm8, other, op) + raise TypeError( + "Cannot compare tz-naive and tz-aware timestamps" + ) + elif other.dtype.kind == "O": + # Operate element-wise + return np.array( + [PyObject_RichCompare(self, x, op) for x in other], + dtype=bool, + ) + elif op == Py_NE: + return np.ones(other.shape, dtype=np.bool_) + elif op == Py_EQ: + return np.zeros(other.shape, dtype=np.bool_) + return NotImplemented + else: - ndim = getattr(other, "ndim", -1) - - if ndim != -1: - if ndim == 0: - if is_datetime64_object(other): - other = type(self)(other) - elif is_array(other): - # zero-dim array, occurs if try comparison with - # datetime64 scalar on the left hand side - # Unfortunately, for datetime64 values, other.item() - # incorrectly returns an integer, so we need to use - # the numpy C api to extract it. - other = cnp.PyArray_ToScalar(cnp.PyArray_DATA(other), - other) - other = type(self)(other) - else: - return NotImplemented - elif is_array(other): - # avoid recursion error GH#15183 - if other.dtype.kind == "M": - if self.tz is None: - return PyObject_RichCompare(self.asm8, other, op) - raise TypeError( - "Cannot compare tz-naive and tz-aware timestamps" - ) - if other.dtype.kind == "O": - # Operate element-wise - return np.array( - [PyObject_RichCompare(self, x, op) for x in other], - dtype=bool, - ) - return PyObject_RichCompare(np.array([self]), other, op) - return PyObject_RichCompare(other, self, reverse_ops[op]) - else: - return NotImplemented + return NotImplemented - self._assert_tzawareness_compat(other) + self._assert_tzawareness_compat(ots) return cmp_scalar(self.value, ots.value, op) def __reduce_ex__(self, protocol): @@ -314,22 +301,7 @@ cdef class _Timestamp(ABCTimestamp): datetime dtval = self.to_pydatetime() self._assert_tzawareness_compat(other) - - if self.nanosecond == 0: - return PyObject_RichCompareBool(dtval, other, op) - else: - if op == Py_EQ: - return False - elif op == Py_NE: - return True - elif op == Py_LT: - return dtval < other - elif op == Py_LE: - return dtval < other - elif op == Py_GT: - return dtval >= other - elif op == Py_GE: - return dtval >= other + return PyObject_RichCompareBool(dtval, other, op) cdef _assert_tzawareness_compat(_Timestamp self, datetime other): if self.tzinfo is None: @@ -406,10 +378,10 @@ cdef class _Timestamp(ABCTimestamp): elif is_tick_object(other): try: nanos = other.nanos - except OverflowError: + except OverflowError as err: raise OverflowError( f"the add operation between {other} and {self} will overflow" - ) + ) from err result = type(self)(self.value + nanos, tz=self.tzinfo, freq=self.freq) return result diff --git a/pandas/tests/arithmetic/test_numeric.py b/pandas/tests/arithmetic/test_numeric.py index 269235b943e46..b6456a2141c06 100644 --- a/pandas/tests/arithmetic/test_numeric.py +++ b/pandas/tests/arithmetic/test_numeric.py @@ -66,7 +66,9 @@ def test_df_numeric_cmp_dt64_raises(self): ts = pd.Timestamp.now() df = pd.DataFrame({"x": range(5)}) - msg = "'[<>]' not supported between instances of 'Timestamp' and 'int'" + msg = ( + "'[<>]' not supported between instances of 'numpy.ndarray' and 'Timestamp'" + ) with pytest.raises(TypeError, match=msg): df > ts with pytest.raises(TypeError, match=msg): diff --git a/pandas/tests/frame/test_arithmetic.py b/pandas/tests/frame/test_arithmetic.py index d75f1f14b6369..b9102b1f84c4a 100644 --- a/pandas/tests/frame/test_arithmetic.py +++ b/pandas/tests/frame/test_arithmetic.py @@ -106,7 +106,7 @@ def test_timestamp_compare(self): else: msg = ( "'(<|>)=?' not supported between " - "instances of 'Timestamp' and 'float'" + "instances of 'numpy.ndarray' and 'Timestamp'" ) with pytest.raises(TypeError, match=msg): left_f(df, pd.Timestamp("20010109")) diff --git a/pandas/tests/scalar/timestamp/test_comparisons.py b/pandas/tests/scalar/timestamp/test_comparisons.py index 27aef8c4a9eb7..71693a9ca61ce 100644 --- a/pandas/tests/scalar/timestamp/test_comparisons.py +++ b/pandas/tests/scalar/timestamp/test_comparisons.py @@ -211,9 +211,9 @@ def test_compare_zerodim_array(self): assert arr.ndim == 0 result = arr < ts - assert result is True + assert result is np.bool_(True) result = arr > ts - assert result is False + assert result is np.bool_(False) def test_rich_comparison_with_unsupported_type():