Skip to content

Commit 60ba05f

Browse files
authored
BUG: Respect center=True in rolling.apply when numba engine is used (#34816)
1 parent b599990 commit 60ba05f

File tree

3 files changed

+14
-9
lines changed

3 files changed

+14
-9
lines changed

doc/source/whatsnew/v1.1.0.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1015,7 +1015,7 @@ Groupby/resample/rolling
10151015
The behaviour now is consistent, independent of internal heuristics. (:issue:`31612`, :issue:`14927`, :issue:`13056`)
10161016
- Bug in :meth:`SeriesGroupBy.agg` where any column name was accepted in the named aggregation of ``SeriesGroupBy`` previously. The behaviour now allows only ``str`` and callables else would raise ``TypeError``. (:issue:`34422`)
10171017
- Bug in :meth:`DataFrame.groupby` lost index, when one of the ``agg`` keys referenced an empty list (:issue:`32580`)
1018-
1018+
- Bug in :meth:`Rolling.apply` where ``center=True`` was ignored when ``engine='numba'`` was specified (:issue:`34784`)
10191019

10201020
Reshaping
10211021
^^^^^^^^^

pandas/core/window/rolling.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def __init__(
150150
obj,
151151
window=None,
152152
min_periods: Optional[int] = None,
153-
center: Optional[bool] = False,
153+
center: bool = False,
154154
win_type: Optional[str] = None,
155155
axis: Axis = 0,
156156
on: Optional[Union[str, Index]] = None,
@@ -1353,17 +1353,20 @@ def apply(
13531353
kwargs = {}
13541354
kwargs.pop("_level", None)
13551355
kwargs.pop("floor", None)
1356-
window = self._get_window()
1357-
offset = calculate_center_offset(window) if self.center else 0
13581356
if not is_bool(raw):
13591357
raise ValueError("raw parameter must be `True` or `False`")
13601358

13611359
if engine == "cython":
13621360
if engine_kwargs is not None:
13631361
raise ValueError("cython engine does not accept engine_kwargs")
1362+
# Cython apply functions handle center, so don't need to use
1363+
# _apply's center handling
1364+
window = self._get_window()
1365+
offset = calculate_center_offset(window) if self.center else 0
13641366
apply_func = self._generate_cython_apply_func(
13651367
args, kwargs, raw, offset, func
13661368
)
1369+
center = False
13671370
elif engine == "numba":
13681371
if raw is False:
13691372
raise ValueError("raw must be `True` when using the numba engine")
@@ -1375,14 +1378,14 @@ def apply(
13751378
apply_func = generate_numba_apply_func(
13761379
args, kwargs, func, engine_kwargs
13771380
)
1381+
center = self.center
13781382
else:
13791383
raise ValueError("engine must be either 'numba' or 'cython'")
13801384

1381-
# TODO: Why do we always pass center=False?
13821385
# name=func & raw=raw for WindowGroupByMixin._apply
13831386
return self._apply(
13841387
apply_func,
1385-
center=False,
1388+
center=center,
13861389
floor=0,
13871390
name=func,
13881391
use_numba_cache=engine == "numba",

pandas/tests/window/test_numba.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# Filter warnings when parallel=True and the function can't be parallelized by Numba
1414
class TestApply:
1515
@pytest.mark.parametrize("jit", [True, False])
16-
def test_numba_vs_cython(self, jit, nogil, parallel, nopython):
16+
def test_numba_vs_cython(self, jit, nogil, parallel, nopython, center):
1717
def f(x, *args):
1818
arg_sum = 0
1919
for arg in args:
@@ -29,10 +29,12 @@ def f(x, *args):
2929
args = (2,)
3030

3131
s = Series(range(10))
32-
result = s.rolling(2).apply(
32+
result = s.rolling(2, center=center).apply(
3333
f, args=args, engine="numba", engine_kwargs=engine_kwargs, raw=True
3434
)
35-
expected = s.rolling(2).apply(f, engine="cython", args=args, raw=True)
35+
expected = s.rolling(2, center=center).apply(
36+
f, engine="cython", args=args, raw=True
37+
)
3638
tm.assert_series_equal(result, expected)
3739

3840
@pytest.mark.parametrize("jit", [True, False])

0 commit comments

Comments
 (0)