@@ -478,7 +478,7 @@ def maybe_cast_pointwise_result(
478
478
return result
479
479
480
480
481
- def maybe_cast_to_pyarrow_result (result : ArrayLike ) -> ArrayLike :
481
+ def maybe_cast_to_pyarrow_result (result : ArrayLike , obj_dtype : DtypeObj ) -> ArrayLike :
482
482
"""
483
483
Try casting result of a pointwise operation to its pyarrow dtype
484
484
and arrow extension array if appropriate. If not possible,
@@ -501,12 +501,14 @@ def maybe_cast_to_pyarrow_result(result: ArrayLike) -> ArrayLike:
501
501
stripped_result = result [~ isna (result )]
502
502
npvalues = lib .maybe_convert_objects (stripped_result , try_float = False )
503
503
504
+ if stripped_result .size == 0 :
505
+ return maybe_cast_pointwise_result (npvalues , obj_dtype , numeric_only = True )
506
+
504
507
try :
505
508
dtype = convert_dtypes (npvalues , dtype_backend = "pyarrow" )
506
509
out = pd_array (result , dtype = dtype )
507
510
except (TypeError , ValueError , np .ComplexWarning ):
508
511
out = npvalues
509
-
510
512
return out
511
513
512
514
@@ -1194,29 +1196,34 @@ def _infer_pyarrow_dtype(
1194
1196
input_array : ArrayLike ,
1195
1197
inferred_dtype : str ,
1196
1198
) -> DtypeObj :
1197
- if inferred_dtype not in ["time" , "date" , "decimal" , "bytes" ]:
1198
- return input_array .dtype
1199
-
1200
- # For a limited set of dtype
1201
- # Let pyarrow infer dtype from input_array
1202
1199
import pyarrow as pa
1203
- from pyarrow import (
1204
- ArrowInvalid ,
1205
- ArrowMemoryError ,
1206
- ArrowNotImplementedError ,
1207
- )
1208
1200
1209
- try :
1210
- pyarrow_array = pa .array (input_array )
1211
- return ArrowDtype (pyarrow_array .type )
1212
- except (
1213
- TypeError ,
1214
- ValueError ,
1215
- ArrowInvalid ,
1216
- ArrowMemoryError ,
1217
- ArrowNotImplementedError ,
1218
- ):
1219
- return input_array .dtype
1201
+ if inferred_dtype == "date" :
1202
+ return ArrowDtype (pa .date32 ())
1203
+ elif inferred_dtype == "time" :
1204
+ return ArrowDtype (pa .time64 ("us" ))
1205
+ elif inferred_dtype == "bytes" :
1206
+ return ArrowDtype (pa .binary ())
1207
+ elif inferred_dtype == "decimal" :
1208
+ from pyarrow import (
1209
+ ArrowInvalid ,
1210
+ ArrowMemoryError ,
1211
+ ArrowNotImplementedError ,
1212
+ )
1213
+
1214
+ try :
1215
+ pyarrow_array = pa .array (input_array )
1216
+ return ArrowDtype (pyarrow_array .type )
1217
+ except (
1218
+ TypeError ,
1219
+ ValueError ,
1220
+ ArrowInvalid ,
1221
+ ArrowMemoryError ,
1222
+ ArrowNotImplementedError ,
1223
+ ):
1224
+ return input_array .dtype
1225
+
1226
+ return input_array .dtype
1220
1227
1221
1228
1222
1229
def maybe_infer_to_datetimelike (
0 commit comments