diff --git a/pandas/core/reshape/merge.py b/pandas/core/reshape/merge.py index e68277c38063e..cfaa5a1fdad64 100644 --- a/pandas/core/reshape/merge.py +++ b/pandas/core/reshape/merge.py @@ -2067,6 +2067,32 @@ def _validate_tolerance(self, left_join_keys: list[ArrayLike]) -> None: else: raise MergeError("key must be integer, timestamp or float") + def _convert_values_for_libjoin( + self, values: AnyArrayLike, side: str + ) -> np.ndarray: + # we require sortedness and non-null values in the join keys + if not Index(values).is_monotonic_increasing: + if isna(values).any(): + raise ValueError(f"Merge keys contain null values on {side} side") + raise ValueError(f"{side} keys must be sorted") + + if isinstance(values, ArrowExtensionArray): + values = values._maybe_convert_datelike_array() + + if needs_i8_conversion(values.dtype): + values = values.view("i8") + + elif isinstance(values, BaseMaskedArray): + # we've verified above that no nulls exist + values = values._data + elif isinstance(values, ExtensionArray): + values = values.to_numpy() + + # error: Incompatible return value type (got "Union[ExtensionArray, + # Any, ndarray[Any, Any], ndarray[Any, dtype[Any]], Index, Series]", + # expected "ndarray[Any, Any]") + return values # type: ignore[return-value] + def _get_join_indexers(self) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]: """return the join indexers""" @@ -2110,31 +2136,11 @@ def injection(obj: ArrayLike): assert left_values.dtype == right_values.dtype tolerance = self.tolerance - - # we require sortedness and non-null values in the join keys - if not Index(left_values).is_monotonic_increasing: - side = "left" - if isna(left_values).any(): - raise ValueError(f"Merge keys contain null values on {side} side") - raise ValueError(f"{side} keys must be sorted") - - if not Index(right_values).is_monotonic_increasing: - side = "right" - if isna(right_values).any(): - raise ValueError(f"Merge keys contain null values on {side} side") - raise ValueError(f"{side} keys must be sorted") - - if isinstance(left_values, ArrowExtensionArray): - left_values = left_values._maybe_convert_datelike_array() - - if isinstance(right_values, ArrowExtensionArray): - right_values = right_values._maybe_convert_datelike_array() - - # initial type conversion as needed - if needs_i8_conversion(getattr(left_values, "dtype", None)): - if tolerance is not None: + if tolerance is not None: + # TODO: can we reuse a tolerance-conversion function from + # e.g. TimedeltaIndex? + if needs_i8_conversion(left_values.dtype): tolerance = Timedelta(tolerance) - # TODO: we have no test cases with PeriodDtype here; probably # need to adjust tolerance for that case. if left_values.dtype.kind in "mM": @@ -2145,22 +2151,9 @@ def injection(obj: ArrayLike): tolerance = tolerance._value - # TODO: require left_values.dtype == right_values.dtype, or at least - # comparable for e.g. dt64tz - left_values = left_values.view("i8") - right_values = right_values.view("i8") - - if isinstance(left_values, BaseMaskedArray): - # we've verified above that no nulls exist - left_values = left_values._data - elif isinstance(left_values, ExtensionArray): - left_values = left_values.to_numpy() - - if isinstance(right_values, BaseMaskedArray): - # we've verified above that no nulls exist - right_values = right_values._data - elif isinstance(right_values, ExtensionArray): - right_values = right_values.to_numpy() + # initial type conversion as needed + left_values = self._convert_values_for_libjoin(left_values, "left") + right_values = self._convert_values_for_libjoin(right_values, "right") # a "by" parameter requires special handling if self.left_by is not None: @@ -2259,19 +2252,7 @@ def _get_multiindex_indexer( # get flat i8 join keys lkey, rkey = _get_join_keys(lcodes, rcodes, tuple(shape), sort) - - # factorize keys to a dense i8 space - lkey, rkey, count = _factorize_keys(lkey, rkey, sort=sort) - - return libjoin.left_outer_join(lkey, rkey, count, sort=sort) - - -def _get_single_indexer( - join_key: ArrayLike, index: Index, sort: bool = False -) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]: - left_key, right_key, count = _factorize_keys(join_key, index._values, sort=sort) - - return libjoin.left_outer_join(left_key, right_key, count, sort=sort) + return lkey, rkey def _get_empty_indexer() -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]: @@ -2315,13 +2296,20 @@ def _left_join_on_index( left_ax: Index, right_ax: Index, join_keys: list[ArrayLike], sort: bool = False ) -> tuple[Index, npt.NDArray[np.intp] | None, npt.NDArray[np.intp]]: if isinstance(right_ax, MultiIndex): - left_indexer, right_indexer = _get_multiindex_indexer( - join_keys, right_ax, sort=sort - ) + lkey, rkey = _get_multiindex_indexer(join_keys, right_ax, sort=sort) else: - left_indexer, right_indexer = _get_single_indexer( - join_keys[0], right_ax, sort=sort - ) + # error: Incompatible types in assignment (expression has type + # "Union[Union[ExtensionArray, ndarray[Any, Any]], Index, Series]", + # variable has type "ndarray[Any, dtype[signedinteger[Any]]]") + lkey = join_keys[0] # type: ignore[assignment] + # error: Incompatible types in assignment (expression has type "Index", + # variable has type "ndarray[Any, dtype[signedinteger[Any]]]") + rkey = right_ax._values # type: ignore[assignment] + + left_key, right_key, count = _factorize_keys(lkey, rkey, sort=sort) + left_indexer, right_indexer = libjoin.left_outer_join( + left_key, right_key, count, sort=sort + ) if sort or len(left_ax) != len(left_indexer): # if asked to sort or there are 1-to-many matches