|
31 | 31 | DtypeObj,
|
32 | 32 | IntervalClosedType,
|
33 | 33 | npt,
|
| 34 | + NumpyIndexT, |
34 | 35 | )
|
35 | 36 | from pandas.errors import InvalidIndexError
|
36 | 37 | from pandas.util._decorators import (
|
|
47 | 48 | )
|
48 | 49 | from pandas.core.dtypes.common import (
|
49 | 50 | ensure_platform_int,
|
| 51 | + is_array_like, |
50 | 52 | is_datetime64tz_dtype,
|
51 | 53 | is_datetime_or_timedelta_dtype,
|
52 | 54 | is_dtype_equal,
|
@@ -146,6 +148,24 @@ def _new_IntervalIndex(cls, d):
|
146 | 148 | return cls.from_arrays(**d)
|
147 | 149 |
|
148 | 150 |
|
| 151 | +def maybe_convert_numeric_to_64bit(arr: NumpyIndexT) -> NumpyIndexT: |
| 152 | + # IntervalTree only supports 64 bit numpy array |
| 153 | + |
| 154 | + if not is_array_like(arr): |
| 155 | + return arr |
| 156 | + dtype = arr.dtype |
| 157 | + if not np.issubclass_(dtype.type, np.number): |
| 158 | + return arr |
| 159 | + elif is_signed_integer_dtype(dtype) and dtype != np.int64: |
| 160 | + return arr.astype(np.int64) |
| 161 | + elif is_unsigned_integer_dtype(dtype) and dtype != np.uint64: |
| 162 | + return arr.astype(np.uint64) |
| 163 | + elif is_float_dtype(dtype) and dtype != np.float64: |
| 164 | + return arr.astype(np.float64) |
| 165 | + else: |
| 166 | + return arr |
| 167 | + |
| 168 | + |
149 | 169 | @Appender(
|
150 | 170 | _interval_shared_docs["class"]
|
151 | 171 | % {
|
@@ -343,7 +363,9 @@ def from_tuples(
|
343 | 363 | @cache_readonly
|
344 | 364 | def _engine(self) -> IntervalTree: # type: ignore[override]
|
345 | 365 | left = self._maybe_convert_i8(self.left)
|
| 366 | + left = maybe_convert_numeric_to_64bit(left) |
346 | 367 | right = self._maybe_convert_i8(self.right)
|
| 368 | + right = maybe_convert_numeric_to_64bit(right) |
347 | 369 | return IntervalTree(left, right, closed=self.closed)
|
348 | 370 |
|
349 | 371 | def __contains__(self, key: Any) -> bool:
|
@@ -520,13 +542,12 @@ def _maybe_convert_i8(self, key):
|
520 | 542 | The original key if no conversion occurred, int if converted scalar,
|
521 | 543 | Int64Index if converted list-like.
|
522 | 544 | """
|
523 |
| - original = key |
524 | 545 | if is_list_like(key):
|
525 | 546 | key = ensure_index(key)
|
526 |
| - key = self._maybe_convert_numeric_to_64bit(key) |
| 547 | + key = maybe_convert_numeric_to_64bit(key) |
527 | 548 |
|
528 | 549 | if not self._needs_i8_conversion(key):
|
529 |
| - return original |
| 550 | + return key |
530 | 551 |
|
531 | 552 | scalar = is_scalar(key)
|
532 | 553 | if is_interval_dtype(key) or isinstance(key, Interval):
|
@@ -569,20 +590,6 @@ def _maybe_convert_i8(self, key):
|
569 | 590 |
|
570 | 591 | return key_i8
|
571 | 592 |
|
572 |
| - def _maybe_convert_numeric_to_64bit(self, idx: Index) -> Index: |
573 |
| - # IntervalTree only supports 64 bit numpy array |
574 |
| - dtype = idx.dtype |
575 |
| - if np.issubclass_(dtype.type, np.number): |
576 |
| - return idx |
577 |
| - elif is_signed_integer_dtype(dtype) and dtype != np.int64: |
578 |
| - return idx.astype(np.int64) |
579 |
| - elif is_unsigned_integer_dtype(dtype) and dtype != np.uint64: |
580 |
| - return idx.astype(np.uint64) |
581 |
| - elif is_float_dtype(dtype) and dtype != np.float64: |
582 |
| - return idx.astype(np.float64) |
583 |
| - else: |
584 |
| - return idx |
585 |
| - |
586 | 593 | def _searchsorted_monotonic(self, label, side: Literal["left", "right"] = "left"):
|
587 | 594 | if not self.is_non_overlapping_monotonic:
|
588 | 595 | raise KeyError(
|
|
0 commit comments