From 471b7776229d8d88a45b7c33b1984dd6da08a90c Mon Sep 17 00:00:00 2001 From: Brock Date: Sat, 21 Jan 2023 15:39:28 -0800 Subject: [PATCH] ENH: support min/max/sum for pyarrow duration dtypes --- pandas/core/arrays/arrow/array.py | 12 +++++++++++- pandas/tests/extension/base/groupby.py | 1 + pandas/tests/extension/test_arrow.py | 4 ---- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index e2a74ea6f5351..a6a83feb24435 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1033,8 +1033,15 @@ def pyarrow_meth(data, skip_nulls, **kwargs): if pyarrow_meth is None: # Let ExtensionArray._reduce raise the TypeError return super()._reduce(name, skipna=skipna, **kwargs) + + data_to_reduce = self._data + + pa_dtype = self._data.type + if name in ["min", "max", "sum"] and pa.types.is_duration(pa_dtype): + data_to_reduce = self._data.cast(pa.int64()) + try: - result = pyarrow_meth(self._data, skip_nulls=skipna, **kwargs) + result = pyarrow_meth(data_to_reduce, skip_nulls=skipna, **kwargs) except (AttributeError, NotImplementedError, TypeError) as err: msg = ( f"'{type(self).__name__}' with dtype {self.dtype} " @@ -1045,6 +1052,9 @@ def pyarrow_meth(data, skip_nulls, **kwargs): raise TypeError(msg) from err if pc.is_null(result).as_py(): return self.dtype.na_value + + if name in ["min", "max", "sum"] and pa.types.is_duration(pa_dtype): + result = result.cast(pa_dtype) return result.as_py() def __setitem__(self, key, value) -> None: diff --git a/pandas/tests/extension/base/groupby.py b/pandas/tests/extension/base/groupby.py index 3a9dbe9dfb384..200a494997116 100644 --- a/pandas/tests/extension/base/groupby.py +++ b/pandas/tests/extension/base/groupby.py @@ -137,6 +137,7 @@ def test_in_numeric_groupby(self, data_for_grouping): or is_string_dtype(dtype) or is_period_dtype(dtype) or is_object_dtype(dtype) + or dtype.kind == "m" # in particular duration[*][pyarrow] ): expected = pd.Index(["B", "C"]) result = df.groupby("A").sum().columns diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index a7c243cdfe74f..91a725234359a 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -539,10 +539,6 @@ def test_reduce_series(self, data, all_numeric_reductions, skipna, request): "sem", ] and pa.types.is_temporal(pa_dtype): request.node.add_marker(xfail_mark) - elif all_numeric_reductions in ["sum", "min", "max"] and pa.types.is_duration( - pa_dtype - ): - request.node.add_marker(xfail_mark) elif pa.types.is_boolean(pa_dtype) and all_numeric_reductions in { "sem", "std",