Skip to content

Commit bae0bf0

Browse files
authored
REF: prune groupby paths (#51187)
* REF: avoid handling corner cases in op_via_apply * simplify _wrap_aggregated_output * REF: remove _wrap_transformed_output * final * mypy fixup * remove unnecessary
1 parent eec200d commit bae0bf0

File tree

4 files changed

+40
-73
lines changed

4 files changed

+40
-73
lines changed

pandas/core/groupby/generic.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -390,10 +390,16 @@ def _wrap_applied_output(
390390
"""
391391
if len(values) == 0:
392392
# GH #6265
393+
if is_transform:
394+
# GH#47787 see test_group_on_empty_multiindex
395+
res_index = data.index
396+
else:
397+
res_index = self.grouper.result_index
398+
393399
return self.obj._constructor(
394400
[],
395401
name=self.obj.name,
396-
index=self.grouper.result_index,
402+
index=res_index,
397403
dtype=data.dtype,
398404
)
399405
assert values is not None
@@ -1146,14 +1152,12 @@ def cov(
11461152
@property
11471153
@doc(Series.is_monotonic_increasing.__doc__)
11481154
def is_monotonic_increasing(self) -> Series:
1149-
result = self._op_via_apply("is_monotonic_increasing")
1150-
return result
1155+
return self.apply(lambda ser: ser.is_monotonic_increasing)
11511156

11521157
@property
11531158
@doc(Series.is_monotonic_decreasing.__doc__)
11541159
def is_monotonic_decreasing(self) -> Series:
1155-
result = self._op_via_apply("is_monotonic_decreasing")
1156-
return result
1160+
return self.apply(lambda ser: ser.is_monotonic_decreasing)
11571161

11581162
@doc(Series.hist.__doc__)
11591163
def hist(
@@ -1191,8 +1195,7 @@ def hist(
11911195
@property
11921196
@doc(Series.dtype.__doc__)
11931197
def dtype(self) -> Series:
1194-
result = self._op_via_apply("dtype")
1195-
return result
1198+
return self.apply(lambda ser: ser.dtype)
11961199

11971200
@doc(Series.unique.__doc__)
11981201
def unique(self) -> Series:
@@ -1438,9 +1441,13 @@ def _wrap_applied_output(
14381441
):
14391442

14401443
if len(values) == 0:
1441-
result = self.obj._constructor(
1442-
index=self.grouper.result_index, columns=data.columns
1443-
)
1444+
if is_transform:
1445+
# GH#47787 see test_group_on_empty_multiindex
1446+
res_index = data.index
1447+
else:
1448+
res_index = self.grouper.result_index
1449+
1450+
result = self.obj._constructor(index=res_index, columns=data.columns)
14441451
result = result.astype(data.dtypes, copy=False)
14451452
return result
14461453

@@ -1729,18 +1736,11 @@ def _transform_item_by_item(self, obj: DataFrame, wrapper) -> DataFrame:
17291736
# iterate through columns, see test_transform_exclude_nuisance
17301737
# gets here with non-unique columns
17311738
output = {}
1732-
inds = []
17331739
for i, (colname, sgb) in enumerate(self._iterate_column_groupbys(obj)):
17341740
output[i] = sgb.transform(wrapper)
1735-
inds.append(i)
1736-
1737-
if not output:
1738-
raise TypeError("Transform function invalid for data types")
1739-
1740-
columns = obj.columns.take(inds)
17411741

17421742
result = self.obj._constructor(output, index=obj.index)
1743-
result.columns = columns
1743+
result.columns = obj.columns
17441744
return result
17451745

17461746
def filter(self, func, dropna: bool = True, *args, **kwargs):
@@ -2693,8 +2693,8 @@ def hist(
26932693
@property
26942694
@doc(DataFrame.dtypes.__doc__)
26952695
def dtypes(self) -> Series:
2696-
result = self._op_via_apply("dtypes")
2697-
return result
2696+
# error: Incompatible return value type (got "DataFrame", expected "Series")
2697+
return self.apply(lambda df: df.dtypes) # type: ignore[return-value]
26982698

26992699
@doc(DataFrame.corrwith.__doc__)
27002700
def corrwith(

pandas/core/groupby/groupby.py

Lines changed: 19 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -956,9 +956,6 @@ def __getattr__(self, attr: str):
956956
def _op_via_apply(self, name: str, *args, **kwargs):
957957
"""Compute the result of an operation by using GroupBy's apply."""
958958
f = getattr(type(self._obj_with_exclusions), name)
959-
if not callable(f):
960-
return self.apply(lambda self: getattr(self, name))
961-
962959
sig = inspect.signature(f)
963960

964961
# a little trickery for aggregation functions that need an axis
@@ -980,9 +977,6 @@ def curried(x):
980977
return self.apply(curried)
981978

982979
is_transform = name in base.transformation_kernels
983-
# Transform needs to keep the same schema, including when empty
984-
if is_transform and self._obj_with_exclusions.empty:
985-
return self._obj_with_exclusions
986980
result = self._python_apply_general(
987981
curried,
988982
self._obj_with_exclusions,
@@ -1105,6 +1099,7 @@ def _set_result_index_ordered(
11051099

11061100
return result
11071101

1102+
@final
11081103
def _insert_inaxis_grouper(self, result: Series | DataFrame) -> DataFrame:
11091104
if isinstance(result, Series):
11101105
result = result.to_frame()
@@ -1131,30 +1126,22 @@ def _indexed_output_to_ndframe(
11311126
@final
11321127
def _wrap_aggregated_output(
11331128
self,
1134-
output: Series | DataFrame | Mapping[base.OutputKey, ArrayLike],
1129+
result: Series | DataFrame,
11351130
qs: npt.NDArray[np.float64] | None = None,
11361131
):
11371132
"""
11381133
Wraps the output of GroupBy aggregations into the expected result.
11391134
11401135
Parameters
11411136
----------
1142-
output : Series, DataFrame, or Mapping[base.OutputKey, ArrayLike]
1143-
Data to wrap.
1137+
result : Series, DataFrame
11441138
11451139
Returns
11461140
-------
11471141
Series or DataFrame
11481142
"""
1149-
1150-
if isinstance(output, (Series, DataFrame)):
1151-
# We get here (for DataFrameGroupBy) if we used Manager.grouped_reduce,
1152-
# in which case our columns are already set correctly.
1153-
# ATM we do not get here for SeriesGroupBy; when we do, we will
1154-
# need to require that result.name already match self.obj.name
1155-
result = output
1156-
else:
1157-
result = self._indexed_output_to_ndframe(output)
1143+
# ATM we do not get here for SeriesGroupBy; when we do, we will
1144+
# need to require that result.name already match self.obj.name
11581145

11591146
if not self.as_index:
11601147
# `not self.as_index` is only relevant for DataFrameGroupBy,
@@ -1183,36 +1170,6 @@ def _wrap_aggregated_output(
11831170

11841171
return self._reindex_output(result, qs=qs)
11851172

1186-
@final
1187-
def _wrap_transformed_output(
1188-
self, output: Mapping[base.OutputKey, ArrayLike]
1189-
) -> Series | DataFrame:
1190-
"""
1191-
Wraps the output of GroupBy transformations into the expected result.
1192-
1193-
Parameters
1194-
----------
1195-
output : Mapping[base.OutputKey, ArrayLike]
1196-
Data to wrap.
1197-
1198-
Returns
1199-
-------
1200-
Series or DataFrame
1201-
Series for SeriesGroupBy, DataFrame for DataFrameGroupBy
1202-
"""
1203-
if isinstance(output, (Series, DataFrame)):
1204-
result = output
1205-
else:
1206-
result = self._indexed_output_to_ndframe(output)
1207-
1208-
if self.axis == 1:
1209-
# Only relevant for DataFrameGroupBy
1210-
result = result.T
1211-
result.columns = self.obj.columns
1212-
1213-
result.index = self.obj.index
1214-
return result
1215-
12161173
def _wrap_applied_output(
12171174
self,
12181175
data,
@@ -1456,7 +1413,8 @@ def _python_agg_general(self, func, *args, **kwargs):
14561413
output: dict[base.OutputKey, ArrayLike] = {}
14571414

14581415
if self.ngroups == 0:
1459-
# agg_series below assumes ngroups > 0
1416+
# e.g. test_evaluate_with_empty_groups different path gets different
1417+
# result dtype in empty case.
14601418
return self._python_apply_general(f, self._selected_obj, is_agg=True)
14611419

14621420
for idx, obj in enumerate(self._iterate_slices()):
@@ -1466,9 +1424,11 @@ def _python_agg_general(self, func, *args, **kwargs):
14661424
output[key] = result
14671425

14681426
if not output:
1427+
# e.g. test_groupby_crash_on_nunique, test_margins_no_values_no_cols
14691428
return self._python_apply_general(f, self._selected_obj)
14701429

1471-
return self._wrap_aggregated_output(output)
1430+
res = self._indexed_output_to_ndframe(output)
1431+
return self._wrap_aggregated_output(res)
14721432

14731433
@final
14741434
def _agg_general(
@@ -1837,6 +1797,7 @@ def hfunc(bvalues: ArrayLike) -> ArrayLike:
18371797
# If we are grouping on categoricals we want unobserved categories to
18381798
# return zero, rather than the default of NaN which the reindexing in
18391799
# _wrap_agged_manager() returns. GH 35028
1800+
# e.g. test_dataframe_groupby_on_2_categoricals_when_observed_is_false
18401801
with com.temp_setattr(self, "observed", True):
18411802
result = self._wrap_agged_manager(new_mgr)
18421803

@@ -2555,6 +2516,7 @@ def ohlc(self) -> DataFrame:
25552516
)
25562517
return self._reindex_output(result)
25572518

2519+
# TODO: 2023-02-05 all tests that get here have self.as_index
25582520
return self._apply_to_column_groupbys(
25592521
lambda x: x.ohlc(), self._obj_with_exclusions
25602522
)
@@ -2832,7 +2794,13 @@ def blk_func(values: ArrayLike) -> ArrayLike:
28322794
if isinstance(new_obj, Series):
28332795
new_obj.name = obj.name
28342796

2835-
return self._wrap_transformed_output(new_obj)
2797+
if self.axis == 1:
2798+
# Only relevant for DataFrameGroupBy
2799+
new_obj = new_obj.T
2800+
new_obj.columns = self.obj.columns
2801+
2802+
new_obj.index = self.obj.index
2803+
return new_obj
28362804

28372805
@final
28382806
@Substitution(name="groupby")

pandas/core/groupby/ops.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,7 +1028,6 @@ def _aggregate_series_pure_python(
10281028
) -> npt.NDArray[np.object_]:
10291029
ids, _, ngroups = self.group_info
10301030

1031-
counts = np.zeros(ngroups, dtype=int)
10321031
result = np.empty(ngroups, dtype="O")
10331032
initialized = False
10341033

@@ -1044,7 +1043,6 @@ def _aggregate_series_pure_python(
10441043
libreduction.check_result_array(res, group.dtype)
10451044
initialized = True
10461045

1047-
counts[i] = group.shape[0]
10481046
result[i] = res
10491047

10501048
return result

pandas/core/resample.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,7 @@ def _groupby_and_aggregate(self, how, *args, **kwargs):
435435
try:
436436
if isinstance(obj, ABCDataFrame) and callable(how):
437437
# Check if the function is reducing or not.
438+
# e.g. test_resample_apply_with_additional_args
438439
result = grouped._aggregate_item_by_item(how, *args, **kwargs)
439440
else:
440441
result = grouped.aggregate(how, *args, **kwargs)

0 commit comments

Comments
 (0)