diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index c72c5fa019bd7..5be2319088f9a 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -112,7 +112,10 @@ pandas_dtype, validate_all_hashable, ) -from pandas.core.dtypes.concat import concat_compat +from pandas.core.dtypes.concat import ( + concat_compat, + union_categoricals, +) from pandas.core.dtypes.dtypes import ( ArrowDtype, CategoricalDtype, @@ -212,6 +215,7 @@ PeriodArray, ) + __all__ = ["Index"] _unsortable_types = frozenset(("mixed", "mixed-integer")) @@ -2922,6 +2926,27 @@ def union(self, other, sort=None): "Can only union MultiIndex with MultiIndex or Index of tuples, " "try mi.to_flat_index().union(other) instead." ) + + if isinstance(self, ABCCategoricalIndex) and isinstance( + other, ABCCategoricalIndex + ): + both_categories = self.categories + # if ordered and unordered, we set categories to be unordered + ordered = False if self.ordered != other.ordered else None + if ordered is False: + both_categories = union_categoricals( + [self.as_unordered(), other.as_unordered()], # type: ignore[attr-defined] + sort_categories=True, + ).categories + else: + both_categories = union_categoricals( + [self, other], sort_categories=True + ).categories + # Convert both indexes to have the same categories + self = self.set_categories(both_categories, ordered=ordered) # type: ignore[attr-defined] + other = other.set_categories(both_categories, ordered=ordered) # type: ignore[attr-defined] + return self.union(other, sort=sort) + self, other = self._dti_setop_align_tzs(other, "union") dtype = self._find_common_type_compat(other) @@ -3006,7 +3031,7 @@ def _union(self, other: Index, sort: bool | None): else: missing = algos.unique1d(self.get_indexer_non_unique(other)[1]) - result: Index | MultiIndex | ArrayLike + result: Index | MultiIndex | CategoricalIndex | ArrayLike if self._is_multi: # Preserve MultiIndex to avoid losing dtypes result = self.append(other.take(missing)) diff --git a/pandas/tests/frame/test_constructors.py b/pandas/tests/frame/test_constructors.py index 7d1a5b4492740..20c940ad0c6dd 100644 --- a/pandas/tests/frame/test_constructors.py +++ b/pandas/tests/frame/test_constructors.py @@ -2349,7 +2349,7 @@ def test_construct_with_two_categoricalindex_series(self): result = DataFrame([s1, s2]) expected = DataFrame( np.array([[39, 6, 4, np.nan, np.nan], [152.0, 242.0, 150.0, 2.0, 2.0]]), - columns=["female", "male", "unknown", "f", "m"], + columns=CategoricalIndex(["female", "male", "unknown", "f", "m"]), ) tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/indexes/test_setops.py b/pandas/tests/indexes/test_setops.py index 27b54ea66f0ac..6fae11781e623 100644 --- a/pandas/tests/indexes/test_setops.py +++ b/pandas/tests/indexes/test_setops.py @@ -573,7 +573,7 @@ def test_union_duplicate_index_subsets_of_each_other( expected = Index([1, 2, 2, 3, 3, 4], dtype=dtype) if isinstance(a, CategoricalIndex): - expected = Index([1, 2, 2, 3, 3, 4]) + expected = CategoricalIndex([1, 2, 2, 3, 3, 4]) result = a.union(b) tm.assert_index_equal(result, expected) result = a.union(b, sort=False) @@ -670,7 +670,7 @@ def test_union_with_duplicate_index_not_subset_and_non_monotonic( b = Index([0, 0, 1], dtype=dtype) expected = Index([0, 0, 1, 2], dtype=dtype) if isinstance(a, CategoricalIndex): - expected = Index([0, 0, 1, 2]) + expected = CategoricalIndex([0, 0, 1, 2]) result = a.union(b) tm.assert_index_equal(result, expected) diff --git a/pandas/tests/reshape/concat/test_append.py b/pandas/tests/reshape/concat/test_append.py index 3fb6a3fb61396..96ca06e1d16a4 100644 --- a/pandas/tests/reshape/concat/test_append.py +++ b/pandas/tests/reshape/concat/test_append.py @@ -234,6 +234,8 @@ def test_append_different_columns_types(self, df_columns, series_index): result = df._append(ser) idx_diff = ser.index.difference(df_columns) combined_columns = Index(df_columns.tolist()).append(idx_diff) + if isinstance(result.columns, pd.CategoricalIndex): + combined_columns = pd.CategoricalIndex(combined_columns) expected = DataFrame( [ [1.0, 2.0, 3.0, np.nan, np.nan, np.nan],