Skip to content

REF: helper for merge casting #53976

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 6, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 47 additions & 59 deletions pandas/core/reshape/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2067,6 +2067,32 @@ def _validate_tolerance(self, left_join_keys: list[ArrayLike]) -> None:
else:
raise MergeError("key must be integer, timestamp or float")

def _convert_values_for_libjoin(
self, values: AnyArrayLike, side: str
) -> np.ndarray:
# we require sortedness and non-null values in the join keys
if not Index(values).is_monotonic_increasing:
if isna(values).any():
raise ValueError(f"Merge keys contain null values on {side} side")
raise ValueError(f"{side} keys must be sorted")

if isinstance(values, ArrowExtensionArray):
values = values._maybe_convert_datelike_array()

if needs_i8_conversion(values.dtype):
values = values.view("i8")

elif isinstance(values, BaseMaskedArray):
# we've verified above that no nulls exist
values = values._data
elif isinstance(values, ExtensionArray):
values = values.to_numpy()

# error: Incompatible return value type (got "Union[ExtensionArray,
# Any, ndarray[Any, Any], ndarray[Any, dtype[Any]], Index, Series]",
# expected "ndarray[Any, Any]")
return values # type: ignore[return-value]

def _get_join_indexers(self) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]:
"""return the join indexers"""

Expand Down Expand Up @@ -2110,31 +2136,11 @@ def injection(obj: ArrayLike):
assert left_values.dtype == right_values.dtype

tolerance = self.tolerance

# we require sortedness and non-null values in the join keys
if not Index(left_values).is_monotonic_increasing:
side = "left"
if isna(left_values).any():
raise ValueError(f"Merge keys contain null values on {side} side")
raise ValueError(f"{side} keys must be sorted")

if not Index(right_values).is_monotonic_increasing:
side = "right"
if isna(right_values).any():
raise ValueError(f"Merge keys contain null values on {side} side")
raise ValueError(f"{side} keys must be sorted")

if isinstance(left_values, ArrowExtensionArray):
left_values = left_values._maybe_convert_datelike_array()

if isinstance(right_values, ArrowExtensionArray):
right_values = right_values._maybe_convert_datelike_array()

# initial type conversion as needed
if needs_i8_conversion(getattr(left_values, "dtype", None)):
if tolerance is not None:
if tolerance is not None:
# TODO: can we reuse a tolerance-conversion function from
# e.g. TimedeltaIndex?
if needs_i8_conversion(left_values.dtype):
tolerance = Timedelta(tolerance)

# TODO: we have no test cases with PeriodDtype here; probably
# need to adjust tolerance for that case.
if left_values.dtype.kind in "mM":
Expand All @@ -2145,22 +2151,9 @@ def injection(obj: ArrayLike):

tolerance = tolerance._value

# TODO: require left_values.dtype == right_values.dtype, or at least
# comparable for e.g. dt64tz
left_values = left_values.view("i8")
right_values = right_values.view("i8")

if isinstance(left_values, BaseMaskedArray):
# we've verified above that no nulls exist
left_values = left_values._data
elif isinstance(left_values, ExtensionArray):
left_values = left_values.to_numpy()

if isinstance(right_values, BaseMaskedArray):
# we've verified above that no nulls exist
right_values = right_values._data
elif isinstance(right_values, ExtensionArray):
right_values = right_values.to_numpy()
# initial type conversion as needed
left_values = self._convert_values_for_libjoin(left_values, "left")
right_values = self._convert_values_for_libjoin(right_values, "right")

# a "by" parameter requires special handling
if self.left_by is not None:
Expand Down Expand Up @@ -2259,19 +2252,7 @@ def _get_multiindex_indexer(

# get flat i8 join keys
lkey, rkey = _get_join_keys(lcodes, rcodes, tuple(shape), sort)

# factorize keys to a dense i8 space
lkey, rkey, count = _factorize_keys(lkey, rkey, sort=sort)

return libjoin.left_outer_join(lkey, rkey, count, sort=sort)


def _get_single_indexer(
join_key: ArrayLike, index: Index, sort: bool = False
) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]:
left_key, right_key, count = _factorize_keys(join_key, index._values, sort=sort)

return libjoin.left_outer_join(left_key, right_key, count, sort=sort)
return lkey, rkey


def _get_empty_indexer() -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]:
Expand Down Expand Up @@ -2315,13 +2296,20 @@ def _left_join_on_index(
left_ax: Index, right_ax: Index, join_keys: list[ArrayLike], sort: bool = False
) -> tuple[Index, npt.NDArray[np.intp] | None, npt.NDArray[np.intp]]:
if isinstance(right_ax, MultiIndex):
left_indexer, right_indexer = _get_multiindex_indexer(
join_keys, right_ax, sort=sort
)
lkey, rkey = _get_multiindex_indexer(join_keys, right_ax, sort=sort)
else:
left_indexer, right_indexer = _get_single_indexer(
join_keys[0], right_ax, sort=sort
)
# error: Incompatible types in assignment (expression has type
# "Union[Union[ExtensionArray, ndarray[Any, Any]], Index, Series]",
# variable has type "ndarray[Any, dtype[signedinteger[Any]]]")
lkey = join_keys[0] # type: ignore[assignment]
# error: Incompatible types in assignment (expression has type "Index",
# variable has type "ndarray[Any, dtype[signedinteger[Any]]]")
rkey = right_ax._values # type: ignore[assignment]

left_key, right_key, count = _factorize_keys(lkey, rkey, sort=sort)
left_indexer, right_indexer = libjoin.left_outer_join(
left_key, right_key, count, sort=sort
)

if sort or len(left_ax) != len(left_indexer):
# if asked to sort or there are 1-to-many matches
Expand Down