diff --git a/doc/source/whatsnew/v0.24.0.txt b/doc/source/whatsnew/v0.24.0.txt index f246ebad3aa2c..42c4134437ff6 100644 --- a/doc/source/whatsnew/v0.24.0.txt +++ b/doc/source/whatsnew/v0.24.0.txt @@ -755,7 +755,7 @@ Interval - Bug in the :class:`IntervalIndex` constructor where the ``closed`` parameter did not always override the inferred ``closed`` (:issue:`19370`) - Bug in the ``IntervalIndex`` repr where a trailing comma was missing after the list of intervals (:issue:`20611`) - Bug in :class:`Interval` where scalar arithmetic operations did not retain the ``closed`` value (:issue:`22313`) -- +- Bug in :class:`IntervalIndex` where indexing with datetime-like values raised a ``KeyError`` (:issue:`20636`) Indexing ^^^^^^^^ diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index f72f87aeb2af6..25d4dd0cbcc81 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -6,12 +6,14 @@ from pandas.compat import add_metaclass from pandas.core.dtypes.missing import isna -from pandas.core.dtypes.cast import find_common_type, maybe_downcast_to_dtype +from pandas.core.dtypes.cast import ( + find_common_type, maybe_downcast_to_dtype, infer_dtype_from_scalar) from pandas.core.dtypes.common import ( ensure_platform_int, is_list_like, is_datetime_or_timedelta_dtype, is_datetime64tz_dtype, + is_dtype_equal, is_integer_dtype, is_float_dtype, is_interval_dtype, @@ -29,8 +31,8 @@ Interval, IntervalMixin, IntervalTree, ) -from pandas.core.indexes.datetimes import date_range -from pandas.core.indexes.timedeltas import timedelta_range +from pandas.core.indexes.datetimes import date_range, DatetimeIndex +from pandas.core.indexes.timedeltas import timedelta_range, TimedeltaIndex from pandas.core.indexes.multi import MultiIndex import pandas.core.common as com from pandas.util._decorators import cache_readonly, Appender @@ -192,7 +194,9 @@ def _isnan(self): @cache_readonly def _engine(self): - return IntervalTree(self.left, self.right, closed=self.closed) + left = self._maybe_convert_i8(self.left) + right = self._maybe_convert_i8(self.right) + return IntervalTree(left, right, closed=self.closed) def __contains__(self, key): """ @@ -514,6 +518,78 @@ def _maybe_cast_indexed(self, key): return key + def _needs_i8_conversion(self, key): + """ + Check if a given key needs i8 conversion. Conversion is necessary for + Timestamp, Timedelta, DatetimeIndex, and TimedeltaIndex keys. An + Interval-like requires conversion if it's endpoints are one of the + aforementioned types. + + Assumes that any list-like data has already been cast to an Index. + + Parameters + ---------- + key : scalar or Index-like + The key that should be checked for i8 conversion + + Returns + ------- + boolean + """ + if is_interval_dtype(key) or isinstance(key, Interval): + return self._needs_i8_conversion(key.left) + + i8_types = (Timestamp, Timedelta, DatetimeIndex, TimedeltaIndex) + return isinstance(key, i8_types) + + def _maybe_convert_i8(self, key): + """ + Maybe convert a given key to it's equivalent i8 value(s). Used as a + preprocessing step prior to IntervalTree queries (self._engine), which + expects numeric data. + + Parameters + ---------- + key : scalar or list-like + The key that should maybe be converted to i8. + + Returns + ------- + key: scalar or list-like + The original key if no conversion occured, int if converted scalar, + Int64Index if converted list-like. + """ + original = key + if is_list_like(key): + key = ensure_index(key) + + if not self._needs_i8_conversion(key): + return original + + scalar = is_scalar(key) + if is_interval_dtype(key) or isinstance(key, Interval): + # convert left/right and reconstruct + left = self._maybe_convert_i8(key.left) + right = self._maybe_convert_i8(key.right) + constructor = Interval if scalar else IntervalIndex.from_arrays + return constructor(left, right, closed=self.closed) + + if scalar: + # Timestamp/Timedelta + key_dtype, key_i8 = infer_dtype_from_scalar(key, pandas_dtype=True) + else: + # DatetimeIndex/TimedeltaIndex + key_dtype, key_i8 = key.dtype, Index(key.asi8) + + # ensure consistency with IntervalIndex subtype + subtype = self.dtype.subtype + msg = ('Cannot index an IntervalIndex of subtype {subtype} with ' + 'values of dtype {other}') + if not is_dtype_equal(subtype, key_dtype): + raise ValueError(msg.format(subtype=subtype, other=key_dtype)) + + return key_i8 + def _check_method(self, method): if method is None: return @@ -648,6 +724,7 @@ def get_loc(self, key, method=None): else: # use the interval tree + key = self._maybe_convert_i8(key) if isinstance(key, Interval): left, right = _get_interval_closed_bounds(key) return self._engine.get_loc_interval(left, right) @@ -711,8 +788,10 @@ def _get_reindexer(self, target): """ # find the left and right indexers - lindexer = self._engine.get_indexer(target.left.values) - rindexer = self._engine.get_indexer(target.right.values) + left = self._maybe_convert_i8(target.left) + right = self._maybe_convert_i8(target.right) + lindexer = self._engine.get_indexer(left.values) + rindexer = self._engine.get_indexer(right.values) # we want to return an indexer on the intervals # however, our keys could provide overlapping of multiple diff --git a/pandas/tests/indexes/interval/test_interval.py b/pandas/tests/indexes/interval/test_interval.py index 71f56c5bc1164..0ff5ab232d670 100644 --- a/pandas/tests/indexes/interval/test_interval.py +++ b/pandas/tests/indexes/interval/test_interval.py @@ -1,7 +1,9 @@ from __future__ import division +from itertools import permutations import pytest import numpy as np +import re from pandas import ( Interval, IntervalIndex, Index, isna, notna, interval_range, Timestamp, Timedelta, date_range, timedelta_range) @@ -498,6 +500,48 @@ def test_get_loc_length_one(self, item, closed): result = index.get_loc(item) assert result == 0 + # Make consistent with test_interval_new.py (see #16316, #16386) + @pytest.mark.parametrize('breaks', [ + date_range('20180101', periods=4), + date_range('20180101', periods=4, tz='US/Eastern'), + timedelta_range('0 days', periods=4)], ids=lambda x: str(x.dtype)) + def test_get_loc_datetimelike_nonoverlapping(self, breaks): + # GH 20636 + # nonoverlapping = IntervalIndex method and no i8 conversion + index = IntervalIndex.from_breaks(breaks) + + value = index[0].mid + result = index.get_loc(value) + expected = 0 + assert result == expected + + interval = Interval(index[0].left, index[1].right) + result = index.get_loc(interval) + expected = slice(0, 2) + assert result == expected + + # Make consistent with test_interval_new.py (see #16316, #16386) + @pytest.mark.parametrize('arrays', [ + (date_range('20180101', periods=4), date_range('20180103', periods=4)), + (date_range('20180101', periods=4, tz='US/Eastern'), + date_range('20180103', periods=4, tz='US/Eastern')), + (timedelta_range('0 days', periods=4), + timedelta_range('2 days', periods=4))], ids=lambda x: str(x[0].dtype)) + def test_get_loc_datetimelike_overlapping(self, arrays): + # GH 20636 + # overlapping = IntervalTree method with i8 conversion + index = IntervalIndex.from_arrays(*arrays) + + value = index[0].mid + Timedelta('12 hours') + result = np.sort(index.get_loc(value)) + expected = np.array([0, 1], dtype='int64') + assert tm.assert_numpy_array_equal(result, expected) + + interval = Interval(index[0].left, index[1].right) + result = np.sort(index.get_loc(interval)) + expected = np.array([0, 1, 2], dtype='int64') + assert tm.assert_numpy_array_equal(result, expected) + # To be removed, replaced by test_interval_new.py (see #16316, #16386) def test_get_indexer(self): actual = self.index.get_indexer([-1, 0, 0.5, 1, 1.5, 2, 3]) @@ -555,6 +599,97 @@ def test_get_indexer_length_one(self, item, closed): expected = np.array([0] * len(item), dtype='intp') tm.assert_numpy_array_equal(result, expected) + # Make consistent with test_interval_new.py (see #16316, #16386) + @pytest.mark.parametrize('arrays', [ + (date_range('20180101', periods=4), date_range('20180103', periods=4)), + (date_range('20180101', periods=4, tz='US/Eastern'), + date_range('20180103', periods=4, tz='US/Eastern')), + (timedelta_range('0 days', periods=4), + timedelta_range('2 days', periods=4))], ids=lambda x: str(x[0].dtype)) + def test_get_reindexer_datetimelike(self, arrays): + # GH 20636 + index = IntervalIndex.from_arrays(*arrays) + tuples = [(index[0].left, index[0].left + pd.Timedelta('12H')), + (index[-1].right - pd.Timedelta('12H'), index[-1].right)] + target = IntervalIndex.from_tuples(tuples) + + result = index._get_reindexer(target) + expected = np.array([0, 3], dtype='int64') + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize('breaks', [ + date_range('20180101', periods=4), + date_range('20180101', periods=4, tz='US/Eastern'), + timedelta_range('0 days', periods=4)], ids=lambda x: str(x.dtype)) + def test_maybe_convert_i8(self, breaks): + # GH 20636 + index = IntervalIndex.from_breaks(breaks) + + # intervalindex + result = index._maybe_convert_i8(index) + expected = IntervalIndex.from_breaks(breaks.asi8) + tm.assert_index_equal(result, expected) + + # interval + interval = Interval(breaks[0], breaks[1]) + result = index._maybe_convert_i8(interval) + expected = Interval(breaks[0].value, breaks[1].value) + assert result == expected + + # datetimelike index + result = index._maybe_convert_i8(breaks) + expected = Index(breaks.asi8) + tm.assert_index_equal(result, expected) + + # datetimelike scalar + result = index._maybe_convert_i8(breaks[0]) + expected = breaks[0].value + assert result == expected + + # list-like of datetimelike scalars + result = index._maybe_convert_i8(list(breaks)) + expected = Index(breaks.asi8) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize('breaks', [ + np.arange(5, dtype='int64'), + np.arange(5, dtype='float64')], ids=lambda x: str(x.dtype)) + @pytest.mark.parametrize('make_key', [ + IntervalIndex.from_breaks, + lambda breaks: Interval(breaks[0], breaks[1]), + lambda breaks: breaks, + lambda breaks: breaks[0], + list], ids=['IntervalIndex', 'Interval', 'Index', 'scalar', 'list']) + def test_maybe_convert_i8_numeric(self, breaks, make_key): + # GH 20636 + index = IntervalIndex.from_breaks(breaks) + key = make_key(breaks) + + # no conversion occurs for numeric + result = index._maybe_convert_i8(key) + assert result is key + + @pytest.mark.parametrize('breaks1, breaks2', permutations([ + date_range('20180101', periods=4), + date_range('20180101', periods=4, tz='US/Eastern'), + timedelta_range('0 days', periods=4)], 2), ids=lambda x: str(x.dtype)) + @pytest.mark.parametrize('make_key', [ + IntervalIndex.from_breaks, + lambda breaks: Interval(breaks[0], breaks[1]), + lambda breaks: breaks, + lambda breaks: breaks[0], + list], ids=['IntervalIndex', 'Interval', 'Index', 'scalar', 'list']) + def test_maybe_convert_i8_errors(self, breaks1, breaks2, make_key): + # GH 20636 + index = IntervalIndex.from_breaks(breaks1) + key = make_key(breaks2) + + msg = ('Cannot index an IntervalIndex of subtype {dtype1} with ' + 'values of dtype {dtype2}') + msg = re.escape(msg.format(dtype1=breaks1.dtype, dtype2=breaks2.dtype)) + with tm.assert_raises_regex(ValueError, msg): + index._maybe_convert_i8(key) + # To be removed, replaced by test_interval_new.py (see #16316, #16386) def test_contains(self): # Only endpoints are valid.