diff --git a/pandas/_libs/index.pyx b/pandas/_libs/index.pyx index e4ec9db560b80..2dfc14378baf6 100644 --- a/pandas/_libs/index.pyx +++ b/pandas/_libs/index.pyx @@ -26,19 +26,13 @@ from pandas._libs import algos, hashtable as _hash from pandas._libs.tslibs import Timestamp, Timedelta, period as periodlib from pandas._libs.missing import checknull -cdef int64_t NPY_NAT = util.get_nat() - cdef inline bint is_definitely_invalid_key(object val): - if isinstance(val, tuple): - try: - hash(val) - except TypeError: - return True - - # we have a _data, means we are a NDFrame - return (isinstance(val, slice) or util.is_array(val) - or isinstance(val, list) or hasattr(val, '_data')) + try: + hash(val) + except TypeError: + return True + return False cpdef get_value_at(ndarray arr, object loc, object tz=None): @@ -168,6 +162,15 @@ cdef class IndexEngine: int count indexer = self._get_index_values() == val + return self._unpack_bool_indexer(indexer, val) + + cdef _unpack_bool_indexer(self, + ndarray[uint8_t, ndim=1, cast=True] indexer, + object val): + cdef: + ndarray[intp_t, ndim=1] found + int count + found = np.where(indexer)[0] count = len(found) @@ -446,7 +449,7 @@ cdef class DatetimeEngine(Int64Engine): cdef: int64_t loc if is_definitely_invalid_key(val): - raise TypeError + raise TypeError(f"'{val}' is an invalid key") try: conv = self._unbox_scalar(val) @@ -651,7 +654,10 @@ cdef class BaseMultiIndexCodesEngine: # integers representing labels: we will use its get_loc and get_indexer self._base.__init__(self, lambda: lab_ints, len(lab_ints)) - def _extract_level_codes(self, object target, object method=None): + def _codes_to_ints(self, codes): + raise NotImplementedError("Implemented by subclass") + + def _extract_level_codes(self, object target): """ Map the requested list of (tuple) keys to their integer representations for searching in the underlying integer index. diff --git a/pandas/_libs/index_class_helper.pxi.in b/pandas/_libs/index_class_helper.pxi.in index cd2b9fbe7d6d6..c7b67667bda17 100644 --- a/pandas/_libs/index_class_helper.pxi.in +++ b/pandas/_libs/index_class_helper.pxi.in @@ -10,24 +10,26 @@ WARNING: DO NOT edit .pxi FILE directly, .pxi is generated from .pxi.in {{py: -# name, dtype, ctype, hashtable_name, hashtable_dtype -dtypes = [('Float64', 'float64', 'float64_t', 'Float64', 'float64'), - ('Float32', 'float32', 'float32_t', 'Float64', 'float64'), - ('Int64', 'int64', 'int64_t', 'Int64', 'int64'), - ('Int32', 'int32', 'int32_t', 'Int64', 'int64'), - ('Int16', 'int16', 'int16_t', 'Int64', 'int64'), - ('Int8', 'int8', 'int8_t', 'Int64', 'int64'), - ('UInt64', 'uint64', 'uint64_t', 'UInt64', 'uint64'), - ('UInt32', 'uint32', 'uint32_t', 'UInt64', 'uint64'), - ('UInt16', 'uint16', 'uint16_t', 'UInt64', 'uint64'), - ('UInt8', 'uint8', 'uint8_t', 'UInt64', 'uint64'), +# name, dtype, hashtable_name +dtypes = [('Float64', 'float64', 'Float64'), + ('Float32', 'float32', 'Float64'), + ('Int64', 'int64', 'Int64'), + ('Int32', 'int32', 'Int64'), + ('Int16', 'int16', 'Int64'), + ('Int8', 'int8', 'Int64'), + ('UInt64', 'uint64', 'UInt64'), + ('UInt32', 'uint32', 'UInt64'), + ('UInt16', 'uint16', 'UInt64'), + ('UInt8', 'uint8', 'UInt64'), ] }} -{{for name, dtype, ctype, hashtable_name, hashtable_dtype in dtypes}} +{{for name, dtype, hashtable_name in dtypes}} cdef class {{name}}Engine(IndexEngine): + # constructor-caller is responsible for ensuring that vgetter() + # returns an ndarray with dtype {{dtype}}_t cdef _make_hash_table(self, Py_ssize_t n): return _hash.{{hashtable_name}}HashTable(n) @@ -41,22 +43,18 @@ cdef class {{name}}Engine(IndexEngine): cdef void _call_map_locations(self, values): # self.mapping is of type {{hashtable_name}}HashTable, # so convert dtype of values - self.mapping.map_locations(algos.ensure_{{hashtable_dtype}}(values)) - - cdef _get_index_values(self): - return algos.ensure_{{dtype}}(self.vgetter()) + self.mapping.map_locations(algos.ensure_{{hashtable_name.lower()}}(values)) cdef _maybe_get_bool_indexer(self, object val): cdef: ndarray[uint8_t, ndim=1, cast=True] indexer ndarray[intp_t, ndim=1] found - ndarray[{{ctype}}] values + ndarray[{{dtype}}_t, ndim=1] values int count = 0 self._check_type(val) - # A view is needed for some subclasses, such as PeriodEngine: - values = self._get_index_values().view('{{dtype}}') + values = self._get_index_values() try: with warnings.catch_warnings(): # e.g. if values is float64 and `val` is a str, suppress warning @@ -67,14 +65,6 @@ cdef class {{name}}Engine(IndexEngine): # when trying to cast it to ndarray raise KeyError(val) - found = np.where(indexer)[0] - count = len(found) - - if count > 1: - return indexer - if count == 1: - return int(found[0]) - - raise KeyError(val) + return self._unpack_bool_indexer(indexer, val) {{endfor}}