From 3232a4f8f3c1accaa0ecc9ff5d73f419388a84cb Mon Sep 17 00:00:00 2001 From: Brock Date: Thu, 24 Sep 2020 18:58:57 -0700 Subject: [PATCH] CLN: share setitem/getitem validators --- pandas/core/arrays/_mixins.py | 2 ++ pandas/core/arrays/categorical.py | 5 ++--- pandas/core/arrays/numpy_.py | 14 -------------- 3 files changed, 4 insertions(+), 17 deletions(-) diff --git a/pandas/core/arrays/_mixins.py b/pandas/core/arrays/_mixins.py index 2bf530eb2bad4..4d13a18c8ef0b 100644 --- a/pandas/core/arrays/_mixins.py +++ b/pandas/core/arrays/_mixins.py @@ -14,6 +14,7 @@ from pandas.core.algorithms import take, unique from pandas.core.array_algos.transforms import shift from pandas.core.arrays.base import ExtensionArray +from pandas.core.construction import extract_array from pandas.core.indexers import check_array_indexer _T = TypeVar("_T", bound="NDArrayBackedExtensionArray") @@ -197,6 +198,7 @@ def __getitem__(self, key): return result def _validate_getitem_key(self, key): + key = extract_array(key, extract_numpy=True) return check_array_indexer(self, key) @doc(ExtensionArray.fillna) diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index d2f88b353e1c1..7445d99fd7374 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -47,7 +47,7 @@ from pandas.core.base import ExtensionArray, NoNewAttributesMixin, PandasObject import pandas.core.common as com from pandas.core.construction import array, extract_array, sanitize_array -from pandas.core.indexers import check_array_indexer, deprecate_ndim_indexing +from pandas.core.indexers import deprecate_ndim_indexing from pandas.core.missing import interpolate_2d from pandas.core.ops.common import unpack_zerodim_and_defer from pandas.core.sorting import nargsort @@ -1923,8 +1923,7 @@ def _validate_setitem_key(self, key): # else: array of True/False in Series or Categorical - key = check_array_indexer(self, key) - return key + return super()._validate_setitem_key(key) def _reverse_indexer(self) -> Dict[Hashable, np.ndarray]: """ diff --git a/pandas/core/arrays/numpy_.py b/pandas/core/arrays/numpy_.py index f65b130b396da..6b982bf579f04 100644 --- a/pandas/core/arrays/numpy_.py +++ b/pandas/core/arrays/numpy_.py @@ -16,7 +16,6 @@ from pandas.core.array_algos import masked_reductions from pandas.core.arrays._mixins import NDArrayBackedExtensionArray from pandas.core.arrays.base import ExtensionOpsMixin -from pandas.core.construction import extract_array class PandasDtype(ExtensionDtype): @@ -244,19 +243,6 @@ def __array_ufunc__(self, ufunc, method: str, *inputs, **kwargs): # ------------------------------------------------------------------------ # Pandas ExtensionArray Interface - def _validate_getitem_key(self, key): - if isinstance(key, type(self)): - key = key._ndarray - - return super()._validate_getitem_key(key) - - def _validate_setitem_value(self, value): - value = extract_array(value, extract_numpy=True) - - if not lib.is_scalar(value): - value = np.asarray(value, dtype=self._ndarray.dtype) - return value - def isna(self) -> np.ndarray: return isna(self._ndarray)