Skip to content

Commit aa25da2

Browse files
authored
ENH: Add table-wise numba rolling to other agg funcions (#38995)
1 parent 89ddd8a commit aa25da2

File tree

6 files changed

+49
-37
lines changed

6 files changed

+49
-37
lines changed

ci/deps/azure-37-slow.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,4 @@ dependencies:
3636
- xlwt
3737
- moto
3838
- flask
39+
- numba

ci/deps/azure-38-slow.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,4 @@ dependencies:
3434
- xlwt
3535
- moto
3636
- flask
37+
- numba

doc/source/whatsnew/v1.3.0.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ For example:
3737

3838
:class:`Rolling` and :class:`Expanding` now support a ``method`` argument with a
3939
``'table'`` option that performs the windowing operation over an entire :class:`DataFrame`.
40-
See ref:`window.overview` for performance and functional benefits. (:issue:`15095`)
40+
See ref:`window.overview` for performance and functional benefits. (:issue:`15095`, :issue:`38995`)
4141

4242
.. _whatsnew_130.enhancements.other:
4343

pandas/core/window/numba_.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
from typing import Any, Callable, Dict, Optional, Tuple
23

34
import numpy as np
@@ -220,3 +221,21 @@ def roll_table(
220221
return result
221222

222223
return roll_table
224+
225+
226+
# This function will no longer be needed once numba supports
227+
# axis for all np.nan* agg functions
228+
# https://github.com/numba/numba/issues/1269
229+
@functools.lru_cache(maxsize=None)
230+
def generate_manual_numpy_nan_agg_with_axis(nan_func):
231+
numba = import_optional_dependency("numba")
232+
233+
@numba.jit(nopython=True, nogil=True, parallel=True)
234+
def nan_agg_with_axis(table):
235+
result = np.empty(table.shape[1])
236+
for i in numba.prange(table.shape[1]):
237+
partition = table[:, i]
238+
result[i] = nan_func(partition)
239+
return result
240+
241+
return nan_agg_with_axis

pandas/core/window/rolling.py

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
VariableWindowIndexer,
6666
)
6767
from pandas.core.window.numba_ import (
68+
generate_manual_numpy_nan_agg_with_axis,
6869
generate_numba_apply_func,
6970
generate_numba_table_func,
7071
)
@@ -1378,16 +1379,15 @@ def sum(self, *args, engine=None, engine_kwargs=None, **kwargs):
13781379
nv.validate_window_func("sum", args, kwargs)
13791380
if maybe_use_numba(engine):
13801381
if self.method == "table":
1381-
raise NotImplementedError("method='table' is not supported.")
1382-
# Once numba supports np.nansum with axis, args will be relevant.
1383-
# https://github.com/numba/numba/issues/6610
1384-
args = () if self.method == "single" else (0,)
1382+
func = generate_manual_numpy_nan_agg_with_axis(np.nansum)
1383+
else:
1384+
func = np.nansum
1385+
13851386
return self.apply(
1386-
np.nansum,
1387+
func,
13871388
raw=True,
13881389
engine=engine,
13891390
engine_kwargs=engine_kwargs,
1390-
args=args,
13911391
)
13921392
window_func = window_aggregations.roll_sum
13931393
return self._apply(window_func, name="sum", **kwargs)
@@ -1424,16 +1424,15 @@ def max(self, *args, engine=None, engine_kwargs=None, **kwargs):
14241424
nv.validate_window_func("max", args, kwargs)
14251425
if maybe_use_numba(engine):
14261426
if self.method == "table":
1427-
raise NotImplementedError("method='table' is not supported.")
1428-
# Once numba supports np.nanmax with axis, args will be relevant.
1429-
# https://github.com/numba/numba/issues/6610
1430-
args = () if self.method == "single" else (0,)
1427+
func = generate_manual_numpy_nan_agg_with_axis(np.nanmax)
1428+
else:
1429+
func = np.nanmax
1430+
14311431
return self.apply(
1432-
np.nanmax,
1432+
func,
14331433
raw=True,
14341434
engine=engine,
14351435
engine_kwargs=engine_kwargs,
1436-
args=args,
14371436
)
14381437
window_func = window_aggregations.roll_max
14391438
return self._apply(window_func, name="max", **kwargs)
@@ -1496,16 +1495,15 @@ def min(self, *args, engine=None, engine_kwargs=None, **kwargs):
14961495
nv.validate_window_func("min", args, kwargs)
14971496
if maybe_use_numba(engine):
14981497
if self.method == "table":
1499-
raise NotImplementedError("method='table' is not supported.")
1500-
# Once numba supports np.nanmin with axis, args will be relevant.
1501-
# https://github.com/numba/numba/issues/6610
1502-
args = () if self.method == "single" else (0,)
1498+
func = generate_manual_numpy_nan_agg_with_axis(np.nanmin)
1499+
else:
1500+
func = np.nanmin
1501+
15031502
return self.apply(
1504-
np.nanmin,
1503+
func,
15051504
raw=True,
15061505
engine=engine,
15071506
engine_kwargs=engine_kwargs,
1508-
args=args,
15091507
)
15101508
window_func = window_aggregations.roll_min
15111509
return self._apply(window_func, name="min", **kwargs)
@@ -1514,16 +1512,15 @@ def mean(self, *args, engine=None, engine_kwargs=None, **kwargs):
15141512
nv.validate_window_func("mean", args, kwargs)
15151513
if maybe_use_numba(engine):
15161514
if self.method == "table":
1517-
raise NotImplementedError("method='table' is not supported.")
1518-
# Once numba supports np.nanmean with axis, args will be relevant.
1519-
# https://github.com/numba/numba/issues/6610
1520-
args = () if self.method == "single" else (0,)
1515+
func = generate_manual_numpy_nan_agg_with_axis(np.nanmean)
1516+
else:
1517+
func = np.nanmean
1518+
15211519
return self.apply(
1522-
np.nanmean,
1520+
func,
15231521
raw=True,
15241522
engine=engine,
15251523
engine_kwargs=engine_kwargs,
1526-
args=args,
15271524
)
15281525
window_func = window_aggregations.roll_mean
15291526
return self._apply(window_func, name="mean", **kwargs)
@@ -1584,16 +1581,15 @@ def mean(self, *args, engine=None, engine_kwargs=None, **kwargs):
15841581
def median(self, engine=None, engine_kwargs=None, **kwargs):
15851582
if maybe_use_numba(engine):
15861583
if self.method == "table":
1587-
raise NotImplementedError("method='table' is not supported.")
1588-
# Once numba supports np.nanmedian with axis, args will be relevant.
1589-
# https://github.com/numba/numba/issues/6610
1590-
args = () if self.method == "single" else (0,)
1584+
func = generate_manual_numpy_nan_agg_with_axis(np.nanmedian)
1585+
else:
1586+
func = np.nanmedian
1587+
15911588
return self.apply(
1592-
np.nanmedian,
1589+
func,
15931590
raw=True,
15941591
engine=engine,
15951592
engine_kwargs=engine_kwargs,
1596-
args=args,
15971593
)
15981594
window_func = window_aggregations.roll_median_c
15991595
return self._apply(window_func, name="median", **kwargs)

pandas/tests/window/test_numba.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ def test_invalid_kwargs_nopython():
163163

164164

165165
@td.skip_if_no("numba", "0.46.0")
166+
@pytest.mark.slow
166167
@pytest.mark.filterwarnings("ignore:\\nThe keyword argument")
167168
# Filter warnings when parallel=True and the function can't be parallelized by Numba
168169
class TestTableMethod:
@@ -177,9 +178,6 @@ def f(x):
177178
f, engine="numba", raw=True
178179
)
179180

180-
@pytest.mark.xfail(
181-
raises=NotImplementedError, reason="method='table' is not supported."
182-
)
183181
def test_table_method_rolling_methods(
184182
self, axis, nogil, parallel, nopython, arithmetic_numba_supported_operators
185183
):
@@ -247,9 +245,6 @@ def f(x):
247245
)
248246
tm.assert_frame_equal(result, expected)
249247

250-
@pytest.mark.xfail(
251-
raises=NotImplementedError, reason="method='table' is not supported."
252-
)
253248
def test_table_method_expanding_methods(
254249
self, axis, nogil, parallel, nopython, arithmetic_numba_supported_operators
255250
):

0 commit comments

Comments
 (0)