diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index 0c9db271edf37..9cdb5dfc47776 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -37,7 +37,7 @@ from pandas.core.dtypes.dtypes import CategoricalDtype from pandas.core.dtypes.generic import ABCIndexClass, ABCSeries from pandas.core.dtypes.inference import is_hashable -from pandas.core.dtypes.missing import isna, notna +from pandas.core.dtypes.missing import is_valid_nat_for_dtype, isna, notna from pandas.core import ops from pandas.core.accessor import PandasDelegate, delegate_names @@ -1834,7 +1834,7 @@ def __contains__(self, key) -> bool: Returns True if `key` is in this Categorical. """ # if key is a NaN, check if any NaN is in self. - if is_scalar(key) and isna(key): + if is_valid_nat_for_dtype(key, self.categories.dtype): return self.isna().any() return contains(self, key, container=self._codes) diff --git a/pandas/core/indexes/category.py b/pandas/core/indexes/category.py index 0cf6698d316bb..80d3e5c8a0dbe 100644 --- a/pandas/core/indexes/category.py +++ b/pandas/core/indexes/category.py @@ -19,7 +19,7 @@ is_scalar, ) from pandas.core.dtypes.dtypes import CategoricalDtype -from pandas.core.dtypes.missing import isna +from pandas.core.dtypes.missing import is_valid_nat_for_dtype, isna from pandas.core import accessor from pandas.core.algorithms import take_1d @@ -365,10 +365,9 @@ def _has_complex_internals(self) -> bool: @doc(Index.__contains__) def __contains__(self, key: Any) -> bool: # if key is a NaN, check if any NaN is in self. - if is_scalar(key) and isna(key): + if is_valid_nat_for_dtype(key, self.categories.dtype): return self.hasnans - hash(key) return contains(self, key, container=self._engine) @doc(Index.astype) diff --git a/pandas/tests/indexes/categorical/test_indexing.py b/pandas/tests/indexes/categorical/test_indexing.py index a36568bbbe633..9cf901c0797d8 100644 --- a/pandas/tests/indexes/categorical/test_indexing.py +++ b/pandas/tests/indexes/categorical/test_indexing.py @@ -285,6 +285,43 @@ def test_contains_nan(self): ci = CategoricalIndex(list("aabbca") + [np.nan], categories=list("cabdef")) assert np.nan in ci + @pytest.mark.parametrize("unwrap", [True, False]) + def test_contains_na_dtype(self, unwrap): + dti = pd.date_range("2016-01-01", periods=100).insert(0, pd.NaT) + pi = dti.to_period("D") + tdi = dti - dti[-1] + ci = CategoricalIndex(dti) + + obj = ci + if unwrap: + obj = ci._data + + assert np.nan in obj + assert None in obj + assert pd.NaT in obj + assert np.datetime64("NaT") in obj + assert np.timedelta64("NaT") not in obj + + obj2 = CategoricalIndex(tdi) + if unwrap: + obj2 = obj2._data + + assert np.nan in obj2 + assert None in obj2 + assert pd.NaT in obj2 + assert np.datetime64("NaT") not in obj2 + assert np.timedelta64("NaT") in obj2 + + obj3 = CategoricalIndex(pi) + if unwrap: + obj3 = obj3._data + + assert np.nan in obj3 + assert None in obj3 + assert pd.NaT in obj3 + assert np.datetime64("NaT") not in obj3 + assert np.timedelta64("NaT") not in obj3 + @pytest.mark.parametrize( "item, expected", [