Skip to content

Commit 65efe5d

Browse files
committed
API: ensure IntervalIndex.left/right are 64bit if numeric part II
1 parent 32a261a commit 65efe5d

File tree

3 files changed

+31
-22
lines changed

3 files changed

+31
-22
lines changed

pandas/core/arrays/interval.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
Dtype,
3636
IntervalClosedType,
3737
NpDtype,
38+
NumpyIndexT,
3839
PositionalIndexer,
3940
ScalarIndexer,
4041
SequenceIndexer,
@@ -56,7 +57,9 @@
5657
is_list_like,
5758
is_object_dtype,
5859
is_scalar,
60+
is_signed_integer_dtype,
5961
is_string_dtype,
62+
is_unsigned_integer_dtype,
6063
needs_i8_conversion,
6164
pandas_dtype,
6265
)
@@ -180,6 +183,21 @@
180183
"""
181184

182185

186+
def maybe_convert_numeric_to_64bit(arr: NumpyIndexT) -> NumpyIndexT:
187+
# IntervalTree only supports 64 bit numpy array
188+
dtype = arr.dtype
189+
if not np.issubclass_(dtype.type, np.number):
190+
return arr
191+
elif is_signed_integer_dtype(dtype) and dtype != np.int64:
192+
return arr.astype(np.int64)
193+
elif is_unsigned_integer_dtype(dtype) and dtype != np.uint64:
194+
return arr.astype(np.uint64)
195+
elif is_float_dtype(dtype) and dtype != np.float64:
196+
return arr.astype(np.float64)
197+
else:
198+
return arr
199+
200+
183201
@Appender(
184202
_interval_shared_docs["class"]
185203
% {
@@ -252,6 +270,7 @@ def __new__(
252270

253271
# might need to convert empty or purely na data
254272
data = _maybe_convert_platform_interval(data)
273+
data = maybe_convert_numeric_to_64bit(data)
255274
left, right, infer_closed = intervals_to_interval_bounds(
256275
data, validate_closed=closed is None
257276
)

pandas/core/indexes/interval.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,6 @@
5959
is_number,
6060
is_object_dtype,
6161
is_scalar,
62-
is_signed_integer_dtype,
63-
is_unsigned_integer_dtype,
6462
)
6563
from pandas.core.dtypes.dtypes import IntervalDtype
6664
from pandas.core.dtypes.missing import is_valid_na_for_dtype
@@ -69,6 +67,7 @@
6967
from pandas.core.arrays.interval import (
7068
IntervalArray,
7169
_interval_shared_docs,
70+
maybe_convert_numeric_to_64bit,
7271
)
7372
import pandas.core.common as com
7473
from pandas.core.indexers import is_valid_positional_slice
@@ -520,13 +519,12 @@ def _maybe_convert_i8(self, key):
520519
The original key if no conversion occurred, int if converted scalar,
521520
Int64Index if converted list-like.
522521
"""
523-
original = key
524522
if is_list_like(key):
525523
key = ensure_index(key)
526-
key = self._maybe_convert_numeric_to_64bit(key)
524+
key = maybe_convert_numeric_to_64bit(key)
527525

528526
if not self._needs_i8_conversion(key):
529-
return original
527+
return key
530528

531529
scalar = is_scalar(key)
532530
if is_interval_dtype(key) or isinstance(key, Interval):
@@ -569,20 +567,6 @@ def _maybe_convert_i8(self, key):
569567

570568
return key_i8
571569

572-
def _maybe_convert_numeric_to_64bit(self, idx: Index) -> Index:
573-
# IntervalTree only supports 64 bit numpy array
574-
dtype = idx.dtype
575-
if np.issubclass_(dtype.type, np.number):
576-
return idx
577-
elif is_signed_integer_dtype(dtype) and dtype != np.int64:
578-
return idx.astype(np.int64)
579-
elif is_unsigned_integer_dtype(dtype) and dtype != np.uint64:
580-
return idx.astype(np.uint64)
581-
elif is_float_dtype(dtype) and dtype != np.float64:
582-
return idx.astype(np.float64)
583-
else:
584-
return idx
585-
586570
def _searchsorted_monotonic(self, label, side: Literal["left", "right"] = "left"):
587571
if not self.is_non_overlapping_monotonic:
588572
raise KeyError(

pandas/tests/indexes/interval/test_interval.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
timedelta_range,
1919
)
2020
import pandas._testing as tm
21-
from pandas.core.api import Float64Index
21+
from pandas.core.api import (
22+
Float64Index,
23+
NumericIndex,
24+
)
2225
import pandas.core.common as com
2326

2427

@@ -435,9 +438,12 @@ def test_maybe_convert_i8_numeric(self, breaks, make_key):
435438
index = IntervalIndex.from_breaks(breaks)
436439
key = make_key(breaks)
437440

438-
# no conversion occurs for numeric
439441
result = index._maybe_convert_i8(key)
440-
assert result is key
442+
if not isinstance(result, NumericIndex):
443+
assert result is key
444+
else:
445+
expected = NumericIndex(key)
446+
tm.assert_index_equal(result, expected)
441447

442448
@pytest.mark.parametrize(
443449
"breaks1, breaks2",

0 commit comments

Comments
 (0)