From 822ae79b6600d1f388a90c9c681ec2e82c5b478f Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 11 Jan 2023 19:16:51 -0800 Subject: [PATCH 1/2] BUG: ArrowDtype.construct_from_string round-trip --- pandas/core/arrays/arrow/dtype.py | 39 ++++++++++++ pandas/tests/extension/base/dtype.py | 18 +++--- pandas/tests/extension/test_arrow.py | 91 ++++++---------------------- 3 files changed, 66 insertions(+), 82 deletions(-) diff --git a/pandas/core/arrays/arrow/dtype.py b/pandas/core/arrays/arrow/dtype.py index f5f87bea83b8f..cebc920086b7c 100644 --- a/pandas/core/arrays/arrow/dtype.py +++ b/pandas/core/arrays/arrow/dtype.py @@ -148,6 +148,14 @@ def construct_from_string(cls, string: str) -> ArrowDtype: except ValueError as err: has_parameters = re.search(r"\[.*\]", base_type) if has_parameters: + + # Fallback to try common temporal types + try: + return cls._parse_temporal_dtype_string(base_type) + except (NotImplementedError, ValueError): + # Fall through to raise with nice exception message below + pass + raise NotImplementedError( "Passing pyarrow type specific parameters " f"({has_parameters.group()}) in the string is not supported. " @@ -157,6 +165,37 @@ def construct_from_string(cls, string: str) -> ArrowDtype: raise TypeError(f"'{base_type}' is not a valid pyarrow data type.") from err return cls(pa_dtype) + @classmethod + def _parse_temporal_dtype_string(cls, string: str) -> ArrowDtype: + """ + Construct a temporal ArrowDtype from string. + """ + # we assume + # 1) "[pyarrow]" has already been stripped from the end of our string. + # 2) we know "[" is present + head, tail = string.split("[", 1) + + if not tail.endswith("]"): + raise ValueError + tail = tail[:-1] + + if head == "timestamp": + if "," not in tail: + tz = None + unit = tail + else: + unit, tz = tail.split(",", 1) + unit = unit.strip() + tz = tz.strip() + if tz.startswith("tz="): + tz = tz[3:] + + pa_type = pa.timestamp(unit, tz=tz) + dtype = cls(pa_type) + return dtype + + raise NotImplementedError(string) + @property def _is_numeric(self) -> bool: """ diff --git a/pandas/tests/extension/base/dtype.py b/pandas/tests/extension/base/dtype.py index 2635343d73fd7..392a75f8a69a7 100644 --- a/pandas/tests/extension/base/dtype.py +++ b/pandas/tests/extension/base/dtype.py @@ -20,14 +20,6 @@ def test_kind(self, dtype): valid = set("biufcmMOSUV") assert dtype.kind in valid - def test_construct_from_string_own_name(self, dtype): - result = dtype.construct_from_string(dtype.name) - assert type(result) is type(dtype) - - # check OK as classmethod - result = type(dtype).construct_from_string(dtype.name) - assert type(result) is type(dtype) - def test_is_dtype_from_name(self, dtype): result = type(dtype).is_dtype(dtype.name) assert result is True @@ -97,9 +89,13 @@ def test_eq(self, dtype): assert dtype == dtype.name assert dtype != "anonther_type" - def test_construct_from_string(self, dtype): - dtype_instance = type(dtype).construct_from_string(dtype.name) - assert isinstance(dtype_instance, type(dtype)) + def test_construct_from_string_own_name(self, dtype): + result = dtype.construct_from_string(dtype.name) + assert type(result) is type(dtype) + + # check OK as classmethod + result = type(dtype).construct_from_string(dtype.name) + assert type(result) is type(dtype) def test_construct_from_string_another_type_raises(self, dtype): msg = f"Cannot construct a '{type(dtype).__name__}' from 'another_type'" diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 78c49ae066288..9b84443c54cbb 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -245,13 +245,9 @@ def test_astype_str(self, data, request): class TestConstructors(base.BaseConstructorsTests): def test_from_dtype(self, data, request): pa_dtype = data.dtype.pyarrow_dtype - if (pa.types.is_timestamp(pa_dtype) and pa_dtype.tz) or pa.types.is_string( - pa_dtype - ): - if pa.types.is_string(pa_dtype): - reason = "ArrowDtype(pa.string()) != StringDtype('pyarrow')" - else: - reason = f"pyarrow.type_for_alias cannot infer {pa_dtype}" + + if pa.types.is_string(pa_dtype): + reason = "ArrowDtype(pa.string()) != StringDtype('pyarrow')" request.node.add_marker( pytest.mark.xfail( reason=reason, @@ -577,65 +573,24 @@ def test_groupby_extension_agg(self, as_index, data_for_grouping, request): class TestBaseDtype(base.BaseDtypeTests): def test_construct_from_string_own_name(self, dtype, request): pa_dtype = dtype.pyarrow_dtype - if pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None: - request.node.add_marker( - pytest.mark.xfail( - raises=NotImplementedError, - reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}", - ) - ) - elif pa.types.is_string(pa_dtype): - request.node.add_marker( - pytest.mark.xfail( - raises=TypeError, - reason=( - "Still support StringDtype('pyarrow') " - "over ArrowDtype(pa.string())" - ), - ) - ) + + if pa.types.is_string(pa_dtype): + # We still support StringDtype('pyarrow') over ArrowDtype(pa.string()) + msg = r"string\[pyarrow\] should be constructed by StringDtype" + with pytest.raises(TypeError, match=msg): + dtype.construct_from_string(dtype.name) + + return + super().test_construct_from_string_own_name(dtype) def test_is_dtype_from_name(self, dtype, request): pa_dtype = dtype.pyarrow_dtype - if pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None: - request.node.add_marker( - pytest.mark.xfail( - raises=NotImplementedError, - reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}", - ) - ) - elif pa.types.is_string(pa_dtype): - request.node.add_marker( - pytest.mark.xfail( - reason=( - "Still support StringDtype('pyarrow') " - "over ArrowDtype(pa.string())" - ), - ) - ) - super().test_is_dtype_from_name(dtype) - - def test_construct_from_string(self, dtype, request): - pa_dtype = dtype.pyarrow_dtype - if pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None: - request.node.add_marker( - pytest.mark.xfail( - raises=NotImplementedError, - reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}", - ) - ) - elif pa.types.is_string(pa_dtype): - request.node.add_marker( - pytest.mark.xfail( - raises=TypeError, - reason=( - "Still support StringDtype('pyarrow') " - "over ArrowDtype(pa.string())" - ), - ) - ) - super().test_construct_from_string(dtype) + if pa.types.is_string(pa_dtype): + # We still support StringDtype('pyarrow') over ArrowDtype(pa.string()) + assert not type(dtype).is_dtype(dtype.name) + else: + super().test_is_dtype_from_name(dtype) def test_construct_from_string_another_type_raises(self, dtype): msg = r"'another_type' must end with '\[pyarrow\]'" @@ -720,13 +675,6 @@ def test_EA_types(self, engine, data, request): request.node.add_marker( pytest.mark.xfail(raises=TypeError, reason="GH 47534") ) - elif pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None: - request.node.add_marker( - pytest.mark.xfail( - raises=NotImplementedError, - reason=f"Parameterized types with tz={pa_dtype.tz} not supported.", - ) - ) elif pa.types.is_timestamp(pa_dtype) and pa_dtype.unit in ("us", "ns"): request.node.add_marker( pytest.mark.xfail( @@ -1354,8 +1302,9 @@ def test_invalid_other_comp(self, data, comparison_op): def test_arrowdtype_construct_from_string_type_with_unsupported_parameters(): - with pytest.raises(NotImplementedError, match="Passing pyarrow type"): - ArrowDtype.construct_from_string("timestamp[s, tz=UTC][pyarrow]") + dtype = ArrowDtype.construct_from_string("timestamp[s, tz=UTC][pyarrow]") + expected = ArrowDtype(pa.timestamp("s", "UTC")) + assert dtype == expected @pytest.mark.parametrize( From a2a236f5306f8a3d1802c2e7bd2887f1f9fd9443 Mon Sep 17 00:00:00 2001 From: Brock Date: Fri, 17 Feb 2023 15:45:38 -0800 Subject: [PATCH 2/2] suggestions --- pandas/core/arrays/arrow/dtype.py | 17 +++++++---------- pandas/tests/extension/test_arrow.py | 4 ++++ 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/pandas/core/arrays/arrow/dtype.py b/pandas/core/arrays/arrow/dtype.py index 402281f3b312a..fdb9ac877831b 100644 --- a/pandas/core/arrays/arrow/dtype.py +++ b/pandas/core/arrays/arrow/dtype.py @@ -203,7 +203,6 @@ def construct_from_string(cls, string: str) -> ArrowDtype: except ValueError as err: has_parameters = re.search(r"\[.*\]", base_type) if has_parameters: - # Fallback to try common temporal types try: return cls._parse_temporal_dtype_string(base_type) @@ -220,6 +219,7 @@ def construct_from_string(cls, string: str) -> ArrowDtype: raise TypeError(f"'{base_type}' is not a valid pyarrow data type.") from err return cls(pa_dtype) + # TODO(arrow#33642): This can be removed once supported by pyarrow @classmethod def _parse_temporal_dtype_string(cls, string: str) -> ArrowDtype: """ @@ -235,15 +235,12 @@ def _parse_temporal_dtype_string(cls, string: str) -> ArrowDtype: tail = tail[:-1] if head == "timestamp": - if "," not in tail: - tz = None - unit = tail - else: - unit, tz = tail.split(",", 1) - unit = unit.strip() - tz = tz.strip() - if tz.startswith("tz="): - tz = tz[3:] + assert "," in tail # otherwise type_for_alias should work + unit, tz = tail.split(",", 1) + unit = unit.strip() + tz = tz.strip() + if tz.startswith("tz="): + tz = tz[3:] pa_type = pa.timestamp(unit, tz=tz) dtype = cls(pa_type) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 87631bfe02496..c53b6e6658979 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1213,6 +1213,10 @@ def test_invalid_other_comp(self, data, comparison_op): def test_arrowdtype_construct_from_string_type_with_unsupported_parameters(): + with pytest.raises(NotImplementedError, match="Passing pyarrow type"): + ArrowDtype.construct_from_string("not_a_real_dype[s, tz=UTC][pyarrow]") + + # but as of GH#50689, timestamptz is supported dtype = ArrowDtype.construct_from_string("timestamp[s, tz=UTC][pyarrow]") expected = ArrowDtype(pa.timestamp("s", "UTC")) assert dtype == expected