From 2e1843a0cb811a577da156bd11a7990a8ee0c8a7 Mon Sep 17 00:00:00 2001 From: Brock Date: Sun, 24 Oct 2021 14:38:31 -0700 Subject: [PATCH] REF: share ExtensionIndex.insert-> Index.insert --- pandas/core/dtypes/missing.py | 4 ++++ pandas/core/indexes/base.py | 15 +++++++++++---- pandas/core/indexes/extension.py | 25 ------------------------- pandas/tests/dtypes/test_missing.py | 9 +++++++++ 4 files changed, 24 insertions(+), 29 deletions(-) diff --git a/pandas/core/dtypes/missing.py b/pandas/core/dtypes/missing.py index f5fbd4cc4a7fc..38553bc1be8d6 100644 --- a/pandas/core/dtypes/missing.py +++ b/pandas/core/dtypes/missing.py @@ -37,6 +37,7 @@ needs_i8_conversion, ) from pandas.core.dtypes.dtypes import ( + CategoricalDtype, ExtensionDtype, IntervalDtype, PeriodDtype, @@ -641,5 +642,8 @@ def is_valid_na_for_dtype(obj, dtype: DtypeObj) -> bool: elif isinstance(dtype, IntervalDtype): return lib.is_float(obj) or obj is None or obj is libmissing.NA + elif isinstance(dtype, CategoricalDtype): + return is_valid_na_for_dtype(obj, dtype.categories.dtype) + # fallback, default to allowing NaN, None, NA, NaT return not isinstance(obj, (np.datetime64, np.timedelta64, Decimal)) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 05047540c6ccd..e82bd61938f15 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -6432,14 +6432,21 @@ def insert(self, loc: int, item) -> Index: if is_valid_na_for_dtype(item, self.dtype) and self.dtype != object: item = self._na_value + arr = self._values + try: - item = self._validate_fill_value(item) - except TypeError: + if isinstance(arr, ExtensionArray): + res_values = arr.insert(loc, item) + return type(self)._simple_new(res_values, name=self.name) + else: + item = self._validate_fill_value(item) + except (TypeError, ValueError): + # e.g. trying to insert an integer into a DatetimeIndex + # We cannot keep the same dtype, so cast to the (often object) + # minimal shared dtype before doing the insert. dtype = self._find_common_type_compat(item) return self.astype(dtype).insert(loc, item) - arr = self._values - if arr.dtype != object or not isinstance( item, (tuple, np.datetime64, np.timedelta64) ): diff --git a/pandas/core/indexes/extension.py b/pandas/core/indexes/extension.py index ccd18f54da327..7c7f1b267b5be 100644 --- a/pandas/core/indexes/extension.py +++ b/pandas/core/indexes/extension.py @@ -134,31 +134,6 @@ class ExtensionIndex(Index): # --------------------------------------------------------------------- - def insert(self, loc: int, item) -> Index: - """ - Make new Index inserting new item at location. Follows - Python list.append semantics for negative values. - - Parameters - ---------- - loc : int - item : object - - Returns - ------- - new_index : Index - """ - try: - result = self._data.insert(loc, item) - except (ValueError, TypeError): - # e.g. trying to insert an integer into a DatetimeIndex - # We cannot keep the same dtype, so cast to the (often object) - # minimal shared dtype before doing the insert. - dtype = self._find_common_type_compat(item) - return self.astype(dtype).insert(loc, item) - else: - return type(self)._simple_new(result, name=self.name) - def _validate_fill_value(self, value): """ Convert value to be insertable to underlying array. diff --git a/pandas/tests/dtypes/test_missing.py b/pandas/tests/dtypes/test_missing.py index bf68c4b79bcea..55d0e5e73418e 100644 --- a/pandas/tests/dtypes/test_missing.py +++ b/pandas/tests/dtypes/test_missing.py @@ -18,6 +18,7 @@ is_scalar, ) from pandas.core.dtypes.dtypes import ( + CategoricalDtype, DatetimeTZDtype, IntervalDtype, PeriodDtype, @@ -739,3 +740,11 @@ def test_is_valid_na_for_dtype_interval(self): dtype = IntervalDtype("datetime64[ns]", "both") assert not is_valid_na_for_dtype(NaT, dtype) + + def test_is_valid_na_for_dtype_categorical(self): + dtype = CategoricalDtype(categories=[0, 1, 2]) + assert is_valid_na_for_dtype(np.nan, dtype) + + assert not is_valid_na_for_dtype(NaT, dtype) + assert not is_valid_na_for_dtype(np.datetime64("NaT", "ns"), dtype) + assert not is_valid_na_for_dtype(np.timedelta64("NaT", "ns"), dtype)