From cd3c048aa5ff0fccdcab2af8e3a21b597e0e8223 Mon Sep 17 00:00:00 2001 From: Brock Date: Thu, 29 Sep 2022 07:44:15 -0700 Subject: [PATCH] REF: helpers to de-duplicate datetimelike arithmetic --- pandas/core/arrays/datetimelike.py | 119 +++++++++++++++++------------ 1 file changed, 70 insertions(+), 49 deletions(-) diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index 997cef7b09576..b3f9569e7ab63 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -36,6 +36,7 @@ Period, Resolution, Tick, + Timedelta, Timestamp, astype_overflowsafe, delta_to_nanoseconds, @@ -1112,6 +1113,41 @@ def _cmp_method(self, other, op): __divmod__ = make_invalid_op("__divmod__") __rdivmod__ = make_invalid_op("__rdivmod__") + @final + def _get_i8_values_and_mask( + self, other + ) -> tuple[int | npt.NDArray[np.int64], None | npt.NDArray[np.bool_]]: + """ + Get the int64 values and b_mask to pass to checked_add_with_arr. + """ + if isinstance(other, Period): + i8values = other.ordinal + mask = None + elif isinstance(other, (Timestamp, Timedelta)): + i8values = other.value + mask = None + else: + # PeriodArray, DatetimeArray, TimedeltaArray + mask = other._isnan + i8values = other.asi8 + return i8values, mask + + @final + def _get_arithmetic_result_freq(self, other) -> BaseOffset | None: + """ + Check if we can preserve self.freq in addition or subtraction. + """ + # Adding or subtracting a Timedelta/Timestamp scalar is freq-preserving + # whenever self.freq is a Tick + if is_period_dtype(self.dtype): + return self.freq + elif not lib.is_scalar(other): + return None + elif isinstance(self.freq, Tick): + # In these cases + return self.freq + return None + @final def _add_datetimelike_scalar(self, other) -> DatetimeArray: if not is_timedelta64_dtype(self.dtype): @@ -1168,33 +1204,12 @@ def _sub_datetimelike_scalar(self, other: datetime | np.datetime64): self = cast("DatetimeArray", self) # subtract a datetime from myself, yielding a ndarray[timedelta64[ns]] - # error: Non-overlapping identity check (left operand type: "Union[datetime, - # datetime64]", right operand type: "NaTType") [comparison-overlap] - assert other is not NaT # type: ignore[comparison-overlap] - other = Timestamp(other) - # error: Non-overlapping identity check (left operand type: "Timestamp", - # right operand type: "NaTType") - if other is NaT: # type: ignore[comparison-overlap] + if isna(other): + # i.e. np.datetime64("NaT") return self - NaT - try: - self._assert_tzawareness_compat(other) - except TypeError as err: - new_message = str(err).replace("compare", "subtract") - raise type(err)(new_message) from err - - i8 = self.asi8 - result = checked_add_with_arr(i8, -other.value, arr_mask=self._isnan) - res_m8 = result.view(f"timedelta64[{self._unit}]") - - new_freq = None - if isinstance(self.freq, Tick): - # adding a scalar preserves freq - new_freq = self.freq - - from pandas.core.arrays import TimedeltaArray - - return TimedeltaArray._simple_new(res_m8, dtype=res_m8.dtype, freq=new_freq) + other = Timestamp(other) + return self._sub_datetimelike(other) @final def _sub_datetime_arraylike(self, other): @@ -1206,6 +1221,13 @@ def _sub_datetime_arraylike(self, other): self = cast("DatetimeArray", self) other = ensure_wrapped_if_datetimelike(other) + return self._sub_datetimelike(other) + + @final + def _sub_datetimelike(self, other: Timestamp | DatetimeArray) -> TimedeltaArray: + self = cast("DatetimeArray", self) + + from pandas.core.arrays import TimedeltaArray try: self._assert_tzawareness_compat(other) @@ -1213,12 +1235,14 @@ def _sub_datetime_arraylike(self, other): new_message = str(err).replace("compare", "subtract") raise type(err)(new_message) from err - self_i8 = self.asi8 - other_i8 = other.asi8 - new_values = checked_add_with_arr( - self_i8, -other_i8, arr_mask=self._isnan, b_mask=other._isnan + other_i8, o_mask = self._get_i8_values_and_mask(other) + res_values = checked_add_with_arr( + self.asi8, -other_i8, arr_mask=self._isnan, b_mask=o_mask ) - return new_values.view("timedelta64[ns]") + res_m8 = res_values.view(f"timedelta64[{self._unit}]") + + new_freq = self._get_arithmetic_result_freq(other) + return TimedeltaArray._simple_new(res_m8, dtype=res_m8.dtype, freq=new_freq) @final def _sub_period(self, other: Period) -> npt.NDArray[np.object_]: @@ -1262,27 +1286,16 @@ def _add_timedeltalike_scalar(self, other): Same type as self """ if isna(other): - # i.e np.timedelta64("NaT"), not recognized by delta_to_nanoseconds + # i.e np.timedelta64("NaT") new_values = np.empty(self.shape, dtype="i8").view(self._ndarray.dtype) new_values.fill(iNaT) return type(self)._simple_new(new_values, dtype=self.dtype) # PeriodArray overrides, so we only get here with DTA/TDA - # error: "DatetimeLikeArrayMixin" has no attribute "_reso" - inc = delta_to_nanoseconds(other, reso=self._reso) # type: ignore[attr-defined] - - new_values = checked_add_with_arr(self.asi8, inc, arr_mask=self._isnan) - new_values = new_values.view(self._ndarray.dtype) - - new_freq = None - if isinstance(self.freq, Tick) or is_period_dtype(self.dtype): - # adding a scalar preserves freq - new_freq = self.freq + self = cast("DatetimeArray | TimedeltaArray", self) - # error: Unexpected keyword argument "freq" for "_simple_new" of "NDArrayBacked" - return type(self)._simple_new( # type: ignore[call-arg] - new_values, dtype=self.dtype, freq=new_freq - ) + other = Timedelta(other)._as_unit(self._unit) + return self._add_timedeltalike(other) def _add_timedelta_arraylike( self, other: TimedeltaArray | npt.NDArray[np.timedelta64] @@ -1310,13 +1323,21 @@ def _add_timedelta_arraylike( else: other = other._as_unit(self._unit) - self_i8 = self.asi8 - other_i8 = other.asi8 + return self._add_timedeltalike(other) + + @final + def _add_timedeltalike(self, other: Timedelta | TimedeltaArray): + self = cast("DatetimeArray | TimedeltaArray", self) + + other_i8, o_mask = self._get_i8_values_and_mask(other) new_values = checked_add_with_arr( - self_i8, other_i8, arr_mask=self._isnan, b_mask=other._isnan + self.asi8, other_i8, arr_mask=self._isnan, b_mask=o_mask ) res_values = new_values.view(self._ndarray.dtype) - return type(self)._simple_new(res_values, dtype=self.dtype) + + new_freq = self._get_arithmetic_result_freq(other) + + return type(self)._simple_new(res_values, dtype=self.dtype, freq=new_freq) @final def _add_nat(self):