Skip to content

Commit bb1ef46

Browse files
authored
REF: re-use maybe_cast_result in Series.combine (#40909)
1 parent eb0978f commit bb1ef46

File tree

4 files changed

+28
-31
lines changed

4 files changed

+28
-31
lines changed

pandas/core/dtypes/cast.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@
102102
if TYPE_CHECKING:
103103
from typing import Literal
104104

105-
from pandas import Series
106105
from pandas.core.arrays import (
107106
DatetimeArray,
108107
ExtensionArray,
@@ -375,7 +374,11 @@ def trans(x):
375374

376375

377376
def maybe_cast_result(
378-
result: ArrayLike, obj: Series, numeric_only: bool = False, how: str = ""
377+
result: ArrayLike,
378+
dtype: DtypeObj,
379+
numeric_only: bool = False,
380+
how: str = "",
381+
same_dtype: bool = True,
379382
) -> ArrayLike:
380383
"""
381384
Try casting result to a different type if appropriate
@@ -384,19 +387,20 @@ def maybe_cast_result(
384387
----------
385388
result : array-like
386389
Result to cast.
387-
obj : Series
390+
dtype : np.dtype or ExtensionDtype
388391
Input Series from which result was calculated.
389392
numeric_only : bool, default False
390393
Whether to cast only numerics or datetimes as well.
391394
how : str, default ""
392395
How the result was computed.
396+
same_dtype : bool, default True
397+
Specify dtype when calling _from_sequence
393398
394399
Returns
395400
-------
396401
result : array-like
397402
result maybe casted to the dtype.
398403
"""
399-
dtype = obj.dtype
400404
dtype = maybe_cast_result_dtype(dtype, how)
401405

402406
assert not is_scalar(result)
@@ -407,7 +411,10 @@ def maybe_cast_result(
407411
# things like counts back to categorical
408412

409413
cls = dtype.construct_array_type()
410-
result = maybe_cast_to_extension_array(cls, result, dtype=dtype)
414+
if same_dtype:
415+
result = maybe_cast_to_extension_array(cls, result, dtype=dtype)
416+
else:
417+
result = maybe_cast_to_extension_array(cls, result)
411418

412419
elif (numeric_only and is_numeric_dtype(dtype)) or not numeric_only:
413420
result = maybe_downcast_to_dtype(result, dtype)

pandas/core/groupby/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -787,7 +787,7 @@ def _aggregate_series_pure_python(self, obj: Series, func: F):
787787
result[label] = res
788788

789789
out = lib.maybe_convert_objects(result, try_float=False)
790-
out = maybe_cast_result(out, obj, numeric_only=True)
790+
out = maybe_cast_result(out, obj.dtype, numeric_only=True)
791791

792792
return out, counts
793793

pandas/core/series.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,13 @@
6060
from pandas.core.dtypes.cast import (
6161
convert_dtypes,
6262
maybe_box_native,
63-
maybe_cast_to_extension_array,
63+
maybe_cast_result,
6464
validate_numeric_casting,
6565
)
6666
from pandas.core.dtypes.common import (
6767
ensure_platform_int,
6868
is_bool,
69-
is_categorical_dtype,
7069
is_dict_like,
71-
is_extension_array_dtype,
7270
is_integer,
7371
is_iterator,
7472
is_list_like,
@@ -3087,22 +3085,9 @@ def combine(self, other, func, fill_value=None) -> Series:
30873085
new_values = [func(lv, other) for lv in self._values]
30883086
new_name = self.name
30893087

3090-
if is_categorical_dtype(self.dtype):
3091-
pass
3092-
elif is_extension_array_dtype(self.dtype):
3093-
# TODO: can we do this for only SparseDtype?
3094-
# The function can return something of any type, so check
3095-
# if the type is compatible with the calling EA.
3096-
3097-
# error: Incompatible types in assignment (expression has type
3098-
# "Union[ExtensionArray, ndarray]", variable has type "List[Any]")
3099-
new_values = maybe_cast_to_extension_array( # type: ignore[assignment]
3100-
# error: Argument 2 to "maybe_cast_to_extension_array" has incompatible
3101-
# type "List[Any]"; expected "Union[ExtensionArray, ndarray]"
3102-
type(self._values),
3103-
new_values, # type: ignore[arg-type]
3104-
)
3105-
return self._constructor(new_values, index=new_index, name=new_name)
3088+
res_values = sanitize_array(new_values, None)
3089+
res_values = maybe_cast_result(res_values, self.dtype, same_dtype=False)
3090+
return self._constructor(res_values, index=new_index, name=new_name)
31063091

31073092
def combine_first(self, other) -> Series:
31083093
"""

pandas/tests/extension/decimal/test_decimal.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -362,13 +362,18 @@ def _create_arithmetic_method(cls, op):
362362
DecimalArrayWithoutCoercion._add_arithmetic_ops()
363363

364364

365-
def test_combine_from_sequence_raises():
365+
def test_combine_from_sequence_raises(monkeypatch):
366366
# https://github.com/pandas-dev/pandas/issues/22850
367-
ser = pd.Series(
368-
DecimalArrayWithoutFromSequence(
369-
[decimal.Decimal("1.0"), decimal.Decimal("2.0")]
370-
)
371-
)
367+
cls = DecimalArrayWithoutFromSequence
368+
369+
@classmethod
370+
def construct_array_type(cls):
371+
return DecimalArrayWithoutFromSequence
372+
373+
monkeypatch.setattr(DecimalDtype, "construct_array_type", construct_array_type)
374+
375+
arr = cls([decimal.Decimal("1.0"), decimal.Decimal("2.0")])
376+
ser = pd.Series(arr)
372377
result = ser.combine(ser, operator.add)
373378

374379
# note: object dtype

0 commit comments

Comments
 (0)