Skip to content

Commit df69f2a

Browse files
authored
ENH: Add numba engine to several rolling aggregations (#38895)
1 parent 7f2a768 commit df69f2a

File tree

7 files changed

+290
-45
lines changed

7 files changed

+290
-45
lines changed

asv_bench/benchmarks/rolling.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,20 +50,24 @@ class Engine:
5050
["int", "float"],
5151
[np.sum, lambda x: np.sum(x) + 5],
5252
["cython", "numba"],
53+
["sum", "max", "min", "median", "mean"],
5354
)
54-
param_names = ["constructor", "dtype", "function", "engine"]
55+
param_names = ["constructor", "dtype", "function", "engine", "method"]
5556

56-
def setup(self, constructor, dtype, function, engine):
57+
def setup(self, constructor, dtype, function, engine, method):
5758
N = 10 ** 3
5859
arr = (100 * np.random.random(N)).astype(dtype)
5960
self.data = getattr(pd, constructor)(arr)
6061

61-
def time_rolling_apply(self, constructor, dtype, function, engine):
62+
def time_rolling_apply(self, constructor, dtype, function, engine, method):
6263
self.data.rolling(10).apply(function, raw=True, engine=engine)
6364

64-
def time_expanding_apply(self, constructor, dtype, function, engine):
65+
def time_expanding_apply(self, constructor, dtype, function, engine, method):
6566
self.data.expanding().apply(function, raw=True, engine=engine)
6667

68+
def time_rolling_methods(self, constructor, dtype, function, engine, method):
69+
getattr(self.data.rolling(10), method)(engine=engine)
70+
6771

6872
class ExpandingMethods:
6973

doc/source/user_guide/window.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,10 @@ Numba will be applied in potentially two routines:
321321
#. If ``func`` is a standard Python function, the engine will `JIT <https://numba.pydata.org/numba-doc/latest/user/overview.html>`__ the passed function. ``func`` can also be a JITed function in which case the engine will not JIT the function again.
322322
#. The engine will JIT the for loop where the apply function is applied to each window.
323323

324+
.. versionadded:: 1.3
325+
326+
``mean``, ``median``, ``max``, ``min``, and ``sum`` also support the ``engine`` and ``engine_kwargs`` arguments.
327+
324328
The ``engine_kwargs`` argument is a dictionary of keyword arguments that will be passed into the
325329
`numba.jit decorator <https://numba.pydata.org/numba-doc/latest/reference/jit-compilation.html#numba.jit>`__.
326330
These keyword arguments will be applied to *both* the passed function (if a standard Python function)

doc/source/whatsnew/v1.3.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ Other enhancements
5353
- :func:`to_numeric` now supports downcasting of nullable ``ExtensionDtype`` objects (:issue:`33013`)
5454
- Add support for dict-like names in :class:`MultiIndex.set_names` and :class:`MultiIndex.rename` (:issue:`20421`)
5555
- :func:`pandas.read_excel` can now auto detect .xlsb files (:issue:`35416`)
56+
- :meth:`.Rolling.sum`, :meth:`.Expanding.sum`, :meth:`.Rolling.mean`, :meth:`.Expanding.mean`, :meth:`.Rolling.median`, :meth:`.Expanding.median`, :meth:`.Rolling.max`, :meth:`.Expanding.max`, :meth:`.Rolling.min`, and :meth:`.Expanding.min` now support ``Numba`` execution with the ``engine`` keyword (:issue:`38895`)
5657

5758
.. ---------------------------------------------------------------------------
5859

pandas/core/window/expanding.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -172,33 +172,33 @@ def apply(
172172

173173
@Substitution(name="expanding")
174174
@Appender(_shared_docs["sum"])
175-
def sum(self, *args, **kwargs):
175+
def sum(self, *args, engine=None, engine_kwargs=None, **kwargs):
176176
nv.validate_expanding_func("sum", args, kwargs)
177-
return super().sum(*args, **kwargs)
177+
return super().sum(*args, engine=engine, engine_kwargs=engine_kwargs, **kwargs)
178178

179179
@Substitution(name="expanding", func_name="max")
180180
@Appender(_doc_template)
181181
@Appender(_shared_docs["max"])
182-
def max(self, *args, **kwargs):
182+
def max(self, *args, engine=None, engine_kwargs=None, **kwargs):
183183
nv.validate_expanding_func("max", args, kwargs)
184-
return super().max(*args, **kwargs)
184+
return super().max(*args, engine=engine, engine_kwargs=engine_kwargs, **kwargs)
185185

186186
@Substitution(name="expanding")
187187
@Appender(_shared_docs["min"])
188-
def min(self, *args, **kwargs):
188+
def min(self, *args, engine=None, engine_kwargs=None, **kwargs):
189189
nv.validate_expanding_func("min", args, kwargs)
190-
return super().min(*args, **kwargs)
190+
return super().min(*args, engine=engine, engine_kwargs=engine_kwargs, **kwargs)
191191

192192
@Substitution(name="expanding")
193193
@Appender(_shared_docs["mean"])
194-
def mean(self, *args, **kwargs):
194+
def mean(self, *args, engine=None, engine_kwargs=None, **kwargs):
195195
nv.validate_expanding_func("mean", args, kwargs)
196-
return super().mean(*args, **kwargs)
196+
return super().mean(*args, engine=engine, engine_kwargs=engine_kwargs, **kwargs)
197197

198198
@Substitution(name="expanding")
199199
@Appender(_shared_docs["median"])
200-
def median(self, **kwargs):
201-
return super().median(**kwargs)
200+
def median(self, engine=None, engine_kwargs=None, **kwargs):
201+
return super().median(engine=engine, engine_kwargs=engine_kwargs, **kwargs)
202202

203203
@Substitution(name="expanding", versionadded="")
204204
@Appender(_shared_docs["std"])
@@ -256,9 +256,16 @@ def kurt(self, **kwargs):
256256

257257
@Substitution(name="expanding")
258258
@Appender(_shared_docs["quantile"])
259-
def quantile(self, quantile, interpolation="linear", **kwargs):
259+
def quantile(
260+
self,
261+
quantile,
262+
interpolation="linear",
263+
**kwargs,
264+
):
260265
return super().quantile(
261-
quantile=quantile, interpolation=interpolation, **kwargs
266+
quantile=quantile,
267+
interpolation=interpolation,
268+
**kwargs,
262269
)
263270

264271
@Substitution(name="expanding", func_name="cov")

0 commit comments

Comments
 (0)