From 079706b704b5b28b5905e123484dce0f088a4595 Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 26 Jan 2021 08:04:09 -0800 Subject: [PATCH 1/4] BUG: make Index.where behavior mirror Index.putmask behavior --- pandas/core/arrays/interval.py | 10 ++++ pandas/core/indexes/base.py | 51 ++++++++++++-------- pandas/core/indexes/interval.py | 32 ++++++------ pandas/core/indexes/numeric.py | 9 ++-- pandas/tests/series/indexing/test_setitem.py | 11 +++-- 5 files changed, 68 insertions(+), 45 deletions(-) diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index 882ca0955bc99..357a68c217cbd 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -1413,6 +1413,16 @@ def to_tuples(self, na_tuple=True): # --------------------------------------------------------------------- + def putmask(self, mask: np.ndarray, value) -> None: + value_left, value_right = self._validate_setitem_value(value) + + if isinstance(self._left, np.ndarray): + np.putmask(self._left, mask, value_left) + np.putmask(self._right, mask, value_right) + else: + self._left.putmask(mask, value_left) + self._right.putmask(mask, value_right) + def delete(self: IntervalArrayT, loc) -> IntervalArrayT: if isinstance(self._left, np.ndarray): new_left = np.delete(self._left, loc) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 40215ea87f978..ea1f7e832995d 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -41,6 +41,7 @@ from pandas.core.dtypes.cast import ( find_common_type, + infer_dtype_from, maybe_cast_to_integer_array, maybe_promote, validate_numeric_casting, @@ -87,7 +88,7 @@ ABCTimedeltaIndex, ) from pandas.core.dtypes.inference import is_dict_like -from pandas.core.dtypes.missing import array_equivalent, isna +from pandas.core.dtypes.missing import array_equivalent, is_valid_nat_for_dtype, isna from pandas.core import missing, ops from pandas.core.accessor import CachedAccessor @@ -4317,19 +4318,8 @@ def where(self, cond, other=None): >>> idx.where(idx.isin(['car', 'train']), 'other') Index(['car', 'other', 'train', 'other'], dtype='object') """ - if other is None: - other = self._na_value - - values = self.values - - try: - self._validate_fill_value(other) - except (ValueError, TypeError): - return self.astype(object).where(cond, other) - - values = np.where(cond, values, other) - - return Index(values, name=self.name) + cond = np.asarray(cond, dtype=bool) + return self.putmask(~cond, other) # construction helpers @final @@ -4542,16 +4532,24 @@ def putmask(self, mask, value): numpy.ndarray.putmask : Changes elements of an array based on conditional and input values. """ - values = self._values.copy() + mask = np.asarray(mask, dtype=bool) + if mask.shape != self.shape: + raise ValueError("putmask: mask and data must be the same size") + if not mask.any(): + return self.copy() + + if value is None: + value = self._na_value try: converted = self._validate_fill_value(value) except (ValueError, TypeError) as err: if is_object_dtype(self): raise err - # coerces to object - return self.astype(object).putmask(mask, value) + dtype = self._find_common_type_compat(value) + return self.astype(dtype).putmask(mask, value) + values = self._values.copy() np.putmask(values, mask, converted) return self._shallow_copy(values) @@ -5189,18 +5187,31 @@ def _maybe_promote(self, other: Index): return self, other - def _find_common_type_compat(self, target: Index) -> DtypeObj: + @final + def _find_common_type_compat(self, target) -> DtypeObj: """ Implementation of find_common_type that adjusts for Index-specific special cases. """ - dtype = find_common_type([self.dtype, target.dtype]) + if is_interval_dtype(self.dtype) and is_valid_nat_for_dtype(target, self.dtype): + # e.g. setting NA value into IntervalArray[int64] + dtype = IntervalDtype(np.float64, closed=self.closed) + return dtype + + target_dtype, _ = infer_dtype_from(target, pandas_dtype=True) + dtype = find_common_type([self.dtype, target_dtype]) if dtype.kind in ["i", "u"]: # TODO: what about reversed with self being categorical? - if is_categorical_dtype(target.dtype) and target.hasnans: + if ( + isinstance(target, Index) + and is_categorical_dtype(target.dtype) + and target.hasnans + ): # FIXME: find_common_type incorrect with Categorical GH#38240 # FIXME: some cases where float64 cast can be lossy? dtype = np.dtype(np.float64) + if dtype.kind == "c": + dtype = np.dtype(object) return dtype @final diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index 40413bfb40b4b..9841b63029f17 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -799,29 +799,22 @@ def length(self): return Index(self._data.length, copy=False) def putmask(self, mask, value): - arr = self._data.copy() + mask = np.asarray(mask, dtype=bool) + if mask.shape != self.shape: + raise ValueError("putmask: mask and data must be the same size") + if not mask.any(): + return self.copy() + try: - value_left, value_right = arr._validate_setitem_value(value) + self._validate_fill_value(value) except (ValueError, TypeError): - return self.astype(object).putmask(mask, value) + dtype = self._find_common_type_compat(value) + return self.astype(dtype).putmask(mask, value) - if isinstance(self._data._left, np.ndarray): - np.putmask(arr._left, mask, value_left) - np.putmask(arr._right, mask, value_right) - else: - # TODO: special case not needed with __array_function__ - arr._left.putmask(mask, value_left) - arr._right.putmask(mask, value_right) + arr = self._data.copy() + arr.putmask(mask, value) return type(self)._simple_new(arr, name=self.name) - @Appender(Index.where.__doc__) - def where(self, cond, other=None): - if other is None: - other = self._na_value - values = np.where(cond, self._values, other) - result = IntervalArray(values) - return type(self)._simple_new(result, name=self.name) - def insert(self, loc, item): """ Return a new IntervalIndex inserting new item at location. Follows @@ -998,6 +991,9 @@ def func(self, other, sort=sort): # -------------------------------------------------------------------- + def _validate_fill_value(self, value): + return self._data._validate_setitem_value(value) + @property def _is_all_dates(self) -> bool: """ diff --git a/pandas/core/indexes/numeric.py b/pandas/core/indexes/numeric.py index 26599bd6ab871..d4b1f206583fb 100644 --- a/pandas/core/indexes/numeric.py +++ b/pandas/core/indexes/numeric.py @@ -152,9 +152,12 @@ def _validate_fill_value(self, value): raise TypeError value = int(value) - elif hasattr(value, "dtype") and value.dtype.kind in ["m", "M"]: - # TODO: if we're checking arraylike here, do so systematically - raise TypeError + elif hasattr(value, "dtype"): + if value.dtype.kind in ["m", "M"]: + raise TypeError + if value.dtype.kind == "f" and self.dtype.kind in ["i", "u"]: + # TODO: maybe OK if value is castable? + raise TypeError return value diff --git a/pandas/tests/series/indexing/test_setitem.py b/pandas/tests/series/indexing/test_setitem.py index 7f469f361fec7..02689fafb775b 100644 --- a/pandas/tests/series/indexing/test_setitem.py +++ b/pandas/tests/series/indexing/test_setitem.py @@ -331,8 +331,7 @@ def test_index_where(self, obj, key, expected, request): mask = np.zeros(obj.shape, dtype=bool) mask[key] = True - if obj.dtype == bool and not mask.all(): - # When mask is all True, casting behavior does not apply + if obj.dtype == bool: msg = "Index/Series casting behavior inconsistent GH#38692" mark = pytest.mark.xfail(reason=msg) request.node.add_marker(mark) @@ -340,11 +339,15 @@ def test_index_where(self, obj, key, expected, request): res = Index(obj).where(~mask, np.nan) tm.assert_index_equal(res, Index(expected)) - @pytest.mark.xfail(reason="Index/Series casting behavior inconsistent GH#38692") - def test_index_putmask(self, obj, key, expected): + def test_index_putmask(self, obj, key, expected, request): mask = np.zeros(obj.shape, dtype=bool) mask[key] = True + if obj.dtype == bool: + msg = "Index/Series casting behavior inconsistent GH#38692" + mark = pytest.mark.xfail(reason=msg) + request.node.add_marker(mark) + res = Index(obj).putmask(mask, np.nan) tm.assert_index_equal(res, Index(expected)) From 1e42b33e52abb6c9c6986bd76805db4596e2eb70 Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 26 Jan 2021 08:51:47 -0800 Subject: [PATCH 2/4] mypy fixup --- pandas/core/indexes/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index ea1f7e832995d..308c5a9cc2a3b 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -115,7 +115,7 @@ ) if TYPE_CHECKING: - from pandas import MultiIndex, RangeIndex, Series + from pandas import IntervalIndex, MultiIndex, RangeIndex, Series from pandas.core.indexes.datetimelike import DatetimeIndexOpsMixin @@ -5195,6 +5195,7 @@ def _find_common_type_compat(self, target) -> DtypeObj: """ if is_interval_dtype(self.dtype) and is_valid_nat_for_dtype(target, self.dtype): # e.g. setting NA value into IntervalArray[int64] + self = cast("IntervalIndex", self) dtype = IntervalDtype(np.float64, closed=self.closed) return dtype From a16ab2826ab23eda27d8b601d357a90aa2274527 Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 26 Jan 2021 09:30:42 -0800 Subject: [PATCH 3/4] mypy fixup --- pandas/core/indexes/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 308c5a9cc2a3b..494c2c6d19e4a 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -5196,8 +5196,7 @@ def _find_common_type_compat(self, target) -> DtypeObj: if is_interval_dtype(self.dtype) and is_valid_nat_for_dtype(target, self.dtype): # e.g. setting NA value into IntervalArray[int64] self = cast("IntervalIndex", self) - dtype = IntervalDtype(np.float64, closed=self.closed) - return dtype + return IntervalDtype(np.float64, closed=self.closed) target_dtype, _ = infer_dtype_from(target, pandas_dtype=True) dtype = find_common_type([self.dtype, target_dtype]) From 7bac106f1188e07a5ad1353182611030078dcd75 Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 27 Jan 2021 19:25:47 -0800 Subject: [PATCH 4/4] whatsnew --- doc/source/whatsnew/v1.3.0.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/source/whatsnew/v1.3.0.rst b/doc/source/whatsnew/v1.3.0.rst index 5b56d43348957..adf13e40f0bb6 100644 --- a/doc/source/whatsnew/v1.3.0.rst +++ b/doc/source/whatsnew/v1.3.0.rst @@ -389,6 +389,7 @@ Other - Bug in :class:`Index` constructor sometimes silently ignorning a specified ``dtype`` (:issue:`38879`) - Bug in constructing a :class:`Series` from a list and a :class:`PandasDtype` (:issue:`39357`) - Bug in :class:`Styler` which caused CSS to duplicate on multiple renders. (:issue:`39395`) +- :meth:`Index.where` behavior now mirrors :meth:`Index.putmask` behavior, i.e. ``index.where(mask, other)`` matches ``index.putmask(~mask, other)`` (:issue:`39412`) - .. ---------------------------------------------------------------------------