Skip to content

REF: standardize get_indexer/get_indexer_non_unique code #38648

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1233,7 +1233,7 @@ def reset_identity(values):
# so we resort to this
# GH 14776, 30667
if ax.has_duplicates and not result.axes[self.axis].equals(ax):
indexer, _ = result.index.get_indexer_non_unique(ax.values)
indexer, _ = result.index.get_indexer_non_unique(ax._values)
indexer = algorithms.unique1d(indexer)
result = result.take(indexer, axis=self.axis)
else:
Expand Down
26 changes: 12 additions & 14 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2182,7 +2182,7 @@ def _nan_idxs(self):
if self._can_hold_na:
return self._isnan.nonzero()[0]
else:
return np.array([], dtype=np.int64)
return np.array([], dtype=np.intp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this cause a test breakage (on 32-bit)? (or are we not testing this particular case)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like we only really use _nan_idxs in Float64Index.get_loc, so this line wont be reached anyway


@cache_readonly
def hasnans(self) -> bool:
Expand Down Expand Up @@ -2719,12 +2719,12 @@ def _union(self, other, sort):
# find indexes of things in "other" that are not in "self"
if self.is_unique:
indexer = self.get_indexer(other)
indexer = (indexer == -1).nonzero()[0]
missing = (indexer == -1).nonzero()[0]
else:
indexer = algos.unique1d(self.get_indexer_non_unique(other)[1])
missing = algos.unique1d(self.get_indexer_non_unique(other)[1])

if len(indexer) > 0:
other_diff = algos.take_nd(rvals, indexer, allow_fill=False)
if len(missing) > 0:
other_diff = algos.take_nd(rvals, missing, allow_fill=False)
result = concat_compat((lvals, other_diff))

else:
Expand Down Expand Up @@ -2829,13 +2829,14 @@ def _intersection(self, other, sort=False):
return algos.unique1d(result)

try:
indexer = Index(rvals).get_indexer(lvals)
indexer = indexer.take((indexer != -1).nonzero()[0])
indexer = other.get_indexer(lvals)
except (InvalidIndexError, IncompatibleFrequency):
# InvalidIndexError raised by get_indexer if non-unique
# IncompatibleFrequency raised by PeriodIndex.get_indexer
indexer = algos.unique1d(Index(rvals).get_indexer_non_unique(lvals)[0])
indexer = indexer[indexer != -1]
indexer, _ = other.get_indexer_non_unique(lvals)

mask = indexer != -1
indexer = indexer.take(mask.nonzero()[0])

result = other.take(indexer).unique()._values

Expand Down Expand Up @@ -3517,7 +3518,7 @@ def reindex(self, target, method=None, level=None, limit=None, tolerance=None):
"cannot reindex a non-unique index "
"with a method or limit"
)
indexer, missing = self.get_indexer_non_unique(target)
indexer, _ = self.get_indexer_non_unique(target)

if preserve_names and target.nlevels == 1 and target.name != self.name:
target = target.copy()
Expand Down Expand Up @@ -4910,10 +4911,7 @@ def get_indexer_non_unique(self, target):
that = target.astype(dtype, copy=False)
return this.get_indexer_non_unique(that)

if is_categorical_dtype(target.dtype):
tgt_values = np.asarray(target)
else:
tgt_values = target._get_engine_target()
tgt_values = target._get_engine_target()

indexer, missing = self._engine.get_indexer_non_unique(tgt_values)
return ensure_platform_int(indexer), missing
Expand Down
57 changes: 18 additions & 39 deletions pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pandas._libs import lib
from pandas._libs.interval import Interval, IntervalMixin, IntervalTree
from pandas._libs.tslibs import BaseOffset, Timedelta, Timestamp, to_offset
from pandas._typing import AnyArrayLike, DtypeObj, Label
from pandas._typing import DtypeObj, Label
from pandas.errors import InvalidIndexError
from pandas.util._decorators import Appender, cache_readonly
from pandas.util._exceptions import rewrite_exception
Expand Down Expand Up @@ -652,9 +652,8 @@ def _get_indexer(
if self.equals(target):
return np.arange(len(self), dtype="intp")

if self._is_non_comparable_own_type(target):
# different closed or incompatible subtype -> no matches
return np.repeat(np.intp(-1), len(target))
if not self._should_compare(target):
return self._get_indexer_non_comparable(target, method, unique=True)

# non-overlapping -> at most one match per interval in target
# want exact matches -> need both left/right to match, so defer to
Expand All @@ -678,32 +677,22 @@ def _get_indexer(
return ensure_platform_int(indexer)

@Appender(_index_shared_docs["get_indexer_non_unique"] % _index_doc_kwargs)
def get_indexer_non_unique(
self, target: AnyArrayLike
) -> Tuple[np.ndarray, np.ndarray]:
target_as_index = ensure_index(target)

# check that target_as_index IntervalIndex is compatible
if isinstance(target_as_index, IntervalIndex):

if self._is_non_comparable_own_type(target_as_index):
# different closed or incompatible subtype -> no matches
return (
np.repeat(-1, len(target_as_index)),
np.arange(len(target_as_index)),
)

if is_object_dtype(target_as_index) or isinstance(
target_as_index, IntervalIndex
):
# target_as_index might contain intervals: defer elementwise to get_loc
return self._get_indexer_pointwise(target_as_index)
def get_indexer_non_unique(self, target: Index) -> Tuple[np.ndarray, np.ndarray]:
target = ensure_index(target)

if isinstance(target, IntervalIndex) and not self._should_compare(target):
# different closed or incompatible subtype -> no matches
return self._get_indexer_non_comparable(target, None, unique=False)

elif is_object_dtype(target.dtype) or isinstance(target, IntervalIndex):
# target might contain intervals: defer elementwise to get_loc
return self._get_indexer_pointwise(target)

else:
target_as_index = self._maybe_convert_i8(target_as_index)
indexer, missing = self._engine.get_indexer_non_unique(
target_as_index.values
)
# Note: this case behaves differently from other Index subclasses
# because IntervalIndex does partial-int indexing
target = self._maybe_convert_i8(target)
indexer, missing = self._engine.get_indexer_non_unique(target.values)

return ensure_platform_int(indexer), ensure_platform_int(missing)

Expand Down Expand Up @@ -789,16 +778,6 @@ def _should_compare(self, other) -> bool:
return False
return other.closed == self.closed

# TODO: use should_compare and get rid of _is_non_comparable_own_type
def _is_non_comparable_own_type(self, other: "IntervalIndex") -> bool:
# different closed or incompatible subtype -> no matches

# TODO: once closed is part of IntervalDtype, we can just define
# is_comparable_dtype GH#19371
if self.closed != other.closed:
return True
return not self._is_comparable_dtype(other.dtype)

# --------------------------------------------------------------------

@cache_readonly
Expand Down Expand Up @@ -938,7 +917,7 @@ def _format_space(self) -> str:
def _assert_can_do_setop(self, other):
super()._assert_can_do_setop(other)

if isinstance(other, IntervalIndex) and self._is_non_comparable_own_type(other):
if isinstance(other, IntervalIndex) and not self._should_compare(other):
# GH#19016: ensure set op will not return a prohibited dtype
raise TypeError(
"can only do set operations between two IntervalIndex "
Expand Down