diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 012fb1d0c2eb7..6b1f8e7f1c6a5 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -5341,7 +5341,7 @@ def reorder_levels(self, order, axis=0) -> "DataFrame": # ---------------------------------------------------------------------- # Arithmetic / combination related - def _combine_frame(self, other, func, fill_value=None, level=None): + def _combine_frame(self, other: "DataFrame", func, fill_value=None): # at this point we have `self._indexed_same(other)` if fill_value is None: @@ -5368,7 +5368,7 @@ def _arith_op(left, right): return new_data - def _combine_match_index(self, other, func): + def _combine_match_index(self, other: Series, func): # at this point we have `self.index.equals(other.index)` if ops.should_series_dispatch(self, other, func): diff --git a/pandas/core/generic.py b/pandas/core/generic.py index c502adcdb09e0..2312a77e94465 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -7197,7 +7197,7 @@ def _clip_with_one_bound(self, threshold, method, axis, inplace): if isinstance(self, ABCSeries): threshold = self._constructor(threshold, index=self.index) else: - threshold = _align_method_FRAME(self, threshold, axis) + threshold = _align_method_FRAME(self, threshold, axis, flex=None)[1] return self.where(subset, threshold, axis=axis, inplace=inplace) def clip( diff --git a/pandas/core/ops/__init__.py b/pandas/core/ops/__init__.py index 9ed233cad65ce..393bb2f1900ff 100644 --- a/pandas/core/ops/__init__.py +++ b/pandas/core/ops/__init__.py @@ -5,12 +5,13 @@ """ import datetime import operator -from typing import Set, Tuple, Union +from typing import Optional, Set, Tuple, Union import numpy as np from pandas._libs import Timedelta, Timestamp, lib from pandas._libs.ops_dispatch import maybe_dispatch_ufunc_to_dunder_op # noqa:F401 +from pandas._typing import Level from pandas.util._decorators import Appender from pandas.core.dtypes.common import is_list_like, is_timedelta64_dtype @@ -609,8 +610,27 @@ def _combine_series_frame(left, right, func, axis: int): return left._construct_result(new_data) -def _align_method_FRAME(left, right, axis): - """ convert rhs to meet lhs dims if input is list, tuple or np.ndarray """ +def _align_method_FRAME( + left, right, axis, flex: Optional[bool] = False, level: Level = None +): + """ + Convert rhs to meet lhs dims if input is list, tuple or np.ndarray. + + Parameters + ---------- + left : DataFrame + right : Any + axis: int, str, or None + flex: bool or None, default False + Whether this is a flex op, in which case we reindex. + None indicates not to check for alignment. + level : int or level name, default None + + Returns + ------- + left : DataFrame + right : Any + """ def to_series(right): msg = "Unable to coerce to Series, length must be {req_len}: given {given_len}" @@ -661,7 +681,22 @@ def to_series(right): # GH17901 right = to_series(right) - return right + if flex is not None and isinstance(right, ABCDataFrame): + if not left._indexed_same(right): + if flex: + left, right = left.align(right, join="outer", level=level, copy=False) + else: + raise ValueError( + "Can only compare identically-labeled DataFrame objects" + ) + elif isinstance(right, ABCSeries): + # axis=1 is default for DataFrame-with-Series op + axis = left._get_axis_number(axis) if axis is not None else 1 + left, right = left.align( + right, join="outer", axis=axis, level=level, copy=False + ) + + return left, right def _arith_method_FRAME(cls, op, special): @@ -681,16 +716,15 @@ def _arith_method_FRAME(cls, op, special): @Appender(doc) def f(self, other, axis=default_axis, level=None, fill_value=None): - other = _align_method_FRAME(self, other, axis) + self, other = _align_method_FRAME(self, other, axis, flex=True, level=level) if isinstance(other, ABCDataFrame): # Another DataFrame pass_op = op if should_series_dispatch(self, other, op) else na_op pass_op = pass_op if not is_logical else op - left, right = self.align(other, join="outer", level=level, copy=False) - new_data = left._combine_frame(right, pass_op, fill_value) - return left._construct_result(new_data) + new_data = self._combine_frame(other, pass_op, fill_value) + return self._construct_result(new_data) elif isinstance(other, ABCSeries): # For these values of `axis`, we end up dispatching to Series op, @@ -702,9 +736,6 @@ def f(self, other, axis=default_axis, level=None, fill_value=None): raise NotImplementedError(f"fill_value {fill_value} not supported.") axis = self._get_axis_number(axis) if axis is not None else 1 - self, other = self.align( - other, join="outer", axis=axis, level=level, copy=False - ) return _combine_series_frame(self, other, pass_op, axis=axis) else: # in this case we always have `np.ndim(other) == 0` @@ -731,20 +762,15 @@ def _flex_comp_method_FRAME(cls, op, special): @Appender(doc) def f(self, other, axis=default_axis, level=None): - other = _align_method_FRAME(self, other, axis) + self, other = _align_method_FRAME(self, other, axis, flex=True, level=level) if isinstance(other, ABCDataFrame): # Another DataFrame - if not self._indexed_same(other): - self, other = self.align(other, "outer", level=level, copy=False) new_data = dispatch_to_series(self, other, op, str_rep) return self._construct_result(new_data) elif isinstance(other, ABCSeries): axis = self._get_axis_number(axis) if axis is not None else 1 - self, other = self.align( - other, join="outer", axis=axis, level=level, copy=False - ) return _combine_series_frame(self, other, op, axis=axis) else: # in this case we always have `np.ndim(other) == 0` @@ -763,21 +789,15 @@ def _comp_method_FRAME(cls, op, special): @Appender(f"Wrapper for comparison method {op_name}") def f(self, other): - other = _align_method_FRAME(self, other, axis=None) + self, other = _align_method_FRAME( + self, other, axis=None, level=None, flex=False + ) if isinstance(other, ABCDataFrame): # Another DataFrame - if not self._indexed_same(other): - raise ValueError( - "Can only compare identically-labeled DataFrame objects" - ) new_data = dispatch_to_series(self, other, op, str_rep) elif isinstance(other, ABCSeries): - # axis=1 is default for DataFrame-with-Series op - self, other = self.align( - other, join="outer", axis=1, level=None, copy=False - ) new_data = dispatch_to_series(self, other, op, axis="columns") else: diff --git a/pandas/tests/frame/test_operators.py b/pandas/tests/frame/test_operators.py index 55f1216a0efd7..162f3c114fa5d 100644 --- a/pandas/tests/frame/test_operators.py +++ b/pandas/tests/frame/test_operators.py @@ -864,10 +864,10 @@ def test_alignment_non_pandas(self): ]: tm.assert_series_equal( - align(df, val, "index"), Series([1, 2, 3], index=df.index) + align(df, val, "index")[1], Series([1, 2, 3], index=df.index) ) tm.assert_series_equal( - align(df, val, "columns"), Series([1, 2, 3], index=df.columns) + align(df, val, "columns")[1], Series([1, 2, 3], index=df.columns) ) # length mismatch @@ -882,10 +882,11 @@ def test_alignment_non_pandas(self): val = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) tm.assert_frame_equal( - align(df, val, "index"), DataFrame(val, index=df.index, columns=df.columns) + align(df, val, "index")[1], + DataFrame(val, index=df.index, columns=df.columns), ) tm.assert_frame_equal( - align(df, val, "columns"), + align(df, val, "columns")[1], DataFrame(val, index=df.index, columns=df.columns), )