From 02f19f81f68e3c411f83d2341fa936a07f13a634 Mon Sep 17 00:00:00 2001 From: Jeremy Schendel Date: Wed, 3 Oct 2018 23:36:49 -0600 Subject: [PATCH 1/3] BUG: Perform i8 conversion for datetimelike IntervalTree queries --- doc/source/whatsnew/v0.24.0.txt | 2 +- pandas/core/indexes/interval.py | 57 +++++++- .../tests/indexes/interval/test_interval.py | 135 ++++++++++++++++++ 3 files changed, 186 insertions(+), 8 deletions(-) diff --git a/doc/source/whatsnew/v0.24.0.txt b/doc/source/whatsnew/v0.24.0.txt index f246ebad3aa2c..42c4134437ff6 100644 --- a/doc/source/whatsnew/v0.24.0.txt +++ b/doc/source/whatsnew/v0.24.0.txt @@ -755,7 +755,7 @@ Interval - Bug in the :class:`IntervalIndex` constructor where the ``closed`` parameter did not always override the inferred ``closed`` (:issue:`19370`) - Bug in the ``IntervalIndex`` repr where a trailing comma was missing after the list of intervals (:issue:`20611`) - Bug in :class:`Interval` where scalar arithmetic operations did not retain the ``closed`` value (:issue:`22313`) -- +- Bug in :class:`IntervalIndex` where indexing with datetime-like values raised a ``KeyError`` (:issue:`20636`) Indexing ^^^^^^^^ diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index f72f87aeb2af6..9b8570796a455 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -6,12 +6,14 @@ from pandas.compat import add_metaclass from pandas.core.dtypes.missing import isna -from pandas.core.dtypes.cast import find_common_type, maybe_downcast_to_dtype +from pandas.core.dtypes.cast import ( + find_common_type, maybe_downcast_to_dtype, infer_dtype_from_scalar) from pandas.core.dtypes.common import ( ensure_platform_int, is_list_like, is_datetime_or_timedelta_dtype, is_datetime64tz_dtype, + is_dtype_equal, is_integer_dtype, is_float_dtype, is_interval_dtype, @@ -19,7 +21,8 @@ is_scalar, is_float, is_number, - is_integer) + is_integer, + needs_i8_conversion) from pandas.core.indexes.base import ( Index, ensure_index, default_pprint, _index_shared_docs) @@ -29,8 +32,8 @@ Interval, IntervalMixin, IntervalTree, ) -from pandas.core.indexes.datetimes import date_range -from pandas.core.indexes.timedeltas import timedelta_range +from pandas.core.indexes.datetimes import date_range, DatetimeIndex +from pandas.core.indexes.timedeltas import timedelta_range, TimedeltaIndex from pandas.core.indexes.multi import MultiIndex import pandas.core.common as com from pandas.util._decorators import cache_readonly, Appender @@ -192,7 +195,9 @@ def _isnan(self): @cache_readonly def _engine(self): - return IntervalTree(self.left, self.right, closed=self.closed) + left = self._maybe_convert_i8(self.left) + right = self._maybe_convert_i8(self.right) + return IntervalTree(left, right, closed=self.closed) def __contains__(self, key): """ @@ -514,6 +519,41 @@ def _maybe_cast_indexed(self, key): return key + def _maybe_convert_i8(self, key): + if isinstance(key, Interval): + if not isinstance(key.left, (Timestamp, Timedelta)): + return key + left = self._maybe_convert_i8(key.left) + right = self._maybe_convert_i8(key.right) + return Interval(left, right, key.closed) + elif isinstance(key, (IntervalIndex, IntervalArray)): + if not needs_i8_conversion(key.left): + return key + left = self._maybe_convert_i8(key.left) + right = self._maybe_convert_i8(key.right) + return IntervalIndex.from_arrays(left, right, key.closed) + elif is_list_like(key) and not isinstance(key, Index): + result = self._maybe_convert_i8(ensure_index(key)) + if result[0] == key[0]: + # return the list-like key if no conversion + return key + return result + + subtype = self.dtype.subtype + msg = ('Cannot index an IntervalIndex of subtype {subtype} with ' + 'values of dtype {other}') + if isinstance(key, (Timestamp, Timedelta)): + key_dtype, key_i8 = infer_dtype_from_scalar(key, pandas_dtype=True) + if not is_dtype_equal(subtype, key_dtype): + raise ValueError(msg.format(subtype=subtype, other=key_dtype)) + return key_i8 + elif isinstance(key, (DatetimeIndex, TimedeltaIndex)): + if not is_dtype_equal(subtype, key.dtype): + raise ValueError(msg.format(subtype=subtype, other=key.dtype)) + return Index(key.asi8) + + return key + def _check_method(self, method): if method is None: return @@ -648,6 +688,7 @@ def get_loc(self, key, method=None): else: # use the interval tree + key = self._maybe_convert_i8(key) if isinstance(key, Interval): left, right = _get_interval_closed_bounds(key) return self._engine.get_loc_interval(left, right) @@ -711,8 +752,10 @@ def _get_reindexer(self, target): """ # find the left and right indexers - lindexer = self._engine.get_indexer(target.left.values) - rindexer = self._engine.get_indexer(target.right.values) + left = self._maybe_convert_i8(target.left) + right = self._maybe_convert_i8(target.right) + lindexer = self._engine.get_indexer(left.values) + rindexer = self._engine.get_indexer(right.values) # we want to return an indexer on the intervals # however, our keys could provide overlapping of multiple diff --git a/pandas/tests/indexes/interval/test_interval.py b/pandas/tests/indexes/interval/test_interval.py index 71f56c5bc1164..559b03d51407d 100644 --- a/pandas/tests/indexes/interval/test_interval.py +++ b/pandas/tests/indexes/interval/test_interval.py @@ -1,7 +1,9 @@ from __future__ import division +from itertools import permutations import pytest import numpy as np +import re from pandas import ( Interval, IntervalIndex, Index, isna, notna, interval_range, Timestamp, Timedelta, date_range, timedelta_range) @@ -498,6 +500,48 @@ def test_get_loc_length_one(self, item, closed): result = index.get_loc(item) assert result == 0 + # Make consistent with test_interval_new.py (see #16316, #16386) + @pytest.mark.parametrize('breaks', [ + date_range('20180101', periods=4), + date_range('20180101', periods=4, tz='US/Eastern'), + timedelta_range('0 days', periods=4)], ids=lambda x: str(x.dtype)) + def test_get_loc_datetimelike_nonoverlapping(self, breaks): + # GH 20636 + # nonoverlapping = IntervalIndex method and no i8 conversion + index = IntervalIndex.from_breaks(breaks) + + value = index[0].mid + result = index.get_loc(value) + expected = 0 + assert result == expected + + interval = Interval(index[0].left, index[1].right) + result = index.get_loc(interval) + expected = slice(0, 2) + assert result == expected + + # Make consistent with test_interval_new.py (see #16316, #16386) + @pytest.mark.parametrize('arrays', [ + (date_range('20180101', periods=4), date_range('20180103', periods=4)), + (date_range('20180101', periods=4, tz='US/Eastern'), + date_range('20180103', periods=4, tz='US/Eastern')), + (timedelta_range('0 days', periods=4), + timedelta_range('2 days', periods=4))], ids=lambda x: str(x[0].dtype)) + def test_get_loc_datetimelike_overlapping(self, arrays): + # GH 20636 + # overlapping = IntervalTree method with i8 conversion + index = IntervalIndex.from_arrays(*arrays) + + value = index[0].mid + Timedelta('12 hours') + result = np.sort(index.get_loc(value)) + expected = np.array([0, 1], dtype='int64') + assert tm.assert_numpy_array_equal(result, expected) + + interval = Interval(index[0].left, index[1].right) + result = np.sort(index.get_loc(interval)) + expected = np.array([0, 1, 2], dtype='int64') + assert tm.assert_numpy_array_equal(result, expected) + # To be removed, replaced by test_interval_new.py (see #16316, #16386) def test_get_indexer(self): actual = self.index.get_indexer([-1, 0, 0.5, 1, 1.5, 2, 3]) @@ -555,6 +599,97 @@ def test_get_indexer_length_one(self, item, closed): expected = np.array([0] * len(item), dtype='intp') tm.assert_numpy_array_equal(result, expected) + # Make consistent with test_interval_new.py (see #16316, #16386) + @pytest.mark.parametrize('arrays', [ + (date_range('20180101', periods=4), date_range('20180103', periods=4)), + (date_range('20180101', periods=4, tz='US/Eastern'), + date_range('20180103', periods=4, tz='US/Eastern')), + (timedelta_range('0 days', periods=4), + timedelta_range('2 days', periods=4))], ids=lambda x: str(x[0].dtype)) + def test_get_reindexer_datetimelike(self, arrays): + # GH 20636 + index = IntervalIndex.from_arrays(*arrays) + tuples = [(index[0].left, index[0].left + pd.Timedelta('12H')), + (index[-1].right - pd.Timedelta('12H'), index[-1].right)] + target = IntervalIndex.from_tuples(tuples) + + result = index._get_reindexer(target) + expected = np.array([0, 3], dtype='int64') + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize('breaks', [ + date_range('20180101', periods=4), + date_range('20180101', periods=4, tz='US/Eastern'), + timedelta_range('0 days', periods=4)], ids=lambda x: str(x.dtype)) + def test_maybe_convert_i8(self, breaks): + # GH 20636 + index = IntervalIndex.from_breaks(breaks) + + # intervalindex + result = index._maybe_convert_i8(index) + expected = IntervalIndex.from_breaks(breaks.asi8) + tm.assert_index_equal(result, expected) + + # interval + interval = Interval(breaks[0], breaks[1]) + result = index._maybe_convert_i8(interval) + expected = Interval(breaks[0].value, breaks[1].value) + assert result == expected + + # datetimelike index + result = index._maybe_convert_i8(breaks) + expected = Index(breaks.asi8) + tm.assert_index_equal(result, expected) + + # datetimelike scalar + result = index._maybe_convert_i8(breaks[0]) + expected = breaks[0].value + assert result == expected + + # list-like of datetimelike scalars + result = index._maybe_convert_i8(list(breaks)) + expected = Index(breaks.asi8) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize('breaks', [ + np.arange(5, dtype='int64'), + np.arange(5, dtype='float64')], ids=lambda x: str(x.dtype)) + def test_maybe_convert_i8_numeric(self, breaks): + # GH 20636 + index = IntervalIndex.from_breaks(breaks) + numeric_keys = [ + IntervalIndex.from_breaks(breaks), + Interval(breaks[0], breaks[1]), + breaks, + breaks[0], + list(breaks)] + + # no conversion occurs for numeric + for key in numeric_keys: + result = index._maybe_convert_i8(key) + assert result is key + + @pytest.mark.parametrize('breaks1, breaks2', permutations([ + date_range('20180101', periods=4), + date_range('20180101', periods=4, tz='US/Eastern'), + timedelta_range('0 days', periods=4)], 2), ids=lambda x: str(x.dtype)) + def test_maybe_convert_i8_errors(self, breaks1, breaks2): + # GH 20636 + index = IntervalIndex.from_breaks(breaks1) + invalid_keys = [ + IntervalIndex.from_breaks(breaks2), + Interval(breaks2[0], breaks2[1]), + breaks2, + breaks2[0], + list(breaks2)] + + msg = ('Cannot index an IntervalIndex of subtype {dtype1} with ' + 'values of dtype {dtype2}') + msg = re.escape(msg.format(dtype1=breaks1.dtype, dtype2=breaks2.dtype)) + for key in invalid_keys: + with tm.assert_raises_regex(ValueError, msg): + index._maybe_convert_i8(key) + # To be removed, replaced by test_interval_new.py (see #16316, #16386) def test_contains(self): # Only endpoints are valid. From d2c8b46f5fb7794df3a70157a80587079efbb44a Mon Sep 17 00:00:00 2001 From: Jeremy Schendel Date: Sat, 6 Oct 2018 10:44:19 -0600 Subject: [PATCH 2/3] document and simplify _maybe_convert_i8 --- pandas/core/indexes/interval.py | 92 +++++++++++++++++++++++---------- 1 file changed, 64 insertions(+), 28 deletions(-) diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index 9b8570796a455..25d4dd0cbcc81 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -21,8 +21,7 @@ is_scalar, is_float, is_number, - is_integer, - needs_i8_conversion) + is_integer) from pandas.core.indexes.base import ( Index, ensure_index, default_pprint, _index_shared_docs) @@ -519,40 +518,77 @@ def _maybe_cast_indexed(self, key): return key + def _needs_i8_conversion(self, key): + """ + Check if a given key needs i8 conversion. Conversion is necessary for + Timestamp, Timedelta, DatetimeIndex, and TimedeltaIndex keys. An + Interval-like requires conversion if it's endpoints are one of the + aforementioned types. + + Assumes that any list-like data has already been cast to an Index. + + Parameters + ---------- + key : scalar or Index-like + The key that should be checked for i8 conversion + + Returns + ------- + boolean + """ + if is_interval_dtype(key) or isinstance(key, Interval): + return self._needs_i8_conversion(key.left) + + i8_types = (Timestamp, Timedelta, DatetimeIndex, TimedeltaIndex) + return isinstance(key, i8_types) + def _maybe_convert_i8(self, key): - if isinstance(key, Interval): - if not isinstance(key.left, (Timestamp, Timedelta)): - return key - left = self._maybe_convert_i8(key.left) - right = self._maybe_convert_i8(key.right) - return Interval(left, right, key.closed) - elif isinstance(key, (IntervalIndex, IntervalArray)): - if not needs_i8_conversion(key.left): - return key + """ + Maybe convert a given key to it's equivalent i8 value(s). Used as a + preprocessing step prior to IntervalTree queries (self._engine), which + expects numeric data. + + Parameters + ---------- + key : scalar or list-like + The key that should maybe be converted to i8. + + Returns + ------- + key: scalar or list-like + The original key if no conversion occured, int if converted scalar, + Int64Index if converted list-like. + """ + original = key + if is_list_like(key): + key = ensure_index(key) + + if not self._needs_i8_conversion(key): + return original + + scalar = is_scalar(key) + if is_interval_dtype(key) or isinstance(key, Interval): + # convert left/right and reconstruct left = self._maybe_convert_i8(key.left) right = self._maybe_convert_i8(key.right) - return IntervalIndex.from_arrays(left, right, key.closed) - elif is_list_like(key) and not isinstance(key, Index): - result = self._maybe_convert_i8(ensure_index(key)) - if result[0] == key[0]: - # return the list-like key if no conversion - return key - return result + constructor = Interval if scalar else IntervalIndex.from_arrays + return constructor(left, right, closed=self.closed) + if scalar: + # Timestamp/Timedelta + key_dtype, key_i8 = infer_dtype_from_scalar(key, pandas_dtype=True) + else: + # DatetimeIndex/TimedeltaIndex + key_dtype, key_i8 = key.dtype, Index(key.asi8) + + # ensure consistency with IntervalIndex subtype subtype = self.dtype.subtype msg = ('Cannot index an IntervalIndex of subtype {subtype} with ' 'values of dtype {other}') - if isinstance(key, (Timestamp, Timedelta)): - key_dtype, key_i8 = infer_dtype_from_scalar(key, pandas_dtype=True) - if not is_dtype_equal(subtype, key_dtype): - raise ValueError(msg.format(subtype=subtype, other=key_dtype)) - return key_i8 - elif isinstance(key, (DatetimeIndex, TimedeltaIndex)): - if not is_dtype_equal(subtype, key.dtype): - raise ValueError(msg.format(subtype=subtype, other=key.dtype)) - return Index(key.asi8) + if not is_dtype_equal(subtype, key_dtype): + raise ValueError(msg.format(subtype=subtype, other=key_dtype)) - return key + return key_i8 def _check_method(self, method): if method is None: From f338f3d8814929240ee7a55e37e8f4bca508a6ef Mon Sep 17 00:00:00 2001 From: Jeremy Schendel Date: Sat, 6 Oct 2018 10:45:01 -0600 Subject: [PATCH 3/3] parametrize some tests --- .../tests/indexes/interval/test_interval.py | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/pandas/tests/indexes/interval/test_interval.py b/pandas/tests/indexes/interval/test_interval.py index 559b03d51407d..0ff5ab232d670 100644 --- a/pandas/tests/indexes/interval/test_interval.py +++ b/pandas/tests/indexes/interval/test_interval.py @@ -654,41 +654,41 @@ def test_maybe_convert_i8(self, breaks): @pytest.mark.parametrize('breaks', [ np.arange(5, dtype='int64'), np.arange(5, dtype='float64')], ids=lambda x: str(x.dtype)) - def test_maybe_convert_i8_numeric(self, breaks): + @pytest.mark.parametrize('make_key', [ + IntervalIndex.from_breaks, + lambda breaks: Interval(breaks[0], breaks[1]), + lambda breaks: breaks, + lambda breaks: breaks[0], + list], ids=['IntervalIndex', 'Interval', 'Index', 'scalar', 'list']) + def test_maybe_convert_i8_numeric(self, breaks, make_key): # GH 20636 index = IntervalIndex.from_breaks(breaks) - numeric_keys = [ - IntervalIndex.from_breaks(breaks), - Interval(breaks[0], breaks[1]), - breaks, - breaks[0], - list(breaks)] + key = make_key(breaks) # no conversion occurs for numeric - for key in numeric_keys: - result = index._maybe_convert_i8(key) - assert result is key + result = index._maybe_convert_i8(key) + assert result is key @pytest.mark.parametrize('breaks1, breaks2', permutations([ date_range('20180101', periods=4), date_range('20180101', periods=4, tz='US/Eastern'), timedelta_range('0 days', periods=4)], 2), ids=lambda x: str(x.dtype)) - def test_maybe_convert_i8_errors(self, breaks1, breaks2): + @pytest.mark.parametrize('make_key', [ + IntervalIndex.from_breaks, + lambda breaks: Interval(breaks[0], breaks[1]), + lambda breaks: breaks, + lambda breaks: breaks[0], + list], ids=['IntervalIndex', 'Interval', 'Index', 'scalar', 'list']) + def test_maybe_convert_i8_errors(self, breaks1, breaks2, make_key): # GH 20636 index = IntervalIndex.from_breaks(breaks1) - invalid_keys = [ - IntervalIndex.from_breaks(breaks2), - Interval(breaks2[0], breaks2[1]), - breaks2, - breaks2[0], - list(breaks2)] + key = make_key(breaks2) msg = ('Cannot index an IntervalIndex of subtype {dtype1} with ' 'values of dtype {dtype2}') msg = re.escape(msg.format(dtype1=breaks1.dtype, dtype2=breaks2.dtype)) - for key in invalid_keys: - with tm.assert_raises_regex(ValueError, msg): - index._maybe_convert_i8(key) + with tm.assert_raises_regex(ValueError, msg): + index._maybe_convert_i8(key) # To be removed, replaced by test_interval_new.py (see #16316, #16386) def test_contains(self):