@@ -2067,6 +2067,32 @@ def _validate_tolerance(self, left_join_keys: list[ArrayLike]) -> None:
2067
2067
else :
2068
2068
raise MergeError ("key must be integer, timestamp or float" )
2069
2069
2070
+ def _convert_values_for_libjoin (
2071
+ self , values : AnyArrayLike , side : str
2072
+ ) -> np .ndarray :
2073
+ # we require sortedness and non-null values in the join keys
2074
+ if not Index (values ).is_monotonic_increasing :
2075
+ if isna (values ).any ():
2076
+ raise ValueError (f"Merge keys contain null values on { side } side" )
2077
+ raise ValueError (f"{ side } keys must be sorted" )
2078
+
2079
+ if isinstance (values , ArrowExtensionArray ):
2080
+ values = values ._maybe_convert_datelike_array ()
2081
+
2082
+ if needs_i8_conversion (values .dtype ):
2083
+ values = values .view ("i8" )
2084
+
2085
+ elif isinstance (values , BaseMaskedArray ):
2086
+ # we've verified above that no nulls exist
2087
+ values = values ._data
2088
+ elif isinstance (values , ExtensionArray ):
2089
+ values = values .to_numpy ()
2090
+
2091
+ # error: Incompatible return value type (got "Union[ExtensionArray,
2092
+ # Any, ndarray[Any, Any], ndarray[Any, dtype[Any]], Index, Series]",
2093
+ # expected "ndarray[Any, Any]")
2094
+ return values # type: ignore[return-value]
2095
+
2070
2096
def _get_join_indexers (self ) -> tuple [npt .NDArray [np .intp ], npt .NDArray [np .intp ]]:
2071
2097
"""return the join indexers"""
2072
2098
@@ -2110,31 +2136,11 @@ def injection(obj: ArrayLike):
2110
2136
assert left_values .dtype == right_values .dtype
2111
2137
2112
2138
tolerance = self .tolerance
2113
-
2114
- # we require sortedness and non-null values in the join keys
2115
- if not Index (left_values ).is_monotonic_increasing :
2116
- side = "left"
2117
- if isna (left_values ).any ():
2118
- raise ValueError (f"Merge keys contain null values on { side } side" )
2119
- raise ValueError (f"{ side } keys must be sorted" )
2120
-
2121
- if not Index (right_values ).is_monotonic_increasing :
2122
- side = "right"
2123
- if isna (right_values ).any ():
2124
- raise ValueError (f"Merge keys contain null values on { side } side" )
2125
- raise ValueError (f"{ side } keys must be sorted" )
2126
-
2127
- if isinstance (left_values , ArrowExtensionArray ):
2128
- left_values = left_values ._maybe_convert_datelike_array ()
2129
-
2130
- if isinstance (right_values , ArrowExtensionArray ):
2131
- right_values = right_values ._maybe_convert_datelike_array ()
2132
-
2133
- # initial type conversion as needed
2134
- if needs_i8_conversion (getattr (left_values , "dtype" , None )):
2135
- if tolerance is not None :
2139
+ if tolerance is not None :
2140
+ # TODO: can we reuse a tolerance-conversion function from
2141
+ # e.g. TimedeltaIndex?
2142
+ if needs_i8_conversion (left_values .dtype ):
2136
2143
tolerance = Timedelta (tolerance )
2137
-
2138
2144
# TODO: we have no test cases with PeriodDtype here; probably
2139
2145
# need to adjust tolerance for that case.
2140
2146
if left_values .dtype .kind in "mM" :
@@ -2145,22 +2151,9 @@ def injection(obj: ArrayLike):
2145
2151
2146
2152
tolerance = tolerance ._value
2147
2153
2148
- # TODO: require left_values.dtype == right_values.dtype, or at least
2149
- # comparable for e.g. dt64tz
2150
- left_values = left_values .view ("i8" )
2151
- right_values = right_values .view ("i8" )
2152
-
2153
- if isinstance (left_values , BaseMaskedArray ):
2154
- # we've verified above that no nulls exist
2155
- left_values = left_values ._data
2156
- elif isinstance (left_values , ExtensionArray ):
2157
- left_values = left_values .to_numpy ()
2158
-
2159
- if isinstance (right_values , BaseMaskedArray ):
2160
- # we've verified above that no nulls exist
2161
- right_values = right_values ._data
2162
- elif isinstance (right_values , ExtensionArray ):
2163
- right_values = right_values .to_numpy ()
2154
+ # initial type conversion as needed
2155
+ left_values = self ._convert_values_for_libjoin (left_values , "left" )
2156
+ right_values = self ._convert_values_for_libjoin (right_values , "right" )
2164
2157
2165
2158
# a "by" parameter requires special handling
2166
2159
if self .left_by is not None :
@@ -2259,19 +2252,7 @@ def _get_multiindex_indexer(
2259
2252
2260
2253
# get flat i8 join keys
2261
2254
lkey , rkey = _get_join_keys (lcodes , rcodes , tuple (shape ), sort )
2262
-
2263
- # factorize keys to a dense i8 space
2264
- lkey , rkey , count = _factorize_keys (lkey , rkey , sort = sort )
2265
-
2266
- return libjoin .left_outer_join (lkey , rkey , count , sort = sort )
2267
-
2268
-
2269
- def _get_single_indexer (
2270
- join_key : ArrayLike , index : Index , sort : bool = False
2271
- ) -> tuple [npt .NDArray [np .intp ], npt .NDArray [np .intp ]]:
2272
- left_key , right_key , count = _factorize_keys (join_key , index ._values , sort = sort )
2273
-
2274
- return libjoin .left_outer_join (left_key , right_key , count , sort = sort )
2255
+ return lkey , rkey
2275
2256
2276
2257
2277
2258
def _get_empty_indexer () -> tuple [npt .NDArray [np .intp ], npt .NDArray [np .intp ]]:
@@ -2315,13 +2296,20 @@ def _left_join_on_index(
2315
2296
left_ax : Index , right_ax : Index , join_keys : list [ArrayLike ], sort : bool = False
2316
2297
) -> tuple [Index , npt .NDArray [np .intp ] | None , npt .NDArray [np .intp ]]:
2317
2298
if isinstance (right_ax , MultiIndex ):
2318
- left_indexer , right_indexer = _get_multiindex_indexer (
2319
- join_keys , right_ax , sort = sort
2320
- )
2299
+ lkey , rkey = _get_multiindex_indexer (join_keys , right_ax , sort = sort )
2321
2300
else :
2322
- left_indexer , right_indexer = _get_single_indexer (
2323
- join_keys [0 ], right_ax , sort = sort
2324
- )
2301
+ # error: Incompatible types in assignment (expression has type
2302
+ # "Union[Union[ExtensionArray, ndarray[Any, Any]], Index, Series]",
2303
+ # variable has type "ndarray[Any, dtype[signedinteger[Any]]]")
2304
+ lkey = join_keys [0 ] # type: ignore[assignment]
2305
+ # error: Incompatible types in assignment (expression has type "Index",
2306
+ # variable has type "ndarray[Any, dtype[signedinteger[Any]]]")
2307
+ rkey = right_ax ._values # type: ignore[assignment]
2308
+
2309
+ left_key , right_key , count = _factorize_keys (lkey , rkey , sort = sort )
2310
+ left_indexer , right_indexer = libjoin .left_outer_join (
2311
+ left_key , right_key , count , sort = sort
2312
+ )
2325
2313
2326
2314
if sort or len (left_ax ) != len (left_indexer ):
2327
2315
# if asked to sort or there are 1-to-many matches
0 commit comments