Skip to content

REF: simplify Index.join dispatch/wrapping #40793

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 73 additions & 85 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from copy import copy as copy_func
from datetime import datetime
import functools
from itertools import zip_longest
import operator
from typing import (
Expand Down Expand Up @@ -39,6 +40,7 @@
ArrayLike,
Dtype,
DtypeObj,
F,
Shape,
T,
final,
Expand Down Expand Up @@ -186,6 +188,33 @@
_o_dtype = np.dtype("object")


def _maybe_return_indexers(meth: F) -> F:
"""
Decorator to simplify 'return_indexers' checks in Index.join.
"""

@functools.wraps(meth)
def join(
self,
other,
how: str_t = "left",
level=None,
return_indexers: bool = False,
sort: bool = False,
):
join_index, lidx, ridx = meth(self, other, how=how, level=level, sort=sort)
if not return_indexers:
return join_index

if lidx is not None:
lidx = ensure_platform_int(lidx)
if ridx is not None:
ridx = ensure_platform_int(ridx)
return join_index, lidx, ridx

return cast(F, join)


def disallow_kwargs(kwargs: dict[str, Any]):
if kwargs:
raise TypeError(f"Unexpected keyword arguments {repr(set(kwargs))}")
Expand Down Expand Up @@ -3761,9 +3790,7 @@ def reindex(self, target, method=None, level=None, limit=None, tolerance=None):
if level is not None:
if method is not None:
raise TypeError("Fill method not supported if level passed")
_, indexer, _ = self._join_level(
target, level, how="right", return_indexers=True
)
_, indexer, _ = self._join_level(target, level, how="right")
else:
if self.equals(target):
indexer = None
Expand Down Expand Up @@ -3859,6 +3886,7 @@ def _reindex_non_unique(self, target):
# --------------------------------------------------------------------
# Join Methods

@_maybe_return_indexers
def join(
self,
other,
Expand Down Expand Up @@ -3900,60 +3928,44 @@ def join(
if self.names == other.names:
pass
else:
return self._join_multi(other, how=how, return_indexers=return_indexers)
return self._join_multi(other, how=how)

# join on the level
if level is not None and (self_is_mi or other_is_mi):
return self._join_level(
other, level, how=how, return_indexers=return_indexers
)
return self._join_level(other, level, how=how)

if len(other) == 0 and how in ("left", "outer"):
join_index = self._view()
if return_indexers:
rindexer = np.repeat(np.intp(-1), len(join_index))
return join_index, None, rindexer
else:
return join_index
rindexer = np.repeat(np.intp(-1), len(join_index))
return join_index, None, rindexer

if len(self) == 0 and how in ("right", "outer"):
join_index = other._view()
if return_indexers:
lindexer = np.repeat(np.intp(-1), len(join_index))
return join_index, lindexer, None
else:
return join_index
lindexer = np.repeat(np.intp(-1), len(join_index))
return join_index, lindexer, None

if self._join_precedence < other._join_precedence:
how = {"right": "left", "left": "right"}.get(how, how)
result = other.join(
self, how=how, level=level, return_indexers=return_indexers
join_index, lidx, ridx = other.join(
self, how=how, level=level, return_indexers=True
)
if return_indexers:
x, y, z = result
result = x, z, y
return result
lidx, ridx = ridx, lidx
return join_index, lidx, ridx

if not is_dtype_equal(self.dtype, other.dtype):
this = self.astype("O")
other = other.astype("O")
return this.join(other, how=how, return_indexers=return_indexers)
return this.join(other, how=how, return_indexers=True)

_validate_join_method(how)

if not self.is_unique and not other.is_unique:
return self._join_non_unique(
other, how=how, return_indexers=return_indexers
)
return self._join_non_unique(other, how=how)
elif not self.is_unique or not other.is_unique:
if self.is_monotonic and other.is_monotonic:
return self._join_monotonic(
other, how=how, return_indexers=return_indexers
)
return self._join_monotonic(other, how=how)
else:
return self._join_non_unique(
other, how=how, return_indexers=return_indexers
)
return self._join_non_unique(other, how=how)
elif (
self.is_monotonic
and other.is_monotonic
Expand All @@ -3965,9 +3977,7 @@ def join(
# Categorical is monotonic if data are ordered as categories, but join can
# not handle this in case of not lexicographically monotonic GH#38502
try:
return self._join_monotonic(
other, how=how, return_indexers=return_indexers
)
return self._join_monotonic(other, how=how)
except TypeError:
pass

Expand All @@ -3987,21 +3997,18 @@ def join(
if sort:
join_index = join_index.sort_values()

if return_indexers:
if join_index is self:
lindexer = None
else:
lindexer = self.get_indexer(join_index)
if join_index is other:
rindexer = None
else:
rindexer = other.get_indexer(join_index)
return join_index, lindexer, rindexer
if join_index is self:
lindexer = None
else:
return join_index
lindexer = self.get_indexer(join_index)
if join_index is other:
rindexer = None
else:
rindexer = other.get_indexer(join_index)
return join_index, lindexer, rindexer

@final
def _join_multi(self, other, how, return_indexers=True):
def _join_multi(self, other, how):
from pandas.core.indexes.multi import MultiIndex
from pandas.core.reshape.merge import restore_dropped_levels_multijoin

Expand Down Expand Up @@ -4054,10 +4061,7 @@ def _join_multi(self, other, how, return_indexers=True):

multi_join_idx = multi_join_idx.remove_unused_levels()

if return_indexers:
return multi_join_idx, lidx, ridx
else:
return multi_join_idx
return multi_join_idx, lidx, ridx

jl = list(overlap)[0]

Expand All @@ -4071,16 +4075,14 @@ def _join_multi(self, other, how, return_indexers=True):
how = {"right": "left", "left": "right"}.get(how, how)

level = other.names.index(jl)
result = self._join_level(
other, level, how=how, return_indexers=return_indexers
)
result = self._join_level(other, level, how=how)

if flip_order and isinstance(result, tuple):
if flip_order:
return result[0], result[2], result[1]
return result

@final
def _join_non_unique(self, other, how="left", return_indexers=False):
def _join_non_unique(self, other, how="left"):
from pandas.core.reshape.merge import get_join_indexers

# We only get here if dtypes match
Expand All @@ -4102,15 +4104,10 @@ def _join_non_unique(self, other, how="left", return_indexers=False):

join_index = self._wrap_joined_index(join_array, other)

if return_indexers:
return join_index, left_idx, right_idx
else:
return join_index
return join_index, left_idx, right_idx

@final
def _join_level(
self, other, level, how="left", return_indexers=False, keep_order=True
):
def _join_level(self, other, level, how="left", keep_order=True):
"""
The join method *only* affects the level of the resulting
MultiIndex. Otherwise it just exactly aligns the Index data to the
Expand Down Expand Up @@ -4247,28 +4244,22 @@ def _get_leaf_sorter(labels: list[np.ndarray]) -> np.ndarray:
if flip_order:
left_indexer, right_indexer = right_indexer, left_indexer

if return_indexers:
left_indexer = (
None if left_indexer is None else ensure_platform_int(left_indexer)
)
right_indexer = (
None if right_indexer is None else ensure_platform_int(right_indexer)
)
return join_index, left_indexer, right_indexer
else:
return join_index
left_indexer = (
None if left_indexer is None else ensure_platform_int(left_indexer)
)
right_indexer = (
None if right_indexer is None else ensure_platform_int(right_indexer)
)
return join_index, left_indexer, right_indexer

@final
def _join_monotonic(self, other, how="left", return_indexers=False):
def _join_monotonic(self, other: Index, how="left"):
# We only get here with matching dtypes
assert other.dtype == self.dtype

if self.equals(other):
ret_index = other if how == "right" else self
if return_indexers:
return ret_index, None, None
else:
return ret_index
return ret_index, None, None

sv = self._get_engine_target()
ov = other._get_engine_target()
Expand Down Expand Up @@ -4304,12 +4295,9 @@ def _join_monotonic(self, other, how="left", return_indexers=False):

join_index = self._wrap_joined_index(join_array, other)

if return_indexers:
lidx = None if lidx is None else ensure_platform_int(lidx)
ridx = None if ridx is None else ensure_platform_int(ridx)
return join_index, lidx, ridx
else:
return join_index
lidx = None if lidx is None else ensure_platform_int(lidx)
ridx = None if ridx is None else ensure_platform_int(ridx)
return join_index, lidx, ridx

def _wrap_joined_index(
self: _IndexT, joined: np.ndarray, other: _IndexT
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2532,7 +2532,7 @@ def reindex(self, target, method=None, level=None, limit=None, tolerance=None):
else:
target = ensure_index(target)
target, indexer, _ = self._join_level(
target, level, how="right", return_indexers=True, keep_order=False
target, level, how="right", keep_order=False
)
else:
target = ensure_index(target)
Expand Down