Skip to content

Commit 25c71c7

Browse files
Terji PetersenTerji Petersen
Terji Petersen
authored and
Terji Petersen
committed
IntervalIndex tests
1 parent 1ce1e8a commit 25c71c7

File tree

2 files changed

+32
-22
lines changed

2 files changed

+32
-22
lines changed

pandas/core/indexes/interval.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@
5959
is_number,
6060
is_object_dtype,
6161
is_scalar,
62+
is_signed_integer_dtype,
63+
is_unsigned_integer_dtype,
6264
)
6365
from pandas.core.dtypes.dtypes import IntervalDtype
6466
from pandas.core.dtypes.missing import is_valid_na_for_dtype
@@ -340,8 +342,8 @@ def from_tuples(
340342
# "Union[IndexEngine, ExtensionEngine]" in supertype "Index"
341343
@cache_readonly
342344
def _engine(self) -> IntervalTree: # type: ignore[override]
343-
left = self._maybe_convert_i8(self.left)
344-
right = self._maybe_convert_i8(self.right)
345+
left = self._maybe_convert_to_64bit_if_numeric(self.left)
346+
right = self._maybe_convert_to_64bit_if_numeric(self.right)
345347
return IntervalTree(left, right, closed=self.closed)
346348

347349
def __contains__(self, key: Any) -> bool:
@@ -501,6 +503,18 @@ def _needs_i8_conversion(self, key) -> bool:
501503
i8_types = (Timestamp, Timedelta, DatetimeIndex, TimedeltaIndex)
502504
return isinstance(key, i8_types)
503505

506+
def _maybe_convert_to_64bit_if_numeric(self, key):
507+
key = self._maybe_convert_i8(key)
508+
dtype = key.dtype
509+
if is_signed_integer_dtype(dtype) and dtype != "int64":
510+
return key.astype(np.int64)
511+
elif is_unsigned_integer_dtype(dtype) and dtype != "uint64":
512+
return key.astype(np.uint64)
513+
elif is_float_dtype(dtype) and dtype != "float64":
514+
return key.astype(np.float64)
515+
else:
516+
return key
517+
504518
def _maybe_convert_i8(self, key):
505519
"""
506520
Maybe convert a given key to its equivalent i8 value(s). Used as a

pandas/tests/indexes/interval/test_constructors.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,7 @@
1919
)
2020
import pandas._testing as tm
2121
from pandas.api.types import is_unsigned_integer_dtype
22-
from pandas.core.api import (
23-
Float64Index,
24-
Int64Index,
25-
UInt64Index,
26-
)
22+
from pandas.core.api import NumericIndex
2723
from pandas.core.arrays import IntervalArray
2824
import pandas.core.common as com
2925

@@ -44,9 +40,9 @@ class ConstructorTests:
4440
params=[
4541
([3, 14, 15, 92, 653], np.int64),
4642
(np.arange(10, dtype="int64"), np.int64),
47-
(Int64Index(range(-10, 11)), np.int64),
48-
(UInt64Index(range(10, 31)), np.uint64),
49-
(Float64Index(np.arange(20, 30, 0.5)), np.float64),
43+
(NumericIndex(range(-10, 11), dtype=np.int64), np.int64),
44+
(NumericIndex(range(10, 31), dtype=np.uint64), np.uint64),
45+
(NumericIndex(np.arange(20, 30, 0.5), dtype=np.float64), np.float64),
5046
(date_range("20180101", periods=10), "<M8[ns]"),
5147
(
5248
date_range("20180101", periods=10, tz="US/Eastern"),
@@ -74,10 +70,10 @@ def test_constructor(self, constructor, breaks_and_expected_subtype, closed, nam
7470
@pytest.mark.parametrize(
7571
"breaks, subtype",
7672
[
77-
(Int64Index([0, 1, 2, 3, 4]), "float64"),
78-
(Int64Index([0, 1, 2, 3, 4]), "datetime64[ns]"),
79-
(Int64Index([0, 1, 2, 3, 4]), "timedelta64[ns]"),
80-
(Float64Index([0, 1, 2, 3, 4]), "int64"),
73+
(NumericIndex([0, 1, 2, 3, 4], dtype=np.int64), "float64"),
74+
(NumericIndex([0, 1, 2, 3, 4], dtype=np.int64), "datetime64[ns]"),
75+
(NumericIndex([0, 1, 2, 3, 4], dtype=np.int64), "timedelta64[ns]"),
76+
(NumericIndex([0, 1, 2, 3, 4], dtype=np.float64), "int64"),
8177
(date_range("2017-01-01", periods=5), "int64"),
8278
(timedelta_range("1 day", periods=5), "int64"),
8379
],
@@ -96,9 +92,9 @@ def test_constructor_dtype(self, constructor, breaks, subtype):
9692
@pytest.mark.parametrize(
9793
"breaks",
9894
[
99-
Int64Index([0, 1, 2, 3, 4]),
100-
UInt64Index([0, 1, 2, 3, 4]),
101-
Float64Index([0, 1, 2, 3, 4]),
95+
NumericIndex([0, 1, 2, 3, 4], dtype=np.int64),
96+
NumericIndex([0, 1, 2, 3, 4], dtype=np.uint64),
97+
NumericIndex([0, 1, 2, 3, 4], dtype=np.float64),
10298
date_range("2017-01-01", periods=5),
10399
timedelta_range("1 day", periods=5),
104100
],
@@ -255,8 +251,8 @@ def test_mixed_float_int(self, left_subtype, right_subtype):
255251
right = np.arange(1, 10, dtype=right_subtype)
256252
result = IntervalIndex.from_arrays(left, right)
257253

258-
expected_left = Float64Index(left)
259-
expected_right = Float64Index(right)
254+
expected_left = NumericIndex(left, dtype=np.float64)
255+
expected_right = NumericIndex(right, dtype=np.float64)
260256
expected_subtype = np.float64
261257

262258
tm.assert_index_equal(result.left, expected_left)
@@ -307,9 +303,9 @@ class TuplesClassConstructorTests(ConstructorTests):
307303
params=[
308304
([3, 14, 15, 92, 653], np.int64),
309305
(np.arange(10, dtype="int64"), np.int64),
310-
(Int64Index(range(-10, 11)), np.int64),
311-
(UInt64Index(range(10, 31)), np.int64),
312-
(Float64Index(np.arange(20, 30, 0.5)), np.float64),
306+
(NumericIndex(range(-10, 11), dtype=np.int64), np.int64),
307+
(NumericIndex(range(10, 31), dtype=np.uint64), np.int64),
308+
(NumericIndex(np.arange(20, 30, 0.5), dtype=np.float64), np.float64),
313309
(date_range("20180101", periods=10), "<M8[ns]"),
314310
(
315311
date_range("20180101", periods=10, tz="US/Eastern"),

0 commit comments

Comments
 (0)