diff --git a/pandas/core/generic.py b/pandas/core/generic.py index c20b2840a40ab..93c597d738501 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -8933,7 +8933,7 @@ def _where( # align the cond to same shape as myself cond = com.apply_if_callable(cond, self) if isinstance(cond, NDFrame): - cond, _ = cond.align(self, join="right", broadcast_axis=1) + cond, _ = cond.align(self, join="right", broadcast_axis=1, copy=False) else: if not hasattr(cond, "shape"): cond = np.asanyarray(cond) @@ -8961,6 +8961,7 @@ def _where( cond = cond.astype(bool) cond = -cond if inplace else cond + cond = cond.reindex(self._info_axis, axis=self._info_axis_number, copy=False) # try to align with other if isinstance(other, NDFrame): @@ -8997,7 +8998,7 @@ def _where( "cannot align with a higher dimensional NDFrame" ) - if not isinstance(other, (MultiIndex, NDFrame)): + elif not isinstance(other, (MultiIndex, NDFrame)): # mainly just catching Index here other = extract_array(other, extract_numpy=True) @@ -9029,11 +9030,6 @@ def _where( else: align = self._get_axis_number(axis) == 1 - if isinstance(cond, NDFrame): - cond = cond.reindex( - self._info_axis, axis=self._info_axis_number, copy=False - ) - if inplace: # we may have different type blocks come out of putmask, so # reconstruct the block manager @@ -9049,7 +9045,6 @@ def _where( cond=cond, align=align, errors=errors, - axis=axis, ) result = self._constructor(new_data) return result.__finalize__(self) diff --git a/pandas/core/internals/array_manager.py b/pandas/core/internals/array_manager.py index 34b3d83c066c2..435d2421ccade 100644 --- a/pandas/core/internals/array_manager.py +++ b/pandas/core/internals/array_manager.py @@ -521,7 +521,7 @@ def quantile( axes = [qs, self._axes[1]] return type(self)(new_arrs, axes) - def where(self, other, cond, align: bool, errors: str, axis: int) -> ArrayManager: + def where(self, other, cond, align: bool, errors: str) -> ArrayManager: if align: align_keys = ["other", "cond"] else: @@ -534,7 +534,6 @@ def where(self, other, cond, align: bool, errors: str, axis: int) -> ArrayManage other=other, cond=cond, errors=errors, - axis=axis, ) # TODO what is this used for? diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py index 174ea8760b0db..f38202f1e3476 100644 --- a/pandas/core/internals/blocks.py +++ b/pandas/core/internals/blocks.py @@ -1286,7 +1286,7 @@ def shift(self, periods: int, axis: int = 0, fill_value: Any = None) -> List[Blo return [self.make_block(new_values)] - def where(self, other, cond, errors="raise", axis: int = 0) -> List[Block]: + def where(self, other, cond, errors="raise") -> List[Block]: """ evaluate the block; return result block(s) from the result @@ -1297,7 +1297,6 @@ def where(self, other, cond, errors="raise", axis: int = 0) -> List[Block]: errors : str, {'raise', 'ignore'}, default 'raise' - ``raise`` : allow exceptions to be raised - ``ignore`` : suppress exceptions. On error return original object - axis : int, default 0 Returns ------- @@ -1305,6 +1304,7 @@ def where(self, other, cond, errors="raise", axis: int = 0) -> List[Block]: """ import pandas.core.computation.expressions as expressions + assert cond.ndim == self.ndim assert not isinstance(other, (ABCIndex, ABCSeries, ABCDataFrame)) assert errors in ["raise", "ignore"] @@ -1317,7 +1317,7 @@ def where(self, other, cond, errors="raise", axis: int = 0) -> List[Block]: icond, noop = validate_putmask(values, ~cond) - if is_valid_na_for_dtype(other, self.dtype) and not self.is_object: + if is_valid_na_for_dtype(other, self.dtype) and self.dtype != _dtype_obj: other = self.fill_value if noop: @@ -1330,7 +1330,7 @@ def where(self, other, cond, errors="raise", axis: int = 0) -> List[Block]: # we cannot coerce, return a compat dtype # we are explicitly ignoring errors block = self.coerce_to_target_dtype(other) - blocks = block.where(orig_other, cond, errors=errors, axis=axis) + blocks = block.where(orig_other, cond, errors=errors) return self._maybe_downcast(blocks, "infer") # error: Argument 1 to "setitem_datetimelike_compat" has incompatible type @@ -1359,7 +1359,7 @@ def where(self, other, cond, errors="raise", axis: int = 0) -> List[Block]: cond = ~icond axis = cond.ndim - 1 cond = cond.swapaxes(axis, 0) - mask = np.array([cond[i].all() for i in range(cond.shape[0])], dtype=bool) + mask = cond.all(axis=1) result_blocks: List[Block] = [] for m in [mask, ~mask]: @@ -1670,7 +1670,7 @@ def shift(self, periods: int, axis: int = 0, fill_value: Any = None) -> List[Blo new_values = self.values.shift(periods=periods, fill_value=fill_value) return [self.make_block_same_class(new_values)] - def where(self, other, cond, errors="raise", axis: int = 0) -> List[Block]: + def where(self, other, cond, errors="raise") -> List[Block]: cond = extract_bool_array(cond) assert not isinstance(other, (ABCIndex, ABCSeries, ABCDataFrame)) @@ -1837,7 +1837,7 @@ def putmask(self, mask, new) -> List[Block]: arr.T.putmask(mask, new) return [self] - def where(self, other, cond, errors="raise", axis: int = 0) -> List[Block]: + def where(self, other, cond, errors="raise") -> List[Block]: # TODO(EA2D): reshape unnecessary with 2D EAs arr = self.array_values().reshape(self.shape) @@ -1846,7 +1846,7 @@ def where(self, other, cond, errors="raise", axis: int = 0) -> List[Block]: try: res_values = arr.T.where(cond, other).T except (ValueError, TypeError): - return super().where(other, cond, errors=errors, axis=axis) + return super().where(other, cond, errors=errors) # TODO(EA2D): reshape not needed with 2D EAs res_values = res_values.reshape(self.values.shape) diff --git a/pandas/core/internals/managers.py b/pandas/core/internals/managers.py index da78fc5dfba76..f4eafd882b62b 100644 --- a/pandas/core/internals/managers.py +++ b/pandas/core/internals/managers.py @@ -574,8 +574,7 @@ def quantile( return type(self)(blocks, new_axes) - def where(self, other, cond, align: bool, errors: str, axis: int) -> BlockManager: - axis = self._normalize_axis(axis) + def where(self, other, cond, align: bool, errors: str) -> BlockManager: if align: align_keys = ["other", "cond"] else: @@ -588,7 +587,6 @@ def where(self, other, cond, align: bool, errors: str, axis: int) -> BlockManage other=other, cond=cond, errors=errors, - axis=axis, ) def setitem(self, indexer, value) -> BlockManager: