Skip to content

REF: simplify index.pyx #31168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions pandas/_libs/index.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,13 @@ from pandas._libs import algos, hashtable as _hash
from pandas._libs.tslibs import Timestamp, Timedelta, period as periodlib
from pandas._libs.missing import checknull

cdef int64_t NPY_NAT = util.get_nat()


cdef inline bint is_definitely_invalid_key(object val):
if isinstance(val, tuple):
try:
hash(val)
except TypeError:
return True

# we have a _data, means we are a NDFrame
return (isinstance(val, slice) or util.is_array(val)
or isinstance(val, list) or hasattr(val, '_data'))
try:
hash(val)
except TypeError:
return True
return False


cpdef get_value_at(ndarray arr, object loc, object tz=None):
Expand Down Expand Up @@ -168,6 +162,15 @@ cdef class IndexEngine:
int count

indexer = self._get_index_values() == val
return self._unpack_bool_indexer(indexer, val)

cdef _unpack_bool_indexer(self,
ndarray[uint8_t, ndim=1, cast=True] indexer,
object val):
cdef:
ndarray[intp_t, ndim=1] found
int count

found = np.where(indexer)[0]
count = len(found)

Expand Down Expand Up @@ -446,7 +449,7 @@ cdef class DatetimeEngine(Int64Engine):
cdef:
int64_t loc
if is_definitely_invalid_key(val):
raise TypeError
raise TypeError(f"'{val}' is an invalid key")

try:
conv = self._unbox_scalar(val)
Expand Down Expand Up @@ -651,7 +654,10 @@ cdef class BaseMultiIndexCodesEngine:
# integers representing labels: we will use its get_loc and get_indexer
self._base.__init__(self, lambda: lab_ints, len(lab_ints))

def _extract_level_codes(self, object target, object method=None):
def _codes_to_ints(self, codes):
raise NotImplementedError("Implemented by subclass")

def _extract_level_codes(self, object target):
"""
Map the requested list of (tuple) keys to their integer representations
for searching in the underlying integer index.
Expand Down
46 changes: 18 additions & 28 deletions pandas/_libs/index_class_helper.pxi.in
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,26 @@ WARNING: DO NOT edit .pxi FILE directly, .pxi is generated from .pxi.in

{{py:

# name, dtype, ctype, hashtable_name, hashtable_dtype
dtypes = [('Float64', 'float64', 'float64_t', 'Float64', 'float64'),
('Float32', 'float32', 'float32_t', 'Float64', 'float64'),
('Int64', 'int64', 'int64_t', 'Int64', 'int64'),
('Int32', 'int32', 'int32_t', 'Int64', 'int64'),
('Int16', 'int16', 'int16_t', 'Int64', 'int64'),
('Int8', 'int8', 'int8_t', 'Int64', 'int64'),
('UInt64', 'uint64', 'uint64_t', 'UInt64', 'uint64'),
('UInt32', 'uint32', 'uint32_t', 'UInt64', 'uint64'),
('UInt16', 'uint16', 'uint16_t', 'UInt64', 'uint64'),
('UInt8', 'uint8', 'uint8_t', 'UInt64', 'uint64'),
# name, dtype, hashtable_name
dtypes = [('Float64', 'float64', 'Float64'),
('Float32', 'float32', 'Float64'),
('Int64', 'int64', 'Int64'),
('Int32', 'int32', 'Int64'),
('Int16', 'int16', 'Int64'),
('Int8', 'int8', 'Int64'),
('UInt64', 'uint64', 'UInt64'),
('UInt32', 'uint32', 'UInt64'),
('UInt16', 'uint16', 'UInt64'),
('UInt8', 'uint8', 'UInt64'),
]
}}

{{for name, dtype, ctype, hashtable_name, hashtable_dtype in dtypes}}
{{for name, dtype, hashtable_name in dtypes}}


cdef class {{name}}Engine(IndexEngine):
# constructor-caller is responsible for ensuring that vgetter()
# returns an ndarray with dtype {{dtype}}_t

cdef _make_hash_table(self, Py_ssize_t n):
return _hash.{{hashtable_name}}HashTable(n)
Expand All @@ -41,22 +43,18 @@ cdef class {{name}}Engine(IndexEngine):
cdef void _call_map_locations(self, values):
# self.mapping is of type {{hashtable_name}}HashTable,
# so convert dtype of values
self.mapping.map_locations(algos.ensure_{{hashtable_dtype}}(values))

cdef _get_index_values(self):
return algos.ensure_{{dtype}}(self.vgetter())
self.mapping.map_locations(algos.ensure_{{hashtable_name.lower()}}(values))

cdef _maybe_get_bool_indexer(self, object val):
cdef:
ndarray[uint8_t, ndim=1, cast=True] indexer
ndarray[intp_t, ndim=1] found
ndarray[{{ctype}}] values
ndarray[{{dtype}}_t, ndim=1] values
int count = 0

self._check_type(val)

# A view is needed for some subclasses, such as PeriodEngine:
values = self._get_index_values().view('{{dtype}}')
values = self._get_index_values()
try:
with warnings.catch_warnings():
# e.g. if values is float64 and `val` is a str, suppress warning
Expand All @@ -67,14 +65,6 @@ cdef class {{name}}Engine(IndexEngine):
# when trying to cast it to ndarray
raise KeyError(val)

found = np.where(indexer)[0]
count = len(found)

if count > 1:
return indexer
if count == 1:
return int(found[0])

raise KeyError(val)
return self._unpack_bool_indexer(indexer, val)

{{endfor}}