From 359b2890cbc060dcbe24665c4977cacc2d744cc2 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Mon, 4 Dec 2023 16:18:45 -0800 Subject: [PATCH 1/4] ENH: Implement str.extract for ArrowDtype --- doc/source/whatsnew/v2.2.0.rst | 1 + pandas/core/arrays/arrow/array.py | 16 ++++++++++--- pandas/tests/extension/test_arrow.py | 35 ++++++++++++++++++++++++---- 3 files changed, 44 insertions(+), 8 deletions(-) diff --git a/doc/source/whatsnew/v2.2.0.rst b/doc/source/whatsnew/v2.2.0.rst index 5ee2bb1778cb1..1f37dd3e2fd50 100644 --- a/doc/source/whatsnew/v2.2.0.rst +++ b/doc/source/whatsnew/v2.2.0.rst @@ -226,6 +226,7 @@ Other enhancements - Allow passing ``read_only``, ``data_only`` and ``keep_links`` arguments to openpyxl using ``engine_kwargs`` of :func:`read_excel` (:issue:`55027`) - DataFrame.apply now allows the usage of numba (via ``engine="numba"``) to JIT compile the passed function, allowing for potential speedups (:issue:`54666`) - Implement masked algorithms for :meth:`Series.value_counts` (:issue:`54984`) +- Implemented :meth:`Series.str.extract` for :class:`ArrowDtype` (:issue:`56268`) - Improved error message that appears in :meth:`DatetimeIndex.to_period` with frequencies which are not supported as period frequencies, such as "BMS" (:issue:`56243`) - Improved error message when constructing :class:`Period` with invalid offsets such as "QS" (:issue:`55785`) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index d162b66e5d369..7d52334e12bbd 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -2296,9 +2296,19 @@ def _str_encode(self, encoding: str, errors: str = "strict"): return type(self)(pa.chunked_array(result)) def _str_extract(self, pat: str, flags: int = 0, expand: bool = True): - raise NotImplementedError( - "str.extract not supported with pd.ArrowDtype(pa.string())." - ) + if flags: + raise NotImplementedError("Only flags=0 is implemented.") + groups = re.compile(pat).groupindex.keys() + if len(groups) == 0: + raise ValueError(f"{pat=} must contain a symbolic group name.") + result = pc.extract_regex(self._pa_array, pat) + if expand: + return { + col: type(self)(pc.struct_field(result, i)) + for col, i in zip(groups, range(result.type.num_fields)) + } + else: + return type(self)(pc.struct_field(result, 0)) def _str_findall(self, pat: str, flags: int = 0): regex = re.compile(pat, flags=flags) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 7131a50956a7d..5e7fc1701ed72 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -2150,14 +2150,39 @@ def test_str_rsplit(): tm.assert_frame_equal(result, expected) -def test_str_unsupported_extract(): - ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) - with pytest.raises( - NotImplementedError, match="str.extract not supported with pd.ArrowDtype" - ): +def test_str_extract_non_symbolic(): + ser = pd.Series(["a1", "b2", "c3"], dtype=ArrowDtype(pa.string())) + with pytest.raises(ValueError, match="pat=.* must contain a symbolic group name."): ser.str.extract(r"[ab](\d)") +def test_str_extract(): + ser = pd.Series(["a1", "b2", "c3"], dtype=ArrowDtype(pa.string())) + result = ser.str.extract(r"(?P[ab])(?P\d)") + expected = pd.DataFrame( + { + "letter": ArrowExtensionArray(pa.array(["a", "b", None])), + "digit": ArrowExtensionArray(pa.array(["1", "2", None])), + } + ) + tm.assert_frame_equal(result, expected) + + +def test_str_extract_expand(): + ser = pd.Series(["a1", "b2", "c3"], dtype=ArrowDtype(pa.string())) + result = ser.str.extract(r"[ab](?P\d)", expand=True) + expected = pd.DataFrame( + { + "digit": ArrowExtensionArray(pa.array(["1", "2", None])), + } + ) + tm.assert_frame_equal(result, expected) + + result = ser.str.extract(r"[ab](?P\d)", expand=False) + expected = pd.Series(ArrowExtensionArray(pa.array(["1", "2", None])), name="digit") + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("unit", ["ns", "us", "ms", "s"]) def test_duration_from_strings_with_nat(unit): # GH51175 From 2dda5cdcfb9c0e3c6ae733fc6d4d7af0d919617f Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Mon, 4 Dec 2023 16:58:07 -0800 Subject: [PATCH 2/4] Parameterize on expand --- pandas/tests/extension/test_arrow.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 5e7fc1701ed72..8e555cfe64629 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -2156,9 +2156,10 @@ def test_str_extract_non_symbolic(): ser.str.extract(r"[ab](\d)") -def test_str_extract(): +@pytest.mark.parametrize("expand", [True, False]) +def test_str_extract(expand): ser = pd.Series(["a1", "b2", "c3"], dtype=ArrowDtype(pa.string())) - result = ser.str.extract(r"(?P[ab])(?P\d)") + result = ser.str.extract(r"(?P[ab])(?P\d)", expand=expand) expected = pd.DataFrame( { "letter": ArrowExtensionArray(pa.array(["a", "b", None])), From daf12ea218d0955aae7535120779549de46d8056 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 5 Dec 2023 10:26:51 -0800 Subject: [PATCH 3/4] Min version compat --- pandas/core/arrays/arrow/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 7d52334e12bbd..379d967daac67 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -2304,7 +2304,7 @@ def _str_extract(self, pat: str, flags: int = 0, expand: bool = True): result = pc.extract_regex(self._pa_array, pat) if expand: return { - col: type(self)(pc.struct_field(result, i)) + col: type(self)(pc.struct_field(result, [i])) for col, i in zip(groups, range(result.type.num_fields)) } else: From ea8ffb11109a8581d3a0cf33a8b359bf56d85649 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 5 Dec 2023 12:20:26 -0800 Subject: [PATCH 4/4] Min version compat again --- pandas/core/arrays/arrow/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 379d967daac67..e8cb490062fb1 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -2308,7 +2308,7 @@ def _str_extract(self, pat: str, flags: int = 0, expand: bool = True): for col, i in zip(groups, range(result.type.num_fields)) } else: - return type(self)(pc.struct_field(result, 0)) + return type(self)(pc.struct_field(result, [0])) def _str_findall(self, pat: str, flags: int = 0): regex = re.compile(pat, flags=flags)