diff --git a/doc/source/whatsnew/v1.3.0.rst b/doc/source/whatsnew/v1.3.0.rst index ce64940d964ca..fbd2c2b5345fc 100644 --- a/doc/source/whatsnew/v1.3.0.rst +++ b/doc/source/whatsnew/v1.3.0.rst @@ -133,7 +133,7 @@ See :ref:`install.dependencies` and :ref:`install.optional_dependencies` for mor Other API changes ^^^^^^^^^^^^^^^^^ - +- Partially initialized :class:`CategoricalDtype` (i.e. those with ``categories=None`` objects will no longer compare as equal to fully initialized dtype objects. - - diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index 413309b3d01ad..24cbbd9ec6ac9 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -36,6 +36,7 @@ is_scalar, is_timedelta64_dtype, needs_i8_conversion, + pandas_dtype, ) from pandas.core.dtypes.dtypes import CategoricalDtype from pandas.core.dtypes.generic import ABCIndex, ABCSeries @@ -409,6 +410,7 @@ def astype(self, dtype: Dtype, copy: bool = True) -> ArrayLike: If copy is set to False and dtype is categorical, the original object is returned. """ + dtype = pandas_dtype(dtype) if self.dtype is dtype: result = self.copy() if copy else self diff --git a/pandas/core/dtypes/common.py b/pandas/core/dtypes/common.py index 081339583e3fd..5869b2cf22516 100644 --- a/pandas/core/dtypes/common.py +++ b/pandas/core/dtypes/common.py @@ -639,6 +639,19 @@ def is_dtype_equal(source, target) -> bool: >>> is_dtype_equal(DatetimeTZDtype(tz="UTC"), "datetime64") False """ + if isinstance(target, str): + if not isinstance(source, str): + # GH#38516 ensure we get the same behavior from + # is_dtype_equal(CDT, "category") and CDT == "category" + try: + src = get_dtype(source) + if isinstance(src, ExtensionDtype): + return src == target + except (TypeError, AttributeError): + return False + elif isinstance(source, str): + return is_dtype_equal(target, source) + try: source = get_dtype(source) target = get_dtype(target) diff --git a/pandas/core/dtypes/dtypes.py b/pandas/core/dtypes/dtypes.py index 0de8a07abbec3..75f3b511bc57d 100644 --- a/pandas/core/dtypes/dtypes.py +++ b/pandas/core/dtypes/dtypes.py @@ -354,12 +354,10 @@ def __eq__(self, other: Any) -> bool: elif not (hasattr(other, "ordered") and hasattr(other, "categories")): return False elif self.categories is None or other.categories is None: - # We're forced into a suboptimal corner thanks to math and - # backwards compatibility. We require that `CDT(...) == 'category'` - # for all CDTs **including** `CDT(None, ...)`. Therefore, *all* - # CDT(., .) = CDT(None, False) and *all* - # CDT(., .) = CDT(None, True). - return True + # For non-fully-initialized dtypes, these are only equal to + # - the string "category" (handled above) + # - other CategoricalDtype with categories=None + return self.categories is other.categories elif self.ordered or other.ordered: # At least one has ordered=True; equal if both have ordered=True # and the same values for categories in the same order. diff --git a/pandas/tests/arrays/categorical/test_dtypes.py b/pandas/tests/arrays/categorical/test_dtypes.py index 12654388de904..a2192b2810596 100644 --- a/pandas/tests/arrays/categorical/test_dtypes.py +++ b/pandas/tests/arrays/categorical/test_dtypes.py @@ -127,7 +127,7 @@ def test_astype(self, ordered): expected = np.array(cat) tm.assert_numpy_array_equal(result, expected) - msg = r"Cannot cast object dtype to " + msg = r"Cannot cast object dtype to float64" with pytest.raises(ValueError, match=msg): cat.astype(float) diff --git a/pandas/tests/dtypes/test_common.py b/pandas/tests/dtypes/test_common.py index 0d0601aa542b4..9e75ba0864e76 100644 --- a/pandas/tests/dtypes/test_common.py +++ b/pandas/tests/dtypes/test_common.py @@ -641,7 +641,6 @@ def test_is_complex_dtype(): (pd.CategoricalIndex(["a", "b"]).dtype, CategoricalDtype(["a", "b"])), (pd.CategoricalIndex(["a", "b"]), CategoricalDtype(["a", "b"])), (CategoricalDtype(), CategoricalDtype()), - (CategoricalDtype(["a", "b"]), CategoricalDtype()), (pd.DatetimeIndex([1, 2]), np.dtype("=M8[ns]")), (pd.DatetimeIndex([1, 2]).dtype, np.dtype("=M8[ns]")), ("