Skip to content

Commit 1714a53

Browse files
committed
API:move use of maybe_convert_numeric_to_64bit to to also be used in IntervalIndex._engine
1 parent 8117a55 commit 1714a53

File tree

2 files changed

+33
-20
lines changed

2 files changed

+33
-20
lines changed

pandas/core/indexes/interval.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
DtypeObj,
3232
IntervalClosedType,
3333
npt,
34+
NumpyIndexT,
3435
)
3536
from pandas.errors import InvalidIndexError
3637
from pandas.util._decorators import (
@@ -47,6 +48,7 @@
4748
)
4849
from pandas.core.dtypes.common import (
4950
ensure_platform_int,
51+
is_array_like,
5052
is_datetime64tz_dtype,
5153
is_datetime_or_timedelta_dtype,
5254
is_dtype_equal,
@@ -146,6 +148,24 @@ def _new_IntervalIndex(cls, d):
146148
return cls.from_arrays(**d)
147149

148150

151+
def maybe_convert_numeric_to_64bit(arr: NumpyIndexT) -> NumpyIndexT:
152+
# IntervalTree only supports 64 bit numpy array
153+
154+
if not is_array_like(arr):
155+
return arr
156+
dtype = arr.dtype
157+
if not np.issubclass_(dtype.type, np.number):
158+
return arr
159+
elif is_signed_integer_dtype(dtype) and dtype != np.int64:
160+
return arr.astype(np.int64)
161+
elif is_unsigned_integer_dtype(dtype) and dtype != np.uint64:
162+
return arr.astype(np.uint64)
163+
elif is_float_dtype(dtype) and dtype != np.float64:
164+
return arr.astype(np.float64)
165+
else:
166+
return arr
167+
168+
149169
@Appender(
150170
_interval_shared_docs["class"]
151171
% {
@@ -343,7 +363,9 @@ def from_tuples(
343363
@cache_readonly
344364
def _engine(self) -> IntervalTree: # type: ignore[override]
345365
left = self._maybe_convert_i8(self.left)
366+
left = maybe_convert_numeric_to_64bit(left)
346367
right = self._maybe_convert_i8(self.right)
368+
right = maybe_convert_numeric_to_64bit(right)
347369
return IntervalTree(left, right, closed=self.closed)
348370

349371
def __contains__(self, key: Any) -> bool:
@@ -520,13 +542,12 @@ def _maybe_convert_i8(self, key):
520542
The original key if no conversion occurred, int if converted scalar,
521543
Int64Index if converted list-like.
522544
"""
523-
original = key
524545
if is_list_like(key):
525546
key = ensure_index(key)
526-
key = self._maybe_convert_numeric_to_64bit(key)
547+
key = maybe_convert_numeric_to_64bit(key)
527548

528549
if not self._needs_i8_conversion(key):
529-
return original
550+
return key
530551

531552
scalar = is_scalar(key)
532553
if is_interval_dtype(key) or isinstance(key, Interval):
@@ -569,20 +590,6 @@ def _maybe_convert_i8(self, key):
569590

570591
return key_i8
571592

572-
def _maybe_convert_numeric_to_64bit(self, idx: Index) -> Index:
573-
# IntervalTree only supports 64 bit numpy array
574-
dtype = idx.dtype
575-
if np.issubclass_(dtype.type, np.number):
576-
return idx
577-
elif is_signed_integer_dtype(dtype) and dtype != np.int64:
578-
return idx.astype(np.int64)
579-
elif is_unsigned_integer_dtype(dtype) and dtype != np.uint64:
580-
return idx.astype(np.uint64)
581-
elif is_float_dtype(dtype) and dtype != np.float64:
582-
return idx.astype(np.float64)
583-
else:
584-
return idx
585-
586593
def _searchsorted_monotonic(self, label, side: Literal["left", "right"] = "left"):
587594
if not self.is_non_overlapping_monotonic:
588595
raise KeyError(

pandas/tests/indexes/interval/test_interval.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
timedelta_range,
1919
)
2020
import pandas._testing as tm
21-
from pandas.core.api import Float64Index
21+
from pandas.core.api import (
22+
Float64Index,
23+
NumericIndex,
24+
)
2225
import pandas.core.common as com
2326

2427

@@ -435,9 +438,12 @@ def test_maybe_convert_i8_numeric(self, breaks, make_key):
435438
index = IntervalIndex.from_breaks(breaks)
436439
key = make_key(breaks)
437440

438-
# no conversion occurs for numeric
439441
result = index._maybe_convert_i8(key)
440-
assert result is key
442+
if not isinstance(result, NumericIndex):
443+
assert result is key
444+
else:
445+
expected = NumericIndex(key)
446+
tm.assert_index_equal(result, expected)
441447

442448
@pytest.mark.parametrize(
443449
"breaks1, breaks2",

0 commit comments

Comments
 (0)