Skip to content

Commit 5636ac0

Browse files
committed
REF: Simplify Index.union
1 parent e186e18 commit 5636ac0

File tree

3 files changed

+27
-31
lines changed

3 files changed

+27
-31
lines changed

pandas/core/dtypes/cast.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
is_numeric_dtype,
6969
is_object_dtype,
7070
is_scalar,
71+
is_signed_integer_dtype,
7172
is_string_dtype,
7273
is_timedelta64_dtype,
7374
is_unsigned_integer_dtype,
@@ -1779,13 +1780,14 @@ def ensure_nanosecond_dtype(dtype: DtypeObj) -> DtypeObj:
17791780
return dtype
17801781

17811782

1782-
def find_common_type(types: list[DtypeObj]) -> DtypeObj:
1783+
def find_common_type(types: list[DtypeObj], *, strict_uint64: bool = False) -> DtypeObj:
17831784
"""
17841785
Find a common data type among the given dtypes.
17851786
17861787
Parameters
17871788
----------
17881789
types : list of dtypes
1790+
strict_uint64 : if True, object dtype is returned if uint64 and signed int present.
17891791
17901792
Returns
17911793
-------
@@ -1831,6 +1833,13 @@ def find_common_type(types: list[DtypeObj]) -> DtypeObj:
18311833
if is_integer_dtype(t) or is_float_dtype(t) or is_complex_dtype(t):
18321834
return np.dtype("object")
18331835

1836+
# Index.union is special: uint64 & signed int -> object
1837+
if strict_uint64:
1838+
has_uint64 = any(t == "uint64" for t in types)
1839+
has_signed_int = any(is_signed_integer_dtype(t) for t in types)
1840+
if has_uint64 and has_signed_int:
1841+
return np.dtype("object")
1842+
18341843
# error: Argument 1 to "find_common_type" has incompatible type
18351844
# "List[Union[dtype, ExtensionDtype]]"; expected "Sequence[Union[dtype,
18361845
# None, type, _SupportsDtype, str, Tuple[Any, int], Tuple[Any, Union[int,

pandas/core/indexes/base.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@
7777
is_float_dtype,
7878
is_hashable,
7979
is_integer,
80-
is_integer_dtype,
8180
is_interval_dtype,
8281
is_iterator,
8382
is_list_like,
@@ -2963,19 +2962,13 @@ def union(self, other, sort=None):
29632962
stacklevel=2,
29642963
)
29652964

2966-
dtype = find_common_type([self.dtype, other.dtype])
2967-
if self._is_numeric_dtype and other._is_numeric_dtype:
2968-
# Right now, we treat union(int, float) a bit special.
2969-
# See https://github.com/pandas-dev/pandas/issues/26778 for discussion
2970-
# We may change union(int, float) to go to object.
2971-
# float | [u]int -> float (the special case)
2972-
# <T> | <T> -> T
2973-
# <T> | <U> -> object
2974-
if not (is_integer_dtype(self.dtype) and is_integer_dtype(other.dtype)):
2975-
dtype = np.dtype("float64")
2976-
else:
2977-
# one is int64 other is uint64
2978-
dtype = np.dtype("object")
2965+
dtype = find_common_type([self.dtype, other.dtype], strict_uint64=True)
2966+
# Right now, we treat union(float, [u]int) a bit special.
2967+
# See https://github.com/pandas-dev/pandas/issues/26778 for discussion
2968+
# Now it's:
2969+
# * float | [u]int -> float
2970+
# * uint64 | signed int -> object
2971+
# We may change union(float [u]int) to go to object.
29792972

29802973
left = self.astype(dtype, copy=False)
29812974
right = other.astype(dtype, copy=False)

pandas/tests/indexes/test_setops.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99
import pytest
1010

11-
from pandas.core.dtypes.common import is_dtype_equal
11+
from pandas.core.dtypes.cast import find_common_type
1212

1313
from pandas import (
1414
CategoricalIndex,
@@ -25,6 +25,7 @@
2525
import pandas._testing as tm
2626
from pandas.api.types import (
2727
is_datetime64tz_dtype,
28+
is_signed_integer_dtype,
2829
pandas_dtype,
2930
)
3031

@@ -48,7 +49,11 @@ def test_union_different_types(index_flat, index_flat2):
4849
idx1 = index_flat
4950
idx2 = index_flat2
5051

51-
type_pair = tuple(sorted([idx1.dtype.type, idx2.dtype.type], key=lambda x: str(x)))
52+
common_dtype = find_common_type([idx1.dtype, idx2.dtype])
53+
54+
any_uint64 = idx1.dtype == np.uint64 or idx2.dtype == np.uint64
55+
idx1_signed = is_signed_integer_dtype(idx1.dtype)
56+
idx2_signed = is_signed_integer_dtype(idx2.dtype)
5257

5358
# Union with a non-unique, non-monotonic index raises error
5459
# This applies to the boolean index
@@ -58,23 +63,12 @@ def test_union_different_types(index_flat, index_flat2):
5863
res1 = idx1.union(idx2)
5964
res2 = idx2.union(idx1)
6065

61-
if is_dtype_equal(idx1.dtype, idx2.dtype):
62-
assert res1.dtype == idx1.dtype
63-
assert res2.dtype == idx1.dtype
64-
65-
elif type_pair not in COMPATIBLE_INCONSISTENT_PAIRS:
66-
# A union with a CategoricalIndex (even as dtype('O')) and a
67-
# non-CategoricalIndex can only be made if both indices are monotonic.
68-
# This is true before this PR as well.
66+
if any_uint64 and (idx1_signed or idx2_signed):
6967
assert res1.dtype == np.dtype("O")
7068
assert res2.dtype == np.dtype("O")
71-
72-
elif idx1.dtype.kind in ["f", "i", "u"] and idx2.dtype.kind in ["f", "i", "u"]:
73-
assert res1.dtype == np.dtype("f8")
74-
assert res2.dtype == np.dtype("f8")
75-
7669
else:
77-
raise NotImplementedError
70+
assert res1.dtype == common_dtype
71+
assert res2.dtype == common_dtype
7872

7973

8074
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)