diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index c7f587b35f557..9c43e3714c332 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -6,7 +6,6 @@ from shutil import get_terminal_size from typing import ( TYPE_CHECKING, - Any, Hashable, Sequence, TypeVar, @@ -38,10 +37,6 @@ Dtype, NpDtype, Ordered, - PositionalIndexer2D, - PositionalIndexerTuple, - ScalarIndexer, - SequenceIndexer, Shape, npt, type_t, @@ -102,7 +97,10 @@ take_nd, unique1d, ) -from pandas.core.arrays._mixins import NDArrayBackedExtensionArray +from pandas.core.arrays._mixins import ( + NDArrayBackedExtensionArray, + ravel_compat, +) from pandas.core.base import ( ExtensionArray, NoNewAttributesMixin, @@ -113,7 +111,6 @@ extract_array, sanitize_array, ) -from pandas.core.indexers import deprecate_ndim_indexing from pandas.core.ops.common import unpack_zerodim_and_defer from pandas.core.sorting import nargsort from pandas.core.strings.object_array import ObjectStringArrayMixin @@ -1484,6 +1481,7 @@ def _validate_scalar(self, fill_value): # ------------------------------------------------------------- + @ravel_compat def __array__(self, dtype: NpDtype | None = None) -> np.ndarray: """ The numpy array interface. @@ -1934,7 +1932,10 @@ def __iter__(self): """ Returns an Iterator over the values of this Categorical. """ - return iter(self._internal_get_values().tolist()) + if self.ndim == 1: + return iter(self._internal_get_values().tolist()) + else: + return (self[n] for n in range(len(self))) def __contains__(self, key) -> bool: """ @@ -2053,27 +2054,6 @@ def __repr__(self) -> str: # ------------------------------------------------------------------ - @overload - def __getitem__(self, key: ScalarIndexer) -> Any: - ... - - @overload - def __getitem__( - self: CategoricalT, - key: SequenceIndexer | PositionalIndexerTuple, - ) -> CategoricalT: - ... - - def __getitem__(self: CategoricalT, key: PositionalIndexer2D) -> CategoricalT | Any: - """ - Return an item. - """ - result = super().__getitem__(key) - if getattr(result, "ndim", 0) > 1: - result = result._ndarray - deprecate_ndim_indexing(result) - return result - def _validate_listlike(self, value): # NB: here we assume scalar-like tuples have already been excluded value = extract_array(value, extract_numpy=True) @@ -2311,7 +2291,19 @@ def _concat_same_type( ) -> CategoricalT: from pandas.core.dtypes.concat import union_categoricals - return union_categoricals(to_concat) + result = union_categoricals(to_concat) + + # in case we are concatenating along axis != 0, we need to reshape + # the result from union_categoricals + first = to_concat[0] + if axis >= first.ndim: + raise ValueError + if axis == 1: + if not all(len(x) == len(first) for x in to_concat): + raise ValueError + # TODO: Will this get contiguity wrong? + result = result.reshape(-1, len(to_concat), order="F") + return result # ------------------------------------------------------------------ @@ -2699,6 +2691,11 @@ def _get_codes_for_values(values, categories: Index) -> np.ndarray: """ dtype_equal = is_dtype_equal(values.dtype, categories.dtype) + if values.ndim > 1: + flat = values.ravel() + codes = _get_codes_for_values(flat, categories) + return codes.reshape(values.shape) + if isinstance(categories.dtype, ExtensionDtype) and is_object_dtype(values): # Support inferring the correct extension dtype from an array of # scalar objects. e.g. diff --git a/pandas/tests/extension/test_categorical.py b/pandas/tests/extension/test_categorical.py index e9dc63e9bd903..6a1a9512bc036 100644 --- a/pandas/tests/extension/test_categorical.py +++ b/pandas/tests/extension/test_categorical.py @@ -303,3 +303,14 @@ def test_not_equal_with_na(self, categories): class TestParsing(base.BaseParsingTests): pass + + +class Test2DCompat(base.Dim2CompatTests): + def test_repr_2d(self, data): + # Categorical __repr__ doesn't include "Categorical", so we need + # to special-case + res = repr(data.reshape(1, -1)) + assert res.count("\nCategories") == 1 + + res = repr(data.reshape(-1, 1)) + assert res.count("\nCategories") == 1