From 2dd480145d61bd1ab583a5cc8c76636ada866a0f Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Tue, 3 Oct 2023 22:01:34 +0200 Subject: [PATCH] Backport PR #55364: BUG: eq not implemented for categorical and arrow backed strings --- doc/source/whatsnew/v2.1.2.rst | 1 + pandas/core/arrays/arrow/array.py | 5 ++++- pandas/tests/indexes/categorical/test_equals.py | 6 ++++++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/doc/source/whatsnew/v2.1.2.rst b/doc/source/whatsnew/v2.1.2.rst index 090df1489e493..b6cacecfdc5f8 100644 --- a/doc/source/whatsnew/v2.1.2.rst +++ b/doc/source/whatsnew/v2.1.2.rst @@ -21,6 +21,7 @@ Fixed regressions Bug fixes ~~~~~~~~~ +- Fixed bug in :meth:`Categorical.equals` if other has arrow backed string dtype (:issue:`55364`) - Fixed bug in :meth:`DataFrame.idxmin` and :meth:`DataFrame.idxmax` raising for arrow dtypes (:issue:`55368`) - diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 190fd8fd54e02..2e3ad8bc13091 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -32,6 +32,7 @@ from pandas.core.dtypes.cast import infer_dtype_from_scalar from pandas.core.dtypes.common import ( + CategoricalDtype, is_array_like, is_bool_dtype, is_integer, @@ -628,7 +629,9 @@ def __setstate__(self, state) -> None: def _cmp_method(self, other, op): pc_func = ARROW_CMP_FUNCS[op.__name__] - if isinstance(other, (ArrowExtensionArray, np.ndarray, list, BaseMaskedArray)): + if isinstance( + other, (ArrowExtensionArray, np.ndarray, list, BaseMaskedArray) + ) or isinstance(getattr(other, "dtype", None), CategoricalDtype): result = pc_func(self._pa_array, self._box_pa(other)) elif is_scalar(other): try: diff --git a/pandas/tests/indexes/categorical/test_equals.py b/pandas/tests/indexes/categorical/test_equals.py index 1ed8f3a903439..a8353f301a3c3 100644 --- a/pandas/tests/indexes/categorical/test_equals.py +++ b/pandas/tests/indexes/categorical/test_equals.py @@ -88,3 +88,9 @@ def test_equals_multiindex(self): ci = mi.to_flat_index().astype("category") assert not ci.equals(mi) + + def test_equals_string_dtype(self, any_string_dtype): + # GH#55364 + idx = CategoricalIndex(list("abc"), name="B") + other = Index(["a", "b", "c"], name="B", dtype=any_string_dtype) + assert idx.equals(other)