Skip to content

Commit 031f8c8

Browse files
authored
REF: helper for merge casting (#53976)
1 parent 2cc6bfa commit 031f8c8

File tree

1 file changed

+47
-59
lines changed

1 file changed

+47
-59
lines changed

pandas/core/reshape/merge.py

Lines changed: 47 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -2067,6 +2067,32 @@ def _validate_tolerance(self, left_join_keys: list[ArrayLike]) -> None:
20672067
else:
20682068
raise MergeError("key must be integer, timestamp or float")
20692069

2070+
def _convert_values_for_libjoin(
2071+
self, values: AnyArrayLike, side: str
2072+
) -> np.ndarray:
2073+
# we require sortedness and non-null values in the join keys
2074+
if not Index(values).is_monotonic_increasing:
2075+
if isna(values).any():
2076+
raise ValueError(f"Merge keys contain null values on {side} side")
2077+
raise ValueError(f"{side} keys must be sorted")
2078+
2079+
if isinstance(values, ArrowExtensionArray):
2080+
values = values._maybe_convert_datelike_array()
2081+
2082+
if needs_i8_conversion(values.dtype):
2083+
values = values.view("i8")
2084+
2085+
elif isinstance(values, BaseMaskedArray):
2086+
# we've verified above that no nulls exist
2087+
values = values._data
2088+
elif isinstance(values, ExtensionArray):
2089+
values = values.to_numpy()
2090+
2091+
# error: Incompatible return value type (got "Union[ExtensionArray,
2092+
# Any, ndarray[Any, Any], ndarray[Any, dtype[Any]], Index, Series]",
2093+
# expected "ndarray[Any, Any]")
2094+
return values # type: ignore[return-value]
2095+
20702096
def _get_join_indexers(self) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]:
20712097
"""return the join indexers"""
20722098

@@ -2110,31 +2136,11 @@ def injection(obj: ArrayLike):
21102136
assert left_values.dtype == right_values.dtype
21112137

21122138
tolerance = self.tolerance
2113-
2114-
# we require sortedness and non-null values in the join keys
2115-
if not Index(left_values).is_monotonic_increasing:
2116-
side = "left"
2117-
if isna(left_values).any():
2118-
raise ValueError(f"Merge keys contain null values on {side} side")
2119-
raise ValueError(f"{side} keys must be sorted")
2120-
2121-
if not Index(right_values).is_monotonic_increasing:
2122-
side = "right"
2123-
if isna(right_values).any():
2124-
raise ValueError(f"Merge keys contain null values on {side} side")
2125-
raise ValueError(f"{side} keys must be sorted")
2126-
2127-
if isinstance(left_values, ArrowExtensionArray):
2128-
left_values = left_values._maybe_convert_datelike_array()
2129-
2130-
if isinstance(right_values, ArrowExtensionArray):
2131-
right_values = right_values._maybe_convert_datelike_array()
2132-
2133-
# initial type conversion as needed
2134-
if needs_i8_conversion(getattr(left_values, "dtype", None)):
2135-
if tolerance is not None:
2139+
if tolerance is not None:
2140+
# TODO: can we reuse a tolerance-conversion function from
2141+
# e.g. TimedeltaIndex?
2142+
if needs_i8_conversion(left_values.dtype):
21362143
tolerance = Timedelta(tolerance)
2137-
21382144
# TODO: we have no test cases with PeriodDtype here; probably
21392145
# need to adjust tolerance for that case.
21402146
if left_values.dtype.kind in "mM":
@@ -2145,22 +2151,9 @@ def injection(obj: ArrayLike):
21452151

21462152
tolerance = tolerance._value
21472153

2148-
# TODO: require left_values.dtype == right_values.dtype, or at least
2149-
# comparable for e.g. dt64tz
2150-
left_values = left_values.view("i8")
2151-
right_values = right_values.view("i8")
2152-
2153-
if isinstance(left_values, BaseMaskedArray):
2154-
# we've verified above that no nulls exist
2155-
left_values = left_values._data
2156-
elif isinstance(left_values, ExtensionArray):
2157-
left_values = left_values.to_numpy()
2158-
2159-
if isinstance(right_values, BaseMaskedArray):
2160-
# we've verified above that no nulls exist
2161-
right_values = right_values._data
2162-
elif isinstance(right_values, ExtensionArray):
2163-
right_values = right_values.to_numpy()
2154+
# initial type conversion as needed
2155+
left_values = self._convert_values_for_libjoin(left_values, "left")
2156+
right_values = self._convert_values_for_libjoin(right_values, "right")
21642157

21652158
# a "by" parameter requires special handling
21662159
if self.left_by is not None:
@@ -2259,19 +2252,7 @@ def _get_multiindex_indexer(
22592252

22602253
# get flat i8 join keys
22612254
lkey, rkey = _get_join_keys(lcodes, rcodes, tuple(shape), sort)
2262-
2263-
# factorize keys to a dense i8 space
2264-
lkey, rkey, count = _factorize_keys(lkey, rkey, sort=sort)
2265-
2266-
return libjoin.left_outer_join(lkey, rkey, count, sort=sort)
2267-
2268-
2269-
def _get_single_indexer(
2270-
join_key: ArrayLike, index: Index, sort: bool = False
2271-
) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]:
2272-
left_key, right_key, count = _factorize_keys(join_key, index._values, sort=sort)
2273-
2274-
return libjoin.left_outer_join(left_key, right_key, count, sort=sort)
2255+
return lkey, rkey
22752256

22762257

22772258
def _get_empty_indexer() -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]:
@@ -2315,13 +2296,20 @@ def _left_join_on_index(
23152296
left_ax: Index, right_ax: Index, join_keys: list[ArrayLike], sort: bool = False
23162297
) -> tuple[Index, npt.NDArray[np.intp] | None, npt.NDArray[np.intp]]:
23172298
if isinstance(right_ax, MultiIndex):
2318-
left_indexer, right_indexer = _get_multiindex_indexer(
2319-
join_keys, right_ax, sort=sort
2320-
)
2299+
lkey, rkey = _get_multiindex_indexer(join_keys, right_ax, sort=sort)
23212300
else:
2322-
left_indexer, right_indexer = _get_single_indexer(
2323-
join_keys[0], right_ax, sort=sort
2324-
)
2301+
# error: Incompatible types in assignment (expression has type
2302+
# "Union[Union[ExtensionArray, ndarray[Any, Any]], Index, Series]",
2303+
# variable has type "ndarray[Any, dtype[signedinteger[Any]]]")
2304+
lkey = join_keys[0] # type: ignore[assignment]
2305+
# error: Incompatible types in assignment (expression has type "Index",
2306+
# variable has type "ndarray[Any, dtype[signedinteger[Any]]]")
2307+
rkey = right_ax._values # type: ignore[assignment]
2308+
2309+
left_key, right_key, count = _factorize_keys(lkey, rkey, sort=sort)
2310+
left_indexer, right_indexer = libjoin.left_outer_join(
2311+
left_key, right_key, count, sort=sort
2312+
)
23252313

23262314
if sort or len(left_ax) != len(left_indexer):
23272315
# if asked to sort or there are 1-to-many matches

0 commit comments

Comments
 (0)