diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 18506b871bda6..b2041a3e1380a 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -266,7 +266,11 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs) func = maybe_mangle_lambdas(func) ret = self._aggregate_multiple_funcs(func) if relabeling: - ret.columns = columns + # error: Incompatible types in assignment (expression has type + # "Optional[List[str]]", variable has type "Index") + ret.columns = columns # type: ignore[assignment] + return ret + else: cyfunc = com.get_cython_func(func) if cyfunc and not args and not kwargs: @@ -282,33 +286,21 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs) # see test_groupby.test_basic result = self._aggregate_named(func, *args, **kwargs) - index = Index(sorted(result), name=self.grouper.names[0]) - ret = create_series_with_explicit_dtype( - result, index=index, dtype_if_empty=object - ) - - if not self.as_index: # pragma: no cover - print("Warning, ignoring as_index=True") - - if isinstance(ret, dict): - from pandas import concat - - ret = concat(ret.values(), axis=1, keys=[key.label for key in ret.keys()]) - return ret + index = Index(sorted(result), name=self.grouper.names[0]) + return create_series_with_explicit_dtype( + result, index=index, dtype_if_empty=object + ) agg = aggregate - def _aggregate_multiple_funcs(self, arg): + def _aggregate_multiple_funcs(self, arg) -> DataFrame: if isinstance(arg, dict): # show the deprecation, but only if we # have not shown a higher level one # GH 15931 - if isinstance(self._selected_obj, Series): - raise SpecificationError("nested renamer is not supported") + raise SpecificationError("nested renamer is not supported") - columns = list(arg.keys()) - arg = arg.items() elif any(isinstance(x, (tuple, list)) for x in arg): arg = [(x, x) if not isinstance(x, (tuple, list)) else x for x in arg] @@ -335,8 +327,14 @@ def _aggregate_multiple_funcs(self, arg): results[base.OutputKey(label=name, position=idx)] = obj.aggregate(func) if any(isinstance(x, DataFrame) for x in results.values()): - # let higher level handle - return results + from pandas import concat + + res_df = concat( + results.values(), axis=1, keys=[key.label for key in results.keys()] + ) + # error: Incompatible return value type (got "Union[DataFrame, Series]", + # expected "DataFrame") + return res_df # type: ignore[return-value] indexed_output = {key.position: val for key, val in results.items()} output = self.obj._constructor_expanddim(indexed_output, index=None) @@ -1000,6 +998,11 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs) result = op.agg() if not is_dict_like(func) and result is not None: return result + elif relabeling and result is not None: + # this should be the only (non-raising) case with relabeling + # used reordered index of columns + result = result.iloc[:, order] + result.columns = columns if result is None: @@ -1039,12 +1042,6 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs) [sobj.columns.name] * result.columns.nlevels ).droplevel(-1) - if relabeling: - - # used reordered index of columns - result = result.iloc[:, order] - result.columns = columns - if not self.as_index: self._insert_inaxis_grouper_inplace(result) result.index = np.arange(len(result)) @@ -1389,9 +1386,7 @@ def _transform_item_by_item(self, obj: DataFrame, wrapper) -> DataFrame: if not output: raise TypeError("Transform function invalid for data types") - columns = obj.columns - if len(output) < len(obj.columns): - columns = columns.take(inds) + columns = obj.columns.take(inds) return self.obj._constructor(output, index=obj.index, columns=columns) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index c5ef18c51a533..d222bc1c083ed 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -2174,7 +2174,9 @@ def backfill(self, limit=None): @final @Substitution(name="groupby") @Substitution(see_also=_common_see_also) - def nth(self, n: int | list[int], dropna: str | None = None) -> DataFrame: + def nth( + self, n: int | list[int], dropna: Literal["any", "all", None] = None + ) -> DataFrame: """ Take the nth row from each group if n is an int, or a subset of rows if n is a list of ints. @@ -2187,9 +2189,9 @@ def nth(self, n: int | list[int], dropna: str | None = None) -> DataFrame: ---------- n : int or list of ints A single nth value for the row or a list of nth values. - dropna : None or str, optional + dropna : {'any', 'all', None}, default None Apply the specified dropna operation before counting which row is - the nth row. Needs to be None, 'any' or 'all'. + the nth row. Returns -------