From 1faeca0d4e3f6d8f32d374511f7db7fcd0a86009 Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 12 Dec 2022 14:59:00 -0800 Subject: [PATCH 1/2] REF: de-duplicate period-dispatch --- pandas/core/arrays/datetimelike.py | 71 ++++++++++++++++-------------- pandas/core/arrays/period.py | 17 ++----- 2 files changed, 41 insertions(+), 47 deletions(-) diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index f9ff702a608a4..9c71216f0129a 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -4,6 +4,7 @@ datetime, timedelta, ) +from functools import wraps import operator from typing import ( TYPE_CHECKING, @@ -157,6 +158,31 @@ DatetimeLikeArrayT = TypeVar("DatetimeLikeArrayT", bound="DatetimeLikeArrayMixin") +def _period_dispatch(meth): + """ + For PeriodArray methods, dispatch to DatetimeArray and re-wrap the results + in PeriodArray. We cannot use ._ndarray directly for the affected + methods because the i8 data has different semantics on NaT values. + """ + + @wraps(meth) + def new_meth(self, *args, **kwargs): + if not is_period_dtype(self.dtype): + return meth(self, *args, **kwargs) + + arr = self.view("M8[ns]") + result = meth(arr, *args, **kwargs) + if result is NaT: + return NaT + elif isinstance(result, Timestamp): + return self._box_func(result.value) + + res_i8 = result.view("i8") + return self._from_backing_data(res_i8) + + return new_meth + + class DatetimeLikeArrayMixin(OpsMixin, NDArrayBackedExtensionArray): """ Shared Base/Mixin class for DatetimeArray, TimedeltaArray, PeriodArray @@ -1525,6 +1551,15 @@ def __isub__(self: DatetimeLikeArrayT, other) -> DatetimeLikeArrayT: # -------------------------------------------------------------- # Reductions + @_period_dispatch + def _quantile( + self: DatetimeLikeArrayT, + qs: npt.NDArray[np.float64], + interpolation: str, + ) -> DatetimeLikeArrayT: + return super()._quantile(qs=qs, interpolation=interpolation) + + @_period_dispatch def min(self, *, axis: AxisInt | None = None, skipna: bool = True, **kwargs): """ Return the minimum value of the Array or minimum along @@ -1539,21 +1574,10 @@ def min(self, *, axis: AxisInt | None = None, skipna: bool = True, **kwargs): nv.validate_min((), kwargs) nv.validate_minmax_axis(axis, self.ndim) - if is_period_dtype(self.dtype): - # pass datetime64 values to nanops to get correct NaT semantics - result = nanops.nanmin( - self._ndarray.view("M8[ns]"), axis=axis, skipna=skipna - ) - if result is NaT: - return NaT - result = result.view("i8") - if axis is None or self.ndim == 1: - return self._box_func(result) - return self._from_backing_data(result) - result = nanops.nanmin(self._ndarray, axis=axis, skipna=skipna) return self._wrap_reduction_result(axis, result) + @_period_dispatch def max(self, *, axis: AxisInt | None = None, skipna: bool = True, **kwargs): """ Return the maximum value of the Array or maximum along @@ -1568,18 +1592,6 @@ def max(self, *, axis: AxisInt | None = None, skipna: bool = True, **kwargs): nv.validate_max((), kwargs) nv.validate_minmax_axis(axis, self.ndim) - if is_period_dtype(self.dtype): - # pass datetime64 values to nanops to get correct NaT semantics - result = nanops.nanmax( - self._ndarray.view("M8[ns]"), axis=axis, skipna=skipna - ) - if result is NaT: - return result - result = result.view("i8") - if axis is None or self.ndim == 1: - return self._box_func(result) - return self._from_backing_data(result) - result = nanops.nanmax(self._ndarray, axis=axis, skipna=skipna) return self._wrap_reduction_result(axis, result) @@ -1620,22 +1632,13 @@ def mean(self, *, skipna: bool = True, axis: AxisInt | None = 0): ) return self._wrap_reduction_result(axis, result) + @_period_dispatch def median(self, *, axis: AxisInt | None = None, skipna: bool = True, **kwargs): nv.validate_median((), kwargs) if axis is not None and abs(axis) >= self.ndim: raise ValueError("abs(axis) must be less than ndim") - if is_period_dtype(self.dtype): - # pass datetime64 values to nanops to get correct NaT semantics - result = nanops.nanmedian( - self._ndarray.view("M8[ns]"), axis=axis, skipna=skipna - ) - result = result.view("i8") - if axis is None or self.ndim == 1: - return self._box_func(result) - return self._from_backing_data(result) - result = nanops.nanmedian(self._ndarray, axis=axis, skipna=skipna) return self._wrap_reduction_result(axis, result) diff --git a/pandas/core/arrays/period.py b/pandas/core/arrays/period.py index 41ca630e86c10..859bb53b6489a 100644 --- a/pandas/core/arrays/period.py +++ b/pandas/core/arrays/period.py @@ -672,13 +672,15 @@ def searchsorted( ) -> npt.NDArray[np.intp] | np.intp: npvalue = self._validate_setitem_value(value).view("M8[ns]") - # Cast to M8 to get datetime-like NaT placement + # Cast to M8 to get datetime-like NaT placement, + # similar to dtl._period_dispatch m8arr = self._ndarray.view("M8[ns]") return m8arr.searchsorted(npvalue, side=side, sorter=sorter) def fillna(self, value=None, method=None, limit=None) -> PeriodArray: if method is not None: - # view as dt64 so we get treated as timelike in core.missing + # view as dt64 so we get treated as timelike in core.missing, + # similar to dtl._period_dispatch dta = self.view("M8[ns]") result = dta.fillna(value=value, method=method, limit=limit) # error: Incompatible return value type (got "Union[ExtensionArray, @@ -686,17 +688,6 @@ def fillna(self, value=None, method=None, limit=None) -> PeriodArray: return result.view(self.dtype) # type: ignore[return-value] return super().fillna(value=value, method=method, limit=limit) - def _quantile( - self: PeriodArray, - qs: npt.NDArray[np.float64], - interpolation: str, - ) -> PeriodArray: - # dispatch to DatetimeArray implementation - dtres = self.view("M8[ns]")._quantile(qs, interpolation) - # error: Incompatible return value type (got "Union[ExtensionArray, - # ndarray[Any, Any]]", expected "PeriodArray") - return dtres.view(self.dtype) # type: ignore[return-value] - # ------------------------------------------------------------------ # Arithmetic Methods From e3088313632bcd98a0ad059141d7a578854fd73b Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 13 Dec 2022 08:01:59 -0800 Subject: [PATCH 2/2] mypy fixup --- pandas/core/arrays/datetimelike.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index 7525f86388244..63940741c3fe3 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -58,6 +58,7 @@ DatetimeLikeScalar, Dtype, DtypeObj, + F, NpDtype, PositionalIndexer2D, PositionalIndexerTuple, @@ -158,7 +159,7 @@ DatetimeLikeArrayT = TypeVar("DatetimeLikeArrayT", bound="DatetimeLikeArrayMixin") -def _period_dispatch(meth): +def _period_dispatch(meth: F) -> F: """ For PeriodArray methods, dispatch to DatetimeArray and re-wrap the results in PeriodArray. We cannot use ._ndarray directly for the affected @@ -180,7 +181,7 @@ def new_meth(self, *args, **kwargs): res_i8 = result.view("i8") return self._from_backing_data(res_i8) - return new_meth + return cast(F, new_meth) class DatetimeLikeArrayMixin(OpsMixin, NDArrayBackedExtensionArray):