Skip to content

ENH/TST: Add BaseParsinngTests tests for ArrowExtensionArray #47536

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 52 additions & 8 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,23 @@
ArrowExtensionArrayT = TypeVar("ArrowExtensionArrayT", bound="ArrowExtensionArray")


def to_pyarrow_type(
dtype: ArrowDtype | pa.DataType | Dtype | None,
) -> pa.DataType | None:
"""
Convert dtype to a pyarrow type instance.
"""
if isinstance(dtype, ArrowDtype):
pa_dtype = dtype.pyarrow_dtype
elif isinstance(dtype, pa.DataType):
pa_dtype = dtype
elif dtype:
pa_dtype = pa.from_numpy_dtype(dtype)
else:
pa_dtype = None
return pa_dtype


class ArrowExtensionArray(OpsMixin, ExtensionArray):
"""
Base class for ExtensionArray backed by Arrow ChunkedArray.
Expand All @@ -89,13 +106,7 @@ def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy=False):
"""
Construct a new ExtensionArray from a sequence of scalars.
"""
if isinstance(dtype, ArrowDtype):
pa_dtype = dtype.pyarrow_dtype
elif dtype:
pa_dtype = pa.from_numpy_dtype(dtype)
else:
pa_dtype = None

pa_dtype = to_pyarrow_type(dtype)
if isinstance(scalars, cls):
data = scalars._data
if pa_dtype:
Expand All @@ -113,7 +124,40 @@ def _from_sequence_of_strings(
"""
Construct a new ExtensionArray from a sequence of strings.
"""
return cls._from_sequence(strings, dtype=dtype, copy=copy)
pa_type = to_pyarrow_type(dtype)
if pa.types.is_timestamp(pa_type):
from pandas.core.tools.datetimes import to_datetime

scalars = to_datetime(strings, errors="raise")
elif pa.types.is_date(pa_type):
from pandas.core.tools.datetimes import to_datetime

scalars = to_datetime(strings, errors="raise").date
elif pa.types.is_duration(pa_type):
from pandas.core.tools.timedeltas import to_timedelta

scalars = to_timedelta(strings, errors="raise")
elif pa.types.is_time(pa_type):
from pandas.core.tools.times import to_time

# "coerce" to allow "null times" (None) to not raise
scalars = to_time(strings, errors="coerce")
elif pa.types.is_boolean(pa_type):
from pandas.core.arrays import BooleanArray

scalars = BooleanArray._from_sequence_of_strings(strings).to_numpy()
elif (
pa.types.is_integer(pa_type)
or pa.types.is_floating(pa_type)
or pa.types.is_decimal(pa_type)
):
from pandas.core.tools.numeric import to_numeric

scalars = to_numeric(strings, errors="raise")
else:
# Let pyarrow try to infer or raise
scalars = strings
return cls._from_sequence(scalars, dtype=pa_type, copy=copy)

def __getitem__(self, item: PositionalIndexer):
"""Select a subset of self.
Expand Down
26 changes: 22 additions & 4 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ def test_setitem_loc_iloc_slice(self, data, using_array_manager, request):
if pa_version_under2p0 and tz not in (None, "UTC"):
request.node.add_marker(
pytest.mark.xfail(
reason=(f"Not supported by pyarrow < 2.0 with timestamp type {tz}")
reason=f"Not supported by pyarrow < 2.0 with timestamp type {tz}"
)
)
elif using_array_manager and pa.types.is_duration(data.dtype.pyarrow_dtype):
Expand All @@ -728,7 +728,7 @@ def test_setitem_slice_array(self, data, request):
if pa_version_under2p0 and tz not in (None, "UTC"):
request.node.add_marker(
pytest.mark.xfail(
reason=(f"Not supported by pyarrow < 2.0 with timestamp type {tz}")
reason=f"Not supported by pyarrow < 2.0 with timestamp type {tz}"
)
)
super().test_setitem_slice_array(data)
Expand All @@ -742,7 +742,7 @@ def test_setitem_with_expansion_dataframe_column(
if pa_version_under2p0 and tz not in (None, "UTC") and not is_null_slice:
request.node.add_marker(
pytest.mark.xfail(
reason=(f"Not supported by pyarrow < 2.0 with timestamp type {tz}")
reason=f"Not supported by pyarrow < 2.0 with timestamp type {tz}"
)
)
elif (
Expand Down Expand Up @@ -780,7 +780,7 @@ def test_setitem_frame_2d_values(self, data, using_array_manager, request):
if pa_version_under2p0 and tz not in (None, "UTC"):
request.node.add_marker(
pytest.mark.xfail(
reason=(f"Not supported by pyarrow < 2.0 with timestamp type {tz}")
reason=f"Not supported by pyarrow < 2.0 with timestamp type {tz}"
)
)
elif using_array_manager and pa.types.is_duration(data.dtype.pyarrow_dtype):
Expand All @@ -796,6 +796,24 @@ def test_setitem_preserves_views(self, data):
super().test_setitem_preserves_views(data)


class TestBaseParsing(base.BaseParsingTests):
@pytest.mark.parametrize("engine", ["c", "python"])
def test_EA_types(self, engine, data, request):
pa_dtype = data.dtype.pyarrow_dtype
if pa.types.is_boolean(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(raises=TypeError, reason="GH 47534")
)
elif pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None:
request.node.add_marker(
pytest.mark.xfail(
raises=NotImplementedError,
reason=f"Parameterized types with tz={pa_dtype.tz} not supported.",
)
)
super().test_EA_types(engine, data)


def test_arrowdtype_construct_from_string_type_with_unsupported_parameters():
with pytest.raises(NotImplementedError, match="Passing pyarrow type"):
ArrowDtype.construct_from_string("timestamp[s, tz=UTC][pyarrow]")