From 11c24324862071065e2c44b622bd2cf4bc787452 Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Fri, 17 Apr 2020 21:58:05 -0700 Subject: [PATCH 1/4] REF: Make numba function cache globally accessible --- pandas/core/groupby/generic.py | 23 ++++++++++---------- pandas/core/util/numba_.py | 2 ++ pandas/core/window/rolling.py | 9 ++++---- pandas/tests/groupby/transform/test_numba.py | 5 +++-- pandas/tests/window/test_numba.py | 3 ++- 5 files changed, 23 insertions(+), 19 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index c007d4920cbe7..3376ca379786d 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -76,6 +76,7 @@ from pandas.core.internals import BlockManager, make_block from pandas.core.series import Series from pandas.core.util.numba_ import ( + _numba_func_cache, check_kwargs_and_nopython, get_jit_arguments, jit_user_function, @@ -161,8 +162,6 @@ def pinner(cls): class SeriesGroupBy(GroupBy[Series]): _apply_whitelist = base.series_apply_whitelist - _numba_func_cache: Dict[Callable, Callable] = {} - def _iterate_slices(self) -> Iterable[Series]: yield self._selected_obj @@ -504,8 +503,9 @@ def _transform_general( nopython, nogil, parallel = get_jit_arguments(engine_kwargs) check_kwargs_and_nopython(kwargs, nopython) validate_udf(func) - numba_func = self._numba_func_cache.get( - func, jit_user_function(func, nopython, nogil, parallel) + cache_key = (func, "groupby_transform") + numba_func = _numba_func_cache.get( + cache_key, jit_user_function(func, nopython, nogil, parallel) ) klass = type(self._selected_obj) @@ -516,8 +516,8 @@ def _transform_general( if engine == "numba": values, index = split_for_numba(group) res = numba_func(values, index, *args) - if func not in self._numba_func_cache: - self._numba_func_cache[func] = numba_func + if cache_key not in _numba_func_cache: + _numba_func_cache[cache_key] = numba_func else: res = func(group, *args, **kwargs) @@ -847,8 +847,6 @@ class DataFrameGroupBy(GroupBy[DataFrame]): _apply_whitelist = base.dataframe_apply_whitelist - _numba_func_cache: Dict[Callable, Callable] = {} - _agg_see_also_doc = dedent( """ See Also @@ -1397,8 +1395,9 @@ def _transform_general( nopython, nogil, parallel = get_jit_arguments(engine_kwargs) check_kwargs_and_nopython(kwargs, nopython) validate_udf(func) - numba_func = self._numba_func_cache.get( - func, jit_user_function(func, nopython, nogil, parallel) + cache_key = (func, "groupby_transform") + numba_func = _numba_func_cache.get( + cache_key, jit_user_function(func, nopython, nogil, parallel) ) else: fast_path, slow_path = self._define_paths(func, *args, **kwargs) @@ -1409,8 +1408,8 @@ def _transform_general( if engine == "numba": values, index = split_for_numba(group) res = numba_func(values, index, *args) - if func not in self._numba_func_cache: - self._numba_func_cache[func] = numba_func + if cache_key not in _numba_func_cache: + _numba_func_cache[cache_key] = numba_func # Return the result as a DataFrame for concatenation later res = DataFrame(res, index=group.index, columns=group.columns) else: diff --git a/pandas/core/util/numba_.py b/pandas/core/util/numba_.py index c5b27b937a05b..b3842f3790ed1 100644 --- a/pandas/core/util/numba_.py +++ b/pandas/core/util/numba_.py @@ -8,6 +8,8 @@ from pandas._typing import FrameOrSeries from pandas.compat._optional import import_optional_dependency +_numba_func_cache: Dict[Tuple[Callable, str], Callable] = dict() + def check_kwargs_and_nopython( kwargs: Optional[Dict] = None, nopython: Optional[bool] = None diff --git a/pandas/core/window/rolling.py b/pandas/core/window/rolling.py index 3fdf81c4bb570..5c976755d2525 100644 --- a/pandas/core/window/rolling.py +++ b/pandas/core/window/rolling.py @@ -38,6 +38,7 @@ from pandas.core.base import DataError, PandasObject, SelectionMixin, ShallowMixin import pandas.core.common as com from pandas.core.indexes.api import Index, ensure_index +from pandas.core.util.numba_ import _numba_func_cache from pandas.core.window.common import ( WindowGroupByMixin, _doc_template, @@ -93,7 +94,6 @@ def __init__( self.win_freq = None self.axis = obj._get_axis_number(axis) if axis is not None else None self.validate() - self._numba_func_cache: Dict[Optional[str], Callable] = dict() @property def _constructor(self): @@ -505,7 +505,7 @@ def calc(x): result = np.asarray(result) if use_numba_cache: - self._numba_func_cache[name] = func + _numba_func_cache[(name, "rolling_apply")] = func if center: result = self._center_window(result, window) @@ -1278,9 +1278,10 @@ def apply( elif engine == "numba": if raw is False: raise ValueError("raw must be `True` when using the numba engine") - if func in self._numba_func_cache: + cache_key = (func, "rolling_apply") + if cache_key in _numba_func_cache: # Return an already compiled version of roll_apply if available - apply_func = self._numba_func_cache[func] + apply_func = _numba_func_cache[cache_key] else: apply_func = generate_numba_apply_func( args, kwargs, func, engine_kwargs diff --git a/pandas/tests/groupby/transform/test_numba.py b/pandas/tests/groupby/transform/test_numba.py index 96078d0aa3662..4d271b7a19561 100644 --- a/pandas/tests/groupby/transform/test_numba.py +++ b/pandas/tests/groupby/transform/test_numba.py @@ -4,6 +4,7 @@ from pandas import DataFrame import pandas._testing as tm +from pandas.core.util.numba_ import _numba_func_cache @td.skip_if_no("numba", "0.46.0") @@ -98,13 +99,13 @@ def func_2(values, index): expected = grouped.transform(lambda x: x + 1, engine="cython") tm.assert_equal(result, expected) # func_1 should be in the cache now - assert func_1 in grouped._numba_func_cache + assert (func_1, "groupby_transform") in _numba_func_cache # Add func_2 to the cache result = grouped.transform(func_2, engine="numba", engine_kwargs=engine_kwargs) expected = grouped.transform(lambda x: x * 5, engine="cython") tm.assert_equal(result, expected) - assert func_2 in grouped._numba_func_cache + assert (func_2, "groupby_transform") in _numba_func_cache # Retest func_1 which should use the cache result = grouped.transform(func_1, engine="numba", engine_kwargs=engine_kwargs) diff --git a/pandas/tests/window/test_numba.py b/pandas/tests/window/test_numba.py index cc8aef1779b46..36d493fb6341e 100644 --- a/pandas/tests/window/test_numba.py +++ b/pandas/tests/window/test_numba.py @@ -5,6 +5,7 @@ from pandas import Series import pandas._testing as tm +from pandas.core.util.numba_ import _numba_func_cache @td.skip_if_no("numba", "0.46.0") @@ -59,7 +60,7 @@ def func_2(x): tm.assert_series_equal(result, expected) # func_1 should be in the cache now - assert func_1 in roll._numba_func_cache + assert (func_1, "rolling_apply") in _numba_func_cache result = roll.apply( func_2, engine="numba", engine_kwargs=engine_kwargs, raw=True From a475102b423683a63fda372d6cca26bba08057d4 Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Sat, 18 Apr 2020 11:35:19 -0700 Subject: [PATCH 2/4] dont make variable private --- pandas/core/groupby/generic.py | 14 +++++++------- pandas/core/util/numba_.py | 2 +- pandas/core/window/rolling.py | 8 ++++---- pandas/tests/groupby/transform/test_numba.py | 6 +++--- pandas/tests/window/test_numba.py | 4 ++-- 5 files changed, 17 insertions(+), 17 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 3376ca379786d..504de404b2509 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -76,7 +76,7 @@ from pandas.core.internals import BlockManager, make_block from pandas.core.series import Series from pandas.core.util.numba_ import ( - _numba_func_cache, + NUMBA_FUNC_CACHE, check_kwargs_and_nopython, get_jit_arguments, jit_user_function, @@ -504,7 +504,7 @@ def _transform_general( check_kwargs_and_nopython(kwargs, nopython) validate_udf(func) cache_key = (func, "groupby_transform") - numba_func = _numba_func_cache.get( + numba_func = NUMBA_FUNC_CACHE.get( cache_key, jit_user_function(func, nopython, nogil, parallel) ) @@ -516,8 +516,8 @@ def _transform_general( if engine == "numba": values, index = split_for_numba(group) res = numba_func(values, index, *args) - if cache_key not in _numba_func_cache: - _numba_func_cache[cache_key] = numba_func + if cache_key not in NUMBA_FUNC_CACHE: + NUMBA_FUNC_CACHE[cache_key] = numba_func else: res = func(group, *args, **kwargs) @@ -1396,7 +1396,7 @@ def _transform_general( check_kwargs_and_nopython(kwargs, nopython) validate_udf(func) cache_key = (func, "groupby_transform") - numba_func = _numba_func_cache.get( + numba_func = NUMBA_FUNC_CACHE.get( cache_key, jit_user_function(func, nopython, nogil, parallel) ) else: @@ -1408,8 +1408,8 @@ def _transform_general( if engine == "numba": values, index = split_for_numba(group) res = numba_func(values, index, *args) - if cache_key not in _numba_func_cache: - _numba_func_cache[cache_key] = numba_func + if cache_key not in NUMBA_FUNC_CACHE: + NUMBA_FUNC_CACHE[cache_key] = numba_func # Return the result as a DataFrame for concatenation later res = DataFrame(res, index=group.index, columns=group.columns) else: diff --git a/pandas/core/util/numba_.py b/pandas/core/util/numba_.py index b3842f3790ed1..af24189adbc27 100644 --- a/pandas/core/util/numba_.py +++ b/pandas/core/util/numba_.py @@ -8,7 +8,7 @@ from pandas._typing import FrameOrSeries from pandas.compat._optional import import_optional_dependency -_numba_func_cache: Dict[Tuple[Callable, str], Callable] = dict() +NUMBA_FUNC_CACHE: Dict[Tuple[Callable, str], Callable] = dict() def check_kwargs_and_nopython( diff --git a/pandas/core/window/rolling.py b/pandas/core/window/rolling.py index 5c976755d2525..0128fa78461eb 100644 --- a/pandas/core/window/rolling.py +++ b/pandas/core/window/rolling.py @@ -38,7 +38,7 @@ from pandas.core.base import DataError, PandasObject, SelectionMixin, ShallowMixin import pandas.core.common as com from pandas.core.indexes.api import Index, ensure_index -from pandas.core.util.numba_ import _numba_func_cache +from pandas.core.util.numba_ import NUMBA_FUNC_CACHE from pandas.core.window.common import ( WindowGroupByMixin, _doc_template, @@ -505,7 +505,7 @@ def calc(x): result = np.asarray(result) if use_numba_cache: - _numba_func_cache[(name, "rolling_apply")] = func + NUMBA_FUNC_CACHE[(name, "rolling_apply")] = func if center: result = self._center_window(result, window) @@ -1279,9 +1279,9 @@ def apply( if raw is False: raise ValueError("raw must be `True` when using the numba engine") cache_key = (func, "rolling_apply") - if cache_key in _numba_func_cache: + if cache_key in NUMBA_FUNC_CACHE: # Return an already compiled version of roll_apply if available - apply_func = _numba_func_cache[cache_key] + apply_func = NUMBA_FUNC_CACHE[cache_key] else: apply_func = generate_numba_apply_func( args, kwargs, func, engine_kwargs diff --git a/pandas/tests/groupby/transform/test_numba.py b/pandas/tests/groupby/transform/test_numba.py index 4d271b7a19561..28904b669ae56 100644 --- a/pandas/tests/groupby/transform/test_numba.py +++ b/pandas/tests/groupby/transform/test_numba.py @@ -4,7 +4,7 @@ from pandas import DataFrame import pandas._testing as tm -from pandas.core.util.numba_ import _numba_func_cache +from pandas.core.util.numba_ import NUMBA_FUNC_CACHE @td.skip_if_no("numba", "0.46.0") @@ -99,13 +99,13 @@ def func_2(values, index): expected = grouped.transform(lambda x: x + 1, engine="cython") tm.assert_equal(result, expected) # func_1 should be in the cache now - assert (func_1, "groupby_transform") in _numba_func_cache + assert (func_1, "groupby_transform") in NUMBA_FUNC_CACHE # Add func_2 to the cache result = grouped.transform(func_2, engine="numba", engine_kwargs=engine_kwargs) expected = grouped.transform(lambda x: x * 5, engine="cython") tm.assert_equal(result, expected) - assert (func_2, "groupby_transform") in _numba_func_cache + assert (func_2, "groupby_transform") in NUMBA_FUNC_CACHE # Retest func_1 which should use the cache result = grouped.transform(func_1, engine="numba", engine_kwargs=engine_kwargs) diff --git a/pandas/tests/window/test_numba.py b/pandas/tests/window/test_numba.py index 36d493fb6341e..8ecf64b171df4 100644 --- a/pandas/tests/window/test_numba.py +++ b/pandas/tests/window/test_numba.py @@ -5,7 +5,7 @@ from pandas import Series import pandas._testing as tm -from pandas.core.util.numba_ import _numba_func_cache +from pandas.core.util.numba_ import NUMBA_FUNC_CACHE @td.skip_if_no("numba", "0.46.0") @@ -60,7 +60,7 @@ def func_2(x): tm.assert_series_equal(result, expected) # func_1 should be in the cache now - assert (func_1, "rolling_apply") in _numba_func_cache + assert (func_1, "rolling_apply") in NUMBA_FUNC_CACHE result = roll.apply( func_2, engine="numba", engine_kwargs=engine_kwargs, raw=True From e319ac636e0feee391e94bab73aa032fe84ae265 Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Sat, 18 Apr 2020 12:06:18 -0700 Subject: [PATCH 3/4] Explicity pass the function to construct the cache key --- pandas/core/window/rolling.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pandas/core/window/rolling.py b/pandas/core/window/rolling.py index 0128fa78461eb..7dfc210eab901 100644 --- a/pandas/core/window/rolling.py +++ b/pandas/core/window/rolling.py @@ -505,7 +505,7 @@ def calc(x): result = np.asarray(result) if use_numba_cache: - NUMBA_FUNC_CACHE[(name, "rolling_apply")] = func + NUMBA_FUNC_CACHE[(kwargs["original_func"], "rolling_apply")] = func if center: result = self._center_window(result, window) @@ -1298,6 +1298,7 @@ def apply( name=func, use_numba_cache=engine == "numba", raw=raw, + original_func=func, ) def _generate_cython_apply_func(self, args, kwargs, raw, offset, func): From 5db6c6c4a04c143e9d026ffac795f38480e64524 Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Sat, 18 Apr 2020 12:26:32 -0700 Subject: [PATCH 4/4] pop original func for rolling groupby --- pandas/core/window/common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pandas/core/window/common.py b/pandas/core/window/common.py index 40f17126fa163..ebc67d0a0e819 100644 --- a/pandas/core/window/common.py +++ b/pandas/core/window/common.py @@ -78,6 +78,7 @@ def _apply( performing the original function call on the grouped object. """ kwargs.pop("floor", None) + kwargs.pop("original_func", None) # TODO: can we de-duplicate with _dispatch? def f(x, name=name, *args):