Skip to content

Commit ffbeda7

Browse files
authored
PERF: Improve performance in rolling.mean(engine="numba") (#43612)
1 parent f3d4817 commit ffbeda7

File tree

7 files changed

+250
-21
lines changed

7 files changed

+250
-21
lines changed

doc/source/whatsnew/v1.4.0.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ Performance improvements
357357
- Performance improvement in :meth:`GroupBy.quantile` (:issue:`43469`)
358358
- :meth:`SparseArray.min` and :meth:`SparseArray.max` no longer require converting to a dense array (:issue:`43526`)
359359
- Performance improvement in :meth:`SparseArray.take` with ``allow_fill=False`` (:issue:`43654`)
360-
-
360+
- Performance improvement in :meth:`.Rolling.mean` and :meth:`.Expanding.mean` with ``engine="numba"`` (:issue:`43612`)
361361

362362
.. ---------------------------------------------------------------------------
363363

pandas/core/_numba/__init__.py

Whitespace-only changes.

pandas/core/_numba/executor.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from __future__ import annotations
2+
3+
from typing import Callable
4+
5+
import numpy as np
6+
7+
from pandas._typing import Scalar
8+
from pandas.compat._optional import import_optional_dependency
9+
10+
from pandas.core.util.numba_ import (
11+
NUMBA_FUNC_CACHE,
12+
get_jit_arguments,
13+
)
14+
15+
16+
def generate_shared_aggregator(
17+
func: Callable[..., Scalar],
18+
engine_kwargs: dict[str, bool] | None,
19+
cache_key_str: str,
20+
):
21+
"""
22+
Generate a Numba function that loops over the columns 2D object and applies
23+
a 1D numba kernel over each column.
24+
25+
Parameters
26+
----------
27+
func : function
28+
aggregation function to be applied to each column
29+
engine_kwargs : dict
30+
dictionary of arguments to be passed into numba.jit
31+
cache_key_str: str
32+
string to access the compiled function of the form
33+
<caller_type>_<aggregation_type> e.g. rolling_mean, groupby_mean
34+
35+
Returns
36+
-------
37+
Numba function
38+
"""
39+
nopython, nogil, parallel = get_jit_arguments(engine_kwargs, None)
40+
41+
cache_key = (func, cache_key_str)
42+
if cache_key in NUMBA_FUNC_CACHE:
43+
return NUMBA_FUNC_CACHE[cache_key]
44+
45+
numba = import_optional_dependency("numba")
46+
47+
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
48+
def column_looper(
49+
values: np.ndarray,
50+
start: np.ndarray,
51+
end: np.ndarray,
52+
min_periods: int,
53+
):
54+
result = np.empty((len(start), values.shape[1]), dtype=np.float64)
55+
for i in numba.prange(values.shape[1]):
56+
result[:, i] = func(values[:, i], start, end, min_periods)
57+
return result
58+
59+
return column_looper
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from pandas.core._numba.kernels.mean_ import sliding_mean
2+
3+
__all__ = ["sliding_mean"]

pandas/core/_numba/kernels/mean_.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
"""
2+
Numba 1D aggregation kernels that can be shared by
3+
* Dataframe / Series
4+
* groupby
5+
* rolling / expanding
6+
7+
Mirrors pandas/_libs/window/aggregation.pyx
8+
"""
9+
from __future__ import annotations
10+
11+
import numba
12+
import numpy as np
13+
14+
15+
@numba.jit(nopython=True, nogil=True, parallel=False)
16+
def is_monotonic_increasing(bounds: np.ndarray) -> bool:
17+
"""Check if int64 values are monotonically increasing."""
18+
n = len(bounds)
19+
if n < 2:
20+
return True
21+
prev = bounds[0]
22+
for i in range(1, n):
23+
cur = bounds[i]
24+
if cur < prev:
25+
return False
26+
prev = cur
27+
return True
28+
29+
30+
@numba.jit(nopython=True, nogil=True, parallel=False)
31+
def add_mean(
32+
val: float, nobs: int, sum_x: float, neg_ct: int, compensation: float
33+
) -> tuple[int, float, int, float]:
34+
if not np.isnan(val):
35+
nobs += 1
36+
y = val - compensation
37+
t = sum_x + y
38+
compensation = t - sum_x - y
39+
sum_x = t
40+
if val < 0:
41+
neg_ct += 1
42+
return nobs, sum_x, neg_ct, compensation
43+
44+
45+
@numba.jit(nopython=True, nogil=True, parallel=False)
46+
def remove_mean(
47+
val: float, nobs: int, sum_x: float, neg_ct: int, compensation: float
48+
) -> tuple[int, float, int, float]:
49+
if not np.isnan(val):
50+
nobs -= 1
51+
y = -val - compensation
52+
t = sum_x + y
53+
compensation = t - sum_x - y
54+
sum_x = t
55+
if val < 0:
56+
neg_ct -= 1
57+
return nobs, sum_x, neg_ct, compensation
58+
59+
60+
@numba.jit(nopython=True, nogil=True, parallel=False)
61+
def sliding_mean(
62+
values: np.ndarray,
63+
start: np.ndarray,
64+
end: np.ndarray,
65+
min_periods: int,
66+
) -> np.ndarray:
67+
N = len(start)
68+
nobs = 0
69+
sum_x = 0.0
70+
neg_ct = 0
71+
compensation_add = 0.0
72+
compensation_remove = 0.0
73+
74+
is_monotonic_increasing_bounds = is_monotonic_increasing(
75+
start
76+
) and is_monotonic_increasing(end)
77+
78+
output = np.empty(N, dtype=np.float64)
79+
80+
for i in range(N):
81+
s = start[i]
82+
e = end[i]
83+
if i == 0 or not is_monotonic_increasing_bounds:
84+
for j in range(s, e):
85+
val = values[j]
86+
nobs, sum_x, neg_ct, compensation_add = add_mean(
87+
val, nobs, sum_x, neg_ct, compensation_add
88+
)
89+
else:
90+
for j in range(start[i - 1], s):
91+
val = values[j]
92+
nobs, sum_x, neg_ct, compensation_remove = remove_mean(
93+
val, nobs, sum_x, neg_ct, compensation_remove
94+
)
95+
96+
for j in range(end[i - 1], e):
97+
val = values[j]
98+
nobs, sum_x, neg_ct, compensation_add = add_mean(
99+
val, nobs, sum_x, neg_ct, compensation_add
100+
)
101+
102+
if nobs >= min_periods and nobs > 0:
103+
result = sum_x / nobs
104+
if neg_ct == 0 and result < 0:
105+
result = 0
106+
elif neg_ct == nobs and result > 0:
107+
result = 0
108+
else:
109+
result = np.nan
110+
111+
output[i] = result
112+
113+
if not is_monotonic_increasing_bounds:
114+
nobs = 0
115+
sum_x = 0.0
116+
neg_ct = 0
117+
compensation_remove = 0.0
118+
119+
return output

pandas/core/window/rolling.py

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
)
5050
from pandas.core.dtypes.missing import notna
5151

52+
from pandas.core._numba import executor
5253
from pandas.core.algorithms import factorize
5354
from pandas.core.apply import ResamplerWindowApply
5455
from pandas.core.arrays import ExtensionArray
@@ -576,6 +577,44 @@ def calc(x):
576577
else:
577578
return self._apply_tablewise(homogeneous_func, name)
578579

580+
def _numba_apply(
581+
self,
582+
func: Callable[..., Any],
583+
numba_cache_key_str: str,
584+
engine_kwargs: dict[str, bool] | None = None,
585+
):
586+
window_indexer = self._get_window_indexer()
587+
min_periods = (
588+
self.min_periods
589+
if self.min_periods is not None
590+
else window_indexer.window_size
591+
)
592+
obj = self._create_data(self._selected_obj)
593+
if self.axis == 1:
594+
obj = obj.T
595+
values = self._prep_values(obj.to_numpy())
596+
if values.ndim == 1:
597+
values = values.reshape(-1, 1)
598+
start, end = window_indexer.get_window_bounds(
599+
num_values=len(values),
600+
min_periods=min_periods,
601+
center=self.center,
602+
closed=self.closed,
603+
)
604+
aggregator = executor.generate_shared_aggregator(
605+
func, engine_kwargs, numba_cache_key_str
606+
)
607+
result = aggregator(values, start, end, min_periods)
608+
NUMBA_FUNC_CACHE[(func, numba_cache_key_str)] = aggregator
609+
result = result.T if self.axis == 1 else result
610+
if obj.ndim == 1:
611+
result = result.squeeze()
612+
out = obj._constructor(result, index=obj.index, name=obj.name)
613+
return out
614+
else:
615+
out = obj._constructor(result, index=obj.index, columns=obj.columns)
616+
return self._resolve_output(out, obj)
617+
579618
def aggregate(self, func, *args, **kwargs):
580619
result = ResamplerWindowApply(self, func, args=args, kwargs=kwargs).agg()
581620
if result is None:
@@ -1331,15 +1370,16 @@ def mean(
13311370
if maybe_use_numba(engine):
13321371
if self.method == "table":
13331372
func = generate_manual_numpy_nan_agg_with_axis(np.nanmean)
1373+
return self.apply(
1374+
func,
1375+
raw=True,
1376+
engine=engine,
1377+
engine_kwargs=engine_kwargs,
1378+
)
13341379
else:
1335-
func = np.nanmean
1380+
from pandas.core._numba.kernels import sliding_mean
13361381

1337-
return self.apply(
1338-
func,
1339-
raw=True,
1340-
engine=engine,
1341-
engine_kwargs=engine_kwargs,
1342-
)
1382+
return self._numba_apply(sliding_mean, "rolling_mean", engine_kwargs)
13431383
window_func = window_aggregations.roll_mean
13441384
return self._apply(window_func, name="mean", **kwargs)
13451385

pandas/tests/window/test_numba.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,44 +43,52 @@ def f(x, *args):
4343
)
4444
tm.assert_series_equal(result, expected)
4545

46+
@pytest.mark.parametrize(
47+
"data", [DataFrame(np.eye(5)), Series(range(5), name="foo")]
48+
)
4649
def test_numba_vs_cython_rolling_methods(
47-
self, nogil, parallel, nopython, arithmetic_numba_supported_operators
50+
self, data, nogil, parallel, nopython, arithmetic_numba_supported_operators
4851
):
4952

5053
method = arithmetic_numba_supported_operators
5154

5255
engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
5356

54-
df = DataFrame(np.eye(5))
55-
roll = df.rolling(2)
57+
roll = data.rolling(2)
5658
result = getattr(roll, method)(engine="numba", engine_kwargs=engine_kwargs)
5759
expected = getattr(roll, method)(engine="cython")
5860

5961
# Check the cache
60-
assert (getattr(np, f"nan{method}"), "Rolling_apply_single") in NUMBA_FUNC_CACHE
62+
if method != "mean":
63+
assert (
64+
getattr(np, f"nan{method}"),
65+
"Rolling_apply_single",
66+
) in NUMBA_FUNC_CACHE
6167

62-
tm.assert_frame_equal(result, expected)
68+
tm.assert_equal(result, expected)
6369

70+
@pytest.mark.parametrize("data", [DataFrame(np.eye(5)), Series(range(5))])
6471
def test_numba_vs_cython_expanding_methods(
65-
self, nogil, parallel, nopython, arithmetic_numba_supported_operators
72+
self, data, nogil, parallel, nopython, arithmetic_numba_supported_operators
6673
):
6774

6875
method = arithmetic_numba_supported_operators
6976

7077
engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
7178

72-
df = DataFrame(np.eye(5))
73-
expand = df.expanding()
79+
data = DataFrame(np.eye(5))
80+
expand = data.expanding()
7481
result = getattr(expand, method)(engine="numba", engine_kwargs=engine_kwargs)
7582
expected = getattr(expand, method)(engine="cython")
7683

7784
# Check the cache
78-
assert (
79-
getattr(np, f"nan{method}"),
80-
"Expanding_apply_single",
81-
) in NUMBA_FUNC_CACHE
85+
if method != "mean":
86+
assert (
87+
getattr(np, f"nan{method}"),
88+
"Expanding_apply_single",
89+
) in NUMBA_FUNC_CACHE
8290

83-
tm.assert_frame_equal(result, expected)
91+
tm.assert_equal(result, expected)
8492

8593
@pytest.mark.parametrize("jit", [True, False])
8694
def test_cache_apply(self, jit, nogil, parallel, nopython):

0 commit comments

Comments
 (0)