From 8ba5a074385fc733621644c63468f0309f523b30 Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 25 Oct 2021 21:29:15 -0700 Subject: [PATCH] ENH: implement EA._where --- pandas/core/arrays/_mixins.py | 2 +- pandas/core/arrays/base.py | 25 +++++++++++++++ pandas/core/arrays/sparse/array.py | 7 ++++ pandas/core/internals/blocks.py | 32 ++----------------- .../indexes/categorical/test_indexing.py | 2 +- pandas/tests/series/indexing/test_where.py | 2 +- 6 files changed, 37 insertions(+), 33 deletions(-) diff --git a/pandas/core/arrays/_mixins.py b/pandas/core/arrays/_mixins.py index cf9820c3aa8f8..ddab05bbfa393 100644 --- a/pandas/core/arrays/_mixins.py +++ b/pandas/core/arrays/_mixins.py @@ -320,7 +320,7 @@ def putmask(self, mask: np.ndarray, value) -> None: np.putmask(self._ndarray, mask, value) - def where( + def _where( self: NDArrayBackedExtensionArrayT, mask: np.ndarray, value ) -> NDArrayBackedExtensionArrayT: """ diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 5536a4665fd79..46b505e7384b4 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -1411,6 +1411,31 @@ def insert(self: ExtensionArrayT, loc: int, item) -> ExtensionArrayT: return type(self)._concat_same_type([self[:loc], item_arr, self[loc:]]) + def _where( + self: ExtensionArrayT, mask: npt.NDArray[np.bool_], value + ) -> ExtensionArrayT: + """ + Analogue to np.where(mask, self, value) + + Parameters + ---------- + mask : np.ndarray[bool] + value : scalar or listlike + + Returns + ------- + same type as self + """ + result = self.copy() + + if is_list_like(value): + val = value[~mask] + else: + val = value + + result[~mask] = val + return result + @classmethod def _empty(cls, shape: Shape, dtype: ExtensionDtype): """ diff --git a/pandas/core/arrays/sparse/array.py b/pandas/core/arrays/sparse/array.py index 87fcf54ed684b..8260846ae7dc7 100644 --- a/pandas/core/arrays/sparse/array.py +++ b/pandas/core/arrays/sparse/array.py @@ -1305,6 +1305,13 @@ def to_dense(self) -> np.ndarray: _internal_get_values = to_dense + def _where(self, mask, value): + # NB: may not preserve dtype, e.g. result may be Sparse[float64] + # while self is Sparse[int64] + naive_implementation = np.where(mask, self, value) + result = type(self)._from_sequence(naive_implementation) + return result + # ------------------------------------------------------------------------ # IO # ------------------------------------------------------------------------ diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py index de612b367f78f..1ea6e9fb2cd77 100644 --- a/pandas/core/internals/blocks.py +++ b/pandas/core/internals/blocks.py @@ -49,7 +49,6 @@ is_dtype_equal, is_extension_array_dtype, is_list_like, - is_sparse, is_string_dtype, ) from pandas.core.dtypes.dtypes import ( @@ -1618,34 +1617,7 @@ def where(self, other, cond, errors="raise") -> list[Block]: # for the type other = self.dtype.na_value - if is_sparse(self.values): - # TODO(SparseArray.__setitem__): remove this if condition - # We need to re-infer the type of the data after doing the - # where, for cases where the subtypes don't match - dtype = None - else: - dtype = self.dtype - - result = self.values.copy() - icond = ~cond - if lib.is_scalar(other): - set_other = other - else: - set_other = other[icond] - try: - result[icond] = set_other - except (NotImplementedError, TypeError): - # NotImplementedError for class not implementing `__setitem__` - # TypeError for SparseArray, which implements just to raise - # a TypeError - if isinstance(result, Categorical): - # TODO: don't special-case - raise - - result = type(self.values)._from_sequence( - np.where(cond, self.values, other), dtype=dtype - ) - + result = self.values._where(cond, other) return [self.make_block_same_class(result)] def _unstack( @@ -1736,7 +1708,7 @@ def where(self, other, cond, errors="raise") -> list[Block]: cond = extract_bool_array(cond) try: - res_values = arr.T.where(cond, other).T + res_values = arr.T._where(cond, other).T except (ValueError, TypeError): return Block.where(self, other, cond, errors=errors) diff --git a/pandas/tests/indexes/categorical/test_indexing.py b/pandas/tests/indexes/categorical/test_indexing.py index 798aa7188cb9a..6f8b18f449779 100644 --- a/pandas/tests/indexes/categorical/test_indexing.py +++ b/pandas/tests/indexes/categorical/test_indexing.py @@ -325,7 +325,7 @@ def test_where_non_categories(self): msg = "Cannot setitem on a Categorical with a new category" with pytest.raises(TypeError, match=msg): # Test the Categorical method directly - ci._data.where(mask, 2) + ci._data._where(mask, 2) class TestContains: diff --git a/pandas/tests/series/indexing/test_where.py b/pandas/tests/series/indexing/test_where.py index ed1ba11c5fd55..0adc1810a6c47 100644 --- a/pandas/tests/series/indexing/test_where.py +++ b/pandas/tests/series/indexing/test_where.py @@ -508,7 +508,7 @@ def test_where_datetimelike_categorical(tz_naive_fixture): tm.assert_index_equal(res, dr) # DatetimeArray.where - res = lvals._data.where(mask, rvals) + res = lvals._data._where(mask, rvals) tm.assert_datetime_array_equal(res, dr._data) # Series.where