diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index 25ac3e445822e..c0ac0a02d2a0c 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -1,6 +1,6 @@ from datetime import datetime, timedelta import operator -from typing import Any, Sequence, Type, TypeVar, Union, cast +from typing import Any, Callable, Sequence, Type, TypeVar, Union, cast import warnings import numpy as np @@ -10,7 +10,7 @@ from pandas._libs.tslibs.period import DIFFERENT_FREQ, IncompatibleFrequency, Period from pandas._libs.tslibs.timedeltas import delta_to_nanoseconds from pandas._libs.tslibs.timestamps import RoundTo, round_nsint64 -from pandas._typing import DatetimeLikeScalar +from pandas._typing import DatetimeLikeScalar, DtypeObj from pandas.compat import set_function_name from pandas.compat.numpy import function as nv from pandas.errors import AbstractMethodError, NullFrequencyError, PerformanceWarning @@ -86,24 +86,10 @@ def _validate_comparison_value(self, other): raise ValueError("Lengths must match") else: - if isinstance(other, list): - # TODO: could use pd.Index to do inference? - other = np.array(other) - - if not isinstance(other, (np.ndarray, type(self))): - raise InvalidComparison(other) - - elif is_object_dtype(other.dtype): - pass - - elif not type(self)._is_recognized_dtype(other.dtype): - raise InvalidComparison(other) - - else: - # For PeriodDType this casting is unnecessary - # TODO: use Index to do inference? - other = type(self)._from_sequence(other) - self._check_compatible_with(other) + try: + other = self._validate_listlike(other, opname, allow_object=True) + except TypeError as err: + raise InvalidComparison(other) from err return other @@ -451,6 +437,8 @@ class DatetimeLikeArrayMixin( _generate_range """ + _is_recognized_dtype: Callable[[DtypeObj], bool] + # ------------------------------------------------------------------ # NDArrayBackedExtensionArray compat @@ -761,6 +749,48 @@ def _validate_shift_value(self, fill_value): return self._unbox(fill_value) + def _validate_listlike( + self, + value, + opname: str, + cast_str: bool = False, + cast_cat: bool = False, + allow_object: bool = False, + ): + if isinstance(value, type(self)): + return value + + # Do type inference if necessary up front + # e.g. we passed PeriodIndex.values and got an ndarray of Periods + value = array(value) + value = extract_array(value, extract_numpy=True) + + if cast_str and is_dtype_equal(value.dtype, "string"): + # We got a StringArray + try: + # TODO: Could use from_sequence_of_strings if implemented + # Note: passing dtype is necessary for PeriodArray tests + value = type(self)._from_sequence(value, dtype=self.dtype) + except ValueError: + pass + + if cast_cat and is_categorical_dtype(value.dtype): + # e.g. we have a Categorical holding self.dtype + if is_dtype_equal(value.categories.dtype, self.dtype): + # TODO: do we need equal dtype or just comparable? + value = value._internal_get_values() + + if allow_object and is_object_dtype(value.dtype): + pass + + elif not type(self)._is_recognized_dtype(value.dtype): + raise TypeError( + f"{opname} requires compatible dtype or scalar, " + f"not {type(value).__name__}" + ) + + return value + def _validate_searchsorted_value(self, value): if isinstance(value, str): try: @@ -776,41 +806,19 @@ def _validate_searchsorted_value(self, value): elif isinstance(value, self._recognized_scalars): value = self._scalar_type(value) - elif isinstance(value, type(self)): - pass - - elif is_list_like(value) and not isinstance(value, type(self)): - value = array(value) - - if not type(self)._is_recognized_dtype(value.dtype): - raise TypeError( - "searchsorted requires compatible dtype or scalar, " - f"not {type(value).__name__}" - ) + elif not is_list_like(value): + raise TypeError(f"Unexpected type for 'value': {type(value)}") else: - raise TypeError(f"Unexpected type for 'value': {type(value)}") + # TODO: cast_str? we accept it for scalar + value = self._validate_listlike(value, "searchsorted") return self._unbox(value) def _validate_setitem_value(self, value): if is_list_like(value): - value = array(value) - if is_dtype_equal(value.dtype, "string"): - # We got a StringArray - try: - # TODO: Could use from_sequence_of_strings if implemented - # Note: passing dtype is necessary for PeriodArray tests - value = type(self)._from_sequence(value, dtype=self.dtype) - except ValueError: - pass - - if not type(self)._is_recognized_dtype(value.dtype): - raise TypeError( - "setitem requires compatible dtype or scalar, " - f"not {type(value).__name__}" - ) + value = self._validate_listlike(value, "setitem", cast_str=True) elif isinstance(value, self._recognized_scalars): value = self._scalar_type(value) @@ -851,18 +859,8 @@ def _validate_where_value(self, other): raise TypeError(f"Where requires matching dtype, not {type(other)}") else: - # Do type inference if necessary up front - # e.g. we passed PeriodIndex.values and got an ndarray of Periods - other = array(other) - other = extract_array(other, extract_numpy=True) - - if is_categorical_dtype(other.dtype): - # e.g. we have a Categorical holding self.dtype - if is_dtype_equal(other.categories.dtype, self.dtype): - other = other._internal_get_values() - - if not type(self)._is_recognized_dtype(other.dtype): - raise TypeError(f"Where requires matching dtype, not {other.dtype}") + other = self._validate_listlike(other, "where", cast_cat=True) + self._check_compatible_with(other, setitem=True) self._check_compatible_with(other, setitem=True) return self._unbox(other)