From 8ca57f29818155a5ceab445cc9c6583761612288 Mon Sep 17 00:00:00 2001 From: Luke Manley Date: Sun, 11 Jun 2023 09:21:27 -0400 Subject: [PATCH 1/3] ENH: Series.explode to support pyarrow-backed list types --- doc/source/whatsnew/v2.1.0.rst | 1 + pandas/core/arrays/arrow/array.py | 30 ++++++++++++++++----- pandas/core/series.py | 13 ++++++--- pandas/tests/series/methods/test_explode.py | 24 +++++++++++++++++ 4 files changed, 58 insertions(+), 10 deletions(-) diff --git a/doc/source/whatsnew/v2.1.0.rst b/doc/source/whatsnew/v2.1.0.rst index baacc8c421414..456325b7550bc 100644 --- a/doc/source/whatsnew/v2.1.0.rst +++ b/doc/source/whatsnew/v2.1.0.rst @@ -108,6 +108,7 @@ Other enhancements - Many read/to_* functions, such as :meth:`DataFrame.to_pickle` and :func:`read_csv`, support forwarding compression arguments to lzma.LZMAFile (:issue:`52979`) - Performance improvement in :func:`concat` with homogeneous ``np.float64`` or ``np.float32`` dtypes (:issue:`52685`) - Performance improvement in :meth:`DataFrame.filter` when ``items`` is given (:issue:`52941`) +- :meth:`Series.explode` now supports pyarrow-backed list types (:issue:`#####`) .. --------------------------------------------------------------------------- .. _whatsnew_210.notable_bug_fixes: diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 817d5d0932744..a8017a0c3fc9a 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -347,9 +347,9 @@ def _box_pa( ------- pa.Array or pa.ChunkedArray or pa.Scalar """ - if is_list_like(value): - return cls._box_pa_array(value, pa_type) - return cls._box_pa_scalar(value, pa_type) + if isinstance(value, pa.Scalar) or not is_list_like(value): + return cls._box_pa_scalar(value, pa_type) + return cls._box_pa_array(value, pa_type) @classmethod def _box_pa_scalar(cls, value, pa_type: pa.DataType | None = None) -> pa.Scalar: @@ -1549,6 +1549,24 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs): return result.as_py() + def _explode(self): + """ + See Series.explode.__doc__. + """ + values = self + counts = pa.compute.list_value_length(values._pa_array) + counts = counts.fill_null(1).to_numpy() + fill_value = pa.scalar([None], type=self._pa_array.type) + mask = counts == 0 + if mask.any(): + values = values.copy() + values[mask] = fill_value + counts = counts.copy() + counts[mask] = 1 + values = values.fillna(fill_value) + values = type(self)(pa.compute.list_flatten(values._pa_array)) + return values, counts + def __setitem__(self, key, value) -> None: """Set one or more values inplace. @@ -1591,10 +1609,10 @@ def __setitem__(self, key, value) -> None: raise IndexError( f"index {key} is out of bounds for axis 0 with size {n}" ) - if is_list_like(value): - raise ValueError("Length of indexer and values mismatch") - elif isinstance(value, pa.Scalar): + if isinstance(value, pa.Scalar): value = value.as_py() + elif is_list_like(value): + raise ValueError("Length of indexer and values mismatch") chunks = [ *self._pa_array[:key].chunks, pa.array([value], type=self._pa_array.type, from_pandas=True), diff --git a/pandas/core/series.py b/pandas/core/series.py index 9c7110cc21082..959c153561572 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -72,7 +72,10 @@ pandas_dtype, validate_all_hashable, ) -from pandas.core.dtypes.dtypes import ExtensionDtype +from pandas.core.dtypes.dtypes import ( + ArrowDtype, + ExtensionDtype, +) from pandas.core.dtypes.generic import ABCDataFrame from pandas.core.dtypes.inference import is_hashable from pandas.core.dtypes.missing import ( @@ -4267,12 +4270,14 @@ def explode(self, ignore_index: bool = False) -> Series: 3 4 dtype: object """ - if not len(self) or not is_object_dtype(self.dtype): + if isinstance(self.dtype, ArrowDtype) and self.dtype.type == list: + values, counts = self._values._explode() + elif len(self) and is_object_dtype(self.dtype): + values, counts = reshape.explode(np.asarray(self._values)) + else: result = self.copy() return result.reset_index(drop=True) if ignore_index else result - values, counts = reshape.explode(np.asarray(self._values)) - if ignore_index: index = default_index(len(values)) else: diff --git a/pandas/tests/series/methods/test_explode.py b/pandas/tests/series/methods/test_explode.py index 886152326cf3e..3ea7b2700c8fd 100644 --- a/pandas/tests/series/methods/test_explode.py +++ b/pandas/tests/series/methods/test_explode.py @@ -1,6 +1,8 @@ import numpy as np import pytest +from pandas.compat import pa_version_under7p0 + import pandas as pd import pandas._testing as tm @@ -141,3 +143,25 @@ def test_explode_scalars_can_ignore_index(): result = s.explode(ignore_index=True) expected = pd.Series([1, 2, 3]) tm.assert_series_equal(result, expected) + + +@pytest.mark.skipif(pa_version_under7p0, reason="minimum pyarrow not installed") +def test_explode_pyarrow_list_type(): + # GH ##### + import pyarrow as pa + + data = [ + [None, None], + [1], + [], + [2, 3], + None, + ] + ser = pd.Series(data, dtype=pd.ArrowDtype(pa.list_(pa.int64()))) + result = ser.explode() + expected = pd.Series( + data=[None, None, 1, None, 2, 3, None], + index=[0, 0, 1, 2, 3, 3, 4], + dtype=pd.ArrowDtype(pa.int64()), + ) + tm.assert_series_equal(result, expected) From c93ee7b30d0f47d86aa1272ad40d3ba50ccd8d86 Mon Sep 17 00:00:00 2001 From: Luke Manley Date: Sun, 11 Jun 2023 09:24:38 -0400 Subject: [PATCH 2/3] gh refs --- doc/source/whatsnew/v2.1.0.rst | 2 +- pandas/tests/series/methods/test_explode.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/source/whatsnew/v2.1.0.rst b/doc/source/whatsnew/v2.1.0.rst index 456325b7550bc..0249da8070511 100644 --- a/doc/source/whatsnew/v2.1.0.rst +++ b/doc/source/whatsnew/v2.1.0.rst @@ -102,13 +102,13 @@ Other enhancements - :meth:`DataFrame.stack` gained the ``sort`` keyword to dictate whether the resulting :class:`MultiIndex` levels are sorted (:issue:`15105`) - :meth:`DataFrame.unstack` gained the ``sort`` keyword to dictate whether the resulting :class:`MultiIndex` levels are sorted (:issue:`15105`) - :meth:`DataFrameGroupby.agg` and :meth:`DataFrameGroupby.transform` now support grouping by multiple keys when the index is not a :class:`MultiIndex` for ``engine="numba"`` (:issue:`53486`) +- :meth:`Series.explode` now supports pyarrow-backed list types (:issue:`53602`) - :meth:`SeriesGroupby.agg` and :meth:`DataFrameGroupby.agg` now support passing in multiple functions for ``engine="numba"`` (:issue:`53486`) - Added ``engine_kwargs`` parameter to :meth:`DataFrame.to_excel` (:issue:`53220`) - Added a new parameter ``by_row`` to :meth:`Series.apply`. When set to ``False`` the supplied callables will always operate on the whole Series (:issue:`53400`). - Many read/to_* functions, such as :meth:`DataFrame.to_pickle` and :func:`read_csv`, support forwarding compression arguments to lzma.LZMAFile (:issue:`52979`) - Performance improvement in :func:`concat` with homogeneous ``np.float64`` or ``np.float32`` dtypes (:issue:`52685`) - Performance improvement in :meth:`DataFrame.filter` when ``items`` is given (:issue:`52941`) -- :meth:`Series.explode` now supports pyarrow-backed list types (:issue:`#####`) .. --------------------------------------------------------------------------- .. _whatsnew_210.notable_bug_fixes: diff --git a/pandas/tests/series/methods/test_explode.py b/pandas/tests/series/methods/test_explode.py index 3ea7b2700c8fd..a01591fa3be4b 100644 --- a/pandas/tests/series/methods/test_explode.py +++ b/pandas/tests/series/methods/test_explode.py @@ -147,7 +147,7 @@ def test_explode_scalars_can_ignore_index(): @pytest.mark.skipif(pa_version_under7p0, reason="minimum pyarrow not installed") def test_explode_pyarrow_list_type(): - # GH ##### + # GH 53602 import pyarrow as pa data = [ From 264dc41fbdbff3e578117df0ed3b097cac802872 Mon Sep 17 00:00:00 2001 From: Luke Manley Date: Mon, 12 Jun 2023 20:26:28 -0400 Subject: [PATCH 3/3] update test --- pandas/tests/series/methods/test_explode.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/pandas/tests/series/methods/test_explode.py b/pandas/tests/series/methods/test_explode.py index a01591fa3be4b..c8a9eb6f89fde 100644 --- a/pandas/tests/series/methods/test_explode.py +++ b/pandas/tests/series/methods/test_explode.py @@ -1,8 +1,6 @@ import numpy as np import pytest -from pandas.compat import pa_version_under7p0 - import pandas as pd import pandas._testing as tm @@ -145,10 +143,10 @@ def test_explode_scalars_can_ignore_index(): tm.assert_series_equal(result, expected) -@pytest.mark.skipif(pa_version_under7p0, reason="minimum pyarrow not installed") -def test_explode_pyarrow_list_type(): +@pytest.mark.parametrize("ignore_index", [True, False]) +def test_explode_pyarrow_list_type(ignore_index): # GH 53602 - import pyarrow as pa + pa = pytest.importorskip("pyarrow") data = [ [None, None], @@ -158,10 +156,10 @@ def test_explode_pyarrow_list_type(): None, ] ser = pd.Series(data, dtype=pd.ArrowDtype(pa.list_(pa.int64()))) - result = ser.explode() + result = ser.explode(ignore_index=ignore_index) expected = pd.Series( data=[None, None, 1, None, 2, 3, None], - index=[0, 0, 1, 2, 3, 3, 4], + index=None if ignore_index else [0, 0, 1, 2, 3, 3, 4], dtype=pd.ArrowDtype(pa.int64()), ) tm.assert_series_equal(result, expected)