Skip to content

Commit 8d40358

Browse files
committed
smaller-diff implementation
1 parent 2c2e783 commit 8d40358

File tree

2 files changed

+24
-19
lines changed

2 files changed

+24
-19
lines changed

pandas/core/groupby/generic.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -219,9 +219,16 @@ def apply(self, func, *args, **kwargs) -> Series:
219219
def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs):
220220

221221
if maybe_use_numba(engine):
222-
return self._aggregate_with_numba(
223-
func, *args, engine_kwargs=engine_kwargs, **kwargs
222+
data = self._obj_with_exclusions
223+
result = self._aggregate_with_numba(
224+
data.to_frame(), func, *args, engine_kwargs=engine_kwargs, **kwargs
224225
)
226+
index = self.grouper.result_index
227+
result = self.obj._constructor(result.ravel(), index=index, name=data.name)
228+
if not self.as_index:
229+
result = self._insert_inaxis_grouper(result)
230+
result.index = default_index(len(result))
231+
return result
225232

226233
relabeling = func is None
227234
columns = None
@@ -1257,9 +1264,16 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
12571264
def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs):
12581265

12591266
if maybe_use_numba(engine):
1260-
return self._aggregate_with_numba(
1261-
func, *args, engine_kwargs=engine_kwargs, **kwargs
1267+
data = self._obj_with_exclusions
1268+
result = self._aggregate_with_numba(
1269+
data, func, *args, engine_kwargs=engine_kwargs, **kwargs
12621270
)
1271+
index = self.grouper.result_index
1272+
result = self.obj._constructor(result, index=index, columns=data.columns)
1273+
if not self.as_index:
1274+
result = self._insert_inaxis_grouper(result)
1275+
result.index = default_index(len(result))
1276+
return result
12631277

12641278
relabeling, func, columns, order = reconstruct_func(func, **kwargs)
12651279
func = maybe_mangle_lambdas(func)

pandas/core/groupby/groupby.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,18 +1270,17 @@ def _transform_with_numba(
12701270
return result.take(np.argsort(sorted_index), axis=0)
12711271

12721272
@final
1273-
def _aggregate_with_numba(self, func, *args, engine_kwargs=None, **kwargs):
1273+
def _aggregate_with_numba(
1274+
self, data: DataFrame, func, *args, engine_kwargs=None, **kwargs
1275+
):
12741276
"""
12751277
Perform groupby aggregation routine with the numba engine.
12761278
12771279
This routine mimics the data splitting routine of the DataSplitter class
12781280
to generate the indices of each group in the sorted data and then passes the
12791281
data and indices into a Numba jitted function.
12801282
"""
1281-
data = self._obj_with_exclusions
1282-
df = data if data.ndim == 2 else data.to_frame()
1283-
1284-
starts, ends, sorted_index, sorted_data = self._numba_prep(df)
1283+
starts, ends, sorted_index, sorted_data = self._numba_prep(data)
12851284
numba_.validate_udf(func)
12861285
numba_agg_func = numba_.generate_numba_agg_func(
12871286
func, **get_jit_arguments(engine_kwargs, kwargs)
@@ -1291,18 +1290,10 @@ def _aggregate_with_numba(self, func, *args, engine_kwargs=None, **kwargs):
12911290
sorted_index,
12921291
starts,
12931292
ends,
1294-
len(df.columns),
1293+
len(data.columns),
12951294
*args,
12961295
)
1297-
1298-
index = self.grouper.result_index
1299-
if data.ndim == 1:
1300-
result_kwargs = {"name": data.name}
1301-
result = result.ravel()
1302-
else:
1303-
result_kwargs = {"columns": data.columns}
1304-
result = data._constructor(result, index=index, **result_kwargs)
1305-
return self._wrap_aggregated_output(result)
1296+
return result
13061297

13071298
# -----------------------------------------------------------------
13081299
# apply/agg/transform

0 commit comments

Comments
 (0)