From c7a33f5bfd556da7f164f7ca11866f30714c3527 Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Mon, 15 Jun 2020 21:06:27 -0700 Subject: [PATCH 1/2] BUG: Respect center=True in rolling.apply when numba engine is used --- doc/source/whatsnew/v1.1.0.rst | 2 +- pandas/core/window/rolling.py | 11 +++++++---- pandas/tests/window/test_numba.py | 8 +++++--- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/doc/source/whatsnew/v1.1.0.rst b/doc/source/whatsnew/v1.1.0.rst index f68135bf8cf9c..7e04d8f906cb0 100644 --- a/doc/source/whatsnew/v1.1.0.rst +++ b/doc/source/whatsnew/v1.1.0.rst @@ -1015,7 +1015,7 @@ Groupby/resample/rolling The behaviour now is consistent, independent of internal heuristics. (:issue:`31612`, :issue:`14927`, :issue:`13056`) - Bug in :meth:`SeriesGroupBy.agg` where any column name was accepted in the named aggregation of ``SeriesGroupBy`` previously. The behaviour now allows only ``str`` and callables else would raise ``TypeError``. (:issue:`34422`) - Bug in :meth:`DataFrame.groupby` lost index, when one of the ``agg`` keys referenced an empty list (:issue:`32580`) - +- Bug in :meth:`Rolling.apply` where ``center=True`` was ignored when ``engine='numba'`` was specified (:issue:`34784`) Reshaping ^^^^^^^^^ diff --git a/pandas/core/window/rolling.py b/pandas/core/window/rolling.py index 92be2d056cfcb..301a5efe1a339 100644 --- a/pandas/core/window/rolling.py +++ b/pandas/core/window/rolling.py @@ -1353,17 +1353,20 @@ def apply( kwargs = {} kwargs.pop("_level", None) kwargs.pop("floor", None) - window = self._get_window() - offset = calculate_center_offset(window) if self.center else 0 if not is_bool(raw): raise ValueError("raw parameter must be `True` or `False`") if engine == "cython": if engine_kwargs is not None: raise ValueError("cython engine does not accept engine_kwargs") + # Cython apply functions handle center, so don't need to use + # _apply's center handling + window = self._get_window() + offset = calculate_center_offset(window) if self.center else 0 apply_func = self._generate_cython_apply_func( args, kwargs, raw, offset, func ) + center = False elif engine == "numba": if raw is False: raise ValueError("raw must be `True` when using the numba engine") @@ -1375,14 +1378,14 @@ def apply( apply_func = generate_numba_apply_func( args, kwargs, func, engine_kwargs ) + center = self.center else: raise ValueError("engine must be either 'numba' or 'cython'") - # TODO: Why do we always pass center=False? # name=func & raw=raw for WindowGroupByMixin._apply return self._apply( apply_func, - center=False, + center=center, floor=0, name=func, use_numba_cache=engine == "numba", diff --git a/pandas/tests/window/test_numba.py b/pandas/tests/window/test_numba.py index 8ecf64b171df4..7e049af0ca1f8 100644 --- a/pandas/tests/window/test_numba.py +++ b/pandas/tests/window/test_numba.py @@ -13,7 +13,7 @@ # Filter warnings when parallel=True and the function can't be parallelized by Numba class TestApply: @pytest.mark.parametrize("jit", [True, False]) - def test_numba_vs_cython(self, jit, nogil, parallel, nopython): + def test_numba_vs_cython(self, jit, nogil, parallel, nopython, center): def f(x, *args): arg_sum = 0 for arg in args: @@ -29,10 +29,12 @@ def f(x, *args): args = (2,) s = Series(range(10)) - result = s.rolling(2).apply( + result = s.rolling(2, center=center).apply( f, args=args, engine="numba", engine_kwargs=engine_kwargs, raw=True ) - expected = s.rolling(2).apply(f, engine="cython", args=args, raw=True) + expected = s.rolling(2, center=center).apply( + f, engine="cython", args=args, raw=True + ) tm.assert_series_equal(result, expected) @pytest.mark.parametrize("jit", [True, False]) From 51af75c3be1c91279df8265a409d9ae90d561a25 Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Mon, 15 Jun 2020 21:36:03 -0700 Subject: [PATCH 2/2] Stricter typing on center --- pandas/core/window/rolling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/window/rolling.py b/pandas/core/window/rolling.py index 301a5efe1a339..ce0a2a9b95025 100644 --- a/pandas/core/window/rolling.py +++ b/pandas/core/window/rolling.py @@ -150,7 +150,7 @@ def __init__( obj, window=None, min_periods: Optional[int] = None, - center: Optional[bool] = False, + center: bool = False, win_type: Optional[str] = None, axis: Axis = 0, on: Optional[Union[str, Index]] = None,