From df9b39a232a5e32fc7ba6ec3309a45a4c08e49fc Mon Sep 17 00:00:00 2001 From: Ormorod Date: Thu, 2 Mar 2023 15:54:58 +0000 Subject: [PATCH 01/26] a whole bunch of 1d constructors --- pandas/core/groupby/generic.py | 20 ++++++++++++-------- pandas/core/groupby/groupby.py | 6 ++++-- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 60c43b6cf0ecd..dedeb2e24fdac 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -271,7 +271,9 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs) result = self._aggregate_named(func, *args, **kwargs) # result is a dict whose keys are the elements of result_index - result = Series(result, index=self.grouper.result_index) + result = self._obj_1d_constructor( + result, index=self.grouper.result_index + ) result = self._wrap_aggregated_output(result) return result @@ -674,7 +676,7 @@ def value_counts( # in a backward compatible way # GH38672 relates to categorical dtype ser = self.apply( - Series.value_counts, + self._obj_1d_constructor.value_counts, normalize=normalize, sort=sort, ascending=ascending, @@ -693,7 +695,7 @@ def value_counts( llab = lambda lab, inc: lab[inc] else: # lab is a Categorical with categories an IntervalIndex - cat_ser = cut(Series(val), bins, include_lowest=True) + cat_ser = cut(self.obj._constructor(val), bins, include_lowest=True) cat_obj = cast("Categorical", cat_ser._values) lev = cat_obj.categories lab = lev.take( @@ -1276,9 +1278,9 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs) elif relabeling: # this should be the only (non-raising) case with relabeling # used reordered index of columns - result = cast(DataFrame, result) + result = cast(self._obj_1d_constructor, result) result = result.iloc[:, order] - result = cast(DataFrame, result) + result = cast(self._obj_1d_constructor, result) # error: Incompatible types in assignment (expression has type # "Optional[List[str]]", variable has type # "Union[Union[Union[ExtensionArray, ndarray[Any, Any]], @@ -1321,7 +1323,7 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs) else: # GH#32040, GH#35246 # e.g. test_groupby_as_index_select_column_sum_empty_df - result = cast(DataFrame, result) + result = cast(self._obj_1d_constructor, result) result.columns = self._obj_with_exclusions.columns.copy() if not self.as_index: @@ -1449,7 +1451,7 @@ def _wrap_applied_output_series( is_transform: bool, ) -> DataFrame | Series: kwargs = first_not_none._construct_axes_dict() - backup = Series(**kwargs) + backup = self._obj_1d_constructor(**kwargs) values = [x if (x is not None) else backup for x in values] all_indexed_same = all_indexes_same(x.index for x in values) @@ -1840,7 +1842,9 @@ def _apply_to_column_groupbys(self, func) -> DataFrame: if not len(results): # concat would raise - res_df = DataFrame([], columns=columns, index=self.grouper.result_index) + res_df = self.obj._constructor( + [], columns=columns, index=self.grouper.result_index + ) else: res_df = concat(results, keys=columns, axis=1) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 55e14bc11246b..91c6290c4c2db 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1629,7 +1629,7 @@ def _cumcount_array(self, ascending: bool = True) -> np.ndarray: @final @property - def _obj_1d_constructor(self) -> Callable: + def _obj_1d_constructor(self): # GH28330 preserve subclassed Series/DataFrames if isinstance(self.obj, DataFrame): return self.obj._constructor_sliced @@ -1844,7 +1844,9 @@ def mean( else: result = self._cython_agg_general( "mean", - alt=lambda x: Series(x).mean(numeric_only=numeric_only), + alt=lambda x: self._obj_1d_constructor(x).mean( + numeric_only=numeric_only + ), numeric_only=numeric_only, ) return result.__finalize__(self.obj, method="groupby") From de8c2314c0665a8bf6930527468716b3f0c1449c Mon Sep 17 00:00:00 2001 From: Ormorod Date: Thu, 2 Mar 2023 20:11:48 +0000 Subject: [PATCH 02/26] got the first three of Lukas' assertions working --- pandas/core/groupby/groupby.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 91c6290c4c2db..e05559de000a8 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1837,6 +1837,16 @@ def mean( Name: B, dtype: float64 """ + if not (type(self.obj) == Series or type(self.obj) == DataFrame): + + def f(df, *args, **kwargs): + print("df") + print(df) + return self.obj._constructor(df).mean() + + result = self.agg(f) + return result.__finalize__(self.obj, method="groupby") + if maybe_use_numba(engine): from pandas.core._numba.kernels import sliding_mean From d057cd08eda3c5c4f097c73728a74a5ce5d31b46 Mon Sep 17 00:00:00 2001 From: Ormorod Date: Thu, 2 Mar 2023 20:17:23 +0000 Subject: [PATCH 03/26] remove print statements --- pandas/core/groupby/groupby.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index e05559de000a8..f804f50d49b8f 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1840,8 +1840,6 @@ def mean( if not (type(self.obj) == Series or type(self.obj) == DataFrame): def f(df, *args, **kwargs): - print("df") - print(df) return self.obj._constructor(df).mean() result = self.agg(f) From 5338d3f0bb09689842276b91fa845ee4b424b545 Mon Sep 17 00:00:00 2001 From: Ormorod Date: Fri, 3 Mar 2023 10:17:52 +0000 Subject: [PATCH 04/26] create test for overridden methods --- pandas/tests/groupby/test_groupby_subclass.py | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/pandas/tests/groupby/test_groupby_subclass.py b/pandas/tests/groupby/test_groupby_subclass.py index 773c1e60e97af..5c54ece9b290e 100644 --- a/pandas/tests/groupby/test_groupby_subclass.py +++ b/pandas/tests/groupby/test_groupby_subclass.py @@ -103,3 +103,46 @@ def test_groupby_resample_preserves_subclass(obj): # Confirm groupby.resample() preserves dataframe type result = df.groupby("Buyer").resample("5D").sum() assert isinstance(result, obj) + + +def test_groupby_overridden_methods(): + class UnitSeries(Series): + @property + def _constructor(self): + return UnitSeries + + @property + def _constructor_expanddim(self): + return UnitDataFrame + + def mean(self, *args, **kwargs): + return 1 + + def beans(self, *args, **kwargs): + return "series toast" + + class UnitDataFrame(DataFrame): + @property + def _constructor(self): + return UnitDataFrame + + @property + def _constructor_expanddim(self): + return UnitSeries + + def mean(self, *args, **kwargs): + return 1 + + def beans(self, *args, **kwargs): + return "df toast" + + params = ["a", "b"] + data = np.random.rand(4, 2) + udf = UnitDataFrame(data, columns=params) + udf["group"] = np.ones(4, dtype=int) + udf.loc[2:, "group"] = 2 + + assert udf.mean() == 1 + assert all(udf.groupby("group").mean() == 1) + assert udf.beans() == "df toast" + # print(udf.groupby('group').beans()) # AttributeError From 5fe7f821bfe4248c08a350859463dc6a18ba9d54 Mon Sep 17 00:00:00 2001 From: Ormorod Date: Fri, 3 Mar 2023 12:37:01 +0000 Subject: [PATCH 05/26] also attack median, std, var, sem, prod, sum, min and max --- pandas/core/groupby/groupby.py | 69 +++++++++++++++++++++++++++++++++- 1 file changed, 67 insertions(+), 2 deletions(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index f804f50d49b8f..bd48671ce25a1 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1880,9 +1880,17 @@ def median(self, numeric_only: bool = False): Series or DataFrame Median of values within each group. """ + if not (type(self.obj) == Series or type(self.obj) == DataFrame): + + def f(df, *args, **kwargs): + return self.obj._constructor(df).median() + + result = self.agg(f) + return result.__finalize__(self.obj, method="groupby") + result = self._cython_agg_general( "median", - alt=lambda x: Series(x).median(numeric_only=numeric_only), + alt=lambda x: self._obj_1d_constructor(x).median(numeric_only=numeric_only), numeric_only=numeric_only, ) return result.__finalize__(self.obj, method="groupby") @@ -1938,6 +1946,14 @@ def std( Series or DataFrame Standard deviation of values within each group. """ + if not (type(self.obj) == Series or type(self.obj) == DataFrame): + + def f(df, *args, **kwargs): + return self.obj._constructor(df).std() + + result = self.agg(f) + return result.__finalize__(self.obj, method="groupby") + if maybe_use_numba(engine): from pandas.core._numba.kernels import sliding_var @@ -2021,6 +2037,14 @@ def var( Series or DataFrame Variance of values within each group. """ + if not (type(self.obj) == Series or type(self.obj) == DataFrame): + + def f(df, *args, **kwargs): + return self.obj._constructor(df).var() + + result = self.agg(f) + return result.__finalize__(self.obj, method="groupby") + if maybe_use_numba(engine): from pandas.core._numba.kernels import sliding_var @@ -2028,7 +2052,7 @@ def var( else: return self._cython_agg_general( "var", - alt=lambda x: Series(x).var(ddof=ddof), + alt=lambda x: self._obj_1d_constructor(x).var(ddof=ddof), numeric_only=numeric_only, ddof=ddof, ) @@ -2190,6 +2214,15 @@ def sem(self, ddof: int = 1, numeric_only: bool = False): Series or DataFrame Standard error of the mean of values within each group. """ + # TODO: think sem() needs considering more closely + if not (type(self.obj) == Series or type(self.obj) == DataFrame): + + def f(df, *args, **kwargs): + return self.obj._constructor(df).sem() + + result = self.agg(f) + return result.__finalize__(self.obj, method="groupby") + if numeric_only and self.obj.ndim == 1 and not is_numeric_dtype(self.obj.dtype): raise TypeError( f"{type(self).__name__}.sem called with " @@ -2248,6 +2281,14 @@ def sum( engine: str | None = None, engine_kwargs: dict[str, bool] | None = None, ): + if not (type(self.obj) == Series or type(self.obj) == DataFrame): + + def f(df, *args, **kwargs): + return self.obj._constructor(df).sum() + + result = self.agg(f) + return result.__finalize__(self.obj, method="groupby") + if maybe_use_numba(engine): from pandas.core._numba.kernels import sliding_sum @@ -2272,6 +2313,14 @@ def sum( @final @doc(_groupby_agg_method_template, fname="prod", no=False, mc=0) def prod(self, numeric_only: bool = False, min_count: int = 0): + if not (type(self.obj) == Series or type(self.obj) == DataFrame): + + def f(df, *args, **kwargs): + return self.obj._constructor(df).prod() + + result = self.agg(f) + return result.__finalize__(self.obj, method="groupby") + return self._agg_general( numeric_only=numeric_only, min_count=min_count, alias="prod", npfunc=np.prod ) @@ -2285,6 +2334,14 @@ def min( engine: str | None = None, engine_kwargs: dict[str, bool] | None = None, ): + if not (type(self.obj) == Series or type(self.obj) == DataFrame): + + def f(df, *args, **kwargs): + return self.obj._constructor(df).min() + + result = self.agg(f) + return result.__finalize__(self.obj, method="groupby") + if maybe_use_numba(engine): from pandas.core._numba.kernels import sliding_min_max @@ -2306,6 +2363,14 @@ def max( engine: str | None = None, engine_kwargs: dict[str, bool] | None = None, ): + if not (type(self.obj) == Series or type(self.obj) == DataFrame): + + def f(df, *args, **kwargs): + return self.obj._constructor(df).max() + + result = self.agg(f) + return result.__finalize__(self.obj, method="groupby") + if maybe_use_numba(engine): from pandas.core._numba.kernels import sliding_min_max From 4aa2b85dbf11aa146f2b65d67eccd0e5b782a161 Mon Sep 17 00:00:00 2001 From: Ormorod Date: Fri, 3 Mar 2023 18:01:17 +0000 Subject: [PATCH 06/26] add tests for test of methods --- pandas/tests/groupby/test_groupby_subclass.py | 53 +++++++++++++++++-- 1 file changed, 48 insertions(+), 5 deletions(-) diff --git a/pandas/tests/groupby/test_groupby_subclass.py b/pandas/tests/groupby/test_groupby_subclass.py index 5c54ece9b290e..13a399c0c663f 100644 --- a/pandas/tests/groupby/test_groupby_subclass.py +++ b/pandas/tests/groupby/test_groupby_subclass.py @@ -118,8 +118,26 @@ def _constructor_expanddim(self): def mean(self, *args, **kwargs): return 1 - def beans(self, *args, **kwargs): - return "series toast" + def median(self, *args, **kwargs): + return 2 + + def std(self, *args, **kwargs): + return 3 + + def var(self, *args, **kwargs): + return 4 + + def sem(self, *args, **kwargs): + return 5 + + def prod(self, *args, **kwargs): + return 6 + + def min(self, *args, **kwargs): + return 7 + + def max(self, *args, **kwargs): + return 8 class UnitDataFrame(DataFrame): @property @@ -131,10 +149,29 @@ def _constructor_expanddim(self): return UnitSeries def mean(self, *args, **kwargs): + print("UnitDataFrame mean") return 1 - def beans(self, *args, **kwargs): - return "df toast" + def median(self, *args, **kwargs): + return 2 + + def std(self, *args, **kwargs): + return 3 + + def var(self, *args, **kwargs): + return 4 + + def sem(self, *args, **kwargs): + return 5 + + def prod(self, *args, **kwargs): + return 6 + + def min(self, *args, **kwargs): + return 7 + + def max(self, *args, **kwargs): + return 8 params = ["a", "b"] data = np.random.rand(4, 2) @@ -144,5 +181,11 @@ def beans(self, *args, **kwargs): assert udf.mean() == 1 assert all(udf.groupby("group").mean() == 1) - assert udf.beans() == "df toast" + assert all(udf.groupby("group").median() == 1) + assert all(udf.groupby("group").std() == 1) + assert all(udf.groupby("group").var() == 1) + assert all(udf.groupby("group").sem() == 1) + assert all(udf.groupby("group").prod() == 1) + assert all(udf.groupby("group").min() == 1) + assert all(udf.groupby("group").max() == 1) # print(udf.groupby('group').beans()) # AttributeError From 8d7346dadd2edbab978c652e4a88c6992af8da5f Mon Sep 17 00:00:00 2001 From: Ormorod Date: Fri, 3 Mar 2023 18:09:00 +0000 Subject: [PATCH 07/26] change 1d constructors to constructors --- pandas/core/groupby/generic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 8e3f7b165b745..db6f9d7b87025 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -1291,9 +1291,9 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs) elif relabeling: # this should be the only (non-raising) case with relabeling # used reordered index of columns - result = cast(self._obj_1d_constructor, result) + result = cast(self.obj._constructor, result) result = result.iloc[:, order] - result = cast(self._obj_1d_constructor, result) + result = cast(self.obj._constructor, result) # error: Incompatible types in assignment (expression has type # "Optional[List[str]]", variable has type # "Union[Union[Union[ExtensionArray, ndarray[Any, Any]], From 37ae233a780d5ed2ef80a53ef15796a1062c206f Mon Sep 17 00:00:00 2001 From: Ormorod Date: Fri, 3 Mar 2023 18:15:53 +0000 Subject: [PATCH 08/26] tidy up --- pandas/tests/groupby/test_groupby_subclass.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/pandas/tests/groupby/test_groupby_subclass.py b/pandas/tests/groupby/test_groupby_subclass.py index 13a399c0c663f..90ca26ffecf6a 100644 --- a/pandas/tests/groupby/test_groupby_subclass.py +++ b/pandas/tests/groupby/test_groupby_subclass.py @@ -149,7 +149,6 @@ def _constructor_expanddim(self): return UnitSeries def mean(self, *args, **kwargs): - print("UnitDataFrame mean") return 1 def median(self, *args, **kwargs): @@ -179,13 +178,11 @@ def max(self, *args, **kwargs): udf["group"] = np.ones(4, dtype=int) udf.loc[2:, "group"] = 2 - assert udf.mean() == 1 assert all(udf.groupby("group").mean() == 1) - assert all(udf.groupby("group").median() == 1) - assert all(udf.groupby("group").std() == 1) - assert all(udf.groupby("group").var() == 1) - assert all(udf.groupby("group").sem() == 1) - assert all(udf.groupby("group").prod() == 1) - assert all(udf.groupby("group").min() == 1) - assert all(udf.groupby("group").max() == 1) - # print(udf.groupby('group').beans()) # AttributeError + assert all(udf.groupby("group").median() == 2) + assert all(udf.groupby("group").std() == 3) + assert all(udf.groupby("group").var() == 4) + assert all(udf.groupby("group").sem() == 5) + assert all(udf.groupby("group").prod() == 6) + assert all(udf.groupby("group").min() == 7) + assert all(udf.groupby("group").max() == 8) From aa57cc2257570cd71dd4f5e947b45220ec383deb Mon Sep 17 00:00:00 2001 From: Ormorod Date: Fri, 3 Mar 2023 18:22:45 +0000 Subject: [PATCH 09/26] change to np.all --- pandas/tests/groupby/test_groupby_subclass.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/pandas/tests/groupby/test_groupby_subclass.py b/pandas/tests/groupby/test_groupby_subclass.py index 90ca26ffecf6a..b20ed0e3a21c1 100644 --- a/pandas/tests/groupby/test_groupby_subclass.py +++ b/pandas/tests/groupby/test_groupby_subclass.py @@ -116,7 +116,7 @@ def _constructor_expanddim(self): return UnitDataFrame def mean(self, *args, **kwargs): - return 1 + return 2 def median(self, *args, **kwargs): return 2 @@ -149,7 +149,7 @@ def _constructor_expanddim(self): return UnitSeries def mean(self, *args, **kwargs): - return 1 + return 2 def median(self, *args, **kwargs): return 2 @@ -178,11 +178,11 @@ def max(self, *args, **kwargs): udf["group"] = np.ones(4, dtype=int) udf.loc[2:, "group"] = 2 - assert all(udf.groupby("group").mean() == 1) - assert all(udf.groupby("group").median() == 2) - assert all(udf.groupby("group").std() == 3) - assert all(udf.groupby("group").var() == 4) - assert all(udf.groupby("group").sem() == 5) - assert all(udf.groupby("group").prod() == 6) - assert all(udf.groupby("group").min() == 7) - assert all(udf.groupby("group").max() == 8) + assert np.all(udf.groupby("group").mean() == 1) + assert np.all(udf.groupby("group").median() == 2) + assert np.all(udf.groupby("group").std() == 3) + assert np.all(udf.groupby("group").var() == 4) + assert np.all(udf.groupby("group").sem() == 5) + assert np.all(udf.groupby("group").prod() == 6) + assert np.all(udf.groupby("group").min() == 7) + assert np.all(udf.groupby("group").max() == 8) From b3df075a4691ada5e0ccb7e81aeab13e0a68c3c3 Mon Sep 17 00:00:00 2001 From: Ormorod Date: Fri, 3 Mar 2023 19:17:00 +0000 Subject: [PATCH 10/26] remove deliberate test failure --- pandas/tests/groupby/test_groupby_subclass.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas/tests/groupby/test_groupby_subclass.py b/pandas/tests/groupby/test_groupby_subclass.py index b20ed0e3a21c1..4b444d6759fed 100644 --- a/pandas/tests/groupby/test_groupby_subclass.py +++ b/pandas/tests/groupby/test_groupby_subclass.py @@ -116,7 +116,7 @@ def _constructor_expanddim(self): return UnitDataFrame def mean(self, *args, **kwargs): - return 2 + return 1 def median(self, *args, **kwargs): return 2 @@ -149,7 +149,7 @@ def _constructor_expanddim(self): return UnitSeries def mean(self, *args, **kwargs): - return 2 + return 1 def median(self, *args, **kwargs): return 2 From adc132abfff72ffbe39f95955e3404e2f5c9cd25 Mon Sep 17 00:00:00 2001 From: Ormorod Date: Fri, 3 Mar 2023 19:20:31 +0000 Subject: [PATCH 11/26] check that self._obj_1d_constructor is Series --- pandas/core/groupby/generic.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index db6f9d7b87025..24d1308a1e274 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -682,6 +682,12 @@ def value_counts( index_names = self.grouper.names + [self.obj.name] + constructor_1d = ( + self._obj_1d_constructor + if isinstance(self._obj_1d_constructor, Series) + else Series + ) + if is_categorical_dtype(val.dtype) or ( bins is not None and not np.iterable(bins) ): @@ -689,7 +695,7 @@ def value_counts( # in a backward compatible way # GH38672 relates to categorical dtype ser = self.apply( - self._obj_1d_constructor.value_counts, + constructor_1d.value_counts, normalize=normalize, sort=sort, ascending=ascending, From 31868ffd75aba9cb8efb6ce96f8b9f541965ef17 Mon Sep 17 00:00:00 2001 From: Ormorod Date: Fri, 3 Mar 2023 19:29:25 +0000 Subject: [PATCH 12/26] add entry to docs --- doc/source/whatsnew/v2.1.0.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doc/source/whatsnew/v2.1.0.rst b/doc/source/whatsnew/v2.1.0.rst index c0ca5b2320338..696207611eaa8 100644 --- a/doc/source/whatsnew/v2.1.0.rst +++ b/doc/source/whatsnew/v2.1.0.rst @@ -185,6 +185,10 @@ Plotting - - +Groupby +- Bug in :meth:`GroupBy.mean`, :meth:`GroupBy.median`, :meth:`GroupBy.std`, :meth:`GroupBy.var`, :meth:`GroupBy.sem`, :meth:`GroupBy.prod`, :meth:`GroupBy.min`, :meth:`GroupBy.max` don't use corresponding methods of subclasses of :class:`Series` or :class:`DataFrame` (:issue:`51757`) +- + Groupby/resample/rolling ^^^^^^^^^^^^^^^^^^^^^^^^ - Bug in :meth:`DataFrameGroupBy.idxmin`, :meth:`SeriesGroupBy.idxmin`, :meth:`DataFrameGroupBy.idxmax`, :meth:`SeriesGroupBy.idxmax` return wrong dtype when used on empty DataFrameGroupBy or SeriesGroupBy (:issue:`51423`) From 2efa052dead0ce6bd66e2b5499792619d5f56b83 Mon Sep 17 00:00:00 2001 From: Ormorod Date: Mon, 6 Mar 2023 13:46:23 +0000 Subject: [PATCH 13/26] check for equality of mean methods --- pandas/core/groupby/groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 63df9181a8bc6..956efecedc488 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1837,7 +1837,7 @@ def mean( Name: B, dtype: float64 """ - if not (type(self.obj) == Series or type(self.obj) == DataFrame): + if not (self.obj.mean is Series.mean or self.obj.mean is DataFrame.mean): def f(df, *args, **kwargs): return self.obj._constructor(df).mean() From 1505a1c98d10e9a384f6af81405c14a93be9f49d Mon Sep 17 00:00:00 2001 From: Ormorod Date: Mon, 6 Mar 2023 14:06:14 +0000 Subject: [PATCH 14/26] repeat for other methods --- pandas/core/groupby/groupby.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 956efecedc488..e805ed1e40aa0 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1880,7 +1880,9 @@ def median(self, numeric_only: bool = False): Series or DataFrame Median of values within each group. """ - if not (type(self.obj) == Series or type(self.obj) == DataFrame): + if not ( + self.obj.median is Series.median or self.obj.median is DataFrame.median + ): def f(df, *args, **kwargs): return self.obj._constructor(df).median() @@ -1946,7 +1948,7 @@ def std( Series or DataFrame Standard deviation of values within each group. """ - if not (type(self.obj) == Series or type(self.obj) == DataFrame): + if not (self.obj.std is Series.std or self.obj.std is DataFrame.std): def f(df, *args, **kwargs): return self.obj._constructor(df).std() @@ -2037,7 +2039,7 @@ def var( Series or DataFrame Variance of values within each group. """ - if not (type(self.obj) == Series or type(self.obj) == DataFrame): + if not (self.obj.var is Series.var or self.obj.var is DataFrame.var): def f(df, *args, **kwargs): return self.obj._constructor(df).var() @@ -2215,7 +2217,7 @@ def sem(self, ddof: int = 1, numeric_only: bool = False): Standard error of the mean of values within each group. """ # TODO: think sem() needs considering more closely - if not (type(self.obj) == Series or type(self.obj) == DataFrame): + if not (self.obj.sem is Series.sem or self.obj.sem is DataFrame.sem): def f(df, *args, **kwargs): return self.obj._constructor(df).sem() @@ -2281,7 +2283,7 @@ def sum( engine: str | None = None, engine_kwargs: dict[str, bool] | None = None, ): - if not (type(self.obj) == Series or type(self.obj) == DataFrame): + if not (self.obj.sum is Series.sum or self.obj.sum is DataFrame.sum): def f(df, *args, **kwargs): return self.obj._constructor(df).sum() @@ -2313,7 +2315,7 @@ def f(df, *args, **kwargs): @final @doc(_groupby_agg_method_template, fname="prod", no=False, mc=0) def prod(self, numeric_only: bool = False, min_count: int = 0): - if not (type(self.obj) == Series or type(self.obj) == DataFrame): + if not (self.obj.prod is Series.prod or self.obj.prod is DataFrame.prod): def f(df, *args, **kwargs): return self.obj._constructor(df).prod() @@ -2334,7 +2336,7 @@ def min( engine: str | None = None, engine_kwargs: dict[str, bool] | None = None, ): - if not (type(self.obj) == Series or type(self.obj) == DataFrame): + if not (self.obj.min is Series.min or self.obj.min is DataFrame.min): def f(df, *args, **kwargs): return self.obj._constructor(df).min() @@ -2363,7 +2365,7 @@ def max( engine: str | None = None, engine_kwargs: dict[str, bool] | None = None, ): - if not (type(self.obj) == Series or type(self.obj) == DataFrame): + if not (self.obj.max is Series.max or self.obj.max is DataFrame.max): def f(df, *args, **kwargs): return self.obj._constructor(df).max() From af9ac261d1bde1c55bdf0e74c2c82b1b43b3be15 Mon Sep 17 00:00:00 2001 From: Ormorod Date: Mon, 6 Mar 2023 14:35:01 +0000 Subject: [PATCH 15/26] also test Series --- pandas/tests/groupby/test_groupby_subclass.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pandas/tests/groupby/test_groupby_subclass.py b/pandas/tests/groupby/test_groupby_subclass.py index 4b444d6759fed..544fc424001e2 100644 --- a/pandas/tests/groupby/test_groupby_subclass.py +++ b/pandas/tests/groupby/test_groupby_subclass.py @@ -186,3 +186,12 @@ def max(self, *args, **kwargs): assert np.all(udf.groupby("group").prod() == 6) assert np.all(udf.groupby("group").min() == 7) assert np.all(udf.groupby("group").max() == 8) + for useries in udf: + assert np.all(useries.groupby("group").mean() == 1) + assert np.all(useries.groupby("group").median() == 2) + assert np.all(useries.groupby("group").std() == 3) + assert np.all(useries.groupby("group").var() == 4) + assert np.all(useries.groupby("group").sem() == 5) + assert np.all(useries.groupby("group").prod() == 6) + assert np.all(useries.groupby("group").min() == 7) + assert np.all(useries.groupby("group").max() == 8) From f4bc5483249935c340c1ab4e827ed11c20a43b77 Mon Sep 17 00:00:00 2001 From: Ormorod Date: Mon, 6 Mar 2023 16:48:56 +0000 Subject: [PATCH 16/26] pass through numeric_only --- pandas/core/groupby/groupby.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index e805ed1e40aa0..66f19656f0bec 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1840,7 +1840,7 @@ def mean( if not (self.obj.mean is Series.mean or self.obj.mean is DataFrame.mean): def f(df, *args, **kwargs): - return self.obj._constructor(df).mean() + return self.obj._constructor(df).mean(numeric_only=numeric_only) result = self.agg(f) return result.__finalize__(self.obj, method="groupby") @@ -1885,7 +1885,7 @@ def median(self, numeric_only: bool = False): ): def f(df, *args, **kwargs): - return self.obj._constructor(df).median() + return self.obj._constructor(df).median(numeric_only=numeric_only) result = self.agg(f) return result.__finalize__(self.obj, method="groupby") @@ -1951,7 +1951,7 @@ def std( if not (self.obj.std is Series.std or self.obj.std is DataFrame.std): def f(df, *args, **kwargs): - return self.obj._constructor(df).std() + return self.obj._constructor(df).std(numeric_only=numeric_only) result = self.agg(f) return result.__finalize__(self.obj, method="groupby") @@ -2042,7 +2042,7 @@ def var( if not (self.obj.var is Series.var or self.obj.var is DataFrame.var): def f(df, *args, **kwargs): - return self.obj._constructor(df).var() + return self.obj._constructor(df).var(numeric_only=numeric_only) result = self.agg(f) return result.__finalize__(self.obj, method="groupby") @@ -2220,7 +2220,7 @@ def sem(self, ddof: int = 1, numeric_only: bool = False): if not (self.obj.sem is Series.sem or self.obj.sem is DataFrame.sem): def f(df, *args, **kwargs): - return self.obj._constructor(df).sem() + return self.obj._constructor(df).sem(numeric_only=numeric_only) result = self.agg(f) return result.__finalize__(self.obj, method="groupby") @@ -2286,7 +2286,7 @@ def sum( if not (self.obj.sum is Series.sum or self.obj.sum is DataFrame.sum): def f(df, *args, **kwargs): - return self.obj._constructor(df).sum() + return self.obj._constructor(df).sum(numeric_only=numeric_only) result = self.agg(f) return result.__finalize__(self.obj, method="groupby") @@ -2318,7 +2318,7 @@ def prod(self, numeric_only: bool = False, min_count: int = 0): if not (self.obj.prod is Series.prod or self.obj.prod is DataFrame.prod): def f(df, *args, **kwargs): - return self.obj._constructor(df).prod() + return self.obj._constructor(df).prod(numeric_only=numeric_only) result = self.agg(f) return result.__finalize__(self.obj, method="groupby") @@ -2339,7 +2339,7 @@ def min( if not (self.obj.min is Series.min or self.obj.min is DataFrame.min): def f(df, *args, **kwargs): - return self.obj._constructor(df).min() + return self.obj._constructor(df).min(numeric_only=numeric_only) result = self.agg(f) return result.__finalize__(self.obj, method="groupby") @@ -2368,7 +2368,7 @@ def max( if not (self.obj.max is Series.max or self.obj.max is DataFrame.max): def f(df, *args, **kwargs): - return self.obj._constructor(df).max() + return self.obj._constructor(df).max(numeric_only=numeric_only) result = self.agg(f) return result.__finalize__(self.obj, method="groupby") From 12a9fa81a2b47f57317e375808712e0703a853e5 Mon Sep 17 00:00:00 2001 From: Ormorod Date: Mon, 6 Mar 2023 17:52:09 +0000 Subject: [PATCH 17/26] reinstate type hinting --- pandas/core/groupby/groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 66f19656f0bec..8cf280dc97b23 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1629,7 +1629,7 @@ def _cumcount_array(self, ascending: bool = True) -> np.ndarray: @final @property - def _obj_1d_constructor(self): + def _obj_1d_constructor(self) -> Callable: # GH28330 preserve subclassed Series/DataFrames if isinstance(self.obj, DataFrame): return self.obj._constructor_sliced From f46eea95aba4f4647e9b54a02f367a2b87ff0fde Mon Sep 17 00:00:00 2001 From: Ormorod Date: Mon, 6 Mar 2023 17:57:09 +0000 Subject: [PATCH 18/26] add type() to method comparison --- pandas/core/groupby/groupby.py | 35 +++++++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 8cf280dc97b23..b6734eec87374 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1837,7 +1837,9 @@ def mean( Name: B, dtype: float64 """ - if not (self.obj.mean is Series.mean or self.obj.mean is DataFrame.mean): + if not ( + type(self.obj).mean is Series.mean or type(self.obj).mean is DataFrame.mean + ): def f(df, *args, **kwargs): return self.obj._constructor(df).mean(numeric_only=numeric_only) @@ -1881,7 +1883,8 @@ def median(self, numeric_only: bool = False): Median of values within each group. """ if not ( - self.obj.median is Series.median or self.obj.median is DataFrame.median + type(self.obj).median is Series.median + or type(self.obj).median is DataFrame.median ): def f(df, *args, **kwargs): @@ -1948,7 +1951,9 @@ def std( Series or DataFrame Standard deviation of values within each group. """ - if not (self.obj.std is Series.std or self.obj.std is DataFrame.std): + if not ( + type(self.obj).std is Series.std or type(self.obj).std is DataFrame.std + ): def f(df, *args, **kwargs): return self.obj._constructor(df).std(numeric_only=numeric_only) @@ -2039,7 +2044,9 @@ def var( Series or DataFrame Variance of values within each group. """ - if not (self.obj.var is Series.var or self.obj.var is DataFrame.var): + if not ( + type(self.obj).var is Series.var or type(self.obj).var is DataFrame.var + ): def f(df, *args, **kwargs): return self.obj._constructor(df).var(numeric_only=numeric_only) @@ -2217,7 +2224,9 @@ def sem(self, ddof: int = 1, numeric_only: bool = False): Standard error of the mean of values within each group. """ # TODO: think sem() needs considering more closely - if not (self.obj.sem is Series.sem or self.obj.sem is DataFrame.sem): + if not ( + type(self.obj).sem is Series.sem or type(self.obj).sem is DataFrame.sem + ): def f(df, *args, **kwargs): return self.obj._constructor(df).sem(numeric_only=numeric_only) @@ -2283,7 +2292,9 @@ def sum( engine: str | None = None, engine_kwargs: dict[str, bool] | None = None, ): - if not (self.obj.sum is Series.sum or self.obj.sum is DataFrame.sum): + if not ( + type(self.obj).sum is Series.sum or type(self.obj).sum is DataFrame.sum + ): def f(df, *args, **kwargs): return self.obj._constructor(df).sum(numeric_only=numeric_only) @@ -2315,7 +2326,9 @@ def f(df, *args, **kwargs): @final @doc(_groupby_agg_method_template, fname="prod", no=False, mc=0) def prod(self, numeric_only: bool = False, min_count: int = 0): - if not (self.obj.prod is Series.prod or self.obj.prod is DataFrame.prod): + if not ( + type(self.obj).prod is Series.prod or type(self.obj).prod is DataFrame.prod + ): def f(df, *args, **kwargs): return self.obj._constructor(df).prod(numeric_only=numeric_only) @@ -2336,7 +2349,9 @@ def min( engine: str | None = None, engine_kwargs: dict[str, bool] | None = None, ): - if not (self.obj.min is Series.min or self.obj.min is DataFrame.min): + if not ( + type(self.obj).min is Series.min or type(self.obj).min is DataFrame.min + ): def f(df, *args, **kwargs): return self.obj._constructor(df).min(numeric_only=numeric_only) @@ -2365,7 +2380,9 @@ def max( engine: str | None = None, engine_kwargs: dict[str, bool] | None = None, ): - if not (self.obj.max is Series.max or self.obj.max is DataFrame.max): + if not ( + type(self.obj).max is Series.max or type(self.obj).max is DataFrame.max + ): def f(df, *args, **kwargs): return self.obj._constructor(df).max(numeric_only=numeric_only) From 185a3c1ed72f9c812764f23f2414d60e304f44e7 Mon Sep 17 00:00:00 2001 From: Ormorod Date: Tue, 7 Mar 2023 10:49:21 +0000 Subject: [PATCH 19/26] test transform --- pandas/tests/groupby/test_groupby_subclass.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/pandas/tests/groupby/test_groupby_subclass.py b/pandas/tests/groupby/test_groupby_subclass.py index 544fc424001e2..2cff6445f5cd1 100644 --- a/pandas/tests/groupby/test_groupby_subclass.py +++ b/pandas/tests/groupby/test_groupby_subclass.py @@ -172,9 +172,9 @@ def min(self, *args, **kwargs): def max(self, *args, **kwargs): return 8 - params = ["a", "b"] + columns = ["a", "b"] data = np.random.rand(4, 2) - udf = UnitDataFrame(data, columns=params) + udf = UnitDataFrame(data, columns=columns) udf["group"] = np.ones(4, dtype=int) udf.loc[2:, "group"] = 2 @@ -186,12 +186,12 @@ def max(self, *args, **kwargs): assert np.all(udf.groupby("group").prod() == 6) assert np.all(udf.groupby("group").min() == 7) assert np.all(udf.groupby("group").max() == 8) - for useries in udf: - assert np.all(useries.groupby("group").mean() == 1) - assert np.all(useries.groupby("group").median() == 2) - assert np.all(useries.groupby("group").std() == 3) - assert np.all(useries.groupby("group").var() == 4) - assert np.all(useries.groupby("group").sem() == 5) - assert np.all(useries.groupby("group").prod() == 6) - assert np.all(useries.groupby("group").min() == 7) - assert np.all(useries.groupby("group").max() == 8) + + assert np.all(udf.groupby("group").transform("mean") == 1) + assert np.all(udf.groupby("group").transform("median") == 2) + assert np.all(udf.groupby("group").transform("std") == 3) + assert np.all(udf.groupby("group").transform("var") == 4) + assert np.all(udf.groupby("group").transform("sem") == 5) + assert np.all(udf.groupby("group").transform("prod") == 6) + assert np.all(udf.groupby("group").transform("min") == 7) + assert np.all(udf.groupby("group").transform("max") == 8) From bf9bde6f69d18335ac027ce2aef1373b052e2b07 Mon Sep 17 00:00:00 2001 From: Ormorod Date: Tue, 7 Mar 2023 10:53:00 +0000 Subject: [PATCH 20/26] correct _constructor --- pandas/core/groupby/generic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 09aae725ffbe4..10b88d8d473ac 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -1343,7 +1343,7 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs) else: # GH#32040, GH#35246 # e.g. test_groupby_as_index_select_column_sum_empty_df - result = cast(self._obj_1d_constructor, result) + result = cast(self._constructor, result) result.columns = self._obj_with_exclusions.columns.copy() if not self.as_index: From a6be1eaf2f052688d4d4e811cf523f4aa1afc76a Mon Sep 17 00:00:00 2001 From: Ormorod Date: Mon, 27 Mar 2023 13:45:22 +0100 Subject: [PATCH 21/26] remove unnecessary(?) if statement --- pandas/core/groupby/generic.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index e8d69a4fa2ffb..d4f023f90cf52 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -683,12 +683,6 @@ def value_counts( index_names = self.grouper.names + [self.obj.name] - constructor_1d = ( - self._obj_1d_constructor - if isinstance(self._obj_1d_constructor, Series) - else Series - ) - if is_categorical_dtype(val.dtype) or ( bins is not None and not np.iterable(bins) ): @@ -696,7 +690,7 @@ def value_counts( # in a backward compatible way # GH38672 relates to categorical dtype ser = self.apply( - constructor_1d.value_counts, + self._obj_1d_constructor.value_counts, normalize=normalize, sort=sort, ascending=ascending, From 6631b1ec4b5b181210850ac7132dccaecae98aba Mon Sep 17 00:00:00 2001 From: AdamOrmondroyd Date: Tue, 25 Apr 2023 11:33:12 +0100 Subject: [PATCH 22/26] first pass at decorator --- pandas/core/groupby/groupby.py | 35 ++++++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 8cb04e81c4510..24a84cc5e7243 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1459,6 +1459,30 @@ def _aggregate_with_numba(self, func, *args, engine_kwargs=None, **kwargs): res.index = default_index(len(res)) return res + def _use_subclass_method(func): + """ + Use the corresponding func method in the case of a + subclassed Series or DataFrame. + """ + + def inner(self, *args, **kwargs): + if not ( + getattr(type(self.obj), func.__name__) is getattr(Series, func.__name__) + or getattr(type(self.obj), func.__name__) + is getattr(DataFrame, func.__name__) + ): + print("is a subclass") + result = self.agg( + lambda df: getattr(self.obj._constructor(df), func.__name__)( + *args, **kwargs + ) + ) + return result.__finalize__(self.obj, method="groupby") + print("not a subclass") + return func(self, *args, **kwargs) + + return inner + # ----------------------------------------------------------------- # apply/agg/transform @@ -1898,6 +1922,7 @@ def hfunc(bvalues: ArrayLike) -> ArrayLike: return self._reindex_output(result, fill_value=0) @final + @_use_subclass_method @Substitution(name="groupby") @Substitution(see_also=_common_see_also) def mean( @@ -1974,16 +1999,6 @@ def mean( Name: B, dtype: float64 """ - if not ( - type(self.obj).mean is Series.mean or type(self.obj).mean is DataFrame.mean - ): - - def f(df, *args, **kwargs): - return self.obj._constructor(df).mean(numeric_only=numeric_only) - - result = self.agg(f) - return result.__finalize__(self.obj, method="groupby") - if maybe_use_numba(engine): from pandas.core._numba.kernels import sliding_mean From 94dc18610353ce2145fc742da0a27ae1397d7113 Mon Sep 17 00:00:00 2001 From: AdamOrmondroyd Date: Tue, 25 Apr 2023 11:48:27 +0100 Subject: [PATCH 23/26] add decorator to other methods --- pandas/core/groupby/groupby.py | 91 +++------------------------------- 1 file changed, 7 insertions(+), 84 deletions(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 24a84cc5e7243..28335baf8d56c 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1471,14 +1471,12 @@ def inner(self, *args, **kwargs): or getattr(type(self.obj), func.__name__) is getattr(DataFrame, func.__name__) ): - print("is a subclass") result = self.agg( lambda df: getattr(self.obj._constructor(df), func.__name__)( *args, **kwargs ) ) return result.__finalize__(self.obj, method="groupby") - print("not a subclass") return func(self, *args, **kwargs) return inner @@ -2014,6 +2012,7 @@ def mean( return result.__finalize__(self.obj, method="groupby") @final + @_use_subclass_method def median(self, numeric_only: bool = False): """ Compute median of groups, excluding missing values. @@ -2034,17 +2033,6 @@ def median(self, numeric_only: bool = False): Series or DataFrame Median of values within each group. """ - if not ( - type(self.obj).median is Series.median - or type(self.obj).median is DataFrame.median - ): - - def f(df, *args, **kwargs): - return self.obj._constructor(df).median(numeric_only=numeric_only) - - result = self.agg(f) - return result.__finalize__(self.obj, method="groupby") - result = self._cython_agg_general( "median", alt=lambda x: self._obj_1d_constructor(x).median(numeric_only=numeric_only), @@ -2053,6 +2041,7 @@ def f(df, *args, **kwargs): return result.__finalize__(self.obj, method="groupby") @final + @_use_subclass_method @Substitution(name="groupby") @Appender(_common_see_also) def std( @@ -2103,16 +2092,6 @@ def std( Series or DataFrame Standard deviation of values within each group. """ - if not ( - type(self.obj).std is Series.std or type(self.obj).std is DataFrame.std - ): - - def f(df, *args, **kwargs): - return self.obj._constructor(df).std(numeric_only=numeric_only) - - result = self.agg(f) - return result.__finalize__(self.obj, method="groupby") - if maybe_use_numba(engine): from pandas.core._numba.kernels import sliding_var @@ -2126,6 +2105,7 @@ def f(df, *args, **kwargs): ) @final + @_use_subclass_method @Substitution(name="groupby") @Appender(_common_see_also) def var( @@ -2176,16 +2156,6 @@ def var( Series or DataFrame Variance of values within each group. """ - if not ( - type(self.obj).var is Series.var or type(self.obj).var is DataFrame.var - ): - - def f(df, *args, **kwargs): - return self.obj._constructor(df).var(numeric_only=numeric_only) - - result = self.agg(f) - return result.__finalize__(self.obj, method="groupby") - if maybe_use_numba(engine): from pandas.core._numba.kernels import sliding_var @@ -2332,6 +2302,7 @@ def _value_counts( return result.__finalize__(self.obj, method="value_counts") @final + @_use_subclass_method def sem(self, ddof: int = 1, numeric_only: bool = False): """ Compute standard error of the mean of groups, excluding missing values. @@ -2357,17 +2328,6 @@ def sem(self, ddof: int = 1, numeric_only: bool = False): Series or DataFrame Standard error of the mean of values within each group. """ - # TODO: think sem() needs considering more closely - if not ( - type(self.obj).sem is Series.sem or type(self.obj).sem is DataFrame.sem - ): - - def f(df, *args, **kwargs): - return self.obj._constructor(df).sem(numeric_only=numeric_only) - - result = self.agg(f) - return result.__finalize__(self.obj, method="groupby") - if numeric_only and self.obj.ndim == 1 and not is_numeric_dtype(self.obj.dtype): raise TypeError( f"{type(self).__name__}.sem called with " @@ -2412,6 +2372,7 @@ def size(self) -> DataFrame | Series: return result @final + @_use_subclass_method @doc(_groupby_agg_method_template, fname="sum", no=False, mc=0) def sum( self, @@ -2420,16 +2381,6 @@ def sum( engine: str | None = None, engine_kwargs: dict[str, bool] | None = None, ): - if not ( - type(self.obj).sum is Series.sum or type(self.obj).sum is DataFrame.sum - ): - - def f(df, *args, **kwargs): - return self.obj._constructor(df).sum(numeric_only=numeric_only) - - result = self.agg(f) - return result.__finalize__(self.obj, method="groupby") - if maybe_use_numba(engine): from pandas.core._numba.kernels import sliding_sum @@ -2452,23 +2403,15 @@ def f(df, *args, **kwargs): return self._reindex_output(result, fill_value=0) @final + @_use_subclass_method @doc(_groupby_agg_method_template, fname="prod", no=False, mc=0) def prod(self, numeric_only: bool = False, min_count: int = 0): - if not ( - type(self.obj).prod is Series.prod or type(self.obj).prod is DataFrame.prod - ): - - def f(df, *args, **kwargs): - return self.obj._constructor(df).prod(numeric_only=numeric_only) - - result = self.agg(f) - return result.__finalize__(self.obj, method="groupby") - return self._agg_general( numeric_only=numeric_only, min_count=min_count, alias="prod", npfunc=np.prod ) @final + @_use_subclass_method @doc(_groupby_agg_method_template, fname="min", no=False, mc=-1) def min( self, @@ -2477,16 +2420,6 @@ def min( engine: str | None = None, engine_kwargs: dict[str, bool] | None = None, ): - if not ( - type(self.obj).min is Series.min or type(self.obj).min is DataFrame.min - ): - - def f(df, *args, **kwargs): - return self.obj._constructor(df).min(numeric_only=numeric_only) - - result = self.agg(f) - return result.__finalize__(self.obj, method="groupby") - if maybe_use_numba(engine): from pandas.core._numba.kernels import sliding_min_max @@ -2508,16 +2441,6 @@ def max( engine: str | None = None, engine_kwargs: dict[str, bool] | None = None, ): - if not ( - type(self.obj).max is Series.max or type(self.obj).max is DataFrame.max - ): - - def f(df, *args, **kwargs): - return self.obj._constructor(df).max(numeric_only=numeric_only) - - result = self.agg(f) - return result.__finalize__(self.obj, method="groupby") - if maybe_use_numba(engine): from pandas.core._numba.kernels import sliding_min_max From 310c339aa494a355232729413c165a9d99ca84d6 Mon Sep 17 00:00:00 2001 From: AdamOrmondroyd Date: Tue, 25 Apr 2023 11:54:18 +0100 Subject: [PATCH 24/26] missed max() --- pandas/core/groupby/groupby.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 28335baf8d56c..f0e45c4a53938 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -2433,6 +2433,7 @@ def min( ) @final + @_use_subclass_method @doc(_groupby_agg_method_template, fname="max", no=False, mc=-1) def max( self, From 27c4ed9a991fdb32d428501a76bfe309705d1699 Mon Sep 17 00:00:00 2001 From: AdamOrmondroyd Date: Tue, 25 Apr 2023 16:42:12 +0100 Subject: [PATCH 25/26] add @wraps --- pandas/core/groupby/groupby.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index f0e45c4a53938..f4140f9b00f00 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1465,6 +1465,7 @@ def _use_subclass_method(func): subclassed Series or DataFrame. """ + @wraps(func) def inner(self, *args, **kwargs): if not ( getattr(type(self.obj), func.__name__) is getattr(Series, func.__name__) From 8a9f30f99f56d5abda25c0d58cbf658f9ef1422f Mon Sep 17 00:00:00 2001 From: AdamOrmondroyd Date: Wed, 26 Apr 2023 13:24:35 +0100 Subject: [PATCH 26/26] add tests for series example --- pandas/tests/groupby/test_groupby_subclass.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/pandas/tests/groupby/test_groupby_subclass.py b/pandas/tests/groupby/test_groupby_subclass.py index 5dc59ca128753..94c7c362380d3 100644 --- a/pandas/tests/groupby/test_groupby_subclass.py +++ b/pandas/tests/groupby/test_groupby_subclass.py @@ -182,6 +182,8 @@ def max(self, *args, **kwargs): udf["group"] = np.ones(4, dtype=int) udf.loc[2:, "group"] = 2 + us = udf[["a", "group"]] + assert np.all(udf.groupby("group").mean() == 1) assert np.all(udf.groupby("group").median() == 2) assert np.all(udf.groupby("group").std() == 3) @@ -191,6 +193,15 @@ def max(self, *args, **kwargs): assert np.all(udf.groupby("group").min() == 7) assert np.all(udf.groupby("group").max() == 8) + assert np.all(us.groupby("group").mean() == 1) + assert np.all(us.groupby("group").median() == 2) + assert np.all(us.groupby("group").std() == 3) + assert np.all(us.groupby("group").var() == 4) + assert np.all(us.groupby("group").sem() == 5) + assert np.all(us.groupby("group").prod() == 6) + assert np.all(us.groupby("group").min() == 7) + assert np.all(us.groupby("group").max() == 8) + assert np.all(udf.groupby("group").transform("mean") == 1) assert np.all(udf.groupby("group").transform("median") == 2) assert np.all(udf.groupby("group").transform("std") == 3) @@ -199,3 +210,12 @@ def max(self, *args, **kwargs): assert np.all(udf.groupby("group").transform("prod") == 6) assert np.all(udf.groupby("group").transform("min") == 7) assert np.all(udf.groupby("group").transform("max") == 8) + + assert np.all(us.groupby("group").transform("mean") == 1) + assert np.all(us.groupby("group").transform("median") == 2) + assert np.all(us.groupby("group").transform("std") == 3) + assert np.all(us.groupby("group").transform("var") == 4) + assert np.all(us.groupby("group").transform("sem") == 5) + assert np.all(us.groupby("group").transform("prod") == 6) + assert np.all(us.groupby("group").transform("min") == 7) + assert np.all(us.groupby("group").transform("max") == 8)