diff --git a/doc/source/whatsnew/v2.1.0.rst b/doc/source/whatsnew/v2.1.0.rst index 2c5263f447951..db07b46771083 100644 --- a/doc/source/whatsnew/v2.1.0.rst +++ b/doc/source/whatsnew/v2.1.0.rst @@ -417,6 +417,10 @@ Plotting - Bug in :meth:`Series.plot` when invoked with ``color=None`` (:issue:`51953`) - +Groupby +- Bug in :meth:`GroupBy.mean`, :meth:`GroupBy.median`, :meth:`GroupBy.std`, :meth:`GroupBy.var`, :meth:`GroupBy.sem`, :meth:`GroupBy.prod`, :meth:`GroupBy.min`, :meth:`GroupBy.max` don't use corresponding methods of subclasses of :class:`Series` or :class:`DataFrame` (:issue:`51757`) +- + Groupby/resample/rolling ^^^^^^^^^^^^^^^^^^^^^^^^ - Bug in :meth:`DataFrame.resample` and :meth:`Series.resample` in incorrectly allowing non-fixed ``freq`` when resampling on a :class:`TimedeltaIndex` (:issue:`51896`) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 37ef04f17a2e5..b307e9eca82d3 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -290,7 +290,9 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs) ) # result is a dict whose keys are the elements of result_index - result = Series(result, index=self.grouper.result_index) + result = self._obj_1d_constructor( + result, index=self.grouper.result_index + ) result = self._wrap_aggregated_output(result) return result @@ -703,7 +705,7 @@ def value_counts( # in a backward compatible way # GH38672 relates to categorical dtype ser = self.apply( - Series.value_counts, + self._obj_1d_constructor.value_counts, normalize=normalize, sort=sort, ascending=ascending, @@ -722,7 +724,9 @@ def value_counts( llab = lambda lab, inc: lab[inc] else: # lab is a Categorical with categories an IntervalIndex - cat_ser = cut(Series(val, copy=False), bins, include_lowest=True) + cat_ser = cut( + self.obj._constructor(val, copy=False), bins, include_lowest=True + ) cat_obj = cast("Categorical", cat_ser._values) lev = cat_obj.categories lab = lev.take( @@ -1406,9 +1410,9 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs) elif relabeling: # this should be the only (non-raising) case with relabeling # used reordered index of columns - result = cast(DataFrame, result) + result = cast(self.obj._constructor, result) result = result.iloc[:, order] - result = cast(DataFrame, result) + result = cast(self.obj._constructor, result) # error: Incompatible types in assignment (expression has type # "Optional[List[str]]", variable has type # "Union[Union[Union[ExtensionArray, ndarray[Any, Any]], @@ -1451,7 +1455,7 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs) else: # GH#32040, GH#35246 # e.g. test_groupby_as_index_select_column_sum_empty_df - result = cast(DataFrame, result) + result = cast(self._constructor, result) result.columns = self._obj_with_exclusions.columns.copy() if not self.as_index: @@ -1586,7 +1590,7 @@ def _wrap_applied_output_series( is_transform: bool, ) -> DataFrame | Series: kwargs = first_not_none._construct_axes_dict() - backup = Series(**kwargs) + backup = self._obj_1d_constructor(**kwargs) values = [x if (x is not None) else backup for x in values] all_indexed_same = all_indexes_same(x.index for x in values) @@ -1981,7 +1985,9 @@ def _apply_to_column_groupbys(self, func) -> DataFrame: if not len(results): # concat would raise - res_df = DataFrame([], columns=columns, index=self.grouper.result_index) + res_df = self.obj._constructor( + [], columns=columns, index=self.grouper.result_index + ) else: res_df = concat(results, keys=columns, axis=1) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index bdab641719ded..9434e02230b1a 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1450,6 +1450,29 @@ def _aggregate_with_numba(self, func, *args, engine_kwargs=None, **kwargs): res.index = default_index(len(res)) return res + def _use_subclass_method(func): + """ + Use the corresponding func method in the case of a + subclassed Series or DataFrame. + """ + + @wraps(func) + def inner(self, *args, **kwargs): + if not ( + getattr(type(self.obj), func.__name__) is getattr(Series, func.__name__) + or getattr(type(self.obj), func.__name__) + is getattr(DataFrame, func.__name__) + ): + result = self.agg( + lambda df: getattr(self.obj._constructor(df), func.__name__)( + *args, **kwargs + ) + ) + return result.__finalize__(self.obj, method="groupby") + return func(self, *args, **kwargs) + + return inner + # ----------------------------------------------------------------- # apply/agg/transform @@ -1879,6 +1902,7 @@ def hfunc(bvalues: ArrayLike) -> ArrayLike: return self._reindex_output(result, fill_value=0) @final + @_use_subclass_method @Substitution(name="groupby") @Substitution(see_also=_common_see_also) def mean( @@ -1962,12 +1986,15 @@ def mean( else: result = self._cython_agg_general( "mean", - alt=lambda x: Series(x).mean(numeric_only=numeric_only), + alt=lambda x: self._obj_1d_constructor(x).mean( + numeric_only=numeric_only + ), numeric_only=numeric_only, ) return result.__finalize__(self.obj, method="groupby") @final + @_use_subclass_method def median(self, numeric_only: bool = False): """ Compute median of groups, excluding missing values. @@ -1990,12 +2017,13 @@ def median(self, numeric_only: bool = False): """ result = self._cython_agg_general( "median", - alt=lambda x: Series(x).median(numeric_only=numeric_only), + alt=lambda x: self._obj_1d_constructor(x).median(numeric_only=numeric_only), numeric_only=numeric_only, ) return result.__finalize__(self.obj, method="groupby") @final + @_use_subclass_method @Substitution(name="groupby") @Appender(_common_see_also) def std( @@ -2059,6 +2087,7 @@ def std( ) @final + @_use_subclass_method @Substitution(name="groupby") @Appender(_common_see_also) def var( @@ -2116,7 +2145,7 @@ def var( else: return self._cython_agg_general( "var", - alt=lambda x: Series(x).var(ddof=ddof), + alt=lambda x: self._obj_1d_constructor(x).var(ddof=ddof), numeric_only=numeric_only, ddof=ddof, ) @@ -2255,6 +2284,7 @@ def _value_counts( return result.__finalize__(self.obj, method="value_counts") @final + @_use_subclass_method def sem(self, ddof: int = 1, numeric_only: bool = False): """ Compute standard error of the mean of groups, excluding missing values. @@ -2324,6 +2354,7 @@ def size(self) -> DataFrame | Series: return result @final + @_use_subclass_method @doc(_groupby_agg_method_template, fname="sum", no=False, mc=0) def sum( self, @@ -2354,6 +2385,7 @@ def sum( return self._reindex_output(result, fill_value=0) @final + @_use_subclass_method @doc(_groupby_agg_method_template, fname="prod", no=False, mc=0) def prod(self, numeric_only: bool = False, min_count: int = 0): return self._agg_general( @@ -2361,6 +2393,7 @@ def prod(self, numeric_only: bool = False, min_count: int = 0): ) @final + @_use_subclass_method @doc(_groupby_agg_method_template, fname="min", no=False, mc=-1) def min( self, @@ -2382,6 +2415,7 @@ def min( ) @final + @_use_subclass_method @doc(_groupby_agg_method_template, fname="max", no=False, mc=-1) def max( self, diff --git a/pandas/tests/groupby/test_groupby_subclass.py b/pandas/tests/groupby/test_groupby_subclass.py index 773c1e60e97af..72a2e82178760 100644 --- a/pandas/tests/groupby/test_groupby_subclass.py +++ b/pandas/tests/groupby/test_groupby_subclass.py @@ -103,3 +103,115 @@ def test_groupby_resample_preserves_subclass(obj): # Confirm groupby.resample() preserves dataframe type result = df.groupby("Buyer").resample("5D").sum() assert isinstance(result, obj) + + +def test_groupby_overridden_methods(): + class UnitSeries(Series): + @property + def _constructor(self): + return UnitSeries + + @property + def _constructor_expanddim(self): + return UnitDataFrame + + def mean(self, *args, **kwargs): + return 1 + + def median(self, *args, **kwargs): + return 2 + + def std(self, *args, **kwargs): + return 3 + + def var(self, *args, **kwargs): + return 4 + + def sem(self, *args, **kwargs): + return 5 + + def prod(self, *args, **kwargs): + return 6 + + def min(self, *args, **kwargs): + return 7 + + def max(self, *args, **kwargs): + return 8 + + class UnitDataFrame(DataFrame): + @property + def _constructor(self): + return UnitDataFrame + + @property + def _constructor_expanddim(self): + return UnitSeries + + def mean(self, *args, **kwargs): + return 1 + + def median(self, *args, **kwargs): + return 2 + + def std(self, *args, **kwargs): + return 3 + + def var(self, *args, **kwargs): + return 4 + + def sem(self, *args, **kwargs): + return 5 + + def prod(self, *args, **kwargs): + return 6 + + def min(self, *args, **kwargs): + return 7 + + def max(self, *args, **kwargs): + return 8 + + columns = ["a", "b"] + data = np.random.rand(4, 2) + udf = UnitDataFrame(data, columns=columns) + udf["group"] = np.ones(4, dtype=int) + udf.loc[2:, "group"] = 2 + + us = udf[["a", "group"]] + + assert np.all(udf.groupby("group").mean() == 1) + assert np.all(udf.groupby("group").median() == 2) + assert np.all(udf.groupby("group").std() == 3) + assert np.all(udf.groupby("group").var() == 4) + assert np.all(udf.groupby("group").sem() == 5) + assert np.all(udf.groupby("group").prod() == 6) + assert np.all(udf.groupby("group").min() == 7) + assert np.all(udf.groupby("group").max() == 8) + + assert np.all(us.groupby("group").mean() == 1) + assert np.all(us.groupby("group").median() == 2) + assert np.all(us.groupby("group").std() == 3) + assert np.all(us.groupby("group").var() == 4) + assert np.all(us.groupby("group").sem() == 5) + assert np.all(us.groupby("group").prod() == 6) + assert np.all(us.groupby("group").min() == 7) + assert np.all(us.groupby("group").max() == 8) + + assert np.all(udf.groupby("group").transform("mean") == 1) + assert np.all(udf.groupby("group").transform("median") == 2) + assert np.all(udf.groupby("group").transform("std") == 3) + assert np.all(udf.groupby("group").transform("var") == 4) + assert np.all(udf.groupby("group").transform("sem") == 5) + assert np.all(udf.groupby("group").transform("prod") == 6) + assert np.all(udf.groupby("group").transform("min") == 7) + assert np.all(udf.groupby("group").transform("max") == 8) + + assert np.all(us.groupby("group").transform("mean") == 1) + assert np.all(us.groupby("group").transform("median") == 2) + assert np.all(us.groupby("group").transform("std") == 3) + assert np.all(us.groupby("group").transform("var") == 4) + assert np.all(us.groupby("group").transform("sem") == 5) + assert np.all(us.groupby("group").transform("prod") == 6) + assert np.all(us.groupby("group").transform("min") == 7) + assert np.all(us.groupby("group").transform("max") == 8)