diff --git a/pandas/_libs/index.pyx b/pandas/_libs/index.pyx index e5e3b27c41721..e4ec9db560b80 100644 --- a/pandas/_libs/index.pyx +++ b/pandas/_libs/index.pyx @@ -72,9 +72,10 @@ cdef class IndexEngine: self.over_size_threshold = n >= _SIZE_CUTOFF self.clear_mapping() - def __contains__(self, object val): + def __contains__(self, val: object) -> bool: + # We assume before we get here: + # - val is hashable self._ensure_mapping_populated() - hash(val) return val in self.mapping cpdef get_value(self, ndarray arr, object key, object tz=None): @@ -415,7 +416,9 @@ cdef class DatetimeEngine(Int64Engine): raise TypeError(scalar) return scalar.value - def __contains__(self, object val): + def __contains__(self, val: object) -> bool: + # We assume before we get here: + # - val is hashable cdef: int64_t loc, conv @@ -712,7 +715,9 @@ cdef class BaseMultiIndexCodesEngine: return indexer - def __contains__(self, object val): + def __contains__(self, val: object) -> bool: + # We assume before we get here: + # - val is hashable # Default __contains__ looks in the underlying mapping, which in this # case only contains integer representations. try: diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index dc74840958e1f..98e5ed678f945 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -1,7 +1,7 @@ from datetime import datetime import operator from textwrap import dedent -from typing import Dict, FrozenSet, Hashable, Optional, Union +from typing import Any, Dict, FrozenSet, Hashable, Optional, Union import warnings import numpy as np @@ -4145,7 +4145,7 @@ def is_type_compatible(self, kind) -> bool: """ @Appender(_index_shared_docs["contains"] % _index_doc_kwargs) - def __contains__(self, key) -> bool: + def __contains__(self, key: Any) -> bool: hash(key) try: return key in self._engine diff --git a/pandas/core/indexes/category.py b/pandas/core/indexes/category.py index 0ff6469d6b19c..268ab9ba4e4c4 100644 --- a/pandas/core/indexes/category.py +++ b/pandas/core/indexes/category.py @@ -385,11 +385,12 @@ def _wrap_setop_result(self, other, result): return self._shallow_copy(result, name=name) @Appender(_index_shared_docs["contains"] % _index_doc_kwargs) - def __contains__(self, key) -> bool: + def __contains__(self, key: Any) -> bool: # if key is a NaN, check if any NaN is in self. if is_scalar(key) and isna(key): return self.hasnans + hash(key) return contains(self, key, container=self._engine) def __array__(self, dtype=None) -> np.ndarray: diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index bf1272b223f70..2ba1d3a188c01 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -2,7 +2,7 @@ Base and utility classes for tseries type pandas objects. """ import operator -from typing import List, Optional, Set +from typing import Any, List, Optional, Set import numpy as np @@ -153,7 +153,8 @@ def equals(self, other) -> bool: return np.array_equal(self.asi8, other.asi8) @Appender(_index_shared_docs["contains"] % _index_doc_kwargs) - def __contains__(self, key): + def __contains__(self, key: Any) -> bool: + hash(key) try: res = self.get_loc(key) except (KeyError, TypeError, ValueError): diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index 523d6404f5efa..3108c1a1afd0c 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -374,7 +374,7 @@ def _engine(self): right = self._maybe_convert_i8(self.right) return IntervalTree(left, right, closed=self.closed) - def __contains__(self, key) -> bool: + def __contains__(self, key: Any) -> bool: """ return a boolean if this key is IN the index We *only* accept an Interval @@ -387,6 +387,7 @@ def __contains__(self, key) -> bool: ------- bool """ + hash(key) if not isinstance(key, Interval): return False diff --git a/pandas/core/indexes/multi.py b/pandas/core/indexes/multi.py index 10a2d9f68a7b6..8682af6ab6369 100644 --- a/pandas/core/indexes/multi.py +++ b/pandas/core/indexes/multi.py @@ -1,6 +1,6 @@ import datetime from sys import getsizeof -from typing import Hashable, List, Optional, Sequence, Union +from typing import Any, Hashable, List, Optional, Sequence, Union import warnings import numpy as np @@ -973,7 +973,7 @@ def _shallow_copy_with_infer(self, values, **kwargs): return self._shallow_copy(values, **kwargs) @Appender(_index_shared_docs["contains"] % _index_doc_kwargs) - def __contains__(self, key) -> bool: + def __contains__(self, key: Any) -> bool: hash(key) try: self.get_loc(key) diff --git a/pandas/core/indexes/numeric.py b/pandas/core/indexes/numeric.py index def77ffbea591..465f21da1278a 100644 --- a/pandas/core/indexes/numeric.py +++ b/pandas/core/indexes/numeric.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import numpy as np @@ -461,7 +461,8 @@ def equals(self, other) -> bool: except (TypeError, ValueError): return False - def __contains__(self, other) -> bool: + def __contains__(self, other: Any) -> bool: + hash(other) if super().__contains__(other): return True diff --git a/pandas/core/indexes/period.py b/pandas/core/indexes/period.py index 35f96e61704f0..af6361826a76d 100644 --- a/pandas/core/indexes/period.py +++ b/pandas/core/indexes/period.py @@ -1,4 +1,5 @@ from datetime import datetime, timedelta +from typing import Any import weakref import numpy as np @@ -369,18 +370,18 @@ def _engine(self): return self._engine_type(period, len(self)) @Appender(_index_shared_docs["contains"]) - def __contains__(self, key) -> bool: + def __contains__(self, key: Any) -> bool: if isinstance(key, Period): if key.freq != self.freq: return False else: return key.ordinal in self._engine else: + hash(key) try: self.get_loc(key) return True - except (TypeError, KeyError): - # TypeError can be reached if we pass a tuple that is not hashable + except KeyError: return False @cache_readonly diff --git a/pandas/core/indexes/range.py b/pandas/core/indexes/range.py index 336f65ca574dc..22940f851ddb0 100644 --- a/pandas/core/indexes/range.py +++ b/pandas/core/indexes/range.py @@ -1,7 +1,7 @@ from datetime import timedelta import operator from sys import getsizeof -from typing import Optional, Union +from typing import Any, Optional import warnings import numpy as np @@ -332,7 +332,7 @@ def is_monotonic_decreasing(self) -> bool: def has_duplicates(self) -> bool: return False - def __contains__(self, key: Union[int, np.integer]) -> bool: + def __contains__(self, key: Any) -> bool: hash(key) try: key = ensure_python_int(key) diff --git a/pandas/tests/indexes/common.py b/pandas/tests/indexes/common.py index afc068d6696ef..f3ebe8313d0c6 100644 --- a/pandas/tests/indexes/common.py +++ b/pandas/tests/indexes/common.py @@ -883,3 +883,11 @@ def test_getitem_2d_deprecated(self): res = idx[:, None] assert isinstance(res, np.ndarray), type(res) + + def test_contains_requires_hashable_raises(self): + idx = self.create_index() + with pytest.raises(TypeError, match="unhashable type"): + [] in idx + + with pytest.raises(TypeError): + {} in idx._engine