Skip to content

Commit e37eda5

Browse files
authored
REF: use can_use_libjoin more consistently, docstring for it (#54664)
* REF: fix can_use_libjoin check * DOC: docstring for can_use_libjoin * Make can_use_libjoin checks more-correct * mypy fixup * Fix exception message
1 parent 11b8d1a commit e37eda5

File tree

1 file changed

+31
-19
lines changed

1 file changed

+31
-19
lines changed

pandas/core/indexes/base.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -373,9 +373,6 @@ def _left_indexer_unique(self, other: Self) -> npt.NDArray[np.intp]:
373373
# Caller is responsible for ensuring other.dtype == self.dtype
374374
sv = self._get_join_target()
375375
ov = other._get_join_target()
376-
# can_use_libjoin assures sv and ov are ndarrays
377-
sv = cast(np.ndarray, sv)
378-
ov = cast(np.ndarray, ov)
379376
# similar but not identical to ov.searchsorted(sv)
380377
return libjoin.left_join_indexer_unique(sv, ov)
381378

@@ -386,9 +383,6 @@ def _left_indexer(
386383
# Caller is responsible for ensuring other.dtype == self.dtype
387384
sv = self._get_join_target()
388385
ov = other._get_join_target()
389-
# can_use_libjoin assures sv and ov are ndarrays
390-
sv = cast(np.ndarray, sv)
391-
ov = cast(np.ndarray, ov)
392386
joined_ndarray, lidx, ridx = libjoin.left_join_indexer(sv, ov)
393387
joined = self._from_join_target(joined_ndarray)
394388
return joined, lidx, ridx
@@ -400,9 +394,6 @@ def _inner_indexer(
400394
# Caller is responsible for ensuring other.dtype == self.dtype
401395
sv = self._get_join_target()
402396
ov = other._get_join_target()
403-
# can_use_libjoin assures sv and ov are ndarrays
404-
sv = cast(np.ndarray, sv)
405-
ov = cast(np.ndarray, ov)
406397
joined_ndarray, lidx, ridx = libjoin.inner_join_indexer(sv, ov)
407398
joined = self._from_join_target(joined_ndarray)
408399
return joined, lidx, ridx
@@ -414,9 +405,6 @@ def _outer_indexer(
414405
# Caller is responsible for ensuring other.dtype == self.dtype
415406
sv = self._get_join_target()
416407
ov = other._get_join_target()
417-
# can_use_libjoin assures sv and ov are ndarrays
418-
sv = cast(np.ndarray, sv)
419-
ov = cast(np.ndarray, ov)
420408
joined_ndarray, lidx, ridx = libjoin.outer_join_indexer(sv, ov)
421409
joined = self._from_join_target(joined_ndarray)
422410
return joined, lidx, ridx
@@ -3354,6 +3342,7 @@ def _union(self, other: Index, sort: bool | None):
33543342
and other.is_monotonic_increasing
33553343
and not (self.has_duplicates and other.has_duplicates)
33563344
and self._can_use_libjoin
3345+
and other._can_use_libjoin
33573346
):
33583347
# Both are monotonic and at least one is unique, so can use outer join
33593348
# (actually don't need either unique, but without this restriction
@@ -3452,7 +3441,7 @@ def intersection(self, other, sort: bool = False):
34523441
self, other = self._dti_setop_align_tzs(other, "intersection")
34533442

34543443
if self.equals(other):
3455-
if self.has_duplicates:
3444+
if not self.is_unique:
34563445
result = self.unique()._get_reconciled_name_object(other)
34573446
else:
34583447
result = self._get_reconciled_name_object(other)
@@ -3507,7 +3496,9 @@ def _intersection(self, other: Index, sort: bool = False):
35073496
self.is_monotonic_increasing
35083497
and other.is_monotonic_increasing
35093498
and self._can_use_libjoin
3499+
and other._can_use_libjoin
35103500
and not isinstance(self, ABCMultiIndex)
3501+
and not isinstance(other, ABCMultiIndex)
35113502
):
35123503
try:
35133504
res_indexer, indexer, _ = self._inner_indexer(other)
@@ -4654,7 +4645,10 @@ def join(
46544645
return self._join_non_unique(other, how=how)
46554646
elif not self.is_unique or not other.is_unique:
46564647
if self.is_monotonic_increasing and other.is_monotonic_increasing:
4657-
if not isinstance(self.dtype, IntervalDtype):
4648+
# Note: 2023-08-15 we *do* have tests that get here with
4649+
# Categorical, string[python] (can use libjoin)
4650+
# and Interval (cannot)
4651+
if self._can_use_libjoin and other._can_use_libjoin:
46584652
# otherwise we will fall through to _join_via_get_indexer
46594653
# GH#39133
46604654
# go through object dtype for ea till engine is supported properly
@@ -4666,6 +4660,7 @@ def join(
46664660
self.is_monotonic_increasing
46674661
and other.is_monotonic_increasing
46684662
and self._can_use_libjoin
4663+
and other._can_use_libjoin
46694664
and not isinstance(self, ABCMultiIndex)
46704665
and not isinstance(self.dtype, CategoricalDtype)
46714666
):
@@ -4970,6 +4965,7 @@ def _join_monotonic(
49704965
) -> tuple[Index, npt.NDArray[np.intp] | None, npt.NDArray[np.intp] | None]:
49714966
# We only get here with matching dtypes and both monotonic increasing
49724967
assert other.dtype == self.dtype
4968+
assert self._can_use_libjoin and other._can_use_libjoin
49734969

49744970
if self.equals(other):
49754971
# This is a convenient place for this check, but its correctness
@@ -5038,19 +5034,28 @@ def _wrap_joined_index(
50385034
name = get_op_result_name(self, other)
50395035
return self._constructor._with_infer(joined, name=name, dtype=self.dtype)
50405036

5037+
@final
50415038
@cache_readonly
50425039
def _can_use_libjoin(self) -> bool:
50435040
"""
5044-
Whether we can use the fastpaths implement in _libs.join
5041+
Whether we can use the fastpaths implemented in _libs.join.
5042+
5043+
This is driven by whether (in monotonic increasing cases that are
5044+
guaranteed not to have NAs) we can convert to a np.ndarray without
5045+
making a copy. If we cannot, this negates the performance benefit
5046+
of using libjoin.
50455047
"""
50465048
if type(self) is Index:
50475049
# excludes EAs, but include masks, we get here with monotonic
50485050
# values only, meaning no NA
50495051
return (
50505052
isinstance(self.dtype, np.dtype)
5051-
or isinstance(self.values, BaseMaskedArray)
5052-
or isinstance(self._values, ArrowExtensionArray)
5053+
or isinstance(self._values, (ArrowExtensionArray, BaseMaskedArray))
5054+
or self.dtype == "string[python]"
50535055
)
5056+
# For IntervalIndex, the conversion to numpy converts
5057+
# to object dtype, which negates the performance benefit of libjoin
5058+
# TODO: exclude RangeIndex and MultiIndex as these also make copies?
50545059
return not isinstance(self.dtype, IntervalDtype)
50555060

50565061
# --------------------------------------------------------------------
@@ -5172,7 +5177,8 @@ def _get_engine_target(self) -> ArrayLike:
51725177
return self._values.astype(object)
51735178
return vals
51745179

5175-
def _get_join_target(self) -> ArrayLike:
5180+
@final
5181+
def _get_join_target(self) -> np.ndarray:
51765182
"""
51775183
Get the ndarray or ExtensionArray that we can pass to the join
51785184
functions.
@@ -5184,7 +5190,13 @@ def _get_join_target(self) -> ArrayLike:
51845190
# This is only used if our array is monotonic, so no missing values
51855191
# present
51865192
return self._values.to_numpy()
5187-
return self._get_engine_target()
5193+
5194+
# TODO: exclude ABCRangeIndex, ABCMultiIndex cases here as those create
5195+
# copies.
5196+
target = self._get_engine_target()
5197+
if not isinstance(target, np.ndarray):
5198+
raise ValueError("_can_use_libjoin should return False.")
5199+
return target
51885200

51895201
def _from_join_target(self, result: np.ndarray) -> ArrayLike:
51905202
"""

0 commit comments

Comments
 (0)