From e8c555159601636fb79dc0d142b72666f8fd87b2 Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 20 Apr 2021 19:24:56 -0700 Subject: [PATCH 1/3] REF: make maybe_cast_result_dtype a WrappedCythonOp method --- pandas/core/dtypes/cast.py | 36 -------------------------- pandas/core/groupby/ops.py | 53 +++++++++++++++++++++++++++++++++----- 2 files changed, 47 insertions(+), 42 deletions(-) diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index e91927d87d318..d739b46620032 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -406,42 +406,6 @@ def maybe_cast_pointwise_result( return result -def maybe_cast_result_dtype(dtype: DtypeObj, how: str) -> DtypeObj: - """ - Get the desired dtype of a result based on the - input dtype and how it was computed. - - Parameters - ---------- - dtype : DtypeObj - Input dtype. - how : str - How the result was computed. - - Returns - ------- - DtypeObj - The desired dtype of the result. - """ - from pandas.core.arrays.boolean import BooleanDtype - from pandas.core.arrays.floating import Float64Dtype - from pandas.core.arrays.integer import ( - Int64Dtype, - _IntegerDtype, - ) - - if how in ["add", "cumsum", "sum", "prod"]: - if dtype == np.dtype(bool): - return np.dtype(np.int64) - elif isinstance(dtype, (BooleanDtype, _IntegerDtype)): - return Int64Dtype() - elif how in ["mean", "median", "var"] and isinstance( - dtype, (BooleanDtype, _IntegerDtype) - ): - return Float64Dtype() - return dtype - - def maybe_cast_to_extension_array( cls: type[ExtensionArray], obj: ArrayLike, dtype: ExtensionDtype | None = None ) -> ArrayLike: diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index 6eddf8e9e8773..8630de2b7ac5f 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -37,7 +37,6 @@ from pandas.core.dtypes.cast import ( maybe_cast_pointwise_result, - maybe_cast_result_dtype, maybe_downcast_to_dtype, ) from pandas.core.dtypes.common import ( @@ -256,6 +255,41 @@ def get_out_dtype(self, dtype: np.dtype) -> np.dtype: out_dtype = "object" return np.dtype(out_dtype) + def get_result_dtype(self, dtype: DtypeObj) -> DtypeObj: + """ + Get the desired dtype of a result based on the + input dtype and how it was computed. + + Parameters + ---------- + dtype : DtypeObj + Input dtype. + + Returns + ------- + DtypeObj + The desired dtype of the result. + """ + from pandas.core.arrays.boolean import BooleanDtype + from pandas.core.arrays.floating import Float64Dtype + from pandas.core.arrays.integer import ( + Int64Dtype, + _IntegerDtype, + ) + + how = self.how + + if how in ["add", "cumsum", "sum", "prod"]: + if dtype == np.dtype(bool): + return np.dtype(np.int64) + elif isinstance(dtype, (BooleanDtype, _IntegerDtype)): + return Int64Dtype() + elif how in ["mean", "median", "var"] and isinstance( + dtype, (BooleanDtype, _IntegerDtype) + ): + return Float64Dtype() + return dtype + class BaseGrouper: """ @@ -555,7 +589,14 @@ def get_group_levels(self) -> list[Index]: @final def _ea_wrap_cython_operation( - self, kind: str, values, how: str, axis: int, min_count: int = -1, **kwargs + self, + cy_op: WrappedCythonOp, + kind: str, + values, + how: str, + axis: int, + min_count: int = -1, + **kwargs, ) -> ArrayLike: """ If we have an ExtensionArray, unwrap, call _cython_operation, and @@ -592,7 +633,7 @@ def _ea_wrap_cython_operation( # other cast_blocklist methods dont go through cython_operation return res_values - dtype = maybe_cast_result_dtype(orig_values.dtype, how) + dtype = cy_op.get_result_dtype(orig_values.dtype) # error: Item "dtype[Any]" of "Union[dtype[Any], ExtensionDtype]" # has no attribute "construct_array_type" cls = dtype.construct_array_type() # type: ignore[union-attr] @@ -609,7 +650,7 @@ def _ea_wrap_cython_operation( # other cast_blocklist methods dont go through cython_operation return res_values - dtype = maybe_cast_result_dtype(orig_values.dtype, how) + dtype = cy_op.get_result_dtype(orig_values.dtype) # error: Item "dtype[Any]" of "Union[dtype[Any], ExtensionDtype]" # has no attribute "construct_array_type" cls = dtype.construct_array_type() # type: ignore[union-attr] @@ -647,7 +688,7 @@ def _cython_operation( if is_extension_array_dtype(dtype): return self._ea_wrap_cython_operation( - kind, values, how, axis, min_count, **kwargs + cy_op, kind, values, how, axis, min_count, **kwargs ) elif values.ndim == 1: @@ -731,7 +772,7 @@ def _cython_operation( if how not in cy_op.cast_blocklist: # e.g. if we are int64 and need to restore to datetime64/timedelta64 # "rank" is the only member of cast_blocklist we get here - dtype = maybe_cast_result_dtype(orig_values.dtype, how) + dtype = cy_op.get_result_dtype(orig_values.dtype) op_result = maybe_downcast_to_dtype(result, dtype) else: op_result = result From 01c918da328dae44b34ae77f970ea2fc88c61735 Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 20 Apr 2021 19:26:47 -0700 Subject: [PATCH 2/3] improve docstring --- pandas/core/groupby/ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index 8630de2b7ac5f..cab0b547e28c8 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -262,12 +262,12 @@ def get_result_dtype(self, dtype: DtypeObj) -> DtypeObj: Parameters ---------- - dtype : DtypeObj + dtype : np.dtype or ExtensionDtype Input dtype. Returns ------- - DtypeObj + np.dtype or ExtensionDtype The desired dtype of the result. """ from pandas.core.arrays.boolean import BooleanDtype From aba54a9341e5b0485d6b5c4eace30bd32c8de87e Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 21 Apr 2021 08:37:44 -0700 Subject: [PATCH 3/3] post-merge fixup --- pandas/core/groupby/ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index fb80e7d892e24..2db92bc696f02 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -672,6 +672,7 @@ def _ea_wrap_cython_operation( @final def _masked_ea_wrap_cython_operation( self, + cy_op: WrappedCythonOp, kind: str, values: BaseMaskedArray, how: str, @@ -692,7 +693,7 @@ def _masked_ea_wrap_cython_operation( res_values = self._cython_operation( kind, arr, how, axis, min_count, mask=mask, **kwargs ) - dtype = self.get_result_dtype(orig_values.dtype) + dtype = cy_op.get_result_dtype(orig_values.dtype) assert isinstance(dtype, BaseMaskedDtype) cls = dtype.construct_array_type()