diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index f62aa95e1e814..b68ec3c473a41 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -102,7 +102,6 @@ if TYPE_CHECKING: from typing import Literal - from pandas import Series from pandas.core.arrays import ( DatetimeArray, ExtensionArray, @@ -375,7 +374,11 @@ def trans(x): def maybe_cast_result( - result: ArrayLike, obj: Series, numeric_only: bool = False, how: str = "" + result: ArrayLike, + dtype: DtypeObj, + numeric_only: bool = False, + how: str = "", + same_dtype: bool = True, ) -> ArrayLike: """ Try casting result to a different type if appropriate @@ -384,19 +387,20 @@ def maybe_cast_result( ---------- result : array-like Result to cast. - obj : Series + dtype : np.dtype or ExtensionDtype Input Series from which result was calculated. numeric_only : bool, default False Whether to cast only numerics or datetimes as well. how : str, default "" How the result was computed. + same_dtype : bool, default True + Specify dtype when calling _from_sequence Returns ------- result : array-like result maybe casted to the dtype. """ - dtype = obj.dtype dtype = maybe_cast_result_dtype(dtype, how) assert not is_scalar(result) @@ -407,7 +411,10 @@ def maybe_cast_result( # things like counts back to categorical cls = dtype.construct_array_type() - result = maybe_cast_to_extension_array(cls, result, dtype=dtype) + if same_dtype: + result = maybe_cast_to_extension_array(cls, result, dtype=dtype) + else: + result = maybe_cast_to_extension_array(cls, result) elif (numeric_only and is_numeric_dtype(dtype)) or not numeric_only: result = maybe_downcast_to_dtype(result, dtype) diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index bc5318a1f367c..2e7031ab2888e 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -788,7 +788,7 @@ def _aggregate_series_pure_python(self, obj: Series, func: F): result[label] = res out = lib.maybe_convert_objects(result, try_float=False) - out = maybe_cast_result(out, obj, numeric_only=True) + out = maybe_cast_result(out, obj.dtype, numeric_only=True) return out, counts diff --git a/pandas/core/series.py b/pandas/core/series.py index 5ba68aaa5c16d..1c49a28ef93ed 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -60,15 +60,13 @@ from pandas.core.dtypes.cast import ( convert_dtypes, maybe_box_native, - maybe_cast_to_extension_array, + maybe_cast_result, validate_numeric_casting, ) from pandas.core.dtypes.common import ( ensure_platform_int, is_bool, - is_categorical_dtype, is_dict_like, - is_extension_array_dtype, is_integer, is_iterator, is_list_like, @@ -3079,22 +3077,9 @@ def combine(self, other, func, fill_value=None) -> Series: new_values = [func(lv, other) for lv in self._values] new_name = self.name - if is_categorical_dtype(self.dtype): - pass - elif is_extension_array_dtype(self.dtype): - # TODO: can we do this for only SparseDtype? - # The function can return something of any type, so check - # if the type is compatible with the calling EA. - - # error: Incompatible types in assignment (expression has type - # "Union[ExtensionArray, ndarray]", variable has type "List[Any]") - new_values = maybe_cast_to_extension_array( # type: ignore[assignment] - # error: Argument 2 to "maybe_cast_to_extension_array" has incompatible - # type "List[Any]"; expected "Union[ExtensionArray, ndarray]" - type(self._values), - new_values, # type: ignore[arg-type] - ) - return self._constructor(new_values, index=new_index, name=new_name) + res_values = sanitize_array(new_values, None) + res_values = maybe_cast_result(res_values, self.dtype, same_dtype=False) + return self._constructor(res_values, index=new_index, name=new_name) def combine_first(self, other) -> Series: """ diff --git a/pandas/tests/extension/decimal/test_decimal.py b/pandas/tests/extension/decimal/test_decimal.py index 55f9d85574f94..7a3f88d0d6c41 100644 --- a/pandas/tests/extension/decimal/test_decimal.py +++ b/pandas/tests/extension/decimal/test_decimal.py @@ -362,13 +362,18 @@ def _create_arithmetic_method(cls, op): DecimalArrayWithoutCoercion._add_arithmetic_ops() -def test_combine_from_sequence_raises(): +def test_combine_from_sequence_raises(monkeypatch): # https://github.com/pandas-dev/pandas/issues/22850 - ser = pd.Series( - DecimalArrayWithoutFromSequence( - [decimal.Decimal("1.0"), decimal.Decimal("2.0")] - ) - ) + cls = DecimalArrayWithoutFromSequence + + @classmethod + def construct_array_type(cls): + return DecimalArrayWithoutFromSequence + + monkeypatch.setattr(DecimalDtype, "construct_array_type", construct_array_type) + + arr = cls([decimal.Decimal("1.0"), decimal.Decimal("2.0")]) + ser = pd.Series(arr) result = ser.combine(ser, operator.add) # note: object dtype