Skip to content

Commit 5fa77ae

Browse files
authored
REF: combine all alignment into _align_method_FRAME (#31288)
1 parent 4cb9044 commit 5fa77ae

File tree

3 files changed

+52
-31
lines changed

3 files changed

+52
-31
lines changed

pandas/core/generic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7206,7 +7206,7 @@ def _clip_with_one_bound(self, threshold, method, axis, inplace):
72067206
if isinstance(self, ABCSeries):
72077207
threshold = self._constructor(threshold, index=self.index)
72087208
else:
7209-
threshold = _align_method_FRAME(self, threshold, axis)
7209+
threshold = _align_method_FRAME(self, threshold, axis, flex=None)[1]
72107210
return self.where(subset, threshold, axis=axis, inplace=inplace)
72117211

72127212
def clip(

pandas/core/ops/__init__.py

Lines changed: 46 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
"""
66
import datetime
77
import operator
8-
from typing import Set, Tuple, Union
8+
from typing import Optional, Set, Tuple, Union
99

1010
import numpy as np
1111

1212
from pandas._libs import Timedelta, Timestamp, lib
1313
from pandas._libs.ops_dispatch import maybe_dispatch_ufunc_to_dunder_op # noqa:F401
14+
from pandas._typing import Level
1415
from pandas.util._decorators import Appender
1516

1617
from pandas.core.dtypes.common import is_list_like, is_timedelta64_dtype
@@ -615,8 +616,27 @@ def _combine_series_frame(left, right, func, axis: int):
615616
return left._construct_result(new_data)
616617

617618

618-
def _align_method_FRAME(left, right, axis):
619-
""" convert rhs to meet lhs dims if input is list, tuple or np.ndarray """
619+
def _align_method_FRAME(
620+
left, right, axis, flex: Optional[bool] = False, level: Level = None
621+
):
622+
"""
623+
Convert rhs to meet lhs dims if input is list, tuple or np.ndarray.
624+
625+
Parameters
626+
----------
627+
left : DataFrame
628+
right : Any
629+
axis: int, str, or None
630+
flex: bool or None, default False
631+
Whether this is a flex op, in which case we reindex.
632+
None indicates not to check for alignment.
633+
level : int or level name, default None
634+
635+
Returns
636+
-------
637+
left : DataFrame
638+
right : Any
639+
"""
620640

621641
def to_series(right):
622642
msg = "Unable to coerce to Series, length must be {req_len}: given {given_len}"
@@ -667,7 +687,22 @@ def to_series(right):
667687
# GH17901
668688
right = to_series(right)
669689

670-
return right
690+
if flex is not None and isinstance(right, ABCDataFrame):
691+
if not left._indexed_same(right):
692+
if flex:
693+
left, right = left.align(right, join="outer", level=level, copy=False)
694+
else:
695+
raise ValueError(
696+
"Can only compare identically-labeled DataFrame objects"
697+
)
698+
elif isinstance(right, ABCSeries):
699+
# axis=1 is default for DataFrame-with-Series op
700+
axis = left._get_axis_number(axis) if axis is not None else 1
701+
left, right = left.align(
702+
right, join="outer", axis=axis, level=level, copy=False
703+
)
704+
705+
return left, right
671706

672707

673708
def _arith_method_FRAME(cls, op, special):
@@ -687,16 +722,15 @@ def _arith_method_FRAME(cls, op, special):
687722
@Appender(doc)
688723
def f(self, other, axis=default_axis, level=None, fill_value=None):
689724

690-
other = _align_method_FRAME(self, other, axis)
725+
self, other = _align_method_FRAME(self, other, axis, flex=True, level=level)
691726

692727
if isinstance(other, ABCDataFrame):
693728
# Another DataFrame
694729
pass_op = op if should_series_dispatch(self, other, op) else na_op
695730
pass_op = pass_op if not is_logical else op
696731

697-
left, right = self.align(other, join="outer", level=level, copy=False)
698-
new_data = left._combine_frame(right, pass_op, fill_value)
699-
return left._construct_result(new_data)
732+
new_data = self._combine_frame(other, pass_op, fill_value)
733+
return self._construct_result(new_data)
700734

701735
elif isinstance(other, ABCSeries):
702736
# For these values of `axis`, we end up dispatching to Series op,
@@ -708,9 +742,6 @@ def f(self, other, axis=default_axis, level=None, fill_value=None):
708742
raise NotImplementedError(f"fill_value {fill_value} not supported.")
709743

710744
axis = self._get_axis_number(axis) if axis is not None else 1
711-
self, other = self.align(
712-
other, join="outer", axis=axis, level=level, copy=False
713-
)
714745
return _combine_series_frame(self, other, pass_op, axis=axis)
715746
else:
716747
# in this case we always have `np.ndim(other) == 0`
@@ -737,20 +768,15 @@ def _flex_comp_method_FRAME(cls, op, special):
737768
@Appender(doc)
738769
def f(self, other, axis=default_axis, level=None):
739770

740-
other = _align_method_FRAME(self, other, axis)
771+
self, other = _align_method_FRAME(self, other, axis, flex=True, level=level)
741772

742773
if isinstance(other, ABCDataFrame):
743774
# Another DataFrame
744-
if not self._indexed_same(other):
745-
self, other = self.align(other, "outer", level=level, copy=False)
746775
new_data = dispatch_to_series(self, other, op, str_rep)
747776
return self._construct_result(new_data)
748777

749778
elif isinstance(other, ABCSeries):
750779
axis = self._get_axis_number(axis) if axis is not None else 1
751-
self, other = self.align(
752-
other, join="outer", axis=axis, level=level, copy=False
753-
)
754780
return _combine_series_frame(self, other, op, axis=axis)
755781
else:
756782
# in this case we always have `np.ndim(other) == 0`
@@ -769,21 +795,15 @@ def _comp_method_FRAME(cls, op, special):
769795
@Appender(f"Wrapper for comparison method {op_name}")
770796
def f(self, other):
771797

772-
other = _align_method_FRAME(self, other, axis=None)
798+
self, other = _align_method_FRAME(
799+
self, other, axis=None, level=None, flex=False
800+
)
773801

774802
if isinstance(other, ABCDataFrame):
775803
# Another DataFrame
776-
if not self._indexed_same(other):
777-
raise ValueError(
778-
"Can only compare identically-labeled DataFrame objects"
779-
)
780804
new_data = dispatch_to_series(self, other, op, str_rep)
781805

782806
elif isinstance(other, ABCSeries):
783-
# axis=1 is default for DataFrame-with-Series op
784-
self, other = self.align(
785-
other, join="outer", axis=1, level=None, copy=False
786-
)
787807
new_data = dispatch_to_series(self, other, op, axis="columns")
788808

789809
else:

pandas/tests/frame/test_operators.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -864,10 +864,10 @@ def test_alignment_non_pandas(self):
864864
]:
865865

866866
tm.assert_series_equal(
867-
align(df, val, "index"), Series([1, 2, 3], index=df.index)
867+
align(df, val, "index")[1], Series([1, 2, 3], index=df.index)
868868
)
869869
tm.assert_series_equal(
870-
align(df, val, "columns"), Series([1, 2, 3], index=df.columns)
870+
align(df, val, "columns")[1], Series([1, 2, 3], index=df.columns)
871871
)
872872

873873
# length mismatch
@@ -882,10 +882,11 @@ def test_alignment_non_pandas(self):
882882

883883
val = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
884884
tm.assert_frame_equal(
885-
align(df, val, "index"), DataFrame(val, index=df.index, columns=df.columns)
885+
align(df, val, "index")[1],
886+
DataFrame(val, index=df.index, columns=df.columns),
886887
)
887888
tm.assert_frame_equal(
888-
align(df, val, "columns"),
889+
align(df, val, "columns")[1],
889890
DataFrame(val, index=df.index, columns=df.columns),
890891
)
891892

0 commit comments

Comments
 (0)