diff --git a/doc/source/whatsnew/v0.17.0.txt b/doc/source/whatsnew/v0.17.0.txt index 0b4acdc3e89bb..3d625c0299db5 100644 --- a/doc/source/whatsnew/v0.17.0.txt +++ b/doc/source/whatsnew/v0.17.0.txt @@ -383,7 +383,7 @@ Other enhancements - ``DataFrame`` has gained the ``nlargest`` and ``nsmallest`` methods (:issue:`10393`) -- Add a ``limit_direction`` keyword argument that works with ``limit`` to enable ``interpolate`` to fill ``NaN`` values forward, backward, or both (:issue:`9218` and :issue:`10420`) +- Add a ``limit_direction`` keyword argument that works with ``limit`` to enable ``interpolate`` to fill ``NaN`` values forward, backward, or both (:issue:`9218`, :issue:`10420`, :issue:`11115`) .. ipython:: python diff --git a/pandas/core/common.py b/pandas/core/common.py index 8ffffae6bd160..9189b0d89de4f 100644 --- a/pandas/core/common.py +++ b/pandas/core/common.py @@ -1582,13 +1582,10 @@ def interpolate_1d(xvalues, yvalues, method='linear', limit=None, method = 'values' def _interp_limit(invalid, fw_limit, bw_limit): - "Get idx of values that won't be forward-filled b/c they exceed the limit." - all_nans = np.where(invalid)[0] - if all_nans.size == 0: # no nans anyway - return [] - violate = [invalid[max(0, x - bw_limit):x + fw_limit + 1] for x in all_nans] - violate = np.array([x.all() & (x.size > bw_limit + fw_limit) for x in violate]) - return all_nans[violate] + fw_limit - bw_limit + "Get idx of values that won't be filled b/c they exceed the limits." + for x in np.where(invalid)[0]: + if invalid[max(0, x - fw_limit):x + bw_limit + 1].all(): + yield x valid_limit_directions = ['forward', 'backward', 'both'] limit_direction = limit_direction.lower() @@ -1624,7 +1621,7 @@ def _interp_limit(invalid, fw_limit, bw_limit): if limit_direction == 'backward': violate_limit = sorted(end_nans | set(_interp_limit(invalid, 0, limit))) if limit_direction == 'both': - violate_limit = _interp_limit(invalid, limit, limit) + violate_limit = sorted(_interp_limit(invalid, limit, limit)) xvalues = getattr(xvalues, 'values', xvalues) yvalues = getattr(yvalues, 'values', yvalues) diff --git a/pandas/tests/test_generic.py b/pandas/tests/test_generic.py index 19989116b26df..3a26be2ca1032 100644 --- a/pandas/tests/test_generic.py +++ b/pandas/tests/test_generic.py @@ -878,7 +878,6 @@ def test_interp_limit_forward(self): def test_interp_limit_bad_direction(self): s = Series([1, 3, np.nan, np.nan, np.nan, 11]) - expected = Series([1., 3., 5., 7., 9., 11.]) self.assertRaises(ValueError, s.interpolate, method='linear', limit=2, @@ -930,6 +929,25 @@ def test_interp_limit_to_ends(self): method='linear', limit=2, limit_direction='both') assert_series_equal(result, expected) + def test_interp_limit_before_ends(self): + # These test are for issue #11115 -- limit ends properly. + s = Series([np.nan, np.nan, 5, 7, np.nan, np.nan]) + + expected = Series([np.nan, np.nan, 5., 7., 7., np.nan]) + result = s.interpolate( + method='linear', limit=1, limit_direction='forward') + assert_series_equal(result, expected) + + expected = Series([np.nan, 5., 5., 7., np.nan, np.nan]) + result = s.interpolate( + method='linear', limit=1, limit_direction='backward') + assert_series_equal(result, expected) + + expected = Series([np.nan, 5., 5., 7., 7., np.nan]) + result = s.interpolate( + method='linear', limit=1, limit_direction='both') + assert_series_equal(result, expected) + def test_interp_all_good(self): # scipy tm._skip_if_no_scipy()