Skip to content

REF: make maybe_cast_result_dtype a WrappedCythonOp method #41065

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 0 additions & 36 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,42 +406,6 @@ def maybe_cast_pointwise_result(
return result


def maybe_cast_result_dtype(dtype: DtypeObj, how: str) -> DtypeObj:
"""
Get the desired dtype of a result based on the
input dtype and how it was computed.

Parameters
----------
dtype : DtypeObj
Input dtype.
how : str
How the result was computed.

Returns
-------
DtypeObj
The desired dtype of the result.
"""
from pandas.core.arrays.boolean import BooleanDtype
from pandas.core.arrays.floating import Float64Dtype
from pandas.core.arrays.integer import (
Int64Dtype,
_IntegerDtype,
)

if how in ["add", "cumsum", "sum", "prod"]:
if dtype == np.dtype(bool):
return np.dtype(np.int64)
elif isinstance(dtype, (BooleanDtype, _IntegerDtype)):
return Int64Dtype()
elif how in ["mean", "median", "var"] and isinstance(
dtype, (BooleanDtype, _IntegerDtype)
):
return Float64Dtype()
return dtype


def maybe_cast_to_extension_array(
cls: type[ExtensionArray], obj: ArrayLike, dtype: ExtensionDtype | None = None
) -> ArrayLike:
Expand Down
58 changes: 50 additions & 8 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@

from pandas.core.dtypes.cast import (
maybe_cast_pointwise_result,
maybe_cast_result_dtype,
maybe_downcast_to_dtype,
)
from pandas.core.dtypes.common import (
Expand Down Expand Up @@ -262,6 +261,41 @@ def get_out_dtype(self, dtype: np.dtype) -> np.dtype:
out_dtype = "object"
return np.dtype(out_dtype)

def get_result_dtype(self, dtype: DtypeObj) -> DtypeObj:
"""
Get the desired dtype of a result based on the
input dtype and how it was computed.

Parameters
----------
dtype : np.dtype or ExtensionDtype
Input dtype.

Returns
-------
np.dtype or ExtensionDtype
The desired dtype of the result.
"""
from pandas.core.arrays.boolean import BooleanDtype
from pandas.core.arrays.floating import Float64Dtype
from pandas.core.arrays.integer import (
Int64Dtype,
_IntegerDtype,
)

how = self.how

if how in ["add", "cumsum", "sum", "prod"]:
if dtype == np.dtype(bool):
return np.dtype(np.int64)
elif isinstance(dtype, (BooleanDtype, _IntegerDtype)):
return Int64Dtype()
elif how in ["mean", "median", "var"] and isinstance(
dtype, (BooleanDtype, _IntegerDtype)
):
return Float64Dtype()
return dtype

def uses_mask(self) -> bool:
return self.how in self._MASKED_CYTHON_FUNCTIONS

Expand Down Expand Up @@ -564,7 +598,14 @@ def get_group_levels(self) -> list[Index]:

@final
def _ea_wrap_cython_operation(
self, kind: str, values, how: str, axis: int, min_count: int = -1, **kwargs
self,
cy_op: WrappedCythonOp,
kind: str,
values,
how: str,
axis: int,
min_count: int = -1,
**kwargs,
) -> ArrayLike:
"""
If we have an ExtensionArray, unwrap, call _cython_operation, and
Expand Down Expand Up @@ -601,7 +642,7 @@ def _ea_wrap_cython_operation(
# other cast_blocklist methods dont go through cython_operation
return res_values

dtype = maybe_cast_result_dtype(orig_values.dtype, how)
dtype = cy_op.get_result_dtype(orig_values.dtype)
# error: Item "dtype[Any]" of "Union[dtype[Any], ExtensionDtype]"
# has no attribute "construct_array_type"
cls = dtype.construct_array_type() # type: ignore[union-attr]
Expand All @@ -618,7 +659,7 @@ def _ea_wrap_cython_operation(
# other cast_blocklist methods dont go through cython_operation
return res_values

dtype = maybe_cast_result_dtype(orig_values.dtype, how)
dtype = cy_op.get_result_dtype(orig_values.dtype)
# error: Item "dtype[Any]" of "Union[dtype[Any], ExtensionDtype]"
# has no attribute "construct_array_type"
cls = dtype.construct_array_type() # type: ignore[union-attr]
Expand All @@ -631,6 +672,7 @@ def _ea_wrap_cython_operation(
@final
def _masked_ea_wrap_cython_operation(
self,
cy_op: WrappedCythonOp,
kind: str,
values: BaseMaskedArray,
how: str,
Expand All @@ -651,7 +693,7 @@ def _masked_ea_wrap_cython_operation(
res_values = self._cython_operation(
kind, arr, how, axis, min_count, mask=mask, **kwargs
)
dtype = maybe_cast_result_dtype(orig_values.dtype, how)
dtype = cy_op.get_result_dtype(orig_values.dtype)
assert isinstance(dtype, BaseMaskedDtype)
cls = dtype.construct_array_type()

Expand Down Expand Up @@ -694,11 +736,11 @@ def _cython_operation(
if is_extension_array_dtype(dtype):
if isinstance(values, BaseMaskedArray) and func_uses_mask:
return self._masked_ea_wrap_cython_operation(
kind, values, how, axis, min_count, **kwargs
cy_op, kind, values, how, axis, min_count, **kwargs
)
else:
return self._ea_wrap_cython_operation(
kind, values, how, axis, min_count, **kwargs
cy_op, kind, values, how, axis, min_count, **kwargs
)

elif values.ndim == 1:
Expand Down Expand Up @@ -797,7 +839,7 @@ def _cython_operation(
if how not in cy_op.cast_blocklist:
# e.g. if we are int64 and need to restore to datetime64/timedelta64
# "rank" is the only member of cast_blocklist we get here
dtype = maybe_cast_result_dtype(orig_values.dtype, how)
dtype = cy_op.get_result_dtype(orig_values.dtype)
op_result = maybe_downcast_to_dtype(result, dtype)
else:
op_result = result
Expand Down