Skip to content

REF: do extract_array earlier in series arith/comparison ops #28066

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Sep 2, 2019
100 changes: 63 additions & 37 deletions pandas/core/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""
import datetime
import operator
from typing import Any, Callable, Tuple
from typing import Any, Callable, Tuple, Union

import numpy as np

Expand Down Expand Up @@ -34,10 +34,11 @@
ABCIndexClass,
ABCSeries,
ABCSparseSeries,
ABCTimedeltaArray,
ABCTimedeltaIndex,
)
from pandas.core.dtypes.missing import isna, notna

import pandas as pd
from pandas._typing import ArrayLike
from pandas.core.construction import array, extract_array
from pandas.core.ops.array_ops import comp_method_OBJECT_ARRAY, define_na_arithmetic_op
Expand Down Expand Up @@ -148,6 +149,8 @@ def maybe_upcast_for_op(obj, shape: Tuple[int, ...]):
Be careful to call this *after* determining the `name` attribute to be
attached to the result of the arithmetic operation.
"""
from pandas.core.arrays import TimedeltaArray

if type(obj) is datetime.timedelta:
# GH#22390 cast up to Timedelta to rely on Timedelta
# implementation; otherwise operation against numeric-dtype
Expand All @@ -157,12 +160,10 @@ def maybe_upcast_for_op(obj, shape: Tuple[int, ...]):
if isna(obj):
# wrapping timedelta64("NaT") in Timedelta returns NaT,
# which would incorrectly be treated as a datetime-NaT, so
# we broadcast and wrap in a Series
# we broadcast and wrap in a TimedeltaArray
obj = obj.astype("timedelta64[ns]")
right = np.broadcast_to(obj, shape)

# Note: we use Series instead of TimedeltaIndex to avoid having
# to worry about catching NullFrequencyError.
return pd.Series(right)
return TimedeltaArray(right)

# In particular non-nanosecond timedelta64 needs to be cast to
# nanoseconds, or else we get undesired behavior like
Expand All @@ -173,7 +174,7 @@ def maybe_upcast_for_op(obj, shape: Tuple[int, ...]):
# GH#22390 Unfortunately we need to special-case right-hand
# timedelta64 dtypes because numpy casts integer dtypes to
# timedelta64 when operating with timedelta64
return pd.TimedeltaIndex(obj)
return TimedeltaArray._from_sequence(obj)
return obj


Expand Down Expand Up @@ -520,13 +521,34 @@ def column_op(a, b):
return result


def dispatch_to_extension_op(op, left, right):
def dispatch_to_extension_op(
op,
left: Union[ABCExtensionArray, np.ndarray],
right: Any,
keep_null_freq: bool = False,
):
"""
Assume that left or right is a Series backed by an ExtensionArray,
apply the operator defined by op.

Parameters
----------
op : binary operator
left : ExtensionArray or np.ndarray
right : object
keep_null_freq : bool, default False
Whether to re-raise a NullFrequencyError unchanged, as opposed to
catching and raising TypeError.

Returns
-------
ExtensionArray or np.ndarray
2-tuple of these if op is divmod or rdivmod
"""
# NB: left and right should already be unboxed, so neither should be
# a Series or Index.

if left.dtype.kind in "mM":
if left.dtype.kind in "mM" and isinstance(left, np.ndarray):
# We need to cast datetime64 and timedelta64 ndarrays to
# DatetimeArray/TimedeltaArray. But we avoid wrapping others in
# PandasArray as that behaves poorly with e.g. IntegerArray.
Expand All @@ -535,15 +557,15 @@ def dispatch_to_extension_op(op, left, right):
# The op calls will raise TypeError if the op is not defined
# on the ExtensionArray

# unbox Series and Index to arrays
new_left = extract_array(left, extract_numpy=True)
new_right = extract_array(right, extract_numpy=True)

try:
res_values = op(new_left, new_right)
res_values = op(left, right)
except NullFrequencyError:
# DatetimeIndex and TimedeltaIndex with freq == None raise ValueError
# on add/sub of integers (or int-like). We re-raise as a TypeError.
if keep_null_freq:
# TODO: remove keep_null_freq after Timestamp+int deprecation
# GH#22535 is enforced
raise
raise TypeError(
"incompatible type for a datetime/timedelta "
"operation [{name}]".format(name=op.__name__)
Expand Down Expand Up @@ -615,25 +637,29 @@ def wrapper(left, right):
if isinstance(right, ABCDataFrame):
return NotImplemented

keep_null_freq = isinstance(
right,
(ABCDatetimeIndex, ABCDatetimeArray, ABCTimedeltaIndex, ABCTimedeltaArray),
)

left, right = _align_method_SERIES(left, right)
res_name = get_op_result_name(left, right)
right = maybe_upcast_for_op(right, left.shape)

if should_extension_dispatch(left, right):
result = dispatch_to_extension_op(op, left, right)
lvalues = extract_array(left, extract_numpy=True)
rvalues = extract_array(right, extract_numpy=True)

elif is_timedelta64_dtype(right) or isinstance(
right, (ABCDatetimeArray, ABCDatetimeIndex)
):
# We should only get here with td64 right with non-scalar values
# for right upcast by maybe_upcast_for_op
assert not isinstance(right, (np.timedelta64, np.ndarray))
result = op(left._values, right)
rvalues = maybe_upcast_for_op(rvalues, lvalues.shape)

else:
lvalues = extract_array(left, extract_numpy=True)
rvalues = extract_array(right, extract_numpy=True)
if should_extension_dispatch(lvalues, rvalues):
result = dispatch_to_extension_op(op, lvalues, rvalues, keep_null_freq)

elif is_timedelta64_dtype(rvalues) or isinstance(rvalues, ABCDatetimeArray):
# We should only get here with td64 rvalues with non-scalar values
# for rvalues upcast by maybe_upcast_for_op
assert not isinstance(rvalues, (np.timedelta64, np.ndarray))
result = dispatch_to_extension_op(op, lvalues, rvalues, keep_null_freq)

else:
with np.errstate(all="ignore"):
result = na_op(lvalues, rvalues)

Expand Down Expand Up @@ -708,25 +734,25 @@ def wrapper(self, other, axis=None):
if len(self) != len(other):
raise ValueError("Lengths must match to compare")

if should_extension_dispatch(self, other):
res_values = dispatch_to_extension_op(op, self, other)
lvalues = extract_array(self, extract_numpy=True)
rvalues = extract_array(other, extract_numpy=True)

elif is_scalar(other) and isna(other):
if should_extension_dispatch(lvalues, rvalues):
res_values = dispatch_to_extension_op(op, lvalues, rvalues)

elif is_scalar(rvalues) and isna(rvalues):
# numpy does not like comparisons vs None
if op is operator.ne:
res_values = np.ones(len(self), dtype=bool)
res_values = np.ones(len(lvalues), dtype=bool)
else:
res_values = np.zeros(len(self), dtype=bool)
res_values = np.zeros(len(lvalues), dtype=bool)

else:
lvalues = extract_array(self, extract_numpy=True)
rvalues = extract_array(other, extract_numpy=True)

with np.errstate(all="ignore"):
res_values = na_op(lvalues, rvalues)
if is_scalar(res_values):
raise TypeError(
"Could not compare {typ} type with Series".format(typ=type(other))
"Could not compare {typ} type with Series".format(typ=type(rvalues))
)

result = self._constructor(res_values, index=self.index)
Expand Down