diff --git a/doc/source/whatsnew/v3.0.0.rst b/doc/source/whatsnew/v3.0.0.rst index a15da861cfbec..a31e6868f53ac 100644 --- a/doc/source/whatsnew/v3.0.0.rst +++ b/doc/source/whatsnew/v3.0.0.rst @@ -42,6 +42,7 @@ Other enhancements - :meth:`DataFrame.corrwith` now accepts ``min_periods`` as optional arguments, as in :meth:`DataFrame.corr` and :meth:`Series.corr` (:issue:`9490`) - :meth:`DataFrame.cummin`, :meth:`DataFrame.cummax`, :meth:`DataFrame.cumprod` and :meth:`DataFrame.cumsum` methods now have a ``numeric_only`` parameter (:issue:`53072`) - :meth:`DataFrame.fillna` and :meth:`Series.fillna` can now accept ``value=None``; for non-object dtype the corresponding NA value will be used (:issue:`57723`) +- :meth:`GroupBy.transform` now accepts list-like arguments and dictionary arguments similar to :meth:`GroupBy.agg`, and supports :class:`NamedAgg` (:issue:`58318`) - :meth:`Series.cummin` and :meth:`Series.cummax` now supports :class:`CategoricalDtype` (:issue:`52335`) - :meth:`Series.plot` now correctly handle the ``ylabel`` parameter for pie charts, allowing for explicit control over the y-axis label (:issue:`58239`) - Support reading Stata 110-format (Stata 7) dta files (:issue:`47176`) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index a20577e8d3df9..45e3640e7541b 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -75,6 +75,7 @@ all_indexes_same, default_index, ) +from pandas.core.reshape.concat import concat from pandas.core.series import Series from pandas.core.sorting import get_group_index from pandas.core.util.numba_ import maybe_use_numba @@ -1863,15 +1864,145 @@ def _transform_general(self, func, engine, engine_kwargs, *args, **kwargs): 3 5 9 4 5 8 5 5 9 + + List-like arguments + + >>> df = pd.DataFrame({"col": list("aab"), "val": range(3), "other_val": range(3)}) + >>> df.groupby("col").transform(["sum", "min"]) + val other_val + sum min sum min + 0 1 0 1 0 + 1 1 0 1 0 + 2 2 2 2 2 + + .. versionchanged:: 3.0.0 + + Dictionary arguments + + >>> df = pd.DataFrame({"col": list("aab"), "val": range(3), "other_val": range(3)}) + >>> df.groupby("col").transform({"val": "sum", "other_val": "min"}) + val other_val + 0 1 0 + 1 1 0 + 2 2 2 + + .. versionchanged:: 3.0.0 + + Named aggregation + + >>> df = pd.DataFrame({"col": list("aab"), "val": range(3), "other_val": range(3)}) + >>> df.groupby("col").transform( + ... val_sum=pd.NamedAgg(column="val", aggfunc="sum"), + ... other_min=pd.NamedAgg(column="other_val", aggfunc="min") + ... ) + val_sum other_min + 0 1 0 + 1 1 0 + 2 2 2 + + .. versionchanged:: 3.0.0 """ ) @Substitution(klass="DataFrame", example=__examples_dataframe_doc) @Appender(_transform_template) - def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs): - return self._transform( - func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs + def transform( + self, + func: None + | (Callable | str | list[Callable | str] | dict[str, NamedAgg]) = None, + *args, + engine: str | None = None, + engine_kwargs: dict | None = None, + **kwargs, + ) -> DataFrame: + if func is None: + transformed_func = dict(kwargs.items()) + return self._transform_multiple_funcs( + transformed_func, *args, engine=engine, engine_kwargs=engine_kwargs + ) + elif isinstance(func, dict): + return self._transform_multiple_funcs( + func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs + ) + elif isinstance(func, list): + func = maybe_mangle_lambdas(func) + return self._transform_multiple_funcs( + func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs + ) + else: + return self._transform( + func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs + ) + + def _transform_multiple_funcs( + self, + func: Any, + *args, + engine: str | None = None, + engine_kwargs: dict | None = None, + **kwargs, + ) -> DataFrame: + if isinstance(func, dict): + results = [] + for name, agg in func.items(): + if isinstance(agg, NamedAgg): + column_name = agg.column + agg_func = agg.aggfunc + else: + column_name = name + agg_func = agg + result = self._transform_single_column( + column_name, + agg_func, + *args, + engine=engine, + engine_kwargs=engine_kwargs, + **kwargs, + ) + result.name = name + results.append(result) + output = concat(results, axis=1) + elif isinstance(func, list): + results = [] + col_order = [] + keys_list = list(self.keys) if isinstance(self.keys, list) else [self.keys] + for column in self.obj.columns: + if column in keys_list: + continue + column_results = [ + self._transform_single_column( + column, + agg_func, + *args, + engine=engine, + engine_kwargs=engine_kwargs, + **kwargs, + ).rename((column, agg_func)) + for agg_func in func + ] + for col_result in column_results: + results.append(col_result) + col_order.append(col_result.name) + output = concat(results, ignore_index=True, axis=1) + arrays = [list(x) for x in zip(*col_order)] + output.columns = MultiIndex.from_arrays(arrays) + + return output + + def _transform_single_column( + self, + column_name: Hashable, + agg_func: Callable | str, + *args, + engine: str | None = None, + engine_kwargs: dict | None = None, + **kwargs, + ) -> Series: + data = self._gotitem(column_name, ndim=1) + result = data.transform( + agg_func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs ) + return result def _define_paths(self, func, *args, **kwargs): if isinstance(func, str): diff --git a/pandas/tests/groupby/transform/test_transform.py b/pandas/tests/groupby/transform/test_transform.py index d6d545a8c4834..6dbfc0fd94e2d 100644 --- a/pandas/tests/groupby/transform/test_transform.py +++ b/pandas/tests/groupby/transform/test_transform.py @@ -19,6 +19,7 @@ date_range, ) import pandas._testing as tm +from pandas.core.groupby import NamedAgg from pandas.tests.groupby import get_groupby_method_args @@ -84,6 +85,60 @@ def demean(arr): tm.assert_frame_equal(result, expected) +def test_transform_with_list_like(): + df = DataFrame({"col": list("aab"), "val": range(3), "another": range(3)}) + result = df.groupby("col").transform(["sum", "min"]) + expected = DataFrame( + { + ("val", "sum"): [1, 1, 2], + ("val", "min"): [0, 0, 2], + ("another", "sum"): [1, 1, 2], + ("another", "min"): [0, 0, 2], + } + ) + expected.columns = MultiIndex.from_tuples( + [("val", "sum"), ("val", "min"), ("another", "sum"), ("another", "min")] + ) + tm.assert_frame_equal(result, expected) + + +def test_transform_with_dict(): + df = DataFrame({"col": list("aab"), "val": range(3), "another": range(3)}) + result = df.groupby("col").transform({"val": "sum", "another": "min"}) + expected = DataFrame({"val": [1, 1, 2], "another": [0, 0, 2]}) + tm.assert_frame_equal(result, expected) + + +def test_transform_with_namedagg(): + df = DataFrame({"A": list("aaabbbccc"), "B": range(9), "D": range(9, 18)}) + result = df.groupby("A").transform( + b_min=NamedAgg(column="B", aggfunc="min"), + d_sum=NamedAgg(column="D", aggfunc="sum"), + ) + expected = DataFrame( + { + "b_min": [0, 0, 0, 3, 3, 3, 6, 6, 6], + "d_sum": [30, 30, 30, 39, 39, 39, 48, 48, 48], + } + ) + tm.assert_frame_equal(result, expected) + + +def test_transform_with_duplicate_columns(): + df = DataFrame({"A": list("aaabbbccc"), "B": range(9, 18)}) + result = df.groupby("A").transform( + b_min=NamedAgg(column="B", aggfunc="min"), + b_max=NamedAgg(column="B", aggfunc="max"), + ) + expected = DataFrame( + { + "b_min": [9, 9, 9, 12, 12, 12, 15, 15, 15], + "b_max": [11, 11, 11, 14, 14, 14, 17, 17, 17], + } + ) + tm.assert_frame_equal(result, expected) + + def test_transform_fast(): df = DataFrame( {