Skip to content

Commit f21c722

Browse files
Brian TuBrian Tu
Brian Tu
authored and
Brian Tu
committed
ENH: tolerance now takes list-like argument for reindex and get_indexer.
1 parent 8e6b09f commit f21c722

File tree

15 files changed

+258
-20
lines changed

15 files changed

+258
-20
lines changed

doc/source/whatsnew/v0.21.0.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ Other Enhancements
8181
- :func:`date_range` now accepts 'YS' in addition to 'AS' as an alias for start of year (:issue:`9313`)
8282
- :func:`date_range` now accepts 'Y' in addition to 'A' as an alias for end of year (:issue:`9313`)
8383
- Integration with `Apache Parquet <https://parquet.apache.org/>`__, including a new top-level :func:`pd.read_parquet` and :func:`DataFrame.to_parquet` method, see :ref:`here <io.parquet>`.
84+
- :func:`Series.reindex`, :func:`DataFrame.reindex`, :func:`Index.get_indexer` now support list-like argument for `tolerance`.
8485

8586
.. _whatsnew_0210.api_breaking:
8687

pandas/core/generic.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2256,9 +2256,10 @@ def reindex_like(self, other, method=None, copy=True, limit=None,
22562256
Maximum number of consecutive labels to fill for inexact matches.
22572257
tolerance : optional
22582258
Maximum distance between labels of the other object and this
2259-
object for inexact matches.
2259+
object for inexact matches. Can be list-like.
22602260
22612261
.. versionadded:: 0.17.0
2262+
.. versionadded:: 0.21.0 (list-like tolerance)
22622263
22632264
Notes
22642265
-----
@@ -2596,8 +2597,14 @@ def sort_index(self, axis=0, level=None, ascending=True, inplace=False,
25962597
Maximum distance between original and new labels for inexact
25972598
matches. The values of the index at the matching locations most
25982599
satisfy the equation ``abs(index[indexer] - target) <= tolerance``.
2600+
Tolerance may be a scalar value, which applies the same tolerance
2601+
to all values, or list-like, which applies variable tolerance per
2602+
element. List-like includes list, tuple, array, Series, and must be
2603+
the same size as the index and its dtype must exactly match the
2604+
index's type.
25992605
26002606
.. versionadded:: 0.17.0
2607+
.. versionadded:: 0.21.0 (list-like tolerance)
26012608
26022609
Examples
26032610
--------
@@ -2819,8 +2826,14 @@ def _reindex_multi(self, axes, copy, fill_value):
28192826
Maximum distance between original and new labels for inexact
28202827
matches. The values of the index at the matching locations most
28212828
satisfy the equation ``abs(index[indexer] - target) <= tolerance``.
2829+
Tolerance may be a scalar value, which applies the same tolerance
2830+
to all values, or list-like, which applies variable tolerance per
2831+
element. List-like includes list, tuple, array, Series, and must be
2832+
the same size as the index and its dtype must exactly match the
2833+
index's type.
28222834
28232835
.. versionadded:: 0.17.0
2836+
.. versionadded:: 0.21.0 (list-like tolerance)
28242837
28252838
Examples
28262839
--------

pandas/core/indexes/base.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ class InvalidIndexError(Exception):
8585
_o_dtype = np.dtype(object)
8686
_Identity = object
8787

88-
8988
def _new_Index(cls, d):
9089
""" This is called upon unpickling, rather than the default which doesn't
9190
have arguments and breaks __new__
@@ -2436,9 +2435,14 @@ def _get_unique_index(self, dropna=False):
24362435
tolerance : optional
24372436
Maximum distance from index value for inexact matches. The value of
24382437
the index at the matching location most satisfy the equation
2439-
``abs(index[loc] - key) <= tolerance``.
2438+
``abs(index[loc] - key) <= tolerance``. Tolerance may be a scalar
2439+
value, which applies the same tolerance to all values, or
2440+
list-like, which applies variable tolerance per element. List-like
2441+
includes list, tuple, array, Series, and must be the same size as
2442+
the index and its dtype must exactly match the index's type.
24402443
24412444
.. versionadded:: 0.17.0
2445+
.. versionadded:: 0.21.0 (list-like tolerance)
24422446
24432447
Returns
24442448
-------
@@ -2558,8 +2562,14 @@ def _get_level_values(self, level):
25582562
Maximum distance between original and new labels for inexact
25592563
matches. The values of the index at the matching locations most
25602564
satisfy the equation ``abs(index[indexer] - target) <= tolerance``.
2565+
Tolerance may be a scalar value, which applies the same tolerance
2566+
to all values, or list-like, which applies variable tolerance per
2567+
element. List-like includes list, tuple, array, Series, and must be
2568+
the same size as the index and its dtype must exactly match the
2569+
index's type.
25612570
25622571
.. versionadded:: 0.17.0
2572+
.. versionadded:: 0.21.0 (list-like tolerance)
25632573
25642574
Examples
25652575
--------
@@ -2580,6 +2590,10 @@ def get_indexer(self, target, method=None, limit=None, tolerance=None):
25802590
target = _ensure_index(target)
25812591
if tolerance is not None:
25822592
tolerance = self._convert_tolerance(tolerance)
2593+
if isinstance(tolerance, np.ndarray) and \
2594+
target.size != tolerance.size and tolerance.size > 1:
2595+
raise ValueError('ndarray tolerance size must match '
2596+
'target index size')
25832597

25842598
pself, ptarget = self._maybe_promote(target)
25852599
if pself is not self or ptarget is not target:
@@ -2614,7 +2628,7 @@ def get_indexer(self, target, method=None, limit=None, tolerance=None):
26142628

26152629
def _convert_tolerance(self, tolerance):
26162630
# override this method on subclasses
2617-
return tolerance
2631+
return _list_to_ndarray(tolerance)
26182632

26192633
def _get_fill_indexer(self, target, method, limit=None, tolerance=None):
26202634
if self.is_monotonic_increasing and target.is_monotonic_increasing:
@@ -4008,6 +4022,16 @@ def invalid_op(self, other=None):
40084022
Index._add_comparison_methods()
40094023

40104024

4025+
def _list_to_ndarray(a):
4026+
"""Convert list-like to np.ndarray, otherwise leave as-is.
4027+
Used for converting tolerance to ndarray in _convert_tolerance.
4028+
"""
4029+
if isinstance(a, ABCSeries):
4030+
return a.values
4031+
elif isinstance(a, (list, tuple)):
4032+
return np.array(a)
4033+
return a
4034+
40114035
def _ensure_index(index_like, copy=False):
40124036
if isinstance(index_like, Index):
40134037
if copy:

pandas/core/indexes/datetimelike.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
Timedelta, Timestamp, iNaT, NaT)
2828
from pandas._libs.period import Period
2929

30-
from pandas.core.indexes.base import Index, _index_shared_docs
30+
from pandas.core.indexes.base import (Index, _index_shared_docs,
31+
_list_to_ndarray)
3132
from pandas.util._decorators import Appender, cache_readonly
3233
import pandas.core.dtypes.concat as _concat
3334
import pandas.tseries.frequencies as frequencies
@@ -432,12 +433,35 @@ def asobject(self):
432433
return Index(self._box_values(self.asi8), name=self.name, dtype=object)
433434

434435
def _convert_tolerance(self, tolerance):
435-
try:
436-
return Timedelta(tolerance).to_timedelta64()
437-
except ValueError:
438-
raise ValueError('tolerance argument for %s must be convertible '
439-
'to Timedelta: %r'
440-
% (type(self).__name__, tolerance))
436+
tolerance = _list_to_ndarray(tolerance)
437+
if isinstance(tolerance, np.ndarray):
438+
if np.issubdtype(tolerance.dtype, np.timedelta64):
439+
return tolerance
440+
else:
441+
try:
442+
tolerance = np.array([np.timedelta64(x)
443+
for x in tolerance])
444+
# in case user mixes something like seconds and Month
445+
if not np.issubdtype(tolerance.dtype, np.timedelta64):
446+
raise TypeError('All values in tolerance array must '
447+
'be convertible to [ns]')
448+
except ValueError as e:
449+
raise TypeError(('tolerance argument for %s must contain '
450+
'objects convertible to np.timedelta64 '
451+
'if it is list type') %
452+
(type(self).__name__,)) from e
453+
else:
454+
warnings.warn('Converting tolerance array to '
455+
'np.timedelta64 objects, consider doing preconverting '
456+
'for speed')
457+
return tolerance
458+
else:
459+
try:
460+
return Timedelta(tolerance).to_timedelta64()
461+
except ValueError:
462+
raise ValueError(('tolerance argument for %s must be '
463+
'convertible to Timedelta if it is a scalar: %r')
464+
% (type(self).__name__, tolerance))
441465

442466
def _maybe_mask_results(self, result, fill_value=None, convert=None):
443467
"""

pandas/core/indexes/datetimes.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1435,7 +1435,12 @@ def get_loc(self, key, method=None, tolerance=None):
14351435
try:
14361436
stamp = Timestamp(key, tz=self.tz)
14371437
return Index.get_loc(self, stamp, method, tolerance)
1438-
except (KeyError, ValueError):
1438+
except KeyError:
1439+
raise KeyError(key)
1440+
except ValueError as e:
1441+
# ndarray tolerance size must match target index size
1442+
if 'ndarray' in str(e):
1443+
raise e
14391444
raise KeyError(key)
14401445

14411446
def _maybe_cast_slice_bound(self, label, side, kind):

pandas/core/indexes/numeric.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from pandas import compat
1616
from pandas.core import algorithms
1717
from pandas.core.indexes.base import (
18-
Index, InvalidIndexError, _index_shared_docs)
18+
Index, InvalidIndexError, _index_shared_docs, _list_to_ndarray)
1919
from pandas.util._decorators import Appender, cache_readonly
2020
import pandas.core.indexes.base as ibase
2121

@@ -72,11 +72,20 @@ def _convert_for_op(self, value):
7272
return value
7373

7474
def _convert_tolerance(self, tolerance):
75-
try:
76-
return float(tolerance)
77-
except ValueError:
78-
raise ValueError('tolerance argument for %s must be numeric: %r' %
79-
(type(self).__name__, tolerance))
75+
tolerance = _list_to_ndarray(tolerance)
76+
if isinstance(tolerance, np.ndarray):
77+
if np.issubdtype(tolerance.dtype, np.number):
78+
return tolerance
79+
else:
80+
raise ValueError(('tolerance argument for %s must contain '
81+
'numeric elements if it is list type') % (type(self).__name__,))
82+
else:
83+
try:
84+
return float(tolerance)
85+
except ValueError:
86+
raise ValueError(('tolerance argument for %s must be numeric '
87+
'if it is a scalar: %r') %
88+
(type(self).__name__, tolerance))
8089

8190
@classmethod
8291
def _assert_safe_casting(cls, data, subarr):

pandas/core/indexes/period.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -633,12 +633,17 @@ def to_timestamp(self, freq=None, how='start'):
633633
return DatetimeIndex(new_data, freq='infer', name=self.name)
634634

635635
def _maybe_convert_timedelta(self, other):
636-
if isinstance(other, (timedelta, np.timedelta64, offsets.Tick)):
636+
if isinstance(other,
637+
(timedelta, np.timedelta64, offsets.Tick, np.ndarray)):
637638
offset = frequencies.to_offset(self.freq.rule_code)
638639
if isinstance(offset, offsets.Tick):
639640
nanos = tslib._delta_to_nanoseconds(other)
640641
offset_nanos = tslib._delta_to_nanoseconds(offset)
641-
if nanos % offset_nanos == 0:
642+
if isinstance(other, np.ndarray):
643+
check = np.all(nanos % offset_nanos == 0)
644+
else:
645+
check = nanos % offset_nanos == 0
646+
if check:
642647
return nanos // offset_nanos
643648
elif isinstance(other, offsets.DateOffset):
644649
freqstr = other.rule_code
@@ -775,6 +780,10 @@ def get_indexer(self, target, method=None, limit=None, tolerance=None):
775780

776781
if tolerance is not None:
777782
tolerance = self._convert_tolerance(tolerance)
783+
if isinstance(tolerance, np.ndarray) and \
784+
target.size != tolerance.size and tolerance.size > 1:
785+
raise ValueError('ndarray tolerance size must match '
786+
'target index size')
778787
return Index.get_indexer(self._int64index, target, method,
779788
limit, tolerance)
780789

@@ -902,6 +911,10 @@ def _get_string_slice(self, key):
902911

903912
def _convert_tolerance(self, tolerance):
904913
tolerance = DatetimeIndexOpsMixin._convert_tolerance(self, tolerance)
914+
if isinstance(tolerance, np.ndarray) \
915+
and not np.issubdtype(tolerance.dtype, np.timedelta64):
916+
raise TypeError('All values in tolerance array must be '
917+
'convertible to [ns]')
905918
return self._maybe_convert_timedelta(tolerance)
906919

907920
def insert(self, loc, item):

pandas/tests/frame/test_indexing.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1905,9 +1905,13 @@ def test_reindex_methods(self):
19051905

19061906
actual = df.reindex_like(df, method=method, tolerance=0)
19071907
assert_frame_equal(df, actual)
1908+
actual = df.reindex_like(df, method=method, tolerance=[0,0,0,0])
1909+
assert_frame_equal(df, actual)
19081910

19091911
actual = df.reindex(target, method=method, tolerance=1)
19101912
assert_frame_equal(expected, actual)
1913+
actual = df.reindex(target, method=method, tolerance=[1,1,1,1])
1914+
assert_frame_equal(expected, actual)
19111915

19121916
e2 = expected[::-1]
19131917
actual = df.reindex(target[::-1], method=method)
@@ -1928,6 +1932,11 @@ def test_reindex_methods(self):
19281932
actual = df.reindex(target, method='nearest', tolerance=0.2)
19291933
assert_frame_equal(expected, actual)
19301934

1935+
expected = pd.DataFrame({'x': [0, np.nan, 1, np.nan]}, index=target)
1936+
actual = df.reindex(target, method='nearest',
1937+
tolerance=[0.5, 0.01, 0.4, 0.1])
1938+
assert_frame_equal(expected, actual)
1939+
19311940
def test_reindex_frame_add_nat(self):
19321941
rng = date_range('1/1/2000 00:00:00', periods=10, freq='10s')
19331942
df = DataFrame({'A': np.random.randn(len(rng)), 'B': rng})

pandas/tests/indexes/datetimes/test_datetime.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ def test_get_loc(self):
4545
idx.get_loc('2000-01-01T12', method='nearest', tolerance='foo')
4646
with pytest.raises(KeyError):
4747
idx.get_loc('2000-01-01T03', method='nearest', tolerance='2 hours')
48+
with pytest.raises(ValueError,
49+
match='tolerance size must match target index size'):
50+
idx.get_loc('2000-01-01', method='nearest',
51+
tolerance=[pd.Timedelta('1day').to_timedelta64(),
52+
pd.Timedelta('1day').to_timedelta64()])
4853

4954
assert idx.get_loc('2000', method='nearest') == slice(0, 3)
5055
assert idx.get_loc('2000-01', method='nearest') == slice(0, 3)
@@ -93,6 +98,30 @@ def test_get_indexer(self):
9398
idx.get_indexer(target, 'nearest',
9499
tolerance=pd.Timedelta('1 hour')),
95100
np.array([0, -1, 1], dtype=np.intp))
101+
tol_raw = [pd.Timedelta('1 hour'),
102+
pd.Timedelta('1 hour'),
103+
pd.Timedelta('1 hour').to_timedelta64(), ]
104+
with pytest.warns(UserWarning) as speedwarning:
105+
tm.assert_numpy_array_equal(
106+
idx.get_indexer(target, 'nearest',
107+
tolerance=tol_raw),
108+
np.array([0, -1, 1], dtype=np.intp))
109+
assert len(speedwarning) == 1
110+
assert speedwarning[0].message.args[0]\
111+
.endswith('preconverting for speed')
112+
tm.assert_numpy_array_equal(
113+
idx.get_indexer(target, 'nearest',
114+
tolerance=[np.timedelta64(x) for x in tol_raw]),
115+
np.array([0, -1, 1], dtype=np.intp))
116+
with pytest.raises(TypeError, match=('must contain objects '
117+
'convertible to np.timedelta64')):
118+
idx.get_indexer(target, 'nearest', tolerance=[1,2,3])
119+
tol_bad = [pd.Timedelta('2 hour').to_timedelta64(),
120+
pd.Timedelta('1 hour').to_timedelta64(),
121+
np.timedelta64(1, 'M'), ]
122+
with pytest.raises(TypeError, match=('All values.*'
123+
'convertible to \\[ns\\]')), pytest.warns(UserWarning):
124+
idx.get_indexer(target, 'nearest', tolerance=tol_bad)
96125
with pytest.raises(ValueError):
97126
idx.get_indexer(idx[[0]], method='nearest', tolerance='foo')
98127

pandas/tests/indexes/period/test_period.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,11 @@ def test_get_loc(self):
8989
idx.get_loc('2000-01-10', method='nearest', tolerance='1 hour')
9090
with pytest.raises(KeyError):
9191
idx.get_loc('2000-01-10', method='nearest', tolerance='1 day')
92+
with pytest.raises(ValueError, match=('ndarray tolerance size must '
93+
'match target index size')):
94+
idx.get_loc('2000-01-10', method='nearest',
95+
tolerance=[pd.Timedelta('1 day').to_timedelta64(),
96+
pd.Timedelta('1 day').to_timedelta64()])
9297

9398
def test_where(self):
9499
i = self.create_index()
@@ -156,6 +161,30 @@ def test_get_indexer(self):
156161
tm.assert_numpy_array_equal(idx.get_indexer(target, 'nearest',
157162
tolerance='1 day'),
158163
np.array([0, 1, 1], dtype=np.intp))
164+
tol_raw = [pd.Timedelta('1 hour'),
165+
pd.Timedelta('1 hour'),
166+
np.timedelta64(1, 'D'), ]
167+
with pytest.warns(UserWarning) as speedwarning:
168+
tm.assert_numpy_array_equal(
169+
idx.get_indexer(target, 'nearest',
170+
tolerance=tol_raw),
171+
np.array([0, -1, 1], dtype=np.intp))
172+
assert len(speedwarning) == 1
173+
assert speedwarning[0].message.args[0]\
174+
.endswith('preconverting for speed')
175+
tm.assert_numpy_array_equal(
176+
idx.get_indexer(target, 'nearest',
177+
tolerance=[np.timedelta64(x) for x in tol_raw]),
178+
np.array([0, -1, 1], dtype=np.intp))
179+
tol_bad = [pd.Timedelta('2 hour').to_timedelta64(),
180+
pd.Timedelta('1 hour').to_timedelta64(),
181+
np.timedelta64(1, 'M'), ]
182+
with pytest.raises(TypeError, match=('All values.*'
183+
'convertible to \\[ns\\]')), pytest.warns(UserWarning):
184+
idx.get_indexer(target, 'nearest', tolerance=tol_bad)
185+
with pytest.raises(TypeError, match=('must contain objects '
186+
'convertible to np.timedelta64')):
187+
idx.get_indexer(target, 'nearest', tolerance=[1,2,3])
159188

160189
def test_repeat(self):
161190
# GH10183

0 commit comments

Comments
 (0)