Skip to content

Commit 612d7d0

Browse files
author
Kei
committed
Update impl to fix tests
1 parent 9181eaf commit 612d7d0

File tree

2 files changed

+31
-24
lines changed

2 files changed

+31
-24
lines changed

pandas/core/dtypes/cast.py

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ def maybe_cast_pointwise_result(
478478
return result
479479

480480

481-
def maybe_cast_to_pyarrow_result(result: ArrayLike) -> ArrayLike:
481+
def maybe_cast_to_pyarrow_result(result: ArrayLike, obj_dtype: DtypeObj) -> ArrayLike:
482482
"""
483483
Try casting result of a pointwise operation to its pyarrow dtype
484484
and arrow extension array if appropriate. If not possible,
@@ -501,12 +501,14 @@ def maybe_cast_to_pyarrow_result(result: ArrayLike) -> ArrayLike:
501501
stripped_result = result[~isna(result)]
502502
npvalues = lib.maybe_convert_objects(stripped_result, try_float=False)
503503

504+
if stripped_result.size == 0:
505+
return maybe_cast_pointwise_result(npvalues, obj_dtype, numeric_only=True)
506+
504507
try:
505508
dtype = convert_dtypes(npvalues, dtype_backend="pyarrow")
506509
out = pd_array(result, dtype=dtype)
507510
except (TypeError, ValueError, np.ComplexWarning):
508511
out = npvalues
509-
510512
return out
511513

512514

@@ -1194,29 +1196,34 @@ def _infer_pyarrow_dtype(
11941196
input_array: ArrayLike,
11951197
inferred_dtype: str,
11961198
) -> DtypeObj:
1197-
if inferred_dtype not in ["time", "date", "decimal", "bytes"]:
1198-
return input_array.dtype
1199-
1200-
# For a limited set of dtype
1201-
# Let pyarrow infer dtype from input_array
12021199
import pyarrow as pa
1203-
from pyarrow import (
1204-
ArrowInvalid,
1205-
ArrowMemoryError,
1206-
ArrowNotImplementedError,
1207-
)
12081200

1209-
try:
1210-
pyarrow_array = pa.array(input_array)
1211-
return ArrowDtype(pyarrow_array.type)
1212-
except (
1213-
TypeError,
1214-
ValueError,
1215-
ArrowInvalid,
1216-
ArrowMemoryError,
1217-
ArrowNotImplementedError,
1218-
):
1219-
return input_array.dtype
1201+
if inferred_dtype == "date":
1202+
return ArrowDtype(pa.date32())
1203+
elif inferred_dtype == "time":
1204+
return ArrowDtype(pa.time64("us"))
1205+
elif inferred_dtype == "bytes":
1206+
return ArrowDtype(pa.binary())
1207+
elif inferred_dtype == "decimal":
1208+
from pyarrow import (
1209+
ArrowInvalid,
1210+
ArrowMemoryError,
1211+
ArrowNotImplementedError,
1212+
)
1213+
1214+
try:
1215+
pyarrow_array = pa.array(input_array)
1216+
return ArrowDtype(pyarrow_array.type)
1217+
except (
1218+
TypeError,
1219+
ValueError,
1220+
ArrowInvalid,
1221+
ArrowMemoryError,
1222+
ArrowNotImplementedError,
1223+
):
1224+
return input_array.dtype
1225+
1226+
return input_array.dtype
12201227

12211228

12221229
def maybe_infer_to_datetimelike(

pandas/core/groupby/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -919,7 +919,7 @@ def agg_series(
919919
result = self._aggregate_series_pure_python(obj, func)
920920

921921
if isinstance(obj._values, ArrowExtensionArray):
922-
return maybe_cast_to_pyarrow_result(result)
922+
return maybe_cast_to_pyarrow_result(result, obj.dtype)
923923

924924
if not isinstance(obj._values, np.ndarray) and not isinstance(
925925
obj._values, ArrowExtensionArray

0 commit comments

Comments
 (0)