diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 6745203d5beb7..c281fda71a7ec 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -77,11 +77,8 @@ from pandas.core.series import Series from pandas.core.util.numba_ import ( NUMBA_FUNC_CACHE, - check_kwargs_and_nopython, - get_jit_arguments, - jit_user_function, + generate_numba_func, split_for_numba, - validate_udf, ) from pandas.plotting import boxplot_frame_groupby @@ -507,12 +504,8 @@ def _transform_general( """ if engine == "numba": - nopython, nogil, parallel = get_jit_arguments(engine_kwargs) - check_kwargs_and_nopython(kwargs, nopython) - validate_udf(func) - cache_key = (func, "groupby_transform") - numba_func = NUMBA_FUNC_CACHE.get( - cache_key, jit_user_function(func, nopython, nogil, parallel) + numba_func, cache_key = generate_numba_func( + func, engine_kwargs, kwargs, "groupby_transform" ) klass = type(self._selected_obj) @@ -1407,12 +1400,8 @@ def _transform_general( obj = self._obj_with_exclusions gen = self.grouper.get_iterator(obj, axis=self.axis) if engine == "numba": - nopython, nogil, parallel = get_jit_arguments(engine_kwargs) - check_kwargs_and_nopython(kwargs, nopython) - validate_udf(func) - cache_key = (func, "groupby_transform") - numba_func = NUMBA_FUNC_CACHE.get( - cache_key, jit_user_function(func, nopython, nogil, parallel) + numba_func, cache_key = generate_numba_func( + func, engine_kwargs, kwargs, "groupby_transform" ) else: fast_path, slow_path = self._define_paths(func, *args, **kwargs) diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index f799baf354794..d67811988d0f8 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -56,11 +56,8 @@ ) from pandas.core.util.numba_ import ( NUMBA_FUNC_CACHE, - check_kwargs_and_nopython, - get_jit_arguments, - jit_user_function, + generate_numba_func, split_for_numba, - validate_udf, ) @@ -689,12 +686,8 @@ def _aggregate_series_pure_python( ): if engine == "numba": - nopython, nogil, parallel = get_jit_arguments(engine_kwargs) - check_kwargs_and_nopython(kwargs, nopython) - validate_udf(func) - cache_key = (func, "groupby_agg") - numba_func = NUMBA_FUNC_CACHE.get( - cache_key, jit_user_function(func, nopython, nogil, parallel) + numba_func, cache_key = generate_numba_func( + func, engine_kwargs, kwargs, "groupby_agg" ) group_index, _, ngroups = self.group_info diff --git a/pandas/core/util/numba_.py b/pandas/core/util/numba_.py index c2e4b38ad5b4d..c3f60ea7cc217 100644 --- a/pandas/core/util/numba_.py +++ b/pandas/core/util/numba_.py @@ -167,3 +167,43 @@ def f(values, index, ...): f"The first {min_number_args} arguments to {func.__name__} must be " f"{expected_args}" ) + + +def generate_numba_func( + func: Callable, + engine_kwargs: Optional[Dict[str, bool]], + kwargs: dict, + cache_key_str: str, +) -> Tuple[Callable, Tuple[Callable, str]]: + """ + Return a JITed function and cache key for the NUMBA_FUNC_CACHE + + This _may_ be specific to groupby (as it's only used there currently). + + Parameters + ---------- + func : function + user defined function + engine_kwargs : dict or None + numba.jit arguments + kwargs : dict + kwargs for func + cache_key_str : str + string representing the second part of the cache key tuple + + Returns + ------- + (JITed function, cache key) + + Raises + ------ + NumbaUtilError + """ + nopython, nogil, parallel = get_jit_arguments(engine_kwargs) + check_kwargs_and_nopython(kwargs, nopython) + validate_udf(func) + cache_key = (func, cache_key_str) + numba_func = NUMBA_FUNC_CACHE.get( + cache_key, jit_user_function(func, nopython, nogil, parallel) + ) + return numba_func, cache_key