diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 6f956a3dcc9b6..ebb9d82766c1b 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -19,6 +19,7 @@ Iterable, List, Mapping, + Optional, Sequence, Tuple, Type, @@ -30,7 +31,7 @@ import numpy as np from pandas._libs import lib -from pandas._typing import FrameOrSeries +from pandas._typing import FrameOrSeries, FrameOrSeriesUnion from pandas.util._decorators import Appender, Substitution, doc from pandas.core.dtypes.cast import ( @@ -413,12 +414,31 @@ def _wrap_transformed_output( assert isinstance(result, Series) return result - def _wrap_applied_output(self, keys, values, not_indexed_same=False): + def _wrap_applied_output( + self, keys: Index, values: Optional[List[Any]], not_indexed_same: bool = False + ) -> FrameOrSeriesUnion: + """ + Wrap the output of SeriesGroupBy.apply into the expected result. + + Parameters + ---------- + keys : Index + Keys of groups that Series was grouped by. + values : Optional[List[Any]] + Applied output for each group. + not_indexed_same : bool, default False + Whether the applied outputs are not indexed the same as the group axes. + + Returns + ------- + DataFrame or Series + """ if len(keys) == 0: # GH #6265 return self.obj._constructor( [], name=self._selection_name, index=keys, dtype=np.float64 ) + assert values is not None def _get_index() -> Index: if self.grouper.nkeys > 1: @@ -430,7 +450,7 @@ def _get_index() -> Index: if isinstance(values[0], dict): # GH #823 #24880 index = _get_index() - result = self._reindex_output( + result: FrameOrSeriesUnion = self._reindex_output( self.obj._constructor_expanddim(values, index=index) ) # if self.observed is False, @@ -438,11 +458,7 @@ def _get_index() -> Index: result = result.stack(dropna=self.observed) result.name = self._selection_name return result - - if isinstance(values[0], Series): - return self._concat_objects(keys, values, not_indexed_same=not_indexed_same) - elif isinstance(values[0], DataFrame): - # possible that Series -> DataFrame by applied function + elif isinstance(values[0], (Series, DataFrame)): return self._concat_objects(keys, values, not_indexed_same=not_indexed_same) else: # GH #6265 #24880