Skip to content

REF: re-use maybe_cast_result in Series.combine #40909

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@
if TYPE_CHECKING:
from typing import Literal

from pandas import Series
from pandas.core.arrays import (
DatetimeArray,
ExtensionArray,
Expand Down Expand Up @@ -375,7 +374,11 @@ def trans(x):


def maybe_cast_result(
result: ArrayLike, obj: Series, numeric_only: bool = False, how: str = ""
result: ArrayLike,
dtype: DtypeObj,
numeric_only: bool = False,
how: str = "",
same_dtype: bool = True,
) -> ArrayLike:
"""
Try casting result to a different type if appropriate
Expand All @@ -384,19 +387,20 @@ def maybe_cast_result(
----------
result : array-like
Result to cast.
obj : Series
dtype : np.dtype or ExtensionDtype
Input Series from which result was calculated.
numeric_only : bool, default False
Whether to cast only numerics or datetimes as well.
how : str, default ""
How the result was computed.
same_dtype : bool, default True
Specify dtype when calling _from_sequence

Returns
-------
result : array-like
result maybe casted to the dtype.
"""
dtype = obj.dtype
dtype = maybe_cast_result_dtype(dtype, how)

assert not is_scalar(result)
Expand All @@ -407,7 +411,10 @@ def maybe_cast_result(
# things like counts back to categorical

cls = dtype.construct_array_type()
result = maybe_cast_to_extension_array(cls, result, dtype=dtype)
if same_dtype:
result = maybe_cast_to_extension_array(cls, result, dtype=dtype)
else:
result = maybe_cast_to_extension_array(cls, result)

elif (numeric_only and is_numeric_dtype(dtype)) or not numeric_only:
result = maybe_downcast_to_dtype(result, dtype)
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ def _aggregate_series_pure_python(self, obj: Series, func: F):
result[label] = res

out = lib.maybe_convert_objects(result, try_float=False)
out = maybe_cast_result(out, obj, numeric_only=True)
out = maybe_cast_result(out, obj.dtype, numeric_only=True)

return out, counts

Expand Down
23 changes: 4 additions & 19 deletions pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,13 @@
from pandas.core.dtypes.cast import (
convert_dtypes,
maybe_box_native,
maybe_cast_to_extension_array,
maybe_cast_result,
validate_numeric_casting,
)
from pandas.core.dtypes.common import (
ensure_platform_int,
is_bool,
is_categorical_dtype,
is_dict_like,
is_extension_array_dtype,
is_integer,
is_iterator,
is_list_like,
Expand Down Expand Up @@ -3079,22 +3077,9 @@ def combine(self, other, func, fill_value=None) -> Series:
new_values = [func(lv, other) for lv in self._values]
new_name = self.name

if is_categorical_dtype(self.dtype):
pass
elif is_extension_array_dtype(self.dtype):
# TODO: can we do this for only SparseDtype?
# The function can return something of any type, so check
# if the type is compatible with the calling EA.

# error: Incompatible types in assignment (expression has type
# "Union[ExtensionArray, ndarray]", variable has type "List[Any]")
new_values = maybe_cast_to_extension_array( # type: ignore[assignment]
# error: Argument 2 to "maybe_cast_to_extension_array" has incompatible
# type "List[Any]"; expected "Union[ExtensionArray, ndarray]"
type(self._values),
new_values, # type: ignore[arg-type]
)
return self._constructor(new_values, index=new_index, name=new_name)
res_values = sanitize_array(new_values, None)
res_values = maybe_cast_result(res_values, self.dtype, same_dtype=False)
return self._constructor(res_values, index=new_index, name=new_name)

def combine_first(self, other) -> Series:
"""
Expand Down
17 changes: 11 additions & 6 deletions pandas/tests/extension/decimal/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,13 +362,18 @@ def _create_arithmetic_method(cls, op):
DecimalArrayWithoutCoercion._add_arithmetic_ops()


def test_combine_from_sequence_raises():
def test_combine_from_sequence_raises(monkeypatch):
# https://github.com/pandas-dev/pandas/issues/22850
ser = pd.Series(
DecimalArrayWithoutFromSequence(
[decimal.Decimal("1.0"), decimal.Decimal("2.0")]
)
)
cls = DecimalArrayWithoutFromSequence

@classmethod
def construct_array_type(cls):
return DecimalArrayWithoutFromSequence

monkeypatch.setattr(DecimalDtype, "construct_array_type", construct_array_type)

arr = cls([decimal.Decimal("1.0"), decimal.Decimal("2.0")])
ser = pd.Series(arr)
result = ser.combine(ser, operator.add)

# note: object dtype
Expand Down