diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index 58847528d2183..b732db4c66003 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -9,7 +9,7 @@ from pandas._config import get_option -from pandas._libs import NaT, algos as libalgos, hashtable as htable +from pandas._libs import NaT, algos as libalgos, hashtable as htable, lib from pandas._typing import ArrayLike, Dtype, Ordered, Scalar from pandas.compat.numpy import function as nv from pandas.util._decorators import cache_readonly, deprecate_kwarg, doc @@ -1868,14 +1868,6 @@ def __repr__(self) -> str: # ------------------------------------------------------------------ - def _maybe_coerce_indexer(self, indexer): - """ - return an indexer coerced to the codes dtype - """ - if isinstance(indexer, np.ndarray) and indexer.dtype.kind == "i": - indexer = indexer.astype(self._codes.dtype) - return indexer - def __getitem__(self, key): """ Return an item. @@ -1905,6 +1897,11 @@ def __setitem__(self, key, value): If (one or more) Value is not in categories or if a assigned `Categorical` does not have the same categories """ + key = self._validate_setitem_key(key) + value = self._validate_setitem_value(value) + self._ndarray[key] = value + + def _validate_setitem_value(self, value): value = extract_array(value, extract_numpy=True) # require identical categories set @@ -1934,12 +1931,19 @@ def __setitem__(self, key, value): "category, set the categories first" ) - # set by position - if isinstance(key, (int, np.integer)): + lindexer = self.categories.get_indexer(rvalue) + if isinstance(lindexer, np.ndarray) and lindexer.dtype.kind == "i": + lindexer = lindexer.astype(self._ndarray.dtype) + + return lindexer + + def _validate_setitem_key(self, key): + if lib.is_integer(key): + # set by position pass - # tuple of indexers (dataframe) elif isinstance(key, tuple): + # tuple of indexers (dataframe) # only allow 1 dimensional slicing, but can # in a 2-d case be passed (slice(None),....) if len(key) == 2: @@ -1951,17 +1955,14 @@ def __setitem__(self, key, value): else: raise AssertionError("invalid slicing for a 1-ndim categorical") - # slicing in Series or Categorical elif isinstance(key, slice): + # slicing in Series or Categorical pass # else: array of True/False in Series or Categorical - lindexer = self.categories.get_indexer(rvalue) - lindexer = self._maybe_coerce_indexer(lindexer) - key = check_array_indexer(self, key) - self._codes[key] = lindexer + return key def _reverse_indexer(self) -> Dict[Hashable, np.ndarray]: """ diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index a218745db0a44..2626890c2dbe5 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -546,6 +546,15 @@ def __getitem__(self, key): return self._box_func(result) return self._simple_new(result, dtype=self.dtype) + key = self._validate_getitem_key(key) + result = self._ndarray[key] + if lib.is_scalar(result): + return self._box_func(result) + + freq = self._get_getitem_freq(key) + return self._simple_new(result, dtype=self.dtype, freq=freq) + + def _validate_getitem_key(self, key): if com.is_bool_indexer(key): # first convert to boolean, because check_array_indexer doesn't # allow object dtype @@ -560,12 +569,7 @@ def __getitem__(self, key): pass else: key = check_array_indexer(self, key) - - freq = self._get_getitem_freq(key) - result = self._ndarray[key] - if lib.is_scalar(result): - return self._box_func(result) - return self._simple_new(result, dtype=self.dtype, freq=freq) + return key def _get_getitem_freq(self, key): """