Skip to content

Commit 1dd8725

Browse files
committed
API: ensure IntervalIndex.left/right are 64bit if numeric part II
1 parent c4a84ab commit 1dd8725

File tree

3 files changed

+37
-22
lines changed

3 files changed

+37
-22
lines changed

pandas/core/arrays/interval.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
Dtype,
3636
IntervalClosedType,
3737
NpDtype,
38+
NumpyIndexT,
3839
PositionalIndexer,
3940
ScalarIndexer,
4041
SequenceIndexer,
@@ -48,6 +49,7 @@
4849

4950
from pandas.core.dtypes.cast import LossySetitemError
5051
from pandas.core.dtypes.common import (
52+
is_array_like,
5153
is_categorical_dtype,
5254
is_dtype_equal,
5355
is_float_dtype,
@@ -56,7 +58,9 @@
5658
is_list_like,
5759
is_object_dtype,
5860
is_scalar,
61+
is_signed_integer_dtype,
5962
is_string_dtype,
63+
is_unsigned_integer_dtype,
6064
needs_i8_conversion,
6165
pandas_dtype,
6266
)
@@ -180,6 +184,24 @@
180184
"""
181185

182186

187+
def maybe_convert_numeric_to_64bit(arr: NumpyIndexT) -> NumpyIndexT:
188+
# IntervalTree only supports 64 bit numpy array
189+
190+
if not is_array_like(arr):
191+
return arr
192+
dtype = arr.dtype
193+
if not np.issubclass_(dtype.type, np.number):
194+
return arr
195+
elif is_signed_integer_dtype(dtype) and dtype != np.int64:
196+
return arr.astype(np.int64)
197+
elif is_unsigned_integer_dtype(dtype) and dtype != np.uint64:
198+
return arr.astype(np.uint64)
199+
elif is_float_dtype(dtype) and dtype != np.float64:
200+
return arr.astype(np.float64)
201+
else:
202+
return arr
203+
204+
183205
@Appender(
184206
_interval_shared_docs["class"]
185207
% {
@@ -304,7 +326,10 @@ def _ensure_simple_new_inputs(
304326
from pandas.core.indexes.base import ensure_index
305327

306328
left = ensure_index(left, copy=copy)
329+
left = maybe_convert_numeric_to_64bit(left)
330+
307331
right = ensure_index(right, copy=copy)
332+
right = maybe_convert_numeric_to_64bit(right)
308333

309334
if closed is None and isinstance(dtype, IntervalDtype):
310335
closed = dtype.closed

pandas/core/indexes/interval.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,6 @@
5959
is_number,
6060
is_object_dtype,
6161
is_scalar,
62-
is_signed_integer_dtype,
63-
is_unsigned_integer_dtype,
6462
)
6563
from pandas.core.dtypes.dtypes import IntervalDtype
6664
from pandas.core.dtypes.missing import is_valid_na_for_dtype
@@ -69,6 +67,7 @@
6967
from pandas.core.arrays.interval import (
7068
IntervalArray,
7169
_interval_shared_docs,
70+
maybe_convert_numeric_to_64bit,
7271
)
7372
import pandas.core.common as com
7473
from pandas.core.indexers import is_valid_positional_slice
@@ -520,13 +519,12 @@ def _maybe_convert_i8(self, key):
520519
The original key if no conversion occurred, int if converted scalar,
521520
Int64Index if converted list-like.
522521
"""
523-
original = key
524522
if is_list_like(key):
525523
key = ensure_index(key)
526-
key = self._maybe_convert_numeric_to_64bit(key)
524+
key = maybe_convert_numeric_to_64bit(key)
527525

528526
if not self._needs_i8_conversion(key):
529-
return original
527+
return key
530528

531529
scalar = is_scalar(key)
532530
if is_interval_dtype(key) or isinstance(key, Interval):
@@ -569,20 +567,6 @@ def _maybe_convert_i8(self, key):
569567

570568
return key_i8
571569

572-
def _maybe_convert_numeric_to_64bit(self, idx: Index) -> Index:
573-
# IntervalTree only supports 64 bit numpy array
574-
dtype = idx.dtype
575-
if np.issubclass_(dtype.type, np.number):
576-
return idx
577-
elif is_signed_integer_dtype(dtype) and dtype != np.int64:
578-
return idx.astype(np.int64)
579-
elif is_unsigned_integer_dtype(dtype) and dtype != np.uint64:
580-
return idx.astype(np.uint64)
581-
elif is_float_dtype(dtype) and dtype != np.float64:
582-
return idx.astype(np.float64)
583-
else:
584-
return idx
585-
586570
def _searchsorted_monotonic(self, label, side: Literal["left", "right"] = "left"):
587571
if not self.is_non_overlapping_monotonic:
588572
raise KeyError(

pandas/tests/indexes/interval/test_interval.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
timedelta_range,
1919
)
2020
import pandas._testing as tm
21-
from pandas.core.api import Float64Index
21+
from pandas.core.api import (
22+
Float64Index,
23+
NumericIndex,
24+
)
2225
import pandas.core.common as com
2326

2427

@@ -435,9 +438,12 @@ def test_maybe_convert_i8_numeric(self, breaks, make_key):
435438
index = IntervalIndex.from_breaks(breaks)
436439
key = make_key(breaks)
437440

438-
# no conversion occurs for numeric
439441
result = index._maybe_convert_i8(key)
440-
assert result is key
442+
if not isinstance(result, NumericIndex):
443+
assert result is key
444+
else:
445+
expected = NumericIndex(key)
446+
tm.assert_index_equal(result, expected)
441447

442448
@pytest.mark.parametrize(
443449
"breaks1, breaks2",

0 commit comments

Comments
 (0)