From f8797b43b0813823826c19ad4235c65e4f9e3723 Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 28 Aug 2023 16:11:13 -0700 Subject: [PATCH] ENH: allow EADtype to specify _supports_2d --- pandas/core/dtypes/base.py | 27 +++++++++++++++++++++++++ pandas/core/dtypes/common.py | 8 +------- pandas/core/dtypes/dtypes.py | 8 ++++++++ pandas/tests/extension/base/__init__.py | 1 + pandas/tests/extension/base/dim2.py | 11 ++++++++++ 5 files changed, 48 insertions(+), 7 deletions(-) diff --git a/pandas/core/dtypes/base.py b/pandas/core/dtypes/base.py index a055afe6ec0ae..6567ca7155b0d 100644 --- a/pandas/core/dtypes/base.py +++ b/pandas/core/dtypes/base.py @@ -418,6 +418,33 @@ def index_class(self) -> type_t[Index]: return Index + @property + def _supports_2d(self) -> bool: + """ + Do ExtensionArrays with this dtype support 2D arrays? + + Historically ExtensionArrays were limited to 1D. By returning True here, + authors can indicate that their arrays support 2D instances. This can + improve performance in some cases, particularly operations with `axis=1`. + + Arrays that support 2D values should: + + - implement Array.reshape + - subclass the Dim2CompatTests in tests.extension.base + - _concat_same_type should support `axis` keyword + - _reduce and reductions should support `axis` keyword + """ + return False + + @property + def _can_fast_transpose(self) -> bool: + """ + Is transposing an array with this dtype zero-copy? + + Only relevant for cases where _supports_2d is True. + """ + return False + class StorageExtensionDtype(ExtensionDtype): """ExtensionDtype that may be backed by more than one implementation.""" diff --git a/pandas/core/dtypes/common.py b/pandas/core/dtypes/common.py index 3db36fc50e343..63c59297f454a 100644 --- a/pandas/core/dtypes/common.py +++ b/pandas/core/dtypes/common.py @@ -1256,13 +1256,7 @@ def is_1d_only_ea_dtype(dtype: DtypeObj | None) -> bool: """ Analogue to is_extension_array_dtype but excluding DatetimeTZDtype. """ - # Note: if other EA dtypes are ever held in HybridBlock, exclude those - # here too. - # NB: need to check DatetimeTZDtype and not is_datetime64tz_dtype - # to exclude ArrowTimestampUSDtype - return isinstance(dtype, ExtensionDtype) and not isinstance( - dtype, (DatetimeTZDtype, PeriodDtype) - ) + return isinstance(dtype, ExtensionDtype) and not dtype._supports_2d def is_extension_array_dtype(arr_or_dtype) -> bool: diff --git a/pandas/core/dtypes/dtypes.py b/pandas/core/dtypes/dtypes.py index f43865465cb28..8ab002445633e 100644 --- a/pandas/core/dtypes/dtypes.py +++ b/pandas/core/dtypes/dtypes.py @@ -210,6 +210,8 @@ class CategoricalDtype(PandasExtensionDtype, ExtensionDtype): base = np.dtype("O") _metadata = ("categories", "ordered") _cache_dtypes: dict[str_type, PandasExtensionDtype] = {} + _supports_2d = False + _can_fast_transpose = False def __init__(self, categories=None, ordered: Ordered = False) -> None: self._finalize(categories, ordered, fastpath=False) @@ -727,6 +729,8 @@ class DatetimeTZDtype(PandasExtensionDtype): _metadata = ("unit", "tz") _match = re.compile(r"(datetime64|M8)\[(?P.+), (?P.+)\]") _cache_dtypes: dict[str_type, PandasExtensionDtype] = {} + _supports_2d = True + _can_fast_transpose = True @property def na_value(self) -> NaTType: @@ -970,6 +974,8 @@ class PeriodDtype(PeriodDtypeBase, PandasExtensionDtype): _cache_dtypes: dict[BaseOffset, int] = {} # type: ignore[assignment] __hash__ = PeriodDtypeBase.__hash__ _freq: BaseOffset + _supports_2d = True + _can_fast_transpose = True def __new__(cls, freq): """ @@ -1432,6 +1438,8 @@ class NumpyEADtype(ExtensionDtype): """ _metadata = ("_dtype",) + _supports_2d = False + _can_fast_transpose = False def __init__(self, dtype: npt.DTypeLike | NumpyEADtype | None) -> None: if isinstance(dtype, NumpyEADtype): diff --git a/pandas/tests/extension/base/__init__.py b/pandas/tests/extension/base/__init__.py index 82b61722f5e96..6efaa95aef1b5 100644 --- a/pandas/tests/extension/base/__init__.py +++ b/pandas/tests/extension/base/__init__.py @@ -85,6 +85,7 @@ class ExtensionTests( BaseReduceTests, BaseReshapingTests, BaseSetitemTests, + Dim2CompatTests, ): pass diff --git a/pandas/tests/extension/base/dim2.py b/pandas/tests/extension/base/dim2.py index a0c24ee068e81..85b3dd2d2b267 100644 --- a/pandas/tests/extension/base/dim2.py +++ b/pandas/tests/extension/base/dim2.py @@ -20,6 +20,17 @@ class Dim2CompatTests: # Note: these are ONLY for ExtensionArray subclasses that support 2D arrays. # i.e. not for pyarrow-backed EAs. + @pytest.fixture(autouse=True) + def skip_if_doesnt_support_2d(self, dtype, request): + if not dtype._supports_2d: + node = request.node + # In cases where we are mixed in to ExtensionTests, we only want to + # skip tests that are defined in Dim2CompatTests + test_func = node._obj + if test_func.__qualname__.startswith("Dim2CompatTests"): + # TODO: is there a less hacky way of checking this? + pytest.skip("Test is only for EAs that support 2D.") + def test_transpose(self, data): arr2d = data.repeat(2).reshape(-1, 2) shape = arr2d.shape