Skip to content

Commit 4117441

Browse files
committed
move maybe_upcast_numeric_to_64bit to core.dtypes.cast
1 parent 1714a53 commit 4117441

File tree

3 files changed

+42
-27
lines changed

3 files changed

+42
-27
lines changed

pandas/core/arrays/interval.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@
4545
from pandas.errors import IntCastingNaNError
4646
from pandas.util._decorators import Appender
4747

48-
from pandas.core.dtypes.cast import LossySetitemError
48+
from pandas.core.dtypes.cast import (
49+
LossySetitemError,
50+
maybe_upcast_numeric_to_64bit,
51+
)
4952
from pandas.core.dtypes.common import (
5053
is_categorical_dtype,
5154
is_dtype_equal,
@@ -1750,5 +1753,6 @@ def _maybe_convert_platform_interval(values) -> ArrayLike:
17501753
values = extract_array(values, extract_numpy=True)
17511754

17521755
if not hasattr(values, "dtype"):
1753-
return np.asarray(values)
1756+
values = np.asarray(values)
1757+
values = maybe_upcast_numeric_to_64bit(values)
17541758
return values

pandas/core/dtypes/cast.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
ArrayLike,
3636
Dtype,
3737
DtypeObj,
38+
NumpyIndexT,
3839
Scalar,
3940
npt,
4041
)
@@ -52,6 +53,7 @@
5253
ensure_int64,
5354
ensure_object,
5455
ensure_str,
56+
is_array_like,
5557
is_bool,
5658
is_bool_dtype,
5759
is_complex,
@@ -65,6 +67,7 @@
6567
is_numeric_dtype,
6668
is_object_dtype,
6769
is_scalar,
70+
is_signed_integer_dtype,
6871
is_string_dtype,
6972
is_timedelta64_dtype,
7073
is_unsigned_integer_dtype,
@@ -412,6 +415,34 @@ def trans(x):
412415
return result
413416

414417

418+
def maybe_upcast_numeric_to_64bit(arr: NumpyIndexT) -> NumpyIndexT:
419+
"""
420+
If array is a int/uint/float bit size lower than 64 bit, upcast it to 64 bit.
421+
422+
Parameters
423+
----------
424+
arr : ndarray or ExtensionArray
425+
426+
Returns
427+
-------
428+
ndarray or ExtensionArray
429+
"""
430+
431+
if not is_array_like(arr):
432+
return arr
433+
dtype = arr.dtype
434+
if not np.issubclass_(dtype.type, np.number):
435+
return arr
436+
elif is_signed_integer_dtype(dtype) and dtype != np.int64:
437+
return arr.astype(np.int64)
438+
elif is_unsigned_integer_dtype(dtype) and dtype != np.uint64:
439+
return arr.astype(np.uint64)
440+
elif is_float_dtype(dtype) and dtype != np.float64:
441+
return arr.astype(np.float64)
442+
else:
443+
return arr
444+
445+
415446
def maybe_cast_pointwise_result(
416447
result: ArrayLike,
417448
dtype: DtypeObj,

pandas/core/indexes/interval.py

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
DtypeObj,
3232
IntervalClosedType,
3333
npt,
34-
NumpyIndexT,
3534
)
3635
from pandas.errors import InvalidIndexError
3736
from pandas.util._decorators import (
@@ -45,10 +44,10 @@
4544
infer_dtype_from_scalar,
4645
maybe_box_datetimelike,
4746
maybe_downcast_numeric,
47+
maybe_upcast_numeric_to_64bit,
4848
)
4949
from pandas.core.dtypes.common import (
5050
ensure_platform_int,
51-
is_array_like,
5251
is_datetime64tz_dtype,
5352
is_datetime_or_timedelta_dtype,
5453
is_dtype_equal,
@@ -61,8 +60,6 @@
6160
is_number,
6261
is_object_dtype,
6362
is_scalar,
64-
is_signed_integer_dtype,
65-
is_unsigned_integer_dtype,
6663
)
6764
from pandas.core.dtypes.dtypes import IntervalDtype
6865
from pandas.core.dtypes.missing import is_valid_na_for_dtype
@@ -148,24 +145,6 @@ def _new_IntervalIndex(cls, d):
148145
return cls.from_arrays(**d)
149146

150147

151-
def maybe_convert_numeric_to_64bit(arr: NumpyIndexT) -> NumpyIndexT:
152-
# IntervalTree only supports 64 bit numpy array
153-
154-
if not is_array_like(arr):
155-
return arr
156-
dtype = arr.dtype
157-
if not np.issubclass_(dtype.type, np.number):
158-
return arr
159-
elif is_signed_integer_dtype(dtype) and dtype != np.int64:
160-
return arr.astype(np.int64)
161-
elif is_unsigned_integer_dtype(dtype) and dtype != np.uint64:
162-
return arr.astype(np.uint64)
163-
elif is_float_dtype(dtype) and dtype != np.float64:
164-
return arr.astype(np.float64)
165-
else:
166-
return arr
167-
168-
169148
@Appender(
170149
_interval_shared_docs["class"]
171150
% {
@@ -362,10 +341,11 @@ def from_tuples(
362341
# "Union[IndexEngine, ExtensionEngine]" in supertype "Index"
363342
@cache_readonly
364343
def _engine(self) -> IntervalTree: # type: ignore[override]
344+
# IntervalTree does not supports numpy array unless they are 64 bit
365345
left = self._maybe_convert_i8(self.left)
366-
left = maybe_convert_numeric_to_64bit(left)
346+
left = maybe_upcast_numeric_to_64bit(left)
367347
right = self._maybe_convert_i8(self.right)
368-
right = maybe_convert_numeric_to_64bit(right)
348+
right = maybe_upcast_numeric_to_64bit(right)
369349
return IntervalTree(left, right, closed=self.closed)
370350

371351
def __contains__(self, key: Any) -> bool:
@@ -544,7 +524,7 @@ def _maybe_convert_i8(self, key):
544524
"""
545525
if is_list_like(key):
546526
key = ensure_index(key)
547-
key = maybe_convert_numeric_to_64bit(key)
527+
key = maybe_upcast_numeric_to_64bit(key)
548528

549529
if not self._needs_i8_conversion(key):
550530
return key

0 commit comments

Comments
 (0)