diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index de6d6c8e07144..66885b098c6ff 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -2,6 +2,7 @@ from copy import copy as copy_func from datetime import datetime +import functools from itertools import zip_longest import operator from typing import ( @@ -39,6 +40,7 @@ ArrayLike, Dtype, DtypeObj, + F, Shape, T, final, @@ -186,6 +188,33 @@ _o_dtype = np.dtype("object") +def _maybe_return_indexers(meth: F) -> F: + """ + Decorator to simplify 'return_indexers' checks in Index.join. + """ + + @functools.wraps(meth) + def join( + self, + other, + how: str_t = "left", + level=None, + return_indexers: bool = False, + sort: bool = False, + ): + join_index, lidx, ridx = meth(self, other, how=how, level=level, sort=sort) + if not return_indexers: + return join_index + + if lidx is not None: + lidx = ensure_platform_int(lidx) + if ridx is not None: + ridx = ensure_platform_int(ridx) + return join_index, lidx, ridx + + return cast(F, join) + + def disallow_kwargs(kwargs: dict[str, Any]): if kwargs: raise TypeError(f"Unexpected keyword arguments {repr(set(kwargs))}") @@ -3761,9 +3790,7 @@ def reindex(self, target, method=None, level=None, limit=None, tolerance=None): if level is not None: if method is not None: raise TypeError("Fill method not supported if level passed") - _, indexer, _ = self._join_level( - target, level, how="right", return_indexers=True - ) + _, indexer, _ = self._join_level(target, level, how="right") else: if self.equals(target): indexer = None @@ -3859,6 +3886,7 @@ def _reindex_non_unique(self, target): # -------------------------------------------------------------------- # Join Methods + @_maybe_return_indexers def join( self, other, @@ -3900,60 +3928,44 @@ def join( if self.names == other.names: pass else: - return self._join_multi(other, how=how, return_indexers=return_indexers) + return self._join_multi(other, how=how) # join on the level if level is not None and (self_is_mi or other_is_mi): - return self._join_level( - other, level, how=how, return_indexers=return_indexers - ) + return self._join_level(other, level, how=how) if len(other) == 0 and how in ("left", "outer"): join_index = self._view() - if return_indexers: - rindexer = np.repeat(np.intp(-1), len(join_index)) - return join_index, None, rindexer - else: - return join_index + rindexer = np.repeat(np.intp(-1), len(join_index)) + return join_index, None, rindexer if len(self) == 0 and how in ("right", "outer"): join_index = other._view() - if return_indexers: - lindexer = np.repeat(np.intp(-1), len(join_index)) - return join_index, lindexer, None - else: - return join_index + lindexer = np.repeat(np.intp(-1), len(join_index)) + return join_index, lindexer, None if self._join_precedence < other._join_precedence: how = {"right": "left", "left": "right"}.get(how, how) - result = other.join( - self, how=how, level=level, return_indexers=return_indexers + join_index, lidx, ridx = other.join( + self, how=how, level=level, return_indexers=True ) - if return_indexers: - x, y, z = result - result = x, z, y - return result + lidx, ridx = ridx, lidx + return join_index, lidx, ridx if not is_dtype_equal(self.dtype, other.dtype): this = self.astype("O") other = other.astype("O") - return this.join(other, how=how, return_indexers=return_indexers) + return this.join(other, how=how, return_indexers=True) _validate_join_method(how) if not self.is_unique and not other.is_unique: - return self._join_non_unique( - other, how=how, return_indexers=return_indexers - ) + return self._join_non_unique(other, how=how) elif not self.is_unique or not other.is_unique: if self.is_monotonic and other.is_monotonic: - return self._join_monotonic( - other, how=how, return_indexers=return_indexers - ) + return self._join_monotonic(other, how=how) else: - return self._join_non_unique( - other, how=how, return_indexers=return_indexers - ) + return self._join_non_unique(other, how=how) elif ( self.is_monotonic and other.is_monotonic @@ -3965,9 +3977,7 @@ def join( # Categorical is monotonic if data are ordered as categories, but join can # not handle this in case of not lexicographically monotonic GH#38502 try: - return self._join_monotonic( - other, how=how, return_indexers=return_indexers - ) + return self._join_monotonic(other, how=how) except TypeError: pass @@ -3987,21 +3997,18 @@ def join( if sort: join_index = join_index.sort_values() - if return_indexers: - if join_index is self: - lindexer = None - else: - lindexer = self.get_indexer(join_index) - if join_index is other: - rindexer = None - else: - rindexer = other.get_indexer(join_index) - return join_index, lindexer, rindexer + if join_index is self: + lindexer = None else: - return join_index + lindexer = self.get_indexer(join_index) + if join_index is other: + rindexer = None + else: + rindexer = other.get_indexer(join_index) + return join_index, lindexer, rindexer @final - def _join_multi(self, other, how, return_indexers=True): + def _join_multi(self, other, how): from pandas.core.indexes.multi import MultiIndex from pandas.core.reshape.merge import restore_dropped_levels_multijoin @@ -4054,10 +4061,7 @@ def _join_multi(self, other, how, return_indexers=True): multi_join_idx = multi_join_idx.remove_unused_levels() - if return_indexers: - return multi_join_idx, lidx, ridx - else: - return multi_join_idx + return multi_join_idx, lidx, ridx jl = list(overlap)[0] @@ -4071,16 +4075,14 @@ def _join_multi(self, other, how, return_indexers=True): how = {"right": "left", "left": "right"}.get(how, how) level = other.names.index(jl) - result = self._join_level( - other, level, how=how, return_indexers=return_indexers - ) + result = self._join_level(other, level, how=how) - if flip_order and isinstance(result, tuple): + if flip_order: return result[0], result[2], result[1] return result @final - def _join_non_unique(self, other, how="left", return_indexers=False): + def _join_non_unique(self, other, how="left"): from pandas.core.reshape.merge import get_join_indexers # We only get here if dtypes match @@ -4102,15 +4104,10 @@ def _join_non_unique(self, other, how="left", return_indexers=False): join_index = self._wrap_joined_index(join_array, other) - if return_indexers: - return join_index, left_idx, right_idx - else: - return join_index + return join_index, left_idx, right_idx @final - def _join_level( - self, other, level, how="left", return_indexers=False, keep_order=True - ): + def _join_level(self, other, level, how="left", keep_order=True): """ The join method *only* affects the level of the resulting MultiIndex. Otherwise it just exactly aligns the Index data to the @@ -4247,28 +4244,22 @@ def _get_leaf_sorter(labels: list[np.ndarray]) -> np.ndarray: if flip_order: left_indexer, right_indexer = right_indexer, left_indexer - if return_indexers: - left_indexer = ( - None if left_indexer is None else ensure_platform_int(left_indexer) - ) - right_indexer = ( - None if right_indexer is None else ensure_platform_int(right_indexer) - ) - return join_index, left_indexer, right_indexer - else: - return join_index + left_indexer = ( + None if left_indexer is None else ensure_platform_int(left_indexer) + ) + right_indexer = ( + None if right_indexer is None else ensure_platform_int(right_indexer) + ) + return join_index, left_indexer, right_indexer @final - def _join_monotonic(self, other, how="left", return_indexers=False): + def _join_monotonic(self, other: Index, how="left"): # We only get here with matching dtypes assert other.dtype == self.dtype if self.equals(other): ret_index = other if how == "right" else self - if return_indexers: - return ret_index, None, None - else: - return ret_index + return ret_index, None, None sv = self._get_engine_target() ov = other._get_engine_target() @@ -4304,12 +4295,9 @@ def _join_monotonic(self, other, how="left", return_indexers=False): join_index = self._wrap_joined_index(join_array, other) - if return_indexers: - lidx = None if lidx is None else ensure_platform_int(lidx) - ridx = None if ridx is None else ensure_platform_int(ridx) - return join_index, lidx, ridx - else: - return join_index + lidx = None if lidx is None else ensure_platform_int(lidx) + ridx = None if ridx is None else ensure_platform_int(ridx) + return join_index, lidx, ridx def _wrap_joined_index( self: _IndexT, joined: np.ndarray, other: _IndexT diff --git a/pandas/core/indexes/multi.py b/pandas/core/indexes/multi.py index eea1a069b9df6..15514a1627645 100644 --- a/pandas/core/indexes/multi.py +++ b/pandas/core/indexes/multi.py @@ -2532,7 +2532,7 @@ def reindex(self, target, method=None, level=None, limit=None, tolerance=None): else: target = ensure_index(target) target, indexer, _ = self._join_level( - target, level, how="right", return_indexers=True, keep_order=False + target, level, how="right", keep_order=False ) else: target = ensure_index(target)