Skip to content

Commit 939d0ba

Browse files
authored
API: ensure IntervalIndex.left/right are 64bit if numeric, part II (#50195)
1 parent 1d63474 commit 939d0ba

File tree

4 files changed

+51
-28
lines changed

4 files changed

+51
-28
lines changed

pandas/core/arrays/interval.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1787,5 +1787,7 @@ def _maybe_convert_platform_interval(values) -> ArrayLike:
17871787
values = extract_array(values, extract_numpy=True)
17881788

17891789
if not hasattr(values, "dtype"):
1790-
return np.asarray(values)
1790+
values = np.asarray(values)
1791+
if is_integer_dtype(values) and values.dtype != np.int64:
1792+
values = values.astype(np.int64)
17911793
return values

pandas/core/dtypes/cast.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
ArrayLike,
3636
Dtype,
3737
DtypeObj,
38+
NumpyIndexT,
3839
Scalar,
3940
npt,
4041
)
@@ -65,6 +66,7 @@
6566
is_numeric_dtype,
6667
is_object_dtype,
6768
is_scalar,
69+
is_signed_integer_dtype,
6870
is_string_dtype,
6971
is_timedelta64_dtype,
7072
is_unsigned_integer_dtype,
@@ -412,6 +414,29 @@ def trans(x):
412414
return result
413415

414416

417+
def maybe_upcast_numeric_to_64bit(arr: NumpyIndexT) -> NumpyIndexT:
418+
"""
419+
If array is a int/uint/float bit size lower than 64 bit, upcast it to 64 bit.
420+
421+
Parameters
422+
----------
423+
arr : ndarray or ExtensionArray
424+
425+
Returns
426+
-------
427+
ndarray or ExtensionArray
428+
"""
429+
dtype = arr.dtype
430+
if is_signed_integer_dtype(dtype) and dtype != np.int64:
431+
return arr.astype(np.int64)
432+
elif is_unsigned_integer_dtype(dtype) and dtype != np.uint64:
433+
return arr.astype(np.uint64)
434+
elif is_float_dtype(dtype) and dtype != np.float64:
435+
return arr.astype(np.float64)
436+
else:
437+
return arr
438+
439+
415440
def maybe_cast_pointwise_result(
416441
result: ArrayLike,
417442
dtype: DtypeObj,

pandas/core/indexes/interval.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
infer_dtype_from_scalar,
4545
maybe_box_datetimelike,
4646
maybe_downcast_numeric,
47+
maybe_upcast_numeric_to_64bit,
4748
)
4849
from pandas.core.dtypes.common import (
4950
ensure_platform_int,
@@ -59,8 +60,6 @@
5960
is_number,
6061
is_object_dtype,
6162
is_scalar,
62-
is_signed_integer_dtype,
63-
is_unsigned_integer_dtype,
6463
)
6564
from pandas.core.dtypes.dtypes import IntervalDtype
6665
from pandas.core.dtypes.missing import is_valid_na_for_dtype
@@ -342,8 +341,11 @@ def from_tuples(
342341
# "Union[IndexEngine, ExtensionEngine]" in supertype "Index"
343342
@cache_readonly
344343
def _engine(self) -> IntervalTree: # type: ignore[override]
344+
# IntervalTree does not supports numpy array unless they are 64 bit
345345
left = self._maybe_convert_i8(self.left)
346+
left = maybe_upcast_numeric_to_64bit(left)
346347
right = self._maybe_convert_i8(self.right)
348+
right = maybe_upcast_numeric_to_64bit(right)
347349
return IntervalTree(left, right, closed=self.closed)
348350

349351
def __contains__(self, key: Any) -> bool:
@@ -520,13 +522,12 @@ def _maybe_convert_i8(self, key):
520522
The original key if no conversion occurred, int if converted scalar,
521523
Int64Index if converted list-like.
522524
"""
523-
original = key
524525
if is_list_like(key):
525526
key = ensure_index(key)
526-
key = self._maybe_convert_numeric_to_64bit(key)
527+
key = maybe_upcast_numeric_to_64bit(key)
527528

528529
if not self._needs_i8_conversion(key):
529-
return original
530+
return key
530531

531532
scalar = is_scalar(key)
532533
if is_interval_dtype(key) or isinstance(key, Interval):
@@ -569,20 +570,6 @@ def _maybe_convert_i8(self, key):
569570

570571
return key_i8
571572

572-
def _maybe_convert_numeric_to_64bit(self, idx: Index) -> Index:
573-
# IntervalTree only supports 64 bit numpy array
574-
dtype = idx.dtype
575-
if np.issubclass_(dtype.type, np.number):
576-
return idx
577-
elif is_signed_integer_dtype(dtype) and dtype != np.int64:
578-
return idx.astype(np.int64)
579-
elif is_unsigned_integer_dtype(dtype) and dtype != np.uint64:
580-
return idx.astype(np.uint64)
581-
elif is_float_dtype(dtype) and dtype != np.float64:
582-
return idx.astype(np.float64)
583-
else:
584-
return idx
585-
586573
def _searchsorted_monotonic(self, label, side: Literal["left", "right"] = "left"):
587574
if not self.is_non_overlapping_monotonic:
588575
raise KeyError(

pandas/tests/indexes/interval/test_interval.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -415,27 +415,36 @@ def test_maybe_convert_i8_nat(self, breaks):
415415
tm.assert_index_equal(result, expected)
416416

417417
@pytest.mark.parametrize(
418-
"breaks",
419-
[np.arange(5, dtype="int64"), np.arange(5, dtype="float64")],
420-
ids=lambda x: str(x.dtype),
418+
"make_key",
419+
[lambda breaks: breaks, list],
420+
ids=["lambda", "list"],
421421
)
422+
def test_maybe_convert_i8_numeric(self, make_key, any_real_numpy_dtype):
423+
# GH 20636
424+
breaks = np.arange(5, dtype=any_real_numpy_dtype)
425+
index = IntervalIndex.from_breaks(breaks)
426+
key = make_key(breaks)
427+
428+
result = index._maybe_convert_i8(key)
429+
expected = Index(key)
430+
tm.assert_index_equal(result, expected)
431+
422432
@pytest.mark.parametrize(
423433
"make_key",
424434
[
425435
IntervalIndex.from_breaks,
426436
lambda breaks: Interval(breaks[0], breaks[1]),
427-
lambda breaks: breaks,
428437
lambda breaks: breaks[0],
429-
list,
430438
],
431-
ids=["IntervalIndex", "Interval", "Index", "scalar", "list"],
439+
ids=["IntervalIndex", "Interval", "scalar"],
432440
)
433-
def test_maybe_convert_i8_numeric(self, breaks, make_key):
441+
def test_maybe_convert_i8_numeric_identical(self, make_key, any_real_numpy_dtype):
434442
# GH 20636
443+
breaks = np.arange(5, dtype=any_real_numpy_dtype)
435444
index = IntervalIndex.from_breaks(breaks)
436445
key = make_key(breaks)
437446

438-
# no conversion occurs for numeric
447+
# test if _maybe_convert_i8 won't change key if an Interval or IntervalIndex
439448
result = index._maybe_convert_i8(key)
440449
assert result is key
441450

0 commit comments

Comments
 (0)