From 0d336de321ac93c650336deefed261c98910317d Mon Sep 17 00:00:00 2001 From: Brock Date: Thu, 9 Sep 2021 16:48:21 -0700 Subject: [PATCH] REF: avoid unnecessary argument to groupby numba funcs --- pandas/core/groupby/generic.py | 6 ++++-- pandas/core/groupby/groupby.py | 6 +----- pandas/core/groupby/numba_.py | 18 ++++++++++-------- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index c57f64e306199..1a413e48708b5 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -228,9 +228,10 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs) if maybe_use_numba(engine): with group_selection_context(self): data = self._selected_obj - result, index = self._aggregate_with_numba( + result = self._aggregate_with_numba( data.to_frame(), func, *args, engine_kwargs=engine_kwargs, **kwargs ) + index = self._group_keys_index return self.obj._constructor(result.ravel(), index=index, name=data.name) relabeling = func is None @@ -926,9 +927,10 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs) if maybe_use_numba(engine): with group_selection_context(self): data = self._selected_obj - result, index = self._aggregate_with_numba( + result = self._aggregate_with_numba( data, func, *args, engine_kwargs=engine_kwargs, **kwargs ) + index = self._group_keys_index return self.obj._constructor(result, index=index, columns=data.columns) relabeling, func, columns, order = reconstruct_func(func, **kwargs) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 85e7c9a62b2d4..0bb408f7040ce 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1269,7 +1269,6 @@ def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs) data and indices into a Numba jitted function. """ starts, ends, sorted_index, sorted_data = self._numba_prep(func, data) - group_keys = self.grouper.group_keys_seq numba_transform_func = numba_.generate_numba_transform_func( kwargs, func, engine_kwargs @@ -1279,7 +1278,6 @@ def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs) sorted_index, starts, ends, - len(group_keys), len(data.columns), *args, ) @@ -1302,7 +1300,6 @@ def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs) data and indices into a Numba jitted function. """ starts, ends, sorted_index, sorted_data = self._numba_prep(func, data) - index = self._group_keys_index numba_agg_func = numba_.generate_numba_agg_func(kwargs, func, engine_kwargs) result = numba_agg_func( @@ -1310,7 +1307,6 @@ def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs) sorted_index, starts, ends, - len(index), len(data.columns), *args, ) @@ -1319,7 +1315,7 @@ def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs) if cache_key not in NUMBA_FUNC_CACHE: NUMBA_FUNC_CACHE[cache_key] = numba_agg_func - return result, index + return result # ----------------------------------------------------------------- # apply/agg/transform diff --git a/pandas/core/groupby/numba_.py b/pandas/core/groupby/numba_.py index ad78280c5d835..beb77360d5a3f 100644 --- a/pandas/core/groupby/numba_.py +++ b/pandas/core/groupby/numba_.py @@ -59,9 +59,7 @@ def generate_numba_agg_func( kwargs: dict[str, Any], func: Callable[..., Scalar], engine_kwargs: dict[str, bool] | None, -) -> Callable[ - [np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, int, Any], np.ndarray -]: +) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, Any], np.ndarray]: """ Generate a numba jitted agg function specified by values from engine_kwargs. @@ -100,10 +98,13 @@ def group_agg( index: np.ndarray, begin: np.ndarray, end: np.ndarray, - num_groups: int, num_columns: int, *args: Any, ) -> np.ndarray: + + assert len(begin) == len(end) + num_groups = len(begin) + result = np.empty((num_groups, num_columns)) for i in numba.prange(num_groups): group_index = index[begin[i] : end[i]] @@ -119,9 +120,7 @@ def generate_numba_transform_func( kwargs: dict[str, Any], func: Callable[..., np.ndarray], engine_kwargs: dict[str, bool] | None, -) -> Callable[ - [np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, int, Any], np.ndarray -]: +) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, Any], np.ndarray]: """ Generate a numba jitted transform function specified by values from engine_kwargs. @@ -160,10 +159,13 @@ def group_transform( index: np.ndarray, begin: np.ndarray, end: np.ndarray, - num_groups: int, num_columns: int, *args: Any, ) -> np.ndarray: + + assert len(begin) == len(end) + num_groups = len(begin) + result = np.empty((len(values), num_columns)) for i in numba.prange(num_groups): group_index = index[begin[i] : end[i]]